Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/aspire/samplers/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable

import array_api_compat.numpy as np
import array_api_extra as xpx
from orng import ArrayRNG

from ...flows.base import Flow
Expand All @@ -14,7 +15,6 @@
effective_sample_size,
to_numpy,
track_calls,
update_at_indices,
)
from ..mcmc import MCMCSampler

Expand Down Expand Up @@ -483,9 +483,7 @@ def log_prob(self, z, beta=None):
beta=beta
).flatten() + samples.array_to_namespace(log_abs_det_jacobian)

log_prob = update_at_indices(
log_prob, self.xp.isnan(log_prob), -self.xp.inf
)
log_prob = xpx.at(log_prob, self.xp.isnan(log_prob)).set(-self.xp.inf)
return log_prob

def build_checkpoint_state(
Expand Down
22 changes: 9 additions & 13 deletions src/aspire/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
from typing import Any, Callable

import array_api_extra as xpx
import h5py
from array_api_compat import device as get_device
from array_api_compat import is_torch_namespace
Expand All @@ -15,7 +16,6 @@
copy_array,
logit,
sigmoid,
update_at_indices,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -253,19 +253,15 @@ def fit(self, x):
logger.debug(
f"Fitting periodic transform to parameters: {self.periodic_parameters}"
)
x = update_at_indices(
x,
(slice(None), self.periodic_mask),
self._periodic_transform.fit(x[:, self.periodic_mask]),
x = xpx.at(x, (slice(None), self.periodic_mask)).set(
self._periodic_transform.fit(x[:, self.periodic_mask])
)
if self.bounded_parameters:
logger.debug(
f"Fitting bounded transform to parameters: {self.bounded_parameters}"
)
x = update_at_indices(
x,
(slice(None), self.bounded_mask),
self._bounded_transform.fit(x[:, self.bounded_mask]),
x = xpx.at(x, (slice(None), self.bounded_mask)).set(
self._bounded_transform.fit(x[:, self.bounded_mask])
)
if self.affine_transform:
logger.debug("Fitting affine transform to all parameters.")
Expand All @@ -280,14 +276,14 @@ def forward(self, x):
y, log_j_periodic = self._periodic_transform.forward(
x[..., self.periodic_mask]
)
x = update_at_indices(x, (slice(None), self.periodic_mask), y)
x = xpx.at(x, (slice(None), self.periodic_mask)).set(y)
log_abs_det_jacobian += log_j_periodic

if self.bounded_parameters:
y, log_j_bounded = self._bounded_transform.forward(
x[..., self.bounded_mask]
)
x = update_at_indices(x, (slice(None), self.bounded_mask), y)
x = xpx.at(x, (slice(None), self.bounded_mask)).set(y)
log_abs_det_jacobian += log_j_bounded

if self.affine_transform:
Expand All @@ -307,14 +303,14 @@ def inverse(self, x):
y, log_j_bounded = self._bounded_transform.inverse(
x[..., self.bounded_mask]
)
x = update_at_indices(x, (slice(None), self.bounded_mask), y)
x = xpx.at(x, (slice(None), self.bounded_mask)).set(y)
log_abs_det_jacobian += log_j_bounded

if self.periodic_parameters:
y, log_j_periodic = self._periodic_transform.inverse(
x[..., self.periodic_mask]
)
x = update_at_indices(x, (slice(None), self.periodic_mask), y)
x = xpx.at(x, (slice(None), self.periodic_mask)).set(y)
log_abs_det_jacobian += log_j_periodic

return x, log_abs_det_jacobian
Expand Down
11 changes: 6 additions & 5 deletions src/aspire/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING, Any

import array_api_compat.numpy as np
import array_api_extra as xpx
import h5py
import wrapt
from array_api_compat import (
Expand Down Expand Up @@ -955,11 +956,11 @@ def update_at_indices(x: Array, slc: Array, y: Array) -> Array:
Array
The updated array.
"""
try:
x[slc] = y
return x
except TypeError:
return x.at[slc].set(y)
warnings.warn(
"update_at_indices is deprecated and will be removed in a future version. Please use array-api-extra.at instead",
UserWarning,
)
return xpx.at(x, slc).set(y)


@dataclass
Expand Down
10 changes: 3 additions & 7 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import math

import array_api_extra as xpx
import pytest

from aspire import transforms
from aspire.utils import AspireFile, copy_array, update_at_indices
from aspire.utils import AspireFile, copy_array


def _make_array(xp, data, dtype):
Expand Down Expand Up @@ -270,14 +271,9 @@ def test_composite_transform_forward_inverse_roundtrip(xp, dtype):
x_inv, inv_log_j = transform.inverse(y)

x_exp = copy_array(x)
print(x_exp)
print((x_exp[:, 0] + 3) % 6 - 3)
x_exp = update_at_indices(
x_exp,
(slice(None), 0),
x_exp = xpx.at(x_exp, (slice(None), 0)).set(
((x_exp[:, 0] + 3) % 6) - 3,
) # Wrap x0 to [-3, 3]
print(x_exp)

assert x.shape == y.shape
assert xp.allclose(x_inv, x_exp, atol=1e-5)
Expand Down
Loading