From de3db540b440ec57f99e34593aa847f74993b7d6 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Mon, 16 Jun 2025 12:49:02 -0400 Subject: [PATCH 01/12] Add pointer to grid in vector field This is needed for consistency of grid access independent of whether we are referencing a scalar or vector field. --- parcels/field.py | 1 + 1 file changed, 1 insertion(+) diff --git a/parcels/field.py b/parcels/field.py index d1377db0d5..f7d07bacba 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -386,6 +386,7 @@ def __init__( self.U = U self.V = V self.W = W + self.grid = U.grid if W is None: assert_same_time_interval((U, V)) From e83c672620509d36be315404deb834c3548aeb22 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Mon, 16 Jun 2025 12:50:37 -0400 Subject: [PATCH 02/12] Add depth property for consistency with XGrid This property is needed for determining min/max depth bounds in particleset.__init__ . With consistency between XGrid and UXGrid, this getting depth ranges with either grid type can be accomplished with the same code in particleset.__init__ --- parcels/uxgrid.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/parcels/uxgrid.py b/parcels/uxgrid.py index eccd51aa58..850b7b8a00 100644 --- a/parcels/uxgrid.py +++ b/parcels/uxgrid.py @@ -35,6 +35,20 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid: raise ValueError("z must be a 1D array of vertical coordinates") self.z = z + @property + def depth(self): + """ + Note + ---- + Included for compatibility with v3 codebase. May be removed in future. + TODO v4: Evaluate + """ + try: + _ = self.z.values + except KeyError: + return np.zeros(1) + return self.z.values + def search( self, z: float, y: float, x: float, ei: int | None = None, search2D: bool = False ) -> tuple[np.ndarray, int]: From bc4eae36372fb1c26452bae3c9ac6897cf822a84 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Mon, 16 Jun 2025 12:53:04 -0400 Subject: [PATCH 03/12] [#2034] Remove particle interaction setup in particleset.__init__ This commit also cleans up a few sections of init for working with the new field/fieldset API. --- parcels/particleset.py | 134 ++---------------------------- tests/v4/test_uxarray_fieldset.py | 2 +- 2 files changed, 10 insertions(+), 126 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index 9fd5f65017..7a65254287 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -1,7 +1,6 @@ import sys import warnings from collections.abc import Iterable -from copy import copy from datetime import date, datetime, timedelta import cftime @@ -10,20 +9,13 @@ from scipy.spatial import KDTree from tqdm import tqdm -from parcels._compat import MPI from parcels._core.utils.time import TimeInterval from parcels._reprs import particleset_repr from parcels.application_kernels.advection import AdvectionRK4 from parcels.grid import GridType from parcels.interaction.interactionkernel import InteractionKernel -from parcels.interaction.neighborsearch import ( - BruteFlatNeighborSearch, - BruteSphericalNeighborSearch, - HashSphericalNeighborSearch, - KDTreeFlatNeighborSearch, -) from parcels.kernel import Kernel -from parcels.particle import Particle, Variable +from parcels.particle import Particle from parcels.particledata import ParticleData, ParticleDataIterator from parcels.particlefile import ParticleFile from parcels.tools._helpers import timedelta_to_float @@ -109,37 +101,6 @@ def __init__( self.fieldset = fieldset self._pclass = pclass - # ==== first: create a new subclass of the pclass that includes the required variables ==== # - # ==== see dynamic-instantiation trick here: https://www.python-course.eu/python3_classes_and_type.php ==== # - class_name = pclass.__name__ - array_class = None - if class_name not in dir(): - - def ArrayClass_init(self, *args, **kwargs): - fieldset = kwargs.get("fieldset", None) - ngrids = kwargs.get("ngrids", None) - if type(self).ngrids.initial < 0: - numgrids = ngrids - if numgrids is None and fieldset is not None: - numgrids = fieldset.gridset_size - assert numgrids is not None, "Neither fieldsets nor number of grids are specified - exiting." - type(self).ngrids.initial = numgrids - self.ngrids = type(self).ngrids.initial - if self.ngrids >= 0: - self.ei = np.zeros(self.ngrids, dtype=np.int32) - super(type(self), self).__init__(*args, **kwargs) - - array_class_vdict = { - "ngrids": Variable("ngrids", dtype=np.int32, to_write=False, initial=-1), - "ei": Variable("ei", dtype=np.int32, to_write=False), - "__init__": ArrayClass_init, - } - array_class = type(class_name, (pclass,), array_class_vdict) - else: - array_class = locals()[class_name] - # ==== dynamic re-classing completed ==== # - _pclass = array_class - lon = np.empty(shape=0) if lon is None else convert_to_flat_array(lon) lat = np.empty(shape=0) if lat is None else convert_to_flat_array(lat) @@ -147,7 +108,10 @@ def ArrayClass_init(self, *args, **kwargs): pid_orig = np.arange(lon.size) if depth is None: - mindepth = self.fieldset.dimrange("depth")[0] + mindepth = 0 + for field in self.fieldset.fields.values(): + if field.grid.depth is not None: + mindepth = min(mindepth, field.grid.depth[0]) depth = np.ones(lon.size) * mindepth else: depth = convert_to_flat_array(depth) @@ -162,8 +126,8 @@ def ArrayClass_init(self, *args, **kwargs): raise NotImplementedError("If fieldset.time_origin is not a date, time of a particle must be a double") time = np.array([self.time_origin.reltime(t) if _convert_to_reltime(t) else t for t in time]) assert lon.size == time.size, "time and positions (lon, lat, depth) do not have the same lengths." - if fieldset.time_interval: - _warn_particle_times_outside_fieldset_time_bounds(time, fieldset.time_interval) + # if fieldset.time_interval: # TODO : Fixe time_interval for datasets with no time interval + # _warn_particle_times_outside_fieldset_time_bounds(time, fieldset.time_interval) if lonlatdepth_dtype is None: lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U) @@ -179,98 +143,18 @@ def ArrayClass_init(self, *args, **kwargs): lon.size == kwargs[kwvar].size ), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths." - self.repeatdt = timedelta_to_float(repeatdt) if repeatdt is not None else None - - if self.repeatdt: - if self.repeatdt <= 0: - raise ValueError("Repeatdt should be > 0") - if time[0] and not np.allclose(time, time[0]): - raise ValueError("All Particle.time should be the same when repeatdt is not None") - self._repeatpclass = pclass - self._repeatkwargs = kwargs - self._repeatkwargs.pop("partition_function", None) - - ngrids = fieldset.gridset_size - - # Variables used for interaction kernels. - inter_dist_horiz = None - inter_dist_vert = None - # The _dirty_neighbor attribute keeps track of whether - # the neighbor search structure needs to be rebuilt. - # If indices change (for example adding/deleting a particle) - # The NS structure needs to be rebuilt and _dirty_neighbor should be - # set to true. Since the NS structure isn't immediately initialized, - # it is set to True here. - self._dirty_neighbor = True - self.particledata = ParticleData( - _pclass, + self._pclass, lon=lon, lat=lat, depth=depth, time=time, lonlatdepth_dtype=lonlatdepth_dtype, pid_orig=pid_orig, - ngrid=ngrids, + ngrid=fieldset.gridset_size, **kwargs, ) - # Initialize neighbor search data structure (used for interaction). - if interaction_distance is not None: - meshes = [g.mesh for g in fieldset.gridset.grids] - # Assert all grids have the same mesh type - assert np.all(np.array(meshes) == meshes[0]) - mesh_type = meshes[0] - if mesh_type == "spherical": - if len(self) < 1000: - interaction_class = BruteSphericalNeighborSearch - else: - interaction_class = HashSphericalNeighborSearch - elif mesh_type == "flat": - if len(self) < 1000: - interaction_class = BruteFlatNeighborSearch - else: - interaction_class = KDTreeFlatNeighborSearch - else: - assert False, "Interaction is only possible on 'flat' and 'spherical' meshes" - try: - if len(interaction_distance) == 2: - inter_dist_vert, inter_dist_horiz = interaction_distance - else: - inter_dist_vert = interaction_distance[0] - inter_dist_horiz = interaction_distance[0] - except TypeError: - inter_dist_vert = interaction_distance - inter_dist_horiz = interaction_distance - self._neighbor_tree = interaction_class( - inter_dist_vert=inter_dist_vert, - inter_dist_horiz=inter_dist_horiz, - periodic_domain_zonal=periodic_domain_zonal, - ) - # End of neighbor search data structure initialization. - - if self.repeatdt: - if len(time) > 0 and time[0] is None: - self._repeat_starttime = time[0] - else: - if self.particledata.data["time"][0] and not np.allclose( - self.particledata.data["time"], self.particledata.data["time"][0] - ): - raise ValueError("All Particle.time should be the same when repeatdt is not None") - self._repeat_starttime = copy(self.particledata.data["time"][0]) - self._repeatlon = copy(self.particledata.data["lon"]) - self._repeatlat = copy(self.particledata.data["lat"]) - self._repeatdepth = copy(self.particledata.data["depth"]) - for kwvar in kwargs: - if kwvar not in ["partition_function"]: - self._repeatkwargs[kwvar] = copy(self.particledata.data[kwvar]) - - if self.repeatdt: - if MPI and self.particledata.pu_indicators is not None: - mpi_comm = MPI.COMM_WORLD - mpi_rank = mpi_comm.Get_rank() - self._repeatpid = pid_orig[self.particledata.pu_indicators == mpi_rank] - self._kernel = None def __del__(self): diff --git a/tests/v4/test_uxarray_fieldset.py b/tests/v4/test_uxarray_fieldset.py index fbb2cc3840..86dff027cc 100644 --- a/tests/v4/test_uxarray_fieldset.py +++ b/tests/v4/test_uxarray_fieldset.py @@ -83,9 +83,9 @@ def test_fesom_fieldset(ds_fesom_channel, uv_fesom_channel): assert (fieldset.V == ds_fesom_channel.V).all() -@pytest.mark.skip(reason="ParticleSet.__init__ needs major refactoring") def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel): fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) + # Check that the fieldset has the expected properties assert (fieldset.U == ds_fesom_channel.U).all() assert (fieldset.V == ds_fesom_channel.V).all() From 04f0393ac7482e6db118571aed99af5e4367c920 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Thu, 19 Jun 2025 15:32:46 -0400 Subject: [PATCH 04/12] [#2034] Fix VectorField eval call stack --- parcels/field.py | 57 +++++++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index f7d07bacba..8b09239b99 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -431,36 +431,43 @@ def vector_interp_method(self, method: Callable): # and np.allclose(grid1.depth, grid2.depth) # and np.allclose(grid1.time, grid2.time) # ) - def _interpolate(self, time, z, y, x, ei): - bcoords, _ei, ti = self._search_indices(time, z, y, x, ei=ei) - - if self._vector_interp_method is None: - u = self.U.eval(time, z, y, x, _ei, applyConversion=False) - v = self.V.eval(time, z, y, x, _ei, applyConversion=False) - if "3D" in self.vector_type: - w = self.W.eval(time, z, y, x, _ei, applyConversion=False) - return (u, v, w) - else: - return (u, v, 0) - else: - (u, v, w) = self._vector_interp_method(ti, _ei, bcoords, time, z, y, x) - return (u, v, w) + def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): + """Interpolate field values in space and time. - def eval(self, time, z, y, x, ei=None, applyConversion=True): - if ei is None: - _ei = 0 + We interpolate linearly in time and apply implicit unit + conversion to the result. Note that we defer to + scipy.interpolate to perform spatial interpolation. + """ + if particle is None: + _ei = None else: - _ei = ei[self.igrid] + _ei = particle.ei[self.igrid] - (u, v, w) = self._interpolate(time, z, y, x, _ei) + try: + tau, ti = _search_time_index(self.U, time) + bcoords, _ei = self.grid.search(z, y, x, ei=_ei) + if self._vector_interp_method is None: + u = self.U._interp_method(self.U, ti, _ei, bcoords, tau, time, z, y, x) + v = self.V._interp_method(self.V, ti, _ei, bcoords, tau, time, z, y, x) + if "3D" in self.vector_type: + w = self.W._interp_method(self.W, ti, _ei, bcoords, tau, time, z, y, x) + else: + (u, v, w) = self._vector_interp_method(self, ti, _ei, bcoords, time, z, y, x) - if applyConversion: - u = self.U.units.to_target(u, z, y, x) - v = self.V.units.to_target(v, z, y, x) - if "3D" in self.vector_type: - w = self.W.units.to_target(w, z, y, x) + if applyConversion: + u = self.U.units.to_target(u, z, y, x) + v = self.V.units.to_target(v, z, y, x) + if "3D" in self.vector_type: + w = self.W.units.to_target(w, z, y, x) if self.W else 0.0 + else: + if "3D" in self.vector_type: + return (u, v, w) + else: + return (u, v, 0) - return (u, v, w) + except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e: + e.add_note(f"Error interpolating field '{self.name}'.") + raise e def __getitem__(self, key): try: From 25bdfe9d7f02ce79d95db3dc6aed76c09db3aa35 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Thu, 19 Jun 2025 15:34:56 -0400 Subject: [PATCH 05/12] [#2034] Simplify pset.execute for minimal forward step --- parcels/particleset.py | 244 ++++++++++------------------------------- 1 file changed, 57 insertions(+), 187 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index 7a65254287..da04854b3e 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -1,9 +1,8 @@ import sys import warnings from collections.abc import Iterable -from datetime import date, datetime, timedelta +from datetime import date, datetime -import cftime import numpy as np import xarray as xr from scipy.spatial import KDTree @@ -18,7 +17,6 @@ from parcels.particle import Particle from parcels.particledata import ParticleData, ParticleDataIterator from parcels.particlefile import ParticleFile -from parcels.tools._helpers import timedelta_to_float from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array from parcels.tools.loggers import logger from parcels.tools.statuscodes import StatusCode @@ -81,11 +79,8 @@ def __init__( lat=None, depth=None, time=None, - repeatdt=None, lonlatdepth_dtype=None, pid_orig=None, - interaction_distance=None, - periodic_domain_zonal=None, **kwargs, ): self.particledata = None @@ -124,10 +119,11 @@ def __init__( time = np.array([np.datetime64(t) for t in time]) if time.size > 0 and isinstance(time[0], np.timedelta64) and not self.time_origin: raise NotImplementedError("If fieldset.time_origin is not a date, time of a particle must be a double") + time = np.array([self.time_origin.reltime(t) if _convert_to_reltime(t) else t for t in time]) assert lon.size == time.size, "time and positions (lon, lat, depth) do not have the same lengths." - # if fieldset.time_interval: # TODO : Fixe time_interval for datasets with no time interval - # _warn_particle_times_outside_fieldset_time_bounds(time, fieldset.time_interval) + if fieldset.time_interval: + _warn_particle_times_outside_fieldset_time_bounds(time, fieldset.time_interval) if lonlatdepth_dtype is None: lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U) @@ -746,15 +742,11 @@ def set_variable_write_status(self, var, write_status): def execute( self, + endtime: np.timedelta64 | np.datetime64, + dt: np.float64 | np.float32 | np.timedelta64, pyfunc=AdvectionRK4, - pyfunc_inter=None, - endtime=None, - runtime: float | timedelta | np.timedelta64 | None = None, - dt: float | timedelta | np.timedelta64 = 1.0, output_file=None, verbose_progress=True, - postIterationCallbacks=None, - callbackdt: float | timedelta | np.timedelta64 | None = None, ): """Execute a given kernel function over the particle set for multiple timesteps. @@ -767,13 +759,10 @@ def execute( Kernel function to execute. This can be the name of a defined Python function or a :class:`parcels.kernel.Kernel` object. Kernels can be concatenated using the + operator (Default value = AdvectionRK4) - endtime : - End time for the timestepping loop. - It is either a datetime object or a positive double. (Default value = None) - runtime : - Length of the timestepping loop. Use instead of endtime. - It is either a timedelta object or a positive double. (Default value = None) - dt : + endtime (datetime.datetime or np.timedelta64): : + End time for the timestepping loop. If a timedelta is provided, it is interpreted as the total simulation time. + If a datetime is provided, it is interpreted as the end time of the simulation. + dt (timedelta): Timestep interval (in seconds) to be passed to the kernel. It is either a timedelta object or a double. Use a negative value for a backward-in-time simulation. (Default value = 1 second) @@ -781,12 +770,6 @@ def execute( mod:`parcels.particlefile.ParticleFile` object for particle output (Default value = None) verbose_progress : bool Boolean for providing a progress bar for the kernel execution loop. (Default value = True) - postIterationCallbacks : - Optional, array of functions that are to be called after each iteration (post-process, non-Kernel) (Default value = None) - callbackdt : - Optional, in conjecture with 'postIterationCallbacks', timestep interval to (latest) interrupt the running kernel and invoke post-iteration callbacks from 'postIterationCallbacks' (Default value = None) - pyfunc_inter : - (Default value = None) Notes ----- @@ -803,197 +786,84 @@ def execute( self._kernel = pyfunc else: self._kernel = self.Kernel(pyfunc) + if output_file: output_file.metadata["parcels_kernels"] = self._kernel.name - # Set up the interaction kernel(s) if not set and given. - if self._interaction_kernel is None and pyfunc_inter is not None: - if isinstance(pyfunc_inter, InteractionKernel): - self._interaction_kernel = pyfunc_inter + # The fieldset time intervale defines the extent of time that is allowed to be + # simulated. If `fieldset.time_interval` is not None, it will be used to determine the endtime (the min of endtime or fieldset.time_interval[1]). + # If `fieldset.time_interval` is None, the endtime will be determined by the + # `endtime` parameter or the fieldset's time dimension. + # Time parameters for the main for loop are converted to floats, since the interpolation kernels expect float objects for time + # The initial time (in float point) representation is t0=0.0 and time is interpreted as relative to the start of the time interval + fieldset_timeinterval = self.fieldset.time_interval + + if fieldset_timeinterval is None: + if isinstance(endtime, np.datetime64): + raise NotImplementedError( + "If fieldset.time_interval is None, endtime must be a np.timedelta64 not a np.datetime64" + ) + duration = endtime / np.timedelta64(1, "s") # converts np.timedelta64 to seconds as float64 + + else: + # Get the particle time interval + if isinstance(endtime, np.datetime64): + simulation_endtime = np.min(fieldset_timeinterval[1], endtime) + if simulation_endtime < fieldset_timeinterval[1]: + print( + f"Simulation endtime is limited by fieldset.time_interval. End time adjusted to {simulation_endtime}" + ) + duration = (simulation_endtime - fieldset_timeinterval[0]) / np.timedelta64(1, "s") + else: - self._interaction_kernel = self.InteractionKernel(pyfunc_inter) - - # Convert all time variables to seconds - if isinstance(endtime, timedelta): - raise TypeError("endtime must be either a datetime or a double") - if isinstance(endtime, datetime): - endtime = np.datetime64(endtime) - elif isinstance(endtime, cftime.datetime): - endtime = self.time_origin.reltime(endtime) - if isinstance(endtime, np.datetime64): - if self.time_origin.calendar is None: - raise NotImplementedError("If fieldset.time_origin is not a date, execution endtime must be a double") - endtime = self.time_origin.reltime(endtime) - - if runtime is not None: - runtime = timedelta_to_float(runtime) - - dt = timedelta_to_float(dt) - - if abs(dt) <= 1e-6: - raise ValueError("Time step dt is too small") - if (dt * 1e6) % 1 != 0: - raise ValueError("Output interval should not have finer precision than 1e-6 s") - outputdt = timedelta_to_float(output_file.outputdt) if output_file else np.inf - - if callbackdt is not None: - callbackdt = timedelta_to_float(callbackdt) - - assert runtime is None or runtime >= 0, "runtime must be positive" - assert outputdt is None or outputdt >= 0, "outputdt must be positive" - - if runtime is not None and endtime is not None: - raise RuntimeError("Only one of (endtime, runtime) can be specified") - - mintime, maxtime = self.fieldset.dimrange("time") # TODO : change to fieldset.time_interval - - default_release_time = mintime if dt >= 0 else maxtime - if np.any(np.isnan(self.particledata.data["time"])): - self.particledata.data["time"][np.isnan(self.particledata.data["time"])] = default_release_time - self.particledata.data["time_nextloop"][np.isnan(self.particledata.data["time_nextloop"])] = ( - default_release_time - ) - min_rt = np.min(self.particledata.data["time_nextloop"]) - max_rt = np.max(self.particledata.data["time_nextloop"]) - - # Derive starttime and endtime from arguments or fieldset defaults - starttime = min_rt if dt >= 0 else max_rt - if self.repeatdt is not None and self._repeat_starttime is None: - self._repeat_starttime = starttime - if runtime is not None: - endtime = starttime + runtime * np.sign(dt) - elif endtime is None: - mintime, maxtime = self.fieldset.dimrange("time") - endtime = maxtime if dt >= 0 else mintime - - if (abs(endtime - starttime) < 1e-5 or runtime == 0) and dt == 0: - raise RuntimeError( - "dt and runtime are zero, or endtime is equal to Particle.time. " - "ParticleSet.execute() will not do anything." - ) + duration = endtime / np.timedelta64(1, "s") - if np.isfinite(outputdt): - _warn_outputdt_release_desync(outputdt, starttime, self.particledata.data["time_nextloop"]) + if isinstance(dt, np.datetime64): + dt = dt / np.timedelta64(1, "s") # convert to seconds as float64 - self.particledata._data["dt"][:] = dt + outputdt = output_file.outputdt if output_file else None - if callbackdt is None: - interupt_dts = [np.inf, outputdt] - if self.repeatdt is not None: - interupt_dts.append(self.repeatdt) - callbackdt = np.min(np.array(interupt_dts)) + self.particledata._data["dt"][:] = dt # Set up pbar if output_file: logger.info(f"Output files are stored in {output_file.fname}.") if verbose_progress: - pbar = tqdm(total=abs(endtime - starttime), file=sys.stdout) + pbar = tqdm(total=abs(duration), file=sys.stdout) - # Set up variables for first iteration - if self.repeatdt: - next_prelease = self._repeat_starttime + ( - abs(starttime - self._repeat_starttime) // self.repeatdt + 1 - ) * self.repeatdt * np.sign(dt) - else: - next_prelease = np.inf if dt > 0 else -np.inf if output_file: - next_output = starttime + dt + next_output = outputdt else: next_output = np.inf * np.sign(dt) - next_callback = starttime + callbackdt * np.sign(dt) tol = 1e-12 - time = starttime + time = 0.0 - while (time < endtime and dt > 0) or (time > endtime and dt < 0): + while time < duration and dt > 0: # Forward in time only for now # Check if we can fast-forward to the next time needed for the particles - if dt > 0: - skip_kernel = True if min(self.time) > (time + dt) else False - else: - skip_kernel = True if max(self.time) < (time + dt) else False - - time_at_startofloop = time - - next_input = self.fieldset.computeTimeChunk(time, dt) + # if dt > 0: + # skip_kernel = True if duration > (time + dt) else False + # else: + # skip_kernel = True if max(self.time) < (time + dt) else False + + t0 = time + next_time = t0 + dt + res = self._kernel.execute(self, endtime=next_time, dt=dt) + if res == StatusCode.StopAllExecution: + return StatusCode.StopAllExecution - # Define next_time (the timestamp when the execution needs to be handed back to python) - if dt > 0: - next_time = min(next_prelease, next_input, next_output, next_callback, endtime) - else: - next_time = max(next_prelease, next_input, next_output, next_callback, endtime) - - # If we don't perform interaction, only execute the normal kernel efficiently. - if self._interaction_kernel is None: - if not skip_kernel: - res = self._kernel.execute(self, endtime=next_time, dt=dt) - if res == StatusCode.StopAllExecution: - return StatusCode.StopAllExecution - # Interaction: interleave the interaction and non-interaction kernel for each time step. - # E.g. Normal -> Inter -> Normal -> Inter if endtime-time == 2*dt - else: - cur_time = time - while (cur_time < next_time and dt > 0) or (cur_time > next_time and dt < 0): - if dt > 0: - cur_end_time = min(cur_time + dt, next_time) - else: - cur_end_time = max(cur_time + dt, next_time) - self._kernel.execute(self, endtime=cur_end_time, dt=dt) - self._interaction_kernel.execute(self, endtime=cur_end_time, dt=dt) - cur_time += dt # End of interaction specific code time = next_time - # Check for empty ParticleSet - if np.isinf(next_prelease) and len(self) == 0: - return StatusCode.StopAllExecution - - if abs(time - next_output) < tol: - for fld in self.fieldset.get_fields(): - if hasattr(fld, "to_write") and fld.to_write: - if fld.grid.tdim > 1: - raise RuntimeError( - "Field writing during execution only works for Fields with one snapshot in time" - ) - fldfilename = str(output_file.fname).replace(".zarr", f"_{fld.to_write:04d}") - fld.write(fldfilename) - fld.to_write += 1 - if abs(time - next_output) < tol: if output_file: - output_file.write(self, time_at_startofloop) + output_file.write(self, t0) if np.isfinite(outputdt): next_output += outputdt * np.sign(dt) - # ==== insert post-process here to also allow for memory clean-up via external func ==== # - if abs(time - next_callback) < tol: - if postIterationCallbacks is not None: - for extFunc in postIterationCallbacks: - extFunc() - next_callback += callbackdt * np.sign(dt) - - if abs(time - next_prelease) < tol: - pset_new = self.__class__( - fieldset=self.fieldset, - time=time, - lon=self._repeatlon, - lat=self._repeatlat, - depth=self._repeatdepth, - pclass=self._repeatpclass, - lonlatdepth_dtype=self.particledata.lonlatdepth_dtype, - partition_function=False, - pid_orig=self._repeatpid, - **self._repeatkwargs, - ) - for p in pset_new: - p.dt = dt - self.add(pset_new) - next_prelease += self.repeatdt * np.sign(dt) - - if time != endtime: - next_input = self.fieldset.computeTimeChunk(time, dt) if verbose_progress: - pbar.update(abs(time - time_at_startofloop)) + pbar.update(abs(dt)) if verbose_progress: pbar.close() From 1ef59ac3ceaa7f40d92379581c5a35df2bf7ef20 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Thu, 19 Jun 2025 21:47:25 -0400 Subject: [PATCH 06/12] Put stommel gyre u,v,p fields on faces; consistent with fesom --- parcels/_datasets/unstructured/generic.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/parcels/_datasets/unstructured/generic.py b/parcels/_datasets/unstructured/generic.py index ac9542638d..58481564f7 100644 --- a/parcels/_datasets/unstructured/generic.py +++ b/parcels/_datasets/unstructured/generic.py @@ -7,7 +7,7 @@ __all__ = ["Nx", "datasets"] T = 13 -Nx = 20 +Nx = 10 vmax = 1.0 delta = 0.1 TIME = xr.date_range("2000", "2001", T) @@ -44,12 +44,12 @@ def _stommel_gyre_delaunay(): uxgrid.attrs["Conventions"] = "UGRID-1.0" # Define arrays U (zonal), V (meridional) and P (sea surface height) - U = np.zeros((1, nz1, lat.size), dtype=np.float64) - V = np.zeros((1, nz1, lat.size), dtype=np.float64) + U = np.zeros((1, nz1, uxgrid.n_face), dtype=np.float64) + V = np.zeros((1, nz1, uxgrid.n_face), dtype=np.float64) W = np.zeros((1, nz, lat.size), dtype=np.float64) - P = np.zeros((1, nz1, lat.size), dtype=np.float64) + P = np.zeros((1, nz1, uxgrid.n_face), dtype=np.float64) - for i, (x, y) in enumerate(zip(lon_flat, lat_flat, strict=False)): + for i, (x, y) in enumerate(zip(uxgrid.face_lon, uxgrid.face_lat, strict=False)): xi = x / 60.0 yi = y / 60.0 @@ -61,10 +61,10 @@ def _stommel_gyre_delaunay(): data=U, name="U", uxgrid=uxgrid, - dims=["time", "nz1", "n_node"], + dims=["time", "nz1", "n_face"], coords=dict( time=(["time"], [TIME[0]]), - nz1=(["nz1"], [0]), + nz1=(["nz1"], zc), ), attrs=dict( description="zonal velocity", units="m/s", location="node", mesh="delaunay", Conventions="UGRID-1.0" @@ -74,7 +74,7 @@ def _stommel_gyre_delaunay(): data=V, name="V", uxgrid=uxgrid, - dims=["time", "nz1", "n_node"], + dims=["time", "nz1", "n_face"], coords=dict( time=(["time"], [TIME[0]]), nz1=(["nz1"], zc), @@ -100,7 +100,7 @@ def _stommel_gyre_delaunay(): data=P, name="p", uxgrid=uxgrid, - dims=["time", "nz1", "n_node"], + dims=["time", "nz1", "n_face"], coords=dict( time=(["time"], [TIME[0]]), nz1=(["nz1"], zc), From 4ab8494427c58d315b61870ceea4be8759d85dba Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Thu, 19 Jun 2025 21:48:19 -0400 Subject: [PATCH 07/12] Change to timedelta/datetime for time management --- parcels/particleset.py | 61 +++++++++++++++++++++++++++++++----------- 1 file changed, 45 insertions(+), 16 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index da04854b3e..3c9d3ecf49 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -1,7 +1,7 @@ import sys import warnings from collections.abc import Iterable -from datetime import date, datetime +from datetime import date, datetime, timedelta import numpy as np import xarray as xr @@ -14,7 +14,7 @@ from parcels.grid import GridType from parcels.interaction.interactionkernel import InteractionKernel from parcels.kernel import Kernel -from parcels.particle import Particle +from parcels.particle import Particle, Variable from parcels.particledata import ParticleData, ParticleDataIterator from parcels.particlefile import ParticleFile from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array @@ -96,6 +96,37 @@ def __init__( self.fieldset = fieldset self._pclass = pclass + # ==== first: create a new subclass of the pclass that includes the required variables ==== # + # ==== see dynamic-instantiation trick here: https://www.python-course.eu/python3_classes_and_type.php ==== # + class_name = pclass.__name__ + array_class = None + if class_name not in dir(): + + def ArrayClass_init(self, *args, **kwargs): + fieldset = kwargs.get("fieldset", None) + ngrids = kwargs.get("ngrids", None) + if type(self).ngrids.initial < 0: + numgrids = ngrids + if numgrids is None and fieldset is not None: + numgrids = fieldset.gridset_size + assert numgrids is not None, "Neither fieldsets nor number of grids are specified - exiting." + type(self).ngrids.initial = numgrids + self.ngrids = type(self).ngrids.initial + if self.ngrids >= 0: + self.ei = np.zeros(self.ngrids, dtype=np.int32) + super(type(self), self).__init__(*args, **kwargs) + + array_class_vdict = { + "ngrids": Variable("ngrids", dtype=np.int32, to_write=False, initial=-1), + "ei": Variable("ei", dtype=np.int32, to_write=False), + "__init__": ArrayClass_init, + } + array_class = type(class_name, (pclass,), array_class_vdict) + else: + array_class = locals()[class_name] + # ==== dynamic re-classing completed ==== # + _pclass = array_class + lon = np.empty(shape=0) if lon is None else convert_to_flat_array(lon) lat = np.empty(shape=0) if lat is None else convert_to_flat_array(lat) @@ -140,7 +171,7 @@ def __init__( ), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths." self.particledata = ParticleData( - self._pclass, + _pclass, lon=lon, lat=lat, depth=depth, @@ -742,8 +773,8 @@ def set_variable_write_status(self, var, write_status): def execute( self, - endtime: np.timedelta64 | np.datetime64, - dt: np.float64 | np.float32 | np.timedelta64, + endtime: timedelta | datetime, + dt: np.float64 | np.float32 | timedelta, pyfunc=AdvectionRK4, output_file=None, verbose_progress=True, @@ -799,27 +830,27 @@ def execute( fieldset_timeinterval = self.fieldset.time_interval if fieldset_timeinterval is None: - if isinstance(endtime, np.datetime64): + if isinstance(endtime, datetime): raise NotImplementedError( - "If fieldset.time_interval is None, endtime must be a np.timedelta64 not a np.datetime64" + "If fieldset.time_interval is None, endtime must be a timedelta not a datetime" ) - duration = endtime / np.timedelta64(1, "s") # converts np.timedelta64 to seconds as float64 + duration = endtime.total_seconds() # converts timedelta to seconds as float64 else: # Get the particle time interval - if isinstance(endtime, np.datetime64): - simulation_endtime = np.min(fieldset_timeinterval[1], endtime) + if isinstance(endtime, datetime): + simulation_endtime = min(fieldset_timeinterval[1], endtime) if simulation_endtime < fieldset_timeinterval[1]: print( f"Simulation endtime is limited by fieldset.time_interval. End time adjusted to {simulation_endtime}" ) - duration = (simulation_endtime - fieldset_timeinterval[0]) / np.timedelta64(1, "s") + duration = (simulation_endtime - fieldset_timeinterval[0]).total_seconds() else: - duration = endtime / np.timedelta64(1, "s") + duration = endtime.total_seconds() - if isinstance(dt, np.datetime64): - dt = dt / np.timedelta64(1, "s") # convert to seconds as float64 + if isinstance(dt, timedelta): + dt = dt.total_seconds() # convert to seconds as float64 outputdt = output_file.outputdt if output_file else None @@ -839,14 +870,12 @@ def execute( tol = 1e-12 time = 0.0 - while time < duration and dt > 0: # Forward in time only for now # Check if we can fast-forward to the next time needed for the particles # if dt > 0: # skip_kernel = True if duration > (time + dt) else False # else: # skip_kernel = True if max(self.time) < (time + dt) else False - t0 = time next_time = t0 + dt res = self._kernel.execute(self, endtime=next_time, dt=dt) From 00ed176f0e066188609edaec5b7d4001a01c68ac Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Thu, 19 Jun 2025 21:49:02 -0400 Subject: [PATCH 08/12] Disable reusing ei temporarily Currently, we don't have support for igrid selection --- parcels/field.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 8b09239b99..51ade1e263 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -309,10 +309,10 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): conversion to the result. Note that we defer to scipy.interpolate to perform spatial interpolation. """ - if particle is None: - _ei = None - else: - _ei = particle.ei[self.igrid] + # if particle is None: + _ei = None + # else: + # _ei = particle.ei[self.igrid] try: tau, ti = _search_time_index(self, time) @@ -438,10 +438,10 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): conversion to the result. Note that we defer to scipy.interpolate to perform spatial interpolation. """ - if particle is None: - _ei = None - else: - _ei = particle.ei[self.igrid] + # if particle is None: + _ei = None + # else: + # _ei = particle.ei[self.igrid] try: tau, ti = _search_time_index(self.U, time) @@ -454,16 +454,17 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): else: (u, v, w) = self._vector_interp_method(self, ti, _ei, bcoords, time, z, y, x) + # print(u,v) if applyConversion: u = self.U.units.to_target(u, z, y, x) v = self.V.units.to_target(v, z, y, x) if "3D" in self.vector_type: w = self.W.units.to_target(w, z, y, x) if self.W else 0.0 + + if "3D" in self.vector_type: + return (u, v, w) else: - if "3D" in self.vector_type: - return (u, v, w) - else: - return (u, v, 0) + return (u, v) except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e: e.add_note(f"Error interpolating field '{self.name}'.") @@ -472,7 +473,7 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): def __getitem__(self, key): try: if _isParticle(key): - return self.eval(key.time, key.depth, key.lat, key.lon, key.ei) + return self.eval(key.time, key.depth, key.lat, key.lon, key) else: return self.eval(*key) except tuple(AllParcelsErrorCodes.keys()) as error: From 0c692ae3edeb6785ff87c51a7895ff6a306476bb Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Thu, 19 Jun 2025 21:49:39 -0400 Subject: [PATCH 09/12] Return a tuple, even when no time interval is present This prevents errors in the field.eval calls for constant in time fields --- parcels/_index_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 08966d8fd2..805e5e1eb8 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -38,7 +38,7 @@ def _search_time_index(field: Field, time: datetime): if the sampled value is outside the time value range. """ if field.time_interval is None: - return 0 + return 0, 0 if time not in field.time_interval: _raise_time_extrapolation_error(time, field=None) From 72549e7ccc845a675616e41793e07bb3aaac07cc Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Thu, 19 Jun 2025 23:16:12 -0400 Subject: [PATCH 10/12] Update stommel uxarray notebook for demo --- docs/examples/tutorial_stommel_uxarray.ipynb | 439 ++++++++----------- 1 file changed, 173 insertions(+), 266 deletions(-) diff --git a/docs/examples/tutorial_stommel_uxarray.ipynb b/docs/examples/tutorial_stommel_uxarray.ipynb index d5bc4317cf..527b66915c 100644 --- a/docs/examples/tutorial_stommel_uxarray.ipynb +++ b/docs/examples/tutorial_stommel_uxarray.ipynb @@ -5,7 +5,19 @@ "metadata": {}, "source": [ "# Stommel Gyre on Unstructured Grid\n", - "This tutorial walks through creating a UXArray dataset using the Stommel Gyre analytical solution for a closed rectangular domain on a beta-plane" + "This tutorial walks a simple example of using Parcels for particle advection on an unstructured grid. The purpose of this tutorial is to introduce you to the new way fields and fieldsets can be instantiated in Parcels using UXArray DataArrays and UXArray grids.\n", + "\n", + "We focus on a simple example, using constant-in-time velocity and pressure fields for the classic barotropic Stommel Gyre. This example dataset is included in Parcels' new `parcels._datasets` module. This module provides example XArray and UXArray datasets that are compatible with Parcels and mimic the way many general circulation model outputs are represented in (U)XArray. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading the example dataset\n", + "Creating a particle simulation starts with defining a dataset that contains the fields that will be used to influence particle attributes, such as position, through kernels. In this example, we focus on advection. Because of this, the dataset we're using will provide velocity fields for our simulation.\n", + "\n", + "Parcels now includes pre-canned example datasets to demonstrate the schema of XArray and UXArray datasets that are compatible with Parcels. For unstructured grid datasets, you can use the `parcels._datasets.unstructured.generic.datasets` dictionary to see which datasets are available for unstructured grids." ] }, { @@ -14,122 +26,16 @@ "metadata": {}, "outputs": [], "source": [ - "def stommel_fieldset_uxarray(xdim=200, ydim=200):\n", - " \"\"\"Simulate a periodic current along a western boundary, with significantly\n", - " larger velocities along the western edge than the rest of the region\n", - "\n", - " The original test description can be found in: N. Fabbroni, 2009,\n", - " Numerical Simulation of Passive tracers dispersion in the sea,\n", - " Ph.D. dissertation, University of Bologna\n", - " http://amsdottorato.unibo.it/1733/1/Fabbroni_Nicoletta_Tesi.pdf\n", - " \"\"\"\n", - " import math\n", - "\n", - " import numpy as np\n", - " import pandas as pd\n", - " import uxarray as ux\n", - "\n", - " a = b = 66666 * 1e3\n", - " scalefac = 0.00025 # to scale for physically meaningful velocities\n", - "\n", - " # Coordinates of the test fieldset\n", - " # Crowd points to the west edge of the domain\n", - " # using a polyonmial map on x-direction\n", - " x = np.linspace(0, 1, xdim, dtype=np.float32)\n", - " lon, lat = np.meshgrid(a * x, np.linspace(0, b, ydim, dtype=np.float32))\n", - " points = (lon.flatten() / 1111111.111111111, lat.flatten() / 1111111.111111111)\n", + "from parcels._datasets.unstructured.generic import datasets as datasets_unstructured\n", "\n", - " # Create the grid\n", - " uxgrid = ux.Grid.from_points(points, method=\"regional_delaunay\")\n", - " uxgrid.construct_face_centers()\n", - "\n", - " # Define arrays U (zonal), V (meridional) and P (sea surface height)\n", - " U = np.zeros((1, 1, lat.size), dtype=np.float32)\n", - " V = np.zeros((1, 1, lat.size), dtype=np.float32)\n", - " P = np.zeros((1, 1, lat.size), dtype=np.float32)\n", - "\n", - " beta = 2e-11\n", - " r = 1 / (11.6 * 86400)\n", - " es = r / (beta * a)\n", - "\n", - " i = 0\n", - " for x, y in zip(lon.flatten(), lat.flatten()):\n", - " xi = x / a\n", - " yi = y / b\n", - " P[0, 0, i] = (\n", - " (1 - math.exp(-xi / es) - xi) * math.pi * np.sin(math.pi * yi) * scalefac\n", - " )\n", - " U[0, 0, i] = (\n", - " -(1 - math.exp(-xi / es) - xi)\n", - " * math.pi**2\n", - " * np.cos(math.pi * yi)\n", - " * scalefac\n", - " )\n", - " V[0, 0, i] = (\n", - " (math.exp(-xi / es) / es - 1) * math.pi * np.sin(math.pi * yi) * scalefac\n", - " )\n", - " i += 1\n", - "\n", - " u = ux.UxDataArray(\n", - " data=U,\n", - " name=\"u\",\n", - " uxgrid=uxgrid,\n", - " dims=[\"time\", \"nz1\", \"n_node\"],\n", - " coords=dict(\n", - " time=([\"time\"], pd.to_datetime([\"2000-01-01\"])),\n", - " nz1=([\"nz1\"], [0]),\n", - " ),\n", - " attrs=dict(\n", - " description=\"zonal velocity\",\n", - " units=\"m/s\",\n", - " location=\"node\",\n", - " mesh=\"delaunay\",\n", - " ),\n", - " )\n", - " v = ux.UxDataArray(\n", - " data=V,\n", - " name=\"v\",\n", - " uxgrid=uxgrid,\n", - " dims=[\"time\", \"nz1\", \"n_node\"],\n", - " coords=dict(\n", - " time=([\"time\"], pd.to_datetime([\"2000-01-01\"])),\n", - " nz1=([\"nz1\"], [0]),\n", - " ),\n", - " attrs=dict(\n", - " description=\"meridional velocity\",\n", - " units=\"m/s\",\n", - " location=\"node\",\n", - " mesh=\"delaunay\",\n", - " ),\n", - " )\n", - " p = ux.UxDataArray(\n", - " data=P,\n", - " name=\"p\",\n", - " uxgrid=uxgrid,\n", - " dims=[\"time\", \"nz1\", \"n_node\"],\n", - " coords=dict(\n", - " time=([\"time\"], pd.to_datetime([\"2000-01-01\"])),\n", - " nz1=([\"nz1\"], [0]),\n", - " ),\n", - " attrs=dict(\n", - " description=\"pressure\",\n", - " units=\"N/m^2\",\n", - " location=\"node\",\n", - " mesh=\"delaunay\",\n", - " ),\n", - " )\n", - "\n", - " return ux.UxDataset({\"u\": u, \"v\": v, \"p\": p}, uxgrid=uxgrid)\n", - "\n", - "\n", - "uxds = stommel_fieldset_uxarray(50, 50)\n", - "\n", - "uxds.uxgrid.plot(\n", - " line_width=0.5,\n", - " height=500,\n", - " width=1000,\n", - " title=\"Regional Delaunay Regions\",\n", - ")" + "datasets_unstructured.keys()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example, we'll be using the stommel_gyre_delaunay example dataset. This dataset is created by generating a delaunay triangulation of a uniform grid of points in a square domain $x \\in [0,60^\\circ] \\times [0,60^\\circ]$. There is a single vertical layer that is 1000m thick. This layer is defined by the layer surfaces $z_f = 0$ and $z_f = 1000$." ] }, { @@ -138,136 +44,58 @@ "metadata": {}, "outputs": [], "source": [ - "def stommel_fieldset_xarray(xdim=200, ydim=200, grid_type=\"A\"):\n", - " \"\"\"Simulate a periodic current along a western boundary, with significantly\n", - " larger velocities along the western edge than the rest of the region\n", - "\n", - " The original test description can be found in: N. Fabbroni, 2009,\n", - " Numerical Simulation of Passive tracers dispersion in the sea,\n", - " Ph.D. dissertation, University of Bologna\n", - " http://amsdottorato.unibo.it/1733/1/Fabbroni_Nicoletta_Tesi.pdf\n", - " \"\"\"\n", - " import math\n", - "\n", - " import numpy as np\n", - " import pandas as pd\n", - " import xarray as xr\n", - "\n", - " a = b = 10000 * 1e3\n", - " scalefac = 0.05 # to scale for physically meaningful velocities\n", - " dx, dy = a / xdim, b / ydim\n", - "\n", - " # Coordinates of the test fieldset (on A-grid in deg)\n", - " lon = np.linspace(0, a, xdim, dtype=np.float32)\n", - " lat = np.linspace(0, b, ydim, dtype=np.float32)\n", - "\n", - " # Define arrays U (zonal), V (meridional) and P (sea surface height)\n", - " U = np.zeros((1, 1, lat.size, lon.size), dtype=np.float32)\n", - " V = np.zeros((1, 1, lat.size, lon.size), dtype=np.float32)\n", - " P = np.zeros((1, 1, lat.size, lon.size), dtype=np.float32)\n", + "ds = datasets_unstructured[\"stommel_gyre_delaunay\"]\n", + "ds" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the dataset, we have the following dimensions\n", "\n", - " beta = 2e-11\n", - " r = 1 / (11.6 * 86400)\n", - " es = r / (beta * a)\n", + "* `time: 1` - The number of time levels that the variables in this dataset are defined at. \n", + "* `nz1: 1` - The number of vertical layers. The `nz1` dimension is associated with the `nz1` coordinate that defines the vertical position of the center of each vertical layer. The `nz1` coordinate consists of non-negative values that are assumed to increase with `nz1` dimension index.\n", + "* `n_face: 721` - The number of 2-d unstructured grid faces in the `UXArray.grid`\n", + "* `nz: 2` - The number of vertical layer interfaces. The `nz` dimension is associated with the `nz` coordinate that defines the vertical positions of the interfaces of each vertical layer. The `nz` coordinate consists of non-negative values that are assumed to increase with `nz` dimension index. Note that the number of layer interfaces is always the number of layers plus one.\n", + "* `n_node: 400` - The number of corner node vertices in the grid.\n", "\n", - " for j in range(lat.size):\n", - " for i in range(lon.size):\n", - " xi = lon[i] / a\n", - " yi = lat[j] / b\n", - " P[..., j, i] = (\n", - " (1 - math.exp(-xi / es) - xi)\n", - " * math.pi\n", - " * np.sin(math.pi * yi)\n", - " * scalefac\n", - " )\n", - " if grid_type == \"A\":\n", - " U[..., j, i] = (\n", - " -(1 - math.exp(-xi / es) - xi)\n", - " * math.pi**2\n", - " * np.cos(math.pi * yi)\n", - " * scalefac\n", - " )\n", - " V[..., j, i] = (\n", - " (math.exp(-xi / es) / es - 1)\n", - " * math.pi\n", - " * np.sin(math.pi * yi)\n", - " * scalefac\n", - " )\n", + "Whenever you are building a UXArray dataset for use in Parcels, its important to keep in mind that these dimensions and coordinates are assumed to exist for your dataset. Further, it is highly recommended that you use UXArray when possible to load unstructured general circulation model data when possible. This ensures that other characteristics, such as the counterclockwise ordering of vertices for each element, are defined properly for use in Parcels." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defining a Grid, Fields, and Vector Fields\n", "\n", - " time = pd.to_datetime([\"2000-01-01\"])\n", - " z = [0]\n", - " if grid_type == \"C\":\n", - " V[..., :, 1:] = (P[..., :, 1:] - P[..., :, 0:-1]) / dx * a\n", - " U[..., 1:, :] = -(P[..., 1:, :] - P[..., 0:-1, :]) / dy * b\n", - " u_dims = [\"time\", \"nz1\", \"face_lat\", \"node_lon\"]\n", - " u_lat = lat\n", - " u_lon = lon - dx * 0.5\n", - " u_location = \"x_edge\"\n", - " v_dims = [\"time\", \"nz1\", \"node_lat\", \"face_lon\"]\n", - " v_lat = lat - dy * 0.5\n", - " v_lon = lon\n", - " v_location = \"y_edge\"\n", - " p_dims = [\"time\", \"nz1\", \"face_lat\", \"face_lon\"]\n", - " p_lat = lat\n", - " p_lon = lon\n", - " p_location = \"face\"\n", + "A `UXArray.Dataset` consists of multiple `UXArray.UxDataArray`'s and a `UXArray.UxGrid`. Parcels views general circulation model data through the `Field` and `VectorField` classes. A `Field` is defined by its `name`, `data`, `grid`, and `interp_method`. A `VectorField` can be constructed by using 2 or 3 `Field`'s. The `Field.data` attribute can be either an `XArray.DataArray` or `UXArray.UxDataArray` object. The `Field.grid` attribute is of type `Parcels.XGrid` or `Parcels.UXGrid`. Last, the `interp_method` is a dynamic function that can be set at runtime to define the interpolation procedure for the `Field`. This gives you the flexibility to use one of the pre-defined interpolation methods included with Parcels v4, or to create your own interpolator. \n", "\n", - " else:\n", - " u_dims = [\"time\", \"nz1\", \"node_lat\", \"node_lon\"]\n", - " v_dims = [\"time\", \"nz1\", \"node_lat\", \"node_lon\"]\n", - " p_dims = [\"time\", \"nz1\", \"node_lat\", \"node_lon\"]\n", - " u_lat = lat\n", - " u_lon = lon\n", - " v_lat = lat\n", - " v_lon = lon\n", - " u_location = \"node\"\n", - " v_location = \"node\"\n", - " p_lat = lat\n", - " p_lon = lon\n", - " p_location = \"node\"\n", - "\n", - " u = xr.DataArray(\n", - " data=U,\n", - " name=\"u\",\n", - " dims=u_dims,\n", - " coords=[time, z, u_lat, u_lon],\n", - " attrs=dict(\n", - " description=\"zonal velocity\",\n", - " units=\"m/s\",\n", - " location=u_location,\n", - " mesh=f\"Arakawa-{grid_type}\",\n", - " ),\n", - " )\n", - " v = xr.DataArray(\n", - " data=V,\n", - " name=\"v\",\n", - " dims=v_dims,\n", - " coords=[time, z, v_lat, v_lon],\n", - " attrs=dict(\n", - " description=\"meridional velocity\",\n", - " units=\"m/s\",\n", - " location=v_location,\n", - " mesh=f\"Arakawa-{grid_type}\",\n", - " ),\n", - " )\n", - " p = xr.DataArray(\n", - " data=P,\n", - " name=\"p\",\n", - " dims=p_dims,\n", - " coords=[time, z, p_lat, p_lon],\n", - " attrs=dict(\n", - " description=\"pressure\",\n", - " units=\"N/m^2\",\n", - " location=p_location,\n", - " mesh=f\"Arakawa-{grid_type}\",\n", - " ),\n", - " )\n", + "The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from parcels.uxgrid import UxGrid\n", "\n", - " return xr.Dataset({\"u\": u, \"v\": v, \"p\": p})\n", + "grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"])\n", + "# You can view the uxgrid object with the following command:\n", + "grid.uxgrid" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the `UxGrid` object defined, we can now define our `Field` objects, provided we can align a suitable interpolator what that `Field`. Aligning an interpolator requires you to be cognizant of the location that each `DataArray` is associated with. Since Parcels v4 provides flexibility to customize your interpolation scheme, care must be taken when pairing an interpolation scheme with a field. On unstructured grids, data is typically registered to \"nodes\", \"faces\", or \"edges\". For example, with FESOM2 data, `u` and `v` velocity components are face registered while the vertical velocity component `w` is node registered.\n", "\n", + "In Parcels, grid searching is conducted with respect to the faces. In other words, when a grid index `ei` is provided to an interpolation method, this refers the face index `fi` at vertical layer `zi` (when unraveled). Within the interpolation method, the `field.grid.uxgrid.face_node_connectivity` attribute can be used to obtain the node indices that surround the face. Using these connectivity tables is necessary for properly indexing node registered data.\n", "\n", - "ds_arakawa_a = stommel_fieldset_xarray(50, 50, \"A\")\n", - "ds_arakawa_c = stommel_fieldset_xarray(50, 50, \"C\")" + "For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method. Setting the `mesh_type` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees." ] }, { @@ -276,7 +104,37 @@ "metadata": {}, "outputs": [], "source": [ - "ds_arakawa_a" + "from parcels.application_kernels.interpolation import UXPiecewiseConstantFace\n", + "from parcels.field import Field\n", + "\n", + "U = Field(\n", + " name=\"U\",\n", + " data=ds.U,\n", + " grid=grid,\n", + " mesh_type=\"spherical\",\n", + " interp_method=UXPiecewiseConstantFace,\n", + ")\n", + "V = Field(\n", + " name=\"V\",\n", + " data=ds.V,\n", + " grid=grid,\n", + " mesh_type=\"spherical\",\n", + " interp_method=UXPiecewiseConstantFace,\n", + ")\n", + "P = Field(\n", + " name=\"P\",\n", + " data=ds.p,\n", + " grid=grid,\n", + " mesh_type=\"spherical\",\n", + " interp_method=UXPiecewiseConstantFace,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we've defined the `U` and `V` fields, we can define a `VectorField`. The `VectorField` is created in a similar manner, except that it is initialized with `Field` objects. You can optionally define an `interp_method` on the `VectorField`. When this is done, the `VectorField.interp_method` is used for interpolation; otherwise, evaluation of the `VectorField` is done component-wise using the `interp_method` associated with each component separately." ] }, { @@ -285,7 +143,17 @@ "metadata": {}, "outputs": [], "source": [ - "ds_arakawa_a[\"u\"].attrs" + "from parcels.field import VectorField\n", + "\n", + "UV = VectorField(name=\"UV\", U=U, V=V)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defining the FieldSet\n", + "With all of the fields defined, that we want for this simulation, we can now create the `FieldSet`. As the name suggests, the `FieldSet` is the set of all `Field`s that will be used for a particle simulation. A `FieldSet` is initialized with a list of `Field` objects" ] }, { @@ -294,7 +162,35 @@ "metadata": {}, "outputs": [], "source": [ - "ds_arakawa_c" + "from parcels.fieldset import FieldSet\n", + "\n", + "fieldset = FieldSet([UV, UV.U, UV.V, P])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting your own custom interpolator\n", + "You may be wondering how to set your own custom interpolator. In Parcels v4, this is as simple as defining a function that matches a specific API. The API you need to match is defined in the `field.py` module in the `Field._interp_template` and `VectorField._interp_template`. Specifically,\n", + "\n", + "```python\n", + "def _interp_template(\n", + " self, # Field or VectorField\n", + " ti: int, # Time index\n", + " ei: int, # Flat grid index\n", + " bcoords: np.ndarray, # Barycentric coordinates relative to the cell vertices\n", + " tau: np.float32 | np.float64, # Time interpolation weight\n", + " t: np.float32 | np.float64, # Current simulation time\n", + " z: np.float32 | np.float64, # Current particle depth\n", + " y: np.float32 | np.float64, # Current particle y-position\n", + " x: np.float32 | np.float64, # Current particle x-position\n", + " ) -> np.float32 | np.float64 # For `Field`, returns a float value.\n", + "```\n", + "\n", + "So long as your function matches this API, you can define such a function and set the `Field.interp_method` to that function.\n", + "\n", + "\n" ] }, { @@ -305,38 +201,49 @@ "source": [ "import numpy as np\n", "\n", - "min_length_scale = 1111111.111111111 * np.sqrt(np.min(uxds.uxgrid.face_areas))\n", - "print(min_length_scale)\n", "\n", - "max_v = np.sqrt(uxds[\"u\"] ** 2 + uxds[\"v\"] ** 2).max()\n", - "print(max_v)\n", + "def my_custom_interpolator(\n", + " self,\n", + " ti: int,\n", + " ei: int,\n", + " bcoords: np.ndarray,\n", + " tau: np.float32 | np.float64,\n", + " t: np.float32 | np.float64,\n", + " z: np.float32 | np.float64,\n", + " y: np.float32 | np.float64,\n", + " x: np.float32 | np.float64,\n", + ") -> np.float32 | np.float64:\n", + " \"\"\"Custom interpolation method for the P field.\n", + " This method interpolates the value at a face by averaging the values of its neighboring faces.\n", + " While this may be nonsense, it demonstrates how to create a custom interpolation method.\"\"\"\n", + "\n", + " zi, fi = self.grid.unravel_index(ei)\n", + " neighbors = self.grid.uxgrid.face_face_connectivity[fi]\n", + " f_at_neighbors = self.data.values[ti, zi, neighbors]\n", + " # Interpolate using the average of the neighboring face values\n", + " if len(f_at_neighbors) > 0:\n", + " return np.mean(f_at_neighbors)\n", + " # If no neighbors, return the value at the face itself\n", + " else:\n", + " return self.data.values[ti, zi, fi]\n", + "\n", "\n", - "cfl = 0.1\n", - "dt = cfl * min_length_scale / max_v\n", - "print(dt)" + "# Assign the custom interpolator to the P field\n", + "P.interp_method = my_custom_interpolator" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "from datetime import timedelta\n", - "\n", - "import numpy as np\n", - "import uxarray as ux\n", + "## Understanding the context inside an interpolator method\n", + "Providing the `Field` object as an input to an interpolator exposes you to a ton of useful information and methods for building complex interpolators. Particularly, the `Field.grid` attribute gives you access to connectivity tables and metric terms that you may find useful for constructing an interpolator. For context, the `Parcels.UXGrid` class is built on top of the `Parcels.BaseGrid` class (much likes it's structured grid `Parcels.XGrid` counterpart). The `Parcels.UXGrid` class combines a `UXArray.grid` object alongside the vertical layer interfaces, which provides sufficient information to define the API that the `BaseGrid` class demands. This includes\n", "\n", - "from parcels import Particle, ParticleSet, UxAdvectionEuler, UXFieldSet\n", + "* `search` - A method for returning a flat grid index `ei` for a position `(x,y,z)`\n", + "* `ravel_index` - A method for converting a face index `fi` and a vertical layer index `zi` into a single flat grid index `ei`\n", + "* `unravel_index` - A method for converted a single flat grid index `ei` into a face index `fi` and a vertical layer index `zi`\n", "\n", - "npart = 10\n", - "fieldset = UXFieldSet(uxds)\n", - "# pset = ParticleSet(\n", - "# fieldset,\n", - "# pclass=Particle,\n", - "# lon=np.linspace(1, 59, npart),\n", - "# lat=np.zeros(npart)+30)\n", - "# pset.execute(UxAdvectionEuler, runtime=timedelta(hours=24), dt=timedelta(seconds=dt))" + "The `ravel/unravel` methods are a necessity for most interpolators. For unstructured grids, the `Field.grid.uxgrid` attribute give you access to all of the attributes associated with a `UxArray.grid` object (See https://uxarray.readthedocs.io/en/latest/api.html#grid for more details.)" ] } ], @@ -356,7 +263,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.2" + "version": "3.12.10" } }, "nbformat": 4, From b6bd088f8e8592edbd68337d956a284cd330bba8 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Fri, 20 Jun 2025 08:38:38 -0400 Subject: [PATCH 11/12] Add cell for executing the particleset --- docs/examples/tutorial_stommel_uxarray.ipynb | 21 ++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/docs/examples/tutorial_stommel_uxarray.ipynb b/docs/examples/tutorial_stommel_uxarray.ipynb index 527b66915c..70a05a0db4 100644 --- a/docs/examples/tutorial_stommel_uxarray.ipynb +++ b/docs/examples/tutorial_stommel_uxarray.ipynb @@ -245,6 +245,27 @@ "\n", "The `ravel/unravel` methods are a necessity for most interpolators. For unstructured grids, the `Field.grid.uxgrid` attribute give you access to all of the attributes associated with a `UxArray.grid` object (See https://uxarray.readthedocs.io/en/latest/api.html#grid for more details.)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Running the forward integration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime, timedelta\n", + "\n", + "from parcels import Particle, ParticleSet\n", + "\n", + "pset = ParticleSet(fieldset, lon=30.0, lat=5.0, depth=50.0, pclass=Particle)\n", + "pset.execute(endtime=timedelta(seconds=30), dt=timedelta(seconds=1))" + ] } ], "metadata": { From 9677e08bc0e5a2cca5398ad45dfbc51591dc3b9e Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Thu, 26 Jun 2025 22:00:54 -0400 Subject: [PATCH 12/12] Change number of grid vertices back to 20x20 for test --- parcels/_datasets/unstructured/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parcels/_datasets/unstructured/generic.py b/parcels/_datasets/unstructured/generic.py index 58481564f7..a18309c15b 100644 --- a/parcels/_datasets/unstructured/generic.py +++ b/parcels/_datasets/unstructured/generic.py @@ -7,7 +7,7 @@ __all__ = ["Nx", "datasets"] T = 13 -Nx = 10 +Nx = 20 vmax = 1.0 delta = 0.1 TIME = xr.date_range("2000", "2001", T)