From 16a6707a671c293c5382888d43d0c01c60e67a66 Mon Sep 17 00:00:00 2001 From: Jared Lewis Date: Fri, 6 Mar 2026 17:15:27 +1100 Subject: [PATCH 1/6] perf: batch dataset ingestion commits and fix spurious update logging Batch ingestion commits in groups of 50 to reduce SQLite fsync overhead from N commits to N/50. Memory is managed via expire_all() after each batch. Fix incorrect "updated" logging caused by numpy scalar types (e.g. numpy.int64) comparing unequal to Python native types from the database. Added _to_python_native() to normalize values before comparison in _values_differ(). --- .../climate-ref/src/climate_ref/database.py | 26 +++++++++- .../src/climate_ref/datasets/__init__.py | 51 ++++++++++++------- 2 files changed, 58 insertions(+), 19 deletions(-) diff --git a/packages/climate-ref/src/climate_ref/database.py b/packages/climate-ref/src/climate_ref/database.py index 02edafaae..87e196074 100644 --- a/packages/climate-ref/src/climate_ref/database.py +++ b/packages/climate-ref/src/climate_ref/database.py @@ -139,6 +139,21 @@ def validate_database_url(database_url: str) -> str: return database_url +def _to_python_native(value: Any) -> Any: + """ + Convert numpy/pandas scalar types to Python natives for reliable comparison. + + Without this, comparisons like ``numpy.int64(5) != 5`` may produce + numpy booleans that behave unexpectedly, and type mismatches between + DB values (Python natives) and DataFrame values (numpy scalars) + can cause spurious "updated" detections. + """ + # numpy scalars expose .item() to convert to Python native + if hasattr(value, "item"): + return value.item() + return value + + def _values_differ(current: Any, new: Any) -> bool: """ Safely compare two values for inequality, handling ``pd.NA`` and ``np.nan``. @@ -146,7 +161,13 @@ def _values_differ(current: Any, new: Any) -> bool: Direct ``!=`` comparison with ``pd.NA`` raises ``TypeError`` because ``bool(pd.NA)`` is ambiguous. This helper avoids that by checking for NA on both sides first. + + Values are normalised to Python native types before comparison + to avoid spurious mismatches between numpy scalars and Python builtins. """ + current = _to_python_native(current) + new = _to_python_native(new) + try: current_is_na = pd.isna(current) new_is_na = pd.isna(new) @@ -341,8 +362,9 @@ def update_or_create( # Update existing instance with defaults if defaults: for key, value in defaults.items(): - if _values_differ(getattr(instance, key), value): - logger.debug(f"Updating {model.__name__} {key} to {value}") + current = getattr(instance, key) + if _values_differ(current, value): + logger.debug(f"Updating {model.__name__} {key}: {current!r} -> {value!r}") setattr(instance, key, value) state = ModelState.UPDATED return instance, state diff --git a/packages/climate-ref/src/climate_ref/datasets/__init__.py b/packages/climate-ref/src/climate_ref/datasets/__init__.py index 535ad1e3b..027b19a66 100644 --- a/packages/climate-ref/src/climate_ref/datasets/__init__.py +++ b/packages/climate-ref/src/climate_ref/datasets/__init__.py @@ -10,7 +10,7 @@ from loguru import logger from climate_ref.database import Database, ModelState -from climate_ref.datasets.base import DatasetAdapter +from climate_ref.datasets.base import DatasetAdapter, DatasetRegistrationResult from climate_ref.datasets.cmip6 import CMIP6DatasetAdapter from climate_ref.datasets.cmip7 import CMIP7DatasetAdapter from climate_ref.datasets.obs4mips import Obs4MIPsDatasetAdapter @@ -71,6 +71,25 @@ def log_summary(self, prefix: str = "") -> None: ) +def _accumulate_stats(stats: IngestionStats, results: "DatasetRegistrationResult") -> None: + """Accumulate registration results into ingestion stats.""" + if results.dataset_state == ModelState.CREATED: + stats.datasets_created += 1 + elif results.dataset_state == ModelState.UPDATED: + stats.datasets_updated += 1 + else: + stats.datasets_unchanged += 1 + stats.files_added += len(results.files_added) + stats.files_updated += len(results.files_updated) + stats.files_removed += len(results.files_removed) + stats.files_unchanged += len(results.files_unchanged) + + +#: Number of datasets to commit in a single transaction. +#: Batching reduces SQLite fsync overhead from N commits to N/BATCH_SIZE commits. +INGEST_BATCH_SIZE = 50 + + def ingest_datasets( adapter: DatasetAdapter, directory: Path | None, @@ -85,6 +104,10 @@ def ingest_datasets( This is the common ingestion logic shared between the CLI ingest command and provider setup. + Datasets are committed in batches to reduce database overhead. + Each batch is a single transaction; if a dataset within a batch fails, + the entire batch is rolled back. + Parameters ---------- adapter @@ -132,23 +155,17 @@ def ingest_datasets( stats = IngestionStats() - for instance_id, data_catalog_dataset in data_catalog.groupby(adapter.slug_column): - logger.debug(f"Processing dataset {instance_id}") + groups = list(data_catalog.groupby(adapter.slug_column)) + + for batch_start in range(0, len(groups), INGEST_BATCH_SIZE): + batch = groups[batch_start : batch_start + INGEST_BATCH_SIZE] with db.session.begin(): - results = adapter.register_dataset(db, data_catalog_dataset) - - if results.dataset_state == ModelState.CREATED: - stats.datasets_created += 1 - elif results.dataset_state == ModelState.UPDATED: - stats.datasets_updated += 1 - else: - stats.datasets_unchanged += 1 - stats.files_added += len(results.files_added) - stats.files_updated += len(results.files_updated) - stats.files_removed += len(results.files_removed) - stats.files_unchanged += len(results.files_unchanged) - - # Release ORM objects from the session identity map after each commit. + for instance_id, data_catalog_dataset in batch: + logger.debug(f"Processing dataset {instance_id}") + results = adapter.register_dataset(db, data_catalog_dataset) + _accumulate_stats(stats, results) + + # Release ORM objects from the session identity map after each batch commit. # Without this, all Dataset and DatasetFile objects accumulate in memory # across the entire ingestion loop. db.session.expire_all() From 69687557e387b6555b2be07bd52cd2c79b16fd54 Mon Sep 17 00:00:00 2001 From: Jared Lewis Date: Fri, 6 Mar 2026 21:32:12 +1100 Subject: [PATCH 2/6] docs: add changelog entry for PR #581 --- changelog/581.improvement.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/581.improvement.md diff --git a/changelog/581.improvement.md b/changelog/581.improvement.md new file mode 100644 index 000000000..464c389ce --- /dev/null +++ b/changelog/581.improvement.md @@ -0,0 +1 @@ +Batched dataset ingestion commits in groups of 50 to reduce SQLite overhead and fixed spurious "updated" logging caused by numpy scalar type mismatches. From 312105a7f25cdf039548c6583484008732a787be Mon Sep 17 00:00:00 2001 From: Jared Lewis Date: Sat, 7 Mar 2026 00:59:22 +0000 Subject: [PATCH 3/6] fix: filtering datasets --- .../src/climate_ref/cli/datasets.py | 22 +++---- .../src/climate_ref/datasets/__init__.py | 4 +- .../src/climate_ref/datasets/base.py | 49 +++++++++++---- .../climate-ref/tests/unit/test_database.py | 59 ++++++++++++++++++- 4 files changed, 109 insertions(+), 25 deletions(-) diff --git a/packages/climate-ref/src/climate_ref/cli/datasets.py b/packages/climate-ref/src/climate_ref/cli/datasets.py index 378b0c166..29a6d6442 100644 --- a/packages/climate-ref/src/climate_ref/cli/datasets.py +++ b/packages/climate-ref/src/climate_ref/cli/datasets.py @@ -16,7 +16,6 @@ from climate_ref.cli._utils import pretty_print_df from climate_ref.models import Dataset -from climate_ref.solver import apply_dataset_filters from climate_ref_core.dataset_registry import dataset_registry_manager, fetch_all_files from climate_ref_core.source_types import SourceDatasetType @@ -57,11 +56,9 @@ def list_( # noqa: PLR0913 database = ctx.obj.database - adapter = get_dataset_adapter(source_type.value) - data_catalog = adapter.load_catalog(database, include_files=include_files, limit=limit) - + parsed_filters: dict[str, list[str]] | None = None if dataset_filter: - parsed_filters: dict[str, list[str]] = {} + parsed_filters = {} for entry in dataset_filter: if "=" not in entry: raise typer.BadParameter( @@ -71,16 +68,21 @@ def list_( # noqa: PLR0913 key, value = entry.split("=", 1) parsed_filters.setdefault(key, []).append(value) + adapter = get_dataset_adapter(source_type.value) + + if parsed_filters: + valid_facets = set(adapter.dataset_specific_metadata) for facet in parsed_filters: - if facet not in data_catalog.columns: + if facet not in valid_facets: logger.error( - f"Filter facet '{facet}' not found in data catalog. " - f"Choose from: {', '.join(sorted(data_catalog.columns))}" + f"Filter facet '{facet}' not found in dataset metadata. " + f"Choose from: {', '.join(sorted(valid_facets))}" ) raise typer.Exit(code=1) - filtered = apply_dataset_filters({source_type: data_catalog}, parsed_filters) - data_catalog = filtered[source_type] # type: ignore[assignment] # input is DataFrame + data_catalog = adapter.load_catalog( + database, include_files=include_files, limit=limit, filters=parsed_filters + ) if column: missing = set(column) - set(data_catalog.columns) diff --git a/packages/climate-ref/src/climate_ref/datasets/__init__.py b/packages/climate-ref/src/climate_ref/datasets/__init__.py index 027b19a66..e54299f8e 100644 --- a/packages/climate-ref/src/climate_ref/datasets/__init__.py +++ b/packages/climate-ref/src/climate_ref/datasets/__init__.py @@ -85,8 +85,8 @@ def _accumulate_stats(stats: IngestionStats, results: "DatasetRegistrationResult stats.files_unchanged += len(results.files_unchanged) -#: Number of datasets to commit in a single transaction. -#: Batching reduces SQLite fsync overhead from N commits to N/BATCH_SIZE commits. +# Number of datasets to commit in a single transaction. +# Batching reduces SQLite fsync overhead from N commits to N/BATCH_SIZE commits. INGEST_BATCH_SIZE = 50 diff --git a/packages/climate-ref/src/climate_ref/datasets/base.py b/packages/climate-ref/src/climate_ref/datasets/base.py index de58554f5..3e619d010 100644 --- a/packages/climate-ref/src/climate_ref/datasets/base.py +++ b/packages/climate-ref/src/climate_ref/datasets/base.py @@ -409,10 +409,15 @@ def filter_latest_versions(self, catalog: pd.DataFrame) -> pd.DataFrame: return catalog[catalog[self.version_metadata] == max_version_per_group] - def _get_dataset_files(self, db: Database, limit: int | None = None) -> pd.DataFrame: + def _get_dataset_files( + self, + db: Database, + limit: int | None = None, + filters: dict[str, list[str]] | None = None, + ) -> pd.DataFrame: dataset_type = self.dataset_cls.__mapper_args__["polymorphic_identity"] - result = ( + query = ( db.session.query(DatasetFile) # The join is necessary to be able to order by the dataset columns .join(DatasetFile.dataset) @@ -420,11 +425,16 @@ def _get_dataset_files(self, db: Database, limit: int | None = None) -> pd.DataF # The joinedload is necessary to avoid N+1 queries (one for each dataset) # https://docs.sqlalchemy.org/en/14/orm/loading_relationships.html#the-zen-of-joined-eager-loading .options(joinedload(DatasetFile.dataset.of_type(self.dataset_cls))) - .order_by(Dataset.updated_at.desc()) - .limit(limit) - .all() ) + if filters: + for key, values in filters.items(): + column = getattr(self.dataset_cls, key, None) + if column is not None: + query = query.where(column.in_(values)) + + result = query.order_by(Dataset.updated_at.desc()).limit(limit).all() + return pd.DataFrame( [ { @@ -437,10 +447,21 @@ def _get_dataset_files(self, db: Database, limit: int | None = None) -> pd.DataF index=[file.dataset.id for file in result], ) - def _get_datasets(self, db: Database, limit: int | None = None) -> pd.DataFrame: - result_datasets = ( - db.session.query(self.dataset_cls).order_by(Dataset.updated_at.desc()).limit(limit).all() - ) + def _get_datasets( + self, + db: Database, + limit: int | None = None, + filters: dict[str, list[str]] | None = None, + ) -> pd.DataFrame: + query = db.session.query(self.dataset_cls) + + if filters: + for key, values in filters.items(): + column = getattr(self.dataset_cls, key, None) + if column is not None: + query = query.where(column.in_(values)) + + result_datasets = query.order_by(Dataset.updated_at.desc()).limit(limit).all() return pd.DataFrame( [{k: getattr(dataset, k) for k in self.dataset_specific_metadata} for dataset in result_datasets], @@ -448,7 +469,11 @@ def _get_datasets(self, db: Database, limit: int | None = None) -> pd.DataFrame: ) def load_catalog( - self, db: Database, include_files: bool = True, limit: int | None = None + self, + db: Database, + include_files: bool = True, + limit: int | None = None, + filters: dict[str, list[str]] | None = None, ) -> pd.DataFrame: """ Load the data catalog containing the currently tracked datasets/files from the database @@ -469,9 +494,9 @@ def load_catalog( with db.session.begin(): # TODO: Paginate this query to avoid loading all the data at once if include_files: - catalog = self._get_dataset_files(db, limit) + catalog = self._get_dataset_files(db, limit, filters) else: - catalog = self._get_datasets(db, limit) + catalog = self._get_datasets(db, limit, filters) # If there are no datasets, return an empty DataFrame if catalog.empty: diff --git a/packages/climate-ref/tests/unit/test_database.py b/packages/climate-ref/tests/unit/test_database.py index 6b114bb1c..bcf50331b 100644 --- a/packages/climate-ref/tests/unit/test_database.py +++ b/packages/climate-ref/tests/unit/test_database.py @@ -8,7 +8,13 @@ import sqlalchemy from sqlalchemy import inspect -from climate_ref.database import Database, _create_backup, _values_differ, validate_database_url +from climate_ref.database import ( + Database, + _create_backup, + _to_python_native, + _values_differ, + validate_database_url, +) from climate_ref.models import MetricValue from climate_ref.models.dataset import CMIP6Dataset, Dataset, Obs4MIPsDataset from climate_ref_core.datasets import SourceDatasetType @@ -360,3 +366,54 @@ def test_bool_vs_pd_na(self): assert _values_differ(True, pd.NA) assert _values_differ(False, pd.NA) assert _values_differ(pd.NA, True) + + def test_numpy_int64_equal_to_python_int(self): + """numpy.int64 from a DataFrame vs Python int from the DB must not flag as changed.""" + assert not _values_differ(np.int64(5), 5) + assert not _values_differ(5, np.int64(5)) + + def test_numpy_float64_equal_to_python_float(self): + assert not _values_differ(np.float64(1.5), 1.5) + assert not _values_differ(1.5, np.float64(1.5)) + + def test_numpy_int64_different_from_python_int(self): + assert _values_differ(np.int64(5), 6) + assert _values_differ(6, np.int64(5)) + + def test_numpy_bool_equal(self): + assert not _values_differ(np.bool_(True), True) + assert not _values_differ(np.bool_(False), False) + + +class TestToPythonNative: + """Tests for _to_python_native which normalises numpy scalars to Python builtins.""" + + def test_numpy_int64_converts_to_python_int(self): + result = _to_python_native(np.int64(42)) + assert result == 42 + assert type(result) is int + + def test_numpy_float64_converts_to_python_float(self): + result = _to_python_native(np.float64(3.14)) + assert abs(result - 3.14) < 1e-9 + assert type(result) is float + + def test_numpy_bool_converts_to_python_bool(self): + result = _to_python_native(np.bool_(True)) + assert result is True + assert type(result) is bool + + def test_python_str_passthrough(self): + assert _to_python_native("atmos") == "atmos" + + def test_python_int_passthrough(self): + value = 7 + assert _to_python_native(value) is value + + def test_none_passthrough(self): + assert _to_python_native(None) is None + + def test_pd_na_passthrough(self): + """pd.NA has no .item() so must be returned unchanged.""" + result = _to_python_native(pd.NA) + assert pd.isna(result) From 7dfb4daab876e437a0b60fbbbe166fc8480a128a Mon Sep 17 00:00:00 2001 From: Jared Lewis Date: Sat, 7 Mar 2026 01:26:15 +0000 Subject: [PATCH 4/6] fix: resolve false positive file updates from cftime vs string comparison File metadata (start_time/end_time) is stored as strings in the DB but incoming values from parsers are cftime.datetime objects. The raw != comparison always returned True for these mismatched types, causing every file to be flagged as "updated" on every re-ingest. Add _file_meta_differs() that coerces non-string values to str before comparing, matching the DatasetFile._coerce_time_to_str storage behavior. --- .../src/climate_ref/datasets/base.py | 18 ++++++++-- .../tests/unit/datasets/test_datasets.py | 34 ++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/packages/climate-ref/src/climate_ref/datasets/base.py b/packages/climate-ref/src/climate_ref/datasets/base.py index 3e619d010..110295d18 100644 --- a/packages/climate-ref/src/climate_ref/datasets/base.py +++ b/packages/climate-ref/src/climate_ref/datasets/base.py @@ -6,12 +6,26 @@ from loguru import logger from sqlalchemy.orm import joinedload -from climate_ref.database import Database, ModelState +from climate_ref.database import Database, ModelState, _values_differ from climate_ref.datasets.utils import _is_na, parse_cftime_dates, validate_path from climate_ref.models.dataset import Dataset, DatasetFile from climate_ref_core.exceptions import RefException +def _file_meta_differs(db_value: Any, new_value: Any) -> bool: + """ + Compare a DB-stored file metadata value against an incoming value. + + DatasetFile stores start_time/end_time as strings (via _coerce_time_to_str), + but incoming values from the parser may be cftime.datetime objects. + Coerce non-string values to strings before comparison to avoid + type-mismatch false positives. + """ + if not isinstance(new_value, str) and new_value is not None: + new_value = str(new_value) + return _values_differ(db_value, new_value) + + @define class DatasetRegistrationResult: """ @@ -322,7 +336,7 @@ def register_dataset( # noqa: PLR0912, PLR0915 changed = any( not _is_na(new_meta.get(c)) and hasattr(existing_file, c) - and getattr(existing_file, c) != new_meta[c] + and _file_meta_differs(getattr(existing_file, c), new_meta[c]) for c in file_meta_cols if c in new_meta ) diff --git a/packages/climate-ref/tests/unit/datasets/test_datasets.py b/packages/climate-ref/tests/unit/datasets/test_datasets.py index 86b713959..568bc0763 100644 --- a/packages/climate-ref/tests/unit/datasets/test_datasets.py +++ b/packages/climate-ref/tests/unit/datasets/test_datasets.py @@ -1,5 +1,6 @@ from pathlib import Path +import cftime import numpy as np import pandas as pd import pytest @@ -8,7 +9,7 @@ from climate_ref.database import Database, ModelState from climate_ref.datasets import IngestionStats, get_dataset_adapter, ingest_datasets from climate_ref.datasets import base as base_module -from climate_ref.datasets.base import DatasetAdapter, _is_na +from climate_ref.datasets.base import DatasetAdapter, _file_meta_differs, _is_na from climate_ref.datasets.cmip6 import CMIP6DatasetAdapter from climate_ref.models.dataset import CMIP6Dataset, DatasetFile from climate_ref_core.datasets import SourceDatasetType @@ -980,3 +981,34 @@ def test_zero(self): def test_empty_string(self): assert _is_na("") is False + + +class TestFileMetaDiffers: + """Tests for _file_meta_differs, which compares DB-stored file metadata against incoming values.""" + + def test_cftime_vs_matching_string_is_not_different(self): + """cftime.datetime stringifies to the same value stored in the DB.""" + db_value = "1850-01-16 12:00:00" + incoming = cftime.datetime(1850, 1, 16, 12, 0, 0, 0, calendar="proleptic_gregorian") + assert _file_meta_differs(db_value, incoming) is False + + def test_cftime_vs_different_string_is_different(self): + """cftime.datetime with different date is detected as changed.""" + db_value = "1850-01-01 00:00:00" + incoming = cftime.datetime(1850, 1, 16, 12, 0, 0, 0, calendar="standard") + assert _file_meta_differs(db_value, incoming) is True + + def test_string_vs_string_equal(self): + assert _file_meta_differs("1850-01-01 00:00:00", "1850-01-01 00:00:00") is False + + def test_string_vs_string_different(self): + assert _file_meta_differs("1850-01-01 00:00:00", "1851-01-01 00:00:00") is True + + def test_none_vs_none(self): + assert _file_meta_differs(None, None) is False + + def test_none_vs_value(self): + assert _file_meta_differs(None, "1850-01-01") is True + + def test_value_vs_none(self): + assert _file_meta_differs("1850-01-01", None) is True From 02a906b1159566a686de1bb115b55017b4962a76 Mon Sep 17 00:00:00 2001 From: Jared Lewis Date: Sat, 7 Mar 2026 16:21:13 +1100 Subject: [PATCH 5/6] chore: Introduce a max members parameter --- scripts/fetch-esgf.py | 52 +++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/scripts/fetch-esgf.py b/scripts/fetch-esgf.py index 9538ca5ce..54e009b7e 100644 --- a/scripts/fetch-esgf.py +++ b/scripts/fetch-esgf.py @@ -2,8 +2,8 @@ CLI tool for fetching the required CMIP6 and Obs4MIPs datasets from ESGF. This script can either run all predefined requests or a specific request by ID. -By default, only one ensemble member per model is fetched to reduce the total data volume. -This can be changed with the --no-remove-ensembles flag. +By default, up to 10 ensemble members per source_id are fetched to reduce the total data volume. +This can be changed with the --max-ensembles flag (0 = no limit). This fetches about 3TB of datasets into the default location for intake esgf. This can be adjusted via `~/.config/intake-esgf/config.yaml`. @@ -24,15 +24,15 @@ class CMIP6Request: id: str facets: dict[str, str | tuple[str, ...] | list[str]] - def fetch(self, remove_ensembles: bool = True): + def fetch(self, max_members: int = 1): """ Fetch CMIP6 data from the ESGF catalog and return it as a DataFrame. Parameters ---------- - remove_ensembles : bool, default True - Whether to remove ensemble members, keeping only one per model. - If False, all ensemble members will be included. + max_members : int, default 10 + Maximum number of ensemble members to fetch per source_id. + Set to 0 for no limit (fetch all ensemble members). Returns ------- @@ -48,8 +48,12 @@ def fetch(self, remove_ensembles: bool = True): logger.debug(f"Fetching CMIP6 data: {search_parameters}") try: cmip6_data = catalog.search(**search_parameters) - if remove_ensembles: - cmip6_data = cmip6_data.remove_ensembles() + if max_members > 0 and cmip6_data.df is not None: + df = cmip6_data.df + mask = df.groupby("source_id")["member_id"].transform( + lambda s: s.isin(s.unique()[:max_members]) + ) + cmip6_data.df = df[mask] return cmip6_data.to_path_dict() except Exception: logger.info(f"Error fetching CMIP6 data: {search_parameters}") @@ -65,13 +69,13 @@ class Obs4MIPsRequest: id: str facets: dict[str, str | tuple[str, ...] | list[str]] - def fetch(self, remove_ensembles: bool = True): + def fetch(self, max_members: int = 1): """ Fetch Obs4MIPs data from the ESGF catalog and return it as a DataFrame. Parameters ---------- - remove_ensembles : bool, default True + max_members : int, default 1 Ignored as Obs4MIPs data does not have ensembles. Returns @@ -332,12 +336,12 @@ def fetch(self, remove_ensembles: bool = True): ] -def run_request(request: Request, remove_ensembles: bool = True): +def run_request(request: Request, max_members: int = 1): """ Fetch and log the results of a request """ print(f"Processing request: {request.id}") - df = request.fetch(remove_ensembles=remove_ensembles) + df = request.fetch(max_members=max_members) print(f"{len(df)} datasets") print("\n") @@ -346,11 +350,11 @@ def main( request_id: str = typer.Option( None, help="ID of a specific request to run. If not provided, all requests will be run." ), - remove_ensembles: bool = typer.Option( - True, + max_members: int = typer.Option( + 1, help=( - "Remove ensemble members, keeping only one per model. " - "Use --no-remove-ensembles to fetch all ensembles." + "Maximum number of ensemble members to fetch per source_id. " + "Set to 0 to fetch all ensemble members." ), ), ): @@ -358,8 +362,8 @@ def main( Fetch CMIP6 datasets from ESGF. This script can either run all predefined requests or a specific request by ID. - By default, only one ensemble member per model is fetched, but this can be changed - with the --no-remove-ensembles flag. + By default, up to 1 ensemble members per source_id are fetched, but this can be + changed with --max-ensembles (use 0 for no limit). """ if request_id: # Find and run the specific request @@ -372,15 +376,15 @@ def main( raise typer.Exit(1) logger.info(f"Running single request: {request_id}") - if not remove_ensembles: - logger.info("Fetching all ensemble members") - run_request(matching_requests[0], remove_ensembles=remove_ensembles) + limit = max_members if max_members > 0 else "unlimited" + logger.info(f"Max ensemble members per source_id: {limit}") + run_request(matching_requests[0], max_members=max_members) else: logger.info("Running all requests...") - if not remove_ensembles: - logger.info("Fetching all ensemble members") + limit = max_members if max_members > 0 else "unlimited" + logger.info(f"Max ensemble members per source_id: {limit}") for request in requests: - run_request(request, remove_ensembles=remove_ensembles) + run_request(request, max_members=max_members) # joblib.Parallel(n_jobs=2)(joblib.delayed(run_request)(request) for request in requests) From f5b8cfe9c5409adcca4e23c339d96a1654bf0e96 Mon Sep 17 00:00:00 2001 From: Jared Lewis Date: Sat, 7 Mar 2026 16:24:52 +1100 Subject: [PATCH 6/6] chore: sort before selecting --- scripts/fetch-esgf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fetch-esgf.py b/scripts/fetch-esgf.py index 54e009b7e..5f3178dde 100644 --- a/scripts/fetch-esgf.py +++ b/scripts/fetch-esgf.py @@ -51,7 +51,7 @@ def fetch(self, max_members: int = 1): if max_members > 0 and cmip6_data.df is not None: df = cmip6_data.df mask = df.groupby("source_id")["member_id"].transform( - lambda s: s.isin(s.unique()[:max_members]) + lambda s: s.isin(sorted(s.unique())[:max_members]) ) cmip6_data.df = df[mask] return cmip6_data.to_path_dict()