Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 98 additions & 52 deletions openmc/deplete/d1s.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,33 +124,43 @@ def time_correction_factors(
def apply_time_correction(
tally: openmc.Tally,
time_correction_factors: dict[str, np.ndarray],
index: int = -1,
index: int | Sequence[int] = -1,
sum_nuclides: bool = True
) -> openmc.Tally:
) -> openmc.Tally | list[openmc.Tally]:
"""Apply time correction factors to a tally.

This function applies the time correction factors at the given index to a
tally that contains a :class:`~openmc.ParentNuclideFilter`. When
`sum_nuclides` is True, values over all parent nuclides will be summed,
leaving a single value for each filter combination.
This function applies the time correction factors at the given index (or
indices) to a tally that contains a :class:`~openmc.ParentNuclideFilter`.
When `sum_nuclides` is True, values over all parent nuclides will be
summed, leaving a single value for each filter combination.

Parameters
----------
tally : openmc.Tally
Tally to apply the time correction factors to
time_correction_factors : dict
Time correction factors as returned by :func:`time_correction_factors`
index : int, optional
Index of the time of interest. If N timesteps are provided in
:func:`time_correction_factors`, there are N + 1 times to select from.
The default is -1 which corresponds to the final time.
index : int or iterable of int, optional
Index (or indices) of the time(s) of interest. If N timesteps are
provided in :func:`time_correction_factors`, there are N + 1 times to
select from. The default is -1 which corresponds to the final time.
Passing an iterable returns a list of derived tallies, one per index.
The tally arrays are read and reshaped once and shared across all
indices, and (when ``sum_nuclides`` is True) the time-correction is
applied as a contraction over the parent-nuclide axis rather than a
broadcast-and-reduce per index, which is substantially faster for
large tallies (e.g. mesh tallies) evaluated at many times.
sum_nuclides : bool
Whether to sum over the parent nuclides

Returns
-------
openmc.Tally
Derived tally with time correction factors applied
openmc.Tally or list of openmc.Tally
Derived tally with time correction factors applied. A list is
returned when ``index`` is an iterable; otherwise a single tally is
returned. When ``sum_nuclides`` is True the result is a derived tally,
for which ``sum`` and ``sum_sq`` are None (as for any derived tally);
the meaningful results are ``mean`` and ``std_dev``.

"""
# Make sure the tally contains a ParentNuclideFilter
Expand All @@ -162,54 +172,90 @@ def apply_time_correction(

# Get list of radionuclides based on tally filter
radionuclides = [str(x) for x in tally.filters[i_filter].bins]
tcf = np.array([time_correction_factors[x][index] for x in radionuclides])

# Force tally results to be read and std_dev to be computed
tally.std_dev
# Normalize index to a list; remember whether the caller asked scalar
scalar_input = isinstance(index, (int, np.integer))
indices = [int(index)] if scalar_input else list(index)

# Create shallow copy of tally
new_tally = copy(tally)
new_tally._filters = copy(tally._filters)
# Force tally results to be read and std_dev to be computed (once)
tally.std_dev

# Determine number of bins in other filters
# Determine number of bins in other filters (computed once)
n_bins_before = prod([f.num_bins for f in tally.filters[:i_filter]])
n_bins_after = prod([f.num_bins for f in tally.filters[i_filter + 1:]])

# Reshape sum and sum_sq, apply TCF, and sum along that axis
_, n_nuclides, n_scores = new_tally.shape
_, n_nuclides, n_scores = tally.shape
n_radionuclides = len(radionuclides)
shape = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores)
tally_sum = new_tally.sum.reshape(shape)
tally_sum_sq = new_tally.sum_sq.reshape(shape)
tally_mean = new_tally.mean.reshape(shape)
tally_std_dev = new_tally.std_dev.reshape(shape)

