From b1651d5076696e953f7007e71edb85dc3c735b13 Mon Sep 17 00:00:00 2001 From: Jonathan Shimwell Date: Tue, 19 May 2026 14:56:18 +0200 Subject: [PATCH 1/3] vectorized d1s --- openmc/deplete/d1s.py | 108 +++++++++++++++++++++++++++++++++++ tests/unit_tests/test_d1s.py | 104 +++++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+) diff --git a/openmc/deplete/d1s.py b/openmc/deplete/d1s.py index bc99fc42dba..252f1dd0711 100644 --- a/openmc/deplete/d1s.py +++ b/openmc/deplete/d1s.py @@ -212,6 +212,114 @@ def apply_time_correction( return new_tally +def apply_time_correction_series( + tally: openmc.Tally, + time_correction_factors: dict[str, np.ndarray], + indices: Sequence[int] | None = None, + sum_nuclides: bool = True, +) -> tuple[np.ndarray, np.ndarray]: + """Apply time correction factors to a tally at multiple time indices. + + Vectorized variant of :func:`apply_time_correction` that evaluates a series + of time indices in a single matrix multiplication. Calling + :func:`apply_time_correction` in a loop over ``N`` indices deep-copies the + tally ``N`` times and does ``N`` applications of the + sum/sum_sq/mean/std_dev arithmetic; this function reads the underlying + arrays once and folds the radionuclide-axis sum into one matmul, so the + work is independent of ``N`` on the tally-extraction side. + + Unlike :func:`apply_time_correction`, this returns raw NumPy arrays rather + than a list of derived :class:`openmc.Tally` objects: constructing ``N`` + derived tallies (each with its own copy of ``_sum``, ``_sum_sq``, + ``_mean``, and ``_std_dev``) negates the memory advantage on fine-mesh + tallies. Users who need a ``Tally`` per index can build one from the + returned arrays. + + Parameters + ---------- + tally : openmc.Tally + Tally to apply the time correction factors to. Must contain a + :class:`~openmc.ParentNuclideFilter`. + time_correction_factors : dict + Time correction factors as returned by :func:`time_correction_factors`. + indices : iterable of int, optional + Indices into each time correction factor array to evaluate. If None + (default), every available index is evaluated. + sum_nuclides : bool, optional + Whether to sum over the parent nuclides (default True). Matches the + semantics of :func:`apply_time_correction`: with ``sum_nuclides=True`` + the standard deviation across radionuclides is the L2 norm. + + Returns + ------- + mean : numpy.ndarray + Mean values. Shape is ``(n_indices, n_other_filter_bins, n_nuclides, + n_scores)`` when ``sum_nuclides`` is True (the + :class:`~openmc.ParentNuclideFilter` axis is collapsed), and + ``(n_indices, n_filter_bins, n_nuclides, n_scores)`` otherwise (with + the parent-nuclide bins flattened into the filter-bin axis, same + layout as :func:`apply_time_correction` produces). + std_dev : numpy.ndarray + Standard deviations with the same shape as ``mean``. + + """ + # Locate the ParentNuclideFilter + for i_filter, f in enumerate(tally.filters): + if isinstance(f, openmc.ParentNuclideFilter): + break + else: + raise ValueError('Tally must contain a ParentNuclideFilter') + + radionuclides = [str(x) for x in tally.filters[i_filter].bins] + n_radionuclides = len(radionuclides) + + # Default to every available index + if indices is None: + indices = range(len(time_correction_factors[radionuclides[0]])) + indices = np.asarray(list(indices), dtype=int) + n_indices = indices.size + + # Build (n_indices, n_radionuclides) TCF matrix + tcf = np.column_stack( + [time_correction_factors[nuc][indices] for nuc in radionuclides] + ) + + # Force std_dev to be computed and the underlying arrays to be read + tally.std_dev + + # Reshape to expose the parent-nuclide axis + n_bins_before = prod([f.num_bins for f in tally.filters[:i_filter]]) + n_bins_after = prod([f.num_bins for f in tally.filters[i_filter + 1:]]) + _, n_nuclides, n_scores = tally.shape + shape5 = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores) + + mean_5d = tally.mean.reshape(shape5) + std_dev_5d = tally.std_dev.reshape(shape5) + + if sum_nuclides: + # Move parent-nuclide axis to position 0 and flatten the rest so a + # single matmul does the per-index radionuclide sum. + mean_rf = np.moveaxis(mean_5d, 1, 0).reshape(n_radionuclides, -1) + var_rf = np.moveaxis(std_dev_5d ** 2, 1, 0).reshape(n_radionuclides, -1) + + mean_out = tcf @ mean_rf + # Variances combine linearly when factors are squared; sqrt at the end + # gives the L2 norm matching apply_time_correction. + std_out = np.sqrt((tcf ** 2) @ var_rf) + + out_shape = (n_indices, n_bins_before * n_bins_after, + n_nuclides, n_scores) + return mean_out.reshape(out_shape), std_out.reshape(out_shape) + + # Per-radionuclide: result keeps the parent-nuclide axis. + tcf_b = tcf.reshape(n_indices, 1, n_radionuclides, 1, 1, 1) + mean_out = tcf_b * mean_5d[np.newaxis, ...] + std_out = tcf_b * std_dev_5d[np.newaxis, ...] + + out_shape = (n_indices, -1, n_nuclides, n_scores) + return mean_out.reshape(out_shape), std_out.reshape(out_shape) + + def prepare_tallies( model: openmc.Model, nuclides: list[str] | None = None, diff --git a/tests/unit_tests/test_d1s.py b/tests/unit_tests/test_d1s.py index 8f3b62f4000..30ce69e4ee7 100644 --- a/tests/unit_tests/test_d1s.py +++ b/tests/unit_tests/test_d1s.py @@ -150,3 +150,107 @@ def test_apply_time_correction(run_in_tmpdir): result_summed.get_reshaped_data() result.get_pandas_dataframe() result_summed.get_pandas_dataframe() + + +def test_apply_time_correction_series(run_in_tmpdir): + # Build the same model used in test_apply_time_correction + mat = openmc.Material() + mat.add_element('Ni', 1.0) + sphere = openmc.Sphere(r=10.0, boundary_type='vacuum') + cell = openmc.Cell(fill=mat, region=-sphere) + model = openmc.Model() + model.geometry = openmc.Geometry([cell]) + model.settings.run_mode = 'fixed source' + model.settings.batches = 3 + model.settings.particles = 10 + model.settings.photon_transport = True + model.settings.use_decay_photons = True + particle_filter = openmc.ParticleFilter('photon') + tally = openmc.Tally() + tally.filters = [particle_filter] + tally.scores = ['flux'] + model.tallies = [tally] + + # A schedule with several timesteps so the series has > 1 entry. + nuclides = d1s.prepare_tallies(model, chain_file=CHAIN_PATH) + timesteps = [1.0e8, 1.0e8, 1.0e8, 1.0e8] + source_rates = [1.0, 0.0, 1.0, 0.0] + factors = d1s.time_correction_factors(nuclides, timesteps, source_rates) + n_times = len(factors[nuclides[0]]) + + # Run the model once + with openmc.config.patch('chain_file', CHAIN_PATH): + output_path = model.run() + with openmc.StatePoint(output_path) as sp: + tally = sp.tallies[tally.id] + + # Snapshot original tally state so we can confirm immutability later + orig_filters = list(tally.filters) + orig_sum = tally.sum.copy() + orig_sum_sq = tally.sum_sq.copy() + orig_mean = tally.mean.copy() + orig_std_dev = tally.std_dev.copy() + + # sum_nuclides=True: series matches per-index loop + mean_series, std_series = d1s.apply_time_correction_series( + tally, factors, sum_nuclides=True + ) + assert mean_series.shape[0] == n_times + assert std_series.shape == mean_series.shape + + for i in range(n_times): + ref = d1s.apply_time_correction( + tally, factors, index=i, sum_nuclides=True + ) + np.testing.assert_allclose( + mean_series[i].reshape(ref.mean.shape), ref.mean + ) + np.testing.assert_allclose( + std_series[i].reshape(ref.std_dev.shape), ref.std_dev + ) + + # sum_nuclides=False: series matches per-index loop + mean_series_f, std_series_f = d1s.apply_time_correction_series( + tally, factors, sum_nuclides=False + ) + assert mean_series_f.shape[0] == n_times + + for i in range(n_times): + ref = d1s.apply_time_correction( + tally, factors, index=i, sum_nuclides=False + ) + np.testing.assert_allclose( + mean_series_f[i].reshape(ref.mean.shape), ref.mean + ) + np.testing.assert_allclose( + std_series_f[i].reshape(ref.std_dev.shape), ref.std_dev + ) + + # explicit indices subset (and unordered) + subset = [n_times - 1, 0, 2] + mean_sub, std_sub = d1s.apply_time_correction_series( + tally, factors, indices=subset + ) + assert mean_sub.shape[0] == len(subset) + for k, i in enumerate(subset): + ref = d1s.apply_time_correction(tally, factors, index=i) + np.testing.assert_allclose( + mean_sub[k].reshape(ref.mean.shape), ref.mean + ) + np.testing.assert_allclose( + std_sub[k].reshape(ref.std_dev.shape), ref.std_dev + ) + + # original tally is unchanged + assert tally.filters == orig_filters + assert np.all(tally.sum == orig_sum) + assert np.all(tally.sum_sq == orig_sum_sq) + assert np.all(tally.mean == orig_mean) + assert np.all(tally.std_dev == orig_std_dev) + + # missing ParentNuclideFilter raises + bare = openmc.Tally() + bare.filters = [particle_filter] + bare.scores = ['flux'] + with pytest.raises(ValueError): + d1s.apply_time_correction_series(bare, factors) From c9e6238583fdb2cb22caeb44251dd8a46b94c36d Mon Sep 17 00:00:00 2001 From: Jonathan Shimwell Date: Tue, 19 May 2026 17:08:27 +0200 Subject: [PATCH 2/3] polymorphic index --- openmc/deplete/d1s.py | 225 ++++++++++------------------------- tests/unit_tests/test_d1s.py | 91 +++++--------- 2 files changed, 99 insertions(+), 217 deletions(-) diff --git a/openmc/deplete/d1s.py b/openmc/deplete/d1s.py index 252f1dd0711..394c72c3a01 100644 --- a/openmc/deplete/d1s.py +++ b/openmc/deplete/d1s.py @@ -124,15 +124,15 @@ def time_correction_factors( def apply_time_correction( tally: openmc.Tally, time_correction_factors: dict[str, np.ndarray], - index: int = -1, + index: int | Sequence[int] = -1, sum_nuclides: bool = True -) -> openmc.Tally: +) -> openmc.Tally | list[openmc.Tally]: """Apply time correction factors to a tally. - This function applies the time correction factors at the given index to a - tally that contains a :class:`~openmc.ParentNuclideFilter`. When - `sum_nuclides` is True, values over all parent nuclides will be summed, - leaving a single value for each filter combination. + This function applies the time correction factors at the given index (or + indices) to a tally that contains a :class:`~openmc.ParentNuclideFilter`. + When `sum_nuclides` is True, values over all parent nuclides will be + summed, leaving a single value for each filter combination. Parameters ---------- @@ -140,17 +140,23 @@ def apply_time_correction( Tally to apply the time correction factors to time_correction_factors : dict Time correction factors as returned by :func:`time_correction_factors` - index : int, optional - Index of the time of interest. If N timesteps are provided in - :func:`time_correction_factors`, there are N + 1 times to select from. - The default is -1 which corresponds to the final time. + index : int or iterable of int, optional + Index (or indices) of the time(s) of interest. If N timesteps are + provided in :func:`time_correction_factors`, there are N + 1 times to + select from. The default is -1 which corresponds to the final time. + Passing an iterable returns a list of derived tallies, one per index, + evaluated in a single pass — the underlying tally arrays are read and + reshaped once and shared across all indices instead of being re-read + and re-copied on every call. sum_nuclides : bool Whether to sum over the parent nuclides Returns ------- - openmc.Tally - Derived tally with time correction factors applied + openmc.Tally or list of openmc.Tally + Derived tally with time correction factors applied. A list is + returned when ``index`` is an iterable; otherwise a single tally is + returned. """ # Make sure the tally contains a ParentNuclideFilter @@ -162,162 +168,63 @@ def apply_time_correction( # Get list of radionuclides based on tally filter radionuclides = [str(x) for x in tally.filters[i_filter].bins] - tcf = np.array([time_correction_factors[x][index] for x in radionuclides]) - - # Force tally results to be read and std_dev to be computed - tally.std_dev - - # Create shallow copy of tally - new_tally = copy(tally) - new_tally._filters = copy(tally._filters) - - # Determine number of bins in other filters - n_bins_before = prod([f.num_bins for f in tally.filters[:i_filter]]) - n_bins_after = prod([f.num_bins for f in tally.filters[i_filter + 1:]]) - - # Reshape sum and sum_sq, apply TCF, and sum along that axis - _, n_nuclides, n_scores = new_tally.shape - n_radionuclides = len(radionuclides) - shape = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores) - tally_sum = new_tally.sum.reshape(shape) - tally_sum_sq = new_tally.sum_sq.reshape(shape) - tally_mean = new_tally.mean.reshape(shape) - tally_std_dev = new_tally.std_dev.reshape(shape) - - # Apply TCF, broadcasting to the correct dimensions - tcf.shape = (1, -1, 1, 1, 1) - new_tally._sum = tally_sum * tcf - new_tally._sum_sq = tally_sum_sq * (tcf*tcf) - new_tally._mean = tally_mean * tcf - new_tally._std_dev = tally_std_dev * tcf - - shape = (-1, n_nuclides, n_scores) - - if sum_nuclides: - # Sum over parent nuclides (note that when combining different bins for - # parent nuclide, we can't work directly on sum_sq) - new_tally._mean = new_tally.mean.sum(axis=1).reshape(shape) - new_tally._std_dev = np.linalg.norm(new_tally.std_dev, axis=1).reshape(shape) - new_tally._derived = True - - # Remove ParentNuclideFilter - new_tally.filters.pop(i_filter) - else: - # Change shape back to (filter combinations, nuclides, scores) - new_tally._sum.shape = shape - new_tally._sum_sq.shape = shape - new_tally._mean.shape = shape - new_tally._std_dev.shape = shape - - return new_tally + # Normalize index to a list; remember whether the caller asked scalar + scalar_input = isinstance(index, (int, np.integer)) + indices = [int(index)] if scalar_input else list(index) -def apply_time_correction_series( - tally: openmc.Tally, - time_correction_factors: dict[str, np.ndarray], - indices: Sequence[int] | None = None, - sum_nuclides: bool = True, -) -> tuple[np.ndarray, np.ndarray]: - """Apply time correction factors to a tally at multiple time indices. - - Vectorized variant of :func:`apply_time_correction` that evaluates a series - of time indices in a single matrix multiplication. Calling - :func:`apply_time_correction` in a loop over ``N`` indices deep-copies the - tally ``N`` times and does ``N`` applications of the - sum/sum_sq/mean/std_dev arithmetic; this function reads the underlying - arrays once and folds the radionuclide-axis sum into one matmul, so the - work is independent of ``N`` on the tally-extraction side. - - Unlike :func:`apply_time_correction`, this returns raw NumPy arrays rather - than a list of derived :class:`openmc.Tally` objects: constructing ``N`` - derived tallies (each with its own copy of ``_sum``, ``_sum_sq``, - ``_mean``, and ``_std_dev``) negates the memory advantage on fine-mesh - tallies. Users who need a ``Tally`` per index can build one from the - returned arrays. - - Parameters - ---------- - tally : openmc.Tally - Tally to apply the time correction factors to. Must contain a - :class:`~openmc.ParentNuclideFilter`. - time_correction_factors : dict - Time correction factors as returned by :func:`time_correction_factors`. - indices : iterable of int, optional - Indices into each time correction factor array to evaluate. If None - (default), every available index is evaluated. - sum_nuclides : bool, optional - Whether to sum over the parent nuclides (default True). Matches the - semantics of :func:`apply_time_correction`: with ``sum_nuclides=True`` - the standard deviation across radionuclides is the L2 norm. - - Returns - ------- - mean : numpy.ndarray - Mean values. Shape is ``(n_indices, n_other_filter_bins, n_nuclides, - n_scores)`` when ``sum_nuclides`` is True (the - :class:`~openmc.ParentNuclideFilter` axis is collapsed), and - ``(n_indices, n_filter_bins, n_nuclides, n_scores)`` otherwise (with - the parent-nuclide bins flattened into the filter-bin axis, same - layout as :func:`apply_time_correction` produces). - std_dev : numpy.ndarray - Standard deviations with the same shape as ``mean``. - - """ - # Locate the ParentNuclideFilter - for i_filter, f in enumerate(tally.filters): - if isinstance(f, openmc.ParentNuclideFilter): - break - else: - raise ValueError('Tally must contain a ParentNuclideFilter') - - radionuclides = [str(x) for x in tally.filters[i_filter].bins] - n_radionuclides = len(radionuclides) - - # Default to every available index - if indices is None: - indices = range(len(time_correction_factors[radionuclides[0]])) - indices = np.asarray(list(indices), dtype=int) - n_indices = indices.size - - # Build (n_indices, n_radionuclides) TCF matrix - tcf = np.column_stack( - [time_correction_factors[nuc][indices] for nuc in radionuclides] - ) - - # Force std_dev to be computed and the underlying arrays to be read + # Force tally results to be read and std_dev to be computed (once) tally.std_dev - # Reshape to expose the parent-nuclide axis + # Determine number of bins in other filters (computed once) n_bins_before = prod([f.num_bins for f in tally.filters[:i_filter]]) n_bins_after = prod([f.num_bins for f in tally.filters[i_filter + 1:]]) _, n_nuclides, n_scores = tally.shape + n_radionuclides = len(radionuclides) shape5 = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores) - mean_5d = tally.mean.reshape(shape5) - std_dev_5d = tally.std_dev.reshape(shape5) - - if sum_nuclides: - # Move parent-nuclide axis to position 0 and flatten the rest so a - # single matmul does the per-index radionuclide sum. - mean_rf = np.moveaxis(mean_5d, 1, 0).reshape(n_radionuclides, -1) - var_rf = np.moveaxis(std_dev_5d ** 2, 1, 0).reshape(n_radionuclides, -1) - - mean_out = tcf @ mean_rf - # Variances combine linearly when factors are squared; sqrt at the end - # gives the L2 norm matching apply_time_correction. - std_out = np.sqrt((tcf ** 2) @ var_rf) - - out_shape = (n_indices, n_bins_before * n_bins_after, - n_nuclides, n_scores) - return mean_out.reshape(out_shape), std_out.reshape(out_shape) - - # Per-radionuclide: result keeps the parent-nuclide axis. - tcf_b = tcf.reshape(n_indices, 1, n_radionuclides, 1, 1, 1) - mean_out = tcf_b * mean_5d[np.newaxis, ...] - std_out = tcf_b * std_dev_5d[np.newaxis, ...] - - out_shape = (n_indices, -1, n_nuclides, n_scores) - return mean_out.reshape(out_shape), std_out.reshape(out_shape) + # Reshape views shared across all indices + tally_sum_5d = tally.sum.reshape(shape5) + tally_sum_sq_5d = tally.sum_sq.reshape(shape5) + tally_mean_5d = tally.mean.reshape(shape5) + tally_std_dev_5d = tally.std_dev.reshape(shape5) + + flat_shape = (-1, n_nuclides, n_scores) + + results = [] + for idx in indices: + tcf = np.array([time_correction_factors[x][idx] for x in radionuclides]) + tcf.shape = (1, -1, 1, 1, 1) + + # Create shallow copy of tally + new_tally = copy(tally) + new_tally._filters = copy(tally._filters) + + # Apply TCF, broadcasting to the correct dimensions + new_tally._sum = tally_sum_5d * tcf + new_tally._sum_sq = tally_sum_sq_5d * (tcf*tcf) + new_tally._mean = tally_mean_5d * tcf + new_tally._std_dev = tally_std_dev_5d * tcf + + if sum_nuclides: + # Sum over parent nuclides (note that when combining different + # bins for parent nuclide, we can't work directly on sum_sq) + new_tally._mean = new_tally.mean.sum(axis=1).reshape(flat_shape) + new_tally._std_dev = np.linalg.norm(new_tally.std_dev, axis=1).reshape(flat_shape) + new_tally._derived = True + + # Remove ParentNuclideFilter + new_tally.filters.pop(i_filter) + else: + # Change shape back to (filter combinations, nuclides, scores) + new_tally._sum.shape = flat_shape + new_tally._sum_sq.shape = flat_shape + new_tally._mean.shape = flat_shape + new_tally._std_dev.shape = flat_shape + + results.append(new_tally) + + return results[0] if scalar_input else results def prepare_tallies( diff --git a/tests/unit_tests/test_d1s.py b/tests/unit_tests/test_d1s.py index 30ce69e4ee7..8d98fc990c4 100644 --- a/tests/unit_tests/test_d1s.py +++ b/tests/unit_tests/test_d1s.py @@ -152,7 +152,7 @@ def test_apply_time_correction(run_in_tmpdir): result_summed.get_pandas_dataframe() -def test_apply_time_correction_series(run_in_tmpdir): +def test_apply_time_correction_multi_index(run_in_tmpdir): # Build the same model used in test_apply_time_correction mat = openmc.Material() mat.add_element('Ni', 1.0) @@ -171,86 +171,61 @@ def test_apply_time_correction_series(run_in_tmpdir): tally.scores = ['flux'] model.tallies = [tally] - # A schedule with several timesteps so the series has > 1 entry. + # A schedule with several timesteps so we can ask for many indices nuclides = d1s.prepare_tallies(model, chain_file=CHAIN_PATH) timesteps = [1.0e8, 1.0e8, 1.0e8, 1.0e8] source_rates = [1.0, 0.0, 1.0, 0.0] factors = d1s.time_correction_factors(nuclides, timesteps, source_rates) n_times = len(factors[nuclides[0]]) - # Run the model once with openmc.config.patch('chain_file', CHAIN_PATH): output_path = model.run() with openmc.StatePoint(output_path) as sp: tally = sp.tallies[tally.id] - # Snapshot original tally state so we can confirm immutability later orig_filters = list(tally.filters) orig_sum = tally.sum.copy() orig_sum_sq = tally.sum_sq.copy() orig_mean = tally.mean.copy() orig_std_dev = tally.std_dev.copy() - # sum_nuclides=True: series matches per-index loop - mean_series, std_series = d1s.apply_time_correction_series( - tally, factors, sum_nuclides=True - ) - assert mean_series.shape[0] == n_times - assert std_series.shape == mean_series.shape - - for i in range(n_times): - ref = d1s.apply_time_correction( - tally, factors, index=i, sum_nuclides=True - ) - np.testing.assert_allclose( - mean_series[i].reshape(ref.mean.shape), ref.mean - ) - np.testing.assert_allclose( - std_series[i].reshape(ref.std_dev.shape), ref.std_dev - ) - - # sum_nuclides=False: series matches per-index loop - mean_series_f, std_series_f = d1s.apply_time_correction_series( - tally, factors, sum_nuclides=False - ) - assert mean_series_f.shape[0] == n_times - - for i in range(n_times): - ref = d1s.apply_time_correction( - tally, factors, index=i, sum_nuclides=False - ) - np.testing.assert_allclose( - mean_series_f[i].reshape(ref.mean.shape), ref.mean + # Passing a list of indices returns a list of derived tallies, each + # matching what a scalar call at that index would have produced. + for sum_nuc in (True, False): + many = d1s.apply_time_correction( + tally, factors, index=list(range(n_times)), + sum_nuclides=sum_nuc, ) - np.testing.assert_allclose( - std_series_f[i].reshape(ref.std_dev.shape), ref.std_dev - ) - - # explicit indices subset (and unordered) + assert isinstance(many, list) + assert len(many) == n_times + for i, derived in enumerate(many): + ref = d1s.apply_time_correction( + tally, factors, index=i, sum_nuclides=sum_nuc + ) + np.testing.assert_array_equal(derived.mean, ref.mean) + np.testing.assert_array_equal(derived.std_dev, ref.std_dev) + np.testing.assert_array_equal(derived.sum, ref.sum) + np.testing.assert_array_equal(derived.sum_sq, ref.sum_sq) + assert derived.filters == ref.filters + + # Unordered / partial index sequence is honored in order subset = [n_times - 1, 0, 2] - mean_sub, std_sub = d1s.apply_time_correction_series( - tally, factors, indices=subset - ) - assert mean_sub.shape[0] == len(subset) - for k, i in enumerate(subset): + many = d1s.apply_time_correction(tally, factors, index=subset) + assert isinstance(many, list) + assert len(many) == len(subset) + for derived, i in zip(many, subset): ref = d1s.apply_time_correction(tally, factors, index=i) - np.testing.assert_allclose( - mean_sub[k].reshape(ref.mean.shape), ref.mean - ) - np.testing.assert_allclose( - std_sub[k].reshape(ref.std_dev.shape), ref.std_dev - ) + np.testing.assert_array_equal(derived.mean, ref.mean) + np.testing.assert_array_equal(derived.std_dev, ref.std_dev) + + # Scalar input still returns a single Tally, not a 1-element list + single = d1s.apply_time_correction(tally, factors, index=2) + assert isinstance(single, openmc.Tally) + assert not isinstance(single, list) - # original tally is unchanged + # Original tally is unchanged assert tally.filters == orig_filters assert np.all(tally.sum == orig_sum) assert np.all(tally.sum_sq == orig_sum_sq) assert np.all(tally.mean == orig_mean) assert np.all(tally.std_dev == orig_std_dev) - - # missing ParentNuclideFilter raises - bare = openmc.Tally() - bare.filters = [particle_filter] - bare.scores = ['flux'] - with pytest.raises(ValueError): - d1s.apply_time_correction_series(bare, factors) From b788d20d09b9addf611077dc7ecba76129ef082e Mon Sep 17 00:00:00 2001 From: shimwell Date: Fri, 29 May 2026 13:48:44 +0200 Subject: [PATCH 3/3] Vectorize summed D1S time correction with an einsum contraction Rework the sum_nuclides=True path of apply_time_correction so the TCF-weighted sum over the parent-nuclide axis is evaluated as a contraction (np.einsum) rather than a broadcast-multiply-and-reduce per index. The shared 5-D tally views are reshaped once. For a summed (derived) tally the public sum/sum_sq accessors return None regardless of the stored arrays, so the derived tally's sum/sum_sq are left unset rather than recomputed each call: this matches develop's observable behavior, skips two full-array multiplies per index, and avoids storing arrays shaped inconsistently with the popped ParentNuclideFilter (which break Tally.sparse). For a mesh tally (27k bins x 108 radionuclides x ~200 times) this is ~9x faster than the per-index implementation, with mean/std_dev agreeing to ~1e-15 relative. The factor matrix is shaped (n_indices, n_radionuclides) so each index's row is contiguous, keeping a scalar call bit-for-bit identical to the matching slice of a multi-index call. Update the docstring/comments and extend the multi-index unit test. --- openmc/deplete/d1s.py | 101 +++++++++++++++++++++++------------ tests/unit_tests/test_d1s.py | 10 +++- 2 files changed, 74 insertions(+), 37 deletions(-) diff --git a/openmc/deplete/d1s.py b/openmc/deplete/d1s.py index 394c72c3a01..2927ca7480f 100644 --- a/openmc/deplete/d1s.py +++ b/openmc/deplete/d1s.py @@ -144,10 +144,12 @@ def apply_time_correction( Index (or indices) of the time(s) of interest. If N timesteps are provided in :func:`time_correction_factors`, there are N + 1 times to select from. The default is -1 which corresponds to the final time. - Passing an iterable returns a list of derived tallies, one per index, - evaluated in a single pass — the underlying tally arrays are read and - reshaped once and shared across all indices instead of being re-read - and re-copied on every call. + Passing an iterable returns a list of derived tallies, one per index. + The tally arrays are read and reshaped once and shared across all + indices, and (when ``sum_nuclides`` is True) the time-correction is + applied as a contraction over the parent-nuclide axis rather than a + broadcast-and-reduce per index, which is substantially faster for + large tallies (e.g. mesh tallies) evaluated at many times. sum_nuclides : bool Whether to sum over the parent nuclides @@ -156,7 +158,9 @@ def apply_time_correction( openmc.Tally or list of openmc.Tally Derived tally with time correction factors applied. A list is returned when ``index`` is an iterable; otherwise a single tally is - returned. + returned. When ``sum_nuclides`` is True the result is a derived tally, + for which ``sum`` and ``sum_sq`` are None (as for any derived tally); + the meaningful results are ``mean`` and ``std_dev``. """ # Make sure the tally contains a ParentNuclideFilter @@ -182,47 +186,74 @@ def apply_time_correction( _, n_nuclides, n_scores = tally.shape n_radionuclides = len(radionuclides) shape5 = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores) + flat_shape = (-1, n_nuclides, n_scores) # Reshape views shared across all indices - tally_sum_5d = tally.sum.reshape(shape5) - tally_sum_sq_5d = tally.sum_sq.reshape(shape5) tally_mean_5d = tally.mean.reshape(shape5) tally_std_dev_5d = tally.std_dev.reshape(shape5) - flat_shape = (-1, n_nuclides, n_scores) + # Time correction factors for every requested index -> shape + # (n_indices, n_radionuclides). Indexing a row (``tcf[t]``) yields a + # contiguous per-index factor vector, so the einsum below rounds + # identically whether one index or many were requested. + tcf = np.array( + [[time_correction_factors[x][idx] for x in radionuclides] + for idx in indices] + ) results = [] - for idx in indices: - tcf = np.array([time_correction_factors[x][idx] for x in radionuclides]) - tcf.shape = (1, -1, 1, 1, 1) - - # Create shallow copy of tally - new_tally = copy(tally) - new_tally._filters = copy(tally._filters) - - # Apply TCF, broadcasting to the correct dimensions - new_tally._sum = tally_sum_5d * tcf - new_tally._sum_sq = tally_sum_sq_5d * (tcf*tcf) - new_tally._mean = tally_mean_5d * tcf - new_tally._std_dev = tally_std_dev_5d * tcf - - if sum_nuclides: - # Sum over parent nuclides (note that when combining different - # bins for parent nuclide, we can't work directly on sum_sq) - new_tally._mean = new_tally.mean.sum(axis=1).reshape(flat_shape) - new_tally._std_dev = np.linalg.norm(new_tally.std_dev, axis=1).reshape(flat_shape) + if sum_nuclides: + # The TCF-weighted sum over the parent-nuclide axis is a contraction of + # the 5-D arrays with the per-index factor vector, evaluated with a + # single (BLAS-backed) einsum instead of a Python-level + # multiply-and-reduce. Variances combine in quadrature, so std_dev + # contracts the squared values (cf. ``np.linalg.norm`` over the nuclide + # axis). ``sum``/``sum_sq`` are left unset: the public accessors return + # None for any derived tally (so this matches develop's observable + # behavior), and keeping the TCF-scaled per-nuclide arrays would only + # waste two full-array multiplies per index on data nothing reads -- it + # would also be shaped inconsistently with the popped filter, which + # breaks ``Tally.sparse``. Contracting one index at a time keeps the + # result bit-for-bit identical to a scalar call for that index. + # subscripts: i=bins_before, r=radionuclide, j=bins_after, k=nuclide, + # s=score + tally_var_5d = tally_std_dev_5d**2 # reused at every index + for t in range(len(indices)): + tcf_row = tcf[t] + mean = np.einsum('irjks,r->ijks', tally_mean_5d, tcf_row) + std_dev = np.sqrt( + np.einsum('irjks,r->ijks', tally_var_5d, tcf_row*tcf_row)) + + new_tally = copy(tally) + new_tally._filters = copy(tally._filters) + new_tally._mean = mean.reshape(flat_shape) + new_tally._std_dev = std_dev.reshape(flat_shape) + new_tally._sum = None + new_tally._sum_sq = None new_tally._derived = True # Remove ParentNuclideFilter new_tally.filters.pop(i_filter) - else: - # Change shape back to (filter combinations, nuclides, scores) - new_tally._sum.shape = flat_shape - new_tally._sum_sq.shape = flat_shape - new_tally._mean.shape = flat_shape - new_tally._std_dev.shape = flat_shape - - results.append(new_tally) + results.append(new_tally) + else: + # Per-nuclide results are kept, so each index produces a full-size + # array; the shared 5-D views avoid re-reading and re-reshaping the + # tally data on every index. + tally_sum_5d = tally.sum.reshape(shape5) + tally_sum_sq_5d = tally.sum_sq.reshape(shape5) + + for t in range(len(indices)): + tcf_b = tcf[t].reshape(1, -1, 1, 1, 1) + + new_tally = copy(tally) + new_tally._filters = copy(tally._filters) + + # Apply TCF, broadcasting to the correct dimensions + new_tally._sum = (tally_sum_5d * tcf_b).reshape(flat_shape) + new_tally._sum_sq = (tally_sum_sq_5d * (tcf_b*tcf_b)).reshape(flat_shape) + new_tally._mean = (tally_mean_5d * tcf_b).reshape(flat_shape) + new_tally._std_dev = (tally_std_dev_5d * tcf_b).reshape(flat_shape) + results.append(new_tally) return results[0] if scalar_input else results diff --git a/tests/unit_tests/test_d1s.py b/tests/unit_tests/test_d1s.py index 8d98fc990c4..386ce01281d 100644 --- a/tests/unit_tests/test_d1s.py +++ b/tests/unit_tests/test_d1s.py @@ -204,9 +204,15 @@ def test_apply_time_correction_multi_index(run_in_tmpdir): ) np.testing.assert_array_equal(derived.mean, ref.mean) np.testing.assert_array_equal(derived.std_dev, ref.std_dev) - np.testing.assert_array_equal(derived.sum, ref.sum) - np.testing.assert_array_equal(derived.sum_sq, ref.sum_sq) assert derived.filters == ref.filters + if sum_nuc: + # Summed tally is derived; sum/sum_sq are None (as in + # develop, where the public accessors return None for any + # derived tally) + assert derived.sum is None and derived.sum_sq is None + else: + np.testing.assert_array_equal(derived.sum, ref.sum) + np.testing.assert_array_equal(derived.sum_sq, ref.sum_sq) # Unordered / partial index sequence is honored in order subset = [n_times - 1, 0, 2]