Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions docs/reference/grids.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
Gridsets and grids
==================

parcels.gridset module
----------------------

.. automodule:: parcels.gridset
:members:
:show-inheritance:

parcels.grid module
-------------------

Expand Down
1 change: 0 additions & 1 deletion parcels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from parcels.field import *
from parcels.fieldset import *
from parcels.grid import *
from parcels.gridset import *
from parcels.interaction import *
from parcels.kernel import *
from parcels.particle import *
Expand Down
2 changes: 1 addition & 1 deletion parcels/_reprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def particleset_repr(pset: ParticleSet) -> str:

def fieldset_repr(fieldset: FieldSet) -> str: # TODO v4: Rework or remove entirely
"""Return a pretty repr for FieldSet"""
fields_repr = "\n".join([repr(f) for f in fieldset.get_fields()])
fields_repr = "\n".join([repr(f) for f in fieldset.fields.values()])

out = f"""<{type(fieldset).__name__}>
fields:
Expand Down
2 changes: 1 addition & 1 deletion parcels/application_kernels/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def AdvectionAnalytical(particle, fieldset, time): # pragma: no cover
tol = 1e-10
I_s = 10 # number of intermediate time steps
direction = 1.0 if particle.dt > 0 else -1.0
withW = True if "W" in [f.name for f in fieldset.get_fields()] else False
withW = True if "W" in [f.name for f in fieldset.fields.values()] else False
withTime = True if len(fieldset.U.grid.time) > 1 else False
tau, zeta, eta, xsi, ti, zi, yi, xi = fieldset.U._search_indices(
time, particle.depth, particle.lat, particle.lon, particle=particle
Expand Down
2 changes: 1 addition & 1 deletion parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
self.time_interval = get_time_interval(data)
except ValueError as e:
e.add_note(
f"Error getting time interval for field {name!r}. Are you sure that the time dimension on the xarray dataset is stored as datetime or cftime datetime objects?"
f"Error getting time interval for field {name!r}. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects?"
)
raise e

Expand Down
24 changes: 9 additions & 15 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

if TYPE_CHECKING:
from parcels._typing import DatetimeLike
from parcels.basegrid import BaseGrid
__all__ = ["FieldSet"]


Expand Down Expand Up @@ -109,10 +110,6 @@ def dimrange(self, dim):

return maxleft, minright

@property
def gridset_size(self):
return len(self.fields)

def add_field(self, field: Field, name: str | None = None):
"""Add a :class:`parcels.field.Field` object to the FieldSet.

Expand Down Expand Up @@ -184,17 +181,6 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
)
)

def get_fields(self) -> list[Field | VectorField]:
"""Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField`
objects associated with this FieldSet.
"""
fields = []
for v in self.__dict__.values():
if type(v) in [Field, VectorField]:
if v not in fields:
fields.append(v)
return fields

def add_constant(self, name, value):
"""Add a constant to the FieldSet. Note that all constants are
stored as 32-bit floats.
Expand All @@ -219,6 +205,14 @@ def add_constant(self, name, value):

self.constants[name] = np.float32(value)

@property
def gridset(self) -> list[BaseGrid]:
Comment thread
VeckoTheGecko marked this conversation as resolved.
grids = []
for field in self.fields.values():
if field.grid not in grids:
grids.append(field.grid)
return grids

# def computeTimeChunk(self, time=0.0, dt=1):
# """Load a chunk of three data time steps into the FieldSet.
# This is used when FieldSet uses data imported from netcdf,
Expand Down
61 changes: 0 additions & 61 deletions parcels/gridset.py

This file was deleted.

