diff --git a/earth2studio/data/nclimgrid.py b/earth2studio/data/nclimgrid.py new file mode 100644 index 000000000..045936c37 --- /dev/null +++ b/earth2studio/data/nclimgrid.py @@ -0,0 +1,371 @@ +import logging +import os +from datetime import datetime +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd +import pyarrow as pa +import s3fs +import xarray as xr +from tqdm.auto import tqdm + +from earth2studio.lexicon.nclimgrid import NClimGridLexicon + +logger = logging.getLogger("earth2studio.nclimgrid") +logger.setLevel(logging.INFO) + + +class NClimGrid: + """ + Earth2Studio NClimGrid gridded datasource. + + Designed for large Zarr datasets on S3. + + Key features + ------------ + - Canonical variable mapping via NClimGridLexicon + - Strong schema validation + - Input normalization for datetime, list-like, and slice time requests + - Windowed parquet caching per variable × timestep + - Chunk-stream DataFrame generation to avoid loading large time windows at once + """ + + SOURCE_ID = "earth2studio.data.nclimgrid" + + SCHEMA = pa.schema( + [ + pa.field("time", pa.timestamp("ns")), + pa.field("lat", pa.float32()), + pa.field("lon", pa.float32()), + pa.field("observation", pa.float32()), + pa.field("variable", pa.string()), + ] + ) + + # ------------------------------------------------------- + # Schema utilities + # ------------------------------------------------------- + + @classmethod + def resolve_fields(cls, fields: Any) -> pa.Schema: + """ + Resolve requested output fields into a validated Arrow schema. + """ + if fields is None: + return cls.SCHEMA + + if isinstance(fields, str): + fields = [fields] + + if isinstance(fields, pa.Schema): + for field in fields: + if field.name not in cls.SCHEMA.names: + raise KeyError( + f"Field '{field.name}' not in schema. " + f"Valid fields: {cls.SCHEMA.names}" + ) + expected = cls.SCHEMA.field(field.name).type + if field.type != expected: + raise TypeError( + f"Field '{field.name}' has type {field.type}, expected {expected}" + ) + return fields + + selected = [] + for f in fields: + if f not in cls.SCHEMA.names: + raise KeyError( + f"Field '{f}' not in schema. " f"Valid fields: {cls.SCHEMA.names}" + ) + selected.append(cls.SCHEMA.field(f)) + + return pa.schema(selected) + + # ------------------------------------------------------- + # Input normalization utilities + # ------------------------------------------------------- + + @staticmethod + def _normalize_time(time: Any) -> Any: + """ + Normalize time input into one of: + - None + - slice(pd.Timestamp, pd.Timestamp, step) + - list[pd.Timestamp] + """ + if time is None: + return None + + if isinstance(time, slice): + start = pd.to_datetime(time.start) if time.start is not None else None + stop = pd.to_datetime(time.stop) if time.stop is not None else None + return slice(start, stop, time.step) + + if isinstance(time, datetime): + return [pd.Timestamp(time)] + + if isinstance(time, (list, tuple, np.ndarray, pd.DatetimeIndex, pd.Index)): + return list(pd.to_datetime(time)) + + raise TypeError( + "Invalid time input. Expected None, datetime, slice, or list-like of datetimes." + ) + + @staticmethod + def _normalize_variable(variable: Any) -> list[str]: + """ + Normalize variable input into list[str]. + """ + if isinstance(variable, str): + return [variable] + + if isinstance(variable, (list, tuple, np.ndarray, pd.Index)): + return [str(v) for v in variable] + + raise TypeError("Invalid variable input. Expected str or list-like of strings.") + + @staticmethod + def _resolve_requested_times(times: Any) -> list[pd.Timestamp]: + """ + Resolve requested times from user input rather than from lazy dataset coordinates. + + This avoids forcing .values on a lazily selected DataArray just to enumerate times. + """ + if times is None: + raise ValueError( + "Explicit time selection is required for NClimGrid streaming access." + ) + + if isinstance(times, slice): + if times.start is None or times.stop is None: + raise ValueError( + "Slice time selection must include both start and stop for streaming." + ) + return list(pd.date_range(times.start, times.stop, freq="D")) + + # list-like path + resolved = list(pd.to_datetime(times)) + # stable ordering, unique + return sorted(pd.unique(pd.Index(resolved))) + + # ------------------------------------------------------- + # Constructor + # ------------------------------------------------------- + + def __init__( + self, + bucket: str = "noaa-nclimgrid-daily-pds", + cache: bool = True, + verbose: bool = False, + cache_dir: str = "~/.earth2studio-cache/nclimgrid", + ) -> None: + self.bucket = bucket + self.verbose = verbose + self.cache = cache + self.cache_dir = Path(os.path.expanduser(cache_dir)) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.fs = s3fs.S3FileSystem(anon=True) + self._call_cache: dict[Any, pd.DataFrame] = {} + + # ------------------------------------------------------- + # Constructor + # ------------------------------------------------------- + + def _monthly_nc_path(self, ts: Any) -> str: + ts = pd.Timestamp(ts) + return ( + f"s3://{self.bucket}/access/grids/" + f"{ts.year}/ncdd-{ts.year}{ts.month:02d}-grd-scaled.nc" + ) + + # ------------------------------------------------------- + # Internal cache utilities + # ------------------------------------------------------- + + @staticmethod + def _timestamp_cache_key(ts: Any) -> str: + """ + Convert timestamp into stable YYYYMMDD cache key component. + """ + return pd.Timestamp(ts).strftime("%Y%m%d") + + def _cache_file_for_timestep(self, native: str, ts: Any) -> Path: + """ + One cache file per native variable per timestep. + """ + return self.cache_dir / f"{native}_{self._timestamp_cache_key(ts)}.parquet" + + # ------------------------------------------------------- + # Internal dataframe construction + # ------------------------------------------------------- + + def _dataarray_to_dataframe( + self, + da_t: xr.DataArray, + var: str, + modifier: Any, + ) -> pd.DataFrame: + """ + Convert one timestep DataArray into standardized DataFrame. + + Uses explicit per-timestep materialization to keep memory bounded and + avoid giant lazy graphs. + """ + # Force only the current timestep into memory. + da_t = da_t.load() + + values = modifier(da_t.values) + values = np.asarray(values, dtype="float32") + + df = ( + xr.DataArray( + values, + coords=da_t.coords, + dims=da_t.dims, + name="observation", + ) + .to_dataframe() + .reset_index() + ) + + df["variable"] = var + + # enforce schema column order + df = df[self.SCHEMA.names] + + # enforce dtypes + df["lat"] = df["lat"].astype("float32") + df["lon"] = df["lon"].astype("float32") + df["observation"] = df["observation"].astype("float32") + df["variable"] = df["variable"].astype("string") + + return df + + def _fetch_variable_dataframe( + self, + var: str, + native: str, + modifier: Any, + times: Any, + ) -> pd.DataFrame: + + selected_times = self._resolve_requested_times(times) + frames = [] + + iterator = selected_times + if getattr(self, "verbose", False) and len(selected_times) > 1: + iterator = tqdm(iterator, desc=f"NClimGrid {native}", leave=False) + + current_month_ds = None + current_month = None + + for ts in iterator: + + ts = pd.Timestamp(ts) + + # open new monthly file only when month changes + if current_month != (ts.year, ts.month): + + path = self._monthly_nc_path(ts) + + if self.verbose: + logger.info(f"Opening {path}") + + current_month_ds = xr.open_dataset( + path, + engine="h5netcdf", + storage_options={"anon": True}, + ) + + current_month = (ts.year, ts.month) + + if current_month_ds is None: + continue + + try: + da_t = current_month_ds[native].sel(time=ts) + except Exception as e: + logger.debug(f"Skipping timestep {ts}: {e}") + continue + + cache_file = self._cache_file_for_timestep(native, ts) + + if self.cache and cache_file.exists(): + df_t = pd.read_parquet(cache_file) + else: + df_t = self._dataarray_to_dataframe(da_t, var, modifier) + + if self.cache: + df_t.to_parquet(cache_file) + + frames.append(df_t) + + if not frames: + return pd.DataFrame(columns=self.SCHEMA.names) + + return pd.concat(frames, ignore_index=True) + + # ------------------------------------------------------- + # Core fetch logic + # ------------------------------------------------------- + + def __call__( + self, + time: Any = None, + variable: Any = None, + fields: Any = None, + ) -> pd.DataFrame: + """ + Fetch NClimGrid data as standardized DataFrame. + + Parameters + ---------- + time : None | datetime | slice | list-like of datetime + Requested timestep(s). + variable : str | list[str] + Canonical Earth2Studio variable name(s). + fields : None | str | list[str] | pa.Schema + Output field subset. + + Returns + ------- + pd.DataFrame + """ + if variable is None: + raise ValueError("variable must be provided") + if time is None: + raise ValueError("time must be provided") + + times = self._normalize_time(time) + variables = self._normalize_variable(variable) + schema = self.resolve_fields(fields) + + times_key = tuple(self._resolve_requested_times(times)) + vars_key = tuple(sorted(variables)) + fields_key = tuple([f.name for f in schema]) + + cache_key = (times_key, vars_key, fields_key) + + if self.cache and cache_key in self._call_cache: + return self._call_cache[cache_key].copy() + + frames = [] + for var in variables: + native, modifier = NClimGridLexicon.get_item(var) + df_var = self._fetch_variable_dataframe(var, native, modifier, times) + frames.append(df_var) + + if frames: + out = pd.concat(frames, ignore_index=True) + else: + out = pd.DataFrame(columns=self.SCHEMA.names) + + out.attrs["source"] = self.SOURCE_ID + + # field-aware output + out = out[[f.name for f in schema]] + if self.cache: + self._call_cache[cache_key] = out + return out diff --git a/earth2studio/lexicon/nclimgrid.py b/earth2studio/lexicon/nclimgrid.py new file mode 100644 index 000000000..cfa2bd0dd --- /dev/null +++ b/earth2studio/lexicon/nclimgrid.py @@ -0,0 +1,111 @@ +from collections.abc import Callable + +import numpy as np + +from earth2studio.lexicon.base import LexiconType + + +class NClimGridLexicon(metaclass=LexiconType): + """ + NClimGrid gridded dataset lexicon. + + Provides canonical → native variable mapping and unit normalization + for the CONUS NClimGrid Zarr dataset. + + Dataset characteristics + ----------------------- + • Spatial coverage: CONUS (~0.0417° grid) + • Temporal coverage: daily (1952–present) + • Variables are already physically meaningful fields (not station codes) + • Missing values may exist (NaN / masked) + + Unit conventions + ---------------- + • Temperature fields stored in °C → converted to Kelvin + • Precipitation stored in mm → converted to meters + • SPI is dimensionless → unchanged + + Design goals + ------------ + • Future extensibility (new variables / datasets) + • Robust unit normalization + • Strong validation behaviour + • Explicit metadata layer + """ + + # ------------------------------------------------------- + # Rich variable metadata structure (future-proof) + # ------------------------------------------------------- + + META: dict[str, dict] = { + "t2m_max": { + "native": "tmax", + "units_native": "degC", + "units_e2s": "K", + "description": "daily maximum temperature at 2m", + }, + "t2m_min": { + "native": "tmin", + "units_native": "degC", + "units_e2s": "K", + "description": "daily minimum temperature at 2m", + }, + "tp": { + "native": "prcp", + "units_native": "mm", + "units_e2s": "m", + "description": "daily total precipitation", + }, + "spi": { + "native": "spi", + "units_native": "dimensionless", + "units_e2s": "dimensionless", + "description": "standardized precipitation index", + }, + } + + # Canonical vocabulary (required by LexiconType) + VOCAB: dict[str, str] = { + k: f"{v['description']} ({v['units_e2s']})" for k, v in META.items() + } + + # ------------------------------------------------------- + # Strong validation + modifier factory + # ------------------------------------------------------- + + @classmethod + def get_item(cls, val: str) -> tuple[str, Callable]: + """ + Resolve canonical variable. + + Returns + ------- + native_variable_name, modifier_function + """ + + if val not in cls.META: + raise KeyError( + f"NClimGridLexicon: unknown variable '{val}'. " + f"Valid variables: {list(cls.META.keys())}" + ) + + meta = cls.META[val] + native = meta["native"] + + # --------------------------------------------------- + # Robust modifier (handles NaN + dtype normalization) + # --------------------------------------------------- + + def modifier(x: np.ndarray) -> np.ndarray: + x = np.asarray(x, dtype="float32") + + if native in ("tmax", "tmin"): + return x + 273.15 + + if native == "prcp": + return x / 1000.0 + + # SPI / future dimensionless variables + return x + + return native, modifier diff --git a/test/data/test_nclimgrid.py b/test/data/test_nclimgrid.py new file mode 100644 index 000000000..b2f7d5141 --- /dev/null +++ b/test/data/test_nclimgrid.py @@ -0,0 +1,199 @@ +""" +Unit tests for data/earth2studio/nclimgrid.py +""" + +import time +from datetime import datetime, timedelta + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest + +from earth2studio.data.nclimgrid import NClimGrid +from earth2studio.lexicon.nclimgrid import NClimGridLexicon + +# --------------------------------------------------------------------- +# GLOBAL DATASET FIXTURE (VERY IMPORTANT) +# --------------------------------------------------------------------- + + +@pytest.fixture(scope="session") +def ds(): + """ + Open dataset ONCE for entire test session. + + Prevents: + - repeated Zarr open + - repeated S3 metadata scan + - test hangs + - extreme runtime + """ + return NClimGrid(cache=True, verbose=False) + + +# --------------------------------------------------------------------- +# OFFLINE TESTS +# --------------------------------------------------------------------- + + +class TestNClimGridOffline: + + def test_schema_fields(self): + assert NClimGrid.SCHEMA.names == [ + "time", + "lat", + "lon", + "observation", + "variable", + ] + + def test_schema_types(self): + assert NClimGrid.SCHEMA.field("time").type == pa.timestamp("ns") + assert NClimGrid.SCHEMA.field("lat").type == pa.float32() + assert NClimGrid.SCHEMA.field("lon").type == pa.float32() + assert NClimGrid.SCHEMA.field("observation").type == pa.float32() + assert NClimGrid.SCHEMA.field("variable").type == pa.string() + + def test_resolve_fields_invalid(self): + with pytest.raises(KeyError): + NClimGrid.resolve_fields(["not_real"]) + + def test_source_id(self): + assert NClimGrid.SOURCE_ID == "earth2studio.data.nclimgrid" + + def test_lexicon_variables(self): + for v in ["t2m_max", "t2m_min", "tp", "spi"]: + desc, mod = NClimGridLexicon[v] + assert isinstance(desc, str) + assert callable(mod) + + def test_unit_conversion_kelvin(self): + _, mod = NClimGridLexicon["t2m_max"] + np.testing.assert_allclose(mod(np.array([25.0])), [298.15]) + + def test_unit_conversion_precip(self): + _, mod = NClimGridLexicon["tp"] + np.testing.assert_allclose(mod(np.array([100.0])), [0.1]) + + def test_spi_identity(self): + _, mod = NClimGridLexicon["spi"] + arr = np.array([1.2]) + np.testing.assert_allclose(mod(arr), arr) + + +# --------------------------------------------------------------------- +# ONLINE TESTS +# --------------------------------------------------------------------- + + +@pytest.mark.network +class TestNClimGridOnline: + + DATE = datetime(2010, 7, 1) + + # ---------------- functional ---------------- + + def test_single_variable(self, ds): + df = ds(self.DATE, "t2m_max") + assert len(df) > 1000 + assert set(df["variable"]) == {"t2m_max"} + + def test_multi_variable(self, ds): + df = ds(self.DATE, ["t2m_max", "tp"]) + assert set(df["variable"]) == {"t2m_max", "tp"} + + def test_multiple_dates(self, ds): + dates = [self.DATE, self.DATE + timedelta(days=1)] + df = ds(dates, "t2m_max") + assert df["time"].nunique() == 2 + + def test_slice_semantics(self, ds): + df = ds(slice(datetime(2010, 7, 1), datetime(2010, 7, 3)), "t2m_max") + assert df["time"].nunique() == 3 + + # ---------------- grid integrity ---------------- + + def test_lat_lon_bounds(self, ds): + df = ds(self.DATE, "t2m_max") + assert df["lat"].between(20, 55).all() + assert df["lon"].between(-130, -60).all() + + def test_unique_grid_density(self, ds): + df = ds(self.DATE, "t2m_max") + assert df["lat"].nunique() > 100 + assert df["lon"].nunique() > 100 + + def test_no_nan_coordinates(self, ds): + df = ds(self.DATE, "t2m_max") + assert df["lat"].notna().all() + assert df["lon"].notna().all() + + # ---------------- scientific sanity ---------------- + + def test_temperature_mean_range(self, ds): + df = ds(self.DATE, "t2m_max") + assert 250 < df["observation"].mean() < 320 + + def test_temperature_extreme_range(self, ds): + df = ds(self.DATE, "t2m_max") + assert df["observation"].min() > 200 + assert df["observation"].max() < 350 + + def test_precip_nonnegative(self, ds): + df = ds(self.DATE, "tp") + valid = df["observation"].dropna() + assert len(valid) > 0 + assert (valid >= 0).all() + + # ---------------- scaling ---------------- + + def test_multi_variable_multi_time_scaling(self, ds): + dates = [self.DATE + timedelta(days=i) for i in range(5)] + df = ds(dates, ["t2m_max", "tp"]) + assert df["time"].nunique() == 5 + assert set(df["variable"]) == {"t2m_max", "tp"} + + def test_large_time_window(self, ds): + df = ds(slice(datetime(2010, 7, 1), datetime(2010, 7, 10)), "t2m_max") + assert df["time"].nunique() == 10 + + def test_duplicate_time_input(self, ds): + df = ds([self.DATE, self.DATE], "t2m_max") + assert df["time"].nunique() == 1 + + def test_time_order_invariance(self, ds): + d1 = [datetime(2010, 7, 1), datetime(2010, 7, 2)] + d2 = list(reversed(d1)) + assert set(ds(d1, "t2m_max")["time"]) == set(ds(d2, "t2m_max")["time"]) + + # ---------------- caching ---------------- + + def test_cache_speedup(self, ds): + t0 = time.time() + ds(self.DATE, "t2m_max") + first = time.time() - t0 + + t0 = time.time() + ds(self.DATE, "t2m_max") + second = time.time() - t0 + + assert second <= first + + # ---------------- dataframe integrity ---------------- + + def test_output_types(self, ds): + df = ds(self.DATE, "t2m_max") + assert pd.api.types.is_datetime64_any_dtype(df["time"]) + assert pd.api.types.is_float_dtype(df["lat"]) + assert pd.api.types.is_float_dtype(df["lon"]) + assert pd.api.types.is_float_dtype(df["observation"]) + assert pd.api.types.is_string_dtype(df["variable"]) + + def test_source_attr(self, ds): + df = ds(self.DATE, "t2m_max") + assert df.attrs["source"] == NClimGrid.SOURCE_ID + + def test_invalid_variable(self, ds): + with pytest.raises(KeyError): + ds(self.DATE, "not_real_variable")