diff --git a/src/aspire/samplers/smc/base.py b/src/aspire/samplers/smc/base.py index 338667e..7e713db 100644 --- a/src/aspire/samplers/smc/base.py +++ b/src/aspire/samplers/smc/base.py @@ -3,6 +3,7 @@ from typing import Any, Callable import array_api_compat.numpy as np +import array_api_extra as xpx from orng import ArrayRNG from ...flows.base import Flow @@ -14,7 +15,6 @@ effective_sample_size, to_numpy, track_calls, - update_at_indices, ) from ..mcmc import MCMCSampler @@ -483,9 +483,7 @@ def log_prob(self, z, beta=None): beta=beta ).flatten() + samples.array_to_namespace(log_abs_det_jacobian) - log_prob = update_at_indices( - log_prob, self.xp.isnan(log_prob), -self.xp.inf - ) + log_prob = xpx.at(log_prob, self.xp.isnan(log_prob)).set(-self.xp.inf) return log_prob def build_checkpoint_state( diff --git a/src/aspire/transforms.py b/src/aspire/transforms.py index 2d9bb65..406e9a1 100644 --- a/src/aspire/transforms.py +++ b/src/aspire/transforms.py @@ -3,6 +3,7 @@ import math from typing import Any, Callable +import array_api_extra as xpx import h5py from array_api_compat import device as get_device from array_api_compat import is_torch_namespace @@ -15,7 +16,6 @@ copy_array, logit, sigmoid, - update_at_indices, ) logger = logging.getLogger(__name__) @@ -253,19 +253,15 @@ def fit(self, x): logger.debug( f"Fitting periodic transform to parameters: {self.periodic_parameters}" ) - x = update_at_indices( - x, - (slice(None), self.periodic_mask), - self._periodic_transform.fit(x[:, self.periodic_mask]), + x = xpx.at(x, (slice(None), self.periodic_mask)).set( + self._periodic_transform.fit(x[:, self.periodic_mask]) ) if self.bounded_parameters: logger.debug( f"Fitting bounded transform to parameters: {self.bounded_parameters}" ) - x = update_at_indices( - x, - (slice(None), self.bounded_mask), - self._bounded_transform.fit(x[:, self.bounded_mask]), + x = xpx.at(x, (slice(None), self.bounded_mask)).set( + self._bounded_transform.fit(x[:, self.bounded_mask]) ) if self.affine_transform: logger.debug("Fitting affine transform to all parameters.") @@ -280,14 +276,14 @@ def forward(self, x): y, log_j_periodic = self._periodic_transform.forward( x[..., self.periodic_mask] ) - x = update_at_indices(x, (slice(None), self.periodic_mask), y) + x = xpx.at(x, (slice(None), self.periodic_mask)).set(y) log_abs_det_jacobian += log_j_periodic if self.bounded_parameters: y, log_j_bounded = self._bounded_transform.forward( x[..., self.bounded_mask] ) - x = update_at_indices(x, (slice(None), self.bounded_mask), y) + x = xpx.at(x, (slice(None), self.bounded_mask)).set(y) log_abs_det_jacobian += log_j_bounded if self.affine_transform: @@ -307,14 +303,14 @@ def inverse(self, x): y, log_j_bounded = self._bounded_transform.inverse( x[..., self.bounded_mask] ) - x = update_at_indices(x, (slice(None), self.bounded_mask), y) + x = xpx.at(x, (slice(None), self.bounded_mask)).set(y) log_abs_det_jacobian += log_j_bounded if self.periodic_parameters: y, log_j_periodic = self._periodic_transform.inverse( x[..., self.periodic_mask] ) - x = update_at_indices(x, (slice(None), self.periodic_mask), y) + x = xpx.at(x, (slice(None), self.periodic_mask)).set(y) log_abs_det_jacobian += log_j_periodic return x, log_abs_det_jacobian diff --git a/src/aspire/utils.py b/src/aspire/utils.py index cab6059..909d1e7 100644 --- a/src/aspire/utils.py +++ b/src/aspire/utils.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any import array_api_compat.numpy as np +import array_api_extra as xpx import h5py import wrapt from array_api_compat import ( @@ -955,11 +956,11 @@ def update_at_indices(x: Array, slc: Array, y: Array) -> Array: Array The updated array. """ - try: - x[slc] = y - return x - except TypeError: - return x.at[slc].set(y) + warnings.warn( + "update_at_indices is deprecated and will be removed in a future version. Please use array-api-extra.at instead", + UserWarning, + ) + return xpx.at(x, slc).set(y) @dataclass diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 4f5fc75..bd5af39 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,9 +1,10 @@ import math +import array_api_extra as xpx import pytest from aspire import transforms -from aspire.utils import AspireFile, copy_array, update_at_indices +from aspire.utils import AspireFile, copy_array def _make_array(xp, data, dtype): @@ -270,14 +271,9 @@ def test_composite_transform_forward_inverse_roundtrip(xp, dtype): x_inv, inv_log_j = transform.inverse(y) x_exp = copy_array(x) - print(x_exp) - print((x_exp[:, 0] + 3) % 6 - 3) - x_exp = update_at_indices( - x_exp, - (slice(None), 0), + x_exp = xpx.at(x_exp, (slice(None), 0)).set( ((x_exp[:, 0] + 3) % 6) - 3, ) # Wrap x0 to [-3, 3] - print(x_exp) assert x.shape == y.shape assert xp.allclose(x_inv, x_exp, atol=1e-5)