Skip to content

GDAS Conventional Observation DataFrameSource#795

Open
NickGeneva wants to merge 2 commits intoNVIDIA:mainfrom
NickGeneva:feature/gdas-prepbufr-datasource
Open

GDAS Conventional Observation DataFrameSource#795
NickGeneva wants to merge 2 commits intoNVIDIA:mainfrom
NickGeneva:feature/gdas-prepbufr-datasource

Conversation

@NickGeneva
Copy link
Copy Markdown
Collaborator

@NickGeneva NickGeneva commented Apr 6, 2026

Summary

Add NomadsGDASObsConv DataFrameSource that fetches conventional (in-situ) observations from NOAA's NOMADS GDAS PrepBUFR files in real time (within past 6 hours).
This is public domain data. https://www.weather.gov/disclaimer

  • Supports temperature, humidity, pressure, and wind (u/v) observations from surface stations, upper-air radiosondes, aircraft, ships, buoys, satellite-derived winds, and GPS precipitable water
  • Decodes PrepBUFR files using pybufrkit with parallel process-pool decoding
  • Outputs long-format DataFrame with per-observation lat/lon, pressure level, station ID, observation class, quality markers, and unit-converted values
  • Configurable time tolerance, async NOMADS HTTP access with retry, and optional disk caching

Changes

New files

File Description
earth2studio/data/gdas.py NomadsGDASObsConv DataFrameSource (~1580 lines)
earth2studio/lexicon/gdas.py GDASObsConvLexicon with PrepBUFR mnemonic mappings and unit conversions (TOB→K, QOB→kg/kg, POB→Pa)
test/data/test_gdas.py Unit tests (offline + network/xfail)

Modified files

File Change
earth2studio/lexicon/base.py Added quality (uint16) field to E2STUDIO_SCHEMA for QC markers
earth2studio/lexicon/__init__.py Registered GDASObsConvLexicon export
earth2studio/data/__init__.py Registered NomadsGDASObsConv export
docs/modules/datasources_dataframe.rst Added data.NomadsGDASObsConv to autosummary
CHANGELOG.md Added entries under 0.14.0a0

Usage

from earth2studio.data import NomadsGDASObsConv
from datetime import datetime, timedelta

ds = NomadsGDASObsConv(
    time_tolerance=timedelta(minutes=180),  # ±3h window
    cache=True,
)
df = ds(datetime(2026, 4, 4, 0), ["t", "pres", "u", "v"])
# DataFrame columns: time, lat, lon, variable, observation,
#   pres, station_id, elevation, class, quality

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 6, 2026

Greptile Summary

This PR adds NomadsGDASObsConv, a DataFrameSource (~1593 lines) that fetches real-time conventional observations (radiosondes, surface stations, aircraft, ships, GPS-PW) from NOAA NOMADS PrepBUFR files, plus the companion GDASObsConvLexicon with unit conversions (TOB→K, QOB→kg/kg, POB→Pa). The two P1 issues from prior review rounds—per-variable quality-marker mis-assignment and an infinite loop on zero-length BUFR messages—are both confirmed resolved in this version.

Confidence Score: 5/5

Safe to merge; all prior P1 findings have been addressed and no new blocking issues found.

Both previously-flagged P1 bugs are confirmed fixed. The decode pipeline, per-variable quality-marker mapping, DX table extraction, parallel worker initialization, lexicon unit conversions, and schema addition are all correct. Remaining observations are minor P2 style items (module-level constant allocation, single-component wind edge case) that do not affect correctness or data integrity.

No files require special attention.

Important Files Changed

Filename Overview
earth2studio/data/gdas.py New NomadsGDASObsConv DataFrameSource; all previously-flagged P1 bugs (quality-marker mis-assignment, zero-length BUFR infinite loop, Table D parsing) are confirmed resolved; decode pipeline and parallel worker init are correct
earth2studio/lexicon/gdas.py New GDASObsConvLexicon with correct unit conversions: TOB→K, QOB→kg/kg, POB→Pa, UOB/VOB identity (already in m s⁻¹)
earth2studio/lexicon/base.py Added nullable uint16 'quality' field to E2STUDIO_SCHEMA; correct type and range for 4-bit PrepBUFR QC markers
test/data/test_gdas.py Comprehensive offline suite (schema, exceptions, URL builder, lexicon modifiers, cache, mock end-to-end); network tests correctly marked xfail
earth2studio/data/init.py Registered NomadsGDASObsConv export
earth2studio/lexicon/init.py Registered GDASObsConvLexicon export

Reviews (2): Last reviewed commit: "Greptile fixes" | Re-trigger Greptile

@NickGeneva
Copy link
Copy Markdown
Collaborator Author

NickGeneva commented Apr 6, 2026

Sanity Check Validation

Ran a full PrepBUFR sanity check using NomadsGDASObsConv for a recent ~12h-ago GDAS cycle with ±3h tolerance and 16 decode workers. Validated physical ranges, spatial coverage, observation classes, quality marks, and generated GFS overlay scatter plots.

Metric Value
Source NomadsGDASObsConv (NOMADS PrepBUFR)
Variables t, q, pres, u, v
Time tolerance ±3h
GFS overlay tolerance ±1h
Decode workers 16

