diff --git a/CHANGES.rst b/CHANGES.rst index b63c28b8f..f3bd570a3 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -20,6 +20,12 @@ Bug fixes invalid position or velocity units and silently pass these through to orbit integration later. +- Fixed a bug in the time interpolation of ``EXPPotential`` where evaluating a + time-evolving potential at (or very near) a stored snapshot time could silently return + a neighboring snapshot, or raise an out-of-bounds error at the first or last snapshot. + This was caused by floating-point differences between the requested time and the + limited-precision snapshot times stored in the coefficient file. + API changes ----------- diff --git a/src/gala/potential/potential/builtin/exp_fields.cc b/src/gala/potential/potential/builtin/exp_fields.cc index 45ea0080e..09be6be59 100644 --- a/src/gala/potential/potential/builtin/exp_fields.cc +++ b/src/gala/potential/potential/builtin/exp_fields.cc @@ -6,6 +6,7 @@ #include #include #include +#include namespace fs = std::filesystem; @@ -203,23 +204,38 @@ CoefClasses::CoefStrPtr interpolator(double t, CoefClasses::CoefsPtr coefs) // auto times = coefs->Times(); - if (ttimes.back()) { + // Allow for a small tolerance in comparing the requested time to the stored snapshot + // times. This is because floating point errors can come from: (1) the unit conversion + // applied to the requested time, and (2) the limited precision with which snapshot + // times are written to / read from the coefficient file. It seems that writing the times + // to HDF5 only stores the times with reduced precision? + const double scale = + std::max({std::abs(times.front()), std::abs(times.back()), 1.0}); + const double tol = 1.0e-8 * scale; // 1e-8 is somewhat arbitrary, but seems to work + + if (t < (times.front() - tol) or t > (times.back() + tol)) { std::ostringstream sout; sout << "FieldWrapper::interpolator: time t=" << t << " is out of bounds: [" << times.front() << ", " << times.back() << "] (raw EXP snapshot time units)"; throw std::runtime_error(sout.str()); } - auto it1 = std::lower_bound(times.begin(), times.end(), t); - auto it2 = it1 + 1; - - if (it2 == times.end()) { - it2--; - it1 = it2 - 1; - } - - // Handle degenerate case where it1 == it2 (single time entry) - if (it1 == it2 || *it1 == *it2) { + // Clamp into range so that endpoint roundoff interpolates cleanly to the + // boundary snapshot instead of extrapolating just outside it. + t = std::min(std::max(t, times.front()), times.back()); + + // Bracket the interval [it1, it2] with *it1 <= t <= *it2. Use upper_bound (first + // element strictly greater than t) so that a time sitting epsilon above a stored knot + // keeps that knot as the lower bracket instead of jumping a full interval (the + // previous lower_bound behavior picked the wrong interval in that case and silently + // returned the next snapshot's coefficients). + auto it2 = std::upper_bound(times.begin(), times.end(), t); + if (it2 == times.begin()) ++it2; // t at/below the first knot + if (it2 == times.end()) --it2; // t at/above the last knot + auto it1 = it2 - 1; + + // Handle degenerate case where the bracketing times coincide (duplicate knot) + if (*it1 == *it2) { return coefs->getCoefStruct(*it1); } diff --git a/tests/potential/potential/test_exp.py b/tests/potential/potential/test_exp.py index 48c605e2d..c665b1d14 100644 --- a/tests/potential/potential/test_exp.py +++ b/tests/potential/potential/test_exp.py @@ -271,6 +271,46 @@ def test_exp_unit_tests(): assert u.allclose(pot_multi.tmax_exp, 2.0 * u.Gyr) +def test_multi_eval_at_snapshot_times(): + """Regression test for the time-interpolation evaluation bug. + + Evaluating a time-evolving EXPPotential exactly AT a stored snapshot time + must equal the static snapshot at that index, and must not raise at the + first/last snapshot. Previously, floating-point differences between the + requested time and the (limited-precision) stored snapshot times pushed the + interval search onto the wrong bracket -- silently returning a neighbouring + snapshot's field -- or tripped the out-of-bounds check at the endpoints. See + ``gala_exp::interpolator`` in ``exp_fields.cc``. + """ + pot_multi = EXPPotential( + config_file=EXP_CONFIG_FILE, + coef_file=EXP_MULTI_COEF_FILE, + units=EXPTestBase.exp_units, + ) + + # snapshots were generated at 0, 500, 1000, 1500, 2000 Myr; see generate_exp.py + snapshot_times = [0, 500, 1000, 1500, 2000] * u.Myr + test_x = [8.0, 1.0, -2.0] * u.kpc + + for k, t_k in enumerate(snapshot_times): + frozen = EXPPotential( + config_file=EXP_CONFIG_FILE, + coef_file=EXP_MULTI_COEF_FILE, + snapshot_index=k, + units=EXPTestBase.exp_units, + ) + # Must not raise (including at the first and last snapshot) and must + # match the corresponding frozen snapshot to interpolation precision. + assert u.allclose( + pot_multi.energy(test_x, t=t_k), frozen.energy(test_x), rtol=1e-6 + ) + assert u.allclose( + pot_multi.acceleration(test_x, t=t_k), + frozen.acceleration(test_x), + rtol=1e-6, + ) + + @pytest.mark.skipif(not HAVE_PYEXP, reason="requires pyEXP") def test_pyexp_unit_tests(): """Test PyEXPPotential static/dynamic behavior"""