diff --git a/src/pysatl_core/distributions/support.py b/src/pysatl_core/distributions/support.py index 40f49bf..9aaf2e7 100644 --- a/src/pysatl_core/distributions/support.py +++ b/src/pysatl_core/distributions/support.py @@ -14,13 +14,25 @@ __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" +from collections.abc import Callable from dataclasses import dataclass from math import floor -from typing import TYPE_CHECKING, Protocol, cast, overload, runtime_checkable +from typing import ( + TYPE_CHECKING, + Protocol, + cast, + runtime_checkable, +) import numpy as np -from pysatl_core.types import BoolArray, Interval1D, Number, NumericArray +from pysatl_core.types import ( + BoolArray, + Interval1D, + IntervalND, + Number, + NumericArray, +) if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -34,10 +46,7 @@ class Support(Protocol): Support defines the set of values where a distribution is defined. """ - @overload - def contains(self, x: Number) -> bool: ... - @overload - def contains(self, x: NumericArray) -> BoolArray: ... + def contains(self, x: NumericArray) -> bool | BoolArray: ... class ContinuousSupport(Interval1D, Support): @@ -49,6 +58,15 @@ class ContinuousSupport(Interval1D, Support): """ +class ContinuousNDSupport(IntervalND, Support): + """ + Support for continuous distributions represented as an array of intervals. + + This class inherits from IntervalND and implements the Support protocol + for continuous distributions defined on a list of intervals [left, right]. + """ + + @runtime_checkable class DiscreteSupport(Support, Protocol): """ @@ -128,12 +146,7 @@ def __init__(self, points: Iterable[Number], assume_sorted: bool = False) -> Non self._points = arr[unique_mask] - @overload - def contains(self, x: Number) -> bool: ... - @overload - def contains(self, x: NumericArray) -> BoolArray: ... - - def contains(self, x: Number | NumericArray) -> bool | BoolArray: + def contains(self, x: NumericArray) -> bool | BoolArray: """ Check if point(s) are in the support. @@ -162,10 +175,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray: return bool(result) return cast(BoolArray, result) - def __contains__(self, x: object) -> bool: - """Check if a point is in the support.""" - return bool(self.contains(cast(Number, x))) - def iter_points(self) -> Iterator[Number]: """Iterate through all points in the support.""" return iter(self._points) @@ -252,12 +261,7 @@ def __post_init__(self) -> None: if self.modulus <= 0: raise ValueError("modulus must be a positive integer.") - @overload - def contains(self, x: Number) -> bool: ... - @overload - def contains(self, x: NumericArray) -> BoolArray: ... - - def contains(self, x: Number | NumericArray) -> bool | BoolArray: + def contains(self, x: NumericArray) -> bool | BoolArray: """ Check if point(s) are in the integer lattice support. @@ -283,10 +287,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray: return bool(result) return cast(BoolArray, result) - def __contains__(self, x: object) -> bool: - """Check if a point is in the integer lattice support.""" - return bool(self.contains(cast(Number, x))) - def iter_points(self) -> Iterator[int]: """ Iterate through all points in the integer lattice support. @@ -430,10 +430,20 @@ def is_right_bounded(self) -> bool: __iter__ = iter_points +@dataclass(frozen=True, slots=True) +class PredicateSupport(Support): + predicate: Callable[[NumericArray], bool | BoolArray] + + def contains(self, x: NumericArray) -> bool | BoolArray: + return self.predicate(x) + + __all__ = [ # Base support protocol "Support", "ContinuousSupport", + "ContinuousNDSupport", + "PredicateSupport", # Discrete support protocol and implementations "DiscreteSupport", "ExplicitTableDiscreteSupport", diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py index ed30528..502dac9 100644 --- a/src/pysatl_core/families/__init__.py +++ b/src/pysatl_core/families/__init__.py @@ -14,6 +14,12 @@ from .builtins import __all__ as _builtins_all from .configuration import configure_families_register from .distribution import ParametricFamilyDistribution +from .exponential_family import ( + # CanonicalContinuousExponentialClassFamily, + ContinuousExponentialClassFamily, + ExponentialConjugateHyperparameters, + ExponentialFamilyParametrization, +) from .parametric_family import ParametricFamily from .parametrizations import ( Parametrization, @@ -34,6 +40,9 @@ "configure_families_register", # builtins *_builtins_all, + "ContinuousExponentialClassFamily", + "ExponentialFamilyParametrization", + "ExponentialConjugateHyperparameters", ] del _builtins_all diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py new file mode 100644 index 0000000..9f835bf --- /dev/null +++ b/src/pysatl_core/families/exponential_family.py @@ -0,0 +1,483 @@ +""" +Exponential family distributions in continuous spaces. + +This module implements the continuous exponential family of probability distributions, +their conjugate priors, posterior inference, and posterior predictive distributions. +""" + +from __future__ import annotations + +__author__ = "Vinogradov Ilya" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +import numpy as np +from scipy.differentiate import jacobian +from scipy.integrate import nquad +from scipy.linalg import det + +from pysatl_core.distributions.support import ( + ContinuousSupport, + PredicateSupport, +) +from pysatl_core.families.parametric_family import ParametricFamily +from pysatl_core.families.parametrizations import Parametrization, constraint, parametrization +from pysatl_core.types import ( + CharacteristicName, + DistributionType, + Number, + NumericArray, + ParametrizationName, + UnivariateContinuous, +) + +if TYPE_CHECKING: + from pysatl_core.distributions.support import Support + from pysatl_core.families.parametric_family import ( + CharacteristicsMap, + ParametricFamilyCharacteristic, + ) + + type SupportArg = Callable[[Parametrization], Support | None] | None + + +@dataclass +class ExponentialFamilyParametrization(Parametrization): + """ + Standard parametrization of an exponential family distribution. + + This parametrization uses the natural (canonical) parameter vector `theta` + The density is expressed as: + f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ)) + + Attributes: + theta (NumericArray): Natural parameter vector (can be a scalar or array) + """ + + theta: NumericArray + + def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization: + """Return the base parametrization (identity transform for canonical form).""" + return self + + +@dataclass +class ExponentialConjugateHyperparameters(Parametrization): + """ + Hyperparameters for the conjugate prior of an exponential family + + For a prior of the form: + p(θ) ∝ exp(ν₀ᵀ T(θ) + n₀ A(θ)) + the hyperparameters are: + effective_suff_stat_value = ν₀ + effective_sample_size = n₀ + + Attributes: + effective_suff_stat_value (NumericArray): Pseudo‑sufficient statistic ν₀ + effective_sample_size (Number): Pseudo‑sample size n₀ (a non‑negative scalar) + """ + + effective_suff_stat_value: NumericArray + effective_sample_size: Number + + def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization: + """ + Convert hyperparameters to a canonical parametrization. + + The resulting parameter vector is [ν₀, n₀] concatenated. + """ + return ExponentialFamilyParametrization( + np.append(self.effective_suff_stat_value, self.effective_sample_size) + ) + + +class ContinuousExponentialClassFamily(ParametricFamily): + """ + Representation of a continuous exponential family distribution. + + The density is given by: + f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ)) + + where: + - θ is the natural parameter, + - T(x) is the sufficient statistic vector, + - h(x) is the base measure (the `normalization_constant`), + - A(θ) is the log‑partition function. + + This class supports: + - Canonical parametrization (θ) via `ExponentialFamilyParametrization`. + - Conjugate prior families. + - Posterior updates and posterior predictive distributions. + - Transformation of the random variable (change of variable with Jacobian). + + The user must supply functions for the log‑partition `log_partition`, + sufficient statistics `sufficient_statistics`, + base measure `normalization_constant`, as well as the support of the distribution, + the natural parameter space and the range of the sufficient statistic. + """ + + def __init__( + self, + *, + log_partition: Callable[[NumericArray], NumericArray], + sufficient_statistics: Callable[[NumericArray], NumericArray], + normalization_constant: Callable[[NumericArray], Number], + support: Support, + parameter_space: Support, + sufficient_statistics_values: Support, + name: str, + distr_type: DistributionType | Callable[[Parametrization], DistributionType], + distr_parametrizations: list[ParametrizationName], + distr_characteristics: CharacteristicsMap | None = None, + support_by_parametrization: SupportArg = None, + base_score: Callable[[Parametrization, NumericArray], NumericArray] | None = None, + ): + """ + Initialize a continuous exponential family distribution. + + Args: + log_partition: Function A(θ) – the log‑partition function. + sufficient_statistics: Function T(x) – the sufficient statistic vector. + normalization_constant: Function h(x) – the base measure. + support: Predicate defining the support of the distribution. + parameter_space: Predicate defining the natural parameter space. + sufficient_statistics_values: Predicate defining the range of T(x). + name: Name of the family. + distr_type: Type of distribution or a callable returning it. + distr_parametrizations: List of parametrization names this family supports. + distr_characteristics: Additional analytical characteristics to register. + support_by_parametrization: Callable that returns the support given a parametrization. + base_score: Optional base score function. + """ + self._sufficient = sufficient_statistics + self._log_partition = log_partition + self._normalization = normalization_constant + + self._support = support + self._parameter_space = parameter_space + self._sufficient_statistics_values = sufficient_statistics_values + + family_characteristics: CharacteristicsMap = { + CharacteristicName.PDF: self.density, + CharacteristicName.MEAN: self._mean, + CharacteristicName.VAR: self._var, + } + merged_characteristics = dict(distr_characteristics or {}) + merged_characteristics.update(family_characteristics) + + ParametricFamily.__init__( + self, + name=name, + distr_type=distr_type, + distr_parametrizations=distr_parametrizations, + distr_characteristics=merged_characteristics, + support_by_parametrization=support_by_parametrization, + base_score=base_score, + ) + + @parametrization(family=self, name="theta") + class ThetaParametrization(ExponentialFamilyParametrization): + @constraint(description="theta belongs to parameter_space") + def check_theta_in_parameter_space(self) -> bool: + theta = np.atleast_1d(np.asarray(self.theta, dtype=float)) + return bool(self.__family__._parameter_space.contains(theta)) # type: ignore[attr-defined] + + @property + def log_density(self) -> ParametricFamilyCharacteristic[NumericArray, Number]: + """ + Log‑density function for the exponential family. + + The function takes a parametrization (must be `ExponentialFamilyParametrization`) + and a point `x`, and returns log f(x|θ). Returns -inf for x outside the support. + + Returns: + Callable[[Parametrization, NumericArray], Number] + """ + + def log_density_func(parametrization: Parametrization, x: NumericArray) -> Number: + parametrization = cast(ExponentialFamilyParametrization, parametrization) + parametrization = parametrization.transform_to_base_parametrization() + if not self._support.contains(np.array([x])): + return -np.inf + + theta = parametrization.theta + sufficient = self._sufficient(x) + dot = np.dot(theta, sufficient) + if hasattr(dot, "__len__"): + dot = dot[0] + + result = np.log(self._normalization(x)) + dot + self._log_partition(theta) + return cast(np.floating, result.item()) + + return log_density_func + + @property + def density(self) -> ParametricFamilyCharacteristic[NumericArray, Number]: + """ + Density function (exponentiated log‑density). + + Returns: + Callable[[Parametrization, NumericArray], Number] + """ + log_density = cast(Callable[[Parametrization, NumericArray], Number], self.log_density) + + def density_func(parametrization: Parametrization, x: NumericArray) -> Number: + return cast(Number, np.exp(log_density(parametrization, x))) + + return density_func + + @property + def conjugate_prior_family(self) -> ContinuousExponentialClassFamily: + """ + Build the conjugate prior family for this exponential family. + + The conjugate prior is an exponential family in the natural parameter θ, + with sufficient statistic [θ, A(θ)] and base measure 1. The resulting + family has its own [log_partition, sufficient_statistics, ...] such that + the posterior updates are given by adding the observed sufficient statistics. + + Returns: + ContinuousExponentialClassFamily: The conjugate prior family. + """ + + def conjugate_sufficient( + theta: NumericArray, + ) -> NumericArray: + if not hasattr(theta, "__len__"): + theta = np.array([theta]) + + if not self._parameter_space.contains(theta): + return np.full(len(theta) + 1, float("-inf")) + return np.append(theta, self._log_partition(theta)) + + def conjugate_log_partition( + parametrization: NumericArray, + ) -> NumericArray: + def pdf(theta: NumericArray) -> Number: + if not hasattr(theta, "__len__"): + theta = np.array([theta]) + return cast( + np.floating, + np.exp( + np.dot( + conjugate_sufficient(theta), + parametrization, + ) + ).item(), + ) + + def integrand(x: float) -> float: + theta = np.asarray([x], dtype=float) + if not self._parameter_space.contains(theta): + return 0.0 + return float(pdf(theta)) + + all_value = nquad(integrand, [(float("-inf"), float("+inf"))])[0] + return np.array([cast(np.float64, -np.log(all_value))]) + + def conjugate_sufficient_accepts( + theta: NumericArray, + ) -> bool: + xi = theta[:-1] + nu = theta[-1] + + return bool(self._sufficient_statistics_values.contains(xi)) and bool( + ContinuousSupport(0, np.inf).contains(np.array([nu])) + ) + + return ContinuousExponentialClassFamily( + log_partition=conjugate_log_partition, + sufficient_statistics=conjugate_sufficient, + normalization_constant=lambda _: 1, + support=self._parameter_space, + sufficient_statistics_values=self._parameter_space, + parameter_space=PredicateSupport(predicate=conjugate_sufficient_accepts), + name=self.name, + distr_type=self._distr_type, + distr_parametrizations=self.parametrization_names, + support_by_parametrization=self.support_resolver, + ) + + def transform( + self, + transform_function: Callable[[NumericArray], NumericArray], + ) -> ContinuousExponentialClassFamily: + """ + Transform the random variable by a monotonic, differentiable function. + + The new density is obtained via the change‑of‑variable formula. + The sufficient statistic becomes T(transform⁻¹(x)) and the base measure + gains the Jacobian factor. + + Args: + transform_function: Invertible, differentiable function g(x) such that + y = g(x). Must be defined on the original support. + + Returns: + ContinuousExponentialClassFamily: A new family for the transformed variable. + """ + + def calculate_jacobian(x: NumericArray) -> NumericArray: + if not isinstance(x, Iterable): + x = np.array([x], dtype=float) + else: + x = np.atleast_1d(np.asarray(x, dtype=float)) + + return np.abs(det(jacobian(transform_function, x).df)) + + def new_support(x: NumericArray) -> bool: + return bool(self._support.contains(transform_function(x))) + + def new_sufficient(x: NumericArray) -> NumericArray: + return self._sufficient(transform_function(x)) + + def new_normalization(x: NumericArray) -> Number: + return cast(np.float64, self._normalization(x) * calculate_jacobian(x)) + + return ContinuousExponentialClassFamily( + log_partition=self._log_partition, + sufficient_statistics=new_sufficient, + normalization_constant=new_normalization, + support=PredicateSupport(predicate=new_support), + parameter_space=self._parameter_space, + sufficient_statistics_values=self._sufficient_statistics_values, + name=f"Transformed{self._name}", + distr_type=self._distr_type, + distr_parametrizations=self.parametrization_names, + support_by_parametrization=self.support_resolver, + ) + + @property + def _mean(self) -> ParametricFamilyCharacteristic[Any, Any]: + """Compute the mean E[X] by numerical integration over the density.""" + + def mean_func(parametrization: Parametrization) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) + density = cast(Callable[[Parametrization, NumericArray], Number], self.density) + return nquad( + lambda x: ( + np.dot(x, density(parametrization, x)) + if self._support.contains(np.array([x])) + else 0 + ), + [(float("-inf"), float("inf"))], + )[0] + + return mean_func + + @property + def _second_moment(self) -> ParametricFamilyCharacteristic[Any, Any]: + """Compute the second moment E[X²] by numerical integration.""" + + def func(parametrization: Parametrization) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) + density = cast(Callable[[Parametrization, NumericArray], Number], self.density) + return nquad( + lambda x: ( + x**2 * density(parametrization, x) + if self._support.contains(np.array([x])) + else 0 + ), + [(float("-inf"), float("inf"))], + )[0] + + return func + + @property + def _var(self) -> ParametricFamilyCharacteristic[Any, Any]: + """Compute the variance Var[X] = E[X²] - (E[X])².""" + + def func(parametrization: Parametrization) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) + second_moment = cast(Callable[[Parametrization], Any], self._second_moment) + mean = cast(Callable[[Parametrization], Any], self._mean) + return second_moment(parametrization) - mean(parametrization) ** 2 + + return func + + def posterior_hyperparameters( + self, parametrizaiton: ExponentialConjugateHyperparameters, sample: list[Any] + ) -> ExponentialConjugateHyperparameters: + """ + Update the conjugate prior hyperparameters given observed data. + + For a conjugate prior with hyperparameters (ν₀, n₀), the posterior + hyperparameters become: + ν = ν₀ + Σ_{i} T(x_i) + n = n₀ + N + + Args: + parametrizaiton: Current conjugate hyperparameters. + sample: List of observations (each can be scalar or array). + + Returns: + ExponentialConjugateHyperparameters: + Updated hyperparameters after incorporating the sample. + """ + posterior_effective_suff_stat_value = np.array( + parametrizaiton.effective_suff_stat_value, + copy=True, + ) + posterior_effective_sample_size = parametrizaiton.effective_sample_size + if hasattr(sample, "__iter__") and not isinstance(sample, str): + posterior_effective_suff_stat_value += np.sum( + [self._sufficient(x) for x in sample], + axis=0, + ) + posterior_effective_sample_size += len(sample) + else: + posterior_effective_suff_stat_value += self._sufficient(sample) # type: ignore[arg-type] + posterior_effective_sample_size += 1 + + return ExponentialConjugateHyperparameters( + effective_suff_stat_value=posterior_effective_suff_stat_value, + effective_sample_size=posterior_effective_sample_size, + ) + + @property + def posterior_predictive(self) -> ParametricFamily: + """ + Construct the posterior predictive distribution. + + For a conjugate prior, the posterior predictive density of a new observation x + given hyperparameters (ν, n) is: + p(x | ν, n) = h(x) * exp( A(ν) - A(ν + T(x)) ) + where A(·) is the log‑partition function of the conjugate prior family. + + Returns: + ParametricFamily: A family with parametrization `ExponentialConjugateHyperparameters` + and a `pdf` method implementing the posterior predictive density. + """ + + def conjugate_log_partition( + parametrization: ExponentialConjugateHyperparameters, + ) -> NumericArray: + conjugate_value = self.conjugate_prior_family._log_partition( + parametrization.transform_to_base_parametrization().theta + ) + return np.exp(conjugate_value) + + def posterior_density(parametrization: Parametrization, x: NumericArray) -> Number: + parametrization = cast(ExponentialConjugateHyperparameters, parametrization) + return cast( + np.float32, + self._normalization(x) + * conjugate_log_partition(parametrization) + / conjugate_log_partition( + self.posterior_hyperparameters(parametrizaiton=parametrization, sample=[x]) + ), + ) + + family = ParametricFamily( + name=f"PosteriorPredictive{self.name}", + distr_type=UnivariateContinuous, + distr_characteristics={CharacteristicName.PDF: posterior_density}, + distr_parametrizations=["posterior"], + support_by_parametrization=lambda _: ContinuousSupport(), + ) + parametrization(family=family, name="posterior")(ExponentialConjugateHyperparameters) + return family diff --git a/src/pysatl_core/types.py b/src/pysatl_core/types.py index 25aa63e..33de0af 100644 --- a/src/pysatl_core/types.py +++ b/src/pysatl_core/types.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from enum import Enum, StrEnum, auto from math import inf -from typing import TYPE_CHECKING, Any, cast, overload +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from pysatl_core.distributions.computations.computation import ( @@ -179,13 +179,7 @@ def __post_init__(self) -> None: if self.right == inf and self.right_closed: object.__setattr__(self, "right_closed", False) - @overload - def contains(self, x: Number) -> bool: ... - - @overload - def contains(self, x: NumericArray) -> BoolArray: ... - - def contains(self, x: Number | NumericArray) -> bool | BoolArray: + def contains(self, x: NumericArray) -> bool | BoolArray: """ Check if point(s) are contained in the interval. @@ -210,10 +204,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray: return result - def __contains__(self, x: object) -> bool: - """Check if a single point is in the interval.""" - return bool(self.contains(cast(Number, x))) - @property def is_empty(self) -> bool: """Check if the interval is empty.""" @@ -250,6 +240,25 @@ def shape(self) -> ContinuousSupportShape1D: type Method[In, Out] = AnalyticalComputation[In, Out] | FittedComputationMethod[In, Out] """Type alias for a distribution computation method (analytical or fitted).""" + +@dataclass(frozen=True, slots=True) +class IntervalND: + intervals: list[Interval1D] + + def contains(self, x: NumericArray) -> bool | BoolArray: + def contains_for_point(point: NumericArray) -> bool: + assert len(point) == len(self.intervals) + return all( + bool(interval.contains(np.asarray(x_coordinate))) + for interval, x_coordinate in zip(self.intervals, point, strict=True) + ) + + if len(x.shape) == 1: + return contains_for_point(x) + + return np.array([contains_for_point(point) for point in x]) + + type GenericCharacteristicName = str """Type alias for characteristic names (e.g., 'pdf', 'cdf').""" @@ -465,6 +474,7 @@ class FamilyName(StrEnum): "TransformationMethodSpecsMap", "DistributionType", "Interval1D", + "IntervalND", "ContinuousSupportShape1D", "BoolArray", "NumPyNumber", diff --git a/tests/unit/distributions/test_support.py b/tests/unit/distributions/test_support.py index 11d70bf..a199ac3 100644 --- a/tests/unit/distributions/test_support.py +++ b/tests/unit/distributions/test_support.py @@ -40,7 +40,6 @@ class TestContinuousSupport: ], ) def test_continuous_support_contains_scalar(self, point, expected_result): - assert (point in self.support_example) is expected_result assert self.support_example.contains(point) is expected_result @pytest.mark.parametrize("infinity", [-inf, inf]) @@ -48,7 +47,6 @@ def test_continuous_support_doesnt_contain_inf(self, infinity): # inf isn't considered as a number # but as a limit so support doesn't contain it even if it's a real line support = ContinuousSupport() - assert infinity not in support assert support.contains(infinity) is False @pytest.mark.parametrize( @@ -59,7 +57,6 @@ def test_continuous_support_doesnt_contain_inf(self, infinity): ], ) def test_continuous_support_contains_array(self, points, expected_result): - # np.array doesn't have `in`(__contains__) syntax result = self.support_example.contains(points) assert isinstance(result, np.ndarray) assert result.tolist() == expected_result @@ -104,7 +101,6 @@ def test_table_is_sorted_and_deduplicated(self): ], ) def test_contains_scalar(self, point, expected_result): - assert (point in self.support_example) is expected_result assert self.support_example.contains(point) is expected_result @pytest.mark.parametrize( @@ -196,7 +192,6 @@ def test_invalid_modulus_raises(self): ) def test_contains_scalar(self, support_name, point, expected_result): support = self.support_examples[support_name] - assert (point in support) is expected_result assert support.contains(point) is expected_result @pytest.mark.parametrize( diff --git a/tests/unit/families/builtins/continuous/test_exponential.py b/tests/unit/families/builtins/continuous/test_exponential.py index 6363c47..e92ce65 100644 --- a/tests/unit/families/builtins/continuous/test_exponential.py +++ b/tests/unit/families/builtins/continuous/test_exponential.py @@ -206,10 +206,10 @@ def test_exponential_support(self): assert not dist.support.right_closed # Test containment - assert dist.support.contains(0.0) is True - assert dist.support.contains(1.0) is True - assert dist.support.contains(-0.1) is False - assert dist.support.contains(float("inf")) is False + assert dist.support.contains(np.asarray(0.0)) is True + assert dist.support.contains(np.asarray(1.0)) is True + assert dist.support.contains(np.asarray(-0.1)) is False + assert dist.support.contains(np.asarray(float("inf"))) is False # Test array test_points = np.array([-0.1, 0.0, 1.0, 10.0]) diff --git a/tests/unit/families/builtins/continuous/test_normal.py b/tests/unit/families/builtins/continuous/test_normal.py index ca59349..0f47a42 100644 --- a/tests/unit/families/builtins/continuous/test_normal.py +++ b/tests/unit/families/builtins/continuous/test_normal.py @@ -226,9 +226,9 @@ def test_normal_support(self): assert not dist.support.left_closed assert not dist.support.right_closed - assert dist.support.contains(0) is True - assert dist.support.contains(float("inf")) is False - assert dist.support.contains(float("-inf")) is False + assert dist.support.contains(np.asarray(0)) is True + assert dist.support.contains(np.asarray(float("inf"))) is False + assert dist.support.contains(np.asarray(float("-inf"))) is False test_points = np.array([-500, 0, 5]) results = dist.support.contains(test_points) diff --git a/tests/unit/families/builtins/continuous/test_uniform.py b/tests/unit/families/builtins/continuous/test_uniform.py index 8d63f7f..d0cffc3 100644 --- a/tests/unit/families/builtins/continuous/test_uniform.py +++ b/tests/unit/families/builtins/continuous/test_uniform.py @@ -230,11 +230,11 @@ def test_uniform_support(self): assert dist.support.right_closed # Test containment - assert dist.support.contains(2.0) is True - assert dist.support.contains(5.0) is True - assert dist.support.contains(3.5) is True - assert dist.support.contains(1.9) is False - assert dist.support.contains(5.1) is False + assert dist.support.contains(np.asarray(2.0)) is True + assert dist.support.contains(np.asarray(5.0)) is True + assert dist.support.contains(np.asarray(3.5)) is True + assert dist.support.contains(np.asarray(1.9)) is False + assert dist.support.contains(np.asarray(5.1)) is False # Test array test_points = np.array([1.9, 2.0, 3.5, 5.0, 5.1]) diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py new file mode 100644 index 0000000..d03667c --- /dev/null +++ b/tests/unit/families/test_exponential_family.py @@ -0,0 +1,294 @@ +from collections.abc import Callable + +__author__ = "Vinogradov Ilya" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +import itertools +from typing import cast + +import numpy as np +import pytest +import scipy +from numpy.testing import assert_allclose + +from pysatl_core.distributions.support import ContinuousNDSupport, PredicateSupport +from pysatl_core.families import ( + ContinuousExponentialClassFamily, + ExponentialConjugateHyperparameters, + ExponentialFamilyParametrization, +) +from pysatl_core.families.registry import ParametricFamilyRegister +from pysatl_core.types import ( + CharacteristicName, + Interval1D, + Number, + NumericArray, + UnivariateContinuous, +) + + +def gamma_pdf(alpha: float, beta: float, x: float) -> float: + return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() # type: ignore[attr-defined] + + +def lomax_pdf(shape: float, scale: float, x: float) -> float: + return scipy.stats.lomax(c=shape, scale=scale).pdf(x).item() # type: ignore[attr-defined] + + +def exponential_log_partition(parametrization): + return np.log(-parametrization) + + +def _make_exponential_family() -> ContinuousExponentialClassFamily: + support_neg = PredicateSupport( + predicate=lambda x: bool( + ContinuousNDSupport( + intervals=[Interval1D(-np.inf, 0, left_closed=False, right_closed=False)] + ).contains(np.array([x])) + ) + ) + support_pos = PredicateSupport( + predicate=lambda x: bool( + ContinuousNDSupport( + intervals=[Interval1D(0, np.inf, left_closed=False, right_closed=False)] + ).contains(np.array([x])) + ) + ) + return ContinuousExponentialClassFamily( + name="ExponentialFamily", + log_partition=exponential_log_partition, + sufficient_statistics=lambda x: x, + normalization_constant=lambda _: 1, + parameter_space=support_neg, + sufficient_statistics_values=support_pos, + support=support_pos, + distr_type=UnivariateContinuous, + distr_parametrizations=["theta"], + ) + + +@pytest.fixture(scope="function") +def exponential_family() -> ContinuousExponentialClassFamily: + return _make_exponential_family() + + +@pytest.fixture(scope="function") +def conjugate_for_exponential() -> ContinuousExponentialClassFamily: + def transform_function(x: NumericArray) -> NumericArray: + return -x + + fam = _make_exponential_family() + conjugate_fam = fam.conjugate_prior_family.transform(transform_function) + ParametricFamilyRegister().register(conjugate_fam) + return cast( + ContinuousExponentialClassFamily, + ParametricFamilyRegister().get("TransformedExponentialFamily"), + ) + + +def test_log_density_matches_exponential_form( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + params = ExponentialFamilyParametrization(theta=np.array([-2.0])) + log_density_func = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + exponential_family.log_density, + ) + density_func = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + exponential_family.density, + ) + + log_density = log_density_func(params, np.asarray(0.5)) + density = density_func(params, np.asarray(0.5)) + + assert log_density == pytest.approx(np.log(2.0) - 1.0) + assert density == pytest.approx(2.0 * np.exp(-1.0)) + + +def test_constructor_merges_custom_characteristics() -> None: + support_neg = PredicateSupport( + predicate=lambda x: bool( + ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)]).contains(np.array([x])) + ) + ) + support_pos = PredicateSupport( + predicate=lambda x: bool( + ContinuousNDSupport(intervals=[Interval1D(0, np.inf)]).contains(np.array([x])) + ) + ) + family = ContinuousExponentialClassFamily( + name="ExponentialFamily", + log_partition=exponential_log_partition, + sufficient_statistics=lambda x: x, + normalization_constant=lambda _: 1, + parameter_space=support_neg, + sufficient_statistics_values=support_pos, + support=support_pos, + distr_type=UnivariateContinuous, + distr_parametrizations=["theta"], + distr_characteristics={CharacteristicName.CDF: lambda _params, x: x / (1 + x)}, + ) + + assert CharacteristicName.CDF in family.distr_characteristics + assert CharacteristicName.PDF in family.distr_characteristics + assert CharacteristicName.MEAN in family.distr_characteristics + assert CharacteristicName.VAR in family.distr_characteristics + + +def test_log_density_is_minus_infinity_outside_support( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + params = ExponentialFamilyParametrization(theta=np.array([-2.0])) + log_density_func = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + exponential_family.log_density, + ) + density_func = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + exponential_family.density, + ) + + assert log_density_func(params, np.asarray(-0.1)) == -np.inf + assert density_func(params, np.asarray(-0.1)) == 0.0 + + +def test_distribution_rejects_theta_outside_parameter_space( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + with pytest.raises(ValueError, match="theta belongs to parameter_space"): + exponential_family(theta=np.array([0.0]), parametrization_name="theta") + + +def test_transform_with_negation_moves_support_and_preserves_density( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + transformed = exponential_family.transform(lambda x: -x) + params = ExponentialFamilyParametrization(theta=np.array([-1.5])) + transformed_log_density = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + transformed.log_density, + ) + log_density = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + exponential_family.log_density, + ) + + assert transformed.name == "TransformedExponentialFamily" + assert transformed_log_density(params, np.asarray(-2.0)) == pytest.approx( + log_density(params, np.asarray(2.0)) + ) + assert transformed_log_density(params, np.asarray(2.0)) == -np.inf + + +def test_posterior_hyperparameters_updates_sample_without_mutating_input( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + prior = ExponentialConjugateHyperparameters( + effective_suff_stat_value=np.array([3.0]), + effective_sample_size=2.0, + ) + + posterior = exponential_family.posterior_hyperparameters(prior, sample=[0.5, 1.5]) + + assert_allclose(posterior.effective_suff_stat_value, np.array([5.0])) + assert posterior.effective_sample_size == 4.0 + assert_allclose(prior.effective_suff_stat_value, np.array([3.0])) + assert prior.effective_sample_size == 2.0 + + +def test_posterior_hyperparameters_accepts_single_observation( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + prior = ExponentialConjugateHyperparameters( + effective_suff_stat_value=np.array([3.0]), + effective_sample_size=2.0, + ) + + posterior = exponential_family.posterior_hyperparameters(prior, sample=0.5) # type: ignore[arg-type] + + assert_allclose(posterior.effective_suff_stat_value, np.array([3.5])) + assert posterior.effective_sample_size == 3.0 + + +def test_posterior_predictive_builds_family_with_posterior_parametrization( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + predictive_family = exponential_family.posterior_predictive + + assert predictive_family.name == "PosteriorPredictiveExponentialFamily" + assert predictive_family.parametrization_names == ["posterior"] + assert predictive_family.get_parametrization("posterior") is ExponentialConjugateHyperparameters + assert CharacteristicName.PDF in predictive_family.distr_characteristics + + +@pytest.mark.parametrize( + ("xi", "nu"), + itertools.product((2.0, 3.0, 4.0), (2.0, 3.0, 4.0)), +) +def test_posterior_predictive_matches_lomax_density( + exponential_family: ContinuousExponentialClassFamily, + xi: float, + nu: float, +) -> None: + predictive = exponential_family.posterior_predictive.distribution( + parametrization_name="posterior", + effective_suff_stat_value=np.array([xi]), + effective_sample_size=nu, + ) + pdf = predictive.computation_strategy.query_method("pdf", distr=predictive) + x_values = np.array([0.0, 0.5, 1.5, 3.0, 6.0]) + + actual = np.asarray([pdf(x) for x in x_values], dtype=float).reshape(-1) + expected = np.asarray([lomax_pdf(shape=nu + 1, scale=xi, x=x) for x in x_values]) + + assert_allclose(actual, expected, rtol=1e-6) + + +@pytest.mark.parametrize( + ("theta1", "theta2"), + itertools.product(range(2, 5), range(2, 5)), +) +def test_exponential_pdf(theta1, theta2, conjugate_for_exponential): + gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential + + alpha = theta2 + 1 + beta = theta1 + + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") + pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) + + x = [i / 10 for i in range(100)] + + assert_allclose([pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6) + + +@pytest.mark.parametrize( + ("theta1", "theta2"), + itertools.product(range(2, 5), range(2, 5)), +) +def test_exponential_mean(theta1, theta2, conjugate_for_exponential): + gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential + + alpha = theta2 + 1 + beta = theta1 + + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") + mean = exponential.computation_strategy.query_method(CharacteristicName.MEAN, distr=exponential) + assert np.isclose(mean(), alpha / beta, rtol=1e-6) + + +@pytest.mark.parametrize( + ("theta1", "theta2"), + itertools.product(range(2, 5), range(2, 5)), +) +def test_exponential_var(theta1, theta2, conjugate_for_exponential): + gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential + + alpha = theta2 + 1 + beta = theta1 + + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") + var = exponential.computation_strategy.query_method(CharacteristicName.VAR, distr=exponential) + assert np.isclose(var(), alpha / beta**2, rtol=1e-6)