Skip to content
Open
1 change: 1 addition & 0 deletions changelog/581.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Batched dataset ingestion commits in groups of 50 to reduce SQLite overhead and fixed spurious "updated" logging caused by numpy scalar type mismatches.
22 changes: 12 additions & 10 deletions packages/climate-ref/src/climate_ref/cli/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from climate_ref.cli._utils import pretty_print_df
from climate_ref.models import Dataset
from climate_ref.solver import apply_dataset_filters
from climate_ref_core.dataset_registry import dataset_registry_manager, fetch_all_files
from climate_ref_core.source_types import SourceDatasetType

Expand Down Expand Up @@ -57,11 +56,9 @@ def list_( # noqa: PLR0913

database = ctx.obj.database

adapter = get_dataset_adapter(source_type.value)
data_catalog = adapter.load_catalog(database, include_files=include_files, limit=limit)

parsed_filters: dict[str, list[str]] | None = None
if dataset_filter:
parsed_filters: dict[str, list[str]] = {}
parsed_filters = {}
for entry in dataset_filter:
if "=" not in entry:
raise typer.BadParameter(
Expand All @@ -71,16 +68,21 @@ def list_( # noqa: PLR0913
key, value = entry.split("=", 1)
parsed_filters.setdefault(key, []).append(value)

adapter = get_dataset_adapter(source_type.value)

if parsed_filters:
valid_facets = set(adapter.dataset_specific_metadata)
for facet in parsed_filters:
if facet not in data_catalog.columns:
if facet not in valid_facets:
logger.error(
f"Filter facet '{facet}' not found in data catalog. "
f"Choose from: {', '.join(sorted(data_catalog.columns))}"
f"Filter facet '{facet}' not found in dataset metadata. "
f"Choose from: {', '.join(sorted(valid_facets))}"
)
raise typer.Exit(code=1)

filtered = apply_dataset_filters({source_type: data_catalog}, parsed_filters)
data_catalog = filtered[source_type] # type: ignore[assignment] # input is DataFrame
data_catalog = adapter.load_catalog(
database, include_files=include_files, limit=limit, filters=parsed_filters
)

if column:
missing = set(column) - set(data_catalog.columns)
Expand Down
26 changes: 24 additions & 2 deletions packages/climate-ref/src/climate_ref/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,35 @@ def validate_database_url(database_url: str) -> str:
return database_url


def _to_python_native(value: Any) -> Any:
"""
Convert numpy/pandas scalar types to Python natives for reliable comparison.

Without this, comparisons like ``numpy.int64(5) != 5`` may produce
numpy booleans that behave unexpectedly, and type mismatches between
DB values (Python natives) and DataFrame values (numpy scalars)
can cause spurious "updated" detections.
"""
# numpy scalars expose .item() to convert to Python native
if hasattr(value, "item"):
return value.item()
return value


def _values_differ(current: Any, new: Any) -> bool:
"""
Safely compare two values for inequality, handling ``pd.NA`` and ``np.nan``.

Direct ``!=`` comparison with ``pd.NA`` raises ``TypeError`` because
``bool(pd.NA)`` is ambiguous. This helper avoids that by checking
for NA on both sides first.

Values are normalised to Python native types before comparison
to avoid spurious mismatches between numpy scalars and Python builtins.
"""
current = _to_python_native(current)
new = _to_python_native(new)

try:
current_is_na = pd.isna(current)
new_is_na = pd.isna(new)
Expand Down Expand Up @@ -341,8 +362,9 @@ def update_or_create(
# Update existing instance with defaults
if defaults:
for key, value in defaults.items():
if _values_differ(getattr(instance, key), value):
logger.debug(f"Updating {model.__name__} {key} to {value}")
current = getattr(instance, key)
if _values_differ(current, value):
logger.debug(f"Updating {model.__name__} {key}: {current!r} -> {value!r}")
setattr(instance, key, value)
state = ModelState.UPDATED
return instance, state
Expand Down
51 changes: 34 additions & 17 deletions packages/climate-ref/src/climate_ref/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from loguru import logger

from climate_ref.database import Database, ModelState
from climate_ref.datasets.base import DatasetAdapter
from climate_ref.datasets.base import DatasetAdapter, DatasetRegistrationResult
from climate_ref.datasets.cmip6 import CMIP6DatasetAdapter
from climate_ref.datasets.cmip7 import CMIP7DatasetAdapter
from climate_ref.datasets.obs4mips import Obs4MIPsDatasetAdapter
Expand Down Expand Up @@ -71,6 +71,25 @@ def log_summary(self, prefix: str = "") -> None:
)


def _accumulate_stats(stats: IngestionStats, results: "DatasetRegistrationResult") -> None:
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation for the results parameter in _accumulate_stats uses a forward reference string "DatasetRegistrationResult", but DatasetRegistrationResult is already directly imported at line 13 (from climate_ref.datasets.base import DatasetAdapter, DatasetRegistrationResult). The forward reference string is unnecessary and can be replaced with the direct type DatasetRegistrationResult.

Suggested change
def _accumulate_stats(stats: IngestionStats, results: "DatasetRegistrationResult") -> None:
def _accumulate_stats(stats: IngestionStats, results: DatasetRegistrationResult) -> None:

Copilot uses AI. Check for mistakes.
"""Accumulate registration results into ingestion stats."""
if results.dataset_state == ModelState.CREATED:
stats.datasets_created += 1
elif results.dataset_state == ModelState.UPDATED:
stats.datasets_updated += 1
else:
stats.datasets_unchanged += 1
stats.files_added += len(results.files_added)
stats.files_updated += len(results.files_updated)
stats.files_removed += len(results.files_removed)
stats.files_unchanged += len(results.files_unchanged)


# Number of datasets to commit in a single transaction.
# Batching reduces SQLite fsync overhead from N commits to N/BATCH_SIZE commits.
INGEST_BATCH_SIZE = 50


def ingest_datasets(
adapter: DatasetAdapter,
directory: Path | None,
Expand All @@ -85,6 +104,10 @@ def ingest_datasets(
This is the common ingestion logic shared between the CLI ingest command
and provider setup.

Datasets are committed in batches to reduce database overhead.
Each batch is a single transaction; if a dataset within a batch fails,
the entire batch is rolled back.

Parameters
----------
adapter
Expand Down Expand Up @@ -132,23 +155,17 @@ def ingest_datasets(

stats = IngestionStats()

for instance_id, data_catalog_dataset in data_catalog.groupby(adapter.slug_column):
logger.debug(f"Processing dataset {instance_id}")
groups = list(data_catalog.groupby(adapter.slug_column))

for batch_start in range(0, len(groups), INGEST_BATCH_SIZE):
batch = groups[batch_start : batch_start + INGEST_BATCH_SIZE]
with db.session.begin():
results = adapter.register_dataset(db, data_catalog_dataset)

if results.dataset_state == ModelState.CREATED:
stats.datasets_created += 1
elif results.dataset_state == ModelState.UPDATED:
stats.datasets_updated += 1
else:
stats.datasets_unchanged += 1
stats.files_added += len(results.files_added)
stats.files_updated += len(results.files_updated)
stats.files_removed += len(results.files_removed)
stats.files_unchanged += len(results.files_unchanged)

# Release ORM objects from the session identity map after each commit.
for instance_id, data_catalog_dataset in batch:
logger.debug(f"Processing dataset {instance_id}")
results = adapter.register_dataset(db, data_catalog_dataset)
_accumulate_stats(stats, results)

# Release ORM objects from the session identity map after each batch commit.
# Without this, all Dataset and DatasetFile objects accumulate in memory
# across the entire ingestion loop.
db.session.expire_all()
Expand Down
67 changes: 53 additions & 14 deletions packages/climate-ref/src/climate_ref/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,26 @@
from loguru import logger
from sqlalchemy.orm import joinedload

from climate_ref.database import Database, ModelState
from climate_ref.database import Database, ModelState, _values_differ
from climate_ref.datasets.utils import _is_na, parse_cftime_dates, validate_path
from climate_ref.models.dataset import Dataset, DatasetFile
from climate_ref_core.exceptions import RefException


def _file_meta_differs(db_value: Any, new_value: Any) -> bool:
"""
Compare a DB-stored file metadata value against an incoming value.

DatasetFile stores start_time/end_time as strings (via _coerce_time_to_str),
but incoming values from the parser may be cftime.datetime objects.
Coerce non-string values to strings before comparison to avoid
type-mismatch false positives.
"""
if not isinstance(new_value, str) and new_value is not None:
new_value = str(new_value)
return _values_differ(db_value, new_value)


@define
class DatasetRegistrationResult:
"""
Expand Down Expand Up @@ -322,7 +336,7 @@ def register_dataset( # noqa: PLR0912, PLR0915
changed = any(
not _is_na(new_meta.get(c))
and hasattr(existing_file, c)
and getattr(existing_file, c) != new_meta[c]
and _file_meta_differs(getattr(existing_file, c), new_meta[c])
for c in file_meta_cols
if c in new_meta
)
Expand Down Expand Up @@ -409,22 +423,32 @@ def filter_latest_versions(self, catalog: pd.DataFrame) -> pd.DataFrame:

return catalog[catalog[self.version_metadata] == max_version_per_group]

def _get_dataset_files(self, db: Database, limit: int | None = None) -> pd.DataFrame:
def _get_dataset_files(
self,
db: Database,
limit: int | None = None,
filters: dict[str, list[str]] | None = None,
) -> pd.DataFrame:
dataset_type = self.dataset_cls.__mapper_args__["polymorphic_identity"]

result = (
query = (
db.session.query(DatasetFile)
# The join is necessary to be able to order by the dataset columns
.join(DatasetFile.dataset)
.where(Dataset.dataset_type == dataset_type)
# The joinedload is necessary to avoid N+1 queries (one for each dataset)
# https://docs.sqlalchemy.org/en/14/orm/loading_relationships.html#the-zen-of-joined-eager-loading
.options(joinedload(DatasetFile.dataset.of_type(self.dataset_cls)))
.order_by(Dataset.updated_at.desc())
.limit(limit)
.all()
)

if filters:
for key, values in filters.items():
column = getattr(self.dataset_cls, key, None)
if column is not None:
query = query.where(column.in_(values))

result = query.order_by(Dataset.updated_at.desc()).limit(limit).all()

return pd.DataFrame(
[
{
Expand All @@ -437,18 +461,33 @@ def _get_dataset_files(self, db: Database, limit: int | None = None) -> pd.DataF
index=[file.dataset.id for file in result],
)

def _get_datasets(self, db: Database, limit: int | None = None) -> pd.DataFrame:
result_datasets = (
db.session.query(self.dataset_cls).order_by(Dataset.updated_at.desc()).limit(limit).all()
)
def _get_datasets(
self,
db: Database,
limit: int | None = None,
filters: dict[str, list[str]] | None = None,
) -> pd.DataFrame:
query = db.session.query(self.dataset_cls)

if filters:
for key, values in filters.items():
column = getattr(self.dataset_cls, key, None)
if column is not None:
query = query.where(column.in_(values))

result_datasets = query.order_by(Dataset.updated_at.desc()).limit(limit).all()

return pd.DataFrame(
[{k: getattr(dataset, k) for k in self.dataset_specific_metadata} for dataset in result_datasets],
Comment on lines 480 to 481
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the _get_datasets method, the index of the returned DataFrame is built using [file.id for file in result_datasets] (line 482, which would be part of the refactored function). However, result_datasets contains Dataset objects, not DatasetFile objects. The loop variable file is a misleading name and should be dataset to be consistent with the comprehension on the line above it (for dataset in result_datasets). This naming inconsistency makes the code harder to read and maintain.

Copilot uses AI. Check for mistakes.
index=[file.id for file in result_datasets],
)

def load_catalog(
self, db: Database, include_files: bool = True, limit: int | None = None
self,
db: Database,
include_files: bool = True,
limit: int | None = None,
filters: dict[str, list[str]] | None = None,
) -> pd.DataFrame:
"""
Load the data catalog containing the currently tracked datasets/files from the database
Expand All @@ -469,9 +508,9 @@ def load_catalog(
with db.session.begin():
# TODO: Paginate this query to avoid loading all the data at once
if include_files:
catalog = self._get_dataset_files(db, limit)
catalog = self._get_dataset_files(db, limit, filters)
else:
catalog = self._get_datasets(db, limit)
catalog = self._get_datasets(db, limit, filters)

# If there are no datasets, return an empty DataFrame
if catalog.empty:
Expand Down
34 changes: 33 additions & 1 deletion packages/climate-ref/tests/unit/datasets/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

import cftime
import numpy as np
import pandas as pd
import pytest
Expand All @@ -8,7 +9,7 @@
from climate_ref.database import Database, ModelState
from climate_ref.datasets import IngestionStats, get_dataset_adapter, ingest_datasets
from climate_ref.datasets import base as base_module
from climate_ref.datasets.base import DatasetAdapter, _is_na
from climate_ref.datasets.base import DatasetAdapter, _file_meta_differs, _is_na
from climate_ref.datasets.cmip6 import CMIP6DatasetAdapter
from climate_ref.models.dataset import CMIP6Dataset, DatasetFile
from climate_ref_core.datasets import SourceDatasetType
Expand Down Expand Up @@ -980,3 +981,34 @@ def test_zero(self):

def test_empty_string(self):
assert _is_na("") is False


class TestFileMetaDiffers:
"""Tests for _file_meta_differs, which compares DB-stored file metadata against incoming values."""

def test_cftime_vs_matching_string_is_not_different(self):
"""cftime.datetime stringifies to the same value stored in the DB."""
db_value = "1850-01-16 12:00:00"
incoming = cftime.datetime(1850, 1, 16, 12, 0, 0, 0, calendar="proleptic_gregorian")
assert _file_meta_differs(db_value, incoming) is False

def test_cftime_vs_different_string_is_different(self):
"""cftime.datetime with different date is detected as changed."""
db_value = "1850-01-01 00:00:00"
incoming = cftime.datetime(1850, 1, 16, 12, 0, 0, 0, calendar="standard")
assert _file_meta_differs(db_value, incoming) is True

def test_string_vs_string_equal(self):
assert _file_meta_differs("1850-01-01 00:00:00", "1850-01-01 00:00:00") is False

def test_string_vs_string_different(self):
assert _file_meta_differs("1850-01-01 00:00:00", "1851-01-01 00:00:00") is True

def test_none_vs_none(self):
assert _file_meta_differs(None, None) is False

def test_none_vs_value(self):
assert _file_meta_differs(None, "1850-01-01") is True

def test_value_vs_none(self):
assert _file_meta_differs("1850-01-01", None) is True
Loading
Loading