# Apply TCF, broadcasting to the correct dimensions
tcf.shape = (1, -1, 1, 1, 1)
new_tally._sum = tally_sum * tcf
new_tally._sum_sq = tally_sum_sq * (tcf*tcf)
new_tally._mean = tally_mean * tcf
new_tally._std_dev = tally_std_dev * tcf

shape = (-1, n_nuclides, n_scores)

shape5 = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores)
flat_shape = (-1, n_nuclides, n_scores)

# Reshape views shared across all indices
tally_mean_5d = tally.mean.reshape(shape5)
tally_std_dev_5d = tally.std_dev.reshape(shape5)

# Time correction factors for every requested index -> shape
# (n_indices, n_radionuclides). Indexing a row (``tcf[t]``) yields a
# contiguous per-index factor vector, so the einsum below rounds
# identically whether one index or many were requested.
tcf = np.array(
[[time_correction_factors[x][idx] for x in radionuclides]
for idx in indices]
)

results = []
if sum_nuclides:
# Sum over parent nuclides (note that when combining different bins for
# parent nuclide, we can't work directly on sum_sq)
new_tally._mean = new_tally.mean.sum(axis=1).reshape(shape)
new_tally._std_dev = np.linalg.norm(new_tally.std_dev, axis=1).reshape(shape)
new_tally._derived = True

# Remove ParentNuclideFilter
new_tally.filters.pop(i_filter)
# The TCF-weighted sum over the parent-nuclide axis is a contraction of
# the 5-D arrays with the per-index factor vector, evaluated with a
# single (BLAS-backed) einsum instead of a Python-level
# multiply-and-reduce. Variances combine in quadrature, so std_dev
# contracts the squared values (cf. ``np.linalg.norm`` over the nuclide
# axis). ``sum``/``sum_sq`` are left unset: the public accessors return
# None for any derived tally (so this matches develop's observable
# behavior), and keeping the TCF-scaled per-nuclide arrays would only
# waste two full-array multiplies per index on data nothing reads -- it
# would also be shaped inconsistently with the popped filter, which
# breaks ``Tally.sparse``. Contracting one index at a time keeps the
# result bit-for-bit identical to a scalar call for that index.
# subscripts: i=bins_before, r=radionuclide, j=bins_after, k=nuclide,
# s=score
tally_var_5d = tally_std_dev_5d**2 # reused at every index
for t in range(len(indices)):
tcf_row = tcf[t]
mean = np.einsum('irjks,r->ijks', tally_mean_5d, tcf_row)
std_dev = np.sqrt(
np.einsum('irjks,r->ijks', tally_var_5d, tcf_row*tcf_row))

new_tally = copy(tally)
new_tally._filters = copy(tally._filters)
new_tally._mean = mean.reshape(flat_shape)
new_tally._std_dev = std_dev.reshape(flat_shape)
new_tally._sum = None
new_tally._sum_sq = None
new_tally._derived = True

# Remove ParentNuclideFilter
new_tally.filters.pop(i_filter)
results.append(new_tally)
else:
# Change shape back to (filter combinations, nuclides, scores)
new_tally._sum.shape = shape
new_tally._sum_sq.shape = shape
new_tally._mean.shape = shape
new_tally._std_dev.shape = shape

return new_tally
# Per-nuclide results are kept, so each index produces a full-size
# array; the shared 5-D views avoid re-reading and re-reshaping the
# tally data on every index.
tally_sum_5d = tally.sum.reshape(shape5)
tally_sum_sq_5d = tally.sum_sq.reshape(shape5)

for t in range(len(indices)):
tcf_b = tcf[t].reshape(1, -1, 1, 1, 1)

new_tally = copy(tally)
new_tally._filters = copy(tally._filters)

# Apply TCF, broadcasting to the correct dimensions
new_tally._sum = (tally_sum_5d * tcf_b).reshape(flat_shape)
new_tally._sum_sq = (tally_sum_sq_5d * (tcf_b*tcf_b)).reshape(flat_shape)
new_tally._mean = (tally_mean_5d * tcf_b).reshape(flat_shape)
new_tally._std_dev = (tally_std_dev_5d * tcf_b).reshape(flat_shape)
results.append(new_tally)

