From f7d69af52c00b61107cef1f39668f2693b5b2d11 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Tue, 10 Mar 2026 22:22:26 -0700 Subject: [PATCH 01/20] First pass --- fme/downscaling/data/config.py | 88 ------------------ fme/downscaling/data/datasets.py | 84 +++--------------- fme/downscaling/data/test_config.py | 25 +----- fme/downscaling/data/test_patching.py | 43 +-------- fme/downscaling/evaluator.py | 17 +--- fme/downscaling/inference/inference.py | 45 +++++++--- fme/downscaling/inference/output.py | 27 ++---- fme/downscaling/inference/test_inference.py | 41 ++++++--- fme/downscaling/inference/test_output.py | 28 +----- fme/downscaling/inference/work_items.py | 7 +- fme/downscaling/models.py | 13 +-- fme/downscaling/predict.py | 17 +--- fme/downscaling/predictors/cascade.py | 85 ++---------------- fme/downscaling/predictors/composite.py | 28 ++---- fme/downscaling/predictors/test_cascade.py | 93 +------------------- fme/downscaling/predictors/test_composite.py | 32 ++----- fme/downscaling/test_models.py | 14 ++- fme/downscaling/train.py | 13 +-- 18 files changed, 134 insertions(+), 566 deletions(-) diff --git a/fme/downscaling/data/config.py b/fme/downscaling/data/config.py index aec2ff7c0..a8d9cb022 100644 --- a/fme/downscaling/data/config.py +++ b/fme/downscaling/data/config.py @@ -23,7 +23,6 @@ PairedBatchData, PairedGriddedData, ) -from fme.downscaling.data.static import StaticInputs from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range from fme.downscaling.requirements import DataRequirements @@ -132,18 +131,6 @@ def _full_configs( return all_configs -def _check_fine_res_static_input_compatibility( - static_input_shape: tuple[int, int], data_coords_shape: tuple[int, int] -) -> None: - for static, coord in zip(static_input_shape, data_coords_shape): - if static != coord: - raise ValueError( - f"Static input shape {static_input_shape} is not compatible with " - f"data coordinates shape {data_coords_shape}. Static input dimensions " - "must match fine resolution coordinate dimensions." - ) - - @dataclasses.dataclass class DataLoaderConfig: """ @@ -236,40 +223,6 @@ def get_xarray_dataset( strict_ensemble=self.strict_ensemble, ) - def build_static_inputs( - self, - coarse_coords: LatLonCoordinates, - requires_topography: bool, - static_inputs: StaticInputs | None = None, - ) -> StaticInputs | None: - if requires_topography is False: - return None - if static_inputs is not None: - # TODO: change to use full static inputs list - full_static_inputs = static_inputs - else: - raise ValueError( - "Static inputs required for this model, but no static inputs " - "datasets were specified in the trainer configuration or provided " - "in model checkpoint." - ) - - # Fine grid boundaries are adjusted to exactly match the coarse grid - fine_lat_interval = adjust_fine_coord_range( - self.lat_extent, - full_coarse_coord=coarse_coords.lat, - full_fine_coord=full_static_inputs.coords.lat, - ) - fine_lon_interval = adjust_fine_coord_range( - self.lon_extent, - full_coarse_coord=coarse_coords.lon, - full_fine_coord=full_static_inputs.coords.lon, - ) - subset_static_inputs = full_static_inputs.subset_latlon( - lat_interval=fine_lat_interval, lon_interval=fine_lon_interval - ) - return subset_static_inputs.to_device() - def build_batchitem_dataset( self, dataset: XarrayConcat, @@ -301,14 +254,7 @@ def build( self, requirements: DataRequirements, dist: Distributed | None = None, - static_inputs: StaticInputs | None = None, ) -> GriddedData: - # TODO: static_inputs_from_checkpoint is currently passed from the model - # to allow loading fine topography when no fine data is available. - # See PR https://github.com/ai2cm/ace/pull/728 - # In the future we could disentangle this dependency between the data loader - # and model by enabling the built GriddedData objects to take in full static - # input fields and subset them to the same coordinate range as data. xr_dataset, properties = self.get_xarray_dataset( names=requirements.coarse_names, n_timesteps=1 ) @@ -316,7 +262,6 @@ def build( raise ValueError( "Downscaling data loader only supports datasets with latlon coords." ) - latlon_coords = properties.horizontal_coordinates dataset = self.build_batchitem_dataset( dataset=xr_dataset, properties=properties, @@ -343,14 +288,8 @@ def build( persistent_workers=True if self.num_data_workers > 0 else False, ) example = dataset[0] - subset_static_inputs = self.build_static_inputs( - coarse_coords=latlon_coords, - requires_topography=requirements.use_fine_topography, - static_inputs=static_inputs, - ) return GriddedData( _loader=dataloader, - static_inputs=subset_static_inputs, shape=example.horizontal_shape, dims=example.latlon_coordinates.dims, variable_metadata=dataset.variable_metadata, @@ -468,14 +407,7 @@ def build( train: bool, requirements: DataRequirements, dist: Distributed | None = None, - static_inputs: StaticInputs | None = None, ) -> PairedGriddedData: - # TODO: static_inputs_from_checkpoint is currently passed from the model - # to allow loading fine topography when no fine data is available. - # See PR https://github.com/ai2cm/ace/pull/728 - # In the future we could disentangle this dependency between the data loader - # and model by enabling the built GriddedData objects to take in full static - # input fields and subset them to the same coordinate range as data. if dist is None: dist = Distributed.get_instance() @@ -537,25 +469,6 @@ def build( full_fine_coord=properties_fine.horizontal_coordinates.lon, ) - if requirements.use_fine_topography: - if static_inputs is None: - raise ValueError( - "Model requires static inputs (use_fine_topography=True)," - " but no static inputs were provided to the data loader's" - " build method." - ) - - static_inputs = static_inputs.to_device() - _check_fine_res_static_input_compatibility( - static_inputs.shape, - properties_fine.horizontal_coordinates.shape, - ) - static_inputs = static_inputs.subset_latlon( - lat_interval=fine_lat_extent, lon_interval=fine_lon_extent - ) - else: - static_inputs = None - dataset_fine_subset = HorizontalSubsetDataset( dataset_fine, properties=properties_fine, @@ -611,7 +524,6 @@ def build( return PairedGriddedData( _loader=dataloader, - static_inputs=static_inputs, coarse_shape=example.coarse.horizontal_shape, downscale_factor=example.downscale_factor, dims=example.fine.latlon_coordinates.dims, diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index 38e7f9191..3f073edfc 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -20,14 +20,12 @@ from fme.core.generics.data import SizedMap from fme.core.typing_ import TensorMapping from fme.downscaling.data.patching import Patch, get_patches -from fme.downscaling.data.static import StaticInputs from fme.downscaling.data.utils import ( BatchedLatLonCoordinates, ClosedInterval, check_leading_dim, expand_and_fold_tensor, get_offset, - null_generator, paired_shuffle, scale_tuple, ) @@ -298,7 +296,6 @@ class GriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex - static_inputs: StaticInputs | None @property def loader(self) -> DataLoader[BatchItem]: @@ -307,27 +304,9 @@ def on_device(batch: BatchItem) -> BatchItem: return SizedMap(on_device, self._loader) - @property - def topography_downscale_factor(self) -> int | None: - if self.static_inputs: - if ( - self.static_inputs.shape[0] % self.shape[0] != 0 - or self.static_inputs.shape[1] % self.shape[1] != 0 - ): - raise ValueError( - "Static inputs shape must be evenly divisible by data shape. " - f"Got static inputs with shape {self.static_inputs.shape} " - f"and data with shape {self.shape}" - ) - return self.static_inputs.shape[0] // self.shape[0] - else: - return None - - def get_generator( - self, - ) -> Iterator[tuple["BatchData", StaticInputs | None]]: + def get_generator(self) -> Iterator["BatchData"]: for batch in self.loader: - yield (batch, self.static_inputs) + yield batch def get_patched_generator( self, @@ -335,20 +314,18 @@ def get_patched_generator( overlap: int = 0, drop_partial_patches: bool = True, random_offset: bool = False, - ) -> Iterator[tuple["BatchData", StaticInputs | None]]: + ) -> Iterator["BatchData"]: patched_generator = patched_batch_gen_from_loader( loader=self.loader, - static_inputs=self.static_inputs, coarse_yx_extent=self.shape, coarse_yx_patch_extent=yx_patch_extent, - downscale_factor=self.topography_downscale_factor, coarse_overlap=overlap, drop_partial_patches=drop_partial_patches, random_offset=random_offset, ) return cast( - Iterator[tuple[BatchData, StaticInputs | None]], + Iterator[BatchData], patched_generator, ) @@ -361,7 +338,6 @@ class PairedGriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex - static_inputs: StaticInputs | None @property def loader(self) -> DataLoader[PairedBatchItem]: @@ -370,11 +346,9 @@ def on_device(batch: PairedBatchItem) -> PairedBatchItem: return SizedMap(on_device, self._loader) - def get_generator( - self, - ) -> Iterator[tuple["PairedBatchData", StaticInputs | None]]: + def get_generator(self) -> Iterator["PairedBatchData"]: for batch in self.loader: - yield (batch, self.static_inputs) + yield batch def get_patched_generator( self, @@ -383,10 +357,9 @@ def get_patched_generator( drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, - ) -> Iterator[tuple["PairedBatchData", StaticInputs | None]]: + ) -> Iterator["PairedBatchData"]: patched_generator = patched_batch_gen_from_paired_loader( self.loader, - self.static_inputs, coarse_yx_extent=self.coarse_shape, coarse_yx_patch_extent=coarse_yx_patch_extent, downscale_factor=self.downscale_factor, @@ -396,7 +369,7 @@ def get_patched_generator( shuffle=shuffle, ) return cast( - Iterator[tuple[PairedBatchData, StaticInputs | None]], + Iterator[PairedBatchData], patched_generator, ) @@ -713,44 +686,28 @@ def _get_paired_patches( def patched_batch_gen_from_loader( loader: DataLoader[BatchItem], - static_inputs: StaticInputs | None, coarse_yx_extent: tuple[int, int], coarse_yx_patch_extent: tuple[int, int], - downscale_factor: int | None, coarse_overlap: int = 0, drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, -) -> Iterator[tuple[BatchData, StaticInputs | None]]: +) -> Iterator[BatchData]: for batch in loader: - coarse_patches, fine_patches = _get_paired_patches( + coarse_patches, _ = _get_paired_patches( coarse_yx_extent=coarse_yx_extent, coarse_yx_patch_extent=coarse_yx_patch_extent, coarse_overlap=coarse_overlap, - downscale_factor=downscale_factor, + downscale_factor=None, random_offset=random_offset, shuffle=shuffle, drop_partial_patches=drop_partial_patches, ) - batch_data_patches = batch.generate_from_patches(coarse_patches) - - if static_inputs is not None: - if fine_patches is None: - raise ValueError( - "Topography provided but downscale_factor is None, cannot " - "generate fine patches." - ) - static_inputs_patches = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_patches = null_generator(len(coarse_patches)) - - # Combine outputs from both generators - yield from zip(batch_data_patches, static_inputs_patches) + yield from batch.generate_from_patches(coarse_patches) def patched_batch_gen_from_paired_loader( loader: DataLoader[PairedBatchItem], - static_inputs: StaticInputs | None, coarse_yx_extent: tuple[int, int], coarse_yx_patch_extent: tuple[int, int], downscale_factor: int, @@ -758,7 +715,7 @@ def patched_batch_gen_from_paired_loader( drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, -) -> Iterator[tuple[PairedBatchData, StaticInputs | None]]: +) -> Iterator[PairedBatchData]: for batch in loader: coarse_patches, fine_patches = _get_paired_patches( coarse_yx_extent=coarse_yx_extent, @@ -769,17 +726,4 @@ def patched_batch_gen_from_paired_loader( shuffle=shuffle, drop_partial_patches=drop_partial_patches, ) - batch_data_patches = batch.generate_from_patches(coarse_patches, fine_patches) - - if static_inputs is not None: - if fine_patches is None: - raise ValueError( - "Static inputs provided but downscale_factor is None, cannot " - "generate fine patches." - ) - static_inputs_patches = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_patches = null_generator(len(coarse_patches)) - - # Combine outputs from both generators - yield from zip(batch_data_patches, static_inputs_patches) + yield from batch.generate_from_patches(coarse_patches, fine_patches) diff --git a/fme/downscaling/data/test_config.py b/fme/downscaling/data/test_config.py index 1aaff8577..f916abadd 100644 --- a/fme/downscaling/data/test_config.py +++ b/fme/downscaling/data/test_config.py @@ -1,7 +1,6 @@ import dataclasses import pytest -import torch from fme.core.dataset.merged import MergeNoConcatDatasetConfig from fme.core.dataset.xarray import XarrayDataConfig @@ -10,25 +9,11 @@ PairedDataLoaderConfig, XarrayEnsembleDataConfig, ) -from fme.downscaling.data.static import StaticInput, StaticInputs -from fme.downscaling.data.utils import ClosedInterval, LatLonCoordinates +from fme.downscaling.data.utils import ClosedInterval from fme.downscaling.requirements import DataRequirements from fme.downscaling.test_utils import data_paths_helper -def get_static_inputs(shape=(8, 8)): - return StaticInputs( - fields=[ - StaticInput( - data=torch.ones(shape), - coords=LatLonCoordinates( - lat=torch.ones(shape[0]), lon=torch.ones(shape[1]) - ), - ) - ] - ) - - @pytest.mark.parametrize( "fine_engine, coarse_engine, num_data_workers, expected", [ @@ -78,9 +63,7 @@ def test_DataLoaderConfig_build(tmp_path, very_fast_only: bool): lat_extent=ClosedInterval(1, 4), lon_extent=ClosedInterval(0, 3), ) - data = data_config.build( - requirements=requirements, static_inputs=get_static_inputs(shape=(8, 8)) - ) + data = data_config.build(requirements=requirements) batch = next(iter(data.loader)) # lat/lon midpoints are on (0.5, 1.5, ...) assert batch.data["var0"].shape == (2, 3, 3) @@ -152,9 +135,7 @@ def test_DataLoaderConfig_includes_merge(tmp_path, very_fast_only: bool): lon_extent=ClosedInterval(0, 3), ) - data = data_config.build( - requirements=requirements, static_inputs=get_static_inputs(shape=(8, 8)) - ) + data = data_config.build(requirements=requirements) # XarrayDataConfig + MergeNoConcatDatasetConfig each # contribute 4 timesteps = 8 total assert len(data.loader) == 4 # 8 samples / batch_size 2 diff --git a/fme/downscaling/data/test_patching.py b/fme/downscaling/data/test_patching.py index 7630722c4..fb6cbc7de 100644 --- a/fme/downscaling/data/test_patching.py +++ b/fme/downscaling/data/test_patching.py @@ -2,8 +2,7 @@ import pytest import torch -from fme.core.device import get_device -from fme.downscaling.data import PairedBatchData, StaticInput, StaticInputs +from fme.downscaling.data import PairedBatchData from fme.downscaling.data.datasets import patched_batch_gen_from_paired_loader from fme.downscaling.data.patching import ( _divide_into_slices, @@ -115,19 +114,10 @@ def test_paired_patches_with_random_offset_consistent(overlap): full_coarse_coords = full_data.coarse.latlon_coordinates full_fine_coords = full_data.fine.latlon_coordinates - topography_data = torch.randn( - coarse_shape[0] * downscale_factor, - coarse_shape[1] * downscale_factor, - device=get_device(), - ) - topography = StaticInputs( - fields=[StaticInput(data=topography_data, coords=full_fine_coords[0])] - ) y_offsets = [] x_offsets = [] batch_generator = patched_batch_gen_from_paired_loader( loader=loader, - static_inputs=topography, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(10, 10), downscale_factor=downscale_factor, @@ -136,7 +126,7 @@ def test_paired_patches_with_random_offset_consistent(overlap): random_offset=True, ) paired_batch: PairedBatchData - for paired_batch, _ in batch_generator: # type: ignore + for paired_batch in batch_generator: assert paired_batch.coarse.data["x"].shape == (batch_size, 10, 10) assert paired_batch.fine.data["x"].shape == (batch_size, 20, 20) @@ -177,19 +167,9 @@ def test_paired_patches_shuffle(shuffle): loader = _mock_data_loader( 10, *coarse_shape, downscale_factor=downscale_factor, batch_size=batch_size ) - topography_data = torch.randn( - coarse_shape[0] * downscale_factor, - coarse_shape[1] * downscale_factor, - device=get_device(), - ) - fine_coords = next(iter(loader)).fine.latlon_coordinates[0] - static_inputs = StaticInputs( - fields=[StaticInput(data=topography_data, coords=fine_coords)] - ) generator0 = patched_batch_gen_from_paired_loader( loader=loader, - static_inputs=static_inputs, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(2, 2), downscale_factor=downscale_factor, @@ -200,7 +180,6 @@ def test_paired_patches_shuffle(shuffle): ) generator1 = patched_batch_gen_from_paired_loader( loader=loader, - static_inputs=static_inputs, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(2, 2), downscale_factor=downscale_factor, @@ -212,28 +191,14 @@ def test_paired_patches_shuffle(shuffle): patches0: list[PairedBatchData] = [] patches1: list[PairedBatchData] = [] - topography0: list[torch.Tensor] = [] - topography1: list[torch.Tensor] = [] for i in range(4): - p0, t0 = next(generator0) - patches0.append(p0) # type: ignore - topography0.append(t0) - p1, t1 = next(generator1) - patches1.append(p1) # type: ignore - topography1.append(t1) + patches0.append(next(generator0)) # type: ignore + patches1.append(next(generator1)) # type: ignore data0 = torch.concat([patch.coarse.data["x"] for patch in patches0], dim=0) data1 = torch.concat([patch.coarse.data["x"] for patch in patches1], dim=0) - topo_concat_0 = torch.concat( - [t0.fields[0].data for t0 in topography0 if t0 is not None], dim=0 - ) - topo_concat_1 = torch.concat( - [t1.fields[0].data for t1 in topography1 if t1 is not None], dim=0 - ) if shuffle: assert not torch.equal(data0, data1) - assert not torch.equal(topo_concat_0, topo_concat_1) else: assert torch.equal(data0, data1) - assert torch.equal(topo_concat_0, topo_concat_1) diff --git a/fme/downscaling/evaluator.py b/fme/downscaling/evaluator.py index af67f3026..c901b5e22 100644 --- a/fme/downscaling/evaluator.py +++ b/fme/downscaling/evaluator.py @@ -15,7 +15,6 @@ from fme.downscaling.data import ( PairedDataLoaderConfig, PairedGriddedData, - StaticInputs, enforce_lat_bounds, ) from fme.downscaling.models import CheckpointModelConfig, DiffusionModel @@ -60,12 +59,10 @@ def run(self): else: batch_generator = self.data.get_generator() - for i, (batch, static_inputs) in enumerate(batch_generator): + for i, batch in enumerate(batch_generator): with torch.no_grad(): logging.info(f"Generating predictions on batch {i + 1}") - outputs = self.model.generate_on_batch( - batch, static_inputs, n_samples=self.n_samples - ) + outputs = self.model.generate_on_batch(batch, n_samples=self.n_samples) logging.info("Recording diagnostics to aggregator") # Add sample dimension to coarse values for generation comparison coarse = {k: v.unsqueeze(1) for k, v in batch.coarse.data.items()} @@ -111,7 +108,7 @@ def __init__( def run(self): logging.info(f"Running {self.event_name} event evaluation") - batch, static_inputs = next(iter(self.data.get_generator())) + batch = next(iter(self.data.get_generator())) sample_agg = PairedSampleAggregator( target=batch[0].fine.data, coarse=batch[0].coarse.data, @@ -129,9 +126,7 @@ def run(self): f"Generating samples {start_idx} to {end_idx} " f"for event {self.event_name}" ) - outputs = self.model.generate_on_batch( - batch, static_inputs, n_samples=end_idx - start_idx - ) + outputs = self.model.generate_on_batch(batch, n_samples=end_idx - start_idx) sample_agg.record_batch(outputs.prediction) to_log = sample_agg.get_wandb() @@ -156,7 +151,6 @@ def get_paired_gridded_data( self, base_data_config: PairedDataLoaderConfig, requirements: DataRequirements, - static_inputs_from_checkpoint: StaticInputs | None = None, ) -> PairedGriddedData: enforce_lat_bounds(self.lat_extent) time_slice = self._time_selection_slice @@ -177,7 +171,6 @@ def get_paired_gridded_data( return event_data_config.build( train=False, requirements=requirements, - static_inputs=static_inputs_from_checkpoint, ) @@ -204,7 +197,6 @@ def _build_default_evaluator(self) -> Evaluator: dataset = self.data.build( train=False, requirements=self.model.data_requirements, - static_inputs=model.static_inputs, ) evaluator_model: DiffusionModel | PatchPredictor if self.patch.divide_generation and self.patch.composite_prediction: @@ -241,7 +233,6 @@ def _build_event_evaluator( dataset = event_config.get_paired_gridded_data( base_data_config=self.data, requirements=self.model.data_requirements, - static_inputs_from_checkpoint=model.static_inputs, ) if (dataset.coarse_shape[0] > model.coarse_shape[0]) or ( diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index ed8539e56..daf9b5338 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -10,7 +10,7 @@ from fme.core.generics.trainer import count_parameters from fme.core.logging_utils import LoggingConfig -from ..data import DataLoaderConfig, StaticInputs +from ..data import DataLoaderConfig, adjust_fine_coord_range from ..models import CheckpointModelConfig, DiffusionModel from ..predictors import ( CascadePredictor, @@ -55,7 +55,7 @@ def run_all(self): def _get_generation_model( self, - static_inputs: StaticInputs, + fine_shape: tuple[int, int], output: DownscalingOutput, ) -> DiffusionModel | PatchPredictor | CascadePredictor: """ @@ -66,7 +66,7 @@ def _get_generation_model( generations. """ model_patch_shape = self.model.fine_shape - actual_shape = tuple(static_inputs.shape) + actual_shape = fine_shape if model_patch_shape == actual_shape: # short circuit, no patching necessary @@ -77,8 +77,8 @@ def _get_generation_model( ): # we don't support generating regions smaller than the model patch size raise ValueError( - f"Model coarse shape {model_patch_shape} is larger than " - f"actual topography shape {actual_shape} for output {output.name}." + f"Model fine shape {model_patch_shape} is larger than " + f"actual fine shape {actual_shape} for output {output.name}." ) elif output.patch.needs_patch_predictor: # Use a patch predictor @@ -97,8 +97,8 @@ def _get_generation_model( ) def _on_device_generator(self, loader): - for loaded_item, topography in loader: - yield loaded_item.to_device(), topography.to_device() + for loaded_item in loader: + yield loaded_item.to_device() def run_output_generation(self, output: DownscalingOutput): """Execute the generation loop for this output.""" @@ -107,22 +107,40 @@ def run_output_generation(self, output: DownscalingOutput): # initialize writer and model in loop for coord info model = None writer = None + fine_static_inputs = None total_batches = len(output.data.loader) loaded_item: LoadedSliceWorkItem - static_inputs: StaticInputs - for i, (loaded_item, static_inputs) in enumerate(output.data.get_generator()): + for i, loaded_item in enumerate(output.data.get_generator()): if writer is None: + coarse_lat = loaded_item.batch.latlon_coordinates.lat[0] + coarse_lon = loaded_item.batch.latlon_coordinates.lon[0] + fine_lat_interval = adjust_fine_coord_range( + loaded_item.batch.lat_interval, + full_coarse_coord=coarse_lat, + full_fine_coord=self.model.static_inputs.coords.lat, + downscale_factor=self.model.downscale_factor, + ) + fine_lon_interval = adjust_fine_coord_range( + loaded_item.batch.lon_interval, + full_coarse_coord=coarse_lon, + full_fine_coord=self.model.static_inputs.coords.lon, + downscale_factor=self.model.downscale_factor, + ) + fine_static_inputs = self.model.static_inputs.subset_latlon( + fine_lat_interval, + fine_lon_interval, + ) writer = output.get_writer( - latlon_coords=static_inputs.coords, + latlon_coords=fine_static_inputs.coords, output_dir=self.output_dir, ) writer.initialize_store( - static_inputs.fields[0].data.cpu().numpy().dtype + fine_static_inputs.fields[0].data.cpu().numpy().dtype ) if model is None: model = self._get_generation_model( - static_inputs=static_inputs, output=output + fine_shape=fine_static_inputs.shape, output=output ) logging.info( @@ -132,7 +150,6 @@ def run_output_generation(self, output: DownscalingOutput): output_data = model.generate_on_batch_no_target( loaded_item.batch, - static_inputs=static_inputs, n_samples=loaded_item.n_ens, ) output_np = {key: value.cpu().numpy() for key, value in output_data.items()} @@ -243,7 +260,7 @@ def build(self) -> Downscaler: loader_config=self.data, requirements=self.model.data_requirements, patch=self.patch, - static_inputs_from_checkpoint=model.static_inputs, + fine_shape=model.fine_shape, ) for output_cfg in self.outputs ] diff --git a/fme/downscaling/inference/output.py b/fme/downscaling/inference/output.py index 012b555ad..c2a9f5f12 100644 --- a/fme/downscaling/inference/output.py +++ b/fme/downscaling/inference/output.py @@ -18,7 +18,6 @@ ClosedInterval, DataLoaderConfig, LatLonCoordinates, - StaticInputs, enforce_lat_bounds, ) from ..data.config import XarrayEnsembleDataConfig @@ -153,6 +152,7 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, + fine_shape: tuple[int, int] | None = None, ) -> DownscalingOutput: """ Build an OutputTarget from this configuration. @@ -218,7 +218,7 @@ def _build_gridded_data( loader_config: DataLoaderConfig, requirements: DataRequirements, dist: Distributed | None = None, - static_inputs_from_checkpoint: StaticInputs | None = None, + fine_shape: tuple[int, int] | None = None, ) -> SliceWorkItemGriddedData: xr_dataset, properties = loader_config.get_xarray_dataset( names=requirements.coarse_names, n_timesteps=1 @@ -229,13 +229,6 @@ def _build_gridded_data( "Downscaling data loader only supports datasets with latlon coords." ) dataset = loader_config.build_batchitem_dataset(xr_dataset, properties) - topography = loader_config.build_static_inputs( - coords, - requires_topography=requirements.use_fine_topography, - static_inputs=static_inputs_from_checkpoint, - ) - if topography is None: - raise ValueError("Topography is required for downscaling generation.") work_items = get_work_items( n_times=len(dataset), @@ -243,11 +236,10 @@ def _build_gridded_data( max_samples_per_gpu=self.max_samples_per_gpu, ) - # defer topography device placement until after batch generation slice_dataset = SliceItemDataset( slice_items=work_items, dataset=dataset, - spatial_shape=topography.shape, + spatial_shape=fine_shape, ) # each SliceItemDataset work item loads its own full batch, so batch_size=1 @@ -274,7 +266,6 @@ def _build_gridded_data( all_times=xr_dataset.sample_start_times, dtype=slice_dataset.dtype, max_output_shape=slice_dataset.max_output_shape, - static_inputs=topography, ) def _build( @@ -286,7 +277,7 @@ def _build( requirements: DataRequirements, patch: PatchPredictionConfig, coarse: list[XarrayDataConfig], - static_inputs_from_checkpoint: StaticInputs | None = None, + fine_shape: tuple[int, int] | None = None, ) -> DownscalingOutput: updated_loader_config = self._replace_loader_config( time, @@ -299,7 +290,7 @@ def _build( gridded_data = self._build_gridded_data( updated_loader_config, requirements, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, + fine_shape=fine_shape, ) if self.zarr_chunks is None: @@ -386,7 +377,7 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, - static_inputs_from_checkpoint: StaticInputs | None = None, + fine_shape: tuple[int, int] | None = None, ) -> DownscalingOutput: # Convert single time to TimeSlice time: Slice | TimeSlice @@ -409,7 +400,7 @@ def build( requirements=requirements, patch=patch, coarse=coarse, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, + fine_shape=fine_shape, ) @@ -469,7 +460,7 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, - static_inputs_from_checkpoint: StaticInputs | None = None, + fine_shape: tuple[int, int] | None = None, ) -> DownscalingOutput: coarse = self._single_xarray_config(loader_config.coarse) return self._build( @@ -480,5 +471,5 @@ def build( requirements=requirements, patch=patch, coarse=coarse, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, + fine_shape=fine_shape, ) diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index e5cf8c641..4d6bb0020 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -12,6 +12,7 @@ from fme.core.dataset.time import TimeSlice from fme.core.logging_utils import LoggingConfig from fme.downscaling.data import ( + ClosedInterval, LatLonCoordinates, StaticInput, StaticInputs, @@ -92,7 +93,6 @@ def test_get_generation_model_exact_match(mock_model, mock_output_target): Test _get_generation_model returns model unchanged when shapes match exactly. """ mock_model.fine_shape = (16, 16) - static_inputs = get_static_inputs(shape=(16, 16)) downscaler = Downscaler( model=mock_model, @@ -100,7 +100,7 @@ def test_get_generation_model_exact_match(mock_model, mock_output_target): ) result = downscaler._get_generation_model( - static_inputs=static_inputs, + fine_shape=(16, 16), output=mock_output_target, ) @@ -116,7 +116,6 @@ def test_get_generation_model_raises_when_domain_too_small( smaller than model. """ mock_model.fine_shape = (16, 16) - topo = get_static_inputs(shape=topo_shape) downscaler = Downscaler( model=mock_model, @@ -125,7 +124,7 @@ def test_get_generation_model_raises_when_domain_too_small( with pytest.raises(ValueError): downscaler._get_generation_model( - static_inputs=topo, + fine_shape=topo_shape, output=mock_output_target, ) @@ -138,7 +137,6 @@ def test_get_generation_model_creates_patch_predictor_when_needed( large domains with patching. """ mock_model.fine_shape = (16, 16) - static_inputs = get_static_inputs(shape=(32, 32)) # Larger than model patch_config = PatchPredictionConfig( divide_generation=True, @@ -152,7 +150,7 @@ def test_get_generation_model_creates_patch_predictor_when_needed( ) model = downscaler._get_generation_model( - static_inputs=static_inputs, + fine_shape=(32, 32), output=mock_output_target, ) @@ -168,7 +166,6 @@ def test_get_generation_model_raises_when_large_domain_without_patching( not configured. """ mock_model.fine_shape = (16, 16) - topo = get_static_inputs(shape=(32, 32)) # Larger than model mock_output_target.patch = PatchPredictionConfig(divide_generation=False) downscaler = Downscaler( @@ -178,7 +175,7 @@ def test_get_generation_model_raises_when_large_domain_without_patching( with pytest.raises(ValueError): downscaler._get_generation_model( - static_inputs=topo, + fine_shape=(32, 32), output=mock_output_target, ) @@ -192,15 +189,33 @@ def test_run_target_generation_skips_padding_items( mock_work_item = MagicMock() mock_work_item.is_padding = True mock_work_item.n_ens = 4 - mock_work_item.batch = MagicMock() - - static_inputs = get_static_inputs(shape=(16, 16)) + # to_device() should return self so batch coords are preserved + mock_work_item.to_device.return_value = mock_work_item + + # Set up proper batch coordinates for adjust_fine_coord_range to work. + # Coarse coords are interior so fine can have buffer on each side. + coarse_lat = torch.arange(1, 9).float() # 8 values: 1..8 + coarse_lon = torch.arange(1, 9).float() + mock_latlon = MagicMock() + mock_latlon.lat = coarse_lat.unsqueeze(0) + mock_latlon.lon = coarse_lon.unsqueeze(0) + mock_work_item.batch.latlon_coordinates = mock_latlon + mock_work_item.batch.lat_interval = ClosedInterval(1.0, 8.0) + mock_work_item.batch.lon_interval = ClosedInterval(1.0, 8.0) + + fine_static_inputs = get_static_inputs(shape=(16, 16)) + fine_lat = torch.arange(0, 18).float() # 18 values: 0..17 + fine_lon = torch.arange(0, 18).float() + mock_model.static_inputs.coords.lat = fine_lat + mock_model.static_inputs.coords.lon = fine_lon + mock_model.static_inputs.subset_latlon.return_value = fine_static_inputs + mock_model.downscale_factor = 2 + mock_model.fine_shape = (16, 16) mock_gridded_data = SliceWorkItemGriddedData( - [mock_work_item], {}, [0], torch.float32, (1, 4, 16, 16), static_inputs + [mock_work_item], {}, [0], torch.float32, (1, 4, 16, 16) ) mock_output_target.data = mock_gridded_data - mock_model.fine_shape = (16, 16) mock_output = { "var1": torch.randn(1, 4, 16, 16), diff --git a/fme/downscaling/inference/test_output.py b/fme/downscaling/inference/test_output.py index f9f041ae4..413fcfcb0 100644 --- a/fme/downscaling/inference/test_output.py +++ b/fme/downscaling/inference/test_output.py @@ -1,12 +1,10 @@ from unittest.mock import MagicMock import pytest -import torch from fme.core.dataset.time import TimeSlice from fme.core.dataset.xarray import XarrayDataConfig -from fme.downscaling.data import ClosedInterval, StaticInput, StaticInputs -from fme.downscaling.data.utils import LatLonCoordinates +from fme.downscaling.data import ClosedInterval from fme.downscaling.inference.output import ( DownscalingOutput, DownscalingOutputConfig, @@ -19,19 +17,6 @@ # Tests for OutputTargetConfig validation -def _get_static_inputs(shape=(8, 8)): - return StaticInputs( - fields=[ - StaticInput( - data=torch.ones(shape), - coords=LatLonCoordinates( - lat=torch.ones(shape[0]), lon=torch.ones(shape[1]) - ), - ) - ] - ) - - def test_single_xarray_config_accepts_single_config(): """Test that _single_xarray_config accepts a single XarrayDataConfig.""" xarray_config = XarrayDataConfig( @@ -116,10 +101,7 @@ def test_event_config_build_creates_output_target_with_single_time( lat_extent=ClosedInterval(0.0, 6.0), lon_extent=ClosedInterval(0.0, 6.0), ) - static_inputs = _get_static_inputs((8, 8)) - output_target = config.build( - loader_config, requirements, patch_config, static_inputs - ) + output_target = config.build(loader_config, requirements, patch_config) # Verify OutputTarget was created assert isinstance(output_target, DownscalingOutput) @@ -147,11 +129,7 @@ def test_region_config_build_creates_output_target_with_time_range( n_ens=4, save_vars=["var0", "var1"], ) - static_inputs = _get_static_inputs((8, 8)) - - output_target = config.build( - loader_config, requirements, patch_config, static_inputs - ) + output_target = config.build(loader_config, requirements, patch_config) # Verify OutputTarget was created assert isinstance(output_target, DownscalingOutput) diff --git a/fme/downscaling/inference/work_items.py b/fme/downscaling/inference/work_items.py index 0b5ef27a5..ece36912d 100644 --- a/fme/downscaling/inference/work_items.py +++ b/fme/downscaling/inference/work_items.py @@ -10,7 +10,7 @@ from fme.core.distributed import Distributed from fme.core.generics.data import SizedMap -from ..data import BatchData, StaticInputs +from ..data import BatchData from ..data.config import BatchItemDatasetAdapter from .constants import ENSEMBLE_NAME, TIME_NAME @@ -297,7 +297,6 @@ class SliceWorkItemGriddedData: all_times: xr.CFTimeIndex dtype: torch.dtype max_output_shape: tuple[int, ...] - static_inputs: StaticInputs # TODO: currently no protocol or ABC for gridded data objects # if we want to unify, we will need one and just raise @@ -310,7 +309,7 @@ def on_device(work_item: LoadedSliceWorkItem) -> LoadedSliceWorkItem: return SizedMap(on_device, self._loader) - def get_generator(self) -> Iterator[tuple[LoadedSliceWorkItem, StaticInputs]]: + def get_generator(self) -> Iterator[LoadedSliceWorkItem]: work_item: LoadedSliceWorkItem for work_item in self.loader: - yield work_item, self.static_inputs + yield work_item diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 8ffa8c22b..ff7e44ee5 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -383,12 +383,9 @@ def _get_input_from_coarse( def train_on_batch( self, batch: PairedBatchData, - static_inputs: StaticInputs | None, # TODO: remove in follow-on PR optimizer: Optimization | NullOptimization, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" - # Ignore the passed static_inputs; subset self.static_inputs using fine batch - # coordinates. The caller-provided value is kept for signature compatibility. _static_inputs = self._subset_static_inputs( batch.fine.lat_interval, batch.fine.lon_interval ) @@ -452,8 +449,8 @@ def generate( static_inputs: StaticInputs | None, n_samples: int = 1, ) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]: - # static_inputs receives an internally-subsetted value from the calling method; - # external callers should use generate_on_batch / generate_on_batch_no_target. + # Internal method; external callers should use generate_on_batch / + # generate_on_batch_no_target. inputs_ = self._get_input_from_coarse(coarse_data, static_inputs) # expand samples and fold to # [batch * n_samples, output_channels, height, width] @@ -503,11 +500,8 @@ def generate( def generate_on_batch_no_target( self, batch: BatchData, - static_inputs: StaticInputs | None, # TODO: remove in follow-on PR n_samples: int = 1, ) -> TensorDict: - # Ignore the passed static_inputs; derive the fine lat/lon interval from coarse - # batch coordinates via adjust_fine_coord_range, then subset self.static_inputs. if self.config.use_fine_topography: if self.static_inputs is None: raise ValueError( @@ -540,11 +534,8 @@ def generate_on_batch_no_target( def generate_on_batch( self, batch: PairedBatchData, - static_inputs: StaticInputs | None, # TODO: remove in follow-on PR n_samples: int = 1, ) -> ModelOutputs: - # Ignore the passed static_inputs; subset self.static_inputs using fine batch - # coordinates. The caller-provided value is kept for signature compatibility. _static_inputs = self._subset_static_inputs( batch.fine.lat_interval, batch.fine.lon_interval ) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index 0f0715d78..3f66937ac 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -21,7 +21,6 @@ ClosedInterval, DataLoaderConfig, GriddedData, - StaticInputs, enforce_lat_bounds, ) from fme.downscaling.models import CheckpointModelConfig, DiffusionModel @@ -93,7 +92,6 @@ def get_gridded_data( self, base_data_config: DataLoaderConfig, requirements: DataRequirements, - static_inputs_from_checkpoint: StaticInputs | None = None, ) -> GriddedData: enforce_lat_bounds(self.lat_extent) event_coarse = dataclasses.replace(base_data_config.full_config[0]) @@ -109,7 +107,6 @@ def get_gridded_data( ) return event_data_config.build( requirements=requirements, - static_inputs=static_inputs_from_checkpoint, ) @@ -151,7 +148,7 @@ def generation_model(self): def run(self): logging.info(f"Running {self.event_name} event downscaling...") - batch, static_inputs = next(iter(self.data.get_generator())) + batch = next(iter(self.data.get_generator())) coarse_coords = batch[0].latlon_coordinates fine_coords = LatLonCoordinates( lat=_downscale_coord(coarse_coords.lat, self.model.downscale_factor), @@ -174,7 +171,7 @@ def run(self): f"for event {self.event_name}" ) outputs = self.model.generate_on_batch_no_target( - batch, static_inputs=static_inputs, n_samples=end_idx - start_idx + batch, n_samples=end_idx - start_idx ) sample_agg.record_batch(outputs) to_log = sample_agg.get_wandb() @@ -245,22 +242,18 @@ def save_netcdf_data(self, ds: xr.Dataset): @property def _fine_latlon_coordinates(self) -> LatLonCoordinates | None: - if self.data.static_inputs is not None: - return self.data.static_inputs.coords - else: - return None + return None def run(self): aggregator = NoTargetAggregator( downscale_factor=self.model.downscale_factor, latlon_coordinates=self._fine_latlon_coordinates, ) - for i, (batch, static_inputs) in enumerate(self.batch_generator): + for i, batch in enumerate(self.batch_generator): with torch.no_grad(): logging.info(f"Generating predictions on batch {i + 1}") prediction = self.generation_model.generate_on_batch_no_target( batch=batch, - static_inputs=static_inputs, n_samples=self.n_samples, ) logging.info("Recording diagnostics to aggregator") @@ -302,7 +295,6 @@ def build(self) -> list[Downscaler | EventDownscaler]: model = self.model.build() dataset = self.data.build( requirements=self.model.data_requirements, - static_inputs=model.static_inputs, ) downscaler = Downscaler( data=dataset, @@ -316,7 +308,6 @@ def build(self) -> list[Downscaler | EventDownscaler]: event_dataset = event_config.get_gridded_data( base_data_config=self.data, requirements=self.model.data_requirements, - static_inputs_from_checkpoint=model.static_inputs, ) event_downscalers.append( EventDownscaler( diff --git a/fme/downscaling/predictors/cascade.py b/fme/downscaling/predictors/cascade.py index 0f64c73b5..e113c0cef 100644 --- a/fme/downscaling/predictors/cascade.py +++ b/fme/downscaling/predictors/cascade.py @@ -3,28 +3,16 @@ import torch -from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device from fme.core.tensors import unfold_ensemble_dim from fme.core.typing_ import TensorDict, TensorMapping -from fme.downscaling.data import ( - BatchData, - ClosedInterval, - PairedBatchData, - StaticInputs, - adjust_fine_coord_range, - scale_tuple, -) +from fme.downscaling.data import BatchData, PairedBatchData, scale_tuple from fme.downscaling.metrics_and_maths import filter_tensor_mapping from fme.downscaling.models import CheckpointModelConfig, DiffusionModel, ModelOutputs from fme.downscaling.requirements import DataRequirements from fme.downscaling.typing_ import FineResCoarseResPair -def _closed_interval_from_coord(coord: torch.Tensor) -> ClosedInterval: - return ClosedInterval(start=coord.min().item(), stop=coord.max().item()) - - @dataclasses.dataclass class CascadePredictorConfig: """ @@ -46,9 +34,6 @@ def models(self): self._models = [cfg.build() for cfg in self.cascade_model_checkpoints] return self._models - def get_static_inputs(self) -> list[StaticInputs | None]: - return [model.static_inputs for model in self.models] - def build(self): for m in range(len(self.models) - 1): output_shape = scale_tuple( @@ -63,9 +48,7 @@ def build(self): f"input shape {input_shape_next_step} of model at step {m+1}. " ) - return CascadePredictor( - models=self.models, static_inputs=self.get_static_inputs() - ) + return CascadePredictor(models=self.models) @property def data_requirements(self) -> DataRequirements: @@ -87,11 +70,8 @@ def _restore_batch_and_sample_dims(data: TensorMapping, n_samples: int): class CascadePredictor: - def __init__( - self, models: list[DiffusionModel], static_inputs: list[StaticInputs | None] - ): + def __init__(self, models: list[DiffusionModel]): self.models = models - self._static_inputs = static_inputs self.out_packer = self.models[-1].out_packer self.normalizer = FineResCoarseResPair( coarse=self.models[0].normalizer.coarse, @@ -120,10 +100,9 @@ def generate( self, coarse: TensorMapping, n_samples: int, - static_inputs: list[StaticInputs | None], ): current_coarse = coarse - for i, (model, fine_topography) in enumerate(zip(self.models, static_inputs)): + for i, model in enumerate(self.models): sample_data = next(iter(current_coarse.values())) batch_size = sample_data.shape[0] # n_samples are generated for the first step, and subsequent models @@ -131,7 +110,7 @@ def generate( n_samples_cascade_step = n_samples if i == 0 else 1 generated, generated_norm, latent_steps = model.generate( - current_coarse, fine_topography, n_samples_cascade_step + current_coarse, n_samples=n_samples_cascade_step ) generated = { k: v.reshape(batch_size * n_samples_cascade_step, *v.shape[-2:]) @@ -145,28 +124,18 @@ def generate( def generate_on_batch_no_target( self, batch: BatchData, - static_inputs: StaticInputs | None, n_samples: int = 1, ) -> TensorDict: - subset_static_inputs = self._get_subset_static_inputs( - coarse_coords=batch.latlon_coordinates[0] - ) - generated, _, _ = self.generate(batch.data, n_samples, subset_static_inputs) + generated, _, _ = self.generate(batch.data, n_samples) return generated @torch.no_grad() def generate_on_batch( self, batch: PairedBatchData, - static_inputs: list[StaticInputs | None], n_samples: int = 1, ) -> ModelOutputs: - static_inputs = self._get_subset_static_inputs( - coarse_coords=batch.coarse.latlon_coordinates[0] - ) - generated, _, latent_steps = self.generate( - batch.coarse.data, n_samples, static_inputs - ) + generated, _, latent_steps = self.generate(batch.coarse.data, n_samples) targets = filter_tensor_mapping(batch.fine.data, set(self.out_packer.names)) targets = {k: v.unsqueeze(1) for k, v in targets.items()} @@ -176,43 +145,3 @@ def generate_on_batch( loss=torch.tensor(float("inf"), device=get_device()), latent_steps=latent_steps, ) - - def _get_subset_static_inputs( - self, - coarse_coords: LatLonCoordinates, - ) -> list[StaticInputs | None]: - # Intermediate topographies are loaded as full range and need to be subset - # to the matching lat/lon range for each batch. - # TODO: Will eventually move subsetting into checkpoint model. - subset_static_inputs: list[StaticInputs | None] = [] - _coarse_coords = coarse_coords - lat_range = _closed_interval_from_coord(_coarse_coords.lat) - lon_range = _closed_interval_from_coord(_coarse_coords.lon) - - for i, full_intermediate_static_inputs in enumerate(self._static_inputs): - if full_intermediate_static_inputs is not None: - _adjusted_lat_range = adjust_fine_coord_range( - lat_range, - _coarse_coords.lat, - full_intermediate_static_inputs.coords.lat, - downscale_factor=self.models[i].downscale_factor, - ) - _adjusted_lon_range = adjust_fine_coord_range( - lon_range, - _coarse_coords.lon, - full_intermediate_static_inputs.coords.lon, - downscale_factor=self.models[i].downscale_factor, - ) - subset_interm_static_inputs = ( - full_intermediate_static_inputs.subset_latlon( - lat_interval=_adjusted_lat_range, - lon_interval=_adjusted_lon_range, - ) - ) - _coarse_coords = subset_interm_static_inputs.coords - lat_range = _closed_interval_from_coord(_coarse_coords.lat) - lon_range = _closed_interval_from_coord(_coarse_coords.lon) - else: - subset_interm_static_inputs = None - subset_static_inputs.append(subset_interm_static_inputs) - return subset_static_inputs diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index 8cbd08e83..5c375e417 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -3,9 +3,8 @@ import torch from fme.core.typing_ import TensorDict -from fme.downscaling.data import BatchData, PairedBatchData, StaticInputs, scale_tuple +from fme.downscaling.data import BatchData, PairedBatchData, scale_tuple from fme.downscaling.data.patching import Patch, get_patches -from fme.downscaling.data.utils import null_generator from fme.downscaling.models import DiffusionModel, ModelOutputs from fme.downscaling.predictors import CascadePredictor @@ -106,7 +105,6 @@ def _get_patches( def generate_on_batch( self, batch: PairedBatchData, - static_inputs: StaticInputs | None, n_samples: int = 1, ) -> ModelOutputs: predictions = [] @@ -119,17 +117,9 @@ def generate_on_batch( batch_generator = batch.generate_from_patches( coarse_patches=coarse_patches, fine_patches=fine_patches ) - if static_inputs is not None: - static_inputs_generator = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_generator = null_generator(len(fine_patches)) - - for data_patch, static_inputs_patch in zip( - batch_generator, static_inputs_generator - ): - model_output = self.model.generate_on_batch( - data_patch, static_inputs_patch, n_samples - ) + + for data_patch in batch_generator: + model_output = self.model.generate_on_batch(data_patch, n_samples) predictions.append(model_output.prediction) loss = loss + model_output.loss @@ -147,7 +137,6 @@ def generate_on_batch( def generate_on_batch_no_target( self, batch: BatchData, - static_inputs: StaticInputs | None, n_samples: int = 1, ) -> TensorDict: coarse_yx_extent = batch.horizontal_shape @@ -157,17 +146,10 @@ def generate_on_batch_no_target( ) predictions = [] batch_generator = batch.generate_from_patches(coarse_patches) - if static_inputs is not None: - static_inputs_generator = static_inputs.generate_from_patches(fine_patches) - else: - static_inputs_generator = null_generator(len(fine_patches)) - for data_patch, static_inputs_patch in zip( - batch_generator, static_inputs_generator - ): + for data_patch in batch_generator: predictions.append( self.model.generate_on_batch_no_target( batch=data_patch, - static_inputs=static_inputs_patch, n_samples=n_samples, ) ) diff --git a/fme/downscaling/predictors/test_cascade.py b/fme/downscaling/predictors/test_cascade.py index 99511d576..50884b07e 100644 --- a/fme/downscaling/predictors/test_cascade.py +++ b/fme/downscaling/predictors/test_cascade.py @@ -1,28 +1,14 @@ import pytest import torch -from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device from fme.core.loss import LossConfig from fme.core.normalizer import NormalizationConfig -from fme.downscaling.data import StaticInput, StaticInputs from fme.downscaling.models import DiffusionModelConfig, PairedNormalizationConfig from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.predictors.cascade import CascadePredictor -def _get_static_inputs(shape, coords, n_fields=1): - return StaticInputs( - fields=[ - StaticInput( - data=torch.randn(shape, device=get_device()), - coords=coords, - ) - for _ in range(n_fields) - ] - ) - - def _latlon_coords_on_ngrid(n: int, edges=(0, 100)): start, end = edges dx = (end - start) / n @@ -30,7 +16,7 @@ def _latlon_coords_on_ngrid(n: int, edges=(0, 100)): return LatLonCoordinates(lat=midpoints, lon=midpoints) -def _get_diffusion_model(coarse_shape, downscale_factor, static_inputs=None): +def _get_diffusion_model(coarse_shape, downscale_factor): normalizer = PairedNormalizationConfig( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), @@ -51,46 +37,28 @@ def _get_diffusion_model(coarse_shape, downscale_factor, static_inputs=None): churn=0.5, num_diffusion_generation_steps=3, predict_residual=True, - use_fine_topography=True, + use_fine_topography=False, ).build( coarse_shape=coarse_shape, downscale_factor=downscale_factor, - static_inputs=static_inputs, ) @pytest.mark.parametrize("downscale_factors", [[2, 4], [2, 3, 4]]) def test_CascadePredictor_generate(downscale_factors): n_times, n_samples_generate, nside_coarse = 3, 2, 4 - grid_bounds = (0, 100) models = [] - static_inputs_list: list[StaticInputs | None] = [] input_n_cells = nside_coarse for downscale_factor in downscale_factors: - static_inputs_list.append( - _get_static_inputs( - shape=( - input_n_cells * downscale_factor, - input_n_cells * downscale_factor, - ), - coords=_latlon_coords_on_ngrid( - n=input_n_cells * downscale_factor, edges=grid_bounds - ), - n_fields=1, - ) - ) input_n_cells *= downscale_factor models.append( _get_diffusion_model( coarse_shape=(input_n_cells, input_n_cells), downscale_factor=downscale_factor, - static_inputs=static_inputs_list[-1], ) ) - cascade_predictor = CascadePredictor( - models=models, static_inputs=static_inputs_list - ) + cascade_predictor = CascadePredictor(models=models) coarse_input = { "x": torch.randn( (n_times, nside_coarse, nside_coarse), @@ -101,7 +69,6 @@ def test_CascadePredictor_generate(downscale_factors): generated, _, _ = cascade_predictor.generate( coarse=coarse_input, n_samples=n_samples_generate, - static_inputs=static_inputs_list, ) expected_nside = cascade_predictor.downscale_factor * nside_coarse assert generated["x"].shape == ( @@ -110,57 +77,3 @@ def test_CascadePredictor_generate(downscale_factors): expected_nside, expected_nside, ) - - -def test_CascadePredictor__subset_topographies(): - nside_coarse = 8 - downscale_factors = [2, 2] - grid_bounds = (0, 8) - models = [] - static_inputs_list: list[StaticInputs | None] = [] - input_n_cells = nside_coarse - - for downscale_factor in downscale_factors: - models.append( - _get_diffusion_model( - coarse_shape=(input_n_cells, input_n_cells), - downscale_factor=downscale_factor, - ) - ) - static_inputs_list.append( - _get_static_inputs( - shape=( - input_n_cells * downscale_factor, - input_n_cells * downscale_factor, - ), - coords=_latlon_coords_on_ngrid( - n=input_n_cells * downscale_factor, edges=grid_bounds - ), - n_fields=1, - ) - ) - input_n_cells *= downscale_factor - - cascade_predictor = CascadePredictor( - models=models, static_inputs=static_inputs_list - ) - # Coarse grid subset has 1.0 grid spacing and midpoints 1.5 ... 4.5 - coarse_coords = _latlon_coords_on_ngrid(n=4, edges=(1, 5)) - subset_intermediate_topographies = cascade_predictor._get_subset_static_inputs( - coarse_coords=coarse_coords - ) - - # First topography grid 0.5 grid spacing - assert isinstance(subset_intermediate_topographies[0], StaticInputs) - assert subset_intermediate_topographies[0].shape == (8, 8) - assert subset_intermediate_topographies[0].coords.lat[0] == 1.25 - assert subset_intermediate_topographies[0].coords.lat[-1] == 4.75 - assert subset_intermediate_topographies[0].coords.lon[0] == 1.25 - assert subset_intermediate_topographies[0].coords.lon[-1] == 4.75 - # Second topography grid has 0.25 grid spacing - assert isinstance(subset_intermediate_topographies[1], StaticInputs) - assert subset_intermediate_topographies[1].shape == (16, 16) - assert subset_intermediate_topographies[1].coords.lat[0] == 1.125 - assert subset_intermediate_topographies[1].coords.lat[-1] == 4.875 - assert subset_intermediate_topographies[1].coords.lon[0] == 1.125 - assert subset_intermediate_topographies[1].coords.lon[-1] == 4.875 diff --git a/fme/downscaling/predictors/test_composite.py b/fme/downscaling/predictors/test_composite.py index f6b6fa60d..9f31ff05d 100644 --- a/fme/downscaling/predictors/test_composite.py +++ b/fme/downscaling/predictors/test_composite.py @@ -6,7 +6,7 @@ from fme.core.device import get_device from fme.core.packer import Packer from fme.downscaling.aggregators.shape_helpers import upsample_tensor -from fme.downscaling.data import BatchData, PairedBatchData, StaticInput, StaticInputs +from fme.downscaling.data import BatchData, PairedBatchData from fme.downscaling.data.patching import get_patches from fme.downscaling.data.utils import BatchedLatLonCoordinates from fme.downscaling.models import ModelOutputs @@ -16,10 +16,6 @@ ) -def _get_static_inputs(shape, coords): - return StaticInputs(fields=[StaticInput(data=torch.randn(shape), coords=coords)]) - - def test_composite_predictions(): patch_yx_size = (2, 2) patches = get_patches((4, 4), patch_yx_size, overlap=0) @@ -54,9 +50,7 @@ def __init__(self, coarse_shape, downscale_factor): self.modules = [] self.out_packer = Packer(["x"]) - def generate_on_batch( - self, batch: PairedBatchData, static_inputs: StaticInputs | None, n_samples=1 - ): + def generate_on_batch(self, batch: PairedBatchData, n_samples=1): prediction_data = { k: v.unsqueeze(1).expand(-1, n_samples, -1, -1) for k, v in batch.fine.data.items() @@ -65,9 +59,7 @@ def generate_on_batch( prediction=prediction_data, target=prediction_data, loss=torch.tensor(1.0) ) - def generate_on_batch_no_target( - self, batch: BatchData, static_inputs: StaticInputs | None, n_samples=1 - ): + def generate_on_batch_no_target(self, batch: BatchData, n_samples=1): prediction_data = { k: upsample_tensor( v.unsqueeze(1).expand(-1, n_samples, -1, -1), @@ -137,13 +129,6 @@ def test_SpatialCompositePredictor_generate_on_batch(patch_size_coarse): paired_batch_data = get_paired_test_data( *coarse_extent, downscale_factor=downscale_factor, batch_size=batch_size ) - static_inputs = _get_static_inputs( - shape=( - coarse_extent[0] * downscale_factor, - coarse_extent[1] * downscale_factor, - ), - coords=paired_batch_data.fine.latlon_coordinates[0], - ) predictor = PatchPredictor( DummyModel(coarse_shape=patch_size_coarse, downscale_factor=downscale_factor), # type: ignore @@ -152,7 +137,7 @@ def test_SpatialCompositePredictor_generate_on_batch(patch_size_coarse): ) n_samples_generate = 2 outputs = predictor.generate_on_batch( - paired_batch_data, static_inputs, n_samples=n_samples_generate + paired_batch_data, n_samples=n_samples_generate ) assert outputs.prediction["x"].shape == (batch_size, n_samples_generate, 8, 8) # dummy model predicts same value as fine data for all samples @@ -174,13 +159,6 @@ def test_SpatialCompositePredictor_generate_on_batch_no_target(patch_size_coarse paired_batch_data = get_paired_test_data( *coarse_extent, downscale_factor=downscale_factor, batch_size=batch_size ) - static_inputs = _get_static_inputs( - shape=( - coarse_extent[0] * downscale_factor, - coarse_extent[1] * downscale_factor, - ), - coords=paired_batch_data.fine.latlon_coordinates[0], - ) predictor = PatchPredictor( DummyModel(coarse_shape=patch_size_coarse, downscale_factor=2), # type: ignore coarse_extent, @@ -189,6 +167,6 @@ def test_SpatialCompositePredictor_generate_on_batch_no_target(patch_size_coarse n_samples_generate = 2 coarse_batch_data = paired_batch_data.coarse prediction = predictor.generate_on_batch_no_target( - coarse_batch_data, static_inputs, n_samples=n_samples_generate + coarse_batch_data, n_samples=n_samples_generate ) assert prediction["x"].shape == (batch_size, n_samples_generate, 8, 8) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index d6798e66f..e550fe4f1 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -206,7 +206,7 @@ def test_from_state_backward_compat_fine_topography(): # At runtime, omitting static inputs must raise a clear error batch = get_mock_paired_batch([2, *coarse_shape], [2, *fine_shape]) with pytest.raises(ValueError, match="Static inputs must be provided"): - model_from_old_state.generate_on_batch(batch, static_inputs=None) + model_from_old_state.generate_on_batch(batch) def _get_diffusion_model( @@ -265,13 +265,12 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph assert model._get_fine_shape(coarse_shape) == fine_shape optimization = OptimizationConfig().build(modules=[model.module], max_epochs=2) - train_outputs = model.train_on_batch(batch, static_inputs, optimization) + train_outputs = model.train_on_batch(batch, optimization) assert torch.allclose(train_outputs.target["x"], batch.fine.data["x"]) n_generated_samples = 2 generated_outputs = [ - model.generate_on_batch(batch, static_inputs) - for _ in range(n_generated_samples) + model.generate_on_batch(batch) for _ in range(n_generated_samples) ] for generated_output in generated_outputs: @@ -392,7 +391,7 @@ def test_model_error_cases(): # missing fine topography when model requires it batch.fine.topography = None with pytest.raises(ValueError): - model.generate_on_batch(batch, static_inputs=None) + model.generate_on_batch(batch) def test_DiffusionModel_generate_on_batch_no_target(): @@ -417,7 +416,6 @@ def test_DiffusionModel_generate_on_batch_no_target(): samples = model.generate_on_batch_no_target( coarse_batch, - static_inputs=static_inputs, n_samples=n_generated_samples, ) @@ -457,9 +455,7 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): coarse_batch = make_batch_data( (batch_size, *alternative_input_shape), coarse_lat, coarse_lon ) - samples = model.generate_on_batch_no_target( - coarse_batch, n_samples=n_ensemble, static_inputs=None - ) + samples = model.generate_on_batch_no_target(coarse_batch, n_samples=n_ensemble) assert samples["x"].shape == ( batch_size, diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index 271d884be..a1dd08fde 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -174,11 +174,11 @@ def train_one_epoch(self) -> None: self.train_data, random_offset=True, shuffle=True ) outputs = None - for i, (batch, static_inputs) in enumerate(train_batch_generator): + for i, batch in enumerate(train_batch_generator): self.num_batches_seen += 1 if i % 10 == 0: logging.info(f"Training on batch {i + 1}") - outputs = self.model.train_on_batch(batch, static_inputs, self.optimization) + outputs = self.model.train_on_batch(batch, self.optimization) self.ema(self.model.modules) with torch.no_grad(): train_aggregator.record_batch( @@ -250,10 +250,8 @@ def valid_one_epoch(self) -> dict[str, float]: validation_batch_generator = self._get_batch_generator( self.validation_data, random_offset=False, shuffle=False ) - for batch, static_inputs in validation_batch_generator: - outputs = self.model.train_on_batch( - batch, static_inputs, self.null_optimization - ) + for batch in validation_batch_generator: + outputs = self.model.train_on_batch(batch, self.null_optimization) validation_aggregator.record_batch( outputs=outputs, coarse=batch.coarse.data, @@ -261,7 +259,6 @@ def valid_one_epoch(self) -> dict[str, float]: ) generated_outputs = self.model.generate_on_batch( batch, - static_inputs=static_inputs, n_samples=self.config.generate_n_samples, ) # Add sample dimension to coarse values for generation comparison @@ -429,12 +426,10 @@ def build(self) -> Trainer: train_data: PairedGriddedData = self.train_data.build( train=True, requirements=self.model.data_requirements, - static_inputs=static_inputs, ) validation_data: PairedGriddedData = self.validation_data.build( train=False, requirements=self.model.data_requirements, - static_inputs=static_inputs, ) if self.coarse_patch_extent_lat and self.coarse_patch_extent_lon: model_coarse_shape = ( From af4e22d582c175506a45cd1115bcb16abb7c23a2 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Wed, 11 Mar 2026 21:38:36 -0700 Subject: [PATCH 02/20] Test passing for inference --- fme/downscaling/data/datasets.py | 6 +- fme/downscaling/inference/inference.py | 53 ++++++++++----- fme/downscaling/inference/output.py | 4 +- fme/downscaling/inference/test_inference.py | 72 ++++++++------------- fme/downscaling/inference/work_items.py | 4 +- fme/downscaling/models.py | 6 +- fme/downscaling/predictors/cascade.py | 4 ++ fme/downscaling/predictors/composite.py | 4 ++ fme/downscaling/predictors/test_cascade.py | 7 -- 9 files changed, 81 insertions(+), 79 deletions(-) diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index 3f073edfc..4f6e05435 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -305,8 +305,7 @@ def on_device(batch: BatchItem) -> BatchItem: return SizedMap(on_device, self._loader) def get_generator(self) -> Iterator["BatchData"]: - for batch in self.loader: - yield batch + yield from self.loader def get_patched_generator( self, @@ -347,8 +346,7 @@ def on_device(batch: PairedBatchItem) -> PairedBatchItem: return SizedMap(on_device, self._loader) def get_generator(self) -> Iterator["PairedBatchData"]: - for batch in self.loader: - yield batch + yield from self.loader def get_patched_generator( self, diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index daf9b5338..9b875a688 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -55,7 +55,7 @@ def run_all(self): def _get_generation_model( self, - fine_shape: tuple[int, int], + input_shape: tuple[int, int], output: DownscalingOutput, ) -> DiffusionModel | PatchPredictor | CascadePredictor: """ @@ -65,20 +65,22 @@ def _get_generation_model( the user to use patching for larger domains because that provides better generations. """ - model_patch_shape = self.model.fine_shape - actual_shape = fine_shape + model_patch_shape = self.model.coarse_shape - if model_patch_shape == actual_shape: + if model_patch_shape == input_shape: # short circuit, no patching necessary return self.model elif any( expected > actual - for expected, actual in zip(model_patch_shape, actual_shape) + for expected, actual in zip(model_patch_shape, input_shape) ): # we don't support generating regions smaller than the model patch size raise ValueError( - f"Model fine shape {model_patch_shape} is larger than " - f"actual fine shape {actual_shape} for output {output.name}." + f"Model coarse shape {model_patch_shape} is larger than " + f"actual input shape {input_shape} for output {output.name}." + "We do not support generating outputs with a smaller spatial extent" + " than the model's trained patch size. Please adjust the spatial extent" + " to be at least as large as the model's input patch size." ) elif output.patch.needs_patch_predictor: # Use a patch predictor @@ -91,7 +93,7 @@ def _get_generation_model( # User should enable patching raise ValueError( f"Model coarse shape {model_patch_shape} does not match " - f"actual input shape {actual_shape} for output {output.name}, " + f"actual input shape {input_shape} for output {output.name}, " "and patch prediction is not configured. Generation for larger domains " "requires patch prediction." ) @@ -112,22 +114,43 @@ def run_output_generation(self, output: DownscalingOutput): loaded_item: LoadedSliceWorkItem for i, loaded_item in enumerate(output.data.get_generator()): + input_shape = loaded_item.batch.horizontal_shape + if model is None: + model = self._get_generation_model( + input_shape=input_shape, output=output + ) + if writer is None: coarse_lat = loaded_item.batch.latlon_coordinates.lat[0] coarse_lon = loaded_item.batch.latlon_coordinates.lon[0] + if model.static_inputs is None: + raise ValueError( + "Model is missing static inputs, which are required to " + "determine the coordinate information for the output " + "dataset. Please ensure the model is configured with " + "static inputs. This will be fixed in a future update." + ) + # TODO: this is a definciency of the implementation needing + # fine spatial information to determine the output region. + # Right now that requires that we use the model static + # inputs to get the fine coordinates, but we should be able + # to get the fine coordinate information from the output config + # instead, which would remove the need for the model to have + # static inputs at all. This works because the batch always + # contains the full coarse spatial extent fine_lat_interval = adjust_fine_coord_range( loaded_item.batch.lat_interval, full_coarse_coord=coarse_lat, - full_fine_coord=self.model.static_inputs.coords.lat, - downscale_factor=self.model.downscale_factor, + full_fine_coord=model.static_inputs.coords.lat, + downscale_factor=model.downscale_factor, ) fine_lon_interval = adjust_fine_coord_range( loaded_item.batch.lon_interval, full_coarse_coord=coarse_lon, - full_fine_coord=self.model.static_inputs.coords.lon, - downscale_factor=self.model.downscale_factor, + full_fine_coord=model.static_inputs.coords.lon, + downscale_factor=model.downscale_factor, ) - fine_static_inputs = self.model.static_inputs.subset_latlon( + fine_static_inputs = model.static_inputs.subset_latlon( fine_lat_interval, fine_lon_interval, ) @@ -138,10 +161,6 @@ def run_output_generation(self, output: DownscalingOutput): writer.initialize_store( fine_static_inputs.fields[0].data.cpu().numpy().dtype ) - if model is None: - model = self._get_generation_model( - fine_shape=fine_static_inputs.shape, output=output - ) logging.info( f"[{output.name}] Batch {i+1}/{total_batches}, " diff --git a/fme/downscaling/inference/output.py b/fme/downscaling/inference/output.py index c2a9f5f12..2ff091e7b 100644 --- a/fme/downscaling/inference/output.py +++ b/fme/downscaling/inference/output.py @@ -152,7 +152,7 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, - fine_shape: tuple[int, int] | None = None, + output_fine_shape: tuple[int, int], ) -> DownscalingOutput: """ Build an OutputTarget from this configuration. @@ -161,6 +161,8 @@ def build( loader_config: Base data loader configuration to modify requirements: Model's data requirements (variable names, etc.) patch: Default patch prediction configuration + output_fine_shape: Fine shape of the output used as metadata + for the shape of the output to insert into the dataset """ pass diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index 4d6bb0020..ed97d2b60 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -25,7 +25,6 @@ EventConfig, TimeRangeConfig, ) -from fme.downscaling.inference.work_items import SliceWorkItemGriddedData from fme.downscaling.models import ( CheckpointModelConfig, DiffusionModelConfig, @@ -92,7 +91,7 @@ def test_get_generation_model_exact_match(mock_model, mock_output_target): """ Test _get_generation_model returns model unchanged when shapes match exactly. """ - mock_model.fine_shape = (16, 16) + mock_model.coarse_shape = (16, 16) downscaler = Downscaler( model=mock_model, @@ -100,22 +99,22 @@ def test_get_generation_model_exact_match(mock_model, mock_output_target): ) result = downscaler._get_generation_model( - fine_shape=(16, 16), + input_shape=(16, 16), output=mock_output_target, ) assert result is mock_model -@pytest.mark.parametrize("topo_shape", [(8, 16), (16, 8), (8, 8)]) +@pytest.mark.parametrize("input_shape", [(8, 16), (16, 8), (8, 8)]) def test_get_generation_model_raises_when_domain_too_small( - mock_model, mock_output_target, topo_shape + mock_model, mock_output_target, input_shape ): """ Test _get_generation_model raises ValueError when domain is smaller than model. """ - mock_model.fine_shape = (16, 16) + mock_model.coarse_shape = (16, 16) downscaler = Downscaler( model=mock_model, @@ -124,7 +123,7 @@ def test_get_generation_model_raises_when_domain_too_small( with pytest.raises(ValueError): downscaler._get_generation_model( - fine_shape=topo_shape, + input_shape=input_shape, output=mock_output_target, ) @@ -136,7 +135,7 @@ def test_get_generation_model_creates_patch_predictor_when_needed( Test _get_generation_model creates PatchPredictor for large domains with patching. """ - mock_model.fine_shape = (16, 16) + mock_model.coarse_shape = (16, 16) patch_config = PatchPredictionConfig( divide_generation=True, @@ -150,7 +149,7 @@ def test_get_generation_model_creates_patch_predictor_when_needed( ) model = downscaler._get_generation_model( - fine_shape=(32, 32), + input_shape=(32, 32), output=mock_output_target, ) @@ -165,7 +164,7 @@ def test_get_generation_model_raises_when_large_domain_without_patching( Test _get_generation_model raises when domain is large but patching not configured. """ - mock_model.fine_shape = (16, 16) + mock_model.coarse_shape = (16, 16) mock_output_target.patch = PatchPredictionConfig(divide_generation=False) downscaler = Downscaler( @@ -175,7 +174,7 @@ def test_get_generation_model_raises_when_large_domain_without_patching( with pytest.raises(ValueError): downscaler._get_generation_model( - fine_shape=(32, 32), + input_shape=(32, 32), output=mock_output_target, ) @@ -184,44 +183,30 @@ def test_run_target_generation_skips_padding_items( mock_model, mock_output_target, ): - """Test run_target_generation skips writing output for padding items.""" - # Create padding work item + """ + Test run_output_generation calls the model but skips writing for padding items. + """ mock_work_item = MagicMock() mock_work_item.is_padding = True - mock_work_item.n_ens = 4 - # to_device() should return self so batch coords are preserved - mock_work_item.to_device.return_value = mock_work_item - - # Set up proper batch coordinates for adjust_fine_coord_range to work. + mock_work_item.batch.horizontal_shape = (16, 16) # Coarse coords are interior so fine can have buffer on each side. - coarse_lat = torch.arange(1, 9).float() # 8 values: 1..8 - coarse_lon = torch.arange(1, 9).float() - mock_latlon = MagicMock() - mock_latlon.lat = coarse_lat.unsqueeze(0) - mock_latlon.lon = coarse_lon.unsqueeze(0) - mock_work_item.batch.latlon_coordinates = mock_latlon + mock_work_item.batch.latlon_coordinates.lat = ( + torch.arange(1, 9).float().unsqueeze(0) + ) + mock_work_item.batch.latlon_coordinates.lon = ( + torch.arange(1, 9).float().unsqueeze(0) + ) mock_work_item.batch.lat_interval = ClosedInterval(1.0, 8.0) mock_work_item.batch.lon_interval = ClosedInterval(1.0, 8.0) + mock_output_target.data.get_generator.return_value = iter([mock_work_item]) - fine_static_inputs = get_static_inputs(shape=(16, 16)) - fine_lat = torch.arange(0, 18).float() # 18 values: 0..17 - fine_lon = torch.arange(0, 18).float() - mock_model.static_inputs.coords.lat = fine_lat - mock_model.static_inputs.coords.lon = fine_lon - mock_model.static_inputs.subset_latlon.return_value = fine_static_inputs mock_model.downscale_factor = 2 - mock_model.fine_shape = (16, 16) - - mock_gridded_data = SliceWorkItemGriddedData( - [mock_work_item], {}, [0], torch.float32, (1, 4, 16, 16) - ) - mock_output_target.data = mock_gridded_data - - mock_output = { - "var1": torch.randn(1, 4, 16, 16), - "var2": torch.randn(1, 4, 16, 16), + mock_model.static_inputs.coords.lat = torch.arange(0, 18).float() + mock_model.static_inputs.coords.lon = torch.arange(0, 18).float() + mock_model.static_inputs.subset_latlon.return_value.fields[0].data = torch.zeros(1) + mock_model.generate_on_batch_no_target.return_value = { + "var1": torch.zeros(1, 4, 16, 16), } - mock_model.generate_on_batch_no_target.return_value = mock_output mock_writer = MagicMock() mock_output_target.get_writer.return_value = mock_writer @@ -233,11 +218,8 @@ def test_run_target_generation_skips_padding_items( downscaler.run_output_generation(output=mock_output_target) - # Verify model was still called mock_model.generate_on_batch_no_target.assert_called_once() - - # Verify the mock writer was not called - mock_writer.write_batch.assert_not_called() + mock_writer.record_batch.assert_not_called() # Tests for end-to-end generation process diff --git a/fme/downscaling/inference/work_items.py b/fme/downscaling/inference/work_items.py index ece36912d..25049cf5b 100644 --- a/fme/downscaling/inference/work_items.py +++ b/fme/downscaling/inference/work_items.py @@ -310,6 +310,4 @@ def on_device(work_item: LoadedSliceWorkItem) -> LoadedSliceWorkItem: return SizedMap(on_device, self._loader) def get_generator(self) -> Iterator[LoadedSliceWorkItem]: - work_item: LoadedSliceWorkItem - for work_item in self.loader: - yield work_item + yield from self.loader diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index ff7e44ee5..d567b7c7d 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -286,7 +286,8 @@ def __init__( normalizer: The normalizer object used for data normalization. loss: The loss function used for training the model. coarse_shape: The height (lat) and width (lon) of the - coarse-resolution input data. + coarse-resolution input data used to train the model + (same as patch extent, if training on patches). downscale_factor: The factor by which the data is downscaled from coarse to fine. sigma_data: The standard deviation of the data, used for diffusion @@ -338,7 +339,8 @@ def fine_shape(self) -> tuple[int, int]: def _get_fine_shape(self, coarse_shape: tuple[int, int]) -> tuple[int, int]: """ - Calculate the fine shape based on the coarse shape and downscale factor. + Calculate the fine shape based on the coarse shape of data used to train + the model and the downscaling factor. """ return ( coarse_shape[0] * self.downscale_factor, diff --git a/fme/downscaling/predictors/cascade.py b/fme/downscaling/predictors/cascade.py index e113c0cef..cd0a974b0 100644 --- a/fme/downscaling/predictors/cascade.py +++ b/fme/downscaling/predictors/cascade.py @@ -87,6 +87,10 @@ def coarse_shape(self): def fine_shape(self): return self.models[-1].fine_shape + @property + def static_inputs(self): + return self.models[-1].static_inputs + @property def downscale_factor(self): return math.prod([model.downscale_factor for model in self.models]) diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index 5c375e417..c0986ac30 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -79,6 +79,10 @@ def __init__( def coarse_shape(self): return self.coarse_yx_patch_extent + @property + def static_inputs(self): + return self.model.static_inputs + def _get_patches( self, coarse_yx_extent, fine_yx_extent ) -> tuple[list[Patch], list[Patch]]: diff --git a/fme/downscaling/predictors/test_cascade.py b/fme/downscaling/predictors/test_cascade.py index 50884b07e..be7055e1c 100644 --- a/fme/downscaling/predictors/test_cascade.py +++ b/fme/downscaling/predictors/test_cascade.py @@ -9,13 +9,6 @@ from fme.downscaling.predictors.cascade import CascadePredictor -def _latlon_coords_on_ngrid(n: int, edges=(0, 100)): - start, end = edges - dx = (end - start) / n - midpoints = (start + (torch.arange(n) + 0.5) * dx).to(device=get_device()) - return LatLonCoordinates(lat=midpoints, lon=midpoints) - - def _get_diffusion_model(coarse_shape, downscale_factor): normalizer = PairedNormalizationConfig( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), From 3fc1fe8988a6ae4d0fcd4e6de5d5ee17422d77e7 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 13 Mar 2026 15:23:33 -0700 Subject: [PATCH 03/20] Remove unneded static code --- fme/downscaling/data/static.py | 23 ---------- fme/downscaling/data/test_static.py | 70 ----------------------------- 2 files changed, 93 deletions(-) diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 2c7199cf9..4a860174c 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -1,12 +1,10 @@ import dataclasses -from collections.abc import Generator, Iterator import torch import xarray as xr from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device -from fme.downscaling.data.patching import Patch from fme.downscaling.data.utils import ClosedInterval @@ -53,11 +51,6 @@ def to_device(self) -> "StaticInput": ), ) - def _apply_patch(self, patch: Patch): - return self._latlon_index_slice( - lat_slice=patch.input_slice.y, lon_slice=patch.input_slice.x - ) - def _latlon_index_slice( self, lat_slice: slice, @@ -73,13 +66,6 @@ def _latlon_index_slice( coords=sliced_latlon, ) - def generate_from_patches( - self, - patches: list[Patch], - ) -> Generator["StaticInput", None, None]: - for patch in patches: - yield self._apply_patch(patch) - def get_state(self) -> dict: return { "data": self.data.cpu(), @@ -162,15 +148,6 @@ def subset_latlon( def to_device(self) -> "StaticInputs": return StaticInputs(fields=[field.to_device() for field in self.fields]) - def generate_from_patches( - self, - patches: list[Patch], - ) -> Iterator["StaticInputs"]: - for patch in patches: - yield StaticInputs( - fields=[field._apply_patch(patch) for field in self.fields] - ) - def get_state(self) -> dict: return { "fields": [field.get_state() for field in self.fields], diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index 8aeace714..edfc0e080 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -2,7 +2,6 @@ import torch from fme.core.coordinates import LatLonCoordinates -from fme.downscaling.data.patching import Patch, _HorizontalSlice from .static import StaticInput, StaticInputs from .utils import ClosedInterval @@ -48,75 +47,6 @@ def test_subset_latlon(): assert torch.allclose(subset_topo.data, expected_data) -def test_Topography_generate_from_patches(): - output_slice = _HorizontalSlice(y=slice(None), x=slice(None)) - patches = [ - Patch( - input_slice=_HorizontalSlice(y=slice(1, 3), x=slice(None, None)), - output_slice=output_slice, - ), - Patch( - input_slice=_HorizontalSlice(y=slice(0, 2), x=slice(2, 3)), - output_slice=output_slice, - ), - ] - topography = StaticInput( - torch.arange(16).reshape(4, 4), - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - topo_patch_generator = topography.generate_from_patches(patches) - generated_patches = [] - for topo_patch in topo_patch_generator: - generated_patches.append(topo_patch) - assert len(generated_patches) == 2 - assert torch.equal( - generated_patches[0].data, torch.tensor([[4, 5, 6, 7], [8, 9, 10, 11]]) - ) - assert torch.equal(generated_patches[1].data, torch.tensor([[2], [6]])) - - -def test_StaticInputs_generate_from_patches(): - output_slice = _HorizontalSlice(y=slice(None), x=slice(None)) - patches = [ - Patch( - input_slice=_HorizontalSlice(y=slice(1, 3), x=slice(None, None)), - output_slice=output_slice, - ), - Patch( - input_slice=_HorizontalSlice(y=slice(0, 2), x=slice(2, 3)), - output_slice=output_slice, - ), - ] - data = torch.arange(16).reshape(4, 4) - topography = StaticInput( - data, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - land_frac = StaticInput( - data * -1.0, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - static_inputs = StaticInputs([topography, land_frac]) - static_inputs_patch_generator = static_inputs.generate_from_patches(patches) - generated_patches = [] - for static_inputs_patch in static_inputs_patch_generator: - generated_patches.append(static_inputs_patch) - - assert len(generated_patches) == 2 - - expected_topography_patch_0 = torch.tensor([[4, 5, 6, 7], [8, 9, 10, 11]]) - expected_topography_patch_1 = torch.tensor([[2], [6]]) - - # first index is the patch, second is the static input field within - # the StaticInputs container - assert torch.equal(generated_patches[0][0].data, expected_topography_patch_0) - assert torch.equal(generated_patches[1][0].data, expected_topography_patch_1) - - # land_frac field values are -1 * topography - assert torch.equal(generated_patches[0][1].data, expected_topography_patch_0 * -1.0) - assert torch.equal(generated_patches[1][1].data, expected_topography_patch_1 * -1.0) - - def test_StaticInputs_serialize(): data = torch.arange(16).reshape(4, 4) topography = StaticInput( From 3b9c19335fc3bb4f6f141b7a5b693bcd7906f03a Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 13 Mar 2026 16:18:35 -0700 Subject: [PATCH 04/20] updates based on feedback --- fme/downscaling/data/datasets.py | 3 +++ fme/downscaling/inference/output.py | 4 ++-- fme/downscaling/predict.py | 12 +++++++----- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index 4f6e05435..bce3d59c4 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -640,6 +640,9 @@ def __iter__(self): return iter(indices[start:end]) +# downscale_factor=None means fine patches not needed here, but reusing +# _get_paired_patches in both paired and no-target cases to share the +# coincident offset logic. def _get_paired_patches( coarse_yx_extent: tuple[int, int], coarse_yx_patch_extent: tuple[int, int], diff --git a/fme/downscaling/inference/output.py b/fme/downscaling/inference/output.py index 2ff091e7b..2af28e27e 100644 --- a/fme/downscaling/inference/output.py +++ b/fme/downscaling/inference/output.py @@ -152,7 +152,7 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, - output_fine_shape: tuple[int, int], + fine_shape: tuple[int, int], ) -> DownscalingOutput: """ Build an OutputTarget from this configuration. @@ -161,7 +161,7 @@ def build( loader_config: Base data loader configuration to modify requirements: Model's data requirements (variable names, etc.) patch: Default patch prediction configuration - output_fine_shape: Fine shape of the output used as metadata + fine_shape: Fine shape of the output used as metadata for the shape of the output to insert into the dataset """ pass diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index a3397a40f..33b1f11c3 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -235,14 +235,16 @@ def save_netcdf_data(self, ds: xr.Dataset): f"{self.experiment_dir}/generated_maps_and_metrics.nc", mode="w" ) - @property - def _fine_latlon_coordinates(self) -> LatLonCoordinates | None: - return None - def run(self): + # TODO: remove when coordinates are stored with the model + if self.model.static_inputs is None: + raise ValueError( + "Model must have static inputs with coordinates for downscaling " + "generation." + ) aggregator = NoTargetAggregator( downscale_factor=self.model.downscale_factor, - latlon_coordinates=self._fine_latlon_coordinates, + latlon_coordinates=self.model.static_inputs.fields[0].coords, ) for i, batch in enumerate(self.batch_generator): with torch.no_grad(): From 066b28b59f96629e166a237492e99de087bb1dfb Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 13 Mar 2026 16:41:45 -0700 Subject: [PATCH 05/20] simplify the inference code --- fme/downscaling/inference/inference.py | 79 +++++++++++++------------- 1 file changed, 41 insertions(+), 38 deletions(-) diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index 40e897d8d..584e93b43 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -3,10 +3,12 @@ from dataclasses import dataclass, field import dacite +import numpy as np import torch import yaml from fme.core.cli import prepare_directory +from fme.core.coordinates import LatLonCoordinates from fme.core.generics.trainer import count_parameters from fme.core.logging_utils import LoggingConfig @@ -17,6 +19,42 @@ from .work_items import LoadedSliceWorkItem +def _get_fine_coords( + model: DiffusionModel | PatchPredictor, work_item: LoadedSliceWorkItem +) -> LatLonCoordinates: + # TODO: simplify to just use model coordinates when those are saved instead + # of them being appended to static inputs + if model.static_inputs is None: + raise ValueError( + "Model is missing static inputs, which are required to " + "determine the coordinate information for the output " + "dataset. Please ensure the model is configured with " + "static inputs. This will be fixed in a future update." + ) + coarse_lat = work_item.batch.latlon_coordinates.lat[0] + coarse_lon = work_item.batch.latlon_coordinates.lon[0] + fine_lat_interval = adjust_fine_coord_range( + work_item.batch.lat_interval, + full_coarse_coord=coarse_lat, + full_fine_coord=model.static_inputs.coords.lat, + downscale_factor=model.downscale_factor, + ) + fine_lon_interval = adjust_fine_coord_range( + work_item.batch.lon_interval, + full_coarse_coord=coarse_lon, + full_fine_coord=model.static_inputs.coords.lon, + downscale_factor=model.downscale_factor, + ) + return LatLonCoordinates( + lat=model.static_inputs.coords.lat[ + fine_lat_interval.slice_of(model.static_inputs.coords.lat) + ], + lon=model.static_inputs.coords.lon[ + fine_lon_interval.slice_of(model.static_inputs.coords.lon) + ], + ) + + class Downscaler: """ Orchestrates downscaling generation across multiple outputs. @@ -104,7 +142,6 @@ def run_output_generation(self, output: DownscalingOutput): # initialize writer and model in loop for coord info model = None writer = None - fine_static_inputs = None total_batches = len(output.data.loader) loaded_item: LoadedSliceWorkItem @@ -116,46 +153,12 @@ def run_output_generation(self, output: DownscalingOutput): ) if writer is None: - coarse_lat = loaded_item.batch.latlon_coordinates.lat[0] - coarse_lon = loaded_item.batch.latlon_coordinates.lon[0] - if model.static_inputs is None: - raise ValueError( - "Model is missing static inputs, which are required to " - "determine the coordinate information for the output " - "dataset. Please ensure the model is configured with " - "static inputs. This will be fixed in a future update." - ) - # TODO: this is a definciency of the implementation needing - # fine spatial information to determine the output region. - # Right now that requires that we use the model static - # inputs to get the fine coordinates, but we should be able - # to get the fine coordinate information from the output config - # instead, which would remove the need for the model to have - # static inputs at all. This works because the batch always - # contains the full coarse spatial extent - fine_lat_interval = adjust_fine_coord_range( - loaded_item.batch.lat_interval, - full_coarse_coord=coarse_lat, - full_fine_coord=model.static_inputs.coords.lat, - downscale_factor=model.downscale_factor, - ) - fine_lon_interval = adjust_fine_coord_range( - loaded_item.batch.lon_interval, - full_coarse_coord=coarse_lon, - full_fine_coord=model.static_inputs.coords.lon, - downscale_factor=model.downscale_factor, - ) - fine_static_inputs = model.static_inputs.subset_latlon( - fine_lat_interval, - fine_lon_interval, - ) + fine_latlon_coords = _get_fine_coords(model, loaded_item) writer = output.get_writer( - latlon_coords=fine_static_inputs.coords, + latlon_coords=fine_latlon_coords, output_dir=self.output_dir, ) - writer.initialize_store( - fine_static_inputs.fields[0].data.cpu().numpy().dtype - ) + writer.initialize_store(np.float32) logging.info( f"[{output.name}] Batch {i+1}/{total_batches}, " From 67b4dccac3e6d8fa02ad2a938727cb19fe95b5b5 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 14:03:41 -0700 Subject: [PATCH 06/20] Fix renaming by adding static_inputs to model in test_predict --- fme/downscaling/test_predict.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index 6626d5a01..ce253ab14 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -147,7 +147,7 @@ def test_predictor_renaming( coarse_shape = (4, 4) downscale_factor = 2 renaming = {"var0": "var0_renamed", "var1": "var1_renamed"} - predictor_config_path, _ = create_predictor_config( + predictor_config_path, fine_data_path = create_predictor_config( tmp_path, n_samples, model_renaming=renaming, @@ -158,7 +158,11 @@ def test_predictor_renaming( model_config = get_model_config( coarse_shape, downscale_factor, use_fine_topography=False ) - model = model_config.build(coarse_shape=coarse_shape, downscale_factor=2) + model = model_config.build( + coarse_shape=coarse_shape, + downscale_factor=2, + static_inputs=load_static_inputs({"HGTsfc": fine_data_path}), + ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) os.makedirs( From fb980f0c95a05cdb2669e55c820fd1fca30f862b Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 14:11:46 -0700 Subject: [PATCH 07/20] Simplify latlon coordinate retrieval from batch and model info --- fme/downscaling/inference/inference.py | 41 ++----------------------- fme/downscaling/models.py | 31 +++++++++++++++++++ fme/downscaling/predict.py | 20 ++++++------ fme/downscaling/predictors/composite.py | 4 +++ 4 files changed, 47 insertions(+), 49 deletions(-) diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index 584e93b43..5ba046140 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -8,53 +8,16 @@ import yaml from fme.core.cli import prepare_directory -from fme.core.coordinates import LatLonCoordinates from fme.core.generics.trainer import count_parameters from fme.core.logging_utils import LoggingConfig -from ..data import DataLoaderConfig, adjust_fine_coord_range +from ..data import DataLoaderConfig from ..models import CheckpointModelConfig, DiffusionModel from ..predictors import PatchPredictionConfig, PatchPredictor from .output import DownscalingOutput, EventConfig, TimeRangeConfig from .work_items import LoadedSliceWorkItem -def _get_fine_coords( - model: DiffusionModel | PatchPredictor, work_item: LoadedSliceWorkItem -) -> LatLonCoordinates: - # TODO: simplify to just use model coordinates when those are saved instead - # of them being appended to static inputs - if model.static_inputs is None: - raise ValueError( - "Model is missing static inputs, which are required to " - "determine the coordinate information for the output " - "dataset. Please ensure the model is configured with " - "static inputs. This will be fixed in a future update." - ) - coarse_lat = work_item.batch.latlon_coordinates.lat[0] - coarse_lon = work_item.batch.latlon_coordinates.lon[0] - fine_lat_interval = adjust_fine_coord_range( - work_item.batch.lat_interval, - full_coarse_coord=coarse_lat, - full_fine_coord=model.static_inputs.coords.lat, - downscale_factor=model.downscale_factor, - ) - fine_lon_interval = adjust_fine_coord_range( - work_item.batch.lon_interval, - full_coarse_coord=coarse_lon, - full_fine_coord=model.static_inputs.coords.lon, - downscale_factor=model.downscale_factor, - ) - return LatLonCoordinates( - lat=model.static_inputs.coords.lat[ - fine_lat_interval.slice_of(model.static_inputs.coords.lat) - ], - lon=model.static_inputs.coords.lon[ - fine_lon_interval.slice_of(model.static_inputs.coords.lon) - ], - ) - - class Downscaler: """ Orchestrates downscaling generation across multiple outputs. @@ -153,7 +116,7 @@ def run_output_generation(self, output: DownscalingOutput): ) if writer is None: - fine_latlon_coords = _get_fine_coords(model, loaded_item) + fine_latlon_coords = model.get_fine_coords_for_batch(loaded_item.batch) writer = output.get_writer( latlon_coords=fine_latlon_coords, output_dir=self.output_dir, diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index d567b7c7d..2d0d1057e 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -6,6 +6,7 @@ import dacite import torch +from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device from fme.core.distributed import Distributed from fme.core.loss import LossConfig @@ -333,6 +334,36 @@ def _subset_static_inputs( ) return self.static_inputs.subset_latlon(lat_interval, lon_interval) + def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: + """Return fine-resolution coordinates matching the spatial extent of batch.""" + if self.static_inputs is None: + raise ValueError( + "Model is missing static inputs, which are required to determine " + "the coordinate information for the output dataset." + ) + coarse_lat = batch.latlon_coordinates.lat[0] + coarse_lon = batch.latlon_coordinates.lon[0] + fine_lat_interval = adjust_fine_coord_range( + batch.lat_interval, + full_coarse_coord=coarse_lat, + full_fine_coord=self.static_inputs.coords.lat, + downscale_factor=self.downscale_factor, + ) + fine_lon_interval = adjust_fine_coord_range( + batch.lon_interval, + full_coarse_coord=coarse_lon, + full_fine_coord=self.static_inputs.coords.lon, + downscale_factor=self.downscale_factor, + ) + return LatLonCoordinates( + lat=self.static_inputs.coords.lat[ + fine_lat_interval.slice_of(self.static_inputs.coords.lat) + ], + lon=self.static_inputs.coords.lon[ + fine_lon_interval.slice_of(self.static_inputs.coords.lon) + ], + ) + @property def fine_shape(self) -> tuple[int, int]: return self._get_fine_shape(self.coarse_shape) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index 33b1f11c3..b4da331aa 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -236,17 +236,14 @@ def save_netcdf_data(self, ds: xr.Dataset): ) def run(self): - # TODO: remove when coordinates are stored with the model - if self.model.static_inputs is None: - raise ValueError( - "Model must have static inputs with coordinates for downscaling " - "generation." - ) - aggregator = NoTargetAggregator( - downscale_factor=self.model.downscale_factor, - latlon_coordinates=self.model.static_inputs.fields[0].coords, - ) + aggregator: NoTargetAggregator | None = None for i, batch in enumerate(self.batch_generator): + if aggregator is None: + fine_coords = self.model.get_fine_coords_for_batch(batch) + aggregator = NoTargetAggregator( + downscale_factor=self.model.downscale_factor, + latlon_coordinates=fine_coords, + ) with torch.no_grad(): logging.info(f"Generating predictions on batch {i + 1}") prediction = self.generation_model.generate_on_batch_no_target( @@ -257,6 +254,9 @@ def run(self): # Add sample dimension to coarse values for generation comparison coarse = {k: v.unsqueeze(1) for k, v in batch.data.items()} aggregator.record_batch(prediction, coarse, batch.time) + + # dataset build ensures non-empty batch_generator + assert aggregator is not None logs = aggregator.get_wandb() wandb = WandB.get_instance() wandb.log(logs, step=0) diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index bfd6162b2..05add74a0 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -2,6 +2,7 @@ import torch +from fme.core.coordinates import LatLonCoordinates from fme.core.typing_ import TensorDict from fme.downscaling.data import BatchData, PairedBatchData, scale_tuple from fme.downscaling.data.patching import Patch, get_patches @@ -82,6 +83,9 @@ def coarse_shape(self): def static_inputs(self): return self.model.static_inputs + def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: + return self.model.get_fine_coords_for_batch(batch) + def _get_patches( self, coarse_yx_extent, fine_yx_extent ) -> tuple[list[Patch], list[Patch]]: From e2241ea2131135ed0362214a08d54242faa160bf Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 14:52:08 -0700 Subject: [PATCH 08/20] Add test --- fme/downscaling/test_models.py | 46 ++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index e550fe4f1..864f511f3 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -532,6 +532,52 @@ def test_noise_config_error(): ) +def test_get_fine_coords_for_batch(): + # Model trained on full coarse (8x16) / fine (16x32) grid + coarse_shape = (8, 16) + fine_shape = (16, 32) + downscale_factor = 2 + static_inputs = make_static_inputs(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=downscale_factor, + use_fine_topography=True, + static_inputs=static_inputs, + ) + + # Build a batch covering a spatial patch: middle 4 coarse lats and 8 coarse lons. + full_coarse_lat = _get_monotonic_coordinate(coarse_shape[0], stop=fine_shape[0]) + full_coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1]) + patch_coarse_lat = full_coarse_lat[2:6].tolist() # [5, 7, 9, 11] + patch_coarse_lon = full_coarse_lon[4:12].tolist() # [9, 11, ..., 23] + batch = make_batch_data((2, 4, 8), patch_coarse_lat, patch_coarse_lon) + + result = model.get_fine_coords_for_batch(batch) + + expected_lat = model.static_inputs.coords.lat[4:12] + expected_lon = model.static_inputs.coords.lon[8:24] + # model.static_inputs has been moved to device; index into it directly + # to match devices + assert torch.allclose(result.lat, expected_lat) + assert torch.allclose(result.lon, expected_lon) + + +def test_get_fine_coords_for_batch_raises_without_static_inputs(): + model = _get_diffusion_model( + coarse_shape=(16, 16), + downscale_factor=2, + use_fine_topography=False, + static_inputs=None, + ) + batch = make_batch_data( + (1, 16, 16), + _get_monotonic_coordinate(16, stop=16).tolist(), + _get_monotonic_coordinate(16, stop=16).tolist(), + ) + with pytest.raises(ValueError, match="missing static inputs"): + model.get_fine_coords_for_batch(batch) + + def test_checkpoint_config_topography_raises(): with pytest.raises(ValueError): CheckpointModelConfig( From 73f36515952d366f838381a13b7053d81e5cab66 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 14:58:07 -0700 Subject: [PATCH 09/20] Update config references to static inputs --- fme/downscaling/data/config.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fme/downscaling/data/config.py b/fme/downscaling/data/config.py index a8d9cb022..eda213fbc 100644 --- a/fme/downscaling/data/config.py +++ b/fme/downscaling/data/config.py @@ -151,8 +151,9 @@ class DataLoaderConfig: (For multi-GPU runtime, it's the number of workers per GPU.) strict_ensemble: Whether to enforce that the datasets to be concatened have the same dimensions and coordinates. - topography: Deprecated field for specifying the topography dataset. Now - provided via build method's `static_inputs` argument. + topography: Deprecated field for specifying the topography dataset. + StaticInput data are expected to be stored and serialized within a + model through the Trainer build process. lat_extent: The latitude extent to use for the dataset specified in degrees, limited to (-88.0, 88.0). The extent is inclusive, so the start and stop values are included in the extent. Defaults to [-66, 70] which @@ -189,8 +190,8 @@ def __post_init__(self): if self.topography is not None: raise ValueError( "The `topography` field on DataLoaderConfig is deprecated and will be " - "removed in a future release. Pass static_inputs via build's " - "`static_inputs` argument instead." + "removed in a future release. `StaticInputs` are now stored within " + " the model when it is first built and trained." ) @property @@ -334,7 +335,6 @@ class PairedDataLoaderConfig: time dimension. Useful to include longer sequences of small data for testing. topography: Deprecated field for specifying the topography dataset. - Now provided via build method's `static_inputs` argument. sample_with_replacement: If provided, the dataset will be sampled randomly with replacement to the given size each period, instead of retrieving each sample once (either shuffled or not). @@ -366,8 +366,8 @@ def __post_init__(self): if self.topography is not None: raise ValueError( "The `topography` field on PairedDataLoaderConfig is deprecated and " - "will be removed in a future release. Pass static_inputs via the " - "build method's `static_inputs` argument instead." + "will be removed in a future release. `StaticInputs` are now stored " + "within the model when it is first built and trained." ) def _first_data_config( From 08030d1f7368765360f7f2312e298879870a8131 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Thu, 12 Mar 2026 16:44:16 -0700 Subject: [PATCH 10/20] Initial shot --- fme/downscaling/data/config.py | 7 +- fme/downscaling/data/datasets.py | 1 + fme/downscaling/data/static.py | 80 +++----------- fme/downscaling/data/test_static.py | 64 ++++++----- fme/downscaling/inference/test_inference.py | 23 ++-- fme/downscaling/models.py | 111 ++++++++++++++++++-- fme/downscaling/predict.py | 53 +++++----- fme/downscaling/predictors/composite.py | 4 + fme/downscaling/test_models.py | 78 ++++++++++++-- fme/downscaling/test_predict.py | 28 +++-- fme/downscaling/train.py | 1 + 11 files changed, 296 insertions(+), 154 deletions(-) diff --git a/fme/downscaling/data/config.py b/fme/downscaling/data/config.py index eda213fbc..14396a8e7 100644 --- a/fme/downscaling/data/config.py +++ b/fme/downscaling/data/config.py @@ -23,7 +23,11 @@ PairedBatchData, PairedGriddedData, ) -from fme.downscaling.data.utils import ClosedInterval, adjust_fine_coord_range +from fme.downscaling.data.utils import ( + ClosedInterval, + adjust_fine_coord_range, + get_latlon_coords_from_properties, +) from fme.downscaling.requirements import DataRequirements @@ -529,6 +533,7 @@ def build( dims=example.fine.latlon_coordinates.dims, variable_metadata=variable_metadata, all_times=all_times, + fine_coords=get_latlon_coords_from_properties(properties_fine), ) def _get_sampler( diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index bce3d59c4..2b6798f04 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -337,6 +337,7 @@ class PairedGriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex + fine_coords: LatLonCoordinates | None = None @property def loader(self) -> DataLoader[PairedBatchItem]: diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index 4a860174c..ae4721f8d 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -3,26 +3,17 @@ import torch import xarray as xr -from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device -from fme.downscaling.data.utils import ClosedInterval +from fme.downscaling.data.patching import Patch @dataclasses.dataclass class StaticInput: data: torch.Tensor - coords: LatLonCoordinates def __post_init__(self): if len(self.data.shape) != 2: raise ValueError(f"Topography data must be 2D. Got shape {self.data.shape}") - if self.data.shape[0] != len(self.coords.lat) or self.data.shape[1] != len( - self.coords.lon - ): - raise ValueError( - f"Static inputs data shape {self.data.shape} does not match lat/lon " - f"coordinates shape {(len(self.coords.lat), len(self.coords.lon))}" - ) @property def dim(self) -> int: @@ -32,44 +23,23 @@ def dim(self) -> int: def shape(self) -> tuple[int, int]: return self.data.shape - def subset_latlon( + def subset( self, - lat_interval: ClosedInterval, - lon_interval: ClosedInterval, + lat_slice: slice, + lon_slice: slice, ) -> "StaticInput": - lat_slice = lat_interval.slice_of(self.coords.lat) - lon_slice = lon_interval.slice_of(self.coords.lon) - return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice) + return StaticInput(data=self.data[lat_slice, lon_slice]) def to_device(self) -> "StaticInput": device = get_device() - return StaticInput( - data=self.data.to(device), - coords=LatLonCoordinates( - lat=self.coords.lat.to(device), - lon=self.coords.lon.to(device), - ), - ) + return StaticInput(data=self.data.to(device)) - def _latlon_index_slice( - self, - lat_slice: slice, - lon_slice: slice, - ) -> "StaticInput": - sliced_data = self.data[lat_slice, lon_slice] - sliced_latlon = LatLonCoordinates( - lat=self.coords.lat[lat_slice], - lon=self.coords.lon[lon_slice], - ) - return StaticInput( - data=sliced_data, - coords=sliced_latlon, - ) + def _apply_patch(self, patch: Patch): + return self.subset(lat_slice=patch.input_slice.y, lon_slice=patch.input_slice.x) def get_state(self) -> dict: return { "data": self.data.cpu(), - "coords": self.coords.get_state(), } @@ -93,17 +63,11 @@ def _get_normalized_static_input(path: str, field_name: str): f"unexpected shape {static_input.shape} for static input." "Currently, only lat/lon static input is supported." ) - lat_name, lon_name = static_input.dims[-2:] - coords = LatLonCoordinates( - lon=torch.tensor(static_input[lon_name].values), - lat=torch.tensor(static_input[lat_name].values), - ) static_input_normalized = (static_input - static_input.mean()) / static_input.std() return StaticInput( data=torch.tensor(static_input_normalized.values, dtype=torch.float32), - coords=coords, ) @@ -113,36 +77,28 @@ class StaticInputs: def __post_init__(self): for i, field in enumerate(self.fields[1:]): - if field.coords != self.fields[0].coords: + if field.shape != self.fields[0].shape: raise ValueError( - f"All StaticInput fields must have the same coordinates. " - f"Fields {i} and 0 do not match coordinates." + f"All StaticInput fields must have the same shape. " + f"Fields {i + 1} and 0 do not match shapes." ) def __getitem__(self, index: int): return self.fields[index] - @property - def coords(self) -> LatLonCoordinates: - if len(self.fields) == 0: - raise ValueError("No fields in StaticInputs to get coordinates from.") - return self.fields[0].coords - @property def shape(self) -> tuple[int, int]: if len(self.fields) == 0: raise ValueError("No fields in StaticInputs to get shape from.") return self.fields[0].shape - def subset_latlon( + def subset( self, - lat_interval: ClosedInterval, - lon_interval: ClosedInterval, + lat_slice: slice, + lon_slice: slice, ) -> "StaticInputs": return StaticInputs( - fields=[ - field.subset_latlon(lat_interval, lon_interval) for field in self.fields - ] + fields=[field.subset(lat_slice, lon_slice) for field in self.fields] ) def to_device(self) -> "StaticInputs": @@ -159,10 +115,6 @@ def from_state(cls, state: dict) -> "StaticInputs": fields=[ StaticInput( data=field_state["data"], - coords=LatLonCoordinates( - lat=field_state["coords"]["lat"], - lon=field_state["coords"]["lon"], - ), ) for field_state in state["fields"] ] @@ -171,7 +123,7 @@ def from_state(cls, state: dict) -> "StaticInputs": def load_static_inputs( static_inputs_config: dict[str, str] | None, -) -> StaticInputs | None: +) -> "StaticInputs | None": """ Load normalized static inputs from a mapping of field names to file paths. Returns None if the input config is empty. diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index edfc0e080..104ffdc9f 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -1,26 +1,16 @@ import pytest import torch -from fme.core.coordinates import LatLonCoordinates - from .static import StaticInput, StaticInputs -from .utils import ClosedInterval @pytest.mark.parametrize( "init_args", [ pytest.param( - [ - torch.randn((1, 2, 2)), - LatLonCoordinates(torch.arange(2), torch.arange(2)), - ], + [torch.randn((1, 2, 2))], id="3d_data", ), - pytest.param( - [torch.randn((2, 2)), LatLonCoordinates(torch.arange(2), torch.arange(5))], - id="dim_size_mismatch", - ), ], ) def test_Topography_error_cases(init_args): @@ -28,37 +18,43 @@ def test_Topography_error_cases(init_args): StaticInput(*init_args) -def test_subset_latlon(): +def test_subset(): full_data_shape = (10, 10) - expected_slices = [slice(2, 6), slice(3, 8)] data = torch.randn(*full_data_shape) - coords = LatLonCoordinates( - lat=torch.linspace(0, 9, 10), lon=torch.linspace(0, 9, 10) - ) - topo = StaticInput(data=data, coords=coords) - lat_interval = ClosedInterval(2, 5) - lon_interval = ClosedInterval(3, 7) - subset_topo = topo.subset_latlon(lat_interval, lon_interval) - expected_lats = torch.tensor([2, 3, 4, 5], dtype=coords.lat.dtype) - expected_lons = torch.tensor([3, 4, 5, 6, 7], dtype=coords.lon.dtype) - expected_data = data[*expected_slices] - assert torch.equal(subset_topo.coords.lat, expected_lats) - assert torch.equal(subset_topo.coords.lon, expected_lons) - assert torch.allclose(subset_topo.data, expected_data) + topo = StaticInput(data=data) + lat_slice = slice(2, 6) + lon_slice = slice(3, 8) + subset_topo = topo.subset(lat_slice, lon_slice) + assert torch.allclose(subset_topo.data, data[lat_slice, lon_slice]) def test_StaticInputs_serialize(): data = torch.arange(16).reshape(4, 4) - topography = StaticInput( - data, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) - land_frac = StaticInput( - data * -1.0, - LatLonCoordinates(torch.arange(4), torch.arange(4)), - ) + topography = StaticInput(data) + land_frac = StaticInput(data * -1.0) static_inputs = StaticInputs([topography, land_frac]) state = static_inputs.get_state() + # Verify coords are NOT stored in state + assert "coords" not in state["fields"][0] static_inputs_reconstructed = StaticInputs.from_state(state) assert static_inputs_reconstructed[0].data.equal(static_inputs[0].data) assert static_inputs_reconstructed[1].data.equal(static_inputs[1].data) + + +def test_StaticInputs_serialize_backward_compat_with_coords(): + """from_state should silently ignore 'coords' key for old state dicts.""" + data = torch.arange(16, dtype=torch.float32).reshape(4, 4) + # Simulate old state dict format that included coords + old_state = { + "fields": [ + { + "data": data, + "coords": { + "lat": torch.arange(4, dtype=torch.float32), + "lon": torch.arange(4, dtype=torch.float32), + }, + } + ] + } + static_inputs = StaticInputs.from_state(old_state) + assert torch.equal(static_inputs[0].data, data) diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index ed97d2b60..c06b7b0cc 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -64,8 +64,7 @@ def mock_output_target(): def get_static_inputs(shape=(16, 16)): data = torch.randn(shape) - coords = LatLonCoordinates(lat=torch.arange(shape[0]), lon=torch.arange(shape[1])) - return StaticInputs([StaticInput(data=data, coords=coords)]) + return StaticInputs([StaticInput(data=data)]) # Tests for Downscaler initialization @@ -201,9 +200,11 @@ def test_run_target_generation_skips_padding_items( mock_output_target.data.get_generator.return_value = iter([mock_work_item]) mock_model.downscale_factor = 2 - mock_model.static_inputs.coords.lat = torch.arange(0, 18).float() - mock_model.static_inputs.coords.lon = torch.arange(0, 18).float() - mock_model.static_inputs.subset_latlon.return_value.fields[0].data = torch.zeros(1) + mock_model.fine_coords = LatLonCoordinates( + lat=torch.arange(0, 18).float(), + lon=torch.arange(0, 18).float(), + ) + mock_model.static_inputs = None mock_model.generate_on_batch_no_target.return_value = { "var1": torch.zeros(1, 4, 16, 16), } @@ -273,8 +274,16 @@ def checkpointed_model_config( # loader_config is passed in to add static inputs into model # that correspond to the dataset coordinates - static_inputs = load_static_inputs({"HGTsfc": f"{data_paths.fine}/data.nc"}) - model = model_config.build(coarse_shape, 2, static_inputs=static_inputs) + fine_data_path = f"{data_paths.fine}/data.nc" + static_inputs = load_static_inputs({"HGTsfc": fine_data_path}) + ds = xr.open_dataset(fine_data_path) + fine_coords = LatLonCoordinates( + lat=torch.tensor(ds["lat"].values, dtype=torch.float32), + lon=torch.tensor(ds["lon"].values, dtype=torch.float32), + ) + model = model_config.build( + coarse_shape, 2, static_inputs=static_inputs, fine_coords=fine_coords + ) checkpoint_path = tmp_path / "model_checkpoint.pth" model.get_state() diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 2d0d1057e..1b5108b74 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -5,6 +5,7 @@ import dacite import torch +import xarray as xr from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device @@ -183,6 +184,7 @@ def build( downscale_factor: int, rename: dict[str, str] | None = None, static_inputs: StaticInputs | None = None, + fine_coords: LatLonCoordinates | None = None, ) -> "DiffusionModel": invert_rename = {v: k for k, v in (rename or {}).items()} orig_in_names = [invert_rename.get(name, name) for name in self.in_names] @@ -221,6 +223,7 @@ def build( downscale_factor=downscale_factor, sigma_data=sigma_data, static_inputs=static_inputs, + fine_coords=fine_coords, ) def get_state(self) -> Mapping[str, Any]: @@ -277,6 +280,7 @@ def __init__( downscale_factor: int, sigma_data: float, static_inputs: StaticInputs | None = None, + fine_coords: LatLonCoordinates | None = None, ) -> None: """ Args: @@ -295,6 +299,8 @@ def __init__( model preconditioning. static_inputs: Static inputs to the model, loaded from the trainer config or checkpoint. Must be set when use_fine_topography is True. + fine_coords: Full-domain fine-resolution coordinates. Used as the + single coordinate authority for output spatial metadata. """ self.coarse_shape = coarse_shape self.downscale_factor = downscale_factor @@ -310,6 +316,14 @@ def __init__( self.static_inputs = ( static_inputs.to_device() if static_inputs is not None else None ) + self.fine_coords = fine_coords + if fine_coords is not None and static_inputs is not None: + expected = (len(fine_coords.lat), len(fine_coords.lon)) + if static_inputs.shape != expected: + raise ValueError( + f"static_inputs shape {static_inputs.shape} does not match " + f"fine_coords grid {expected}" + ) @property def modules(self) -> torch.nn.ModuleList: @@ -332,7 +346,13 @@ def _subset_static_inputs( "Static inputs must be provided for each batch when use of fine " "static inputs is enabled." ) - return self.static_inputs.subset_latlon(lat_interval, lon_interval) + if self.fine_coords is None: + raise ValueError( + "fine_coords must be set on the model to subset static inputs." + ) + lat_slice = lat_interval.slice_of(self.fine_coords.lat) + lon_slice = lon_interval.slice_of(self.fine_coords.lon) + return self.static_inputs.subset(lat_slice, lon_slice) def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: """Return fine-resolution coordinates matching the spatial extent of batch.""" @@ -541,23 +561,28 @@ def generate_on_batch_no_target( "Static inputs must be provided for each batch when use of fine " "static inputs is enabled." ) + if self.fine_coords is None: + raise ValueError( + "fine_coords must be set on the model when use_fine_topography " + "is enabled." + ) coarse_lat = batch.latlon_coordinates.lat[0] coarse_lon = batch.latlon_coordinates.lon[0] fine_lat_interval = adjust_fine_coord_range( batch.lat_interval, full_coarse_coord=coarse_lat, - full_fine_coord=self.static_inputs.coords.lat, + full_fine_coord=self.fine_coords.lat, downscale_factor=self.downscale_factor, ) fine_lon_interval = adjust_fine_coord_range( batch.lon_interval, full_coarse_coord=coarse_lon, - full_fine_coord=self.static_inputs.coords.lon, + full_fine_coord=self.fine_coords.lon, downscale_factor=self.downscale_factor, ) - _static_inputs = self.static_inputs.subset_latlon( - fine_lat_interval, fine_lon_interval - ) + lat_slice = fine_lat_interval.slice_of(self.fine_coords.lat) + lon_slice = fine_lon_interval.slice_of(self.fine_coords.lon) + _static_inputs = self.static_inputs.subset(lat_slice, lon_slice) else: _static_inputs = None generated, _, _ = self.generate(batch.data, _static_inputs, n_samples) @@ -602,6 +627,9 @@ def get_state(self) -> Mapping[str, Any]: "coarse_shape": self.coarse_shape, "downscale_factor": self.downscale_factor, "static_inputs": static_inputs_state, + "fine_coords": ( + self.fine_coords.get_state() if self.fine_coords is not None else None + ), } @classmethod @@ -615,10 +643,33 @@ def from_state( static_inputs = StaticInputs.from_state(state["static_inputs"]).to_device() else: static_inputs = None + + # Load fine_coords: new checkpoints store it directly; old checkpoints + # that had static_inputs with coords can auto-migrate from raw state. + if state.get("fine_coords") is not None: + fine_coords = LatLonCoordinates( + lat=state["fine_coords"]["lat"], + lon=state["fine_coords"]["lon"], + ) + elif ( + state.get("static_inputs") is not None + and len(state["static_inputs"].get("fields", [])) > 0 + and "coords" in state["static_inputs"]["fields"][0] + ): + # Backward compat: old checkpoints stored coords inside static_inputs fields + coords_state = state["static_inputs"]["fields"][0]["coords"] + fine_coords = LatLonCoordinates( + lat=coords_state["lat"], + lon=coords_state["lon"], + ) + else: + fine_coords = None + model = config.build( state["coarse_shape"], state["downscale_factor"], static_inputs=static_inputs, + fine_coords=fine_coords, ) model.module.load_state_dict(state["module"], strict=True) return model @@ -648,6 +699,9 @@ class CheckpointModelConfig: but the model requires static input data. Raises an error if the checkpoint already has static inputs from training. fine_topography_path: Deprecated. Use static_inputs instead. + fine_coordinates_path: Optional path to a netCDF/zarr file containing lat/lon + coordinates for the full fine domain. Used for old checkpoints that have + no static_inputs and no stored fine_coords. model_updates: Optional mapping of {key: new_value} model config updates to apply when loading the model. This is useful for running evaluation with updated parameters than at training time. Use with caution; not all @@ -658,6 +712,7 @@ class CheckpointModelConfig: rename: dict[str, str] | None = None static_inputs: dict[str, str] | None = None fine_topography_path: str | None = None + fine_coordinates_path: str | None = None model_updates: dict[str, Any] | None = None def __post_init__(self) -> None: @@ -688,6 +743,8 @@ def _checkpoint(self) -> Mapping[str, Any]: ] # backwards compatibility for models before static inputs serialization checkpoint_data["model"].setdefault("static_inputs", None) + # backwards compatibility for models before fine_coords serialization + checkpoint_data["model"].setdefault("fine_coords", None) self._checkpoint_data = checkpoint_data self._checkpoint_is_loaded = True @@ -696,6 +753,23 @@ def _checkpoint(self) -> Mapping[str, Any]: checkpoint_data["model"]["config"][k] = v return self._checkpoint_data + def _load_fine_coords_from_path(self, path: str) -> LatLonCoordinates: + if path.endswith(".zarr"): + ds = xr.open_zarr(path) + else: + ds = xr.open_dataset(path) + lat_name = next((n for n in ["lat", "latitude"] if n in ds.coords), None) + lon_name = next((n for n in ["lon", "longitude"] if n in ds.coords), None) + if lat_name is None or lon_name is None: + raise ValueError( + f"Could not find lat/lon coordinates in {path}. " + "Expected 'lat'/'latitude' and 'lon'/'longitude'." + ) + return LatLonCoordinates( + lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), + lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), + ) + def build( self, ) -> DiffusionModel: @@ -723,6 +797,31 @@ def build( static_inputs=static_inputs, ) model.module.load_state_dict(self._checkpoint["model"]["module"]) + + # Restore fine_coords: new checkpoints have it stored directly; old + # checkpoints may have coords embedded in static_inputs fields. + model_state = self._checkpoint["model"] + if model_state.get("fine_coords") is not None: + fine_coords_state = model_state["fine_coords"] + model.fine_coords = LatLonCoordinates( + lat=fine_coords_state["lat"], + lon=fine_coords_state["lon"], + ) + elif ( + model_state.get("static_inputs") is not None + and len(model_state["static_inputs"].get("fields", [])) > 0 + and "coords" in model_state["static_inputs"]["fields"][0] + ): + coords_state = model_state["static_inputs"]["fields"][0]["coords"] + model.fine_coords = LatLonCoordinates( + lat=coords_state["lat"], + lon=coords_state["lon"], + ) + elif self.fine_coordinates_path is not None: + model.fine_coords = self._load_fine_coords_from_path( + self.fine_coordinates_path + ) + return model @property diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index b4da331aa..a93361aa0 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -21,6 +21,7 @@ ClosedInterval, DataLoaderConfig, GriddedData, + adjust_fine_coord_range, enforce_lat_bounds, ) from fme.downscaling.models import CheckpointModelConfig, DiffusionModel @@ -29,31 +30,6 @@ from fme.downscaling.typing_ import FineResCoarseResPair -def _downscale_coord(coord: torch.tensor, downscale_factor: int): - """ - This is a bandaid fix for the issue where BatchData does not - contain coords for the topography, which is fine-res in the no-target - generation case. The SampleAggregator requires the fine-res coords - for the predictions. - - TODO: remove after topography refactors to have its own data container. - """ - if len(coord.shape) != 1: - raise ValueError("coord tensor to downscale must be 1d") - spacing = coord[1] - coord[0] - # Compute edges from midpoints - first_edge = coord[0] - spacing / 2 - last_edge = coord[-1] + spacing / 2 - - # Subdivide edges - step = spacing / downscale_factor - new_edges = torch.arange(first_edge, last_edge + step / 2, step) - - # Compute new midpoints - coord_new = (new_edges[:-1] + new_edges[1:]) / 2 - return coord_new.to(device=coord.device, dtype=coord.dtype) - - @dataclasses.dataclass class EventConfig: name: str @@ -145,9 +121,32 @@ def run(self): logging.info(f"Running {self.event_name} event downscaling...") batch = next(iter(self.data.get_generator())) coarse_coords = batch[0].latlon_coordinates + if self.model.fine_coords is None: + raise ValueError( + "Model fine_coords must be set for event downscaling output " + "coordinates." + ) + coarse_lat = coarse_coords.lat + coarse_lon = coarse_coords.lon + lat_interval = ClosedInterval(coarse_lat.min().item(), coarse_lat.max().item()) + lon_interval = ClosedInterval(coarse_lon.min().item(), coarse_lon.max().item()) + fine_lat_interval = adjust_fine_coord_range( + lat_interval, + full_coarse_coord=coarse_lat, + full_fine_coord=self.model.fine_coords.lat, + downscale_factor=self.model.downscale_factor, + ) + fine_lon_interval = adjust_fine_coord_range( + lon_interval, + full_coarse_coord=coarse_lon, + full_fine_coord=self.model.fine_coords.lon, + downscale_factor=self.model.downscale_factor, + ) + lat_slice = fine_lat_interval.slice_of(self.model.fine_coords.lat) + lon_slice = fine_lon_interval.slice_of(self.model.fine_coords.lon) fine_coords = LatLonCoordinates( - lat=_downscale_coord(coarse_coords.lat, self.model.downscale_factor), - lon=_downscale_coord(coarse_coords.lon, self.model.downscale_factor), + lat=self.model.fine_coords.lat[lat_slice], + lon=self.model.fine_coords.lon[lon_slice], ) sample_agg = SampleAggregator( coarse=batch[0].data, diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index 05add74a0..405ac1eed 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -85,6 +85,10 @@ def static_inputs(self): def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: return self.model.get_fine_coords_for_batch(batch) + + @property + def fine_coords(self): + return self.model.fine_coords def _get_patches( self, coarse_yx_extent, fine_yx_extent diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 864f511f3..dc916014c 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -124,30 +124,37 @@ def make_paired_batch_data( def make_static_inputs(fine_shape: tuple[int, int]) -> StaticInputs: - """Create StaticInputs with proper monotonic coordinates for given shape.""" - lat_size, lon_size = fine_shape + """Create StaticInputs for given shape.""" return StaticInputs( fields=[ StaticInput( torch.ones(*fine_shape, device=get_device()), - LatLonCoordinates( - lat=_get_monotonic_coordinate(lat_size, stop=lat_size), - lon=_get_monotonic_coordinate(lon_size, stop=lon_size), - ), ) ] ) +def make_fine_coords(fine_shape: tuple[int, int]) -> LatLonCoordinates: + """Create LatLonCoordinates for given fine shape.""" + lat_size, lon_size = fine_shape + return LatLonCoordinates( + lat=_get_monotonic_coordinate(lat_size, stop=lat_size), + lon=_get_monotonic_coordinate(lon_size, stop=lon_size), + ) + + def test_module_serialization(tmp_path): coarse_shape = (8, 16) - static_inputs = make_static_inputs((16, 32)) + fine_shape = (16, 32) + static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, predict_residual=True, use_fine_topography=False, static_inputs=static_inputs, + fine_coords=fine_coords, ) model_from_state = DiffusionModel.from_state( model.get_state(), @@ -158,6 +165,9 @@ def test_module_serialization(tmp_path): model.module.parameters(), model_from_state.module.parameters() ) ) + assert model_from_state.fine_coords is not None + assert torch.equal(model_from_state.fine_coords.lat, fine_coords.lat) + assert torch.equal(model_from_state.fine_coords.lon, fine_coords.lon) torch.save(model.get_state(), tmp_path / "test.ckpt") model_from_disk = DiffusionModel.from_state( @@ -174,6 +184,9 @@ def test_module_serialization(tmp_path): assert torch.equal( loaded_static_inputs.fields[0].data, static_inputs.fields[0].data ) + assert model_from_disk.fine_coords is not None + assert torch.equal(model_from_disk.fine_coords.lat, fine_coords.lat) + assert torch.equal(model_from_disk.fine_coords.lon, fine_coords.lon) def test_from_state_backward_compat_fine_topography(): @@ -181,21 +194,25 @@ def test_from_state_backward_compat_fine_topography(): fine_shape = (16, 32) downscale_factor = 2 static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) # Simulate old checkpoint format: static_inputs not serialized state = model.get_state() state["static_inputs"] = None + state["fine_coords"] = None # Should load correctly via the elif use_fine_topography branch (+1 channel) model_from_old_state = DiffusionModel.from_state(state) assert model_from_old_state.static_inputs is None + assert model_from_old_state.fine_coords is None assert all( torch.equal(p1, p2) for p1, p2 in zip( @@ -209,12 +226,40 @@ def test_from_state_backward_compat_fine_topography(): model_from_old_state.generate_on_batch(batch) +def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs(): + """Old checkpoints that stored coords in static_inputs fields should have + fine_coords auto-migrated on from_state.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + downscale_factor = 2 + static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=downscale_factor, + predict_residual=True, + use_fine_topography=True, + static_inputs=static_inputs, + fine_coords=fine_coords, + ) + state = model.get_state() + # Simulate old format: fine_coords absent, but static_inputs fields have coords + del state["fine_coords"] + state["static_inputs"]["fields"][0]["coords"] = fine_coords.get_state() + + model_from_old_state = DiffusionModel.from_state(state) + assert model_from_old_state.fine_coords is not None + assert torch.equal(model_from_old_state.fine_coords.lat, fine_coords.lat) + assert torch.equal(model_from_old_state.fine_coords.lon, fine_coords.lon) + + def _get_diffusion_model( coarse_shape, downscale_factor, predict_residual=True, use_fine_topography=True, static_inputs=None, + fine_coords=None, ): normalizer = PairedNormalizationConfig( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), @@ -237,7 +282,12 @@ def _get_diffusion_model( num_diffusion_generation_steps=3, predict_residual=predict_residual, use_fine_topography=use_fine_topography, - ).build(coarse_shape, downscale_factor, static_inputs=static_inputs) + ).build( + coarse_shape, + downscale_factor, + static_inputs=static_inputs, + fine_coords=fine_coords, + ) @pytest.mark.parametrize("predict_residual", [True, False]) @@ -248,9 +298,11 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph batch_size = 2 if use_fine_topography: static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) else: static_inputs = None + fine_coords = None batch = get_mock_paired_batch( [batch_size, *coarse_shape], [batch_size, *fine_shape] ) @@ -260,6 +312,7 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph predict_residual=predict_residual, use_fine_topography=use_fine_topography, static_inputs=static_inputs, + fine_coords=fine_coords, ) assert model._get_fine_shape(coarse_shape) == fine_shape @@ -399,12 +452,14 @@ def test_DiffusionModel_generate_on_batch_no_target(): coarse_shape = (16, 16) downscale_factor = 2 static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=downscale_factor, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) batch_size = 2 @@ -435,7 +490,9 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): # Full fine domain: 64x64 covers inputs for both (8,8) and (32,32) coarse inputs # with a downscaling factor of 2 full_fine_size = 64 - static_inputs = make_static_inputs((full_fine_size, full_fine_size)) + full_fine_shape = (full_fine_size, full_fine_size) + static_inputs = make_static_inputs(full_fine_shape) + fine_coords = make_fine_coords(full_fine_shape) # need to build with static inputs to get the correct n_in_channels model = _get_diffusion_model( coarse_shape=coarse_shape, @@ -443,6 +500,7 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) n_ensemble = 2 batch_size = 2 @@ -590,12 +648,14 @@ def test_checkpoint_model_build_raises_when_checkpoint_has_static_inputs(tmp_pat coarse_shape = (8, 16) fine_shape = (16, 32) static_inputs = make_static_inputs(fine_shape) + fine_coords = make_fine_coords(fine_shape) model = _get_diffusion_model( coarse_shape=coarse_shape, downscale_factor=2, predict_residual=True, use_fine_topography=True, static_inputs=static_inputs, + fine_coords=fine_coords, ) checkpoint_path = tmp_path / "test.ckpt" torch.save({"model": model.get_state()}, checkpoint_path) diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index ce253ab14..5054fb538 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -2,8 +2,10 @@ import pytest import torch +import xarray as xr import yaml +from fme.core.coordinates import LatLonCoordinates from fme.core.loss import LossConfig from fme.core.normalizer import NormalizationConfig from fme.core.testing.wandb import mock_wandb @@ -64,6 +66,18 @@ def get_model_config( ) +def load_fine_coords_from_path(path: str) -> LatLonCoordinates: + """Load lat/lon coordinates from a netCDF or zarr data file.""" + if path.endswith(".zarr"): + ds = xr.open_zarr(path) + else: + ds = xr.open_dataset(path) + return LatLonCoordinates( + lat=torch.tensor(ds["lat"].values, dtype=torch.float32), + lon=torch.tensor(ds["lon"].values, dtype=torch.float32), + ) + + def create_predictor_config( tmp_path, n_samples: int, @@ -97,7 +111,7 @@ def create_predictor_config( out_path = tmp_path / "predictor-config.yaml" with open(out_path, "w") as file: yaml.dump(config, file) - return out_path, f"{paths.fine}/data.nc" + return out_path, paths def test_predictor_runs(tmp_path, very_fast_only: bool): @@ -106,15 +120,18 @@ def test_predictor_runs(tmp_path, very_fast_only: bool): n_samples = 2 coarse_shape = (4, 4) downscale_factor = 2 - predictor_config_path, fine_data_path = create_predictor_config( + predictor_config_path, paths = create_predictor_config( tmp_path, n_samples, ) + fine_data_path = f"{paths.fine}/data.nc" + fine_coords = load_fine_coords_from_path(fine_data_path) model_config = get_model_config(coarse_shape, downscale_factor=downscale_factor) model = model_config.build( coarse_shape=coarse_shape, downscale_factor=downscale_factor, static_inputs=load_static_inputs({"HGTsfc": fine_data_path}), + fine_coords=fine_coords, ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) @@ -147,7 +164,7 @@ def test_predictor_renaming( coarse_shape = (4, 4) downscale_factor = 2 renaming = {"var0": "var0_renamed", "var1": "var1_renamed"} - predictor_config_path, fine_data_path = create_predictor_config( + predictor_config_path, paths = create_predictor_config( tmp_path, n_samples, model_renaming=renaming, @@ -155,13 +172,12 @@ def test_predictor_renaming( "rename": {"var0": "var0_renamed", "var1": "var1_renamed"} }, ) + fine_coords = load_fine_coords_from_path(f"{paths.fine}/data.nc") model_config = get_model_config( coarse_shape, downscale_factor, use_fine_topography=False ) model = model_config.build( - coarse_shape=coarse_shape, - downscale_factor=2, - static_inputs=load_static_inputs({"HGTsfc": fine_data_path}), + coarse_shape=coarse_shape, downscale_factor=2, fine_coords=fine_coords ) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index a1dd08fde..5a0c25666 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -443,6 +443,7 @@ def build(self) -> Trainer: model_coarse_shape, train_data.downscale_factor, static_inputs=static_inputs, + fine_coords=train_data.fine_coords, ) optimization = self.optimization.build( From 4468da38ff7aa180466b37f3e96f3c9ea74b6bdf Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 13 Mar 2026 14:58:38 -0700 Subject: [PATCH 11/20] Make fine coords required --- fme/downscaling/models.py | 101 ++++++++++++++++++++++++--------- fme/downscaling/predict.py | 5 -- fme/downscaling/test_models.py | 19 +++++-- 3 files changed, 87 insertions(+), 38 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 1b5108b74..94f0c5a81 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -182,9 +182,9 @@ def build( self, coarse_shape: tuple[int, int], downscale_factor: int, + fine_coords: LatLonCoordinates, rename: dict[str, str] | None = None, static_inputs: StaticInputs | None = None, - fine_coords: LatLonCoordinates | None = None, ) -> "DiffusionModel": invert_rename = {v: k for k, v in (rename or {}).items()} orig_in_names = [invert_rename.get(name, name) for name in self.in_names] @@ -279,8 +279,8 @@ def __init__( coarse_shape: tuple[int, int], downscale_factor: int, sigma_data: float, + fine_coords: LatLonCoordinates, static_inputs: StaticInputs | None = None, - fine_coords: LatLonCoordinates | None = None, ) -> None: """ Args: @@ -297,10 +297,12 @@ def __init__( coarse to fine. sigma_data: The standard deviation of the data, used for diffusion model preconditioning. + fine_coords: the full-domain fine-resolution coordinates to use + for spatial metadata in the model output. static_inputs: Static inputs to the model, loaded from the trainer config or checkpoint. Must be set when use_fine_topography is True. - fine_coords: Full-domain fine-resolution coordinates. Used as the - single coordinate authority for output spatial metadata. + Expected to be on the same coordinate grid as fine_coords for now, + but this may be relaxed in the future. """ self.coarse_shape = coarse_shape self.downscale_factor = downscale_factor @@ -317,12 +319,13 @@ def __init__( static_inputs.to_device() if static_inputs is not None else None ) self.fine_coords = fine_coords - if fine_coords is not None and static_inputs is not None: + if static_inputs is not None: expected = (len(fine_coords.lat), len(fine_coords.lon)) if static_inputs.shape != expected: raise ValueError( - f"static_inputs shape {static_inputs.shape} does not match " - f"fine_coords grid {expected}" + f"static_inputs are expected to be on the same coordinate grid as " + f"fine_coords. StaticInputs shape {static_inputs.shape} does not " + f"match fine_coords grid {expected}." ) @property @@ -346,10 +349,6 @@ def _subset_static_inputs( "Static inputs must be provided for each batch when use of fine " "static inputs is enabled." ) - if self.fine_coords is None: - raise ValueError( - "fine_coords must be set on the model to subset static inputs." - ) lat_slice = lat_interval.slice_of(self.fine_coords.lat) lon_slice = lon_interval.slice_of(self.fine_coords.lon) return self.static_inputs.subset(lat_slice, lon_slice) @@ -561,11 +560,6 @@ def generate_on_batch_no_target( "Static inputs must be provided for each batch when use of fine " "static inputs is enabled." ) - if self.fine_coords is None: - raise ValueError( - "fine_coords must be set on the model when use_fine_topography " - "is enabled." - ) coarse_lat = batch.latlon_coordinates.lat[0] coarse_lon = batch.latlon_coordinates.lon[0] fine_lat_interval = adjust_fine_coord_range( @@ -627,9 +621,7 @@ def get_state(self) -> Mapping[str, Any]: "coarse_shape": self.coarse_shape, "downscale_factor": self.downscale_factor, "static_inputs": static_inputs_state, - "fine_coords": ( - self.fine_coords.get_state() if self.fine_coords is not None else None - ), + "fine_coords": self.fine_coords.get_state(), } @classmethod @@ -646,24 +638,31 @@ def from_state( # Load fine_coords: new checkpoints store it directly; old checkpoints # that had static_inputs with coords can auto-migrate from raw state. - if state.get("fine_coords") is not None: + fine_coords = state.get("fine_coords") + if fine_coords is not None: + # TODO: Why doesn't LatLonCoordinates have a from_state method? fine_coords = LatLonCoordinates( lat=state["fine_coords"]["lat"], lon=state["fine_coords"]["lon"], ) elif ( - state.get("static_inputs") is not None - and len(state["static_inputs"].get("fields", [])) > 0 + static_inputs is not None + and static_inputs.fields and "coords" in state["static_inputs"]["fields"][0] ): - # Backward compat: old checkpoints stored coords inside static_inputs fields + # Backward compat: old checkpoints with static inputs stored coords inside + # static_inputs fields[0]["coords"] coords_state = state["static_inputs"]["fields"][0]["coords"] fine_coords = LatLonCoordinates( lat=coords_state["lat"], lon=coords_state["lon"], ) else: - fine_coords = None + raise ValueError( + "No fine coordinates found in checkpoint state and no static inputs " + " were available to infer them. fine_coords must be serialized with the" + " checkpoint to resume from training." + ) model = config.build( state["coarse_shape"], @@ -758,12 +757,16 @@ def _load_fine_coords_from_path(self, path: str) -> LatLonCoordinates: ds = xr.open_zarr(path) else: ds = xr.open_dataset(path) - lat_name = next((n for n in ["lat", "latitude"] if n in ds.coords), None) - lon_name = next((n for n in ["lon", "longitude"] if n in ds.coords), None) + lat_name = next( + (n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None + ) + lon_name = next( + (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None + ) if lat_name is None or lon_name is None: raise ValueError( f"Could not find lat/lon coordinates in {path}. " - "Expected 'lat'/'latitude' and 'lon'/'longitude'." + "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." ) return LatLonCoordinates( lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), @@ -788,11 +791,55 @@ def build( static_inputs = load_static_inputs(self.static_inputs) else: static_inputs = None + + fine_coords: LatLonCoordinates + has_fine_coords = self._checkpoint["model"]["fine_coords"] is not None + has_static_input_coords = ( + self._checkpoint["model"]["static_inputs"] is not None + and self._checkpoint["model"]["static_inputs"]["fields"][0].get("coords") + is not None + ) + # TODO: simplify with static input refactor that deisables empty StaticInputs + if ( + has_fine_coords + or has_static_input_coords + and self.fine_coordinates_path is not None + ): + raise ValueError( + "The model checkpoint already has fine coordinates from training. " + "fine_coordinates_path should not be provided in checkpoint model " + "config." + ) + elif has_fine_coords: + fine_coords_state = self._checkpoint["model"]["fine_coords"] + fine_coords = LatLonCoordinates( + lat=fine_coords_state["lat"], + lon=fine_coords_state["lon"], + ) + elif has_static_input_coords: + coords_state = self._checkpoint["model"]["static_inputs"]["fields"][0][ + "coords" + ] + fine_coords = LatLonCoordinates( + lat=coords_state["lat"], + lon=coords_state["lon"], + ) + elif self.fine_coordinates_path is not None: + fine_coords = self._load_fine_coords_from_path(self.fine_coordinates_path) + else: + raise ValueError( + "No fine coordinates found in checkpoint state and no static inputs " + " were available to infer them. fine_coordinates_path must be provided " + "in the checkpoint model configuration to load fine coordinates from " + "the provided path." + ) + model = _CheckpointModelConfigSelector.from_state( self._checkpoint["model"]["config"] ).build( coarse_shape=self._checkpoint["model"]["coarse_shape"], downscale_factor=self._checkpoint["model"]["downscale_factor"], + fine_coords=fine_coords, rename=self._rename, static_inputs=static_inputs, ) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index a93361aa0..449e84368 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -121,11 +121,6 @@ def run(self): logging.info(f"Running {self.event_name} event downscaling...") batch = next(iter(self.data.get_generator())) coarse_coords = batch[0].latlon_coordinates - if self.model.fine_coords is None: - raise ValueError( - "Model fine_coords must be set for event downscaling output " - "coordinates." - ) coarse_lat = coarse_coords.lat coarse_lon = coarse_coords.lon lat_interval = ClosedInterval(coarse_lat.min().item(), coarse_lat.max().item()) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index dc916014c..d801eec25 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -204,15 +204,15 @@ def test_from_state_backward_compat_fine_topography(): fine_coords=fine_coords, ) - # Simulate old checkpoint format: static_inputs not serialized + # Simulate old checkpoint format: static_inputs not serialized, fine_coords still + # present state = model.get_state() state["static_inputs"] = None - state["fine_coords"] = None # Should load correctly via the elif use_fine_topography branch (+1 channel) model_from_old_state = DiffusionModel.from_state(state) assert model_from_old_state.static_inputs is None - assert model_from_old_state.fine_coords is None + assert torch.equal(model_from_old_state.fine_coords.lat, fine_coords.lat) assert all( torch.equal(p1, p2) for p1, p2 in zip( @@ -259,12 +259,18 @@ def _get_diffusion_model( predict_residual=True, use_fine_topography=True, static_inputs=None, - fine_coords=None, + fine_coords: LatLonCoordinates | None = None, ): normalizer = PairedNormalizationConfig( NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), ) + if fine_coords is None: + fine_shape = ( + coarse_shape[0] * downscale_factor, + coarse_shape[1] * downscale_factor, + ) + fine_coords = make_fine_coords(fine_shape) return DiffusionModelConfig( module=DiffusionModuleRegistrySelector( @@ -296,13 +302,12 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph coarse_shape = (8, 16) fine_shape = (16, 32) batch_size = 2 + fine_coords = make_fine_coords(fine_shape) if use_fine_topography: static_inputs = make_static_inputs(fine_shape) - fine_coords = make_fine_coords(fine_shape) batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) else: static_inputs = None - fine_coords = None batch = get_mock_paired_batch( [batch_size, *coarse_shape], [batch_size, *fine_shape] ) @@ -436,6 +441,7 @@ def test_model_error_cases(): ).build( coarse_shape, upscaling_factor, + fine_coords=make_fine_coords(fine_shape), ) batch = get_mock_paired_batch( [batch_size, *coarse_shape], [batch_size, *fine_shape] @@ -552,6 +558,7 @@ def test_lognorm_noise_backwards_compatibility(): model = model_config.build( (32, 32), 2, + fine_coords=make_fine_coords((64, 64)), ) state = model.get_state() From 229e85823073abca2d30c164b578c13d9d05b6bd Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 13 Mar 2026 15:12:27 -0700 Subject: [PATCH 12/20] Fine coords required for paired data --- fme/downscaling/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index 2b6798f04..b11d0e49d 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -337,7 +337,7 @@ class PairedGriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex - fine_coords: LatLonCoordinates | None = None + fine_coords: LatLonCoordinates @property def loader(self) -> DataLoader[PairedBatchItem]: From 22426dced1f642c817665e0fbf2bafad78c2f7cb Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 20:39:08 -0700 Subject: [PATCH 13/20] Mesh with previous updates in refactor pr --- fme/downscaling/data/static.py | 2 +- fme/downscaling/models.py | 60 ++++++++-------------------------- fme/downscaling/test_models.py | 42 +++++++++--------------- 3 files changed, 30 insertions(+), 74 deletions(-) diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index ae4721f8d..48ab54ed7 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -123,7 +123,7 @@ def from_state(cls, state: dict) -> "StaticInputs": def load_static_inputs( static_inputs_config: dict[str, str] | None, -) -> "StaticInputs | None": +) -> StaticInputs | None: """ Load normalized static inputs from a mapping of field names to file paths. Returns None if the input config is empty. diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 94f0c5a81..11ec1b1db 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -318,7 +318,11 @@ def __init__( self.static_inputs = ( static_inputs.to_device() if static_inputs is not None else None ) - self.fine_coords = fine_coords + device = get_device() + self.fine_coords = LatLonCoordinates( + lat=fine_coords.lat.to(device), + lon=fine_coords.lon.to(device), + ) if static_inputs is not None: expected = (len(fine_coords.lat), len(fine_coords.lon)) if static_inputs.shape != expected: @@ -355,32 +359,23 @@ def _subset_static_inputs( def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: """Return fine-resolution coordinates matching the spatial extent of batch.""" - if self.static_inputs is None: - raise ValueError( - "Model is missing static inputs, which are required to determine " - "the coordinate information for the output dataset." - ) coarse_lat = batch.latlon_coordinates.lat[0] coarse_lon = batch.latlon_coordinates.lon[0] fine_lat_interval = adjust_fine_coord_range( batch.lat_interval, full_coarse_coord=coarse_lat, - full_fine_coord=self.static_inputs.coords.lat, + full_fine_coord=self.fine_coords.lat, downscale_factor=self.downscale_factor, ) fine_lon_interval = adjust_fine_coord_range( batch.lon_interval, full_coarse_coord=coarse_lon, - full_fine_coord=self.static_inputs.coords.lon, + full_fine_coord=self.fine_coords.lon, downscale_factor=self.downscale_factor, ) return LatLonCoordinates( - lat=self.static_inputs.coords.lat[ - fine_lat_interval.slice_of(self.static_inputs.coords.lat) - ], - lon=self.static_inputs.coords.lon[ - fine_lon_interval.slice_of(self.static_inputs.coords.lon) - ], + lat=self.fine_coords.lat[fine_lat_interval.slice_of(self.fine_coords.lat)], + lon=self.fine_coords.lon[fine_lon_interval.slice_of(self.fine_coords.lon)], ) @property @@ -574,9 +569,9 @@ def generate_on_batch_no_target( full_fine_coord=self.fine_coords.lon, downscale_factor=self.downscale_factor, ) - lat_slice = fine_lat_interval.slice_of(self.fine_coords.lat) - lon_slice = fine_lon_interval.slice_of(self.fine_coords.lon) - _static_inputs = self.static_inputs.subset(lat_slice, lon_slice) + _static_inputs = self._subset_static_inputs( + fine_lat_interval, fine_lon_interval + ) else: _static_inputs = None generated, _, _ = self.generate(batch.data, _static_inputs, n_samples) @@ -801,10 +796,8 @@ def build( ) # TODO: simplify with static input refactor that deisables empty StaticInputs if ( - has_fine_coords - or has_static_input_coords - and self.fine_coordinates_path is not None - ): + has_fine_coords or has_static_input_coords + ) and self.fine_coordinates_path is not None: raise ValueError( "The model checkpoint already has fine coordinates from training. " "fine_coordinates_path should not be provided in checkpoint model " @@ -844,31 +837,6 @@ def build( static_inputs=static_inputs, ) model.module.load_state_dict(self._checkpoint["model"]["module"]) - - # Restore fine_coords: new checkpoints have it stored directly; old - # checkpoints may have coords embedded in static_inputs fields. - model_state = self._checkpoint["model"] - if model_state.get("fine_coords") is not None: - fine_coords_state = model_state["fine_coords"] - model.fine_coords = LatLonCoordinates( - lat=fine_coords_state["lat"], - lon=fine_coords_state["lon"], - ) - elif ( - model_state.get("static_inputs") is not None - and len(model_state["static_inputs"].get("fields", [])) > 0 - and "coords" in model_state["static_inputs"]["fields"][0] - ): - coords_state = model_state["static_inputs"]["fields"][0]["coords"] - model.fine_coords = LatLonCoordinates( - lat=coords_state["lat"], - lon=coords_state["lon"], - ) - elif self.fine_coordinates_path is not None: - model.fine_coords = self._load_fine_coords_from_path( - self.fine_coordinates_path - ) - return model @property diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index d801eec25..803baefda 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -166,8 +166,8 @@ def test_module_serialization(tmp_path): ) ) assert model_from_state.fine_coords is not None - assert torch.equal(model_from_state.fine_coords.lat, fine_coords.lat) - assert torch.equal(model_from_state.fine_coords.lon, fine_coords.lon) + assert torch.equal(model_from_state.fine_coords.lat.cpu(), fine_coords.lat.cpu()) + assert torch.equal(model_from_state.fine_coords.lon.cpu(), fine_coords.lon.cpu()) torch.save(model.get_state(), tmp_path / "test.ckpt") model_from_disk = DiffusionModel.from_state( @@ -185,8 +185,8 @@ def test_module_serialization(tmp_path): loaded_static_inputs.fields[0].data, static_inputs.fields[0].data ) assert model_from_disk.fine_coords is not None - assert torch.equal(model_from_disk.fine_coords.lat, fine_coords.lat) - assert torch.equal(model_from_disk.fine_coords.lon, fine_coords.lon) + assert torch.equal(model_from_disk.fine_coords.lat.cpu(), fine_coords.lat.cpu()) + assert torch.equal(model_from_disk.fine_coords.lon.cpu(), fine_coords.lon.cpu()) def test_from_state_backward_compat_fine_topography(): @@ -212,7 +212,9 @@ def test_from_state_backward_compat_fine_topography(): # Should load correctly via the elif use_fine_topography branch (+1 channel) model_from_old_state = DiffusionModel.from_state(state) assert model_from_old_state.static_inputs is None - assert torch.equal(model_from_old_state.fine_coords.lat, fine_coords.lat) + assert torch.equal( + model_from_old_state.fine_coords.lat.cpu(), fine_coords.lat.cpu() + ) assert all( torch.equal(p1, p2) for p1, p2 in zip( @@ -249,8 +251,12 @@ def test_from_state_backward_compat_migrates_fine_coords_from_old_static_inputs( model_from_old_state = DiffusionModel.from_state(state) assert model_from_old_state.fine_coords is not None - assert torch.equal(model_from_old_state.fine_coords.lat, fine_coords.lat) - assert torch.equal(model_from_old_state.fine_coords.lon, fine_coords.lon) + assert torch.equal( + model_from_old_state.fine_coords.lat.cpu(), fine_coords.lat.cpu() + ) + assert torch.equal( + model_from_old_state.fine_coords.lon.cpu(), fine_coords.lon.cpu() + ) def _get_diffusion_model( @@ -619,30 +625,12 @@ def test_get_fine_coords_for_batch(): result = model.get_fine_coords_for_batch(batch) - expected_lat = model.static_inputs.coords.lat[4:12] - expected_lon = model.static_inputs.coords.lon[8:24] - # model.static_inputs has been moved to device; index into it directly - # to match devices + expected_lat = model.fine_coords.lat[4:12] + expected_lon = model.fine_coords.lon[8:24] assert torch.allclose(result.lat, expected_lat) assert torch.allclose(result.lon, expected_lon) -def test_get_fine_coords_for_batch_raises_without_static_inputs(): - model = _get_diffusion_model( - coarse_shape=(16, 16), - downscale_factor=2, - use_fine_topography=False, - static_inputs=None, - ) - batch = make_batch_data( - (1, 16, 16), - _get_monotonic_coordinate(16, stop=16).tolist(), - _get_monotonic_coordinate(16, stop=16).tolist(), - ) - with pytest.raises(ValueError, match="missing static inputs"): - model.get_fine_coords_for_batch(batch) - - def test_checkpoint_config_topography_raises(): with pytest.raises(ValueError): CheckpointModelConfig( From a6895736dfc5871c037d7c3263f2755b676e664c Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:19:02 -0700 Subject: [PATCH 14/20] Simplify event downscaler coordinate in run() --- fme/downscaling/predict.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index 449e84368..d56a16e56 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -21,7 +21,6 @@ ClosedInterval, DataLoaderConfig, GriddedData, - adjust_fine_coord_range, enforce_lat_bounds, ) from fme.downscaling.models import CheckpointModelConfig, DiffusionModel @@ -120,29 +119,11 @@ def generation_model(self): def run(self): logging.info(f"Running {self.event_name} event downscaling...") batch = next(iter(self.data.get_generator())) - coarse_coords = batch[0].latlon_coordinates - coarse_lat = coarse_coords.lat - coarse_lon = coarse_coords.lon - lat_interval = ClosedInterval(coarse_lat.min().item(), coarse_lat.max().item()) - lon_interval = ClosedInterval(coarse_lon.min().item(), coarse_lon.max().item()) - fine_lat_interval = adjust_fine_coord_range( - lat_interval, - full_coarse_coord=coarse_lat, - full_fine_coord=self.model.fine_coords.lat, - downscale_factor=self.model.downscale_factor, - ) - fine_lon_interval = adjust_fine_coord_range( - lon_interval, - full_coarse_coord=coarse_lon, - full_fine_coord=self.model.fine_coords.lon, - downscale_factor=self.model.downscale_factor, - ) - lat_slice = fine_lat_interval.slice_of(self.model.fine_coords.lat) - lon_slice = fine_lon_interval.slice_of(self.model.fine_coords.lon) - fine_coords = LatLonCoordinates( - lat=self.model.fine_coords.lat[lat_slice], - lon=self.model.fine_coords.lon[lon_slice], + coarse_coords = LatLonCoordinates( + lat=batch[0].latlon_coordinates.lat, + lon=batch[0].latlon_coordinates.lon, ) + fine_coords = self.model.get_fine_coords_for_batch(batch) sample_agg = SampleAggregator( coarse=batch[0].data, latlon_coordinates=FineResCoarseResPair( From e28502399c87a2db35b1e589488a562918bd2d74 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:20:00 -0700 Subject: [PATCH 15/20] use batch latlon coardinates for coarse --- fme/downscaling/predict.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index d56a16e56..d35d0bc1a 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -9,7 +9,6 @@ import yaml from fme.core.cli import prepare_directory -from fme.core.coordinates import LatLonCoordinates from fme.core.dataset.time import TimeSlice from fme.core.dicts import to_flat_dict from fme.core.distributed import Distributed @@ -119,10 +118,7 @@ def generation_model(self): def run(self): logging.info(f"Running {self.event_name} event downscaling...") batch = next(iter(self.data.get_generator())) - coarse_coords = LatLonCoordinates( - lat=batch[0].latlon_coordinates.lat, - lon=batch[0].latlon_coordinates.lon, - ) + coarse_coords = batch[0].latlon_coordinates fine_coords = self.model.get_fine_coords_for_batch(batch) sample_agg = SampleAggregator( coarse=batch[0].data, From db2732cb3bc3c3550b9494a6924bc0350b6e7cc9 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:26:29 -0700 Subject: [PATCH 16/20] Make fine coord loader public --- fme/downscaling/models.py | 43 ++++++++++++++++----------------- fme/downscaling/test_predict.py | 20 ++++----------- fme/downscaling/test_utils.py | 1 + 3 files changed, 27 insertions(+), 37 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 11ec1b1db..f50071cad 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -680,6 +680,26 @@ def from_state(cls, state: Mapping[str, Any]) -> DiffusionModelConfig: ).wrapper +def load_fine_coords_from_path(path: str) -> LatLonCoordinates: + if path.endswith(".zarr"): + ds = xr.open_zarr(path) + else: + ds = xr.open_dataset(path) + lat_name = next((n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None) + lon_name = next( + (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None + ) + if lat_name is None or lon_name is None: + raise ValueError( + f"Could not find lat/lon coordinates in {path}. " + "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." + ) + return LatLonCoordinates( + lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), + lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), + ) + + @dataclasses.dataclass class CheckpointModelConfig: """ @@ -747,27 +767,6 @@ def _checkpoint(self) -> Mapping[str, Any]: checkpoint_data["model"]["config"][k] = v return self._checkpoint_data - def _load_fine_coords_from_path(self, path: str) -> LatLonCoordinates: - if path.endswith(".zarr"): - ds = xr.open_zarr(path) - else: - ds = xr.open_dataset(path) - lat_name = next( - (n for n in ["lat", "latitude", "grid_yt"] if n in ds.coords), None - ) - lon_name = next( - (n for n in ["lon", "longitude", "grid_xt"] if n in ds.coords), None - ) - if lat_name is None or lon_name is None: - raise ValueError( - f"Could not find lat/lon coordinates in {path}. " - "Expected 'lat'/'latitude'/'grid_yt' and 'lon'/'longitude'/'grid_xt'." - ) - return LatLonCoordinates( - lat=torch.tensor(ds[lat_name].values, dtype=torch.float32), - lon=torch.tensor(ds[lon_name].values, dtype=torch.float32), - ) - def build( self, ) -> DiffusionModel: @@ -818,7 +817,7 @@ def build( lon=coords_state["lon"], ) elif self.fine_coordinates_path is not None: - fine_coords = self._load_fine_coords_from_path(self.fine_coordinates_path) + fine_coords = load_fine_coords_from_path(self.fine_coordinates_path) else: raise ValueError( "No fine coordinates found in checkpoint state and no static inputs " diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index 5054fb538..8101b293e 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -2,16 +2,18 @@ import pytest import torch -import xarray as xr import yaml -from fme.core.coordinates import LatLonCoordinates from fme.core.loss import LossConfig from fme.core.normalizer import NormalizationConfig from fme.core.testing.wandb import mock_wandb from fme.downscaling import predict from fme.downscaling.data import load_static_inputs -from fme.downscaling.models import DiffusionModelConfig, PairedNormalizationConfig +from fme.downscaling.models import ( + DiffusionModelConfig, + PairedNormalizationConfig, + load_fine_coords_from_path, +) from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.test_models import LinearDownscaling from fme.downscaling.test_utils import data_paths_helper @@ -66,18 +68,6 @@ def get_model_config( ) -def load_fine_coords_from_path(path: str) -> LatLonCoordinates: - """Load lat/lon coordinates from a netCDF or zarr data file.""" - if path.endswith(".zarr"): - ds = xr.open_zarr(path) - else: - ds = xr.open_dataset(path) - return LatLonCoordinates( - lat=torch.tensor(ds["lat"].values, dtype=torch.float32), - lon=torch.tensor(ds["lon"].values, dtype=torch.float32), - ) - - def create_predictor_config( tmp_path, n_samples: int, diff --git a/fme/downscaling/test_utils.py b/fme/downscaling/test_utils.py index 4f34a4451..b1546e044 100644 --- a/fme/downscaling/test_utils.py +++ b/fme/downscaling/test_utils.py @@ -78,4 +78,5 @@ def data_paths_helper( create_test_data_on_disk( coarse_path / "data.nc", dim_sizes.coarse, variable_names, coords ) + # TODO: should this return the full filename instead of just the directory? return FineResCoarseResPair[str](fine=fine_path, coarse=coarse_path) From 5b13f6749e5193b2d12e95d8c5e3b850a858741d Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:40:17 -0700 Subject: [PATCH 17/20] BatchLatLon coord access consistency --- fme/downscaling/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index d35d0bc1a..c6b0a3677 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -118,7 +118,7 @@ def generation_model(self): def run(self): logging.info(f"Running {self.event_name} event downscaling...") batch = next(iter(self.data.get_generator())) - coarse_coords = batch[0].latlon_coordinates + coarse_coords = batch.latlon_coordinates[0] fine_coords = self.model.get_fine_coords_for_batch(batch) sample_agg = SampleAggregator( coarse=batch[0].data, From eb57411d06fe96917b16f928fe90ffcef537a7de Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:41:35 -0700 Subject: [PATCH 18/20] linting --- fme/downscaling/predictors/composite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index 405ac1eed..5d0a4be8f 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -85,7 +85,7 @@ def static_inputs(self): def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: return self.model.get_fine_coords_for_batch(batch) - + @property def fine_coords(self): return self.model.fine_coords From ae4b09ecb3c57f6ab5e5a4594b5960f4ccd68fa1 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 21:45:45 -0700 Subject: [PATCH 19/20] Add no coords checkpoint with path test --- fme/downscaling/test_models.py | 37 ++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 803baefda..98bb76b33 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -631,6 +631,43 @@ def test_get_fine_coords_for_batch(): assert torch.allclose(result.lon, expected_lon) +def test_checkpoint_model_build_with_fine_coordinates_path(tmp_path): + """Old-format checkpoint (no fine_coords key, no coords in static_inputs) + should load correctly when fine_coordinates_path is provided.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + fine_coords = make_fine_coords(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + use_fine_topography=False, + fine_coords=fine_coords, + ) + # Simulate old checkpoint: no fine_coords stored + state = model.get_state() + del state["fine_coords"] + checkpoint_path = tmp_path / "test.ckpt" + torch.save({"model": state}, checkpoint_path) + + # Write fine coords to a netCDF file for the loader to read + coords_path = tmp_path / "fine_coords.nc" + ds = xr.Dataset( + coords={ + "lat": fine_coords.lat.cpu().numpy(), + "lon": fine_coords.lon.cpu().numpy(), + } + ) + ds.to_netcdf(coords_path) + + loaded_model = CheckpointModelConfig( + checkpoint_path=str(checkpoint_path), + fine_coordinates_path=str(coords_path), + ).build() + + assert torch.equal(loaded_model.fine_coords.lat.cpu(), fine_coords.lat.cpu()) + assert torch.equal(loaded_model.fine_coords.lon.cpu(), fine_coords.lon.cpu()) + + def test_checkpoint_config_topography_raises(): with pytest.raises(ValueError): CheckpointModelConfig( From 227fd19628a698f3c10254ed2820ffd3e045c5c8 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 16 Mar 2026 22:09:56 -0700 Subject: [PATCH 20/20] Remove fine_shape call signature chain --- fme/downscaling/inference/inference.py | 1 - fme/downscaling/inference/output.py | 95 +++++++++----------- fme/downscaling/inference/test_inference.py | 1 - fme/downscaling/inference/test_output.py | 60 +++++++++++-- fme/downscaling/inference/test_work_items.py | 4 +- fme/downscaling/inference/work_items.py | 13 +-- 6 files changed, 98 insertions(+), 76 deletions(-) diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index 5ba046140..c78885e90 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -240,7 +240,6 @@ def build(self) -> Downscaler: loader_config=self.data, requirements=self.model.data_requirements, patch=self.patch, - fine_shape=model.fine_shape, ) for output_cfg in self.outputs ] diff --git a/fme/downscaling/inference/output.py b/fme/downscaling/inference/output.py index 2af28e27e..e29000679 100644 --- a/fme/downscaling/inference/output.py +++ b/fme/downscaling/inference/output.py @@ -28,6 +28,13 @@ from .zarr_utils import determine_zarr_chunks +@dataclass +class WriterParams: + chunks: dict[str, int] + shards: dict[str, int] + coords: dict[str, np.ndarray] + + def _identity_collate(batch): """ Collate function that returns the single batch item. @@ -66,8 +73,8 @@ def __init__( max_samples_per_gpu: int, data: SliceWorkItemGriddedData, patch: PatchPredictionConfig, - chunks: dict[str, int], - shards: dict[str, int], + zarr_chunks_override: dict[str, int] | None, + zarr_shards_override: dict[str, int] | None, dims: tuple[str, ...] = DIMS, ) -> None: self.name = name @@ -76,22 +83,20 @@ def __init__( self.max_samples_per_gpu = max_samples_per_gpu self.data = data self.patch = patch - self.chunks = chunks - self.shards = shards + self.zarr_chunks_override = zarr_chunks_override + self.zarr_shards_override = zarr_shards_override self.dims = dims - def get_writer( - self, - latlon_coords: LatLonCoordinates, - output_dir: str, - ) -> ZarrWriter: - """ - Create a ZarrWriter for this target. - - Args: - latlon_coords: High-resolution spatial coordinates for outputs - output_dir: Directory to store output zarr file - """ + def _build_writer_params(self, latlon_coords: LatLonCoordinates) -> WriterParams: + lat_size = len(latlon_coords.lat) + lon_size = len(latlon_coords.lon) + n_times, n_ens = self.data.max_output_shape + full_shape = (n_times, n_ens, lat_size, lon_size) + element_size = torch.tensor([], dtype=self.data.dtype).element_size() + chunks = self.zarr_chunks_override or determine_zarr_chunks( + DIMS, full_shape, element_size + ) + shards = self.zarr_shards_override or dict(zip(DIMS, full_shape)) ensemble = list(range(self.n_ens)) coords = dict( zip( @@ -104,15 +109,29 @@ def get_writer( ], ) ) - dims = tuple(coords.keys()) + return WriterParams(chunks=chunks, shards=shards, coords=coords) + + def get_writer( + self, + latlon_coords: LatLonCoordinates, + output_dir: str, + ) -> ZarrWriter: + """ + Create a ZarrWriter for this target. + Args: + latlon_coords: High-resolution spatial coordinates for outputs + output_dir: Directory to store output zarr file + """ + params = self._build_writer_params(latlon_coords) + dims = tuple(params.coords.keys()) return ZarrWriter( path=f"{output_dir}/{self.name}.zarr", dims=dims, - coords=coords, + coords=params.coords, data_vars=self.save_vars, - chunks=self.chunks, - shards=self.shards, + chunks=params.chunks, + shards=params.shards, ) @@ -152,7 +171,6 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, - fine_shape: tuple[int, int], ) -> DownscalingOutput: """ Build an OutputTarget from this configuration. @@ -161,8 +179,6 @@ def build( loader_config: Base data loader configuration to modify requirements: Model's data requirements (variable names, etc.) patch: Default patch prediction configuration - fine_shape: Fine shape of the output used as metadata - for the shape of the output to insert into the dataset """ pass @@ -220,7 +236,6 @@ def _build_gridded_data( loader_config: DataLoaderConfig, requirements: DataRequirements, dist: Distributed | None = None, - fine_shape: tuple[int, int] | None = None, ) -> SliceWorkItemGriddedData: xr_dataset, properties = loader_config.get_xarray_dataset( names=requirements.coarse_names, n_timesteps=1 @@ -241,7 +256,6 @@ def _build_gridded_data( slice_dataset = SliceItemDataset( slice_items=work_items, dataset=dataset, - spatial_shape=fine_shape, ) # each SliceItemDataset work item loads its own full batch, so batch_size=1 @@ -279,7 +293,6 @@ def _build( requirements: DataRequirements, patch: PatchPredictionConfig, coarse: list[XarrayDataConfig], - fine_shape: tuple[int, int] | None = None, ) -> DownscalingOutput: updated_loader_config = self._replace_loader_config( time, @@ -289,27 +302,7 @@ def _build( loader_config, ) - gridded_data = self._build_gridded_data( - updated_loader_config, - requirements, - fine_shape=fine_shape, - ) - - if self.zarr_chunks is None: - # Get element size from dtype by creating a dummy tensor - element_size = torch.tensor([], dtype=gridded_data.dtype).element_size() - chunks = determine_zarr_chunks( - dims=DIMS, - data_shape=gridded_data.max_output_shape, - bytes_per_element=element_size, - ) - else: - chunks = self.zarr_chunks - - if self.zarr_shards is None: - shards = dict(zip(DIMS, gridded_data.max_output_shape)) - else: - shards = self.zarr_shards + gridded_data = self._build_gridded_data(updated_loader_config, requirements) return DownscalingOutput( name=self.name, @@ -318,8 +311,8 @@ def _build( max_samples_per_gpu=self.max_samples_per_gpu, data=gridded_data, patch=patch, - chunks=chunks, - shards=shards, + zarr_chunks_override=self.zarr_chunks, + zarr_shards_override=self.zarr_shards, dims=DIMS, ) @@ -379,7 +372,6 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, - fine_shape: tuple[int, int] | None = None, ) -> DownscalingOutput: # Convert single time to TimeSlice time: Slice | TimeSlice @@ -402,7 +394,6 @@ def build( requirements=requirements, patch=patch, coarse=coarse, - fine_shape=fine_shape, ) @@ -462,7 +453,6 @@ def build( loader_config: DataLoaderConfig, requirements: DataRequirements, patch: PatchPredictionConfig, - fine_shape: tuple[int, int] | None = None, ) -> DownscalingOutput: coarse = self._single_xarray_config(loader_config.coarse) return self._build( @@ -473,5 +463,4 @@ def build( requirements=requirements, patch=patch, coarse=coarse, - fine_shape=fine_shape, ) diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index c06b7b0cc..91c10daf8 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -44,7 +44,6 @@ def mock_model(): """Create a mock model with coarse_shape attribute.""" model = MagicMock() model.coarse_shape = (16, 16) - model.fine_shape = (32, 32) return model diff --git a/fme/downscaling/inference/test_output.py b/fme/downscaling/inference/test_output.py index 413fcfcb0..23bee3d44 100644 --- a/fme/downscaling/inference/test_output.py +++ b/fme/downscaling/inference/test_output.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock +import numpy as np import pytest +import torch from fme.core.dataset.time import TimeSlice from fme.core.dataset.xarray import XarrayDataConfig @@ -10,10 +12,59 @@ DownscalingOutputConfig, EventConfig, TimeRangeConfig, + WriterParams, ) from fme.downscaling.predictors import PatchPredictionConfig from fme.downscaling.requirements import DataRequirements + +def _make_downscaling_output(zarr_chunks_override=None, zarr_shards_override=None): + mock_data = MagicMock() + mock_data.max_output_shape = (2, 4) + mock_data.dtype = torch.float32 + mock_data.all_times.to_numpy.return_value = np.zeros(2) + return DownscalingOutput( + name="test", + save_vars=None, + n_ens=4, + max_samples_per_gpu=4, + data=mock_data, + patch=MagicMock(), + zarr_chunks_override=zarr_chunks_override, + zarr_shards_override=zarr_shards_override, + ) + + +def _make_latlon(lat_size=10, lon_size=20): + latlon = MagicMock() + latlon.lat = torch.zeros(lat_size) + latlon.lon = torch.zeros(lon_size) + return latlon + + +def test_build_writer_params_default_chunks_and_shards(): + output = _make_downscaling_output() + latlon = _make_latlon(lat_size=10, lon_size=20) + params = output._build_writer_params(latlon) + assert isinstance(params, WriterParams) + assert params.shards == {"time": 2, "ensemble": 4, "latitude": 10, "longitude": 20} + assert params.chunks["time"] == 1 + assert params.chunks["ensemble"] == 1 + + +def test_build_writer_params_override_chunks_and_shards(): + zarr_chunks_override = {"time": 5, "ensemble": 5, "latitude": 5, "longitude": 5} + zarr_shards_override = {"time": 10, "ensemble": 10, "latitude": 10, "longitude": 10} + output = _make_downscaling_output( + zarr_chunks_override=zarr_chunks_override, + zarr_shards_override=zarr_shards_override, + ) + latlon = _make_latlon(lat_size=10, lon_size=20) + params = output._build_writer_params(latlon) + assert params.chunks == zarr_chunks_override + assert params.shards == zarr_shards_override + + # Tests for OutputTargetConfig validation @@ -112,10 +163,6 @@ def test_event_config_build_creates_output_target_with_single_time( # Verify time dimension - should have exactly 1 timestep assert len(output_target.data.all_times) == 1 assert output_target.data is not None - assert output_target.chunks is not None - assert tuple(output_target.chunks.values())[:2] == (1, 1) - assert output_target.shards is not None - assert tuple(output_target.shards.values()) == output_target.data.max_output_shape @pytest.mark.parametrize("loader_config", [True], indirect=True) @@ -137,12 +184,7 @@ def test_region_config_build_creates_output_target_with_time_range( assert output_target.n_ens == 4 assert len(output_target.data.all_times) == 2 - # Verify chunks dict structure assert output_target.data is not None - assert output_target.chunks is not None - assert tuple(output_target.chunks.values())[:2] == (1, 1) - assert output_target.shards is not None - assert tuple(output_target.shards.values()) == output_target.data.max_output_shape def test_time_range_config_raise_error_invalid_lat_extent(): diff --git a/fme/downscaling/inference/test_work_items.py b/fme/downscaling/inference/test_work_items.py index d6bf9c37a..3bdd24cc8 100644 --- a/fme/downscaling/inference/test_work_items.py +++ b/fme/downscaling/inference/test_work_items.py @@ -409,8 +409,8 @@ def test_slice_item_dataset_max_output_shape( shape = dataset.max_output_shape # First item: time_slice=slice(0,2), ens_slice=slice(0,4) - # n_times = 2, n_ens = 4, spatial = (64, 64) - assert shape == (2, 4, 64, 64) + # n_times = 2, n_ens = 4 + assert shape == (2, 4) def test_slice_item_dataset_dtype_property( diff --git a/fme/downscaling/inference/work_items.py b/fme/downscaling/inference/work_items.py index 25049cf5b..643fe184f 100644 --- a/fme/downscaling/inference/work_items.py +++ b/fme/downscaling/inference/work_items.py @@ -110,18 +110,11 @@ def __init__( self, slice_items: list[SliceWorkItem], dataset: BatchItemDatasetAdapter, - spatial_shape: tuple[int, int] | None = None, ) -> None: self.slice_items = slice_items self.dataset = dataset self._dtype = None - if spatial_shape is None: - sample_batch_item = self.dataset[0] - self.spatial_shape = sample_batch_item.horizontal_shape - else: - self.spatial_shape = spatial_shape - def __len__(self) -> int: return len(self.slice_items) @@ -133,11 +126,11 @@ def __getitem__(self, idx: int) -> LoadedSliceWorkItem: return loaded_item @property - def max_output_shape(self): + def max_output_shape(self) -> tuple[int, int]: first_item = self.slice_items[0] n_times = first_item.time_slice.stop - first_item.time_slice.start n_ensembles = first_item.ens_slice.stop - first_item.ens_slice.start - return (n_times, n_ensembles, *self.spatial_shape) + return (n_times, n_ensembles) @property def dtype(self) -> torch.dtype: @@ -296,7 +289,7 @@ class SliceWorkItemGriddedData: variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex dtype: torch.dtype - max_output_shape: tuple[int, ...] + max_output_shape: tuple[int, int] # TODO: currently no protocol or ABC for gridded data objects # if we want to unify, we will need one and just raise