-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_interface.py
More file actions
56 lines (40 loc) · 2.13 KB
/
Copy pathmodel_interface.py
File metadata and controls
56 lines (40 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import Protocol, List, Dict, Tuple, Callable, NamedTuple
import numpy as np
import pandas as pd
# --- Data Structures ---
class EStepResult(NamedTuple):
"""Holds the results of the E-step for a single subject."""
pid_index: int # Original index for re-sorting
map_estimate: np.ndarray # MAP estimate in *transformed* space
hessian_inv: np.ndarray # Approx covariance in *transformed* space
neg_log_posterior: float # Value of the objective function at MAP
class FitResult(NamedTuple):
"""Holds the final results of the EM fitting procedure."""
group_params: Dict[str, Dict[str, float]] # Final group mu & var per param
individual_params: List[EStepResult] # List of EStepResult for each subject
bic_int: float # Integrated BIC score
convergence_iterations: int # Number of iterations run
converged: bool # Did the algorithm converge?
final_objective: float # Final value of EM objective proxy
# --- Model Interface Definition ---
class ModelLikelihoodInfo(NamedTuple):
"""Information returned by a model required by the EM algorithm."""
param_names: List[str] # List of parameter names in order
transform_funcs: Dict[str, Dict[str, Callable]] # {'param': {'forward': f, 'inverse': inv_f}}
log_likelihood_func: Callable[[np.ndarray, pd.DataFrame], float] # Takes transformed_params, data -> loglik
class ModelProtocol(Protocol):
"""Defines the interface required for any model used with the EM fitter."""
def get_likelihood_info(self) -> ModelLikelihoodInfo:
"""
Returns information needed to calculate likelihood and handle parameters.
"""
...
def get_param_bounds(self) -> Dict[str, Tuple[float, float]]:
"""
Optional: Returns reasonable bounds for parameters in their *native* space.
Can be used for optimizer constraints if not transforming, or for sanity checks.
Return None or empty dict if not applicable.
Example: {'alpha': (0, 1), 'beta': (0, None)} # None means no upper/lower bound
"""
# Default implementation if not needed by a specific model
return {}