Skip to content

Commit dff6484

Browse files
authored
Add high-level APIs, benchmarks, and identity tests (#28)
1 parent 6feb5e0 commit dff6484

10 files changed

Lines changed: 834 additions & 64 deletions

File tree

README.md

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,26 @@ pip install -e .
2525
### Usage
2626

2727
```python
28-
from alsgls import als_gls, simulate_sur, nll_per_row, XB_from_Blist
28+
from alsgls import ALSGLS, ALSGLSSystem, simulate_sur
2929

3030
Xs_tr, Y_tr, Xs_te, Y_te = simulate_sur(N_tr=240, N_te=120, K=60, p=3, k=4)
31-
B, F, D, mem, _ = als_gls(Xs_tr, Y_tr, k=4)
32-
Yhat_te = XB_from_Blist(Xs_te, B)
33-
nll = nll_per_row(Y_te - Yhat_te, F, D)
31+
32+
# Scikit-learn style estimator
33+
est = ALSGLS(rank="auto", max_sweeps=12)
34+
est.fit(Xs_tr, Y_tr)
35+
test_score = est.score(Xs_te, Y_te) # negative test NLL per observation
36+
37+
# Statsmodels-style system interface
38+
system = {f"eq{j}": (Y_tr[:, j], Xs_tr[j]) for j in range(Y_tr.shape[1])}
39+
sys_model = ALSGLSSystem(system, rank="auto")
40+
sys_results = sys_model.fit()
41+
params = sys_results.params_as_series() # pandas optional
3442
```
3543

36-
See `examples/compare_als_vs_em.py` for a complete ALS versus EM comparison.
44+
See `examples/compare_als_vs_em.py` for a complete ALS versus EM comparison. The
45+
`benchmarks/compare_sur.py` script contrasts ALS-GLS with `statsmodels` and
46+
`linearmodels` SUR implementations on matched simulation grids while recording
47+
peak memory (via Memray, Fil, or the POSIX RSS high-water mark).
3748

3849
### Documentation and notebooks
3950

@@ -63,3 +74,24 @@ To show the magnitude, we ran a Monte‑Carlo experiment with N = 300 observat
6374

6475
Statistically, the two estimators are indistinguishable (paired‑test p ≥ 0.14). Computationally, ALS needs only a few megabytes whereas EM needs tens to hundreds.
6576

77+
### Defaults, tuning knobs, and failure modes
78+
79+
- **Rank (`k`)** – By default the high-level APIs pick `min(8, ceil(K / 10))`, a
80+
conservative fraction of the number of equations. Increase `rank` if the
81+
cross-equation correlation matrix is slow to decay; decrease it when the
82+
diagonal dominates.
83+
- **ALS ridge terms (`lam_F`, `lam_B`)** – Defaults to `1e-3` for both the
84+
latent-factor and regression updates; raise them slightly (e.g. `1e-2`) if CG
85+
struggles to converge or the NLL trace plateaus early.
86+
- **Noise floor (`d_floor`)** – Keeps the diagonal component positive; the
87+
default `1e-8` protects against breakdowns when an equation is nearly
88+
deterministic. Increase it in highly ill-conditioned settings.
89+
- **Stopping criteria** – ALS stops when the relative drop in NLL per sweep is
90+
below `1e-6` (configurable via `rel_tol`) or after `max_sweeps`. Inspect
91+
`info["nll_trace"]` to diagnose stagnation.
92+
- **Possible failures** – Large condition numbers or nearly-collinear regressors
93+
can make the β-step CG solve slow; adjust `cg_tol`/`cg_maxit`, add stronger
94+
ridge, or re-scale predictors. If `info["accept_t"]` stays at zero and the
95+
NLL does not improve, the factor rank may be too large relative to the sample
96+
size.
97+

alsgls/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
from .als import als_gls
2+
from .api import ALSGLS, ALSGLSSystem, ALSGLSSystemResults
23
from .em import em_gls
34
from .metrics import mse, nll_per_row
4-
from .sim import simulate_sur, simulate_gls
55
from .ops import XB_from_Blist
6+
from .sim import simulate_gls, simulate_sur
67

7-
__all__ = ["als_gls", "em_gls", "mse", "nll_per_row", "simulate_sur", "simulate_gls", "XB_from_Blist"]
8+
__all__ = [
9+
"ALSGLS",
10+
"ALSGLSSystem",
11+
"ALSGLSSystemResults",
12+
"XB_from_Blist",
13+
"als_gls",
14+
"em_gls",
15+
"mse",
16+
"nll_per_row",
17+
"simulate_gls",
18+
"simulate_sur",
19+
]

alsgls/als.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def als_gls(
2424
*,
2525
scale_correct: bool = True,
2626
scale_floor: float = 1e-8,
27+
rel_tol: float = 1e-6,
2728
):
2829
"""
2930
Alternating-least-squares GLS with low-rank-plus-diagonal covariance.
@@ -51,6 +52,7 @@ def als_gls(
5152
cg_tol : float CG relative tolerance
5253
scale_correct : bool if True, try guarded MLE scale fix on Σ each sweep
5354
scale_floor : float min scalar for scale correction
55+
rel_tol : float relative NLL improvement threshold for early stopping
5456
5557
Returns
5658
-------
@@ -81,6 +83,8 @@ def als_gls(
8183
raise ValueError(f"k must be between 1 and min(K={K}, N={N})")
8284
if lam_F < 0 or lam_B < 0:
8385
raise ValueError("Regularization parameters must be non-negative")
86+
if rel_tol < 0:
87+
raise ValueError("rel_tol must be non-negative")
8488

8589
p_list = [X.shape[1] for X in Xs]
8690

@@ -242,7 +246,7 @@ def try_with_scale(F_try, D_try):
242246
# Convergence: stop if relative improvement w.r.t previous post-Σ NLL is tiny
243247
rel_impr = (nll_prev - nll_curr) / max(1.0, abs(nll_prev))
244248
nll_prev = nll_curr
245-
if rel_impr < 1e-6:
249+
if rel_impr < rel_tol:
246250
break
247251

248252
# Memory estimate: F (K×k) + D (K) + U (N×k) doubles

0 commit comments

Comments
 (0)