Key findings:

  • All 5 variables produce physically reasonable observation values within expected bounds
  • Global station coverage spans all 4 hemispheres with thousands of unique stations
  • Observation classes match known PrepBUFR types (ADPSFC, SFCSHP, ADPUPA, AIRCFT, AIRCAR, SATWND, etc.)
  • Quality marks predominantly ≤ 3 (good) as expected for GDAS QC'd data
  • Temperature, wind, pressure, and humidity observations visually match GFS analysis fields in the overlay plots
  • Vertical profiles show physically consistent temperature/humidity lapse rates

Diagnostic plots

Station map — obs locations by variable:

station_map

Observation value distributions:

obs_histograms

Observation class distribution:

class_distribution

Quality mark distribution:

quality_marks

Vertical profiles (temperature and humidity vs pressure):

vertical_profiles

GFS overlay scatter plots

Temperature obs on GFS T2m:

obs_temp_on_gfs_t2m

Wind speed obs on GFS 10m wind:

obs_q_on_gfs_q2m

Surface pressure obs on GFS SP:

obs_pres_on_gfs_sp

Humidity obs on GFS Q2m:

obs_q_on_gfs_q2m
Sanity check script (click to expand)
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Sanity check for NomadsGDASObsConv PrepBUFR data source.

Uses the NomadsGDASObsConv data source to fetch recent GDAS observations
from NOMADS, validates that the numerical ranges and spatial coverage are
physically reasonable, and generates diagnostic plots.

Usage
-----
    python prepbufr_research/sanity_check.py [--hours-ago 12] [--workers N]