return results[0] if scalar_input else results


def prepare_tallies(
Expand Down
85 changes: 85 additions & 0 deletions tests/unit_tests/test_d1s.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,88 @@ def test_apply_time_correction(run_in_tmpdir):
result_summed.get_reshaped_data()
result.get_pandas_dataframe()
result_summed.get_pandas_dataframe()


def test_apply_time_correction_multi_index(run_in_tmpdir):
# Build the same model used in test_apply_time_correction
mat = openmc.Material()
mat.add_element('Ni', 1.0)
sphere = openmc.Sphere(r=10.0, boundary_type='vacuum')
cell = openmc.Cell(fill=mat, region=-sphere)
model = openmc.Model()
model.geometry = openmc.Geometry([cell])
model.settings.run_mode = 'fixed source'
model.settings.batches = 3
model.settings.particles = 10
model.settings.photon_transport = True
model.settings.use_decay_photons = True
particle_filter = openmc.ParticleFilter('photon')
tally = openmc.Tally()
tally.filters = [particle_filter]
tally.scores = ['flux']
model.tallies = [tally]

# A schedule with several timesteps so we can ask for many indices
nuclides = d1s.prepare_tallies(model, chain_file=CHAIN_PATH)
timesteps = [1.0e8, 1.0e8, 1.0e8, 1.0e8]
source_rates = [1.0, 0.0, 1.0, 0.0]
factors = d1s.time_correction_factors(nuclides, timesteps, source_rates)
n_times = len(factors[nuclides[0]])

with openmc.config.patch('chain_file', CHAIN_PATH):
output_path = model.run()
with openmc.StatePoint(output_path) as sp:
tally = sp.tallies[tally.id]

orig_filters = list(tally.filters)
orig_sum = tally.sum.copy()
orig_sum_sq = tally.sum_sq.copy()
orig_mean = tally.mean.copy()
orig_std_dev = tally.std_dev.copy()

# Passing a list of indices returns a list of derived tallies, each
# matching what a scalar call at that index would have produced.
for sum_nuc in (True, False):
many = d1s.apply_time_correction(
tally, factors, index=list(range(n_times)),
sum_nuclides=sum_nuc,
)
assert isinstance(many, list)
assert len(many) == n_times
for i, derived in enumerate(many):
ref = d1s.apply_time_correction(
tally, factors, index=i, sum_nuclides=sum_nuc
)
np.testing.assert_array_equal(derived.mean, ref.mean)
np.testing.assert_array_equal(derived.std_dev, ref.std_dev)
assert derived.filters == ref.filters
if sum_nuc:
# Summed tally is derived; sum/sum_sq are None (as in
# develop, where the public accessors return None for any
# derived tally)
assert derived.sum is None and derived.sum_sq is None
else:
np.testing.assert_array_equal(derived.sum, ref.sum)
np.testing.assert_array_equal(derived.sum_sq, ref.sum_sq)

# Unordered / partial index sequence is honored in order
subset = [n_times - 1, 0, 2]
many = d1s.apply_time_correction(tally, factors, index=subset)
assert isinstance(many, list)
assert len(many) == len(subset)
for derived, i in zip(many, subset):
ref = d1s.apply_time_correction(tally, factors, index=i)
np.testing.assert_array_equal(derived.mean, ref.mean)
np.testing.assert_array_equal(derived.std_dev, ref.std_dev)

# Scalar input still returns a single Tally, not a 1-element list
single = d1s.apply_time_correction(tally, factors, index=2)
assert isinstance(single, openmc.Tally)
assert not isinstance(single, list)

# Original tally is unchanged
assert tally.filters == orig_filters
assert np.all(tally.sum == orig_sum)
assert np.all(tally.sum_sq == orig_sum_sq)
assert np.all(tally.mean == orig_mean)
assert np.all(tally.std_dev == orig_std_dev)