Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 9 additions & 21 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,8 +875,6 @@ def execute(
dt: float | timedelta | np.timedelta64 = 1.0,
output_file=None,
verbose_progress=True,
postIterationCallbacks=None,
callbackdt: float | timedelta | np.timedelta64 | None = None,
):
"""Execute a given kernel function over the particle set for multiple timesteps.

Expand All @@ -903,16 +901,12 @@ def execute(
mod:`parcels.particlefile.ParticleFile` object for particle output (Default value = None)
verbose_progress : bool
Boolean for providing a progress bar for the kernel execution loop. (Default value = True)
postIterationCallbacks :
Optional, array of functions that are to be called after each iteration (post-process, non-Kernel) (Default value = None)
callbackdt :
Optional, in conjecture with 'postIterationCallbacks', timestep interval to (latest) interrupt the running kernel and invoke post-iteration callbacks from 'postIterationCallbacks' (Default value = None)
pyfunc_inter :
(Default value = None)

Notes
-----
``ParticleSet.execute()`` acts as the main entrypoint for simulations, and provides the simulation time-loop. This method encapsulates the logic controlling the switching between kernel execution, output file writing, reading in fields for new timesteps, adding new particles to the simulation domain, stopping the simulation, and executing custom functions (``postIterationCallbacks`` provided by the user).
``ParticleSet.execute()`` acts as the main entrypoint for simulations, and provides the simulation time-loop. This method encapsulates the logic controlling the switching between kernel execution, output file writing, reading in fields for new timesteps, adding new particles to the simulation domain, and stopping the simulation.
"""
# check if particleset is empty. If so, return immediately
if len(self) == 0:
Expand Down Expand Up @@ -958,9 +952,6 @@ def execute(
raise ValueError("Output interval should not have finer precision than 1e-6 s")
outputdt = timedelta_to_float(output_file.outputdt) if output_file else np.inf

if callbackdt is not None:
callbackdt = timedelta_to_float(callbackdt)

assert runtime is None or runtime >= 0, "runtime must be positive"
assert outputdt is None or outputdt >= 0, "outputdt must be positive"

Expand Down Expand Up @@ -999,11 +990,11 @@ def execute(

self.particledata._data["dt"][:] = dt

if callbackdt is None:
interupt_dts = [np.inf, outputdt]
if self.repeatdt is not None:
interupt_dts.append(self.repeatdt)
callbackdt = np.min(np.array(interupt_dts))
# TODO: Nick remove?
interupt_dts = [np.inf, outputdt]
if self.repeatdt is not None:
interupt_dts.append(self.repeatdt)
callbackdt = np.min(np.array(interupt_dts))

# Set up pbar
if output_file:
Expand All @@ -1023,7 +1014,7 @@ def execute(
next_output = starttime + dt
else:
next_output = np.inf * np.sign(dt)
next_callback = starttime + callbackdt * np.sign(dt)
next_callback = starttime + callbackdt * np.sign(dt) # TODO: Nick remove?

tol = 1e-12
time = starttime
Expand All @@ -1041,9 +1032,9 @@ def execute(

# Define next_time (the timestamp when the execution needs to be handed back to python)
if dt > 0:
next_time = min(next_prelease, next_input, next_output, next_callback, endtime)
next_time = min(next_prelease, next_input, next_output, next_callback, endtime) # TODO: Nick remove?
else:
next_time = max(next_prelease, next_input, next_output, next_callback, endtime)
next_time = max(next_prelease, next_input, next_output, next_callback, endtime) # TODO: Nick remove?

# If we don't perform interaction, only execute the normal kernel efficiently.
if self._interaction_kernel is None:
Expand Down Expand Up @@ -1092,9 +1083,6 @@ def execute(

# ==== insert post-process here to also allow for memory clean-up via external func ==== #
if abs(time - next_callback) < tol:
if postIterationCallbacks is not None:
for extFunc in postIterationCallbacks:
extFunc()
next_callback += callbackdt * np.sign(dt)

if abs(time - next_prelease) < tol:
Expand Down