Skip to content
62 changes: 36 additions & 26 deletions src/pysatl_core/distributions/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,25 @@
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from collections.abc import Callable
from dataclasses import dataclass
from math import floor
from typing import TYPE_CHECKING, Protocol, cast, overload, runtime_checkable
from typing import (
TYPE_CHECKING,
Protocol,
cast,
runtime_checkable,
)

import numpy as np

from pysatl_core.types import BoolArray, Interval1D, Number, NumericArray
from pysatl_core.types import (
BoolArray,
Interval1D,
IntervalND,
Number,
NumericArray,
)

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
Expand All @@ -34,10 +46,7 @@ class Support(Protocol):
Support defines the set of values where a distribution is defined.
"""

@overload
def contains(self, x: Number) -> bool: ...
@overload
def contains(self, x: NumericArray) -> BoolArray: ...
def contains(self, x: NumericArray) -> bool | BoolArray: ...


class ContinuousSupport(Interval1D, Support):
Expand All @@ -49,6 +58,15 @@ class ContinuousSupport(Interval1D, Support):
"""


class ContinuousNDSupport(IntervalND, Support):
"""
Support for continuous distributions represented as an array of intervals.

This class inherits from IntervalND and implements the Support protocol
for continuous distributions defined on a list of intervals [left, right].
"""


@runtime_checkable
class DiscreteSupport(Support, Protocol):
"""
Expand Down Expand Up @@ -128,12 +146,7 @@ def __init__(self, points: Iterable[Number], assume_sorted: bool = False) -> Non

self._points = arr[unique_mask]

@overload
def contains(self, x: Number) -> bool: ...
@overload
def contains(self, x: NumericArray) -> BoolArray: ...

def contains(self, x: Number | NumericArray) -> bool | BoolArray:
def contains(self, x: NumericArray) -> bool | BoolArray:
"""
Check if point(s) are in the support.

Expand Down Expand Up @@ -162,10 +175,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray:
return bool(result)
return cast(BoolArray, result)

def __contains__(self, x: object) -> bool:
"""Check if a point is in the support."""
return bool(self.contains(cast(Number, x)))

def iter_points(self) -> Iterator[Number]:
"""Iterate through all points in the support."""
return iter(self._points)
Expand Down Expand Up @@ -252,12 +261,7 @@ def __post_init__(self) -> None:
if self.modulus <= 0:
raise ValueError("modulus must be a positive integer.")

@overload
def contains(self, x: Number) -> bool: ...
@overload
def contains(self, x: NumericArray) -> BoolArray: ...

def contains(self, x: Number | NumericArray) -> bool | BoolArray:
def contains(self, x: NumericArray) -> bool | BoolArray:
"""
Check if point(s) are in the integer lattice support.

Expand All @@ -283,10 +287,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray:
return bool(result)
return cast(BoolArray, result)

def __contains__(self, x: object) -> bool:
"""Check if a point is in the integer lattice support."""
return bool(self.contains(cast(Number, x)))

def iter_points(self) -> Iterator[int]:
"""
Iterate through all points in the integer lattice support.
Expand Down Expand Up @@ -430,10 +430,20 @@ def is_right_bounded(self) -> bool:
__iter__ = iter_points


@dataclass(frozen=True, slots=True)
class PredicateSupport(Support):
predicate: Callable[[NumericArray], bool | BoolArray]

def contains(self, x: NumericArray) -> bool | BoolArray:
return self.predicate(x)


__all__ = [
# Base support protocol
"Support",
"ContinuousSupport",
"ContinuousNDSupport",
"PredicateSupport",
# Discrete support protocol and implementations
"DiscreteSupport",
"ExplicitTableDiscreteSupport",
Expand Down
9 changes: 9 additions & 0 deletions src/pysatl_core/families/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from .builtins import __all__ as _builtins_all
from .configuration import configure_families_register
from .distribution import ParametricFamilyDistribution
from .exponential_family import (
# CanonicalContinuousExponentialClassFamily,
ContinuousExponentialClassFamily,
ExponentialConjugateHyperparameters,
ExponentialFamilyParametrization,
)
from .parametric_family import ParametricFamily
from .parametrizations import (
Parametrization,
Expand All @@ -34,6 +40,9 @@
"configure_families_register",
# builtins
*_builtins_all,
"ContinuousExponentialClassFamily",
"ExponentialFamilyParametrization",
"ExponentialConjugateHyperparameters",
]

del _builtins_all
Loading
Loading