Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ Bug fixes
invalid position or velocity units and silently pass these through to orbit
integration later.

- Fixed a bug in the time interpolation of ``EXPPotential`` where evaluating a
time-evolving potential at (or very near) a stored snapshot time could silently return
a neighboring snapshot, or raise an out-of-bounds error at the first or last snapshot.
This was caused by floating-point differences between the requested time and the
limited-precision snapshot times stored in the coefficient file.

API changes
-----------

Expand Down
38 changes: 27 additions & 11 deletions src/gala/potential/potential/builtin/exp_fields.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <memory>
#include <cmath>
#include <filesystem>
#include <algorithm>

namespace fs = std::filesystem;

Expand Down Expand Up @@ -203,23 +204,38 @@ CoefClasses::CoefStrPtr interpolator(double t, CoefClasses::CoefsPtr coefs)
//
auto times = coefs->Times();

if (t<times.front() or t>times.back()) {
// Allow for a small tolerance in comparing the requested time to the stored snapshot
// times. This is because floating point errors can come from: (1) the unit conversion
// applied to the requested time, and (2) the limited precision with which snapshot
// times are written to / read from the coefficient file. It seems that writing the times
// to HDF5 only stores the times with reduced precision?
const double scale =
std::max({std::abs(times.front()), std::abs(times.back()), 1.0});
const double tol = 1.0e-8 * scale; // 1e-8 is somewhat arbitrary, but seems to work

if (t < (times.front() - tol) or t > (times.back() + tol)) {
std::ostringstream sout;
sout << "FieldWrapper::interpolator: time t=" << t << " is out of bounds: ["
<< times.front() << ", " << times.back() << "] (raw EXP snapshot time units)";
throw std::runtime_error(sout.str());
}

auto it1 = std::lower_bound(times.begin(), times.end(), t);
auto it2 = it1 + 1;

if (it2 == times.end()) {
it2--;
it1 = it2 - 1;
}

// Handle degenerate case where it1 == it2 (single time entry)
if (it1 == it2 || *it1 == *it2) {
// Clamp into range so that endpoint roundoff interpolates cleanly to the
// boundary snapshot instead of extrapolating just outside it.
t = std::min(std::max(t, times.front()), times.back());

// Bracket the interval [it1, it2] with *it1 <= t <= *it2. Use upper_bound (first
// element strictly greater than t) so that a time sitting epsilon above a stored knot
// keeps that knot as the lower bracket instead of jumping a full interval (the
// previous lower_bound behavior picked the wrong interval in that case and silently
// returned the next snapshot's coefficients).
auto it2 = std::upper_bound(times.begin(), times.end(), t);
if (it2 == times.begin()) ++it2; // t at/below the first knot
if (it2 == times.end()) --it2; // t at/above the last knot
auto it1 = it2 - 1;

// Handle degenerate case where the bracketing times coincide (duplicate knot)
if (*it1 == *it2) {
return coefs->getCoefStruct(*it1);
}

Expand Down
40 changes: 40 additions & 0 deletions tests/potential/potential/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,46 @@ def test_exp_unit_tests():
assert u.allclose(pot_multi.tmax_exp, 2.0 * u.Gyr)


def test_multi_eval_at_snapshot_times():
"""Regression test for the time-interpolation evaluation bug.

Evaluating a time-evolving EXPPotential exactly AT a stored snapshot time
must equal the static snapshot at that index, and must not raise at the
first/last snapshot. Previously, floating-point differences between the
requested time and the (limited-precision) stored snapshot times pushed the
interval search onto the wrong bracket -- silently returning a neighbouring
snapshot's field -- or tripped the out-of-bounds check at the endpoints. See
``gala_exp::interpolator`` in ``exp_fields.cc``.
"""
pot_multi = EXPPotential(
config_file=EXP_CONFIG_FILE,
coef_file=EXP_MULTI_COEF_FILE,
units=EXPTestBase.exp_units,
)

# snapshots were generated at 0, 500, 1000, 1500, 2000 Myr; see generate_exp.py
snapshot_times = [0, 500, 1000, 1500, 2000] * u.Myr
test_x = [8.0, 1.0, -2.0] * u.kpc

for k, t_k in enumerate(snapshot_times):
frozen = EXPPotential(
config_file=EXP_CONFIG_FILE,
coef_file=EXP_MULTI_COEF_FILE,
snapshot_index=k,
units=EXPTestBase.exp_units,
)
# Must not raise (including at the first and last snapshot) and must
# match the corresponding frozen snapshot to interpolation precision.
assert u.allclose(
pot_multi.energy(test_x, t=t_k), frozen.energy(test_x), rtol=1e-6
)
assert u.allclose(
pot_multi.acceleration(test_x, t=t_k),
frozen.acceleration(test_x),
rtol=1e-6,
)


@pytest.mark.skipif(not HAVE_PYEXP, reason="requires pyEXP")
def test_pyexp_unit_tests():
"""Test PyEXPPotential static/dynamic behavior"""
Expand Down
Loading