"""

from __future__ import annotations

import argparse
import sys
import time
from datetime import datetime, timedelta, timezone
from pathlib import Path

import matplotlib

matplotlib.use("Agg")  # headless backend
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from earth2studio.data.gdas import NomadsGDASObsConv, PREPBUFR_OBS_TYPES
from earth2studio.data.gfs import GFS

# ── Physical bounds for each variable ────────────────────────────────
# (variable_name, unit, min_sane, max_sane)
# NOTE: The pres variable observation value is the raw POB in MB (hPa).
#       The pres *column* in the DataFrame is the level pressure in Pa.
VARIABLE_BOUNDS: dict[str, tuple[str, float, float]] = {
    "t": ("DEG C", -95.0, 60.0),
    "q": ("mg/kg", 0.0, 45_000.0),
    "pres": ("MB (hPa)", 0.0, 1100.0),
    "u": ("m/s", -150.0, 150.0),
    "v": ("m/s", -150.0, 150.0),
}

# Spatial bounds
LAT_RANGE = (-90.0, 90.0)
LON_RANGE = (-180.0, 360.0)  # PrepBUFR can use [0,360] or [-180,180]
ELEV_RANGE = (-500.0, 15000.0)  # m (Dead Sea ~-430, aircraft ~13700, satellite ~14600)
# PrepBUFR uses 9999.0 as a missing value sentinel for elevation
ELEV_MISSING_SENTINEL = 9999.0


def fetch_all_variables(
    obs_time: datetime,
    workers: int = 8,
    tolerance_hours: float = 3.0,
) -> dict[str, pd.DataFrame]:
    """Fetch GDAS observations via the NomadsGDASObsConv data source.

    Parameters
    ----------
    obs_time : datetime
        Centre of the observation time window (UTC, naive).
    workers : int
        Number of parallel BUFR decode workers.
    tolerance_hours : float
        Symmetric time tolerance (hours) around *obs_time*.

    Returns a dict mapping variable name → DataFrame.
    """
    primary_vars = ["t", "q", "pres", "u", "v"]

    print(f"\nFetching variables {primary_vars}")
    print(f"Observation time: {obs_time}")
    print(f"Time tolerance: ±{tolerance_hours}h")
    print(f"Decode workers: {workers}")

    ds = NomadsGDASObsConv(
        time_tolerance=np.timedelta64(int(tolerance_hours * 60), "m"),
        decode_workers=workers,
        cache=True,
        verbose=True,
    )

    t0 = time.perf_counter()
    df = ds(time=obs_time, variable=primary_vars)
    elapsed = time.perf_counter() - t0

    print(f"Fetch + decode time: {elapsed:.1f}s")
    print(f"Total rows: {len(df):,}")

    # Split by variable
    result: dict[str, pd.DataFrame] = {}
    for var in primary_vars:
        sub = df[df["variable"] == var]
        result[var] = sub
        print(f"  {var:>6s}: {len(sub):>8,} rows")

    return result


def validate_ranges(dfs: dict[str, pd.DataFrame]) -> list[str]:
    """Check that observation values fall within sane physical bounds.

    Returns a list of failure messages (empty = all OK).
    """
    failures: list[str] = []

    for var, df in dfs.items():
        if df.empty:
            failures.append(f"[{var}] No data returned")
            continue

        obs = df["observation"].dropna()
        if obs.empty:
            failures.append(f"[{var}] All observations are NaN")
            continue

        bounds = VARIABLE_BOUNDS.get(var)
        if bounds is None:
            continue

        unit, lo, hi = bounds

        obs_min, obs_max = obs.min(), obs.max()
        obs_mean = obs.mean()
        obs_median = obs.median()

        # Check bounds
        below = (obs < lo).sum()
        above = (obs > hi).sum()
        total = len(obs)

        print(f"\n  [{var}] unit={unit}")
        print(f"    count  = {total:,}")
        print(f"    min    = {obs_min:.4f}")
        print(f"    max    = {obs_max:.4f}")
        print(f"    mean   = {obs_mean:.4f}")
        print(f"    median = {obs_median:.4f}")
        print(f"    below {lo}: {below} ({100 * below / total:.2f}%)")
        print(f"    above {hi}: {above} ({100 * above / total:.2f}%)")

        # Allow small fraction of outliers (PrepBUFR has some extreme obs)
        outlier_pct = 100 * (below + above) / total
        if outlier_pct > 1.0:
            failures.append(
                f"[{var}] {outlier_pct:.1f}% of observations outside "
                f"[{lo}, {hi}]: min={obs_min:.2f}, max={obs_max:.2f}"
            )

    return failures


def validate_locations(dfs: dict[str, pd.DataFrame]) -> list[str]:
    """Check that lat/lon/elevation are physically reasonable."""
    failures: list[str] = []

    # Combine all DataFrames for spatial checks
    all_df = pd.concat(dfs.values(), ignore_index=True)
    if all_df.empty:
        failures.append("No data at all")
        return failures

    lat = all_df["lat"].dropna()
    lon = all_df["lon"].dropna()
    elev_raw = all_df["station_elev"].dropna()
    # Filter out PrepBUFR missing sentinel value (9999.0)
    elev = elev_raw[elev_raw != ELEV_MISSING_SENTINEL]
    n_elev_missing = (elev_raw == ELEV_MISSING_SENTINEL).sum()

    print("\n  [lat/lon/elev] spatial coverage:")
    print(f"    lat  range: [{lat.min():.2f}, {lat.max():.2f}]")
    print(f"    lon  range: [{lon.min():.2f}, {lon.max():.2f}]")
    if not elev.empty:
        print(f"    elev range: [{elev.min():.1f}, {elev.max():.1f}] m")
        print(f"    elev missing (9999 sentinel): {n_elev_missing:,}")

    # Check latitude
    bad_lat = ((lat < LAT_RANGE[0]) | (lat > LAT_RANGE[1])).sum()
    if bad_lat > 0:
        failures.append(f"[lat] {bad_lat} values outside {LAT_RANGE}")

    # Check longitude
    bad_lon = ((lon < LON_RANGE[0]) | (lon > LON_RANGE[1])).sum()
    if bad_lon > 0:
        failures.append(f"[lon] {bad_lon} values outside {LON_RANGE}")

    # Check elevation (excluding sentinel values)
    if not elev.empty:
        bad_elev = ((elev < ELEV_RANGE[0]) | (elev > ELEV_RANGE[1])).sum()
        if bad_elev > 0:
            pct = 100 * bad_elev / len(elev)
            # Allow small fraction of outliers (aircraft at cruise altitude)
            if pct > 1.0:
                failures.append(
                    f"[elev] {bad_elev} ({pct:.2f}%) values outside "
                    f"{ELEV_RANGE} (excluding 9999 sentinel)"
                )

    # Global coverage check: expect observations in all 4 hemispheres
    has_north = (lat > 10).any()
    has_south = (lat < -10).any()
    has_east = (lon > 10).any()
    has_west = ((lon < -10) | (lon > 350)).any()

    if not (has_north and has_south):
        failures.append("[coverage] Missing Northern or Southern hemisphere data")
    if not (has_east and has_west):
        failures.append("[coverage] Missing Eastern or Western hemisphere data")

    # Check that we have a reasonable number of unique stations
    stations = all_df["station"].dropna().unique()
    print(f"    unique stations: {len(stations):,}")
    if len(stations) < 100:
        failures.append(
            f"[stations] Only {len(stations)} unique stations (expected >100)"
        )

    return failures


def validate_obs_classes(dfs: dict[str, pd.DataFrame]) -> list[str]:
    """Check that observation classes match known PrepBUFR types."""
    failures: list[str] = []
    known_classes = set(PREPBUFR_OBS_TYPES.values())

    all_df = pd.concat(dfs.values(), ignore_index=True)
    obs_classes = all_df["class"].dropna().unique()

    print(f"\n  [obs classes] found: {sorted(obs_classes)}")
    for cls in obs_classes:
        if cls not in known_classes:
            failures.append(f"[class] Unknown observation class: {cls}")

    return failures


def validate_quality_marks(dfs: dict[str, pd.DataFrame]) -> list[str]:
    """Check quality mark distribution."""
    failures: list[str] = []
    all_df = pd.concat(dfs.values(), ignore_index=True)
    qm = all_df["quality"].dropna()

    if qm.empty:
        failures.append("[quality] No quality marks found")
        return failures

    print(f"\n  [quality marks]")
    print(f"    count: {len(qm):,}")
    print(f"    range: [{qm.min()}, {qm.max()}]")
    vc = qm.value_counts().sort_index().head(10)
    for val, cnt in vc.items():
        print(f"    qm={val}: {cnt:>8,} ({100 * cnt / len(qm):.1f}%)")

    # PrepBUFR quality marks: 0-15, with 0=good, 1=good, 2=neutral,
    # 3=suspect, 9-15=rejected. Most should be 0-3.
    if qm.max() > 15:
        failures.append(f"[quality] Max quality mark {qm.max()} > 15")

    good_pct = 100 * (qm <= 3).sum() / len(qm)
    print(f"    good (<=3): {good_pct:.1f}%")
    if good_pct < 50:
        failures.append(f"[quality] Only {good_pct:.1f}% of obs have quality mark <= 3")

    return failures


def _snap_to_gfs_cycle(dt: datetime) -> datetime:
    """Snap a datetime to the nearest past 6-hour GFS cycle (00/06/12/18 UTC)."""
    hour_6 = (dt.hour // 6) * 6
    return dt.replace(hour=hour_6, minute=0, second=0, microsecond=0)


def _fetch_gfs_fields(
    gfs_time: datetime,
    variables: list[str],
) -> xr.DataArray:
    """Fetch GFS analysis fields for the given time and variables.

    Parameters
    ----------
    gfs_time : datetime
        Analysis time (must be on a 6-hour boundary).
    variables : list[str]
        Variable names from GFSLexicon (e.g. ["t2m", "u10m", "v10m"]).

    Returns
    -------
    xr.DataArray
        Shape (time=1, variable, lat=721, lon=1440).
    """
    gfs = GFS(source="aws", cache=True, verbose=True)
    return gfs(gfs_time, variables)


def _filter_obs_by_time(
    df: pd.DataFrame,
    target: datetime,
    tolerance_hours: float = 1.0,
) -> pd.DataFrame:
    """Keep only observations within *tolerance_hours* of *target*."""
    t_min = target - timedelta(hours=tolerance_hours)
    t_max = target + timedelta(hours=tolerance_hours)
    return df[(df["time"] >= t_min) & (df["time"] <= t_max)]


def plot_obs_on_gfs(
    dfs: dict[str, pd.DataFrame],
    gfs_time: datetime,
    output_dir: Path,
    tolerance_hours: float = 1.0,
) -> None:
    """Scatter-plot observations coloured by value on top of GFS contour fields.

    Generates four figures:
      1. Temperature obs (coloured) on GFS t2m contour
      2. Wind-speed obs (coloured) on GFS wind-speed contour
      3. Surface pressure obs (coloured) on GFS sp contour
      4. Specific humidity obs (coloured) on GFS q2m contour

    Parameters
    ----------
    dfs : dict
        Variable → DataFrame mapping returned by ``decode_all_variables``.
    gfs_time : datetime
        GFS analysis time (snapped to 6-hr cycle).
    output_dir : Path
        Directory for output PNGs.
    tolerance_hours : float
        Keep observations within this many hours of *gfs_time*.
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    # ── Fetch GFS fields ──────────────────────────────────────────
    print(f"\n  Fetching GFS fields for {gfs_time} …")
    gfs_da = _fetch_gfs_fields(gfs_time, ["t2m", "u10m", "v10m", "sp", "q2m"])
    # Extract 2-D numpy arrays
    gfs_lat = gfs_da.coords["lat"].values  # 90 → −90
    gfs_lon = gfs_da.coords["lon"].values  # 0 → 359.75
    gfs_t2m = gfs_da.sel(variable="t2m").values.squeeze()  # (lat, lon)
    gfs_u10 = gfs_da.sel(variable="u10m").values.squeeze()
    gfs_v10 = gfs_da.sel(variable="v10m").values.squeeze()
    gfs_wspd = np.sqrt(gfs_u10**2 + gfs_v10**2)
    gfs_sp = gfs_da.sel(variable="sp").values.squeeze()  # Pa
    gfs_q2m = gfs_da.sel(variable="q2m").values.squeeze()  # kg/kg

    lon2d, lat2d = np.meshgrid(gfs_lon, gfs_lat)

    # ── Helper: normalise obs lon to [0, 360) to match GFS grid ───
    def _norm_lon(s: pd.Series) -> pd.Series:
        return s % 360.0

    # ── 1. Temperature scatter on GFS t2m contour ─────────────────
    t_df = dfs.get("t", pd.DataFrame())
    if not t_df.empty:
        t_near = _filter_obs_by_time(t_df, gfs_time, tolerance_hours)
        # Keep only good-quality obs (quality mark <= 3)
        t_near = t_near[t_near["quality"].fillna(99) <= 3]
        print(f"  Temperature obs within ±{tolerance_hours}h (qm<=3): {len(t_near):,}")
        if not t_near.empty:
            fig, ax = plt.subplots(figsize=(16, 8))
            # Shared colour scale across GFS contour and obs scatter
            t_vmin = min(float(gfs_t2m.min()), float(t_near["observation"].min()))
            t_vmax = max(float(gfs_t2m.max()), float(t_near["observation"].max()))
            if t_vmin >= t_vmax:
                t_vmax = t_vmin + 1.0
            t_levels = np.linspace(t_vmin, t_vmax, 31)
            # Contour fill background
            cf = ax.contourf(
                lon2d,
                lat2d,
                gfs_t2m,
                levels=t_levels,
                cmap="RdYlBu_r",
                alpha=0.55,
            )
            # Scatter obs (same scale)
            ax.scatter(
                _norm_lon(t_near["lon"]),
                t_near["lat"],
                c=t_near["observation"],
                cmap="RdYlBu_r",
                s=8,
                alpha=0.8,
                edgecolors="k",
                linewidths=0.2,
                vmin=t_vmin,
                vmax=t_vmax,
            )
            plt.colorbar(cf, ax=ax, label="Temperature (K)", shrink=0.7)
            ax.set_xlim(0, 360)
            ax.set_ylim(-90, 90)
            ax.set_xlabel("Longitude (°E)")
            ax.set_ylabel("Latitude")
            ax.set_title(
                f"PrepBUFR temperature obs on GFS T2m — "
                f"{gfs_time:%Y-%m-%d %H:%M} UTC  (±{tolerance_hours}h)",
            )
            ax.grid(True, alpha=0.3)
            fig.tight_layout()
            fig.savefig(output_dir / "obs_temp_on_gfs_t2m.png", dpi=150)
            plt.close(fig)
            print(f"  Saved obs_temp_on_gfs_t2m.png")
    else:
        print("  No temperature observations — skipping temp scatter plot.")

    # ── 2. Wind scatter on GFS wind-speed contour ─────────────────
    u_df = dfs.get("u", pd.DataFrame())
    v_df = dfs.get("v", pd.DataFrame())
    if not u_df.empty and not v_df.empty:
        u_near = _filter_obs_by_time(u_df, gfs_time, tolerance_hours)
        v_near = _filter_obs_by_time(v_df, gfs_time, tolerance_hours)
        # Keep only good-quality obs (quality mark <= 3)
        u_near = u_near[u_near["quality"].fillna(99) <= 3]
        v_near = v_near[v_near["quality"].fillna(99) <= 3]
        # Exclude satellite / radar-derived winds (not directly comparable to GFS 10m)
        _wind_exclude = {"SATWND", "VADWND", "ASCATW"}
        u_near = u_near[~u_near["class"].isin(_wind_exclude)]
        v_near = v_near[~v_near["class"].isin(_wind_exclude)]
        # Merge u and v on matching obs locations
        u_sel = u_near[["time", "lat", "lon", "pres", "observation"]].rename(
            columns={"observation": "u"}
        )
        v_sel = v_near[["time", "lat", "lon", "pres", "observation"]].rename(
            columns={"observation": "v"}
        )
        uv = pd.merge(u_sel, v_sel, on=["time", "lat", "lon", "pres"], how="inner")
        obs_wspd = np.sqrt(uv["u"] ** 2 + uv["v"] ** 2)
        print(f"  Wind obs pairs within ±{tolerance_hours}h: {len(uv):,}")
        if not uv.empty:
            fig, ax = plt.subplots(figsize=(16, 8))
            # Shared colour scale — cap at 50 m/s to avoid outlier stations
            # blowing out the range (e.g. station 48698 reports ~240 m/s)
            w_vmin = 0.0
            w_vmax = 30.0
            w_levels = np.linspace(w_vmin, w_vmax, 31)
            cf = ax.contourf(
                lon2d,
                lat2d,
                gfs_wspd,
                levels=w_levels,
                cmap="YlOrRd",
                alpha=0.55,
                extend="max",
            )
            ax.scatter(
                _norm_lon(uv["lon"]),
                uv["lat"],
                c=obs_wspd.clip(upper=w_vmax),
                cmap="YlOrRd",
                s=8,
                alpha=0.8,
                edgecolors="k",
                linewidths=0.2,
                vmin=w_vmin,
                vmax=w_vmax,
            )
            plt.colorbar(cf, ax=ax, label="Wind speed (m/s)", shrink=0.7)
            ax.set_xlim(0, 360)
            ax.set_ylim(-90, 90)
            ax.set_xlabel("Longitude (°E)")
            ax.set_ylabel("Latitude")
            ax.set_title(
                f"PrepBUFR wind obs on GFS 10 m wind speed — "
                f"{gfs_time:%Y-%m-%d %H:%M} UTC  (±{tolerance_hours}h)",
            )
            ax.grid(True, alpha=0.3)
            fig.tight_layout()
            fig.savefig(output_dir / "obs_wind_on_gfs_wspd.png", dpi=150)
            plt.close(fig)
            print(f"  Saved obs_wind_on_gfs_wspd.png")
    else:
        print("  No wind observations — skipping wind scatter plot.")

    # ── 3. Pressure scatter on GFS surface-pressure contour ───────
    p_df = dfs.get("pres", pd.DataFrame())
    if not p_df.empty:
        p_near = _filter_obs_by_time(p_df, gfs_time, tolerance_hours)
        p_near = p_near[p_near["quality"].fillna(99) <= 3]
        # Only keep surface obs (ADPSFC / SFCSHP) for fair comparison to GFS sp
        p_near = p_near[p_near["class"].isin({"ADPSFC", "SFCSHP"})]
        print(
            f"  Pressure obs within ±{tolerance_hours}h (qm<=3, sfc): {len(p_near):,}"
        )
        if not p_near.empty:
            fig, ax = plt.subplots(figsize=(16, 8))
            # Obs pres observation is in Pa; GFS sp is in Pa
            p_vmin = min(float(gfs_sp.min()), float(p_near["observation"].min()))
            p_vmax = max(float(gfs_sp.max()), float(p_near["observation"].max()))
            # Narrow to a reasonable surface range to see detail
            p_vmin = max(p_vmin, 950000.0)
            p_vmax = min(p_vmax, 1060000.0)
            if p_vmin >= p_vmax:
                # Fallback: data outside expected range; use full data extent
                p_vmin = min(float(gfs_sp.min()), float(p_near["observation"].min()))
                p_vmax = max(float(gfs_sp.max()), float(p_near["observation"].max()))
            if p_vmin >= p_vmax:
                p_vmax = p_vmin + 1.0  # degenerate case: ensure strictly increasing
            p_levels = np.linspace(p_vmin, p_vmax, 31)
            cf = ax.contourf(
                lon2d,
                lat2d,
                gfs_sp,
                levels=p_levels,
                cmap="viridis",
                alpha=0.55,
                extend="both",
            )
            ax.scatter(
                _norm_lon(p_near["lon"]),
                p_near["lat"],
                c=p_near["observation"].clip(lower=p_vmin, upper=p_vmax),
                cmap="viridis",
                s=8,
                alpha=0.8,
                edgecolors="k",
                linewidths=0.2,
                vmin=p_vmin,
                vmax=p_vmax,
            )
            plt.colorbar(cf, ax=ax, label="Surface pressure (Pa)", shrink=0.7)
            ax.set_xlim(0, 360)
            ax.set_ylim(-90, 90)
            ax.set_xlabel("Longitude (°E)")
            ax.set_ylabel("Latitude")
            ax.set_title(
                f"PrepBUFR surface pressure obs on GFS SP — "
                f"{gfs_time:%Y-%m-%d %H:%M} UTC  (±{tolerance_hours}h)",
            )
            ax.grid(True, alpha=0.3)
            fig.tight_layout()
            fig.savefig(output_dir / "obs_pres_on_gfs_sp.png", dpi=150)
            plt.close(fig)
            print(f"  Saved obs_pres_on_gfs_sp.png")
    else:
        print("  No pressure observations — skipping pressure scatter plot.")

    # ── 4. Specific humidity scatter on GFS q2m contour ───────────
    q_df = dfs.get("q", pd.DataFrame())
    if not q_df.empty:
        q_near = _filter_obs_by_time(q_df, gfs_time, tolerance_hours)
        q_near = q_near[q_near["quality"].fillna(99) <= 3]
        print(f"  Humidity obs within ±{tolerance_hours}h (qm<=3): {len(q_near):,}")
        if not q_near.empty:
            fig, ax = plt.subplots(figsize=(16, 8))
            # Obs q is in kg/kg; GFS q2m is in kg/kg
            q_vmin = 0.0
            q_vmax = max(float(gfs_q2m.max()), float(q_near["observation"].max()))
            if q_vmin >= q_vmax:
                q_vmax = q_vmin + 1.0
            q_levels = np.linspace(q_vmin, q_vmax, 31)
            cf = ax.contourf(
                lon2d,
                lat2d,
                gfs_q2m,
                levels=q_levels,
                cmap="YlGnBu",
                alpha=0.55,
                extend="max",
            )
            ax.scatter(
                _norm_lon(q_near["lon"]),
                q_near["lat"],
                c=q_near["observation"].clip(upper=q_vmax),
                cmap="YlGnBu",
                s=8,
                alpha=0.8,
                edgecolors="k",
                linewidths=0.2,
                vmin=q_vmin,
                vmax=q_vmax,
            )
            plt.colorbar(cf, ax=ax, label="Specific humidity (kg/kg)", shrink=0.7)
            ax.set_xlim(0, 360)
            ax.set_ylim(-90, 90)
            ax.set_xlabel("Longitude (°E)")
            ax.set_ylabel("Latitude")
            ax.set_title(
                f"PrepBUFR humidity obs on GFS Q2m — "
                f"{gfs_time:%Y-%m-%d %H:%M} UTC  (±{tolerance_hours}h)",
            )
            ax.grid(True, alpha=0.3)
            fig.tight_layout()
            fig.savefig(output_dir / "obs_q_on_gfs_q2m.png", dpi=150)
            plt.close(fig)
            print(f"  Saved obs_q_on_gfs_q2m.png")
    else:
        print("  No humidity observations — skipping humidity scatter plot.")


