diff --git a/src/condor/contrib.py b/src/condor/contrib.py index bfabc9e7..eca90e17 100644 --- a/src/condor/contrib.py +++ b/src/condor/contrib.py @@ -1,6 +1,7 @@ """Built-in model templates""" import logging +import warnings from dataclasses import dataclass, field import ndsplines @@ -651,9 +652,28 @@ def __getitem__(self, idx): return new_self def resample(self, dt, include_output=True, include_events=True, max_deg=3): - """Re-sample the trajectory, to a grid based on evenly-spaced points. With - include_events=True, two points will be inserted for each internal event to get - the state immediately before and after the event.""" + """Re-sample the trajectory on a grid of evenly-spaced points + + Parameters + ---------- + dt : float + Sample spacing in the independent variable (usually time). + include_output : bool, optional + Include :attr:`~ODESystem.dynamic_output` in the returned result. + include_events : bool, optional + Include events regardless of whether or not they fall on a multiple of `dt`. + Two points will be inserted for each internal event to get the state + immediately before and after the event. If disabled and a sample coincides + exactly with an event, the state *after* the update is returned. + max_deg : int, optional + Maximum degree of the interpolating spline. Actual degree used in any given + segment between events may be fewer if there are not sufficient samples. + + Returns + ------- + new_sim : TrajectoryAnalysis + A new trajectory instance with the requested sample spacing. + """ original_instance = getattr(self, "_original_instance", self) if original_instance is not self: @@ -664,89 +684,105 @@ def resample(self, dt, include_output=True, include_events=True, max_deg=3): max_deg=max_deg, ) + if getattr(self.Options, "separate_events", False): + msg = "Resampling a trajectory with separate_events not yet supported" + raise NotImplementedError(msg) + model = self.__class__ if dt <= 0.0: return self new_self = model.__new__(model) - new_self.implementation = self.implementation + + # TODO: add option to rebuild the implemention + if (impl := getattr(self, "implementation", None)) is not None: + new_self.implementation = impl + elif include_output: # override include_output if implementation is not found + include_output = False + warnings.warn( + "Trajectory instances without an implementation currently do not " + "support dynamic output sampling. Set include_output=False to " + "suppress.", + stacklevel=2, + ) + new_self._original_instance = original_instance new_self.bind_field(self.parameter) new_self.input_kwargs = self.input_kwargs - t_grid = np.arange(self._res.t[0], self._res.t[-1], dt) - t_size = t_grid.size - if include_events: - t_size += 2 * len(self._res.e) - 1 - es = [] - elif t_grid[-1] + dt == self._res.t[-1]: - t_size += 1 + t0, tf = self.t[[0, -1]] + + # grid with endpoint if it coincides with tf + length = np.ceil((tf - t0) / dt) + if length * dt == tf: + t_grid = np.arange(t0, tf + dt, dt) + else: + t_grid = np.arange(t0, tf, dt) + + interp = ResultInterpolant(self._res, max_deg=3) + + new_e = [] + new_y = [] if not include_events: - es = None - # self.t = np.empty((t_size,)) - new_self.t = np.ones((t_size,)) * -1 # all self.t should go to new_self.t - new_self.t[: t_grid.size] = t_grid - if t_grid[-1] + dt == self._res.t[-1]: - new_self.t[t_grid.size] = t_grid[-1] + dt - - state_interp = ResultInterpolant(self._res, max_deg=max_deg) - xs = np.empty((t_size, model.state._count)) - include_output = include_output and model.dynamic_output._count - if include_output: - dynamic_output = self.implementation.state_system.dynamic_output - p = self._res.p - ys = np.empty((t_size, model.dynamic_output._count)) + new_t = t_grid + new_x = np.empty((t_grid.size, self._res.x.shape[1]), float) + for seg in interp: + i_to_samp = np.nonzero((t_grid >= seg.t0) & (t_grid < seg.t1)) + x_seg = seg(t_grid[i_to_samp]) + if x_seg.ndim == 1: + x_seg = x_seg[:, None] + new_x[i_to_samp] = x_seg + new_x[-1] = self._res.x[-1] else: - ys = None + all_et = self._res.t[[e.index for e in self._res.e]] + samps_and_es = np.intersect1d(t_grid, all_et, assume_unique=True) + n_samps = t_grid.size + 2 * all_et.size - samps_and_es.size - idx0 = 0 + new_t = np.empty(n_samps, float) + new_x = np.empty((n_samps, self._res.x.shape[1]), float) - for event, x_interp_segment in zip(self._res.e, state_interp): - t_select = np.where( - (new_self.t >= x_interp_segment.t0) - & (new_self.t <= x_interp_segment.t1) - ) - idx0 = t_select[0][0] - idx1 = t_select[0][-1] + 1 - - if include_events: - new_self.t[idx0 + 1 :] = new_self.t[idx0:-1] - xs[idx0, :] = self._res.x[x_interp_segment.idx0] - if include_output: - ys[idx0, :] = self._res.y[x_interp_segment.idx0] - es.append(Root(idx0, event.rootsfound)) - idx0 += 1 - idx1 += 1 - # TODO figure out how to get root info - - ts_to_call = new_self.t[idx0:idx1] - xs[idx0:idx1] = x_interp_segment(ts_to_call) - if include_output: - for idx, t, x in zip(range(idx0, idx1), ts_to_call, xs[idx0:idx1]): - ys[idx, None] = dynamic_output(p, t, x).T - - if include_events: - new_self.t[idx1 + 1 :] = new_self.t[idx1:-1] - new_self.t[idx1 : idx1 + 2] = self._res.t[x_interp_segment.idx1] - xs[idx1, :] = self._res.x[x_interp_segment.idx1] - if include_output: - ys[idx1, :] = self._res.y[x_interp_segment.idx1] - - if include_events: - xs[idx1 + 1, :] = self._res.x[x_interp_segment.idx1] - if include_output: - ys[idx1 + 1, :] = self._res.y[x_interp_segment.idx1] - es.append(Root(idx1, self._res.e[-1].rootsfound)) + new_t[[0, -1]] = t0, tf + new_x[[0, -1]] = self._res.x[[0, -1]] + idx0 = 1 + for seg, ev in zip(interp, self._res.e, strict=False): + i_to_samp = np.nonzero((t_grid > seg.t0) & (t_grid < seg.t1))[0] + n_samp_seg = len(i_to_samp) - new_self.bind_field(model.state.wrap(xs.T)) + # insert event times + new_t[idx0] = seg.t0 + new_t[idx0 + 1 + n_samp_seg] = seg.t1 + # insert sample times + new_t[idx0 + 1 : idx0 + 1 + n_samp_seg] = t_grid[i_to_samp] + + # interpolate + x_seg = seg(new_t[idx0 : idx0 + n_samp_seg + 2]) + if x_seg.ndim == 1: + x_seg = x_seg[:, None] + new_x[idx0 : idx0 + n_samp_seg + 2] = x_seg + + new_e.append(Root(idx0, ev.rootsfound)) + + idx0 += n_samp_seg + 2 + new_e.append(Root(idx0, self._res.e[-1].rootsfound)) + + include_output = include_output and model.dynamic_output._count if include_output: - new_self.bind_field(model.dynamic_output.wrap(ys.T)) + dynamic_output = self.implementation.state_system.dynamic_output + p = self._res.p + new_y = np.empty((new_t.size, model.dynamic_output._count)) + for i, (t, x) in enumerate(zip(new_t, new_x, strict=True)): + new_y[i] = dynamic_output(p, t, x).T + + new_self.t = new_t + new_self.bind_field(model.state.wrap(new_x.T)) + if include_output: + new_self.bind_field(model.dynamic_output.wrap(new_y.T)) new_self._res = Result( - t=new_self.t, x=xs, y=ys, e=es, p=self._res.p, system=self._res.system + t=new_t, x=new_x, y=new_y, e=new_e, p=self._res.p, system=self._res.system ) return new_self diff --git a/src/condor/solvers/sweeping_gradient_method.py b/src/condor/solvers/sweeping_gradient_method.py index 177cdb07..0be089aa 100644 --- a/src/condor/solvers/sweeping_gradient_method.py +++ b/src/condor/solvers/sweeping_gradient_method.py @@ -693,13 +693,10 @@ def __getitem__(self, key): e=self.e[key], ) - def save( - self, - filename, - ): + def save(self, filename): e_idxs = [e.index for e in self.e] e_roots = [e.rootsfound for e in self.e] - np.savez( + np.savez_compressed( filename, e_idxs=e_idxs, e_roots=e_roots, diff --git a/tests/test_trajectory_analysis.py b/tests/test_trajectory_analysis.py index 762ff056..d1b652ed 100644 --- a/tests/test_trajectory_analysis.py +++ b/tests/test_trajectory_analysis.py @@ -332,7 +332,7 @@ class Options: @pytest.fixture -def odesys(): +def mass_spring_ode(): class MassSpring(co.ODESystem): x = state() v = state() @@ -342,41 +342,231 @@ class MassSpring(co.ODESystem): dot[v] = u - wn**2 * x initial[x] = 1 + dynamic_output.specific_energy = 0.5 * wn**2 * x**2 + 0.5 * v**2 + return MassSpring -def test_event_state_to_mode(odesys): +def test_event_state_to_mode(mass_spring_ode): # verify you can reference a state created in an event from a mode - class Event(odesys.Event): + class Event(mass_spring_ode.Event): function = v count = state(name="count_") update[count] = count + 1 - class Mode(odesys.Mode): + class Mode(mass_spring_ode.Mode): condition = Event.count > 0 action[u] = 1 - class Sim(odesys.TrajectoryAnalysis): + class Sim(mass_spring_ode.TrajectoryAnalysis): total_count = trajectory_output(Event.count) tf = 10 - print(Sim(wn=10).total_count) - -def test_mode_param_to_mode(odesys): +def test_mode_param_to_mode(mass_spring_ode): # verify you can reference a parameter created in a mode in another mode - class ModeA(odesys.Mode): + class ModeA(mass_spring_ode.Mode): condition = v > 0 u_hold = parameter() action[u] = u_hold - class ModeB(odesys.Mode): + class ModeB(mass_spring_ode.Mode): condition = 1 action[u] = ModeA.u_hold - class Sim(odesys.TrajectoryAnalysis): + class Sim(mass_spring_ode.TrajectoryAnalysis): tf = 10 Sim(wn=10, u_hold=0.8) + + +def test_file_io(mass_spring_ode, tmp_path): + class Ev(mass_spring_ode.Event): + function = x + + class Sim(mass_spring_ode.TrajectoryAnalysis): + tf = 10 + + sim = Sim(wn=1) + + fp1 = tmp_path / "sim.npz" + sim.to_file(fp1) + sim_from_file = Sim.from_file(fp1) + assert len(sim_from_file._res.e) > 1 + + fp2 = tmp_path / "sim_no_events.npz" + sim_resamp = sim.resample(0.1, include_events=False) + sim_resamp.to_file(fp2) + sim_resamp_from_file = Sim.from_file(fp2) + assert len(sim_resamp_from_file._res.e) == 0 + + +def make_resample_sim(tf_, e_times=None, add_output=False): + class MassSpring(co.ODESystem): + x = state() + v = state() + wn = parameter() + + dot[x] = v + dot[v] = wn**2 * x + + initial[x] = 1 + + if add_output: + dynamic_output.ke = 0.5 * v**2 + + if e_times: + for e in e_times: + + class Ev(MassSpring.Event): + at_time = e + + class Sim(MassSpring.TrajectoryAnalysis): + tf = tf_ + + sim = Sim(wn=8.0) + return sim + + +def test_resample_no_events(): + sim = make_resample_sim(1.0, add_output=False) + + simd = sim.resample(0.2, include_events=False) + np.testing.assert_allclose( + simd.t, + [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + rtol=1e-12, + atol=1e-12, + ) + + # include_events doubles t0 and tf + simd2 = sim.resample(0.2, include_events=True) + np.testing.assert_allclose( + simd2.t, + [0.0, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0], + rtol=1e-12, + atol=1e-12, + ) + + assert simd._res.y == [] + + +@pytest.mark.parametrize("add_output", [False, True]) +def test_resample_with_events(add_output): + # cases: + # - no samples between two events + # - sample coincides exactly with event + + e_times = [0.05, 0.06, 0.3, 0.4] + sim = make_resample_sim(1.0, e_times=e_times, add_output=add_output) + + simd = sim.resample(0.2, include_events=False) + np.testing.assert_allclose( + simd.t, + [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + rtol=1e-12, + atol=1e-12, + ) + if add_output: + assert simd.ke.size == simd.t.size + else: + assert simd._res.y == [] + + simd = sim.resample(0.2, include_events=True, include_output=add_output) + np.testing.assert_allclose( + simd.t, + [0.0, 0.0, 0.05, 0.05, 0.06, 0.06, 0.2, 0.3, 0.3, 0.4, 0.4, 0.6, 0.8, 1.0, 1.0], + rtol=1e-12, + atol=1e-12, + ) + if add_output: + assert simd.ke.size == simd.t.size + else: + assert simd._res.y == [] + + +def test_resample_nonsampled_tf(): + sim = make_resample_sim(1.1, add_output=False) + simd = sim.resample(0.2, include_events=False) + np.testing.assert_allclose( + simd.t, + [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + rtol=1e-12, + atol=1e-12, + ) + + tf = 1.0 + 1e-8 + sim = make_resample_sim(tf, e_times=[0.5]) + simd = sim.resample(0.2, include_events=True) + np.testing.assert_allclose( + simd.t, + [0.0, 0.0, 0.2, 0.4, 0.5, 0.5, 0.6, 0.8, 1.0, tf, tf], + rtol=1e-12, + atol=1e-12, + ) + + +def test_resample_single_state(): + class ODE(co.ODESystem): + a = parameter() + x = state() + dot[x] = -a * x + + class Sim(ODE.TrajectoryAnalysis): + tf = 10.0 + initial[x] = 1 + + sim = Sim(a=0.5) + + sim_resamp = sim.resample(1.0, include_events=False) + assert sim_resamp._res.x.shape == (11, 1) + assert sim_resamp._res.e == [] + + sim_resamp_events = sim.resample(1.0, include_events=True) + assert sim_resamp_events._res.x.shape == (13, 1) + assert len(sim_resamp_events._res.e) == 2 + + +def test_resample_separate_events(mass_spring_ode): + class Sim(mass_spring_ode.TrajectoryAnalysis): + tf = 1 + + class Options: + separate_events = True + + sim = Sim(wn=10) + + with pytest.raises(NotImplementedError): + sim.resample(0.1) + + +def test_resample_no_impl(mass_spring_ode): + # mock pickle dump/load (as in multiprocessing) by deleting implementation + class Sim(mass_spring_ode.TrajectoryAnalysis): + tf = 10 + + sim = Sim(wn=10) + del sim.implementation + + with pytest.warns(UserWarning, match="include_output"): + sim.resample(0.5, include_output=True) + + +def test_resample_check_tplus(mass_spring_ode): + # check that resample with include_events=False and a coincident event take from t+ + # strategy is to create an event exactly coincident with a sample time and update + # the state to switch signs, check that the sample at the event has the changed sign + + class Ev(mass_spring_ode.Event): + at_time = 0.5 + update[x] = -1 + + class Sim(mass_spring_ode.TrajectoryAnalysis): + tf = 1 + + sim = Sim(wn=1) + simd = sim.resample(0.1, include_events=False) + assert all(simd.x[simd.t < 0.45] > 0) + assert all(simd.x[simd.t > 0.45] < 0)