diff --git a/scripts/ckpt_extract_stats/extract_stats.py b/scripts/ckpt_extract_stats/extract_stats.py new file mode 100644 index 000000000..62a759eec --- /dev/null +++ b/scripts/ckpt_extract_stats/extract_stats.py @@ -0,0 +1,173 @@ +"""Extract normalization statistics from an ACE Trainer checkpoint as netCDF files. + +Handles both modern and legacy checkpoint formats via StepperConfig.from_stepper_state. +Produces files named after the normalization config keys: + - network-means.nc, network-stds.nc (always present) + - residual-means.nc, residual-stds.nc (if residual normalization is present) + - loss-means.nc, loss-stds.nc (if loss normalization is present) +""" + +import logging +import os +import pathlib +from typing import Any + +import click +import xarray as xr + +logger = logging.getLogger(__name__) + +NORMALIZATION_KEYS = ("network", "residual", "loss") + + +def _find_normalization(config: dict[str, Any]) -> dict[str, Any]: + """Recursively search a stepper config dict for the normalization config. + + The normalization config is identified as a dict containing a "network" key + whose value is a dict with "means" and "stds" keys. + """ + if ( + "normalization" in config + and isinstance(config["normalization"], dict) + and "network" in config["normalization"] + ): + network = config["normalization"]["network"] + if isinstance(network, dict) and "means" in network and "stds" in network: + return config["normalization"] + + for value in config.values(): + if isinstance(value, dict): + try: + return _find_normalization(value) + except ValueError: + continue + + raise ValueError( + "Could not find normalization config with network means/stds " + f"in config keys: {list(config.keys())}" + ) + + +def _dict_to_dataset(data: dict[str, float]) -> xr.Dataset: + """Convert a {variable_name: scalar_value} dict to an xarray Dataset.""" + return xr.Dataset({name: xr.DataArray(value) for name, value in data.items()}) + + +def extract_stats(checkpoint_path: str | pathlib.Path) -> dict[str, xr.Dataset]: + """Extract normalization stats from a checkpoint as xarray Datasets. + + Uses StepperConfig.from_stepper_state to parse both legacy and modern + checkpoint formats without building the full stepper (which would require + distributed context, GPU, etc.). + + Returns: + Dict mapping filename to Dataset, e.g. + {"network-means.nc": Dataset, "network-stds.nc": Dataset, ...} + """ + import dataclasses + + import torch + + from fme.ace.stepper.single_module import StepperConfig + + checkpoint = torch.load( + str(checkpoint_path), map_location=torch.device("cpu"), weights_only=False + ) + stepper_state = checkpoint["stepper"] + config = StepperConfig.from_stepper_state(stepper_state) + config_dict = dataclasses.asdict(config) + normalization = _find_normalization(config_dict) + + result: dict[str, xr.Dataset] = {} + for key in NORMALIZATION_KEYS: + norm = normalization.get(key) + if norm is None: + continue + means = norm.get("means", {}) + stds = norm.get("stds", {}) + if means: + filename = f"{key}-means.nc" + result[filename] = _dict_to_dataset(means) + logger.info(f"Extracted {filename} with {len(means)} variables") + if stds: + filename = f"{key}-stds.nc" + result[filename] = _dict_to_dataset(stds) + logger.info(f"Extracted {filename} with {len(stds)} variables") + + return result + + +def write_stats(stats: dict[str, xr.Dataset], output_dir: str | pathlib.Path) -> None: + """Write extracted stats datasets to netCDF files in output_dir.""" + output_dir = pathlib.Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for filename, ds in stats.items(): + path = output_dir / filename + ds.to_netcdf(path) + logger.info(f"Wrote {path}") + + +def upload_to_beaker( + output_dir: str | pathlib.Path, + dataset_name: str, + description: str = "", +) -> None: + """Upload the stats directory to a Beaker dataset.""" + import beaker as beaker_module + from beaker import Beaker + + beaker_client = Beaker.from_env() + try: + beaker_client.dataset.get(dataset_name) + logger.info(f"Beaker dataset '{dataset_name}' already exists. Skipping upload.") + return + except beaker_module.exceptions.BeakerDatasetNotFound: + pass + + beaker_client.dataset.create( + dataset_name, + str(output_dir), + workspace="ai2/ace", + description=description, + ) + logger.info(f"Uploaded stats to Beaker dataset '{dataset_name}'") + + +@click.command() +@click.argument("checkpoint_path", type=click.Path(exists=True)) +@click.option( + "--output-dir", + required=True, + type=click.Path(), + help="Directory to write the extracted netCDF stats files.", +) +@click.option( + "--beaker-dataset", + default=None, + type=str, + help="If provided, upload extracted stats to this Beaker dataset name.", +) +def main(checkpoint_path: str, output_dir: str, beaker_dataset: str | None): + """Extract normalization statistics from a Trainer checkpoint. + + CHECKPOINT_PATH is the path to a .tar checkpoint file. + """ + logging.basicConfig(level=logging.INFO) + + stats = extract_stats(checkpoint_path) + write_stats(stats, output_dir) + + filenames = ", ".join(stats.keys()) + logger.info(f"Extracted stats files: {filenames}") + + if beaker_dataset is not None: + description = ( + f"Normalization stats extracted from checkpoint " + f"{os.path.basename(checkpoint_path)}. " + f"Files: {filenames}." + ) + upload_to_beaker(output_dir, beaker_dataset, description=description) + + +if __name__ == "__main__": + main() diff --git a/scripts/ckpt_extract_stats/test_extract_stats.py b/scripts/ckpt_extract_stats/test_extract_stats.py new file mode 100644 index 000000000..a70bc0240 --- /dev/null +++ b/scripts/ckpt_extract_stats/test_extract_stats.py @@ -0,0 +1,251 @@ +import dataclasses +import datetime +import pathlib +import tempfile + +import pytest +import torch +import xarray as xr +from extract_stats import _find_normalization, extract_stats, write_stats + +from fme.ace.registry import ModuleSelector +from fme.ace.stepper.single_module import StepperConfig +from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates +from fme.core.dataset.data_typing import VariableMetadata +from fme.core.dataset_info import DatasetInfo +from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig +from fme.core.step.multi_call import MultiCallStepConfig +from fme.core.step.single_module import SingleModuleStepConfig +from fme.core.step.step import StepSelector + +TIMESTEP = datetime.timedelta(hours=6) +IN_NAMES = ["a", "b"] +OUT_NAMES = ["a", "b"] + + +class PlusOne(torch.nn.Module): + def forward(self, x): + return x + 1 + + +def _build_stepper( + network_means: dict[str, float], + network_stds: dict[str, float], + residual_stds: dict[str, float] | None = None, +): + """Build a minimal stepper and return it.""" + residual = None + if residual_stds is not None: + residual = NormalizationConfig( + means={k: 0.0 for k in residual_stds}, + stds=residual_stds, + ) + config = StepperConfig( + step=StepSelector( + type="multi_call", + config=dataclasses.asdict( + MultiCallStepConfig( + wrapped_step=StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="prebuilt", config={"module": PlusOne()} + ), + in_names=IN_NAMES, + out_names=OUT_NAMES, + normalization=NetworkAndLossNormalizationConfig( + network=NormalizationConfig( + means=network_means, + stds=network_stds, + ), + residual=residual, + ), + ), + ), + ), + include_multi_call_in_loss=False, + ), + ), + ), + ) + dataset_info = DatasetInfo( + horizontal_coordinates=LatLonCoordinates( + lat=torch.zeros(4), lon=torch.zeros(8) + ), + vertical_coordinate=HybridSigmaPressureCoordinate( + ak=torch.arange(7, dtype=torch.float32), + bk=torch.arange(7, dtype=torch.float32), + ), + timestep=TIMESTEP, + variable_metadata={ + OUT_NAMES[0]: VariableMetadata(units="K", long_name="temperature"), + }, + ) + return config.get_stepper(dataset_info=dataset_info) + + +def _save_checkpoint(path: pathlib.Path, stepper): + torch.save({"stepper": stepper.get_state()}, path) + + +class TestFindNormalization: + def test_finds_normalization_in_nested_config(self): + config = { + "step": { + "type": "multi_call", + "config": { + "wrapped_step": { + "type": "single_module", + "config": { + "normalization": { + "network": { + "means": {"x": 1.0}, + "stds": {"x": 2.0}, + }, + "residual": None, + "loss": None, + } + }, + } + }, + } + } + norm = _find_normalization(config) + assert norm["network"]["means"] == {"x": 1.0} + assert norm["network"]["stds"] == {"x": 2.0} + + def test_raises_when_no_normalization(self): + with pytest.raises(ValueError, match="Could not find normalization"): + _find_normalization({"step": {"config": {}}}) + + +class TestExtractStatsModern: + """Test extraction from modern-format checkpoints.""" + + def test_network_only(self): + means = {"a": 1.5, "b": -0.3} + stds = {"a": 0.5, "b": 1.2} + stepper = _build_stepper(network_means=means, network_stds=stds) + + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_path = pathlib.Path(tmpdir) / "ckpt.tar" + _save_checkpoint(ckpt_path, stepper) + stats = extract_stats(ckpt_path) + + assert "network-means.nc" in stats + assert "network-stds.nc" in stats + assert "residual-means.nc" not in stats + assert "residual-stds.nc" not in stats + assert "loss-means.nc" not in stats + assert "loss-stds.nc" not in stats + + for name, expected in means.items(): + assert float(stats["network-means.nc"][name].values) == pytest.approx( + expected + ) + for name, expected in stds.items(): + assert float(stats["network-stds.nc"][name].values) == pytest.approx( + expected + ) + + def test_network_with_residual(self): + means = {"a": 1.5, "b": -0.3} + stds = {"a": 0.5, "b": 1.2} + residual_stds = {"a": 0.1, "b": 0.05} + stepper = _build_stepper( + network_means=means, + network_stds=stds, + residual_stds=residual_stds, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_path = pathlib.Path(tmpdir) / "ckpt.tar" + _save_checkpoint(ckpt_path, stepper) + stats = extract_stats(ckpt_path) + + assert "network-means.nc" in stats + assert "network-stds.nc" in stats + assert "residual-means.nc" in stats + assert "residual-stds.nc" in stats + + for name, expected in residual_stds.items(): + assert float(stats["residual-stds.nc"][name].values) == pytest.approx( + expected + ) + + +class TestExtractStatsLegacy: + """Test extraction from legacy-format checkpoints.""" + + def test_legacy_checkpoint(self): + means = {"a": 2.0, "b": -1.0} + stds = {"a": 0.8, "b": 1.5} + loss_stds = {"a": 0.2, "b": 0.4} + + stepper = _build_stepper(network_means=means, network_stds=stds) + modern_state = stepper.get_state() + + module_weights = modern_state["step"]["wrapped_step"]["module"] + dataset_info = modern_state["dataset_info"] + legacy_config = { + "builder": {"type": "prebuilt", "config": {"module": PlusOne()}}, + "in_names": IN_NAMES, + "out_names": OUT_NAMES, + "normalization": { + "global_means_path": None, + "global_stds_path": None, + "means": means, + "stds": stds, + }, + } + legacy_stepper_state = { + "config": legacy_config, + "normalizer": {"means": means, "stds": stds}, + "loss_normalizer": {"means": means, "stds": loss_stds}, + "module": module_weights, + "vertical_coordinate": dataset_info["vertical_coordinate"], + "gridded_operations": { + "type": "LatLonOperations", + "state": {"area_weights": torch.ones(4, 8)}, + }, + "img_shape": [4, 8], + } + + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_path = pathlib.Path(tmpdir) / "legacy_ckpt.tar" + torch.save({"stepper": legacy_stepper_state}, ckpt_path) + stats = extract_stats(ckpt_path) + + assert "network-means.nc" in stats + assert "network-stds.nc" in stats + + for name, expected in means.items(): + assert float(stats["network-means.nc"][name].values) == pytest.approx( + expected + ) + for name, expected in stds.items(): + assert float(stats["network-stds.nc"][name].values) == pytest.approx( + expected + ) + + # Legacy loss_normalizer is converted to the "loss" normalization key + assert "loss-means.nc" in stats + assert "loss-stds.nc" in stats + for name, expected in loss_stds.items(): + assert float(stats["loss-stds.nc"][name].values) == pytest.approx(expected) + + +class TestWriteStats: + def test_write_creates_files(self): + stats = { + "network-means.nc": xr.Dataset({"x": xr.DataArray(1.0)}), + "network-stds.nc": xr.Dataset({"x": xr.DataArray(2.0)}), + } + with tempfile.TemporaryDirectory() as tmpdir: + write_stats(stats, tmpdir) + for filename in stats: + path = pathlib.Path(tmpdir) / filename + assert path.exists() + ds = xr.open_dataset(path) + assert "x" in ds