def plot_diagnostics(dfs: dict[str, pd.DataFrame], output_dir: Path) -> None:
    """Generate diagnostic plots."""
    output_dir.mkdir(parents=True, exist_ok=True)

    all_df = pd.concat(dfs.values(), ignore_index=True)

    # ── 1. Station map ────────────────────────────────────────────
    fig, ax = plt.subplots(figsize=(14, 7))
    for var, color in [
        ("t", "red"),
        ("pres", "blue"),
        ("u", "green"),
        ("q", "orange"),
    ]:
        df = dfs.get(var, pd.DataFrame())
        if df.empty:
            continue
        # Sample for plotting speed
        sample = df.sample(n=min(5000, len(df)), random_state=42)
        ax.scatter(
            sample["lon"],
            sample["lat"],
            s=1,
            alpha=0.3,
            label=f"{var} ({len(df):,})",
            color=color,
        )
    ax.set_xlabel("Longitude")
    ax.set_ylabel("Latitude")
    ax.set_title("Observation Locations by Variable")
    ax.legend(markerscale=8, loc="lower left")
    ax.set_xlim(0, 360)
    ax.set_ylim(-90, 90)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(output_dir / "station_map.png", dpi=150)
    plt.close(fig)
    print(f"  Saved station_map.png")

    # ── 2. Observation histograms ─────────────────────────────────
    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
    axes_flat = axes.flatten()
    for idx, var in enumerate(["t", "q", "pres", "u", "v", "gps"]):
        ax = axes_flat[idx]
        df = dfs.get(var, pd.DataFrame())
        if df.empty:
            ax.set_title(f"{var}: no data")
            continue
        obs = df["observation"].dropna()
        bounds = VARIABLE_BOUNDS.get(var, ("", obs.min(), obs.max()))

        # Clip extreme outliers for better visualization
        p01, p99 = np.percentile(obs, [0.5, 99.5])
        clipped = obs[(obs >= p01) & (obs <= p99)]

        ax.hist(clipped, bins=100, alpha=0.7, edgecolor="none")
        ax.axvline(
            obs.mean(), color="red", linestyle="--", label=f"mean={obs.mean():.1f}"
        )
        ax.axvline(
            obs.median(), color="green", linestyle="--", label=f"med={obs.median():.1f}"
        )
        ax.set_title(f"{var} (n={len(obs):,})")
        ax.set_xlabel(bounds[0] if bounds else var)
        ax.legend(fontsize=8)

    fig.suptitle("Observation Value Distributions", fontsize=14, y=1.01)
    fig.tight_layout()
    fig.savefig(output_dir / "obs_histograms.png", dpi=150)
    plt.close(fig)
    print(f"  Saved obs_histograms.png")

    # ── 3. Observation class distribution ─────────────────────────
    fig, ax = plt.subplots(figsize=(10, 5))
    class_counts = all_df.groupby(["variable", "class"]).size().unstack(fill_value=0)
    class_counts.plot(kind="bar", ax=ax, stacked=True)
    ax.set_title("Observation Count by Variable and Class")
    ax.set_xlabel("Variable")
    ax.set_ylabel("Count")
    ax.legend(title="Obs Class", bbox_to_anchor=(1.02, 1), loc="upper left")
    fig.tight_layout()
    fig.savefig(output_dir / "class_distribution.png", dpi=150)
    plt.close(fig)
    print(f"  Saved class_distribution.png")

    # ── 4. Quality mark distribution ──────────────────────────────
    fig, ax = plt.subplots(figsize=(10, 5))
    for var in ["t", "q", "pres", "u", "v"]:
        df = dfs.get(var, pd.DataFrame())
        if df.empty:
            continue
        qm = df["quality"].dropna()
        if qm.empty:
            continue
        vc = qm.value_counts().sort_index()
        ax.bar(
            vc.index + 0.1 * list(VARIABLE_BOUNDS.keys()).index(var),
            vc.values,
            width=0.15,
            alpha=0.7,
            label=var,
        )
    ax.set_xlabel("Quality Mark")
    ax.set_ylabel("Count")
    ax.set_title("Quality Mark Distribution by Variable")
    ax.legend()
    ax.set_xticks(range(16))
    fig.tight_layout()
    fig.savefig(output_dir / "quality_marks.png", dpi=150)
    plt.close(fig)
    print(f"  Saved quality_marks.png")

    # ── 5. Pressure level profile (for temperature) ───────────────
    fig, axes = plt.subplots(1, 2, figsize=(12, 8))

    # Temperature vs pressure
    t_df = dfs.get("t", pd.DataFrame())
    if not t_df.empty:
        ax = axes[0]
        t_with_pres = t_df.dropna(subset=["observation", "pres"])
        if not t_with_pres.empty:
            sample = t_with_pres.sample(n=min(10000, len(t_with_pres)), random_state=42)
            ax.scatter(sample["observation"], sample["pres"] / 100, s=1, alpha=0.2)
            ax.set_xlabel("Temperature (°C)")
            ax.set_ylabel("Pressure (hPa)")
            ax.set_title(f"Temperature Profile (n={len(t_with_pres):,})")
            ax.invert_yaxis()
            ax.set_yscale("log")
            ax.grid(True, alpha=0.3)

    # Humidity vs pressure
    q_df = dfs.get("q", pd.DataFrame())
    if not q_df.empty:
        ax = axes[1]
        q_with_pres = q_df.dropna(subset=["observation", "pres"])
        if not q_with_pres.empty:
            sample = q_with_pres.sample(n=min(10000, len(q_with_pres)), random_state=42)
            ax.scatter(sample["observation"], sample["pres"] / 100, s=1, alpha=0.2)
            ax.set_xlabel("Specific Humidity (mg/kg)")
            ax.set_ylabel("Pressure (hPa)")
            ax.set_title(f"Humidity Profile (n={len(q_with_pres):,})")
            ax.invert_yaxis()
            ax.set_yscale("log")
            ax.grid(True, alpha=0.3)

    fig.suptitle("Vertical Profiles", fontsize=14)
    fig.tight_layout()
    fig.savefig(output_dir / "vertical_profiles.png", dpi=150)
    plt.close(fig)
    print(f"  Saved vertical_profiles.png")

    # ── 6. Wind hodograph (u vs v) ───────────────────────────────
    u_df = dfs.get("u", pd.DataFrame())
    v_df = dfs.get("v", pd.DataFrame())
    if not u_df.empty and not v_df.empty:
        fig, ax = plt.subplots(figsize=(8, 8))
        # Merge u and v on shared keys
        u_obs = u_df[["time", "lat", "lon", "pres", "observation"]].rename(
            columns={"observation": "u"}
        )
        v_obs = v_df[["time", "lat", "lon", "pres", "observation"]].rename(
            columns={"observation": "v"}
        )
        uv = pd.merge(u_obs, v_obs, on=["time", "lat", "lon", "pres"], how="inner")
        if not uv.empty:
            sample = uv.sample(n=min(10000, len(uv)), random_state=42)
            sc = ax.scatter(
                sample["u"],
                sample["v"],
                c=sample["pres"] / 100,
                cmap="viridis_r",
                s=2,
                alpha=0.3,
            )
            plt.colorbar(sc, label="Pressure (hPa)")
            ax.set_xlabel("U wind (m/s)")
            ax.set_ylabel("V wind (m/s)")
            ax.set_title(f"Wind Components (n={len(uv):,} matched pairs)")
            ax.axhline(0, color="gray", linewidth=0.5)
            ax.axvline(0, color="gray", linewidth=0.5)
            ax.set_aspect("equal")
            ax.grid(True, alpha=0.3)
        fig.tight_layout()
        fig.savefig(output_dir / "wind_scatter.png", dpi=150)
        plt.close(fig)
        print(f"  Saved wind_scatter.png")


