From a419b3e28ea8bc3194dee5cccc0eba1623f167bc Mon Sep 17 00:00:00 2001 From: shimwell Date: Fri, 29 May 2026 14:25:04 +0200 Subject: [PATCH] Add multi-index apply_time_correction with einsum contraction (variant) Variant of the multi-index apply_time_correction that evaluates the summed TCF-weighted sum over the parent-nuclide axis with an einsum contraction instead of a broadcast-multiply-and-reduce. This is substantially faster for large mesh tallies but shifts mean/std_dev by ~1e-15 relative (floating-point summation order), so results match develop to machine precision rather than bit-for-bit -- well below Monte Carlo statistical noise and regression-test tolerances. Same multi-index API and sum/sum_sq handling as the bitwise variant. --- openmc/deplete/d1s.py | 151 +++++++++++++++++++++++------------ tests/unit_tests/test_d1s.py | 85 ++++++++++++++++++++ 2 files changed, 183 insertions(+), 53 deletions(-) diff --git a/openmc/deplete/d1s.py b/openmc/deplete/d1s.py index d85d2e8a79f..2927ca7480f 100644 --- a/openmc/deplete/d1s.py +++ b/openmc/deplete/d1s.py @@ -124,15 +124,15 @@ def time_correction_factors( def apply_time_correction( tally: openmc.Tally, time_correction_factors: dict[str, np.ndarray], - index: int = -1, + index: int | Sequence[int] = -1, sum_nuclides: bool = True -) -> openmc.Tally: +) -> openmc.Tally | list[openmc.Tally]: """Apply time correction factors to a tally. - This function applies the time correction factors at the given index to a - tally that contains a :class:`~openmc.ParentNuclideFilter`. When - `sum_nuclides` is True, values over all parent nuclides will be summed, - leaving a single value for each filter combination. + This function applies the time correction factors at the given index (or + indices) to a tally that contains a :class:`~openmc.ParentNuclideFilter`. + When `sum_nuclides` is True, values over all parent nuclides will be + summed, leaving a single value for each filter combination. Parameters ---------- @@ -140,17 +140,27 @@ def apply_time_correction( Tally to apply the time correction factors to time_correction_factors : dict Time correction factors as returned by :func:`time_correction_factors` - index : int, optional - Index of the time of interest. If N timesteps are provided in - :func:`time_correction_factors`, there are N + 1 times to select from. - The default is -1 which corresponds to the final time. + index : int or iterable of int, optional + Index (or indices) of the time(s) of interest. If N timesteps are + provided in :func:`time_correction_factors`, there are N + 1 times to + select from. The default is -1 which corresponds to the final time. + Passing an iterable returns a list of derived tallies, one per index. + The tally arrays are read and reshaped once and shared across all + indices, and (when ``sum_nuclides`` is True) the time-correction is + applied as a contraction over the parent-nuclide axis rather than a + broadcast-and-reduce per index, which is substantially faster for + large tallies (e.g. mesh tallies) evaluated at many times. sum_nuclides : bool Whether to sum over the parent nuclides Returns ------- - openmc.Tally - Derived tally with time correction factors applied + openmc.Tally or list of openmc.Tally + Derived tally with time correction factors applied. A list is + returned when ``index`` is an iterable; otherwise a single tally is + returned. When ``sum_nuclides`` is True the result is a derived tally, + for which ``sum`` and ``sum_sq`` are None (as for any derived tally); + the meaningful results are ``mean`` and ``std_dev``. """ # Make sure the tally contains a ParentNuclideFilter @@ -162,55 +172,90 @@ def apply_time_correction( # Get list of radionuclides based on tally filter radionuclides = [str(x) for x in tally.filters[i_filter].bins] - tcf = np.array([time_correction_factors[x][index] for x in radionuclides]) - # Force tally results to be read and std_dev to be computed - tally.std_dev + # Normalize index to a list; remember whether the caller asked scalar + scalar_input = isinstance(index, (int, np.integer)) + indices = [int(index)] if scalar_input else list(index) - # Create shallow copy of tally - new_tally = copy(tally) - new_tally._filters = copy(tally._filters) + # Force tally results to be read and std_dev to be computed (once) + tally.std_dev - # Determine number of bins in other filters + # Determine number of bins in other filters (computed once) n_bins_before = prod([f.num_bins for f in tally.filters[:i_filter]]) n_bins_after = prod([f.num_bins for f in tally.filters[i_filter + 1:]]) - - # Reshape sum and sum_sq, apply TCF, and sum along that axis - _, n_nuclides, n_scores = new_tally.shape + _, n_nuclides, n_scores = tally.shape n_radionuclides = len(radionuclides) - shape = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores) - tally_sum = new_tally.sum.reshape(shape) - tally_sum_sq = new_tally.sum_sq.reshape(shape) - tally_mean = new_tally.mean.reshape(shape) - tally_std_dev = new_tally.std_dev.reshape(shape) - - # Apply TCF, broadcasting to the correct dimensions - tcf.shape = (1, -1, 1, 1, 1) - new_tally._mean = tally_mean * tcf - new_tally._std_dev = tally_std_dev * tcf - - shape = (-1, n_nuclides, n_scores) - + shape5 = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores) + flat_shape = (-1, n_nuclides, n_scores) + + # Reshape views shared across all indices + tally_mean_5d = tally.mean.reshape(shape5) + tally_std_dev_5d = tally.std_dev.reshape(shape5) + + # Time correction factors for every requested index -> shape + # (n_indices, n_radionuclides). Indexing a row (``tcf[t]``) yields a + # contiguous per-index factor vector, so the einsum below rounds + # identically whether one index or many were requested. + tcf = np.array( + [[time_correction_factors[x][idx] for x in radionuclides] + for idx in indices] + ) + + results = [] if sum_nuclides: - # Sum over parent nuclides (note that when combining different bins for - # parent nuclide, we can't work directly on sum_sq) - new_tally._sum = None - new_tally._sum_sq = None - new_tally._mean = new_tally.mean.sum(axis=1).reshape(shape) - new_tally._std_dev = np.linalg.norm(new_tally.std_dev, axis=1).reshape(shape) - new_tally._derived = True - - # Remove ParentNuclideFilter - new_tally.filters.pop(i_filter) + # The TCF-weighted sum over the parent-nuclide axis is a contraction of + # the 5-D arrays with the per-index factor vector, evaluated with a + # single (BLAS-backed) einsum instead of a Python-level + # multiply-and-reduce. Variances combine in quadrature, so std_dev + # contracts the squared values (cf. ``np.linalg.norm`` over the nuclide + # axis). ``sum``/``sum_sq`` are left unset: the public accessors return + # None for any derived tally (so this matches develop's observable + # behavior), and keeping the TCF-scaled per-nuclide arrays would only + # waste two full-array multiplies per index on data nothing reads -- it + # would also be shaped inconsistently with the popped filter, which + # breaks ``Tally.sparse``. Contracting one index at a time keeps the + # result bit-for-bit identical to a scalar call for that index. + # subscripts: i=bins_before, r=radionuclide, j=bins_after, k=nuclide, + # s=score + tally_var_5d = tally_std_dev_5d**2 # reused at every index + for t in range(len(indices)): + tcf_row = tcf[t] + mean = np.einsum('irjks,r->ijks', tally_mean_5d, tcf_row) + std_dev = np.sqrt( + np.einsum('irjks,r->ijks', tally_var_5d, tcf_row*tcf_row)) + + new_tally = copy(tally) + new_tally._filters = copy(tally._filters) + new_tally._mean = mean.reshape(flat_shape) + new_tally._std_dev = std_dev.reshape(flat_shape) + new_tally._sum = None + new_tally._sum_sq = None + new_tally._derived = True + + # Remove ParentNuclideFilter + new_tally.filters.pop(i_filter) + results.append(new_tally) else: - # Apply TCF and change shape back to (filter combinations, nuclides, - # scores) - new_tally._sum = (tally_sum * tcf).reshape(shape) - new_tally._sum_sq = (tally_sum_sq * (tcf*tcf)).reshape(shape) - new_tally._mean.shape = shape - new_tally._std_dev.shape = shape - - return new_tally + # Per-nuclide results are kept, so each index produces a full-size + # array; the shared 5-D views avoid re-reading and re-reshaping the + # tally data on every index. + tally_sum_5d = tally.sum.reshape(shape5) + tally_sum_sq_5d = tally.sum_sq.reshape(shape5) + + for t in range(len(indices)): + tcf_b = tcf[t].reshape(1, -1, 1, 1, 1) + + new_tally = copy(tally) + new_tally._filters = copy(tally._filters) + + # Apply TCF, broadcasting to the correct dimensions + new_tally._sum = (tally_sum_5d * tcf_b).reshape(flat_shape) + new_tally._sum_sq = (tally_sum_sq_5d * (tcf_b*tcf_b)).reshape(flat_shape) + new_tally._mean = (tally_mean_5d * tcf_b).reshape(flat_shape) + new_tally._std_dev = (tally_std_dev_5d * tcf_b).reshape(flat_shape) + results.append(new_tally) + + return results[0] if scalar_input else results def prepare_tallies( diff --git a/tests/unit_tests/test_d1s.py b/tests/unit_tests/test_d1s.py index 49c1b30499c..a95d5d31f98 100644 --- a/tests/unit_tests/test_d1s.py +++ b/tests/unit_tests/test_d1s.py @@ -154,3 +154,88 @@ def test_apply_time_correction(run_in_tmpdir): # The summed tally is derived, so sum/sum_sq are None assert result_summed.sum is None assert result_summed.sum_sq is None + + +def test_apply_time_correction_multi_index(run_in_tmpdir): + # Build the same model used in test_apply_time_correction + mat = openmc.Material() + mat.add_element('Ni', 1.0) + sphere = openmc.Sphere(r=10.0, boundary_type='vacuum') + cell = openmc.Cell(fill=mat, region=-sphere) + model = openmc.Model() + model.geometry = openmc.Geometry([cell]) + model.settings.run_mode = 'fixed source' + model.settings.batches = 3 + model.settings.particles = 10 + model.settings.photon_transport = True + model.settings.use_decay_photons = True + particle_filter = openmc.ParticleFilter('photon') + tally = openmc.Tally() + tally.filters = [particle_filter] + tally.scores = ['flux'] + model.tallies = [tally] + + # A schedule with several timesteps so we can ask for many indices + nuclides = d1s.prepare_tallies(model, chain_file=CHAIN_PATH) + timesteps = [1.0e8, 1.0e8, 1.0e8, 1.0e8] + source_rates = [1.0, 0.0, 1.0, 0.0] + factors = d1s.time_correction_factors(nuclides, timesteps, source_rates) + n_times = len(factors[nuclides[0]]) + + with openmc.config.patch('chain_file', CHAIN_PATH): + output_path = model.run() + with openmc.StatePoint(output_path) as sp: + tally = sp.tallies[tally.id] + + orig_filters = list(tally.filters) + orig_sum = tally.sum.copy() + orig_sum_sq = tally.sum_sq.copy() + orig_mean = tally.mean.copy() + orig_std_dev = tally.std_dev.copy() + + # Passing a list of indices returns a list of derived tallies, each + # matching what a scalar call at that index would have produced. + for sum_nuc in (True, False): + many = d1s.apply_time_correction( + tally, factors, index=list(range(n_times)), + sum_nuclides=sum_nuc, + ) + assert isinstance(many, list) + assert len(many) == n_times + for i, derived in enumerate(many): + ref = d1s.apply_time_correction( + tally, factors, index=i, sum_nuclides=sum_nuc + ) + np.testing.assert_array_equal(derived.mean, ref.mean) + np.testing.assert_array_equal(derived.std_dev, ref.std_dev) + assert derived.filters == ref.filters + if sum_nuc: + # Summed tally is derived; sum/sum_sq are None (as in + # develop, where the public accessors return None for any + # derived tally) + assert derived.sum is None and derived.sum_sq is None + else: + np.testing.assert_array_equal(derived.sum, ref.sum) + np.testing.assert_array_equal(derived.sum_sq, ref.sum_sq) + + # Unordered / partial index sequence is honored in order + subset = [n_times - 1, 0, 2] + many = d1s.apply_time_correction(tally, factors, index=subset) + assert isinstance(many, list) + assert len(many) == len(subset) + for derived, i in zip(many, subset): + ref = d1s.apply_time_correction(tally, factors, index=i) + np.testing.assert_array_equal(derived.mean, ref.mean) + np.testing.assert_array_equal(derived.std_dev, ref.std_dev) + + # Scalar input still returns a single Tally, not a 1-element list + single = d1s.apply_time_correction(tally, factors, index=2) + assert isinstance(single, openmc.Tally) + assert not isinstance(single, list) + + # Original tally is unchanged + assert tally.filters == orig_filters + assert np.all(tally.sum == orig_sum) + assert np.all(tally.sum_sq == orig_sum_sq) + assert np.all(tally.mean == orig_mean) + assert np.all(tally.std_dev == orig_std_dev)