From 3106df56c10bef3aca3830117f7faec6d3b158fc Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 30 Jun 2025 12:52:51 +0200 Subject: [PATCH 1/3] Remove FieldSet.get_fields Not needed anymore now that in v4 fields are explicitly stored in FieldSet.fields --- parcels/_reprs.py | 2 +- parcels/application_kernels/advection.py | 2 +- parcels/fieldset.py | 11 ----------- parcels/interaction/interactionkernel.py | 2 +- parcels/kernel.py | 2 +- parcels/particleset.py | 2 +- 6 files changed, 5 insertions(+), 16 deletions(-) diff --git a/parcels/_reprs.py b/parcels/_reprs.py index 9ff61565aa..c888945caf 100644 --- a/parcels/_reprs.py +++ b/parcels/_reprs.py @@ -65,7 +65,7 @@ def particleset_repr(pset: ParticleSet) -> str: def fieldset_repr(fieldset: FieldSet) -> str: # TODO v4: Rework or remove entirely """Return a pretty repr for FieldSet""" - fields_repr = "\n".join([repr(f) for f in fieldset.get_fields()]) + fields_repr = "\n".join([repr(f) for f in fieldset.fields.values()]) out = f"""<{type(fieldset).__name__}> fields: diff --git a/parcels/application_kernels/advection.py b/parcels/application_kernels/advection.py index 448c3dfdbd..8cd76708fa 100644 --- a/parcels/application_kernels/advection.py +++ b/parcels/application_kernels/advection.py @@ -175,7 +175,7 @@ def AdvectionAnalytical(particle, fieldset, time): # pragma: no cover tol = 1e-10 I_s = 10 # number of intermediate time steps direction = 1.0 if particle.dt > 0 else -1.0 - withW = True if "W" in [f.name for f in fieldset.get_fields()] else False + withW = True if "W" in [f.name for f in fieldset.fields.values()] else False withTime = True if len(fieldset.U.grid.time) > 1 else False tau, zeta, eta, xsi, ti, zi, yi, xi = fieldset.U._search_indices( time, particle.depth, particle.lat, particle.lon, particle=particle diff --git a/parcels/fieldset.py b/parcels/fieldset.py index f29526fe3c..e9f4f52290 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -184,17 +184,6 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"): ) ) - def get_fields(self) -> list[Field | VectorField]: - """Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField` - objects associated with this FieldSet. - """ - fields = [] - for v in self.__dict__.values(): - if type(v) in [Field, VectorField]: - if v not in fields: - fields.append(v) - return fields - def add_constant(self, name, value): """Add a constant to the FieldSet. Note that all constants are stored as 32-bit floats. diff --git a/parcels/interaction/interactionkernel.py b/parcels/interaction/interactionkernel.py index 0f51c821df..16d95fd8f8 100644 --- a/parcels/interaction/interactionkernel.py +++ b/parcels/interaction/interactionkernel.py @@ -120,7 +120,7 @@ def execute_python(self, pset, endtime, dt): InteractionKernel. """ if self.fieldset is not None: - for f in self.fieldset.get_fields(): + for f in self.fieldset.fields.values(): if isinstance(f, VectorField): continue f.data = np.array(f.data) diff --git a/parcels/kernel.py b/parcels/kernel.py index 716769e877..be2be29a02 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -314,7 +314,7 @@ def execute(self, pset, endtime, dt): ) if pset.fieldset is not None: - for f in self.fieldset.get_fields(): + for f in self.fieldset.fields.values(): if isinstance(f, VectorField): continue f.data = np.array(f.data) diff --git a/parcels/particleset.py b/parcels/particleset.py index 9fd5f65017..16b68ff971 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -1065,7 +1065,7 @@ def execute( return StatusCode.StopAllExecution if abs(time - next_output) < tol: - for fld in self.fieldset.get_fields(): + for fld in self.fieldset.fields.values(): if hasattr(fld, "to_write") and fld.to_write: if fld.grid.tdim > 1: raise RuntimeError( From e28b486f294b8f6ac7f15deecc67ce196f7fe9c6 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 30 Jun 2025 13:21:43 +0200 Subject: [PATCH 2/3] Update FieldSet.gridset to use a list Also remove `FieldSet.gridset_size` which was used as a patch --- docs/reference/grids.rst | 7 ----- parcels/__init__.py | 1 - parcels/fieldset.py | 13 ++++++--- parcels/gridset.py | 61 --------------------------------------- parcels/particlefile.py | 2 +- parcels/particleset.py | 8 ++--- tests/v4/test_fieldset.py | 14 +++++++-- 7 files changed, 25 insertions(+), 81 deletions(-) delete mode 100644 parcels/gridset.py diff --git a/docs/reference/grids.rst b/docs/reference/grids.rst index 9c9a54e07c..e8d47a1183 100644 --- a/docs/reference/grids.rst +++ b/docs/reference/grids.rst @@ -1,13 +1,6 @@ Gridsets and grids ================== -parcels.gridset module ----------------------- - -.. automodule:: parcels.gridset - :members: - :show-inheritance: - parcels.grid module ------------------- diff --git a/parcels/__init__.py b/parcels/__init__.py index 0eeb3d2807..fe8d4bfb47 100644 --- a/parcels/__init__.py +++ b/parcels/__init__.py @@ -6,7 +6,6 @@ from parcels.field import * from parcels.fieldset import * from parcels.grid import * -from parcels.gridset import * from parcels.interaction import * from parcels.kernel import * from parcels.particle import * diff --git a/parcels/fieldset.py b/parcels/fieldset.py index e9f4f52290..8de81421b9 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from parcels._typing import DatetimeLike + from parcels.basegrid import BaseGrid __all__ = ["FieldSet"] @@ -109,10 +110,6 @@ def dimrange(self, dim): return maxleft, minright - @property - def gridset_size(self): - return len(self.fields) - def add_field(self, field: Field, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. @@ -208,6 +205,14 @@ def add_constant(self, name, value): self.constants[name] = np.float32(value) + @property + def gridset(self) -> list[BaseGrid]: + grids = [] + for field in self.fields.values(): + if field.grid not in grids: + grids.append(field.grid) + return grids + # def computeTimeChunk(self, time=0.0, dt=1): # """Load a chunk of three data time steps into the FieldSet. # This is used when FieldSet uses data imported from netcdf, diff --git a/parcels/gridset.py b/parcels/gridset.py deleted file mode 100644 index 3609ba795d..0000000000 --- a/parcels/gridset.py +++ /dev/null @@ -1,61 +0,0 @@ -import numpy as np - -__all__ = ["GridSet"] - - -class GridSet: - """GridSet class that holds the Grids on which the Fields are defined.""" - - def __init__(self): - self.grids = [] - - def add_grid(self, field): - grid = field.grid - existing_grid = False - for g in self.grids: - if g == grid: - existing_grid = True - break - sameGrid = True - if grid.time_origin != g.time_origin: - continue - for attr in ["lon", "lat", "depth", "time"]: - gattr = getattr(g, attr) - gridattr = getattr(grid, attr) - if gattr.shape != gridattr.shape or not np.allclose(gattr, gridattr): - sameGrid = False - break - - if sameGrid: - existing_grid = True - field._grid = g # TODO: Is this even necessary? - break - - if not existing_grid: - self.grids.append(grid) - field.igrid = self.grids.index(field.grid) - - def dimrange(self, dim): - """Returns maximum value of a dimension (lon, lat, depth or time) - on 'left' side and minimum value on 'right' side for all grids - in a gridset. Useful for finding e.g. longitude range that - overlaps on all grids in a gridset. - """ - maxleft, minright = (-np.inf, np.inf) - for g in self.grids: - if getattr(g, dim).size == 1: - continue # not including grids where only one entry - else: - if dim == "depth": - maxleft = max(maxleft, np.min(getattr(g, dim))) - minright = min(minright, np.max(getattr(g, dim))) - else: - maxleft = max(maxleft, getattr(g, dim)[0]) - minright = min(minright, getattr(g, dim)[-1]) - maxleft = 0 if maxleft == -np.inf else maxleft # if all len(dim) == 1 - minright = 0 if minright == np.inf else minright # if all len(dim) == 1 - return maxleft, minright - - @property - def size(self): - return len(self.grids) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 511fb6d395..f39dd7b3b8 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -54,7 +54,7 @@ def __init__(self, name, particleset, outputdt, chunks=None, create_new_zarrfile self._particleset = particleset self._parcels_mesh = "spherical" if self.particleset.fieldset is not None: - self._parcels_mesh = self.particleset.fieldset.gridset.grids[0].mesh + self._parcels_mesh = self.particleset.fieldset.gridset[0].mesh self.lonlatdepth_dtype = self.particleset.particledata.lonlatdepth_dtype self._maxids = 0 self._pids_written = {} diff --git a/parcels/particleset.py b/parcels/particleset.py index 16b68ff971..a037cc3868 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -121,7 +121,7 @@ def ArrayClass_init(self, *args, **kwargs): if type(self).ngrids.initial < 0: numgrids = ngrids if numgrids is None and fieldset is not None: - numgrids = fieldset.gridset_size + numgrids = len(fieldset.gridset) assert numgrids is not None, "Neither fieldsets nor number of grids are specified - exiting." type(self).ngrids.initial = numgrids self.ngrids = type(self).ngrids.initial @@ -190,7 +190,7 @@ def ArrayClass_init(self, *args, **kwargs): self._repeatkwargs = kwargs self._repeatkwargs.pop("partition_function", None) - ngrids = fieldset.gridset_size + ngrids = len(fieldset.gridset) # Variables used for interaction kernels. inter_dist_horiz = None @@ -217,7 +217,7 @@ def ArrayClass_init(self, *args, **kwargs): # Initialize neighbor search data structure (used for interaction). if interaction_distance is not None: - meshes = [g.mesh for g in fieldset.gridset.grids] + meshes = [g.mesh for g in fieldset.gridset] # Assert all grids have the same mesh type assert np.all(np.array(meshes) == meshes[0]) mesh_type = meshes[0] @@ -428,7 +428,7 @@ def populate_indices(self): This is only intended for curvilinear grids, where the initial index search may be quite expensive. """ - for i, grid in enumerate(self.fieldset.gridset.grids): + for i, grid in enumerate(self.fieldset.gridset): if grid._gtype not in [GridType.CurvilinearZGrid, GridType.CurvilinearSGrid]: continue diff --git a/tests/v4/test_fieldset.py b/tests/v4/test_fieldset.py index 63b260593b..be5fc04a41 100644 --- a/tests/v4/test_fieldset.py +++ b/tests/v4/test_fieldset.py @@ -72,9 +72,17 @@ def test_fieldset_add_field_already_exists(fieldset): fieldset.add_field(field, "test_field") -@pytest.mark.xfail(reason="FieldSet doesn't yet correctly handle duplicate grids.") -def test_fieldset_gridset_size(fieldset): - assert fieldset.gridset_size == 1 +def test_fieldset_gridset(fieldset): + assert fieldset.fields["U"].grid in fieldset.gridset + assert fieldset.fields["V"].grid in fieldset.gridset + assert fieldset.fields["UV"].grid in fieldset.gridset + assert len(fieldset.gridset) == 1 + + fieldset.add_constant_field("constant_field", 1.0) + assert len(fieldset.gridset) == 2 + + +def test_fieldset_gridset_multiple_grids(): ... def test_fieldset_time_interval(): From a342ca5b698874a4666291e9a02b59421634043a Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 30 Jun 2025 14:00:02 +0200 Subject: [PATCH 3/3] Fix error message --- parcels/field.py | 2 +- tests/v4/test_field.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 51ade1e263..9c7ae9f631 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -169,7 +169,7 @@ def __init__( self.time_interval = get_time_interval(data) except ValueError as e: e.add_note( - f"Error getting time interval for field {name!r}. Are you sure that the time dimension on the xarray dataset is stored as datetime or cftime datetime objects?" + f"Error getting time interval for field {name!r}. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects?" ) raise e diff --git a/tests/v4/test_field.py b/tests/v4/test_field.py index 0e07d24b05..b336129dfe 100644 --- a/tests/v4/test_field.py +++ b/tests/v4/test_field.py @@ -84,7 +84,7 @@ def test_field_init_fail_on_float_time_dim(): grid = XGrid(xgcm.Grid(ds)) with pytest.raises( ValueError, - match="Error getting time interval.*. Are you sure that the time dimension on the xarray dataset is stored as datetime or cftime datetime objects\?", + match="Error getting time interval.*. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects\?", ): Field( name="test_field",