def main() -> None:
    parser = argparse.ArgumentParser(description="PrepBUFR sanity check")
    parser.add_argument(
        "--hours-ago",
        type=float,
        default=12.0,
        help="Hours before now for observation centre time (default: 12)",
    )
    parser.add_argument(
        "--workers",
        type=int,
        default=16,
        help="Number of decode workers (default: 16)",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="prepbufr_research/sanity_plots",
        help="Output directory for plots",
    )
    parser.add_argument(
        "--gfs-hours-ago",
        type=float,
        default=12.0,
        help="Hours before now for GFS overlay timestamp (default: 12)",
    )
    parser.add_argument(
        "--gfs-tolerance",
        type=float,
        default=1.0,
        help="Obs time tolerance in hours for GFS overlay (default: 1)",
    )
    parser.add_argument(
        "--skip-gfs",
        action="store_true",
        help="Skip the GFS overlay scatter plots",
    )
    args = parser.parse_args()

    output_dir = Path(args.output)

    # Compute observation target time (naive UTC)
    now_utc = datetime.now(timezone.utc)
    obs_target = (now_utc - timedelta(hours=args.hours_ago)).replace(tzinfo=None)
    print(f"Observation target: {obs_target} UTC (~{args.hours_ago}h ago)")
    print(f"Output dir: {output_dir}")

    # ── Step 1: Fetch & decode via data source ────────────────────
    print("\n" + "=" * 60)
    print("STEP 1: Fetching observations from NOMADS")
    print("=" * 60)
    dfs = fetch_all_variables(obs_target, workers=args.workers)

    # ── Step 2: Validate ──────────────────────────────────────────
    print("\n" + "=" * 60)
    print("STEP 2: Validating observation ranges")
    print("=" * 60)
    failures: list[str] = []
    failures.extend(validate_ranges(dfs))
    failures.extend(validate_locations(dfs))
    failures.extend(validate_obs_classes(dfs))
    failures.extend(validate_quality_marks(dfs))

    # ── Step 3: Plot ──────────────────────────────────────────────
    print("\n" + "=" * 60)
    print("STEP 3: Generating diagnostic plots")
    print("=" * 60)
    plot_diagnostics(dfs, output_dir)

    # ── Step 4: GFS overlay scatter plots ─────────────────────────
    if not args.skip_gfs:
        print("\n" + "=" * 60)
        print("STEP 4: GFS overlay scatter plots")
        print("=" * 60)
        now_utc = datetime.now(timezone.utc)
        target = now_utc - timedelta(hours=args.gfs_hours_ago)
        gfs_time = _snap_to_gfs_cycle(target)
        # GFS __call__ expects a naive datetime (treated as UTC)
        gfs_time_naive = gfs_time.replace(tzinfo=None)
        print(
            f"  Target time: ~{args.gfs_hours_ago}h ago → {target:%Y-%m-%d %H:%M} UTC"
        )
        print(f"  Snapped GFS cycle: {gfs_time_naive}")
        print(f"  Obs tolerance: ±{args.gfs_tolerance}h")
        plot_obs_on_gfs(
            dfs,
            gfs_time=gfs_time_naive,
            output_dir=output_dir,
            tolerance_hours=args.gfs_tolerance,
        )

    # ── Summary ───────────────────────────────────────────────────
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    if failures:
        print(f"\n  FAILURES ({len(failures)}):")
        for f in failures:
            print(f"    ✗ {f}")
        sys.exit(1)
    else:
        print("\n  ALL CHECKS PASSED")
        print(f"  Plots saved to: {output_dir.resolve()}")
        sys.exit(0)


if __name__ == "__main__":
    main()

All validation checks passed. Plots saved locally in prepbufr_research/sanity_plots/ and generated with the script above.

@NickGeneva
Copy link
Copy Markdown
Collaborator Author

@greptile-apps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant