{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualization of map_raster test\n",
    "\n",
    "This notebook visualizes the test for `map_raster` using:\n",
    "- Real S1A SAR coordinates (Gulf of California)\n",
    "- Synthetic ECMWF wind data\n",
    "\n",
    "Source: `tests/test_map_raster_4cases.py` (comprehensive 4-case test suite)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "# Add parent directories to path to import from tests and mapraster\n",
    "sys.path.insert(0, str(Path.cwd().parent.parent))\n",
    "\n",
    "import numpy as np\n",
    "import xarray as xr\n",
    "import matplotlib.pyplot as plt\n",
    "from mapraster.main import map_raster\n",
    "from tests.tools_test import fake_ecmwf_0100_1h, build_footprint\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Create SAR dataset with real S1A coordinates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_real_sar_dataset():\n",
    "    \"\"\"\n",
    "    Create SAR dataset with real S1A coordinates from Gulf of California.\n",
    "    Based on s1a-iw-owi-dv-20210909t130650-20210909t130715-039605-04AE83.nc\n",
    "    \n",
    "    Returns synthetic dataset with bilinear interpolation between corner points.\n",
    "    \"\"\"\n",
    "    # Real corner coordinates from S1A SAR\n",
    "    corners = {\n",
    "        'lon': np.array([-106.73246, -109.26672, -109.43854, -106.90172]),\n",
    "        'lat': np.array([21.72079, 22.14191, 20.64836, 20.22559])\n",
    "    }\n",
    "    \n",
    "    # SAR grid dimensions\n",
    "    nlines, nsamples = 167, 256\n",
    "    \n",
    "    # Create meshgrid indices\n",
    "    lines = np.arange(nlines)\n",
    "    samples = np.arange(nsamples)\n",
    "    \n",
    "    # Bilinear interpolation\n",
    "    s_norm = samples / (nsamples - 1)\n",
    "    l_norm = lines / (nlines - 1)\n",
    "    \n",
    "    # Top edge (line=0): interpolate between corners[0] and corners[1]\n",
    "    lon_top = corners['lon'][0] + s_norm * (corners['lon'][1] - corners['lon'][0])\n",
    "    lat_top = corners['lat'][0] + s_norm * (corners['lat'][1] - corners['lat'][0])\n",
    "    \n",
    "    # Bottom edge (line=-1): interpolate between corners[3] and corners[2]\n",
    "    lon_bottom = corners['lon'][3] + s_norm * (corners['lon'][2] - corners['lon'][3])\n",
    "    lat_bottom = corners['lat'][3] + s_norm * (corners['lat'][2] - corners['lat'][3])\n",
    "    \n",
    "    # Full grid: interpolate between top and bottom edges\n",
    "    lon_grid = lon_top[None, :] + l_norm[:, None] * (lon_bottom[None, :] - lon_top[None, :])\n",
    "    lat_grid = lat_top[None, :] + l_norm[:, None] * (lat_bottom[None, :] - lat_top[None, :])\n",
    "    \n",
    "    # Create xarray Dataset\n",
    "    ds = xr.Dataset(\n",
    "        {\n",
    "            'longitude': (['line', 'sample'], lon_grid),\n",
    "            'latitude': (['line', 'sample'], lat_grid),\n",
    "        },\n",
    "        coords={\n",
    "            'line': lines,\n",
    "            'sample': samples,\n",
    "        }\n",
    "    )\n",
    "    \n",
    "    return ds\n",
    "\n",
    "sar_dataset = create_real_sar_dataset()\n",
    "print(f\"SAR shape: {sar_dataset['longitude'].shape}\")\n",
    "print(f\"SAR lon range: [{sar_dataset['longitude'].min().values:.2f}, {sar_dataset['longitude'].max().values:.2f}]\")\n",
    "print(f\"SAR lat range: [{sar_dataset['latitude'].min().values:.2f}, {sar_dataset['latitude'].max().values:.2f}]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sar_dataset\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Create synthetic ECMWF wind data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ecmwf_dataset = fake_ecmwf_0100_1h(to180=True, with_nan=False)\n",
    "print(f\"ECMWF shape: {ecmwf_dataset['U10'].shape}\")\n",
    "print(f\"ECMWF x range: [{ecmwf_dataset['x'].min().values:.2f}, {ecmwf_dataset['x'].max().values:.2f}]\")\n",
    "print(f\"ECMWF y range: [{ecmwf_dataset['y'].min().values:.2f}, {ecmwf_dataset['y'].max().values:.2f}]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ecmwf_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Run map_raster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "footprint = build_footprint(sar_dataset)\n",
    "\n",
    "result = map_raster(\n",
    "    raster_ds=ecmwf_dataset,\n",
    "    originalDataset=sar_dataset,\n",
    "    footprint=footprint,\n",
    "    cross_antimeridian=False,\n",
    ")\n",
    "\n",
    "print(f\"Result shape: {result['U10'].shape}\")\n",
    "print(f\"Result U10 range: [{result['U10'].min().values:.2f}, {result['U10'].max().values:.2f}]\")\n",
    "print(f\"Result V10 range: [{result['V10'].min().values:.2f}, {result['V10'].max().values:.2f}]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Visualizations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 SAR grid geometry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
    "\n",
    "# Longitude\n",
    "im1 = ax1.pcolormesh(\n",
    "    sar_dataset['sample'],\n",
    "    sar_dataset['line'],\n",
    "    sar_dataset['longitude'],\n",
    "    shading='auto',\n",
    "    cmap='viridis'\n",
    ")\n",
    "ax1.set_xlabel('Sample')\n",
    "ax1.set_ylabel('Line')\n",
    "ax1.set_title('SAR Longitude Grid')\n",
    "plt.colorbar(im1, ax=ax1, label='Longitude (°)')\n",
    "\n",
    "# Latitude\n",
    "im2 = ax2.pcolormesh(\n",
    "    sar_dataset['sample'],\n",
    "    sar_dataset['line'],\n",
    "    sar_dataset['latitude'],\n",
    "    shading='auto',\n",
    "    cmap='plasma'\n",
    ")\n",
    "ax2.set_xlabel('Sample')\n",
    "ax2.set_ylabel('Line')\n",
    "ax2.set_title('SAR Latitude Grid')\n",
    "plt.colorbar(im2, ax=ax2, label='Latitude (°)')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 ECMWF wind fields (global)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(14, 12))\n",
    "\n",
    "# Wind speed\n",
    "wind_speed_ecmwf = np.sqrt(ecmwf_dataset['U10']**2 + ecmwf_dataset['V10']**2)\n",
    "im0 = ax1.pcolormesh(\n",
    "    ecmwf_dataset['x'],\n",
    "    ecmwf_dataset['y'],\n",
    "    wind_speed_ecmwf,\n",
    "    shading='auto',\n",
    "    cmap='viridis',\n",
    "    vmin=0,\n",
    "    vmax=15\n",
    ")\n",
    "ax1.set_xlabel('Longitude (°)')\n",
    "ax1.set_ylabel('Latitude (°)')\n",
    "ax1.set_title('ECMWF Wind Speed (global)')\n",
    "plt.colorbar(im0, ax=ax1, label='Wind Speed (m/s)')\n",
    "\n",
    "# U10\n",
    "im1 = ax2.pcolormesh(\n",
    "    ecmwf_dataset['x'],\n",
    "    ecmwf_dataset['y'],\n",
    "    ecmwf_dataset['U10'],\n",
    "    shading='auto',\n",
    "    cmap='RdBu_r',\n",
    "    vmin=-10,\n",
    "    vmax=10\n",
    ")\n",
    "ax2.set_xlabel('Longitude (°)')\n",
    "ax2.set_ylabel('Latitude (°)')\n",
    "ax2.set_title('ECMWF U10 Wind Component (global)')\n",
    "plt.colorbar(im1, ax=ax2, label='U10 (m/s)')\n",
    "\n",
    "# V10\n",
    "im2 = ax3.pcolormesh(\n",
    "    ecmwf_dataset['x'],\n",
    "    ecmwf_dataset['y'],\n",
    "    ecmwf_dataset['V10'],\n",
    "    shading='auto',\n",
    "    cmap='RdBu_r',\n",
    "    vmin=-10,\n",
    "    vmax=10\n",
    ")\n",
    "ax3.set_xlabel('Longitude (°)')\n",
    "ax3.set_ylabel('Latitude (°)')\n",
    "ax3.set_title('ECMWF V10 Wind Component (global)')\n",
    "plt.colorbar(im2, ax=ax3, label='V10 (m/s)')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.3 ECMWF wind fields (SAR region zoom)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract SAR region from ECMWF\n",
    "lon_min, lon_max = sar_dataset['longitude'].min().values, sar_dataset['longitude'].max().values\n",
    "lat_min, lat_max = sar_dataset['latitude'].min().values, sar_dataset['latitude'].max().values\n",
    "\n",
    "ecmwf_subset = ecmwf_dataset.sel(\n",
    "    x=slice(lon_min-1, lon_max+1),\n",
    "    y=slice(lat_min-1, lat_max+1)\n",
    ")\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
    "\n",
    "# U10 zoom\n",
    "im1 = ax1.pcolormesh(\n",
    "    ecmwf_subset['x'],\n",
    "    ecmwf_subset['y'],\n",
    "    ecmwf_subset['U10'],\n",
    "    shading='auto',\n",
    "    cmap='RdBu_r'\n",
    ")\n",
    "ax1.set_xlabel('Longitude (°)')\n",
    "ax1.set_ylabel('Latitude (°)')\n",
    "ax1.set_title('ECMWF U10 (SAR region)')\n",
    "plt.colorbar(im1, ax=ax1, label='U10 (m/s)')\n",
    "\n",
    "# Add SAR footprint\n",
    "lon_corners = sar_dataset['longitude'].values[[0, 0, -1, -1, 0], [0, -1, -1, 0, 0]]\n",
    "lat_corners = sar_dataset['latitude'].values[[0, 0, -1, -1, 0], [0, -1, -1, 0, 0]]\n",
    "ax1.plot(lon_corners, lat_corners, 'k-', linewidth=2, label='SAR footprint')\n",
    "ax1.legend()\n",
    "\n",
    "# V10 zoom\n",
    "im2 = ax2.pcolormesh(\n",
    "    ecmwf_subset['x'],\n",
    "    ecmwf_subset['y'],\n",
    "    ecmwf_subset['V10'],\n",
    "    shading='auto',\n",
    "    cmap='RdBu_r'\n",
    ")\n",
    "ax2.set_xlabel('Longitude (°)')\n",
    "ax2.set_ylabel('Latitude (°)')\n",
    "ax2.set_title('ECMWF V10 (SAR region)')\n",
    "plt.colorbar(im2, ax=ax2, label='V10 (m/s)')\n",
    "\n",
    "# Add SAR footprint\n",
    "ax2.plot(lon_corners, lat_corners, 'k-', linewidth=2, label='SAR footprint')\n",
    "ax2.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.4 Interpolated wind on SAR grid (result of map_raster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
    "\n",
    "# U10 interpolated\n",
    "im1 = ax1.pcolormesh(\n",
    "    sar_dataset['sample'],\n",
    "    sar_dataset['line'],\n",
    "    result['U10'],\n",
    "    shading='auto',\n",
    "    cmap='RdBu_r'\n",
    ")\n",
    "ax1.set_xlabel('Sample')\n",
    "ax1.set_ylabel('Line')\n",
    "ax1.set_title('Interpolated U10 on SAR grid')\n",
    "plt.colorbar(im1, ax=ax1, label='U10 (m/s)')\n",
    "\n",
    "# V10 interpolated\n",
    "im2 = ax2.pcolormesh(\n",
    "    sar_dataset['sample'],\n",
    "    sar_dataset['line'],\n",
    "    result['V10'],\n",
    "    shading='auto',\n",
    "    cmap='RdBu_r'\n",
    ")\n",
    "ax2.set_xlabel('Sample')\n",
    "ax2.set_ylabel('Line')\n",
    "ax2.set_title('Interpolated V10 on SAR grid')\n",
    "plt.colorbar(im2, ax=ax2, label='V10 (m/s)')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.5 Before/After comparison in lon/lat coordinates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
    "\n",
    "# Extract SAR coordinates\n",
    "sar_lon = sar_dataset['longitude'].values\n",
    "sar_lat = sar_dataset['latitude'].values\n",
    "\n",
    "# U10: Before (ECMWF subset)\n",
    "ax1 = axes[0, 0]\n",
    "im1 = ax1.pcolormesh(\n",
    "    ecmwf_subset['x'],\n",
    "    ecmwf_subset['y'],\n",
    "    ecmwf_subset['U10'],\n",
    "    shading='auto',\n",
    "    cmap='RdBu_r'\n",
    ")\n",
    "ax1.plot(lon_corners, lat_corners, 'k-', linewidth=2, label='SAR footprint')\n",
    "ax1.set_xlabel('Longitude (°)')\n",
    "ax1.set_ylabel('Latitude (°)')\n",
    "ax1.set_title('ECMWF U10 (Before interpolation)')\n",
    "ax1.legend()\n",
    "plt.colorbar(im1, ax=ax1, label='U10 (m/s)')\n",
    "\n",
    "# U10: After (interpolated on SAR grid)\n",
    "ax2 = axes[0, 1]\n",
    "im2 = ax2.scatter(\n",
    "    sar_lon.flatten(),\n",
    "    sar_lat.flatten(),\n",
    "    c=result['U10'].values.flatten(),\n",
    "    s=5,\n",
    "    cmap='RdBu_r',\n",
    "    vmin=im1.get_clim()[0],\n",
    "    vmax=im1.get_clim()[1]\n",
    ")\n",
    "ax2.set_xlabel('Longitude (°)')\n",
    "ax2.set_ylabel('Latitude (°)')\n",
    "ax2.set_title('Interpolated U10 (After map_raster)')\n",
    "plt.colorbar(im2, ax=ax2, label='U10 (m/s)')\n",
    "\n",
    "# V10: Before (ECMWF subset)\n",
    "ax3 = axes[1, 0]\n",
    "im3 = ax3.pcolormesh(\n",
    "    ecmwf_subset['x'],\n",
    "    ecmwf_subset['y'],\n",
    "    ecmwf_subset['V10'],\n",
    "    shading='auto',\n",
    "    cmap='RdBu_r'\n",
    ")\n",
    "ax3.plot(lon_corners, lat_corners, 'k-', linewidth=2, label='SAR footprint')\n",
    "ax3.set_xlabel('Longitude (°)')\n",
    "ax3.set_ylabel('Latitude (°)')\n",
    "ax3.set_title('ECMWF V10 (Before interpolation)')\n",
    "ax3.legend()\n",
    "plt.colorbar(im3, ax=ax3, label='V10 (m/s)')\n",
    "\n",
    "# V10: After (interpolated on SAR grid)\n",
    "ax4 = axes[1, 1]\n",
    "im4 = ax4.scatter(\n",
    "    sar_lon.flatten(),\n",
    "    sar_lat.flatten(),\n",
    "    c=result['V10'].values.flatten(),\n",
    "    s=5,\n",
    "    cmap='RdBu_r',\n",
    "    vmin=im3.get_clim()[0],\n",
    "    vmax=im3.get_clim()[1]\n",
    ")\n",
    "ax4.set_xlabel('Longitude (°)')\n",
    "ax4.set_ylabel('Latitude (°)')\n",
    "ax4.set_title('Interpolated V10 (After map_raster)')\n",
    "plt.colorbar(im4, ax=ax4, label='V10 (m/s)')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_xsar",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
