diff --git a/parcels/_core/utils/time.py b/parcels/_core/utils/time.py index caef07b122..76d35cf838 100644 --- a/parcels/_core/utils/time.py +++ b/parcels/_core/utils/time.py @@ -13,13 +13,13 @@ class TimeInterval: - """A class representing a time interval between two datetime objects. + """A class representing a time interval between two datetime or np.timedelta64 objects. Parameters ---------- - left : datetime or cftime.datetime + left : np.datetime64 or cftime.datetime or np.timedelta64 The left endpoint of the interval. - right : datetime or cftime.datetime + right : np.datetime64 or cftime.datetime or np.timedelta64 The right endpoint of the interval. Notes @@ -28,12 +28,17 @@ class TimeInterval: """ def __init__(self, left: T, right: T) -> None: - if not isinstance(left, (datetime, cftime.datetime, np.datetime64)): - raise ValueError(f"Expected right to be a datetime, cftime.datetime, or np.datetime64. Got {type(left)}.") - if not isinstance(right, (datetime, cftime.datetime, np.datetime64)): - raise ValueError(f"Expected right to be a datetime, cftime.datetime, or np.datetime64. Got {type(right)}.") + if not isinstance(left, (np.timedelta64, datetime, cftime.datetime, np.datetime64)): + raise ValueError( + f"Expected right to be a np.timedelta64, datetime, cftime.datetime, or np.datetime64. Got {type(left)}." + ) + if not isinstance(right, (np.timedelta64, datetime, cftime.datetime, np.datetime64)): + raise ValueError( + f"Expected right to be a np.timedelta64, datetime, cftime.datetime, or np.datetime64. Got {type(right)}." + ) if left >= right: raise ValueError(f"Expected left to be strictly less than right, got left={left} and right={right}.") + if not is_compatible(left, right): raise ValueError(f"Expected left and right to be compatible, got left={left} and right={right}.") @@ -58,6 +63,8 @@ def intersection(self, other: TimeInterval) -> TimeInterval | None: """Return the intersection of two time intervals. Returns None if there is no overlap.""" if not is_compatible(self.left, other.left): raise ValueError("TimeIntervals are not compatible.") + if not is_compatible(self.right, other.right): + raise ValueError("TimeIntervals are not compatible.") start = max(self.left, other.left) end = min(self.right, other.right) @@ -65,8 +72,17 @@ def intersection(self, other: TimeInterval) -> TimeInterval | None: return TimeInterval(start, end) if start <= end else None -def is_compatible(t1: datetime | cftime.datetime, t2: datetime | cftime.datetime) -> bool: - """Checks whether two (cftime.)datetime objects are compatible.""" +def is_compatible( + t1: datetime | cftime.datetime | np.timedelta64, t2: datetime | cftime.datetime | np.timedelta64 +) -> bool: + """ + Defines whether two datetime or np.timedelta64 objects are compatible in the context + of being left and right sides of an interval. + """ + # Ensure if either is a timedelta64, both must be + if isinstance(t1, np.timedelta64) ^ isinstance(t2, np.timedelta64): + return False + try: t1 - t2 except Exception: diff --git a/tests/v4/test_field.py b/tests/v4/test_field.py index 6dbe77f85b..1bbfb6c11e 100644 --- a/tests/v4/test_field.py +++ b/tests/v4/test_field.py @@ -72,11 +72,13 @@ def test_field_init_structured_grid(data, grid): assert field.grid == grid -@pytest.mark.parametrize("numpy_dtype", ["timedelta64[s]", "float64"]) -def test_field_init_fail_on_bad_time_type(numpy_dtype): - """Tests that field initialisation fails when the time isn't given as datetime object (i.e., is float or timedelta).""" +def test_field_init_fail_on_float_time_dim(): + """Test field initialisation fails when given float array as time dimension. + + (users are expected to use timedelta64 or datetime). + """ ds = datasets_structured["ds_2d_left"].copy() - ds["time"] = np.arange(0, T_structured, dtype=numpy_dtype) + ds["time"] = np.arange(0, T_structured, dtype="float64") data = ds["data_g"] grid = XGrid(xgcm.Grid(ds)) diff --git a/tests/v4/test_fieldset.py b/tests/v4/test_fieldset.py index 99cc54e432..63b260593b 100644 --- a/tests/v4/test_fieldset.py +++ b/tests/v4/test_fieldset.py @@ -128,6 +128,14 @@ def test_fieldset_add_field_incompatible_calendars(fieldset): with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"): fieldset.add_field(field, "test_field") + ds_test = ds.copy() + ds_test["time"] = np.linspace(0, 100, T_structured, dtype="timedelta64[s]") + grid = XGrid(xgcm.Grid(ds_test)) + field = Field("test_field", ds_test["data_g"], grid, mesh_type="flat") + + with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"): + fieldset.add_field(field, "test_field") + @pytest.mark.parametrize( "input_, expected", diff --git a/tests/v4/utils/test_time.py b/tests/v4/utils/test_time.py index c736cdfd66..b5105ac3f9 100644 --- a/tests/v4/utils/test_time.py +++ b/tests/v4/utils/test_time.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime, timedelta +from datetime import datetime import numpy as np import pytest @@ -11,17 +11,36 @@ from parcels._core.utils.time import TimeInterval calendar_strategy = st.sampled_from( - ["gregorian", "proleptic_gregorian", "365_day", "360_day", "julian", "366_day", np.datetime64, datetime] + [ + "gregorian", + "proleptic_gregorian", + "365_day", + "360_day", + "julian", + "366_day", + np.datetime64, + datetime, + np.timedelta64, + ] ) +@st.composite +def np_timedelta64_strategy(draw): + """Strategy for generating np.timedelta64 objects.""" + return np.timedelta64(draw(st.integers(1, 60 * 60 * 24 * 100 * 365)), "s") + + @st.composite def datetime_strategy(draw, calendar=None): + if calendar is None: + calendar = draw(calendar_strategy) + if calendar is np.timedelta64: + return draw(np_timedelta64_strategy()) + year = draw(st.integers(1900, 2100)) month = draw(st.integers(1, 12)) day = draw(st.integers(1, 28)) - if calendar is None: - calendar = draw(calendar_strategy) if calendar is datetime: return datetime(year, month, day) if calendar is np.datetime64: @@ -34,12 +53,8 @@ def datetime_strategy(draw, calendar=None): def time_interval_strategy(draw, left=None, calendar=None): if left is None: left = draw(datetime_strategy(calendar=calendar)) - right = left + draw( - st.timedeltas( - min_value=timedelta(seconds=1), - max_value=timedelta(days=100 * 365), - ) - ) + right = left + draw(np_timedelta64_strategy()) + return TimeInterval(left, right)