2 changes: 1 addition & 1 deletion parcels/interaction/interactionkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def execute_python(self, pset, endtime, dt):
InteractionKernel.
"""
if self.fieldset is not None:
for f in self.fieldset.get_fields():
for f in self.fieldset.fields.values():
if isinstance(f, VectorField):
continue
f.data = np.array(f.data)
Expand Down
2 changes: 1 addition & 1 deletion parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def execute(self, pset, endtime, dt):
)

if pset.fieldset is not None:
for f in self.fieldset.get_fields():
for f in self.fieldset.fields.values():
if isinstance(f, VectorField):
continue
f.data = np.array(f.data)
Expand Down
2 changes: 1 addition & 1 deletion parcels/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, name, particleset, outputdt, chunks=None, create_new_zarrfile
self._particleset = particleset
self._parcels_mesh = "spherical"
if self.particleset.fieldset is not None:
self._parcels_mesh = self.particleset.fieldset.gridset.grids[0].mesh
self._parcels_mesh = self.particleset.fieldset.gridset[0].mesh
self.lonlatdepth_dtype = self.particleset.particledata.lonlatdepth_dtype
self._maxids = 0
self._pids_written = {}
Expand Down
10 changes: 5 additions & 5 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def ArrayClass_init(self, *args, **kwargs):
if type(self).ngrids.initial < 0:
numgrids = ngrids
if numgrids is None and fieldset is not None:
numgrids = fieldset.gridset_size
numgrids = len(fieldset.gridset)
assert numgrids is not None, "Neither fieldsets nor number of grids are specified - exiting."
type(self).ngrids.initial = numgrids
self.ngrids = type(self).ngrids.initial
Expand Down Expand Up @@ -190,7 +190,7 @@ def ArrayClass_init(self, *args, **kwargs):
self._repeatkwargs = kwargs
self._repeatkwargs.pop("partition_function", None)

ngrids = fieldset.gridset_size
ngrids = len(fieldset.gridset)

# Variables used for interaction kernels.
inter_dist_horiz = None
Expand All @@ -217,7 +217,7 @@ def ArrayClass_init(self, *args, **kwargs):

# Initialize neighbor search data structure (used for interaction).
if interaction_distance is not None:
meshes = [g.mesh for g in fieldset.gridset.grids]
meshes = [g.mesh for g in fieldset.gridset]
# Assert all grids have the same mesh type
assert np.all(np.array(meshes) == meshes[0])
mesh_type = meshes[0]
Expand Down Expand Up @@ -428,7 +428,7 @@ def populate_indices(self):
This is only intended for curvilinear grids, where the initial index search
may be quite expensive.
"""
for i, grid in enumerate(self.fieldset.gridset.grids):
for i, grid in enumerate(self.fieldset.gridset):
if grid._gtype not in [GridType.CurvilinearZGrid, GridType.CurvilinearSGrid]:
continue

Expand Down Expand Up @@ -1065,7 +1065,7 @@ def execute(
return StatusCode.StopAllExecution

if abs(time - next_output) < tol:
for fld in self.fieldset.get_fields():
for fld in self.fieldset.fields.values():
if hasattr(fld, "to_write") and fld.to_write:
if fld.grid.tdim > 1:
raise RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion tests/v4/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_field_init_fail_on_float_time_dim():
grid = XGrid(xgcm.Grid(ds))
with pytest.raises(
ValueError,
match="Error getting time interval.*. Are you sure that the time dimension on the xarray dataset is stored as datetime or cftime datetime objects\?",
match="Error getting time interval.*. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects\?",
):
Field(
name="test_field",
Expand Down
14 changes: 11 additions & 3 deletions tests/v4/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,17 @@ def test_fieldset_add_field_already_exists(fieldset):
fieldset.add_field(field, "test_field")


@pytest.mark.xfail(reason="FieldSet doesn't yet correctly handle duplicate grids.")
def test_fieldset_gridset_size(fieldset):
assert fieldset.gridset_size == 1
def test_fieldset_gridset(fieldset):
assert fieldset.fields["U"].grid in fieldset.gridset
assert fieldset.fields["V"].grid in fieldset.gridset
assert fieldset.fields["UV"].grid in fieldset.gridset
assert len(fieldset.gridset) == 1

fieldset.add_constant_field("constant_field", 1.0)
assert len(fieldset.gridset) == 2


def test_fieldset_gridset_multiple_grids(): ...


def test_fieldset_time_interval():
Expand Down
Loading