Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recipes/eval/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
./outputs/*
177 changes: 140 additions & 37 deletions recipes/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ forecast outputs to zarr.
Key features:

- Works with **any** Earth2Studio prognostic or diagnostic model
- **Extensible pipeline interface** — subclass `Pipeline` to add custom inference loops
- Multi-GPU distributed inference via `torchrun` / SLURM / MPI
- Clean work-distribution with automatic load balancing across ranks
- Parallel, non-blocking zarr I/O via thread pool
Expand Down Expand Up @@ -59,13 +60,12 @@ the model's `input_coords()` / `output_coords()`.
# Single process (login node or interactive session)
python predownload.py

# With a campaign config
python predownload.py +campaign=fcn3_2024_full

# Distributed — parallelise across CPU workers
torchrun --nproc_per_node=8 --standalone predownload.py

# Match IC range and model to your planned eval config
python predownload.py model=dlwp \
ic_block_start="2024-01-01" ic_block_end="2024-03-31" ic_block_step=24

# Also pre-fetch ERA5 verification data for the full forecast window
# (variables taken from output.variables — only what will be scored)
python predownload.py predownload.verification.enabled=true
Expand Down Expand Up @@ -98,34 +98,92 @@ Work items (one per initial-time / ensemble-member pair) are partitioned
automatically and evenly across ranks. Remainder items are absorbed by the
first rank rather than requiring exact divisibility.

## Resuming and Multi-Job Runs

Set `resume=true` to skip already-completed work items and append to the
existing zarr store. This is useful in two scenarios:

### Resuming after a failure

If a job is killed or times out partway through, re-submit with the same
config plus `resume=true`. Completed work items are detected via marker
files in `<output.path>/.progress/` and automatically skipped:

```bash
torchrun --nproc_per_node=$NGPU --standalone main.py resume=true
```

### Splitting work across multiple SLURM jobs

Submit N identical jobs with `resume=true`. The first job to start creates
the zarr store; subsequent jobs validate the schema and append. Each job
skips items that have already been completed by earlier jobs:

```bash
# Submit the same command multiple times (or as a SLURM array)
torchrun --nproc_per_node=$NGPU --standalone main.py resume=true
```

Because zarr chunks are non-overlapping per `(time, lead_time)` slice,
concurrent writes from different jobs to different ICs are safe.

When `resume=true`, the `output.overwrite` setting is ignored — existing
data is never deleted. When all items are complete, subsequent runs exit
immediately with a success message.

## Configuration

All configuration lives under `cfg/` and uses [Hydra](https://hydra.cc/docs/intro/).
The config is organized into three layers:

### Project and initial conditions
| Layer | Location | Purpose |
|---|---|---|
| Base | `cfg/default.yaml` | Shared defaults (pipeline, data source, output, predownload) |
| Model | `cfg/model/*.yaml` | Model architecture and checkpoint |
| Campaign | `cfg/campaign/*.yaml` | ICs, ensemble, variables, forecast length |

### Campaign configs

Campaign configs are the primary way to set up evaluation runs. They
override only what differs from the base config — model, ICs, ensemble
size, and output variables. Apply with `+campaign=`:

```bash
# DLWP monthly deterministic
python main.py +campaign=dlwp_2024_monthly

# FCN3 full 56-member ensemble
python main.py +campaign=fcn3_2024_full
```

Both `main.py` and `predownload.py` accept the same `+campaign=` flag,
so the two scripts stay in sync automatically.

To add a new model benchmark, create one file in `cfg/campaign/`:

```yaml
project: eval_run
run_id: dlwp_deterministic
# cfg/campaign/my_model_2024.yaml
# @package _global_
defaults:
- override /model: my_model

# Explicit list of ICs
run_id: my_model_2024
start_times:
- "2024-01-01 00:00:00"
- "2024-01-02 00:00:00"

# Or a range (remove start_times first). ic_block_end is inclusive on the step grid.
# ic_block_start: "2024-01-01 00:00:00"
# ic_block_end: "2024-03-31 00:00:00"
# ic_block_step: 24 # hours
nsteps: 40
output:
variables: [t2m, z500]
```

### Model selection

Models are selected via Hydra defaults. To switch models, either override
on the command line or create a new YAML under `cfg/model/`:
Models are selected via Hydra defaults. Each model config lives in
`cfg/model/` and specifies the architecture class. Campaign configs
override the model via `defaults: [override /model: ...]`, or you
can switch on the command line:

```bash
python main.py model=dlwp
python main.py model=fcn3
```

### Ensemble runs
Expand All @@ -139,18 +197,8 @@ perturbation:
noise_amplitude: 0.05
```

### Output

```yaml
output:
path: outputs/${project}_${run_id}
variables: [t2m, z500]
overwrite: true
thread_writers: 4
chunks:
time: 1
lead_time: 1
```
For stochastic models (e.g. FCN3), the pipeline also calls
`model.set_rng(seed=...)` per ensemble member when available.

### Scoring (planned)

Expand All @@ -165,28 +213,83 @@ recipes/eval/
├── main.py # Hydra entry point — distributed inference
├── predownload.py # Hydra entry point — data pre-fetch
├── cfg/
│ ├── default.yaml # Main config
│ ├── predownload.yaml # Pre-download config (inherits default.yaml)
│ └── model/
│ └── dlwp.yaml # DLWP model config
│ ├── default.yaml # Base config (shared defaults + predownload)
│ ├── predownload.yaml # Thin overlay (hydra.run.dir only)
│ ├── model/
│ │ ├── dlwp.yaml
│ │ └── fcn3.yaml
│ └── campaign/ # One file per evaluation campaign
│ ├── dlwp_2024_monthly.yaml
│ └── fcn3_2024_full.yaml
├── src/
│ ├── work.py # WorkItem, build_work_items, distribute_work
│ ├── pipeline.py # Pipeline ABC + built-in pipelines
│ ├── work.py # WorkItem, distribution, resume markers
│ ├── distributed.py # Rank-ordered execution, logging setup
│ ├── models.py # Model loading (prognostic + diagnostic)
│ ├── output.py # OutputManager (zarr lifecycle)
│ └── inference.py # Core inference loop
│ └── output.py # OutputManager (zarr lifecycle)
└── pyproject.toml
```

Each source module has a specific scoped responsibilities:

| Module | Responsibility |
|---|---|
| `pipeline.py` | `Pipeline` ABC and built-in implementations (Forecast, Diagnostic) |
| `work.py` | Define work units; parse ICs from config; distribute across ranks |
| `distributed.py` | Rank-ordered execution primitive; logging setup |
| `models.py` | Load prognostic/diagnostic models from config |
| `output.py` | Zarr store creation, validation, threaded writes, consolidation |
| `inference.py` | Fetch ICs, perturb, run model iterator, apply diagnostics, write |

### Pipeline interface

All inference logic is driven by a **Pipeline** — an abstract base class
(`src/pipeline.py`) that separates per-work-item inference from the shared
scaffolding (work iteration, output filtering, ensemble injection, zarr
writes). Subclasses implement three methods:

| Method | Purpose |
|---|---|
| `setup(cfg, device)` | Load models, move to device, cache coordinate metadata |
| `build_total_coords(times, ensemble_size)` | Define the full zarr output coordinate system |
| `run_item(item, data_source, device)` | Yield `(tensor, coords)` pairs for one work item |

The base class `Pipeline.run()` handles everything else: iterating work
items, building the output variable filter, injecting the ensemble dimension,
and writing to the `OutputManager`.

Two built-in pipelines are provided:

- **`ForecastPipeline`** (`pipeline=forecast`) — prognostic rollout with
optional diagnostic models. Yields one output per lead-time step.
- **`DiagnosticPipeline`** (`pipeline=diagnostic`) — diagnostic-only (no
prognostic model). Yields a single output per work item.

### Custom pipelines

To add a custom inference loop, subclass `Pipeline` and set `pipeline` in
your Hydra config to the fully-qualified class name:

```python
# my_pipeline.py
from src.pipeline import Pipeline

class MyPipeline(Pipeline):
def setup(self, cfg, device):
...
def build_total_coords(self, times, ensemble_size):
...
def run_item(self, item, data_source, device):
...
yield x, coords
```

```yaml
# In your Hydra config override:
pipeline: my_pipeline.MyPipeline
```

Custom pipelines inherit the full shared machinery — distributed output
management, ensemble dimension handling, threaded zarr writes — for free.

## Testing

Expand Down
32 changes: 32 additions & 0 deletions recipes/eval/cfg/campaign/dlwp_2024_monthly.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# @package _global_
# DLWP monthly 2024 — deterministic 10-day forecasts, all 7 output variables.
defaults:
- override /model: dlwp

run_id: dlwp_2024_monthly

start_times:
- "2024-01-01 00:00:00"
- "2024-02-01 00:00:00"
- "2024-03-01 00:00:00"
- "2024-04-01 00:00:00"
- "2024-05-01 00:00:00"
- "2024-06-01 00:00:00"
- "2024-07-01 00:00:00"
- "2024-08-01 00:00:00"
- "2024-09-01 00:00:00"
- "2024-10-01 00:00:00"
- "2024-11-01 00:00:00"
- "2024-12-01 00:00:00"

nsteps: 40

output:
variables:
- t2m
- t850
- z1000
- z700
- z500
- z300
- tcwv
104 changes: 104 additions & 0 deletions recipes/eval/cfg/campaign/fcn3_2024_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# @package _global_
# FCN3 full 2024 — 56-member ensemble, 00Z/12Z ICs, 10-day forecasts.
# 732 ICs x 56 members = 40,992 work items; use resume=true and
# submit multiple SLURM jobs to distribute work.
defaults:
- override /model: fcn3

run_id: fcn3_2024_full

# Every 00Z and 12Z in 2024 (366 days x 2 = 732 ICs)
start_times: null
ic_block_start: "2024-01-01 00:00:00"
ic_block_end: "2024-12-31 12:00:00"
ic_block_step: 12

nsteps: 40
ensemble_size: 56
resume: true

perturbation:
_target_: earth2studio.perturbation.CorrelatedSphericalGaussian
noise_amplitude: 0.05

output:
overwrite: false
variables:
# Surface (7)
- u10m
- v10m
- u100m
- v100m
- t2m
- msl
- tcwv
# U-wind at pressure levels (13)
- u50
- u100
- u150
- u200
- u250
- u300
- u400
- u500
- u600
- u700
- u850
- u925
- u1000
# V-wind at pressure levels (13)
- v50
- v100
- v150
- v200
- v250
- v300
- v400
- v500
- v600
- v700
- v850
- v925
- v1000
# Geopotential at pressure levels (13)
- z50
- z100
- z150
- z200
- z250
- z300
- z400
- z500
- z600
- z700
- z850
- z925
- z1000
# Temperature at pressure levels (13)
- t50
- t100
- t150
- t200
- t250
- t300
- t400
- t500
- t600
- t700
- t850
- t925
- t1000
# Specific humidity at pressure levels (13)
- q50
- q100
- q150
- q200
- q250
- q300
- q400
- q500
- q600
- q700
- q850
- q925
- q1000
Loading