From 23c5607f4d70eaf49dca744eec5d41d6d2989cca Mon Sep 17 00:00:00 2001 From: domosedy Date: Mon, 11 May 2026 15:34:07 +0300 Subject: [PATCH 1/9] feat(exponential): added base version of executable ExponentialFamily --- src/pysatl_core/families/__init__.py | 8 + .../families/exponential_family.py | 228 ++++++++++++++++++ .../unit/families/test_exponential_family.py | 53 ++++ 3 files changed, 289 insertions(+) create mode 100644 src/pysatl_core/families/exponential_family.py create mode 100644 tests/unit/families/test_exponential_family.py diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py index ed305289..8ed0cb96 100644 --- a/src/pysatl_core/families/__init__.py +++ b/src/pysatl_core/families/__init__.py @@ -21,6 +21,11 @@ constraint, parametrization, ) +from .exponential_family import ( + ExponentialFamily, + ExponentialClassParametrization, + ExponentialConjugateHyperparameters, +) from .registry import ParametricFamilyRegister __all__ = [ @@ -34,6 +39,9 @@ "configure_families_register", # builtins *_builtins_all, + "ExponentialFamily", + "ExponentialClassParametrization", + "ExponentialConjugateHyperparameters", ] del _builtins_all diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py new file mode 100644 index 00000000..d87f93e7 --- /dev/null +++ b/src/pysatl_core/families/exponential_family.py @@ -0,0 +1,228 @@ +from __future__ import annotations +from collections.abc import Callable +from dataclasses import dataclass +import math +from typing import Any, cast +from scipy.integrate import nquad, quad +import numpy as np + +from pysatl_core.distributions.fitters import _ppf_brentq_from_cdf +from pysatl_core.families.parametric_family import ( + ParametricFamily, +) +from pysatl_core.families.parametrizations import Parametrization, parametrization +from pysatl_core.types import ( + DistributionType, + ParametrizationName, +) +from pysatl_core.distributions import ( + SamplingStrategy, +) + +PDF = "pdf" +CDF = "cdf" +PPF = "ppf" +CF = "char_func" +MEAN = "mean" +VAR = "var" +SKEW = "skewness" +KURT = "kurtosis" + + +class ExponentialClassParametrization(Parametrization): + """ + Standard parametrization of Exponential Family. + """ + + theta: list[Callable[[float], float]] # TODO: mb more clever + + +class ExponentialConjugateHyperparameters: + def __init__(self, alpha: Any, beta: int): + self.alpha = alpha + self.beta = beta + + def __str__(self): + return f"alpha={self.alpha}, beta={self.beta}" + + +def accepts(x, support): + if not hasattr(x, "__len__"): + x = [x] + + def accept_1D(x, borders): + left, right = borders + return left <= x <= right + + return all(accept_1D(x_i, border) for x_i, border in zip(x, support)) + + +class ExponentialFamily(ParametricFamily): + def __init__( + self, + *, + A: Callable[[ExponentialClassParametrization], float], + T: Callable[[Any], Any], + h: Callable[[Any], float], + eta: Callable[[Any], Any], + support: list[tuple[float, float]], + param_space: list[tuple[float, float]], + natural_param_space: list[tuple[float, float]], + name: str = "ExponentialFamily", + theta_from_eta: Callable[[Any], Any] = None, + distr_type: DistributionType | Callable[[Parametrization], DistributionType], + distr_parametrizations: list[ParametrizationName], + sampling_strategy: SamplingStrategy, + support_by_parametrization: SupportArg = None, + ): + + self._A = A + self._T = T + self._h = h + + self._eta = eta if eta is not None else (lambda th: th) + self._theta_from_eta = theta_from_eta + self._natural_param_space = natural_param_space + self._param_space = param_space + self._support = support + + distr_characteristics = { + PDF: self.density, + MEAN: self._mean, + VAR: self._var, + } + + ParametricFamily.__init__( + self, + name=name, + distr_type=distr_type, + distr_parametrizations=distr_parametrizations, + distr_characteristics=distr_characteristics, + sampling_strategy=sampling_strategy, + support_by_parametrization=support_by_parametrization, + ) + parametrization(family=self, name="theta")((ExponentialClassParametrization)) + + @property + def log_density(self) -> ParametrizedFunction: + def log_density_func( + parametrization: ExponentialClassParametrization, x: Any + ) -> Any: + if not accepts(x, self._support): + return float("-inf") + + params = cast(ExponentialClassParametrization, parametrization) + theta = params.parameters.get("theta") + eta = self._eta(theta) + sufficient = self._T(x) + dot = np.dot(eta, sufficient) + + result = float(np.log(self._h(x)) + dot + self._A(parametrization)) + return result + + return log_density_func + + @property + def density(self) -> ParametrizedFunction: + return lambda parametrization, x: np.exp(self.log_density(parametrization, x)) + + @property + def conjugate_prior_family(self): + def conjugate_sufficient(eta: Any): + theta = [self._theta_from_eta(eta)] + if not accepts(theta, self._param_space): + return [float("-inf"), float("-inf")] + + return [eta, self._A(ExponentialClassParametrization(theta=theta))] + + def conjugate_log_partition(parametrization: ExponentialClassParametrization): + alpha = parametrization.theta[0] + beta = parametrization.theta[1] + + def pdf(eta: Any): + theta = self._theta_from_eta(eta) + if not hasattr(theta, "__len__"): + theta = [theta] + parametrization = ExponentialClassParametrization( + theta=theta, + ) + return np.exp(np.dot(eta, alpha) + beta * self._A(parametrization)) + + all_value = nquad(pdf, self._natural_param_space)[0] + return -np.log(all_value) + + if self._theta_from_eta is None: + raise RuntimeError("Theta from eta wasn't specified") + + return ExponentialFamily( + A=conjugate_log_partition, + T=conjugate_sufficient, + h=lambda _: 1, + eta=lambda x: x, + theta_from_eta=lambda eta: eta, + support=self._natural_param_space, + natural_param_space=[(float("-inf"), float("inf"))] * 2, + param_space=[(float("-inf"), float("inf"))] * 2, + sampling_strategy=self.sampling_strategy, + distr_type=self._distr_type, + distr_parametrizations=self.parametrization_names, + support_by_parametrization=self.support_resolver, + ) + + @property + def _mean(self) -> ParametrizedFunction: + def mean_func(parametrization: Parametrization, x: Any) -> Any: + if hasattr(x, "__len__"): + dimension_size = len(x) + else: + dimension_size = 1 + print(dimension_size) + return nquad( + lambda x: np.dot(x, self.density(parametrization, x)), + [(float("-inf"), float("inf"))] * dimension_size, + )[0] + + return mean_func + + @property + def _second_moment(self) -> ParametrizedFunction: + def func(parametrization: Parametrization, x: Any) -> Any: + if hasattr(x, "__len__"): + dimension_size = len(x) + else: + dimension_size = 1 + return nquad( + lambda x: x**2 * self.density(parametrization, x), + [(float("-inf"), float("inf"))] * dimension_size, + )[0] + + return func + + @property + def _var(self): + def func(parametrization, x: Any): + return ( + self._second_moment(parametrization, x) + - self._mean(parametrization, x) ** 2 + ) + + return func + + def posterior_hyperparameters( + self, prior_hyper: ExponentialConjugateHyperparameters, sample + ): + alpha = prior_hyper.alpha + beta = prior_hyper.beta + + alpha_post = None + beta_post = None + if hasattr(sample, "__iter__") and not isinstance(sample, str): + alpha_post = np.sum([self._T(x) for x in sample], axis=0) + beta_post = len(sample) + else: + alpha_post = self.T(sample) + beta_post = 1 + + return ExponentialConjugateHyperparameters( + alpha=alpha + alpha_post, beta=beta + beta_post + ) diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py new file mode 100644 index 00000000..620ec325 --- /dev/null +++ b/tests/unit/families/test_exponential_family.py @@ -0,0 +1,53 @@ +from typing import cast +import pytest +import numpy as np + +# from pysatl_core.distributions.computation import PDF +from pysatl_core.distributions.strategies import DefaultSamplingUnivariateStrategy +from pysatl_core.distributions.support import ContinuousSupport +from pysatl_core.families import ( + ExponentialFamily, + ExponentialConjugateHyperparameters, + ExponentialClassParametrization, +) +from pysatl_core.families.registry import ParametricFamilyRegister +from pysatl_core.types import UnivariateContinuous +import math + + +# TODO: WRITE TEEEEEEESTS +def test_exponential(): + pass + # fam = ExponentialFamily( + # A=lambda parametrization: np.log(parametrization.theta[0]), + # T=lambda x: x, + # h=lambda _: 1, + # eta=lambda theta: -1 * theta, + # theta_from_eta=lambda eta: -1 * eta, + # param_space=[(0, float("+inf"))], + # support=[(0, float("+inf"))], + # natural_param_space=[(float("-inf"), 0)], + # distr_type=UnivariateContinuous, + # distr_parametrizations=["theta"], + # sampling_strategy=DefaultSamplingUnivariateStrategy(), + # ) + + # conjugate_fam = fam + # conjugate_fam = fam.conjugate_prior_family + # params = ExponentialClassParametrization(theta=np.array([2, -1])) + # print(conjugate_fam._A(params)) + # ParametricFamilyRegister().register(conjugate_fam) + # # print( + # # fam.posterior_hyperparameters( + # # ExponentialConjugateHyperparameters(alpha=10, beta=1), [12] + # # ) + # # ) + # gamma_family: ExponentialFamily = cast( + # ExponentialFamily, ParametricFamilyRegister().get("ExponentialFamily") + # ) + # print(type(gamma_family)) + # # conjugate = gamma_family.conjugate_prior_family + # # exponential = gamma_family(theta=np.array([2]), parametrization_name="theta") + # exponential = gamma_family(theta=np.array([0, -1]), parametrization_name="theta") + # pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) + # print(pdf(-1)) From 82e959a7269d4030b86d16d28fa2470a8e862139 Mon Sep 17 00:00:00 2001 From: domosedy Date: Mon, 11 May 2026 15:34:23 +0300 Subject: [PATCH 2/9] feat(exponential): added new class structure and manual testing for conjugate prior --- src/pysatl_core/families/__init__.py | 12 +- .../families/exponential_family.py | 221 ++++++++++++------ .../unit/families/test_exponential_family.py | 118 +++++++--- 3 files changed, 247 insertions(+), 104 deletions(-) diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py index 8ed0cb96..3fd94c7d 100644 --- a/src/pysatl_core/families/__init__.py +++ b/src/pysatl_core/families/__init__.py @@ -22,9 +22,12 @@ parametrization, ) from .exponential_family import ( - ExponentialFamily, - ExponentialClassParametrization, ExponentialConjugateHyperparameters, + ExponentialFamily, + ExponentialFamilyParametrization, + NaturalExponentialFamily, + SpacePredicate, + SpacePredicateArray, ) from .registry import ParametricFamilyRegister @@ -40,8 +43,11 @@ # builtins *_builtins_all, "ExponentialFamily", - "ExponentialClassParametrization", + "ExponentialFamilyParametrization", "ExponentialConjugateHyperparameters", + "SpacePredicate", + "SpacePredicateArray", + "NaturalExponentialFamily", ] del _builtins_all diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index d87f93e7..dfa14f11 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -1,15 +1,11 @@ from __future__ import annotations from collections.abc import Callable -from dataclasses import dataclass -import math -from typing import Any, cast +from typing import Any, cast, TYPE_CHECKING from scipy.integrate import nquad, quad import numpy as np from pysatl_core.distributions.fitters import _ppf_brentq_from_cdf -from pysatl_core.families.parametric_family import ( - ParametricFamily, -) +from pysatl_core.families.parametric_family import ParametricFamily from pysatl_core.families.parametrizations import Parametrization, parametrization from pysatl_core.types import ( DistributionType, @@ -19,6 +15,13 @@ SamplingStrategy, ) +if TYPE_CHECKING: + from pysatl_core.distributions.support import Support + + type ParametrizedFunction = Callable[[Parametrization, Any], Any] + type SupportArg = Callable[[Parametrization], Support | None] | None + + PDF = "pdf" CDF = "cdf" PPF = "ppf" @@ -29,7 +32,7 @@ KURT = "kurtosis" -class ExponentialClassParametrization(Parametrization): +class ExponentialFamilyParametrization(Parametrization): """ Standard parametrization of Exponential Family. """ @@ -46,45 +49,56 @@ def __str__(self): return f"alpha={self.alpha}, beta={self.beta}" -def accepts(x, support): +def doesAccept(x, support): if not hasattr(x, "__len__"): x = [x] def accept_1D(x, borders): left, right = borders + if abs(x) == 0 and (abs(left) == 0 or abs(right) == 0): + return False return left <= x <= right return all(accept_1D(x_i, border) for x_i, border in zip(x, support)) -class ExponentialFamily(ParametricFamily): +class SpacePredicate: + def __init__(self, predicate: Callable[[Any], bool]): + self._predicate = predicate + + def accepts(self, x: Any) -> bool: + return self._predicate(x) + + +class SpacePredicateArray(SpacePredicate): + def __init__(self, space: list[tuple[float, float]]): + SpacePredicate.__init__(self, lambda x: doesAccept(x, space)) + self._space = space + + +class NaturalExponentialFamily(ParametricFamily): def __init__( self, *, - A: Callable[[ExponentialClassParametrization], float], - T: Callable[[Any], Any], - h: Callable[[Any], float], - eta: Callable[[Any], Any], - support: list[tuple[float, float]], - param_space: list[tuple[float, float]], - natural_param_space: list[tuple[float, float]], - name: str = "ExponentialFamily", - theta_from_eta: Callable[[Any], Any] = None, + log_partition: Callable[[ExponentialFamilyParametrization], float], + sufficient_statistics: Callable[[Any], Any], + normalization_constant: Callable[[Any], Any], + support: SpacePredicate, + parameter_space: SpacePredicate, + sufficient_statistics_values: SpacePredicate, + name: str = "NaturalExponentialFamily", distr_type: DistributionType | Callable[[Parametrization], DistributionType], distr_parametrizations: list[ParametrizationName], sampling_strategy: SamplingStrategy, support_by_parametrization: SupportArg = None, ): + self._sufficient = sufficient_statistics + self._log_partition = log_partition + self._normalization = normalization_constant - self._A = A - self._T = T - self._h = h - - self._eta = eta if eta is not None else (lambda th: th) - self._theta_from_eta = theta_from_eta - self._natural_param_space = natural_param_space - self._param_space = param_space self._support = support + self._parameter_space = parameter_space + self._sufficient_statistics_values = sufficient_statistics_values distr_characteristics = { PDF: self.density, @@ -101,23 +115,25 @@ def __init__( sampling_strategy=sampling_strategy, support_by_parametrization=support_by_parametrization, ) - parametrization(family=self, name="theta")((ExponentialClassParametrization)) + parametrization(family=self, name="theta")((ExponentialFamilyParametrization)) @property def log_density(self) -> ParametrizedFunction: def log_density_func( - parametrization: ExponentialClassParametrization, x: Any + parametrization: ExponentialFamilyParametrization, x: Any ) -> Any: - if not accepts(x, self._support): + if not self._support.accepts(x): return float("-inf") - params = cast(ExponentialClassParametrization, parametrization) + params = cast(ExponentialFamilyParametrization, parametrization) theta = params.parameters.get("theta") - eta = self._eta(theta) - sufficient = self._T(x) - dot = np.dot(eta, sufficient) - - result = float(np.log(self._h(x)) + dot + self._A(parametrization)) + sufficient = self._sufficient(x) + dot = np.dot(theta, sufficient) + result = float( + np.log(self._normalization(x)) + + dot + + self._log_partition(parametrization) + ) return result return log_density_func @@ -128,41 +144,59 @@ def density(self) -> ParametrizedFunction: @property def conjugate_prior_family(self): - def conjugate_sufficient(eta: Any): - theta = [self._theta_from_eta(eta)] - if not accepts(theta, self._param_space): + def conjugate_sufficient(theta: Any): + if not self._parameter_space.accepts(theta): return [float("-inf"), float("-inf")] - return [eta, self._A(ExponentialClassParametrization(theta=theta))] + return [ + theta, + self._log_partition(ExponentialFamilyParametrization(theta=[theta])), + ] - def conjugate_log_partition(parametrization: ExponentialClassParametrization): + def conjugate_log_partition(parametrization: ExponentialFamilyParametrization): alpha = parametrization.theta[0] beta = parametrization.theta[1] - def pdf(eta: Any): - theta = self._theta_from_eta(eta) + def pdf(theta: Any): if not hasattr(theta, "__len__"): theta = [theta] - parametrization = ExponentialClassParametrization( + parametrization = ExponentialFamilyParametrization( theta=theta, ) - return np.exp(np.dot(eta, alpha) + beta * self._A(parametrization)) + return np.exp( + np.dot(theta, alpha) + beta * self._log_partition(parametrization) + )[0] - all_value = nquad(pdf, self._natural_param_space)[0] + all_value = nquad( + lambda x: pdf(x) if self._parameter_space.accepts(x) else 0, + [(float("-inf"), float("+inf"))], + )[0] return -np.log(all_value) - if self._theta_from_eta is None: - raise RuntimeError("Theta from eta wasn't specified") - - return ExponentialFamily( - A=conjugate_log_partition, - T=conjugate_sufficient, - h=lambda _: 1, - eta=lambda x: x, - theta_from_eta=lambda eta: eta, - support=self._natural_param_space, - natural_param_space=[(float("-inf"), float("inf"))] * 2, - param_space=[(float("-inf"), float("inf"))] * 2, + # TODO: remove hardcoding - Done, all hardcoding is only on user's hands + # 1. pr with prototype/draft - in progress + # 2. write instruction about to add distributions as member of exponential family - not started + # 3. parametrization's spaces (передавать в конструктор) - maybe impossible, discuss this with desiment on meeting + + def conjugate_sufficient_accepts( + parametrization: ExponentialFamilyParametrization, + ): + parametrization = cast(parametrization, ExponentialFamilyParametrization) + theta = parametrization.parameters.get("theta") + xi = theta[:-1] + nu = theta[-1] + + return self._sufficient_statistics_values(xi) and SpacePredicateArray( + [(0, float("+inf"))] + ).accepts(nu) + + return NaturalExponentialFamily( + log_partition=conjugate_log_partition, + sufficient_statistics=conjugate_sufficient, + normalization_constant=lambda _: 1, + support=self._parameter_space, + sufficient_statistics_values=self._parameter_space, # TODO: write convex hull for this + parameter_space=SpacePredicate(conjugate_sufficient_accepts), sampling_strategy=self.sampling_strategy, distr_type=self._distr_type, distr_parametrizations=self.parametrization_names, @@ -172,13 +206,15 @@ def pdf(eta: Any): @property def _mean(self) -> ParametrizedFunction: def mean_func(parametrization: Parametrization, x: Any) -> Any: + dimension_size = 1 if hasattr(x, "__len__"): dimension_size = len(x) - else: - dimension_size = 1 - print(dimension_size) return nquad( - lambda x: np.dot(x, self.density(parametrization, x)), + lambda x: ( + np.dot(x, self.density(parametrization, x)) + if self._support.accepts(x) + else 0 + ), [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -187,12 +223,15 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any: @property def _second_moment(self) -> ParametrizedFunction: def func(parametrization: Parametrization, x: Any) -> Any: + dimension_size = 1 if hasattr(x, "__len__"): dimension_size = len(x) - else: - dimension_size = 1 return nquad( - lambda x: x**2 * self.density(parametrization, x), + lambda x: ( + x**2 * self.density(parametrization, x) + if self._support.accepts(x) + else 0 + ), [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -217,12 +256,62 @@ def posterior_hyperparameters( alpha_post = None beta_post = None if hasattr(sample, "__iter__") and not isinstance(sample, str): - alpha_post = np.sum([self._T(x) for x in sample], axis=0) + alpha_post = np.sum([self._sufficient(x) for x in sample], axis=0) beta_post = len(sample) else: - alpha_post = self.T(sample) + alpha_post = self._sufficient(sample) beta_post = 1 return ExponentialConjugateHyperparameters( alpha=alpha + alpha_post, beta=beta + beta_post ) + + +class ExponentialFamily(NaturalExponentialFamily): + def __init__( + self, + *, + log_partition: Callable[[ExponentialFamilyParametrization], float], + sufficient_statistics: Callable[[Any], Any], + normalization_constant: Callable[[Any], Any], + parameter_from_natural_parameter: Callable[[Any], Any], + support: SpacePredicate, + parameter_space: SpacePredicate, + sufficient_statistics_values: SpacePredicate, + distr_type: DistributionType | Callable[[Parametrization], DistributionType], + distr_parametrizations: list[ParametrizationName], + sampling_strategy: SamplingStrategy, + name: str = "ExponentialFamily", + support_by_parametrization: SupportArg = None, + ): + def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization): + eta_parametrizaion = cast( + ExponentialFamilyParametrization, eta_parametrizaion + ) + eta = eta_parametrizaion.parameters.get("theta") + theta = parameter_from_natural_parameter(eta) + return log_partition(ExponentialFamilyParametrization(theta=[theta])) + + natural_sufficient_statistics_values = SpacePredicate( + lambda eta: sufficient_statistics_values.accepts( + parameter_from_natural_parameter(eta) + ) + ) + natural_parameter_space = SpacePredicate( + lambda eta: parameter_space.accepts(parameter_from_natural_parameter(eta)), + ) + + NaturalExponentialFamily.__init__( + self, + log_partition=natural_log_partition, + sufficient_statistics=sufficient_statistics, + normalization_constant=normalization_constant, + support=support, + parameter_space=natural_parameter_space, + sufficient_statistics_values=natural_sufficient_statistics_values, + name=name, + distr_parametrizations=distr_parametrizations, + distr_type=distr_type, + sampling_strategy=sampling_strategy, + support_by_parametrization=support_by_parametrization, + ) diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py index 620ec325..171d0523 100644 --- a/tests/unit/families/test_exponential_family.py +++ b/tests/unit/families/test_exponential_family.py @@ -1,53 +1,101 @@ -from typing import cast -import pytest import numpy as np +import pytest +import scipy +from typing import cast # from pysatl_core.distributions.computation import PDF from pysatl_core.distributions.strategies import DefaultSamplingUnivariateStrategy -from pysatl_core.distributions.support import ContinuousSupport from pysatl_core.families import ( ExponentialFamily, - ExponentialConjugateHyperparameters, - ExponentialClassParametrization, + ExponentialFamilyParametrization, + SpacePredicateArray, ) from pysatl_core.families.registry import ParametricFamilyRegister from pysatl_core.types import UnivariateContinuous -import math -# TODO: WRITE TEEEEEEESTS +# TODO: WRITE TEEEEEEESTS, MANY TESTS. def test_exponential(): - pass - # fam = ExponentialFamily( - # A=lambda parametrization: np.log(parametrization.theta[0]), - # T=lambda x: x, - # h=lambda _: 1, - # eta=lambda theta: -1 * theta, - # theta_from_eta=lambda eta: -1 * eta, - # param_space=[(0, float("+inf"))], - # support=[(0, float("+inf"))], - # natural_param_space=[(float("-inf"), 0)], + # pass + # fam = NaturalExponentialFamily( + # log_partition=lambda parametrization: np.log(-parametrization.theta[0]), + # sufficient_statistics=lambda x: x, + # normalization_constant=lambda _: 1, + # # param_space=SpacePredicateArray([(0, float("+inf"))]), + # support=SpacePredicateArray([(0, float("+inf"))]), + # parameter_space=SpacePredicateArray([(float("-inf"), 0)]), + # sufficient_statistics_values=SpacePredicateArray([(0, float("+inf"))]), # distr_type=UnivariateContinuous, # distr_parametrizations=["theta"], # sampling_strategy=DefaultSamplingUnivariateStrategy(), # ) - # conjugate_fam = fam - # conjugate_fam = fam.conjugate_prior_family - # params = ExponentialClassParametrization(theta=np.array([2, -1])) - # print(conjugate_fam._A(params)) - # ParametricFamilyRegister().register(conjugate_fam) - # # print( - # # fam.posterior_hyperparameters( - # # ExponentialConjugateHyperparameters(alpha=10, beta=1), [12] - # # ) - # # ) - # gamma_family: ExponentialFamily = cast( - # ExponentialFamily, ParametricFamilyRegister().get("ExponentialFamily") + def get_parameter_from_natural_parameter( + eta_parametrization: ExponentialFamilyParametrization, + ): + if hasattr(eta_parametrization, "__len__"): + if len(eta_parametrization) > 1: + return list(-1 * np.array(eta_parametrization)) + eta_parametrization = eta_parametrization[0] + return -eta_parametrization + + fam = ExponentialFamily( + log_partition=lambda parametrization: np.log(parametrization.theta[0]), + sufficient_statistics=lambda x: x, + normalization_constant=lambda _: 1, + parameter_from_natural_parameter=get_parameter_from_natural_parameter, + parameter_space=SpacePredicateArray([(0, float("+inf"))]), + sufficient_statistics_values=SpacePredicateArray([(0, float("+inf"))]), + support=SpacePredicateArray([(0, float("+inf"))]), + distr_type=UnivariateContinuous, + distr_parametrizations=["theta"], + sampling_strategy=DefaultSamplingUnivariateStrategy(), + ) + + conjugate_fam = fam + conjugate_fam = fam.conjugate_prior_family + ParametricFamilyRegister().register(conjugate_fam) + # print( + # fam.posterior_hyperparameters( + # ExponentialConjugateHyperparameters(alpha=10, beta=1), [12] + # ) # ) - # print(type(gamma_family)) - # # conjugate = gamma_family.conjugate_prior_family - # # exponential = gamma_family(theta=np.array([2]), parametrization_name="theta") - # exponential = gamma_family(theta=np.array([0, -1]), parametrization_name="theta") - # pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) - # print(pdf(-1)) + gamma_family: ExponentialFamily = cast( + ExponentialFamily, ParametricFamilyRegister().get("NaturalExponentialFamily") + ) + print(type(gamma_family)) + # conjugate = gamma_family.conjugate_prior_family + # exponential = gamma_family(theta=np.array([2]), parametrization_name="theta") + theta1 = 4 + theta2 = 4 + + alpha = theta2 + 1 + beta = theta1 + + exponential = gamma_family( + theta=np.array([theta1, theta2]), parametrization_name="theta" + ) + pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) + + def gamma_pdf(alpha: float, beta: float, x: float): + return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() + + x = [i / 10 for i in range(-100, 100)] + # print(pdf(-x)) + import matplotlib.pyplot as plt + + plt.plot(x, [pdf(-xx) for xx in x], label="conjugate") + plt.plot( + x, + [gamma_pdf(alpha, beta, xx) for xx in x], + label=f"gamma({alpha}, {beta}) test", + ) + + from scipy.integrate import quad + + print(quad(pdf, float("-inf"), float("inf"))) + # mean = exponential.computation_strategy.query_method("mean", distr=exponential) + # print(mean(12)) + plt.legend() + plt.savefig("a.png") + # print(gamma_pdf(alpha, beta, x)) From 2ef19cff7612bb9d561e3e219b9bd05cc3862f18 Mon Sep 17 00:00:00 2001 From: domosedy Date: Mon, 11 May 2026 15:34:39 +0300 Subject: [PATCH 3/9] feat(exponential): added transform method to ExponentialFamily --- src/pysatl_core/families/__init__.py | 14 +- .../families/exponential_family.py | 152 +++++++++++++----- .../unit/families/test_exponential_family.py | 114 +++++++------ 3 files changed, 184 insertions(+), 96 deletions(-) diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py index 3fd94c7d..44994c25 100644 --- a/src/pysatl_core/families/__init__.py +++ b/src/pysatl_core/families/__init__.py @@ -14,13 +14,6 @@ from .builtins import __all__ as _builtins_all from .configuration import configure_families_register from .distribution import ParametricFamilyDistribution -from .parametric_family import ParametricFamily -from .parametrizations import ( - Parametrization, - ParametrizationConstraint, - constraint, - parametrization, -) from .exponential_family import ( ExponentialConjugateHyperparameters, ExponentialFamily, @@ -29,6 +22,13 @@ SpacePredicate, SpacePredicateArray, ) +from .parametric_family import ParametricFamily +from .parametrizations import ( + Parametrization, + ParametrizationConstraint, + constraint, + parametrization, +) from .registry import ParametricFamilyRegister __all__ = [ diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index dfa14f11..f2c0fcc3 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -1,19 +1,24 @@ from __future__ import annotations + from collections.abc import Callable -from typing import Any, cast, TYPE_CHECKING -from scipy.integrate import nquad, quad +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Iterable, Sized, cast + import numpy as np +from scipy.integrate import nquad +from scipy.linalg import det +from scipy.differentiate import jacobian -from pysatl_core.distributions.fitters import _ppf_brentq_from_cdf +from pysatl_core.distributions import ( + SamplingStrategy, +) from pysatl_core.families.parametric_family import ParametricFamily from pysatl_core.families.parametrizations import Parametrization, parametrization from pysatl_core.types import ( + GenericCharacteristicName, DistributionType, ParametrizationName, ) -from pysatl_core.distributions import ( - SamplingStrategy, -) if TYPE_CHECKING: from pysatl_core.distributions.support import Support @@ -32,12 +37,13 @@ KURT = "kurtosis" +@dataclass class ExponentialFamilyParametrization(Parametrization): """ Standard parametrization of Exponential Family. """ - theta: list[Callable[[float], float]] # TODO: mb more clever + theta: list[float] # TODO: mb more clever class ExponentialConjugateHyperparameters: @@ -45,21 +51,23 @@ def __init__(self, alpha: Any, beta: int): self.alpha = alpha self.beta = beta - def __str__(self): + def __str__(self) -> str: return f"alpha={self.alpha}, beta={self.beta}" -def doesAccept(x, support): +def doesAccept(x: list[float] | float, support: list[tuple[float, float]]) -> bool: if not hasattr(x, "__len__"): x = [x] - def accept_1D(x, borders): + x = cast(list[float], x) + + def accept_1D(x: float, borders: tuple[float, float]) -> bool: left, right = borders if abs(x) == 0 and (abs(left) == 0 or abs(right) == 0): return False return left <= x <= right - return all(accept_1D(x_i, border) for x_i, border in zip(x, support)) + return all(accept_1D(x_i, border) for x_i, border in zip(x, support, strict=False)) class SpacePredicate: @@ -100,7 +108,10 @@ def __init__( self._parameter_space = parameter_space self._sufficient_statistics_values = sufficient_statistics_values - distr_characteristics = { + distr_characteristics: dict[ + GenericCharacteristicName, + dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction, + ] = { PDF: self.density, MEAN: self._mean, VAR: self._var, @@ -115,20 +126,29 @@ def __init__( sampling_strategy=sampling_strategy, support_by_parametrization=support_by_parametrization, ) - parametrization(family=self, name="theta")((ExponentialFamilyParametrization)) + parametrization(family=self, name="theta")(ExponentialFamilyParametrization) + + def _transform_to_natural_parametrization( + self, theta_parametrization: ExponentialFamilyParametrization + ) -> ExponentialFamilyParametrization: + return theta_parametrization @property def log_density(self) -> ParametrizedFunction: - def log_density_func( - parametrization: ExponentialFamilyParametrization, x: Any - ) -> Any: + def log_density_func(parametrization: Parametrization, x: Any) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) + parametrization = self._transform_to_natural_parametrization( + parametrization + ) if not self._support.accepts(x): return float("-inf") - params = cast(ExponentialFamilyParametrization, parametrization) - theta = params.parameters.get("theta") + theta = parametrization.theta sufficient = self._sufficient(x) dot = np.dot(theta, sufficient) + if hasattr(dot, "__len__"): + dot = dot[0] + result = float( np.log(self._normalization(x)) + dot @@ -143,26 +163,31 @@ def density(self) -> ParametrizedFunction: return lambda parametrization, x: np.exp(self.log_density(parametrization, x)) @property - def conjugate_prior_family(self): - def conjugate_sufficient(theta: Any): + def conjugate_prior_family(self) -> NaturalExponentialFamily: + def conjugate_sufficient( + theta: float, + ) -> list[Any]: if not self._parameter_space.accepts(theta): return [float("-inf"), float("-inf")] + parametrization = ExponentialFamilyParametrization([theta]) + # parametrization.theta = [theta] return [ theta, - self._log_partition(ExponentialFamilyParametrization(theta=[theta])), + self._log_partition(parametrization), ] - def conjugate_log_partition(parametrization: ExponentialFamilyParametrization): + def conjugate_log_partition( + parametrization: ExponentialFamilyParametrization, + ) -> Any: alpha = parametrization.theta[0] beta = parametrization.theta[1] - def pdf(theta: Any): + def pdf(theta: Any) -> Any: if not hasattr(theta, "__len__"): theta = [theta] - parametrization = ExponentialFamilyParametrization( - theta=theta, - ) + parametrization = ExponentialFamilyParametrization(theta=theta) + # parametrization.theta = theta return np.exp( np.dot(theta, alpha) + beta * self._log_partition(parametrization) )[0] @@ -180,15 +205,14 @@ def pdf(theta: Any): def conjugate_sufficient_accepts( parametrization: ExponentialFamilyParametrization, - ): - parametrization = cast(parametrization, ExponentialFamilyParametrization) - theta = parametrization.parameters.get("theta") + ) -> bool: + theta = parametrization.theta xi = theta[:-1] nu = theta[-1] - return self._sufficient_statistics_values(xi) and SpacePredicateArray( - [(0, float("+inf"))] - ).accepts(nu) + return self._sufficient_statistics_values.accepts( + xi + ) and SpacePredicateArray([(0, float("+inf"))]).accepts(nu) return NaturalExponentialFamily( log_partition=conjugate_log_partition, @@ -197,15 +221,50 @@ def conjugate_sufficient_accepts( support=self._parameter_space, sufficient_statistics_values=self._parameter_space, # TODO: write convex hull for this parameter_space=SpacePredicate(conjugate_sufficient_accepts), + name=self.name, sampling_strategy=self.sampling_strategy, distr_type=self._distr_type, distr_parametrizations=self.parametrization_names, support_by_parametrization=self.support_resolver, ) + def transform( + self, + transform_function: Callable[[Any], Any], + ) -> NaturalExponentialFamily: + def calculate_jacobian(x: Any) -> Any: + if type(x) is not list: + x = np.array([x]) + + return np.abs(det(jacobian(transform_function, x).df)) + + def new_support(x: Any) -> bool: + return self._support.accepts(transform_function(x)) + + def new_sufficient(x: Any) -> Any: + return self._sufficient(transform_function(x)) + + def new_normalization(x: Any) -> Any: + return self._normalization(x) * calculate_jacobian(x) + + return NaturalExponentialFamily( + log_partition=self._log_partition, + sufficient_statistics=new_sufficient, + normalization_constant=new_normalization, + support=SpacePredicate(new_support), + parameter_space=self._parameter_space, + sufficient_statistics_values=self._sufficient_statistics_values, + name=f"Transformed{self._name}", + distr_type=self._distr_type, + distr_parametrizations=self.parametrization_names, + sampling_strategy=self.sampling_strategy, + support_by_parametrization=self.support_resolver, + ) + @property def _mean(self) -> ParametrizedFunction: def mean_func(parametrization: Parametrization, x: Any) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) dimension_size = 1 if hasattr(x, "__len__"): dimension_size = len(x) @@ -223,6 +282,7 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any: @property def _second_moment(self) -> ParametrizedFunction: def func(parametrization: Parametrization, x: Any) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) dimension_size = 1 if hasattr(x, "__len__"): dimension_size = len(x) @@ -238,8 +298,9 @@ def func(parametrization: Parametrization, x: Any) -> Any: return func @property - def _var(self): - def func(parametrization, x: Any): + def _var(self) -> ParametrizedFunction: + def func(parametrization: Parametrization, x: Any) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) return ( self._second_moment(parametrization, x) - self._mean(parametrization, x) ** 2 @@ -248,8 +309,8 @@ def func(parametrization, x: Any): return func def posterior_hyperparameters( - self, prior_hyper: ExponentialConjugateHyperparameters, sample - ): + self, prior_hyper: ExponentialConjugateHyperparameters, sample: list[Any] + ) -> ExponentialConjugateHyperparameters: alpha = prior_hyper.alpha beta = prior_hyper.beta @@ -275,6 +336,9 @@ def __init__( sufficient_statistics: Callable[[Any], Any], normalization_constant: Callable[[Any], Any], parameter_from_natural_parameter: Callable[[Any], Any], + natural_parameter: Callable[ + [ExponentialFamilyParametrization], ExponentialFamilyParametrization + ], support: SpacePredicate, parameter_space: SpacePredicate, sufficient_statistics_values: SpacePredicate, @@ -284,11 +348,10 @@ def __init__( name: str = "ExponentialFamily", support_by_parametrization: SupportArg = None, ): - def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization): - eta_parametrizaion = cast( - ExponentialFamilyParametrization, eta_parametrizaion - ) - eta = eta_parametrizaion.parameters.get("theta") + def natural_log_partition( + eta_parametrizaion: ExponentialFamilyParametrization, + ) -> Any: + eta = eta_parametrizaion.theta theta = parameter_from_natural_parameter(eta) return log_partition(ExponentialFamilyParametrization(theta=[theta])) @@ -297,6 +360,8 @@ def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization): parameter_from_natural_parameter(eta) ) ) + + self._natural_parameter = natural_parameter natural_parameter_space = SpacePredicate( lambda eta: parameter_space.accepts(parameter_from_natural_parameter(eta)), ) @@ -315,3 +380,8 @@ def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization): sampling_strategy=sampling_strategy, support_by_parametrization=support_by_parametrization, ) + + def _transform_to_natural_parametrization( + self, theta_parametrization: ExponentialFamilyParametrization + ) -> ExponentialFamilyParametrization: + return self._natural_parameter(theta_parametrization) diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py index 171d0523..35ea7cc7 100644 --- a/tests/unit/families/test_exponential_family.py +++ b/tests/unit/families/test_exponential_family.py @@ -1,9 +1,10 @@ +from typing import Any, cast + import numpy as np import pytest import scipy -from typing import cast +from numpy.testing import assert_allclose -# from pysatl_core.distributions.computation import PDF from pysatl_core.distributions.strategies import DefaultSamplingUnivariateStrategy from pysatl_core.families import ( ExponentialFamily, @@ -14,22 +15,12 @@ from pysatl_core.types import UnivariateContinuous -# TODO: WRITE TEEEEEEESTS, MANY TESTS. -def test_exponential(): - # pass - # fam = NaturalExponentialFamily( - # log_partition=lambda parametrization: np.log(-parametrization.theta[0]), - # sufficient_statistics=lambda x: x, - # normalization_constant=lambda _: 1, - # # param_space=SpacePredicateArray([(0, float("+inf"))]), - # support=SpacePredicateArray([(0, float("+inf"))]), - # parameter_space=SpacePredicateArray([(float("-inf"), 0)]), - # sufficient_statistics_values=SpacePredicateArray([(0, float("+inf"))]), - # distr_type=UnivariateContinuous, - # distr_parametrizations=["theta"], - # sampling_strategy=DefaultSamplingUnivariateStrategy(), - # ) +def gamma_pdf(alpha: float, beta: float, x: float) -> float: + return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() + +@pytest.fixture(scope="function") +def conjugate_for_exponential() -> ExponentialFamily: def get_parameter_from_natural_parameter( eta_parametrization: ExponentialFamilyParametrization, ): @@ -39,11 +30,29 @@ def get_parameter_from_natural_parameter( eta_parametrization = eta_parametrization[0] return -eta_parametrization + def natural_parameter( + theta_parametrization: Any, + ) -> Any: + if type(theta_parametrization) is ExponentialFamilyParametrization: + theta_parametrization = cast( + ExponentialFamilyParametrization, theta_parametrization + ) + eta = -theta_parametrization.theta + return ExponentialFamilyParametrization(theta=eta) + + return -1 * theta_parametrization + + def transform_function(x: list[Any]) -> list[Any]: + if type(x) is not list: + return -x + return [-x[0]] + fam = ExponentialFamily( log_partition=lambda parametrization: np.log(parametrization.theta[0]), sufficient_statistics=lambda x: x, normalization_constant=lambda _: 1, parameter_from_natural_parameter=get_parameter_from_natural_parameter, + natural_parameter=natural_parameter, parameter_space=SpacePredicateArray([(0, float("+inf"))]), sufficient_statistics_values=SpacePredicateArray([(0, float("+inf"))]), support=SpacePredicateArray([(0, float("+inf"))]), @@ -52,22 +61,18 @@ def get_parameter_from_natural_parameter( sampling_strategy=DefaultSamplingUnivariateStrategy(), ) - conjugate_fam = fam - conjugate_fam = fam.conjugate_prior_family + conjugate_fam = fam.conjugate_prior_family.transform(transform_function) ParametricFamilyRegister().register(conjugate_fam) - # print( - # fam.posterior_hyperparameters( - # ExponentialConjugateHyperparameters(alpha=10, beta=1), [12] - # ) - # ) - gamma_family: ExponentialFamily = cast( - ExponentialFamily, ParametricFamilyRegister().get("NaturalExponentialFamily") + return cast( + ExponentialFamily, + ParametricFamilyRegister().get("TransformedExponentialFamily"), ) - print(type(gamma_family)) - # conjugate = gamma_family.conjugate_prior_family - # exponential = gamma_family(theta=np.array([2]), parametrization_name="theta") - theta1 = 4 - theta2 = 4 + + +@pytest.mark.parametrize("theta1", range(2, 5)) +@pytest.mark.parametrize("theta2", range(2, 5)) +def test_exponential_pdf(theta1, theta2, conjugate_for_exponential): + gamma_family: ExponentialFamily = conjugate_for_exponential alpha = theta2 + 1 beta = theta1 @@ -77,25 +82,38 @@ def get_parameter_from_natural_parameter( ) pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) - def gamma_pdf(alpha: float, beta: float, x: float): - return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() + x = [i / 10 for i in range(100)] + + assert_allclose( + [pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6 + ) + + +@pytest.mark.parametrize("theta1", range(2, 5)) +@pytest.mark.parametrize("theta2", range(2, 5)) +def test_exponential_mean(theta1, theta2, conjugate_for_exponential): + gamma_family: ExponentialFamily = conjugate_for_exponential - x = [i / 10 for i in range(-100, 100)] - # print(pdf(-x)) - import matplotlib.pyplot as plt + alpha = theta2 + 1 + beta = theta1 - plt.plot(x, [pdf(-xx) for xx in x], label="conjugate") - plt.plot( - x, - [gamma_pdf(alpha, beta, xx) for xx in x], - label=f"gamma({alpha}, {beta}) test", + exponential = gamma_family( + theta=np.array([theta1, theta2]), parametrization_name="theta" ) + mean = exponential.computation_strategy.query_method("mean", distr=exponential) + assert np.isclose(mean(12), alpha / beta, rtol=1e-6) + - from scipy.integrate import quad +@pytest.mark.parametrize("theta1", range(2, 5)) +@pytest.mark.parametrize("theta2", range(2, 5)) +def test_exponential_var(theta1, theta2, conjugate_for_exponential): + gamma_family: ExponentialFamily = conjugate_for_exponential - print(quad(pdf, float("-inf"), float("inf"))) - # mean = exponential.computation_strategy.query_method("mean", distr=exponential) - # print(mean(12)) - plt.legend() - plt.savefig("a.png") - # print(gamma_pdf(alpha, beta, x)) + alpha = theta2 + 1 + beta = theta1 + + exponential = gamma_family( + theta=np.array([theta1, theta2]), parametrization_name="theta" + ) + var = exponential.computation_strategy.query_method("var", distr=exponential) + assert np.isclose(var(12), alpha / beta**2, rtol=1e-6) From 8cfa0b885855b34bb6369eb78f5ba20178efb85c Mon Sep 17 00:00:00 2001 From: domosedy Date: Mon, 11 May 2026 15:34:55 +0300 Subject: [PATCH 4/9] refactor(exponential): refactoring for current version of master --- src/pysatl_core/distributions/support.py | 36 ++- src/pysatl_core/families/__init__.py | 12 +- .../families/exponential_family.py | 291 ++++++------------ src/pysatl_core/types.py | 20 ++ .../unit/families/test_exponential_family.py | 84 ++--- 5 files changed, 171 insertions(+), 272 deletions(-) diff --git a/src/pysatl_core/distributions/support.py b/src/pysatl_core/distributions/support.py index 40f49bfc..c5d32a3b 100644 --- a/src/pysatl_core/distributions/support.py +++ b/src/pysatl_core/distributions/support.py @@ -14,13 +14,20 @@ __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" +from collections.abc import Callable from dataclasses import dataclass from math import floor -from typing import TYPE_CHECKING, Protocol, cast, overload, runtime_checkable +from typing import ( + TYPE_CHECKING, + Protocol, + cast, + overload, + runtime_checkable, +) import numpy as np -from pysatl_core.types import BoolArray, Interval1D, Number, NumericArray +from pysatl_core.types import BoolArray, Interval1D, IntervalND, Number, NumericArray if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -49,6 +56,15 @@ class ContinuousSupport(Interval1D, Support): """ +class ContinuousNDSupport(IntervalND, Support): # type: ignore[misc] + """ + Support for continuous distributions represented as an array of intervals. + + This class inherits from IntervalND and implements the Support protocol + for continuous distributions defined on a list of intervals [left, right]. + """ + + @runtime_checkable class DiscreteSupport(Support, Protocol): """ @@ -430,10 +446,26 @@ def is_right_bounded(self) -> bool: __iter__ = iter_points +class SupportByPredicate: + def __init__(self, predicate: Callable[[NumericArray | Number], bool]): + self._predicate = predicate + + def __contains__(self, item: NumericArray | Number) -> bool: + return self._predicate(item) + + +class SupportByIntervals(SupportByPredicate): + def __init__(self, support: ContinuousNDSupport): + SupportByPredicate.__init__(self, lambda x: x in support) + + __all__ = [ # Base support protocol "Support", "ContinuousSupport", + "ContinuousNDSupport", + "SupportByPredicate", + "SupportByIntervals", # Discrete support protocol and implementations "DiscreteSupport", "ExplicitTableDiscreteSupport", diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py index 44994c25..3eb3fb8a 100644 --- a/src/pysatl_core/families/__init__.py +++ b/src/pysatl_core/families/__init__.py @@ -15,12 +15,10 @@ from .configuration import configure_families_register from .distribution import ParametricFamilyDistribution from .exponential_family import ( + # CanonicalContinuousExponentialClassFamily, + ContinuousExponentialClassFamily, ExponentialConjugateHyperparameters, - ExponentialFamily, ExponentialFamilyParametrization, - NaturalExponentialFamily, - SpacePredicate, - SpacePredicateArray, ) from .parametric_family import ParametricFamily from .parametrizations import ( @@ -42,12 +40,10 @@ "configure_families_register", # builtins *_builtins_all, - "ExponentialFamily", + "ContinuousExponentialClassFamily", "ExponentialFamilyParametrization", "ExponentialConjugateHyperparameters", - "SpacePredicate", - "SpacePredicateArray", - "NaturalExponentialFamily", + # "CanonicalContinuousExponentialClassFamily", ] del _builtins_all diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index f2c0fcc3..5f277631 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -2,39 +2,33 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Iterable, Sized, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np +from scipy.differentiate import jacobian from scipy.integrate import nquad from scipy.linalg import det -from scipy.differentiate import jacobian -from pysatl_core.distributions import ( - SamplingStrategy, +from pysatl_core.distributions.support import ( + ContinuousSupport, + SupportByPredicate, ) from pysatl_core.families.parametric_family import ParametricFamily from pysatl_core.families.parametrizations import Parametrization, parametrization from pysatl_core.types import ( - GenericCharacteristicName, + CharacteristicName, DistributionType, + GenericCharacteristicName, ParametrizationName, ) if TYPE_CHECKING: from pysatl_core.distributions.support import Support + from pysatl_core.types import Number, NumericArray type ParametrizedFunction = Callable[[Parametrization, Any], Any] type SupportArg = Callable[[Parametrization], Support | None] | None - - -PDF = "pdf" -CDF = "cdf" -PPF = "ppf" -CF = "char_func" -MEAN = "mean" -VAR = "var" -SKEW = "skewness" -KURT = "kurtosis" + type NumberParameter = Number | NumericArray @dataclass @@ -43,61 +37,39 @@ class ExponentialFamilyParametrization(Parametrization): Standard parametrization of Exponential Family. """ - theta: list[float] # TODO: mb more clever - - -class ExponentialConjugateHyperparameters: - def __init__(self, alpha: Any, beta: int): - self.alpha = alpha - self.beta = beta - - def __str__(self) -> str: - return f"alpha={self.alpha}, beta={self.beta}" - - -def doesAccept(x: list[float] | float, support: list[tuple[float, float]]) -> bool: - if not hasattr(x, "__len__"): - x = [x] - - x = cast(list[float], x) + theta: NumberParameter - def accept_1D(x: float, borders: tuple[float, float]) -> bool: - left, right = borders - if abs(x) == 0 and (abs(left) == 0 or abs(right) == 0): - return False - return left <= x <= right + def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization: + return self - return all(accept_1D(x_i, border) for x_i, border in zip(x, support, strict=False)) - -class SpacePredicate: - def __init__(self, predicate: Callable[[Any], bool]): - self._predicate = predicate - - def accepts(self, x: Any) -> bool: - return self._predicate(x) +@dataclass +class ExponentialConjugateHyperparameters: + effective_suff_stat_value: NumberParameter + effective_sample_size: int -class SpacePredicateArray(SpacePredicate): - def __init__(self, space: list[tuple[float, float]]): - SpacePredicate.__init__(self, lambda x: doesAccept(x, space)) - self._space = space +class ContinuousExponentialClassFamily(ParametricFamily): + """ + Representation of exponential class with density = h(x) * exp( + A(t)), + where canonical parametrization is that, when n = t + Usage of this class: + - you can use method transform_to_another to replace x to smth else, for example, into + """ -class NaturalExponentialFamily(ParametricFamily): def __init__( self, *, - log_partition: Callable[[ExponentialFamilyParametrization], float], - sufficient_statistics: Callable[[Any], Any], - normalization_constant: Callable[[Any], Any], - support: SpacePredicate, - parameter_space: SpacePredicate, - sufficient_statistics_values: SpacePredicate, - name: str = "NaturalExponentialFamily", + log_partition: Callable[[NumberParameter], NumberParameter], + sufficient_statistics: Callable[[NumberParameter], NumberParameter], + normalization_constant: Callable[[NumberParameter], NumberParameter], + support: SupportByPredicate, + parameter_space: SupportByPredicate, + sufficient_statistics_values: SupportByPredicate, + name: str = "ExponentialFamily", distr_type: DistributionType | Callable[[Parametrization], DistributionType], distr_parametrizations: list[ParametrizationName], - sampling_strategy: SamplingStrategy, support_by_parametrization: SupportArg = None, ): self._sufficient = sufficient_statistics @@ -112,9 +84,9 @@ def __init__( GenericCharacteristicName, dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction, ] = { - PDF: self.density, - MEAN: self._mean, - VAR: self._var, + CharacteristicName.PDF: self.density, + CharacteristicName.MEAN: self._mean, + CharacteristicName.VAR: self._var, } ParametricFamily.__init__( @@ -123,25 +95,17 @@ def __init__( distr_type=distr_type, distr_parametrizations=distr_parametrizations, distr_characteristics=distr_characteristics, - sampling_strategy=sampling_strategy, support_by_parametrization=support_by_parametrization, ) parametrization(family=self, name="theta")(ExponentialFamilyParametrization) - def _transform_to_natural_parametrization( - self, theta_parametrization: ExponentialFamilyParametrization - ) -> ExponentialFamilyParametrization: - return theta_parametrization - @property def log_density(self) -> ParametrizedFunction: - def log_density_func(parametrization: Parametrization, x: Any) -> Any: + def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Number: parametrization = cast(ExponentialFamilyParametrization, parametrization) - parametrization = self._transform_to_natural_parametrization( - parametrization - ) - if not self._support.accepts(x): - return float("-inf") + parametrization = parametrization.transform_to_base_parametrization() + if x not in self._support: + return -np.inf theta = parametrization.theta sufficient = self._sufficient(x) @@ -149,12 +113,8 @@ def log_density_func(parametrization: Parametrization, x: Any) -> Any: if hasattr(dot, "__len__"): dot = dot[0] - result = float( - np.log(self._normalization(x)) - + dot - + self._log_partition(parametrization) - ) - return result + result = np.log(self._normalization(x)) + dot + self._log_partition(theta) + return cast(np.floating, result.item()) return log_density_func @@ -163,66 +123,55 @@ def density(self) -> ParametrizedFunction: return lambda parametrization, x: np.exp(self.log_density(parametrization, x)) @property - def conjugate_prior_family(self) -> NaturalExponentialFamily: + def conjugate_prior_family(self) -> ContinuousExponentialClassFamily: def conjugate_sufficient( - theta: float, - ) -> list[Any]: - if not self._parameter_space.accepts(theta): - return [float("-inf"), float("-inf")] - - parametrization = ExponentialFamilyParametrization([theta]) - # parametrization.theta = [theta] - return [ - theta, - self._log_partition(parametrization), - ] + theta: NumberParameter, + ) -> NumberParameter: + if not hasattr(theta, "__len__"): + theta = np.array([theta]) - def conjugate_log_partition( - parametrization: ExponentialFamilyParametrization, - ) -> Any: - alpha = parametrization.theta[0] - beta = parametrization.theta[1] + if theta not in self._parameter_space: + return np.full(len(theta) + 1, float("-inf")) + return np.append(theta, self._log_partition(theta)) - def pdf(theta: Any) -> Any: + def conjugate_log_partition( + parametrization: NumberParameter, + ) -> NumberParameter: + def pdf(theta: NumberParameter) -> NumberParameter: if not hasattr(theta, "__len__"): - theta = [theta] - parametrization = ExponentialFamilyParametrization(theta=theta) - # parametrization.theta = theta - return np.exp( - np.dot(theta, alpha) + beta * self._log_partition(parametrization) - )[0] + theta = np.array([theta]) + return cast( + np.floating, + np.exp( + np.dot( + conjugate_sufficient(theta), + parametrization, + ) + ).item(), + ) all_value = nquad( - lambda x: pdf(x) if self._parameter_space.accepts(x) else 0, + lambda x: pdf(x) if x in self._parameter_space else 0, # type: ignore[arg-type] [(float("-inf"), float("+inf"))], )[0] - return -np.log(all_value) - - # TODO: remove hardcoding - Done, all hardcoding is only on user's hands - # 1. pr with prototype/draft - in progress - # 2. write instruction about to add distributions as member of exponential family - not started - # 3. parametrization's spaces (передавать в конструктор) - maybe impossible, discuss this with desiment on meeting + return cast(np.float64, -np.log(all_value)) def conjugate_sufficient_accepts( - parametrization: ExponentialFamilyParametrization, + theta: NumericArray, ) -> bool: - theta = parametrization.theta xi = theta[:-1] nu = theta[-1] - return self._sufficient_statistics_values.accepts( - xi - ) and SpacePredicateArray([(0, float("+inf"))]).accepts(nu) + return xi in self._sufficient_statistics_values and nu in ContinuousSupport(0, np.inf) - return NaturalExponentialFamily( + return ContinuousExponentialClassFamily( log_partition=conjugate_log_partition, sufficient_statistics=conjugate_sufficient, normalization_constant=lambda _: 1, support=self._parameter_space, sufficient_statistics_values=self._parameter_space, # TODO: write convex hull for this - parameter_space=SpacePredicate(conjugate_sufficient_accepts), + parameter_space=SupportByPredicate(conjugate_sufficient_accepts), # type: ignore[arg-type] name=self.name, - sampling_strategy=self.sampling_strategy, distr_type=self._distr_type, distr_parametrizations=self.parametrization_names, support_by_parametrization=self.support_resolver, @@ -231,7 +180,7 @@ def conjugate_sufficient_accepts( def transform( self, transform_function: Callable[[Any], Any], - ) -> NaturalExponentialFamily: + ) -> ContinuousExponentialClassFamily: def calculate_jacobian(x: Any) -> Any: if type(x) is not list: x = np.array([x]) @@ -239,7 +188,7 @@ def calculate_jacobian(x: Any) -> Any: return np.abs(det(jacobian(transform_function, x).df)) def new_support(x: Any) -> bool: - return self._support.accepts(transform_function(x)) + return transform_function(x) in self._support def new_sufficient(x: Any) -> Any: return self._sufficient(transform_function(x)) @@ -247,17 +196,16 @@ def new_sufficient(x: Any) -> Any: def new_normalization(x: Any) -> Any: return self._normalization(x) * calculate_jacobian(x) - return NaturalExponentialFamily( + return ContinuousExponentialClassFamily( log_partition=self._log_partition, sufficient_statistics=new_sufficient, normalization_constant=new_normalization, - support=SpacePredicate(new_support), + support=SupportByPredicate(new_support), parameter_space=self._parameter_space, sufficient_statistics_values=self._sufficient_statistics_values, name=f"Transformed{self._name}", distr_type=self._distr_type, distr_parametrizations=self.parametrization_names, - sampling_strategy=self.sampling_strategy, support_by_parametrization=self.support_resolver, ) @@ -269,10 +217,8 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any: if hasattr(x, "__len__"): dimension_size = len(x) return nquad( - lambda x: ( - np.dot(x, self.density(parametrization, x)) - if self._support.accepts(x) - else 0 + lambda x: ( # type: ignore[arg-type] + np.dot(x, self.density(parametrization, x)) if x in self._support else 0 ), [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -287,10 +233,8 @@ def func(parametrization: Parametrization, x: Any) -> Any: if hasattr(x, "__len__"): dimension_size = len(x) return nquad( - lambda x: ( - x**2 * self.density(parametrization, x) - if self._support.accepts(x) - else 0 + lambda x: ( # type: ignore[arg-type] + x**2 * self.density(parametrization, x) if x in self._support else 0 ), [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -301,87 +245,26 @@ def func(parametrization: Parametrization, x: Any) -> Any: def _var(self) -> ParametrizedFunction: def func(parametrization: Parametrization, x: Any) -> Any: parametrization = cast(ExponentialFamilyParametrization, parametrization) - return ( - self._second_moment(parametrization, x) - - self._mean(parametrization, x) ** 2 - ) + return self._second_moment(parametrization, x) - self._mean(parametrization, x) ** 2 return func def posterior_hyperparameters( self, prior_hyper: ExponentialConjugateHyperparameters, sample: list[Any] ) -> ExponentialConjugateHyperparameters: - alpha = prior_hyper.alpha - beta = prior_hyper.beta - - alpha_post = None - beta_post = None + posterior_effective_suff_stat_value = prior_hyper.effective_suff_stat_value + posterior_effective_sample_size = prior_hyper.effective_sample_size if hasattr(sample, "__iter__") and not isinstance(sample, str): - alpha_post = np.sum([self._sufficient(x) for x in sample], axis=0) - beta_post = len(sample) + posterior_effective_suff_stat_value += np.sum( + [self._sufficient(x) for x in sample], # type: ignore[arg-type] + axis=0, + ) + posterior_effective_sample_size += len(sample) else: - alpha_post = self._sufficient(sample) - beta_post = 1 + posterior_effective_suff_stat_value += self._sufficient(sample) # type: ignore[arg-type] + posterior_effective_sample_size += 1 return ExponentialConjugateHyperparameters( - alpha=alpha + alpha_post, beta=beta + beta_post - ) - - -class ExponentialFamily(NaturalExponentialFamily): - def __init__( - self, - *, - log_partition: Callable[[ExponentialFamilyParametrization], float], - sufficient_statistics: Callable[[Any], Any], - normalization_constant: Callable[[Any], Any], - parameter_from_natural_parameter: Callable[[Any], Any], - natural_parameter: Callable[ - [ExponentialFamilyParametrization], ExponentialFamilyParametrization - ], - support: SpacePredicate, - parameter_space: SpacePredicate, - sufficient_statistics_values: SpacePredicate, - distr_type: DistributionType | Callable[[Parametrization], DistributionType], - distr_parametrizations: list[ParametrizationName], - sampling_strategy: SamplingStrategy, - name: str = "ExponentialFamily", - support_by_parametrization: SupportArg = None, - ): - def natural_log_partition( - eta_parametrizaion: ExponentialFamilyParametrization, - ) -> Any: - eta = eta_parametrizaion.theta - theta = parameter_from_natural_parameter(eta) - return log_partition(ExponentialFamilyParametrization(theta=[theta])) - - natural_sufficient_statistics_values = SpacePredicate( - lambda eta: sufficient_statistics_values.accepts( - parameter_from_natural_parameter(eta) - ) + effective_suff_stat_value=posterior_effective_suff_stat_value, + effective_sample_size=posterior_effective_sample_size, ) - - self._natural_parameter = natural_parameter - natural_parameter_space = SpacePredicate( - lambda eta: parameter_space.accepts(parameter_from_natural_parameter(eta)), - ) - - NaturalExponentialFamily.__init__( - self, - log_partition=natural_log_partition, - sufficient_statistics=sufficient_statistics, - normalization_constant=normalization_constant, - support=support, - parameter_space=natural_parameter_space, - sufficient_statistics_values=natural_sufficient_statistics_values, - name=name, - distr_parametrizations=distr_parametrizations, - distr_type=distr_type, - sampling_strategy=sampling_strategy, - support_by_parametrization=support_by_parametrization, - ) - - def _transform_to_natural_parametrization( - self, theta_parametrization: ExponentialFamilyParametrization - ) -> ExponentialFamilyParametrization: - return self._natural_parameter(theta_parametrization) diff --git a/src/pysatl_core/types.py b/src/pysatl_core/types.py index 25aa63e4..067cc1df 100644 --- a/src/pysatl_core/types.py +++ b/src/pysatl_core/types.py @@ -250,6 +250,25 @@ def shape(self) -> ContinuousSupportShape1D: type Method[In, Out] = AnalyticalComputation[In, Out] | FittedComputationMethod[In, Out] """Type alias for a distribution computation method (analytical or fitted).""" + +@dataclass(frozen=True, slots=True) +class IntervalND: + intervals: list[Interval1D] + + def contains(self, x: Number | NumericArray) -> bool | BoolArray: + if not hasattr(x, "__iter__"): + x = np.array([x]) + + return all( + x_coordinate in interval + for interval, x_coordinate in zip(self.intervals, x, strict=True) + ) + + def __contains__(self, x: object) -> bool: + """Check if a single point is in the interval.""" + return bool(self.contains(cast(Number, x))) + + type GenericCharacteristicName = str """Type alias for characteristic names (e.g., 'pdf', 'cdf').""" @@ -465,6 +484,7 @@ class FamilyName(StrEnum): "TransformationMethodSpecsMap", "DistributionType", "Interval1D", + "IntervalND", "ContinuousSupportShape1D", "BoolArray", "NumPyNumber", diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py index 35ea7cc7..feb19a40 100644 --- a/tests/unit/families/test_exponential_family.py +++ b/tests/unit/families/test_exponential_family.py @@ -1,70 +1,46 @@ -from typing import Any, cast +from typing import cast import numpy as np import pytest import scipy from numpy.testing import assert_allclose -from pysatl_core.distributions.strategies import DefaultSamplingUnivariateStrategy +from pysatl_core.distributions.support import ContinuousNDSupport, SupportByIntervals from pysatl_core.families import ( - ExponentialFamily, - ExponentialFamilyParametrization, - SpacePredicateArray, + ContinuousExponentialClassFamily, ) from pysatl_core.families.registry import ParametricFamilyRegister -from pysatl_core.types import UnivariateContinuous +from pysatl_core.types import Interval1D, UnivariateContinuous def gamma_pdf(alpha: float, beta: float, x: float) -> float: - return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() + return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() # type: ignore[attr-defined] @pytest.fixture(scope="function") -def conjugate_for_exponential() -> ExponentialFamily: - def get_parameter_from_natural_parameter( - eta_parametrization: ExponentialFamilyParametrization, - ): - if hasattr(eta_parametrization, "__len__"): - if len(eta_parametrization) > 1: - return list(-1 * np.array(eta_parametrization)) - eta_parametrization = eta_parametrization[0] - return -eta_parametrization - - def natural_parameter( - theta_parametrization: Any, - ) -> Any: - if type(theta_parametrization) is ExponentialFamilyParametrization: - theta_parametrization = cast( - ExponentialFamilyParametrization, theta_parametrization - ) - eta = -theta_parametrization.theta - return ExponentialFamilyParametrization(theta=eta) - - return -1 * theta_parametrization - - def transform_function(x: list[Any]) -> list[Any]: - if type(x) is not list: - return -x - return [-x[0]] - - fam = ExponentialFamily( - log_partition=lambda parametrization: np.log(parametrization.theta[0]), +def conjugate_for_exponential() -> ContinuousExponentialClassFamily: + def transform_function(x: list[float] | float) -> list[float] | float: + if type(x) is list: + return [-x[0]] + return -x # type: ignore[operator] + + support_neg = SupportByIntervals(ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)])) + support_pos = SupportByIntervals(ContinuousNDSupport(intervals=[Interval1D(0, np.inf)])) + fam = ContinuousExponentialClassFamily( + log_partition=lambda parametrization: np.log(-parametrization), sufficient_statistics=lambda x: x, normalization_constant=lambda _: 1, - parameter_from_natural_parameter=get_parameter_from_natural_parameter, - natural_parameter=natural_parameter, - parameter_space=SpacePredicateArray([(0, float("+inf"))]), - sufficient_statistics_values=SpacePredicateArray([(0, float("+inf"))]), - support=SpacePredicateArray([(0, float("+inf"))]), + parameter_space=support_neg, + sufficient_statistics_values=support_pos, + support=support_pos, distr_type=UnivariateContinuous, distr_parametrizations=["theta"], - sampling_strategy=DefaultSamplingUnivariateStrategy(), ) conjugate_fam = fam.conjugate_prior_family.transform(transform_function) ParametricFamilyRegister().register(conjugate_fam) return cast( - ExponentialFamily, + ContinuousExponentialClassFamily, ParametricFamilyRegister().get("TransformedExponentialFamily"), ) @@ -72,34 +48,28 @@ def transform_function(x: list[Any]) -> list[Any]: @pytest.mark.parametrize("theta1", range(2, 5)) @pytest.mark.parametrize("theta2", range(2, 5)) def test_exponential_pdf(theta1, theta2, conjugate_for_exponential): - gamma_family: ExponentialFamily = conjugate_for_exponential + gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential alpha = theta2 + 1 beta = theta1 - exponential = gamma_family( - theta=np.array([theta1, theta2]), parametrization_name="theta" - ) + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) x = [i / 10 for i in range(100)] - assert_allclose( - [pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6 - ) + assert_allclose([pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6) @pytest.mark.parametrize("theta1", range(2, 5)) @pytest.mark.parametrize("theta2", range(2, 5)) def test_exponential_mean(theta1, theta2, conjugate_for_exponential): - gamma_family: ExponentialFamily = conjugate_for_exponential + gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential alpha = theta2 + 1 beta = theta1 - exponential = gamma_family( - theta=np.array([theta1, theta2]), parametrization_name="theta" - ) + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") mean = exponential.computation_strategy.query_method("mean", distr=exponential) assert np.isclose(mean(12), alpha / beta, rtol=1e-6) @@ -107,13 +77,11 @@ def test_exponential_mean(theta1, theta2, conjugate_for_exponential): @pytest.mark.parametrize("theta1", range(2, 5)) @pytest.mark.parametrize("theta2", range(2, 5)) def test_exponential_var(theta1, theta2, conjugate_for_exponential): - gamma_family: ExponentialFamily = conjugate_for_exponential + gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential alpha = theta2 + 1 beta = theta1 - exponential = gamma_family( - theta=np.array([theta1, theta2]), parametrization_name="theta" - ) + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") var = exponential.computation_strategy.query_method("var", distr=exponential) assert np.isclose(var(12), alpha / beta**2, rtol=1e-6) From a1684e4bf00432a810e37dbfc16bdeb9718eeb63 Mon Sep 17 00:00:00 2001 From: domosedy Date: Sun, 17 May 2026 15:15:35 +0300 Subject: [PATCH 5/9] feat(exponential): add posterior predictive function --- .../families/exponential_family.py | 51 ++++++++++++++++--- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index 5f277631..6ae62c8e 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -20,6 +20,7 @@ DistributionType, GenericCharacteristicName, ParametrizationName, + UnivariateContinuous, ) if TYPE_CHECKING: @@ -44,9 +45,14 @@ def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization: @dataclass -class ExponentialConjugateHyperparameters: - effective_suff_stat_value: NumberParameter - effective_sample_size: int +class ExponentialConjugateHyperparameters(Parametrization): + effective_suff_stat_value: NumericArray + effective_sample_size: Number + + def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization: + return ExponentialFamilyParametrization( + np.append(self.effective_suff_stat_value, self.effective_sample_size) + ) class ContinuousExponentialClassFamily(ParametricFamily): @@ -250,10 +256,10 @@ def func(parametrization: Parametrization, x: Any) -> Any: return func def posterior_hyperparameters( - self, prior_hyper: ExponentialConjugateHyperparameters, sample: list[Any] + self, parametrizaiton: ExponentialConjugateHyperparameters, sample: list[Any] ) -> ExponentialConjugateHyperparameters: - posterior_effective_suff_stat_value = prior_hyper.effective_suff_stat_value - posterior_effective_sample_size = prior_hyper.effective_sample_size + posterior_effective_suff_stat_value = parametrizaiton.effective_suff_stat_value + posterior_effective_sample_size = parametrizaiton.effective_sample_size if hasattr(sample, "__iter__") and not isinstance(sample, str): posterior_effective_suff_stat_value += np.sum( [self._sufficient(x) for x in sample], # type: ignore[arg-type] @@ -268,3 +274,36 @@ def posterior_hyperparameters( effective_suff_stat_value=posterior_effective_suff_stat_value, effective_sample_size=posterior_effective_sample_size, ) + + @property + def posterior_predictive(self) -> ParametricFamily: + def conjugate_log_partition( + parametrization: ExponentialConjugateHyperparameters, + ) -> NumberParameter: + conjugate_value = self.conjugate_prior_family._log_partition( + parametrization.transform_to_base_parametrization().theta + ) + return np.exp(conjugate_value) + + def posterior_density(parametrization: Parametrization, x: NumberParameter) -> Number: + parametrization = cast(ExponentialConjugateHyperparameters, parametrization) + return cast( + np.float32, + self._normalization(x) + * conjugate_log_partition(parametrization) + / conjugate_log_partition( + self.posterior_hyperparameters( + parametrizaiton=parametrization, sample=[self._sufficient(x)] + ) + ), + ) + + family = ParametricFamily( + name=f"PosteriorPredictive{self.name}", + distr_type=UnivariateContinuous, + distr_characteristics={CharacteristicName.PDF: posterior_density}, + distr_parametrizations=["posterior"], + support_by_parametrization=lambda _: ContinuousSupport(), + ) + parametrization(family=family, name="posterior")(ExponentialConjugateHyperparameters) + return family From 2e6fe1c5c134a0b1b13f2804afaea42ea34e4abf Mon Sep 17 00:00:00 2001 From: domosedy Date: Mon, 11 May 2026 15:35:36 +0300 Subject: [PATCH 6/9] refactor(exponential): fix pr issues --- src/pysatl_core/distributions/support.py | 31 +++++++---- src/pysatl_core/families/__init__.py | 1 - .../families/exponential_family.py | 32 +++++++----- src/pysatl_core/types.py | 11 ++++ .../unit/families/test_exponential_family.py | 52 +++++++++++++------ 5 files changed, 86 insertions(+), 41 deletions(-) diff --git a/src/pysatl_core/distributions/support.py b/src/pysatl_core/distributions/support.py index c5d32a3b..ce9e1e27 100644 --- a/src/pysatl_core/distributions/support.py +++ b/src/pysatl_core/distributions/support.py @@ -27,7 +27,14 @@ import numpy as np -from pysatl_core.types import BoolArray, Interval1D, IntervalND, Number, NumericArray +from pysatl_core.types import ( + BoolArray, + Interval1D, + IntervalND, + Number, + NumberParameter, + NumericArray, +) if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -56,7 +63,7 @@ class ContinuousSupport(Interval1D, Support): """ -class ContinuousNDSupport(IntervalND, Support): # type: ignore[misc] +class ContinuousNDSupport(IntervalND, Support): """ Support for continuous distributions represented as an array of intervals. @@ -446,17 +453,20 @@ def is_right_bounded(self) -> bool: __iter__ = iter_points -class SupportByPredicate: - def __init__(self, predicate: Callable[[NumericArray | Number], bool]): - self._predicate = predicate +@dataclass(slots=True) +class SupportByPredicate(Support): + predicate: Callable[[NumberParameter], bool] - def __contains__(self, item: NumericArray | Number) -> bool: - return self._predicate(item) + @overload + def contains(self, x: Number) -> bool: ... + @overload + def contains(self, x: NumericArray) -> BoolArray: ... + def contains(self, x: NumberParameter) -> bool | BoolArray: + return self.predicate(x) -class SupportByIntervals(SupportByPredicate): - def __init__(self, support: ContinuousNDSupport): - SupportByPredicate.__init__(self, lambda x: x in support) + def __contains__(self, item: object) -> bool | BoolArray: + return self.contains(cast(NumberParameter, item)) __all__ = [ @@ -465,7 +475,6 @@ def __init__(self, support: ContinuousNDSupport): "ContinuousSupport", "ContinuousNDSupport", "SupportByPredicate", - "SupportByIntervals", # Discrete support protocol and implementations "DiscreteSupport", "ExplicitTableDiscreteSupport", diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py index 3eb3fb8a..502dac91 100644 --- a/src/pysatl_core/families/__init__.py +++ b/src/pysatl_core/families/__init__.py @@ -43,7 +43,6 @@ "ContinuousExponentialClassFamily", "ExponentialFamilyParametrization", "ExponentialConjugateHyperparameters", - # "CanonicalContinuousExponentialClassFamily", ] del _builtins_all diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index 6ae62c8e..ead485c9 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -1,6 +1,11 @@ from __future__ import annotations -from collections.abc import Callable +__author__ = "Vinogradov Ilya" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + + +from collections.abc import Callable, Iterable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast @@ -25,11 +30,10 @@ if TYPE_CHECKING: from pysatl_core.distributions.support import Support - from pysatl_core.types import Number, NumericArray + from pysatl_core.types import Number, NumberParameter, NumericArray type ParametrizedFunction = Callable[[Parametrization, Any], Any] type SupportArg = Callable[[Parametrization], Support | None] | None - type NumberParameter = Number | NumericArray @dataclass @@ -77,6 +81,7 @@ def __init__( distr_type: DistributionType | Callable[[Parametrization], DistributionType], distr_parametrizations: list[ParametrizationName], support_by_parametrization: SupportArg = None, + base_score: Callable[[Parametrization, NumericArray], NumericArray] | None = None, ): self._sufficient = sufficient_statistics self._log_partition = log_partition @@ -91,8 +96,8 @@ def __init__( dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction, ] = { CharacteristicName.PDF: self.density, - CharacteristicName.MEAN: self._mean, - CharacteristicName.VAR: self._var, + CharacteristicName.MEAN_DEFAULT: self._mean, + CharacteristicName.VAR_DEFAULT: self._var, } ParametricFamily.__init__( @@ -102,6 +107,7 @@ def __init__( distr_parametrizations=distr_parametrizations, distr_characteristics=distr_characteristics, support_by_parametrization=support_by_parametrization, + base_score=base_score, ) parametrization(family=self, name="theta")(ExponentialFamilyParametrization) @@ -176,7 +182,7 @@ def conjugate_sufficient_accepts( normalization_constant=lambda _: 1, support=self._parameter_space, sufficient_statistics_values=self._parameter_space, # TODO: write convex hull for this - parameter_space=SupportByPredicate(conjugate_sufficient_accepts), # type: ignore[arg-type] + parameter_space=SupportByPredicate(predicate=conjugate_sufficient_accepts), # type: ignore[arg-type] name=self.name, distr_type=self._distr_type, distr_parametrizations=self.parametrization_names, @@ -185,28 +191,28 @@ def conjugate_sufficient_accepts( def transform( self, - transform_function: Callable[[Any], Any], + transform_function: Callable[[NumberParameter], NumberParameter], ) -> ContinuousExponentialClassFamily: - def calculate_jacobian(x: Any) -> Any: - if type(x) is not list: + def calculate_jacobian(x: NumberParameter) -> NumberParameter: + if not isinstance(x, Iterable): x = np.array([x]) return np.abs(det(jacobian(transform_function, x).df)) - def new_support(x: Any) -> bool: + def new_support(x: NumberParameter) -> bool: return transform_function(x) in self._support - def new_sufficient(x: Any) -> Any: + def new_sufficient(x: NumberParameter) -> NumberParameter: return self._sufficient(transform_function(x)) - def new_normalization(x: Any) -> Any: + def new_normalization(x: NumberParameter) -> NumberParameter: return self._normalization(x) * calculate_jacobian(x) return ContinuousExponentialClassFamily( log_partition=self._log_partition, sufficient_statistics=new_sufficient, normalization_constant=new_normalization, - support=SupportByPredicate(new_support), + support=SupportByPredicate(predicate=new_support), parameter_space=self._parameter_space, sufficient_statistics_values=self._sufficient_statistics_values, name=f"Transformed{self._name}", diff --git a/src/pysatl_core/types.py b/src/pysatl_core/types.py index 067cc1df..50ddbe60 100644 --- a/src/pysatl_core/types.py +++ b/src/pysatl_core/types.py @@ -121,6 +121,9 @@ class EuclideanDistributionType(DistributionType): type BoolArray = NDArray[np.bool_] """Type alias for boolean arrays.""" +type NumberParameter = Number | NumericArray +"""Type alias for numeric or list parameter""" + class ContinuousSupportShape1D(Enum): """ @@ -255,6 +258,12 @@ def shape(self) -> ContinuousSupportShape1D: class IntervalND: intervals: list[Interval1D] + @overload + def contains(self, x: Number) -> bool: ... + + @overload + def contains(self, x: NumericArray) -> BoolArray: ... + def contains(self, x: Number | NumericArray) -> bool | BoolArray: if not hasattr(x, "__iter__"): x = np.array([x]) @@ -335,6 +344,8 @@ class CharacteristicName(StrEnum): CDF = "cdf" PPF = "ppf" PMF = "pmf" + MEAN_DEFAULT = "MEAN_DEFAULT" # defined in class implementation of mean + VAR_DEFAULT = "VAR_DEFAULT" # defined in class implementation of var LPDF = "lpdf" # unimplemented in graph yet CF = "cf" # unimplemented in graph yet SF = "sf" # unimplemented in graph yet diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py index feb19a40..d4896ce5 100644 --- a/tests/unit/families/test_exponential_family.py +++ b/tests/unit/families/test_exponential_family.py @@ -1,3 +1,9 @@ +__author__ = "Leonid Elkin" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +import itertools +from collections.abc import Iterable from typing import cast import numpy as np @@ -5,12 +11,12 @@ import scipy from numpy.testing import assert_allclose -from pysatl_core.distributions.support import ContinuousNDSupport, SupportByIntervals +from pysatl_core.distributions.support import ContinuousNDSupport, SupportByPredicate from pysatl_core.families import ( ContinuousExponentialClassFamily, ) from pysatl_core.families.registry import ParametricFamilyRegister -from pysatl_core.types import Interval1D, UnivariateContinuous +from pysatl_core.types import CharacteristicName, Interval1D, NumberParameter, UnivariateContinuous def gamma_pdf(alpha: float, beta: float, x: float) -> float: @@ -19,13 +25,17 @@ def gamma_pdf(alpha: float, beta: float, x: float) -> float: @pytest.fixture(scope="function") def conjugate_for_exponential() -> ContinuousExponentialClassFamily: - def transform_function(x: list[float] | float) -> list[float] | float: - if type(x) is list: - return [-x[0]] - return -x # type: ignore[operator] + def transform_function(x: NumberParameter) -> NumberParameter: + if isinstance(x, Iterable): + return np.array([-x[0]]) + return -x - support_neg = SupportByIntervals(ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)])) - support_pos = SupportByIntervals(ContinuousNDSupport(intervals=[Interval1D(0, np.inf)])) + support_neg = SupportByPredicate( + predicate=lambda x: x in ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)]) + ) + support_pos = SupportByPredicate( + predicate=lambda x: x in ContinuousNDSupport(intervals=[Interval1D(0, np.inf)]) + ) fam = ContinuousExponentialClassFamily( log_partition=lambda parametrization: np.log(-parametrization), sufficient_statistics=lambda x: x, @@ -45,8 +55,10 @@ def transform_function(x: list[float] | float) -> list[float] | float: ) -@pytest.mark.parametrize("theta1", range(2, 5)) -@pytest.mark.parametrize("theta2", range(2, 5)) +@pytest.mark.parametrize( + ("theta1", "theta2"), + itertools.product(range(2, 5), range(2, 5)), +) def test_exponential_pdf(theta1, theta2, conjugate_for_exponential): gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential @@ -61,8 +73,10 @@ def test_exponential_pdf(theta1, theta2, conjugate_for_exponential): assert_allclose([pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6) -@pytest.mark.parametrize("theta1", range(2, 5)) -@pytest.mark.parametrize("theta2", range(2, 5)) +@pytest.mark.parametrize( + ("theta1", "theta2"), + itertools.product(range(2, 5), range(2, 5)), +) def test_exponential_mean(theta1, theta2, conjugate_for_exponential): gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential @@ -70,12 +84,16 @@ def test_exponential_mean(theta1, theta2, conjugate_for_exponential): beta = theta1 exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") - mean = exponential.computation_strategy.query_method("mean", distr=exponential) + mean = exponential.computation_strategy.query_method( + CharacteristicName.MEAN_DEFAULT, distr=exponential + ) assert np.isclose(mean(12), alpha / beta, rtol=1e-6) -@pytest.mark.parametrize("theta1", range(2, 5)) -@pytest.mark.parametrize("theta2", range(2, 5)) +@pytest.mark.parametrize( + ("theta1", "theta2"), + itertools.product(range(2, 5), range(2, 5)), +) def test_exponential_var(theta1, theta2, conjugate_for_exponential): gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential @@ -83,5 +101,7 @@ def test_exponential_var(theta1, theta2, conjugate_for_exponential): beta = theta1 exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") - var = exponential.computation_strategy.query_method("var", distr=exponential) + var = exponential.computation_strategy.query_method( + CharacteristicName.VAR_DEFAULT, distr=exponential + ) assert np.isclose(var(12), alpha / beta**2, rtol=1e-6) From f3b65759d3f6ca3a4d8c1360b1f2623d8fd5113d Mon Sep 17 00:00:00 2001 From: domosedy Date: Mon, 11 May 2026 15:35:40 +0300 Subject: [PATCH 7/9] docs(exponential): add docstrings --- .../families/exponential_family.py | 165 ++++++++++++++++-- src/pysatl_core/types.py | 3 + .../unit/families/test_exponential_family.py | 2 +- 3 files changed, 156 insertions(+), 14 deletions(-) diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index ead485c9..666dc2f9 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -1,10 +1,16 @@ +""" +Exponential family distributions in continuous spaces. + +This module implements the continuous exponential family of probability distributions, +their conjugate priors, posterior inference, and posterior predictive distributions. +""" + from __future__ import annotations __author__ = "Vinogradov Ilya" __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" - from collections.abc import Callable, Iterable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast @@ -39,21 +45,48 @@ @dataclass class ExponentialFamilyParametrization(Parametrization): """ - Standard parametrization of Exponential Family. + Standard parametrization of an exponential family distribution. + + This parametrization uses the natural (canonical) parameter vector `theta` + The density is expressed as: + f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ)) + + Attributes: + theta (NumberParameter): Natural parameter vector (can be a scalar or array) """ theta: NumberParameter def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization: + """Return the base parametrization (identity transform for canonical form).""" return self @dataclass class ExponentialConjugateHyperparameters(Parametrization): + """ + Hyperparameters for the conjugate prior of an exponential family + + For a prior of the form: + p(θ) ∝ exp(ν₀ᵀ T(θ) + n₀ A(θ)) + the hyperparameters are: + effective_suff_stat_value = ν₀ + effective_sample_size = n₀ + + Attributes: + effective_suff_stat_value (NumericArray): Pseudo‑sufficient statistic ν₀ + effective_sample_size (Number): Pseudo‑sample size n₀ (a non‑negative scalar) + """ + effective_suff_stat_value: NumericArray effective_sample_size: Number def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization: + """ + Convert hyperparameters to a canonical parametrization. + + The resulting parameter vector is [ν₀, n₀] concatenated. + """ return ExponentialFamilyParametrization( np.append(self.effective_suff_stat_value, self.effective_sample_size) ) @@ -61,11 +94,27 @@ def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization: class ContinuousExponentialClassFamily(ParametricFamily): """ - Representation of exponential class with density = h(x) * exp( + A(t)), - where canonical parametrization is that, when n = t - - Usage of this class: - - you can use method transform_to_another to replace x to smth else, for example, into + Representation of a continuous exponential family distribution. + + The density is given by: + f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ)) + + where: + - θ is the natural parameter, + - T(x) is the sufficient statistic vector, + - h(x) is the base measure (the `normalization_constant`), + - A(θ) is the log‑partition function. + + This class supports: + - Canonical parametrization (θ) via `ExponentialFamilyParametrization`. + - Conjugate prior families. + - Posterior updates and posterior predictive distributions. + - Transformation of the random variable (change of variable with Jacobian). + + The user must supply functions for the log‑partition `log_partition`, + sufficient statistics `sufficient_statistics`, + base measure `normalization_constant`, as well as the support of the distribution, + the natural parameter space and the range of the sufficient statistic. """ def __init__( @@ -83,6 +132,22 @@ def __init__( support_by_parametrization: SupportArg = None, base_score: Callable[[Parametrization, NumericArray], NumericArray] | None = None, ): + """ + Initialize a continuous exponential family distribution. + + Args: + log_partition: Function A(θ) – the log‑partition function. + sufficient_statistics: Function T(x) – the sufficient statistic vector. + normalization_constant: Function h(x) – the base measure. + support: Predicate defining the support of the distribution. + parameter_space: Predicate defining the natural parameter space. + sufficient_statistics_values: Predicate defining the range of T(x). + name: Name of the family. + distr_type: Type of distribution or a callable returning it. + distr_parametrizations: List of parametrization names this family supports. + support_by_parametrization: Callable that returns the support given a parametrization. + base_score: Optional base score function. + """ self._sufficient = sufficient_statistics self._log_partition = log_partition self._normalization = normalization_constant @@ -113,6 +178,16 @@ def __init__( @property def log_density(self) -> ParametrizedFunction: + """ + Log‑density function for the exponential family. + + The function takes a parametrization (must be `ExponentialFamilyParametrization`) + and a point `x`, and returns log f(x|θ). Returns -inf for x outside the support. + + Returns: + Callable[[Parametrization, NumberParameter], Number] + """ + def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Number: parametrization = cast(ExponentialFamilyParametrization, parametrization) parametrization = parametrization.transform_to_base_parametrization() @@ -132,10 +207,28 @@ def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Nu @property def density(self) -> ParametrizedFunction: + """ + Density function (exponentiated log‑density). + + Returns: + Callable[[Parametrization, NumberParameter], Number] + """ return lambda parametrization, x: np.exp(self.log_density(parametrization, x)) @property def conjugate_prior_family(self) -> ContinuousExponentialClassFamily: + """ + Build the conjugate prior family for this exponential family. + + The conjugate prior is an exponential family in the natural parameter θ, + with sufficient statistic [θ, A(θ)] and base measure 1. The resulting + family has its own [log_partition, sufficient_statistics, ...] such that + the posterior updates are given by adding the observed sufficient statistics. + + Returns: + ContinuousExponentialClassFamily: The conjugate prior family. + """ + def conjugate_sufficient( theta: NumberParameter, ) -> NumberParameter: @@ -193,6 +286,21 @@ def transform( self, transform_function: Callable[[NumberParameter], NumberParameter], ) -> ContinuousExponentialClassFamily: + """ + Transform the random variable by a monotonic, differentiable function. + + The new density is obtained via the change‑of‑variable formula. + The sufficient statistic becomes T(transform⁻¹(x)) and the base measure + gains the Jacobian factor. + + Args: + transform_function: Invertible, differentiable function g(x) such that + y = g(x). Must be defined on the original support. + + Returns: + ContinuousExponentialClassFamily: A new family for the transformed variable. + """ + def calculate_jacobian(x: NumberParameter) -> NumberParameter: if not isinstance(x, Iterable): x = np.array([x]) @@ -223,15 +331,15 @@ def new_normalization(x: NumberParameter) -> NumberParameter: @property def _mean(self) -> ParametrizedFunction: + """Compute the mean E[X] by numerical integration over the density.""" + def mean_func(parametrization: Parametrization, x: Any) -> Any: parametrization = cast(ExponentialFamilyParametrization, parametrization) dimension_size = 1 if hasattr(x, "__len__"): dimension_size = len(x) return nquad( - lambda x: ( # type: ignore[arg-type] - np.dot(x, self.density(parametrization, x)) if x in self._support else 0 - ), + lambda x: np.dot(x, self.density(parametrization, x)) if x in self._support else 0, [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -239,15 +347,15 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any: @property def _second_moment(self) -> ParametrizedFunction: + """Compute the second moment E[X²] by numerical integration.""" + def func(parametrization: Parametrization, x: Any) -> Any: parametrization = cast(ExponentialFamilyParametrization, parametrization) dimension_size = 1 if hasattr(x, "__len__"): dimension_size = len(x) return nquad( - lambda x: ( # type: ignore[arg-type] - x**2 * self.density(parametrization, x) if x in self._support else 0 - ), + lambda x: x**2 * self.density(parametrization, x) if x in self._support else 0, [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -255,6 +363,8 @@ def func(parametrization: Parametrization, x: Any) -> Any: @property def _var(self) -> ParametrizedFunction: + """Compute the variance Var[X] = E[X²] - (E[X])².""" + def func(parametrization: Parametrization, x: Any) -> Any: parametrization = cast(ExponentialFamilyParametrization, parametrization) return self._second_moment(parametrization, x) - self._mean(parametrization, x) ** 2 @@ -264,6 +374,22 @@ def func(parametrization: Parametrization, x: Any) -> Any: def posterior_hyperparameters( self, parametrizaiton: ExponentialConjugateHyperparameters, sample: list[Any] ) -> ExponentialConjugateHyperparameters: + """ + Update the conjugate prior hyperparameters given observed data. + + For a conjugate prior with hyperparameters (ν₀, n₀), the posterior + hyperparameters become: + ν = ν₀ + Σ_{i} T(x_i) + n = n₀ + N + + Args: + parametrizaiton: Current conjugate hyperparameters. + sample: List of observations (each can be scalar or array). + + Returns: + ExponentialConjugateHyperparameters: + Updated hyperparameters after incorporating the sample. + """ posterior_effective_suff_stat_value = parametrizaiton.effective_suff_stat_value posterior_effective_sample_size = parametrizaiton.effective_sample_size if hasattr(sample, "__iter__") and not isinstance(sample, str): @@ -283,6 +409,19 @@ def posterior_hyperparameters( @property def posterior_predictive(self) -> ParametricFamily: + """ + Construct the posterior predictive distribution. + + For a conjugate prior, the posterior predictive density of a new observation x + given hyperparameters (ν, n) is: + p(x | ν, n) = h(x) * exp( A(ν) - A(ν + T(x)) ) + where A(·) is the log‑partition function of the conjugate prior family. + + Returns: + ParametricFamily: A family with parametrization `ExponentialConjugateHyperparameters` + and a `pdf` method implementing the posterior predictive density. + """ + def conjugate_log_partition( parametrization: ExponentialConjugateHyperparameters, ) -> NumberParameter: diff --git a/src/pysatl_core/types.py b/src/pysatl_core/types.py index 50ddbe60..bc8e396b 100644 --- a/src/pysatl_core/types.py +++ b/src/pysatl_core/types.py @@ -268,6 +268,9 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray: if not hasattr(x, "__iter__"): x = np.array([x]) + x = np.array(x) + assert len(x) == len(self.intervals) + return all( x_coordinate in interval for interval, x_coordinate in zip(self.intervals, x, strict=True) diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py index d4896ce5..e57a0a93 100644 --- a/tests/unit/families/test_exponential_family.py +++ b/tests/unit/families/test_exponential_family.py @@ -1,4 +1,4 @@ -__author__ = "Leonid Elkin" +__author__ = "Vinogradov Ilya" __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" From 2500678a978ee6d47e6c218576bfa4d780158aed Mon Sep 17 00:00:00 2001 From: domosedy Date: Sun, 17 May 2026 15:14:07 +0300 Subject: [PATCH 8/9] feat(exponential): remove NumberParameter logic --- src/pysatl_core/distributions/support.py | 10 +-- .../families/exponential_family.py | 68 +++++++++++-------- src/pysatl_core/types.py | 30 +++----- .../unit/families/test_exponential_family.py | 11 ++- 4 files changed, 58 insertions(+), 61 deletions(-) diff --git a/src/pysatl_core/distributions/support.py b/src/pysatl_core/distributions/support.py index ce9e1e27..9705f409 100644 --- a/src/pysatl_core/distributions/support.py +++ b/src/pysatl_core/distributions/support.py @@ -32,7 +32,6 @@ Interval1D, IntervalND, Number, - NumberParameter, NumericArray, ) @@ -63,7 +62,8 @@ class ContinuousSupport(Interval1D, Support): """ -class ContinuousNDSupport(IntervalND, Support): +# Support want to have Number as a parameter of contains, but we decided that we should avoid this +class ContinuousNDSupport(IntervalND, Support): # type: ignore[misc] """ Support for continuous distributions represented as an array of intervals. @@ -455,18 +455,18 @@ def is_right_bounded(self) -> bool: @dataclass(slots=True) class SupportByPredicate(Support): - predicate: Callable[[NumberParameter], bool] + predicate: Callable[[NumericArray], bool] @overload def contains(self, x: Number) -> bool: ... @overload def contains(self, x: NumericArray) -> BoolArray: ... - def contains(self, x: NumberParameter) -> bool | BoolArray: + def contains(self, x: NumericArray) -> bool | BoolArray: # type: ignore[misc] return self.predicate(x) def __contains__(self, item: object) -> bool | BoolArray: - return self.contains(cast(NumberParameter, item)) + return self.contains(cast(NumericArray, item)) __all__ = [ diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index 666dc2f9..de36b6b7 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from pysatl_core.distributions.support import Support - from pysatl_core.types import Number, NumberParameter, NumericArray + from pysatl_core.types import Number, NumericArray type ParametrizedFunction = Callable[[Parametrization, Any], Any] type SupportArg = Callable[[Parametrization], Support | None] | None @@ -52,10 +52,10 @@ class ExponentialFamilyParametrization(Parametrization): f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ)) Attributes: - theta (NumberParameter): Natural parameter vector (can be a scalar or array) + theta (NumericArray): Natural parameter vector (can be a scalar or array) """ - theta: NumberParameter + theta: NumericArray def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization: """Return the base parametrization (identity transform for canonical form).""" @@ -120,9 +120,9 @@ class ContinuousExponentialClassFamily(ParametricFamily): def __init__( self, *, - log_partition: Callable[[NumberParameter], NumberParameter], - sufficient_statistics: Callable[[NumberParameter], NumberParameter], - normalization_constant: Callable[[NumberParameter], NumberParameter], + log_partition: Callable[[NumericArray], NumericArray], + sufficient_statistics: Callable[[NumericArray], NumericArray], + normalization_constant: Callable[[NumericArray], Number], support: SupportByPredicate, parameter_space: SupportByPredicate, sufficient_statistics_values: SupportByPredicate, @@ -185,13 +185,13 @@ def log_density(self) -> ParametrizedFunction: and a point `x`, and returns log f(x|θ). Returns -inf for x outside the support. Returns: - Callable[[Parametrization, NumberParameter], Number] + Callable[[Parametrization, NumericArray], Number] """ - def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Number: + def log_density_func(parametrization: Parametrization, x: NumericArray) -> Number: parametrization = cast(ExponentialFamilyParametrization, parametrization) parametrization = parametrization.transform_to_base_parametrization() - if x not in self._support: + if np.array([x]) not in self._support: return -np.inf theta = parametrization.theta @@ -211,7 +211,7 @@ def density(self) -> ParametrizedFunction: Density function (exponentiated log‑density). Returns: - Callable[[Parametrization, NumberParameter], Number] + Callable[[Parametrization, NumericArray], Number] """ return lambda parametrization, x: np.exp(self.log_density(parametrization, x)) @@ -230,8 +230,8 @@ def conjugate_prior_family(self) -> ContinuousExponentialClassFamily: """ def conjugate_sufficient( - theta: NumberParameter, - ) -> NumberParameter: + theta: NumericArray, + ) -> NumericArray: if not hasattr(theta, "__len__"): theta = np.array([theta]) @@ -240,9 +240,9 @@ def conjugate_sufficient( return np.append(theta, self._log_partition(theta)) def conjugate_log_partition( - parametrization: NumberParameter, - ) -> NumberParameter: - def pdf(theta: NumberParameter) -> NumberParameter: + parametrization: NumericArray, + ) -> NumericArray: + def pdf(theta: NumericArray) -> Number: if not hasattr(theta, "__len__"): theta = np.array([theta]) return cast( @@ -259,7 +259,7 @@ def pdf(theta: NumberParameter) -> NumberParameter: lambda x: pdf(x) if x in self._parameter_space else 0, # type: ignore[arg-type] [(float("-inf"), float("+inf"))], )[0] - return cast(np.float64, -np.log(all_value)) + return np.array([cast(np.float64, -np.log(all_value))]) def conjugate_sufficient_accepts( theta: NumericArray, @@ -267,15 +267,17 @@ def conjugate_sufficient_accepts( xi = theta[:-1] nu = theta[-1] - return xi in self._sufficient_statistics_values and nu in ContinuousSupport(0, np.inf) + return xi in self._sufficient_statistics_values and np.array([nu]) in ContinuousSupport( + 0, np.inf + ) return ContinuousExponentialClassFamily( log_partition=conjugate_log_partition, sufficient_statistics=conjugate_sufficient, normalization_constant=lambda _: 1, support=self._parameter_space, - sufficient_statistics_values=self._parameter_space, # TODO: write convex hull for this - parameter_space=SupportByPredicate(predicate=conjugate_sufficient_accepts), # type: ignore[arg-type] + sufficient_statistics_values=self._parameter_space, + parameter_space=SupportByPredicate(predicate=conjugate_sufficient_accepts), name=self.name, distr_type=self._distr_type, distr_parametrizations=self.parametrization_names, @@ -284,7 +286,7 @@ def conjugate_sufficient_accepts( def transform( self, - transform_function: Callable[[NumberParameter], NumberParameter], + transform_function: Callable[[NumericArray], NumericArray], ) -> ContinuousExponentialClassFamily: """ Transform the random variable by a monotonic, differentiable function. @@ -301,20 +303,20 @@ def transform( ContinuousExponentialClassFamily: A new family for the transformed variable. """ - def calculate_jacobian(x: NumberParameter) -> NumberParameter: + def calculate_jacobian(x: NumericArray) -> NumericArray: if not isinstance(x, Iterable): x = np.array([x]) return np.abs(det(jacobian(transform_function, x).df)) - def new_support(x: NumberParameter) -> bool: + def new_support(x: NumericArray) -> bool: return transform_function(x) in self._support - def new_sufficient(x: NumberParameter) -> NumberParameter: + def new_sufficient(x: NumericArray) -> NumericArray: return self._sufficient(transform_function(x)) - def new_normalization(x: NumberParameter) -> NumberParameter: - return self._normalization(x) * calculate_jacobian(x) + def new_normalization(x: NumericArray) -> Number: + return cast(np.float64, self._normalization(x) * calculate_jacobian(x)) return ContinuousExponentialClassFamily( log_partition=self._log_partition, @@ -339,7 +341,11 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any: if hasattr(x, "__len__"): dimension_size = len(x) return nquad( - lambda x: np.dot(x, self.density(parametrization, x)) if x in self._support else 0, + lambda x: ( + np.dot(x, self.density(parametrization, x)) + if np.array([x]) in self._support + else 0 + ), [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -355,7 +361,9 @@ def func(parametrization: Parametrization, x: Any) -> Any: if hasattr(x, "__len__"): dimension_size = len(x) return nquad( - lambda x: x**2 * self.density(parametrization, x) if x in self._support else 0, + lambda x: ( + x**2 * self.density(parametrization, x) if np.array([x]) in self._support else 0 + ), [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -394,7 +402,7 @@ def posterior_hyperparameters( posterior_effective_sample_size = parametrizaiton.effective_sample_size if hasattr(sample, "__iter__") and not isinstance(sample, str): posterior_effective_suff_stat_value += np.sum( - [self._sufficient(x) for x in sample], # type: ignore[arg-type] + [self._sufficient(x) for x in sample], axis=0, ) posterior_effective_sample_size += len(sample) @@ -424,13 +432,13 @@ def posterior_predictive(self) -> ParametricFamily: def conjugate_log_partition( parametrization: ExponentialConjugateHyperparameters, - ) -> NumberParameter: + ) -> NumericArray: conjugate_value = self.conjugate_prior_family._log_partition( parametrization.transform_to_base_parametrization().theta ) return np.exp(conjugate_value) - def posterior_density(parametrization: Parametrization, x: NumberParameter) -> Number: + def posterior_density(parametrization: Parametrization, x: NumericArray) -> Number: parametrization = cast(ExponentialConjugateHyperparameters, parametrization) return cast( np.float32, diff --git a/src/pysatl_core/types.py b/src/pysatl_core/types.py index bc8e396b..fe3cc8b4 100644 --- a/src/pysatl_core/types.py +++ b/src/pysatl_core/types.py @@ -121,9 +121,6 @@ class EuclideanDistributionType(DistributionType): type BoolArray = NDArray[np.bool_] """Type alias for boolean arrays.""" -type NumberParameter = Number | NumericArray -"""Type alias for numeric or list parameter""" - class ContinuousSupportShape1D(Enum): """ @@ -258,27 +255,22 @@ def shape(self) -> ContinuousSupportShape1D: class IntervalND: intervals: list[Interval1D] - @overload - def contains(self, x: Number) -> bool: ... + def contains(self, x: NumericArray) -> bool | BoolArray: + def contains_for_point(point: NumericArray) -> bool: + assert len(point) == len(self.intervals) + return all( + x_coordinate in interval + for interval, x_coordinate in zip(self.intervals, point, strict=True) + ) - @overload - def contains(self, x: NumericArray) -> BoolArray: ... + if len(x.shape) == 1: + return contains_for_point(x) - def contains(self, x: Number | NumericArray) -> bool | BoolArray: - if not hasattr(x, "__iter__"): - x = np.array([x]) - - x = np.array(x) - assert len(x) == len(self.intervals) - - return all( - x_coordinate in interval - for interval, x_coordinate in zip(self.intervals, x, strict=True) - ) + return np.array([contains_for_point(point) for point in x]) def __contains__(self, x: object) -> bool: """Check if a single point is in the interval.""" - return bool(self.contains(cast(Number, x))) + return bool(self.contains(cast(NumericArray, x))) type GenericCharacteristicName = str diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py index e57a0a93..a8b7a254 100644 --- a/tests/unit/families/test_exponential_family.py +++ b/tests/unit/families/test_exponential_family.py @@ -3,7 +3,6 @@ __license__ = "SPDX-License-Identifier: MIT" import itertools -from collections.abc import Iterable from typing import cast import numpy as np @@ -16,7 +15,7 @@ ContinuousExponentialClassFamily, ) from pysatl_core.families.registry import ParametricFamilyRegister -from pysatl_core.types import CharacteristicName, Interval1D, NumberParameter, UnivariateContinuous +from pysatl_core.types import CharacteristicName, Interval1D, NumericArray, UnivariateContinuous def gamma_pdf(alpha: float, beta: float, x: float) -> float: @@ -25,16 +24,14 @@ def gamma_pdf(alpha: float, beta: float, x: float) -> float: @pytest.fixture(scope="function") def conjugate_for_exponential() -> ContinuousExponentialClassFamily: - def transform_function(x: NumberParameter) -> NumberParameter: - if isinstance(x, Iterable): - return np.array([-x[0]]) + def transform_function(x: NumericArray) -> NumericArray: return -x support_neg = SupportByPredicate( - predicate=lambda x: x in ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)]) + predicate=lambda x: np.array([x]) in ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)]) ) support_pos = SupportByPredicate( - predicate=lambda x: x in ContinuousNDSupport(intervals=[Interval1D(0, np.inf)]) + predicate=lambda x: np.array([x]) in ContinuousNDSupport(intervals=[Interval1D(0, np.inf)]) ) fam = ContinuousExponentialClassFamily( log_partition=lambda parametrization: np.log(-parametrization), From ec832d57afdd1da9f42963595f4402c32b01495b Mon Sep 17 00:00:00 2001 From: domosedy Date: Mon, 25 May 2026 22:18:25 +0300 Subject: [PATCH 9/9] refactor(exponential): update support protocol and fix some minor problems --- src/pysatl_core/distributions/support.py | 49 +--- .../families/exponential_family.py | 129 ++++++---- src/pysatl_core/types.py | 22 +- tests/unit/distributions/test_support.py | 5 - .../builtins/continuous/test_exponential.py | 8 +- .../builtins/continuous/test_normal.py | 6 +- .../builtins/continuous/test_uniform.py | 10 +- .../unit/families/test_exponential_family.py | 230 ++++++++++++++++-- 8 files changed, 309 insertions(+), 150 deletions(-) diff --git a/src/pysatl_core/distributions/support.py b/src/pysatl_core/distributions/support.py index 9705f409..9aaf2e72 100644 --- a/src/pysatl_core/distributions/support.py +++ b/src/pysatl_core/distributions/support.py @@ -21,7 +21,6 @@ TYPE_CHECKING, Protocol, cast, - overload, runtime_checkable, ) @@ -47,10 +46,7 @@ class Support(Protocol): Support defines the set of values where a distribution is defined. """ - @overload - def contains(self, x: Number) -> bool: ... - @overload - def contains(self, x: NumericArray) -> BoolArray: ... + def contains(self, x: NumericArray) -> bool | BoolArray: ... class ContinuousSupport(Interval1D, Support): @@ -62,8 +58,7 @@ class ContinuousSupport(Interval1D, Support): """ -# Support want to have Number as a parameter of contains, but we decided that we should avoid this -class ContinuousNDSupport(IntervalND, Support): # type: ignore[misc] +class ContinuousNDSupport(IntervalND, Support): """ Support for continuous distributions represented as an array of intervals. @@ -151,12 +146,7 @@ def __init__(self, points: Iterable[Number], assume_sorted: bool = False) -> Non self._points = arr[unique_mask] - @overload - def contains(self, x: Number) -> bool: ... - @overload - def contains(self, x: NumericArray) -> BoolArray: ... - - def contains(self, x: Number | NumericArray) -> bool | BoolArray: + def contains(self, x: NumericArray) -> bool | BoolArray: """ Check if point(s) are in the support. @@ -185,10 +175,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray: return bool(result) return cast(BoolArray, result) - def __contains__(self, x: object) -> bool: - """Check if a point is in the support.""" - return bool(self.contains(cast(Number, x))) - def iter_points(self) -> Iterator[Number]: """Iterate through all points in the support.""" return iter(self._points) @@ -275,12 +261,7 @@ def __post_init__(self) -> None: if self.modulus <= 0: raise ValueError("modulus must be a positive integer.") - @overload - def contains(self, x: Number) -> bool: ... - @overload - def contains(self, x: NumericArray) -> BoolArray: ... - - def contains(self, x: Number | NumericArray) -> bool | BoolArray: + def contains(self, x: NumericArray) -> bool | BoolArray: """ Check if point(s) are in the integer lattice support. @@ -306,10 +287,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray: return bool(result) return cast(BoolArray, result) - def __contains__(self, x: object) -> bool: - """Check if a point is in the integer lattice support.""" - return bool(self.contains(cast(Number, x))) - def iter_points(self) -> Iterator[int]: """ Iterate through all points in the integer lattice support. @@ -453,28 +430,20 @@ def is_right_bounded(self) -> bool: __iter__ = iter_points -@dataclass(slots=True) -class SupportByPredicate(Support): - predicate: Callable[[NumericArray], bool] - - @overload - def contains(self, x: Number) -> bool: ... - @overload - def contains(self, x: NumericArray) -> BoolArray: ... +@dataclass(frozen=True, slots=True) +class PredicateSupport(Support): + predicate: Callable[[NumericArray], bool | BoolArray] - def contains(self, x: NumericArray) -> bool | BoolArray: # type: ignore[misc] + def contains(self, x: NumericArray) -> bool | BoolArray: return self.predicate(x) - def __contains__(self, item: object) -> bool | BoolArray: - return self.contains(cast(NumericArray, item)) - __all__ = [ # Base support protocol "Support", "ContinuousSupport", "ContinuousNDSupport", - "SupportByPredicate", + "PredicateSupport", # Discrete support protocol and implementations "DiscreteSupport", "ExplicitTableDiscreteSupport", diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index de36b6b7..9f835bf4 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -22,23 +22,26 @@ from pysatl_core.distributions.support import ( ContinuousSupport, - SupportByPredicate, + PredicateSupport, ) from pysatl_core.families.parametric_family import ParametricFamily -from pysatl_core.families.parametrizations import Parametrization, parametrization +from pysatl_core.families.parametrizations import Parametrization, constraint, parametrization from pysatl_core.types import ( CharacteristicName, DistributionType, - GenericCharacteristicName, + Number, + NumericArray, ParametrizationName, UnivariateContinuous, ) if TYPE_CHECKING: from pysatl_core.distributions.support import Support - from pysatl_core.types import Number, NumericArray + from pysatl_core.families.parametric_family import ( + CharacteristicsMap, + ParametricFamilyCharacteristic, + ) - type ParametrizedFunction = Callable[[Parametrization, Any], Any] type SupportArg = Callable[[Parametrization], Support | None] | None @@ -123,12 +126,13 @@ def __init__( log_partition: Callable[[NumericArray], NumericArray], sufficient_statistics: Callable[[NumericArray], NumericArray], normalization_constant: Callable[[NumericArray], Number], - support: SupportByPredicate, - parameter_space: SupportByPredicate, - sufficient_statistics_values: SupportByPredicate, - name: str = "ExponentialFamily", + support: Support, + parameter_space: Support, + sufficient_statistics_values: Support, + name: str, distr_type: DistributionType | Callable[[Parametrization], DistributionType], distr_parametrizations: list[ParametrizationName], + distr_characteristics: CharacteristicsMap | None = None, support_by_parametrization: SupportArg = None, base_score: Callable[[Parametrization, NumericArray], NumericArray] | None = None, ): @@ -145,6 +149,7 @@ def __init__( name: Name of the family. distr_type: Type of distribution or a callable returning it. distr_parametrizations: List of parametrization names this family supports. + distr_characteristics: Additional analytical characteristics to register. support_by_parametrization: Callable that returns the support given a parametrization. base_score: Optional base score function. """ @@ -156,28 +161,33 @@ def __init__( self._parameter_space = parameter_space self._sufficient_statistics_values = sufficient_statistics_values - distr_characteristics: dict[ - GenericCharacteristicName, - dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction, - ] = { + family_characteristics: CharacteristicsMap = { CharacteristicName.PDF: self.density, - CharacteristicName.MEAN_DEFAULT: self._mean, - CharacteristicName.VAR_DEFAULT: self._var, + CharacteristicName.MEAN: self._mean, + CharacteristicName.VAR: self._var, } + merged_characteristics = dict(distr_characteristics or {}) + merged_characteristics.update(family_characteristics) ParametricFamily.__init__( self, name=name, distr_type=distr_type, distr_parametrizations=distr_parametrizations, - distr_characteristics=distr_characteristics, + distr_characteristics=merged_characteristics, support_by_parametrization=support_by_parametrization, base_score=base_score, ) - parametrization(family=self, name="theta")(ExponentialFamilyParametrization) + + @parametrization(family=self, name="theta") + class ThetaParametrization(ExponentialFamilyParametrization): + @constraint(description="theta belongs to parameter_space") + def check_theta_in_parameter_space(self) -> bool: + theta = np.atleast_1d(np.asarray(self.theta, dtype=float)) + return bool(self.__family__._parameter_space.contains(theta)) # type: ignore[attr-defined] @property - def log_density(self) -> ParametrizedFunction: + def log_density(self) -> ParametricFamilyCharacteristic[NumericArray, Number]: """ Log‑density function for the exponential family. @@ -191,7 +201,7 @@ def log_density(self) -> ParametrizedFunction: def log_density_func(parametrization: Parametrization, x: NumericArray) -> Number: parametrization = cast(ExponentialFamilyParametrization, parametrization) parametrization = parametrization.transform_to_base_parametrization() - if np.array([x]) not in self._support: + if not self._support.contains(np.array([x])): return -np.inf theta = parametrization.theta @@ -206,14 +216,19 @@ def log_density_func(parametrization: Parametrization, x: NumericArray) -> Numbe return log_density_func @property - def density(self) -> ParametrizedFunction: + def density(self) -> ParametricFamilyCharacteristic[NumericArray, Number]: """ Density function (exponentiated log‑density). Returns: Callable[[Parametrization, NumericArray], Number] """ - return lambda parametrization, x: np.exp(self.log_density(parametrization, x)) + log_density = cast(Callable[[Parametrization, NumericArray], Number], self.log_density) + + def density_func(parametrization: Parametrization, x: NumericArray) -> Number: + return cast(Number, np.exp(log_density(parametrization, x))) + + return density_func @property def conjugate_prior_family(self) -> ContinuousExponentialClassFamily: @@ -235,7 +250,7 @@ def conjugate_sufficient( if not hasattr(theta, "__len__"): theta = np.array([theta]) - if theta not in self._parameter_space: + if not self._parameter_space.contains(theta): return np.full(len(theta) + 1, float("-inf")) return np.append(theta, self._log_partition(theta)) @@ -255,10 +270,13 @@ def pdf(theta: NumericArray) -> Number: ).item(), ) - all_value = nquad( - lambda x: pdf(x) if x in self._parameter_space else 0, # type: ignore[arg-type] - [(float("-inf"), float("+inf"))], - )[0] + def integrand(x: float) -> float: + theta = np.asarray([x], dtype=float) + if not self._parameter_space.contains(theta): + return 0.0 + return float(pdf(theta)) + + all_value = nquad(integrand, [(float("-inf"), float("+inf"))])[0] return np.array([cast(np.float64, -np.log(all_value))]) def conjugate_sufficient_accepts( @@ -267,8 +285,8 @@ def conjugate_sufficient_accepts( xi = theta[:-1] nu = theta[-1] - return xi in self._sufficient_statistics_values and np.array([nu]) in ContinuousSupport( - 0, np.inf + return bool(self._sufficient_statistics_values.contains(xi)) and bool( + ContinuousSupport(0, np.inf).contains(np.array([nu])) ) return ContinuousExponentialClassFamily( @@ -277,7 +295,7 @@ def conjugate_sufficient_accepts( normalization_constant=lambda _: 1, support=self._parameter_space, sufficient_statistics_values=self._parameter_space, - parameter_space=SupportByPredicate(predicate=conjugate_sufficient_accepts), + parameter_space=PredicateSupport(predicate=conjugate_sufficient_accepts), name=self.name, distr_type=self._distr_type, distr_parametrizations=self.parametrization_names, @@ -305,12 +323,14 @@ def transform( def calculate_jacobian(x: NumericArray) -> NumericArray: if not isinstance(x, Iterable): - x = np.array([x]) + x = np.array([x], dtype=float) + else: + x = np.atleast_1d(np.asarray(x, dtype=float)) return np.abs(det(jacobian(transform_function, x).df)) def new_support(x: NumericArray) -> bool: - return transform_function(x) in self._support + return bool(self._support.contains(transform_function(x))) def new_sufficient(x: NumericArray) -> NumericArray: return self._sufficient(transform_function(x)) @@ -322,7 +342,7 @@ def new_normalization(x: NumericArray) -> Number: log_partition=self._log_partition, sufficient_statistics=new_sufficient, normalization_constant=new_normalization, - support=SupportByPredicate(predicate=new_support), + support=PredicateSupport(predicate=new_support), parameter_space=self._parameter_space, sufficient_statistics_values=self._sufficient_statistics_values, name=f"Transformed{self._name}", @@ -332,50 +352,50 @@ def new_normalization(x: NumericArray) -> Number: ) @property - def _mean(self) -> ParametrizedFunction: + def _mean(self) -> ParametricFamilyCharacteristic[Any, Any]: """Compute the mean E[X] by numerical integration over the density.""" - def mean_func(parametrization: Parametrization, x: Any) -> Any: + def mean_func(parametrization: Parametrization) -> Any: parametrization = cast(ExponentialFamilyParametrization, parametrization) - dimension_size = 1 - if hasattr(x, "__len__"): - dimension_size = len(x) + density = cast(Callable[[Parametrization, NumericArray], Number], self.density) return nquad( lambda x: ( - np.dot(x, self.density(parametrization, x)) - if np.array([x]) in self._support + np.dot(x, density(parametrization, x)) + if self._support.contains(np.array([x])) else 0 ), - [(float("-inf"), float("inf"))] * dimension_size, + [(float("-inf"), float("inf"))], )[0] return mean_func @property - def _second_moment(self) -> ParametrizedFunction: + def _second_moment(self) -> ParametricFamilyCharacteristic[Any, Any]: """Compute the second moment E[X²] by numerical integration.""" - def func(parametrization: Parametrization, x: Any) -> Any: + def func(parametrization: Parametrization) -> Any: parametrization = cast(ExponentialFamilyParametrization, parametrization) - dimension_size = 1 - if hasattr(x, "__len__"): - dimension_size = len(x) + density = cast(Callable[[Parametrization, NumericArray], Number], self.density) return nquad( lambda x: ( - x**2 * self.density(parametrization, x) if np.array([x]) in self._support else 0 + x**2 * density(parametrization, x) + if self._support.contains(np.array([x])) + else 0 ), - [(float("-inf"), float("inf"))] * dimension_size, + [(float("-inf"), float("inf"))], )[0] return func @property - def _var(self) -> ParametrizedFunction: + def _var(self) -> ParametricFamilyCharacteristic[Any, Any]: """Compute the variance Var[X] = E[X²] - (E[X])².""" - def func(parametrization: Parametrization, x: Any) -> Any: + def func(parametrization: Parametrization) -> Any: parametrization = cast(ExponentialFamilyParametrization, parametrization) - return self._second_moment(parametrization, x) - self._mean(parametrization, x) ** 2 + second_moment = cast(Callable[[Parametrization], Any], self._second_moment) + mean = cast(Callable[[Parametrization], Any], self._mean) + return second_moment(parametrization) - mean(parametrization) ** 2 return func @@ -398,7 +418,10 @@ def posterior_hyperparameters( ExponentialConjugateHyperparameters: Updated hyperparameters after incorporating the sample. """ - posterior_effective_suff_stat_value = parametrizaiton.effective_suff_stat_value + posterior_effective_suff_stat_value = np.array( + parametrizaiton.effective_suff_stat_value, + copy=True, + ) posterior_effective_sample_size = parametrizaiton.effective_sample_size if hasattr(sample, "__iter__") and not isinstance(sample, str): posterior_effective_suff_stat_value += np.sum( @@ -445,9 +468,7 @@ def posterior_density(parametrization: Parametrization, x: NumericArray) -> Numb self._normalization(x) * conjugate_log_partition(parametrization) / conjugate_log_partition( - self.posterior_hyperparameters( - parametrizaiton=parametrization, sample=[self._sufficient(x)] - ) + self.posterior_hyperparameters(parametrizaiton=parametrization, sample=[x]) ), ) diff --git a/src/pysatl_core/types.py b/src/pysatl_core/types.py index fe3cc8b4..33de0af7 100644 --- a/src/pysatl_core/types.py +++ b/src/pysatl_core/types.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from enum import Enum, StrEnum, auto from math import inf -from typing import TYPE_CHECKING, Any, cast, overload +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from pysatl_core.distributions.computations.computation import ( @@ -179,13 +179,7 @@ def __post_init__(self) -> None: if self.right == inf and self.right_closed: object.__setattr__(self, "right_closed", False) - @overload - def contains(self, x: Number) -> bool: ... - - @overload - def contains(self, x: NumericArray) -> BoolArray: ... - - def contains(self, x: Number | NumericArray) -> bool | BoolArray: + def contains(self, x: NumericArray) -> bool | BoolArray: """ Check if point(s) are contained in the interval. @@ -210,10 +204,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray: return result - def __contains__(self, x: object) -> bool: - """Check if a single point is in the interval.""" - return bool(self.contains(cast(Number, x))) - @property def is_empty(self) -> bool: """Check if the interval is empty.""" @@ -259,7 +249,7 @@ def contains(self, x: NumericArray) -> bool | BoolArray: def contains_for_point(point: NumericArray) -> bool: assert len(point) == len(self.intervals) return all( - x_coordinate in interval + bool(interval.contains(np.asarray(x_coordinate))) for interval, x_coordinate in zip(self.intervals, point, strict=True) ) @@ -268,10 +258,6 @@ def contains_for_point(point: NumericArray) -> bool: return np.array([contains_for_point(point) for point in x]) - def __contains__(self, x: object) -> bool: - """Check if a single point is in the interval.""" - return bool(self.contains(cast(NumericArray, x))) - type GenericCharacteristicName = str """Type alias for characteristic names (e.g., 'pdf', 'cdf').""" @@ -339,8 +325,6 @@ class CharacteristicName(StrEnum): CDF = "cdf" PPF = "ppf" PMF = "pmf" - MEAN_DEFAULT = "MEAN_DEFAULT" # defined in class implementation of mean - VAR_DEFAULT = "VAR_DEFAULT" # defined in class implementation of var LPDF = "lpdf" # unimplemented in graph yet CF = "cf" # unimplemented in graph yet SF = "sf" # unimplemented in graph yet diff --git a/tests/unit/distributions/test_support.py b/tests/unit/distributions/test_support.py index 11d70bf4..a199ac3f 100644 --- a/tests/unit/distributions/test_support.py +++ b/tests/unit/distributions/test_support.py @@ -40,7 +40,6 @@ class TestContinuousSupport: ], ) def test_continuous_support_contains_scalar(self, point, expected_result): - assert (point in self.support_example) is expected_result assert self.support_example.contains(point) is expected_result @pytest.mark.parametrize("infinity", [-inf, inf]) @@ -48,7 +47,6 @@ def test_continuous_support_doesnt_contain_inf(self, infinity): # inf isn't considered as a number # but as a limit so support doesn't contain it even if it's a real line support = ContinuousSupport() - assert infinity not in support assert support.contains(infinity) is False @pytest.mark.parametrize( @@ -59,7 +57,6 @@ def test_continuous_support_doesnt_contain_inf(self, infinity): ], ) def test_continuous_support_contains_array(self, points, expected_result): - # np.array doesn't have `in`(__contains__) syntax result = self.support_example.contains(points) assert isinstance(result, np.ndarray) assert result.tolist() == expected_result @@ -104,7 +101,6 @@ def test_table_is_sorted_and_deduplicated(self): ], ) def test_contains_scalar(self, point, expected_result): - assert (point in self.support_example) is expected_result assert self.support_example.contains(point) is expected_result @pytest.mark.parametrize( @@ -196,7 +192,6 @@ def test_invalid_modulus_raises(self): ) def test_contains_scalar(self, support_name, point, expected_result): support = self.support_examples[support_name] - assert (point in support) is expected_result assert support.contains(point) is expected_result @pytest.mark.parametrize( diff --git a/tests/unit/families/builtins/continuous/test_exponential.py b/tests/unit/families/builtins/continuous/test_exponential.py index 6363c47b..e92ce656 100644 --- a/tests/unit/families/builtins/continuous/test_exponential.py +++ b/tests/unit/families/builtins/continuous/test_exponential.py @@ -206,10 +206,10 @@ def test_exponential_support(self): assert not dist.support.right_closed # Test containment - assert dist.support.contains(0.0) is True - assert dist.support.contains(1.0) is True - assert dist.support.contains(-0.1) is False - assert dist.support.contains(float("inf")) is False + assert dist.support.contains(np.asarray(0.0)) is True + assert dist.support.contains(np.asarray(1.0)) is True + assert dist.support.contains(np.asarray(-0.1)) is False + assert dist.support.contains(np.asarray(float("inf"))) is False # Test array test_points = np.array([-0.1, 0.0, 1.0, 10.0]) diff --git a/tests/unit/families/builtins/continuous/test_normal.py b/tests/unit/families/builtins/continuous/test_normal.py index ca593495..0f47a422 100644 --- a/tests/unit/families/builtins/continuous/test_normal.py +++ b/tests/unit/families/builtins/continuous/test_normal.py @@ -226,9 +226,9 @@ def test_normal_support(self): assert not dist.support.left_closed assert not dist.support.right_closed - assert dist.support.contains(0) is True - assert dist.support.contains(float("inf")) is False - assert dist.support.contains(float("-inf")) is False + assert dist.support.contains(np.asarray(0)) is True + assert dist.support.contains(np.asarray(float("inf"))) is False + assert dist.support.contains(np.asarray(float("-inf"))) is False test_points = np.array([-500, 0, 5]) results = dist.support.contains(test_points) diff --git a/tests/unit/families/builtins/continuous/test_uniform.py b/tests/unit/families/builtins/continuous/test_uniform.py index 8d63f7fd..d0cffc38 100644 --- a/tests/unit/families/builtins/continuous/test_uniform.py +++ b/tests/unit/families/builtins/continuous/test_uniform.py @@ -230,11 +230,11 @@ def test_uniform_support(self): assert dist.support.right_closed # Test containment - assert dist.support.contains(2.0) is True - assert dist.support.contains(5.0) is True - assert dist.support.contains(3.5) is True - assert dist.support.contains(1.9) is False - assert dist.support.contains(5.1) is False + assert dist.support.contains(np.asarray(2.0)) is True + assert dist.support.contains(np.asarray(5.0)) is True + assert dist.support.contains(np.asarray(3.5)) is True + assert dist.support.contains(np.asarray(1.9)) is False + assert dist.support.contains(np.asarray(5.1)) is False # Test array test_points = np.array([1.9, 2.0, 3.5, 5.0, 5.1]) diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py index a8b7a254..d03667c1 100644 --- a/tests/unit/families/test_exponential_family.py +++ b/tests/unit/families/test_exponential_family.py @@ -1,3 +1,5 @@ +from collections.abc import Callable + __author__ = "Vinogradov Ilya" __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" @@ -10,31 +12,52 @@ import scipy from numpy.testing import assert_allclose -from pysatl_core.distributions.support import ContinuousNDSupport, SupportByPredicate +from pysatl_core.distributions.support import ContinuousNDSupport, PredicateSupport from pysatl_core.families import ( ContinuousExponentialClassFamily, + ExponentialConjugateHyperparameters, + ExponentialFamilyParametrization, ) from pysatl_core.families.registry import ParametricFamilyRegister -from pysatl_core.types import CharacteristicName, Interval1D, NumericArray, UnivariateContinuous +from pysatl_core.types import ( + CharacteristicName, + Interval1D, + Number, + NumericArray, + UnivariateContinuous, +) def gamma_pdf(alpha: float, beta: float, x: float) -> float: return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() # type: ignore[attr-defined] -@pytest.fixture(scope="function") -def conjugate_for_exponential() -> ContinuousExponentialClassFamily: - def transform_function(x: NumericArray) -> NumericArray: - return -x +def lomax_pdf(shape: float, scale: float, x: float) -> float: + return scipy.stats.lomax(c=shape, scale=scale).pdf(x).item() # type: ignore[attr-defined] + + +def exponential_log_partition(parametrization): + return np.log(-parametrization) + - support_neg = SupportByPredicate( - predicate=lambda x: np.array([x]) in ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)]) +def _make_exponential_family() -> ContinuousExponentialClassFamily: + support_neg = PredicateSupport( + predicate=lambda x: bool( + ContinuousNDSupport( + intervals=[Interval1D(-np.inf, 0, left_closed=False, right_closed=False)] + ).contains(np.array([x])) + ) ) - support_pos = SupportByPredicate( - predicate=lambda x: np.array([x]) in ContinuousNDSupport(intervals=[Interval1D(0, np.inf)]) + support_pos = PredicateSupport( + predicate=lambda x: bool( + ContinuousNDSupport( + intervals=[Interval1D(0, np.inf, left_closed=False, right_closed=False)] + ).contains(np.array([x])) + ) ) - fam = ContinuousExponentialClassFamily( - log_partition=lambda parametrization: np.log(-parametrization), + return ContinuousExponentialClassFamily( + name="ExponentialFamily", + log_partition=exponential_log_partition, sufficient_statistics=lambda x: x, normalization_constant=lambda _: 1, parameter_space=support_neg, @@ -44,6 +67,18 @@ def transform_function(x: NumericArray) -> NumericArray: distr_parametrizations=["theta"], ) + +@pytest.fixture(scope="function") +def exponential_family() -> ContinuousExponentialClassFamily: + return _make_exponential_family() + + +@pytest.fixture(scope="function") +def conjugate_for_exponential() -> ContinuousExponentialClassFamily: + def transform_function(x: NumericArray) -> NumericArray: + return -x + + fam = _make_exponential_family() conjugate_fam = fam.conjugate_prior_family.transform(transform_function) ParametricFamilyRegister().register(conjugate_fam) return cast( @@ -52,6 +87,165 @@ def transform_function(x: NumericArray) -> NumericArray: ) +def test_log_density_matches_exponential_form( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + params = ExponentialFamilyParametrization(theta=np.array([-2.0])) + log_density_func = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + exponential_family.log_density, + ) + density_func = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + exponential_family.density, + ) + + log_density = log_density_func(params, np.asarray(0.5)) + density = density_func(params, np.asarray(0.5)) + + assert log_density == pytest.approx(np.log(2.0) - 1.0) + assert density == pytest.approx(2.0 * np.exp(-1.0)) + + +def test_constructor_merges_custom_characteristics() -> None: + support_neg = PredicateSupport( + predicate=lambda x: bool( + ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)]).contains(np.array([x])) + ) + ) + support_pos = PredicateSupport( + predicate=lambda x: bool( + ContinuousNDSupport(intervals=[Interval1D(0, np.inf)]).contains(np.array([x])) + ) + ) + family = ContinuousExponentialClassFamily( + name="ExponentialFamily", + log_partition=exponential_log_partition, + sufficient_statistics=lambda x: x, + normalization_constant=lambda _: 1, + parameter_space=support_neg, + sufficient_statistics_values=support_pos, + support=support_pos, + distr_type=UnivariateContinuous, + distr_parametrizations=["theta"], + distr_characteristics={CharacteristicName.CDF: lambda _params, x: x / (1 + x)}, + ) + + assert CharacteristicName.CDF in family.distr_characteristics + assert CharacteristicName.PDF in family.distr_characteristics + assert CharacteristicName.MEAN in family.distr_characteristics + assert CharacteristicName.VAR in family.distr_characteristics + + +def test_log_density_is_minus_infinity_outside_support( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + params = ExponentialFamilyParametrization(theta=np.array([-2.0])) + log_density_func = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + exponential_family.log_density, + ) + density_func = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + exponential_family.density, + ) + + assert log_density_func(params, np.asarray(-0.1)) == -np.inf + assert density_func(params, np.asarray(-0.1)) == 0.0 + + +def test_distribution_rejects_theta_outside_parameter_space( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + with pytest.raises(ValueError, match="theta belongs to parameter_space"): + exponential_family(theta=np.array([0.0]), parametrization_name="theta") + + +def test_transform_with_negation_moves_support_and_preserves_density( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + transformed = exponential_family.transform(lambda x: -x) + params = ExponentialFamilyParametrization(theta=np.array([-1.5])) + transformed_log_density = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + transformed.log_density, + ) + log_density = cast( + Callable[[ExponentialFamilyParametrization, NumericArray], Number], + exponential_family.log_density, + ) + + assert transformed.name == "TransformedExponentialFamily" + assert transformed_log_density(params, np.asarray(-2.0)) == pytest.approx( + log_density(params, np.asarray(2.0)) + ) + assert transformed_log_density(params, np.asarray(2.0)) == -np.inf + + +def test_posterior_hyperparameters_updates_sample_without_mutating_input( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + prior = ExponentialConjugateHyperparameters( + effective_suff_stat_value=np.array([3.0]), + effective_sample_size=2.0, + ) + + posterior = exponential_family.posterior_hyperparameters(prior, sample=[0.5, 1.5]) + + assert_allclose(posterior.effective_suff_stat_value, np.array([5.0])) + assert posterior.effective_sample_size == 4.0 + assert_allclose(prior.effective_suff_stat_value, np.array([3.0])) + assert prior.effective_sample_size == 2.0 + + +def test_posterior_hyperparameters_accepts_single_observation( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + prior = ExponentialConjugateHyperparameters( + effective_suff_stat_value=np.array([3.0]), + effective_sample_size=2.0, + ) + + posterior = exponential_family.posterior_hyperparameters(prior, sample=0.5) # type: ignore[arg-type] + + assert_allclose(posterior.effective_suff_stat_value, np.array([3.5])) + assert posterior.effective_sample_size == 3.0 + + +def test_posterior_predictive_builds_family_with_posterior_parametrization( + exponential_family: ContinuousExponentialClassFamily, +) -> None: + predictive_family = exponential_family.posterior_predictive + + assert predictive_family.name == "PosteriorPredictiveExponentialFamily" + assert predictive_family.parametrization_names == ["posterior"] + assert predictive_family.get_parametrization("posterior") is ExponentialConjugateHyperparameters + assert CharacteristicName.PDF in predictive_family.distr_characteristics + + +@pytest.mark.parametrize( + ("xi", "nu"), + itertools.product((2.0, 3.0, 4.0), (2.0, 3.0, 4.0)), +) +def test_posterior_predictive_matches_lomax_density( + exponential_family: ContinuousExponentialClassFamily, + xi: float, + nu: float, +) -> None: + predictive = exponential_family.posterior_predictive.distribution( + parametrization_name="posterior", + effective_suff_stat_value=np.array([xi]), + effective_sample_size=nu, + ) + pdf = predictive.computation_strategy.query_method("pdf", distr=predictive) + x_values = np.array([0.0, 0.5, 1.5, 3.0, 6.0]) + + actual = np.asarray([pdf(x) for x in x_values], dtype=float).reshape(-1) + expected = np.asarray([lomax_pdf(shape=nu + 1, scale=xi, x=x) for x in x_values]) + + assert_allclose(actual, expected, rtol=1e-6) + + @pytest.mark.parametrize( ("theta1", "theta2"), itertools.product(range(2, 5), range(2, 5)), @@ -81,10 +275,8 @@ def test_exponential_mean(theta1, theta2, conjugate_for_exponential): beta = theta1 exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") - mean = exponential.computation_strategy.query_method( - CharacteristicName.MEAN_DEFAULT, distr=exponential - ) - assert np.isclose(mean(12), alpha / beta, rtol=1e-6) + mean = exponential.computation_strategy.query_method(CharacteristicName.MEAN, distr=exponential) + assert np.isclose(mean(), alpha / beta, rtol=1e-6) @pytest.mark.parametrize( @@ -98,7 +290,5 @@ def test_exponential_var(theta1, theta2, conjugate_for_exponential): beta = theta1 exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") - var = exponential.computation_strategy.query_method( - CharacteristicName.VAR_DEFAULT, distr=exponential - ) - assert np.isclose(var(12), alpha / beta**2, rtol=1e-6) + var = exponential.computation_strategy.query_method(CharacteristicName.VAR, distr=exponential) + assert np.isclose(var(), alpha / beta**2, rtol=1e-6)