Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nessai/proposal/augmented.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def set_rescaling(self):
def update_flow_config(self):
"""Update the flow configuration dictionary"""
super().update_flow_config()
m = np.ones(self.rescaled_dims)
m = np.ones(self.prime_dims)
m[-self.augment_dims :] = -1
self.flow_config["mask"] = m

Expand Down
13 changes: 12 additions & 1 deletion src/nessai/proposal/flowproposal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from abc import abstractmethod
from inspect import signature
from typing import Optional
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -212,6 +213,16 @@ def dims(self):
@property
def rescaled_dims(self):
"""Return the number of rescaled dimensions"""
warn(
"rescaled_dims is deprecated and will be removed in a future "
"release, use prime_dims instead",
DeprecationWarning,
)
return len(self.prime_parameters)

@property
def prime_dims(self):
"""Return the number of prime dimensions"""
return len(self.prime_parameters)

@property
Expand Down Expand Up @@ -319,7 +330,7 @@ def configure_plotting(self, plot):

def update_flow_config(self):
"""Update the flow configuration dictionary."""
self.flow_config["n_inputs"] = self.rescaled_dims
self.flow_config["n_inputs"] = self.prime_dims

def initialise(self, resumed: bool = False) -> None:
"""
Expand Down
12 changes: 5 additions & 7 deletions src/nessai/proposal/flowproposal/flowproposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def configure_constant_volume(self):
f"{self.latent_prior}"
)
self.fixed_radius = compute_radius(
self.rescaled_dims, self.volume_fraction
self.prime_dims, self.volume_fraction
)
self.fuzz = 1.0
if self.max_radius < self.fixed_radius:
Expand Down Expand Up @@ -267,9 +267,7 @@ def set_rescaling(self):
super().set_rescaling()
if self.expansion_fraction and self.expansion_fraction is not None:
logger.info("Overwriting fuzz factor with expansion fraction")
self.fuzz = (1 + self.expansion_fraction) ** (
1 / self.rescaled_dims
)
self.fuzz = (1 + self.expansion_fraction) ** (1 / self.prime_dims)
logger.info(f"New fuzz factor: {self.fuzz}")
self.configure_constant_volume()

Expand Down Expand Up @@ -302,7 +300,7 @@ def prep_latent_prior(self):
"""Prepare the latent prior."""
if self.latent_prior == "truncated_gaussian":
self._populate_dist = NDimensionalTruncatedGaussian(
self.dims,
self.prime_dims,
self.r,
fuzz=self.fuzz,
rng=self.rng,
Expand All @@ -312,7 +310,7 @@ def prep_latent_prior(self):
self._draw_func = lambda N: self.flow.sample_latent_distribution(N)
else:
draw_kwargs = dict(
dims=self.dims,
dims=self.prime_dims,
r=self.r,
fuzz=self.fuzz,
rng=self.rng,
Expand Down Expand Up @@ -601,7 +599,7 @@ def get_alt_distribution(self):
"""
if self.latent_prior in ["uniform_nsphere", "uniform_nball"]:
return get_uniform_distribution(
self.dims, self.r * self.fuzz, device=self.flow.device
self.prime_dims, self.r * self.fuzz, device=self.flow.device
)

def reset(self):
Expand Down
9 changes: 9 additions & 0 deletions tests/test_deprecation_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@ def test_compute_evidence_ratio_deprecation():
state = create_autospec(_INSIntegralState)
with pytest.deprecated_call():
_INSIntegralState.compute_evidence_ratio(state)


def test_rescaled_dims_deprecation():
"""Assert a warning is raised when rescaled_dims is accessed"""
from nessai.proposal.flowproposal import FlowProposal

proposal = create_autospec(FlowProposal, prime_parameters=["x", "y"])
with pytest.deprecated_call():
assert FlowProposal.rescaled_dims.__get__(proposal) == 2
2 changes: 1 addition & 1 deletion tests/test_proposal/test_augmented.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_init(model):

def test_update_flow_config(proposal):
"""Test update flow config"""
proposal.rescaled_dims = 4
proposal.prime_dims = 4
proposal.augment_dims = 2
proposal.flow_config = dict()
with patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_configure_plotting(proposal, plot, plot_pool, plot_train):
def test_update_flow_proposal(proposal):
"""Assert the number of inputs is updated"""
proposal.flow_config = {}
proposal.rescaled_dims = 4
proposal.prime_dims = 4
BaseFlowProposal.update_flow_config(proposal)
assert proposal.flow_config["n_inputs"] == 4

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def test_dims(proposal):
assert BaseFlowProposal.dims.__get__(proposal) == 2


def test_rescaled_dims(proposal):
"""Test rescaled_dims property"""
def test_prime_dims(proposal):
"""Test prime_dims property"""
proposal.prime_parameters = ["x", "y"]
assert BaseFlowProposal.rescaled_dims.__get__(proposal) == 2
assert BaseFlowProposal.prime_dims.__get__(proposal) == 2


def test_dtype(proposal):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_configure_constant_volume(proposal, latent_prior):
"""Test configuration for constant volume mode."""
proposal.constant_volume_mode = True
proposal.volume_fraction = 0.95
proposal.rescaled_dims = 5
proposal.prime_dims = 5
proposal.latent_prior = latent_prior
proposal.max_radius = 3.0
proposal.min_radius = 5.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_get_alt_distribution_uniform(proposal, prior):
the n-ball.
"""
proposal.latent_prior = prior
proposal.dims = 2
proposal.prime_dims = 2
proposal.r = 2.0
proposal.fuzz = 1.2
proposal.flow = Mock()
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_prep_latent_prior_truncated(proposal):
"""Assert prep latent prior calls the correct values"""

proposal.latent_prior = "truncated_gaussian"
proposal.dims = 2
proposal.prime_dims = 2
proposal.r = 3.0
proposal.fuzz = 1.2
dist = MagicMock()
Expand All @@ -108,7 +108,7 @@ def test_prep_latent_prior_other(proposal):
"""Assert partial acts as expected"""
proposal.latent_prior = "gaussian"
proposal.latent_temperature = 0.9
proposal.dims = 2
proposal.prime_dims = 2
proposal.r = 3.0
proposal.fuzz = 1.2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"expansion_fraction, fuzz", [(None, 2.0), (0.5, 1.5**0.5)]
)
def test_set_rescaling(proposal, expansion_fraction, fuzz):
proposal.rescaled_dims = 2
proposal.prime_dims = 2
proposal.expansion_fraction = expansion_fraction
proposal.fuzz = 2.0
with patch(
Expand Down
Loading