Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion src/causaltensor/cauest/MCNNM.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,59 @@ def __init__(self, Z=None, X=None, Omega=None, fixed_effects = 'two-way'):
self.FE_beta_solver = FixedEffectPanelSolver(fixed_effects=self.fixed_effects, X=self.X, Omega=self.Omega)
self.return_tau_scalar = False

def fit(
self,
O=None,
suggest_r=None,
l=None,
K=None,
list_l=None,
M_init=None,
eps=1e-7,
max_iter=2000,
):
"""Fit MC-NNM using the solver implied by the provided arguments.

Dispatch order is ``suggest_r`` > ``l`` > cross-validation. If neither
``suggest_r`` nor ``l`` is provided, cross-validation is used with
``K=2`` unless another ``K`` is specified.

Parameters
----------
O: 2D numpy array
The observation matrix.
suggest_r: int or None
Suggested rank for ``solve_with_suggested_rank``.
l: float or None
Nuclear norm regularizer for ``solve_with_regularizer``.
K: int or None
Number of cross-validation folds for ``solve_with_cross_validation``.
list_l: iterable or None
Candidate regularizers for cross-validation.
M_init: 2D numpy array or None
Initial low-rank matrix for ``solve_with_regularizer``.
eps: float
Convergence threshold for ``solve_with_regularizer``.
max_iter: int
Maximum iterations for ``solve_with_regularizer``.
"""
if O is None:
raise ValueError("O must be provided.")

if suggest_r is not None:
return self.solve_with_suggested_rank(O=O, suggest_r=suggest_r)
if l is not None:
return self.solve_with_regularizer(
O=O,
l=l,
M_init=M_init,
eps=eps,
max_iter=max_iter,
)
if K is None:
K = 2
return self.solve_with_cross_validation(O=O, K=K, list_l=list_l)

def solve_with_regularizer(self, O=None, l=None, M_init=None, eps=1e-7, max_iter=2000):
""" Solve the matrix completion problem with nuclear norm regularizer and fixed effects
Parameters
Expand Down Expand Up @@ -214,4 +267,4 @@ def MC_NNM_with_suggested_rank(O, Omega, suggest_r=1):
def MC_NNM_with_cross_validation(O, Omega, K=5, list_l=None):
solver = MCNNMPanelSolver(Z = 1-Omega)
res = solver.solve_with_cross_validation(O, K, list_l)
return res.M, res.row_fixed_effects, res.column_fixed_effects, res.tau
return res.M, res.row_fixed_effects, res.column_fixed_effects, res.tau
97 changes: 97 additions & 0 deletions tests/test_mcnnm_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np
import pytest

from causaltensor.cauest.MCNNM import MCNNMPanelSolver


def make_solver():
Z = np.zeros((3, 3), dtype=int)
Z[2, 2] = 1
return MCNNMPanelSolver(Z=Z)


def test_fit_uses_suggested_rank_first(monkeypatch):
solver = make_solver()
O = np.arange(9, dtype=float).reshape(3, 3)
result = object()
calls = {}

def fake_suggested_rank(*, O=None, suggest_r=1):
calls["O"] = O
calls["suggest_r"] = suggest_r
return result

monkeypatch.setattr(solver, "solve_with_suggested_rank", fake_suggested_rank)

assert solver.fit(O=O, suggest_r=2, l=0.5, K=3) is result
assert calls["O"] is O
assert calls["suggest_r"] == 2


def test_fit_uses_regularizer_when_l_is_provided(monkeypatch):
solver = make_solver()
O = np.arange(9, dtype=float).reshape(3, 3)
M_init = np.ones_like(O)
result = object()
calls = {}

def fake_regularizer(
*,
O=None,
l=None,
M_init=None,
eps=1e-7,
max_iter=2000,
):
calls["O"] = O
calls["l"] = l
calls["M_init"] = M_init
calls["eps"] = eps
calls["max_iter"] = max_iter
return result

monkeypatch.setattr(solver, "solve_with_regularizer", fake_regularizer)

assert (
solver.fit(O=O, l=0.5, K=3, M_init=M_init, eps=1e-5, max_iter=7)
is result
)
assert calls["O"] is O
assert calls["l"] == 0.5
assert calls["M_init"] is M_init
assert calls["eps"] == 1e-5
assert calls["max_iter"] == 7


def test_fit_uses_cross_validation_by_default(monkeypatch):
solver = make_solver()
O = np.arange(9, dtype=float).reshape(3, 3)
list_l = [0.1, 0.2]
result = object()
calls = {}

def fake_cross_validation(*, O=None, K=2, list_l=None):
calls["O"] = O
calls["K"] = K
calls["list_l"] = list_l
return result

monkeypatch.setattr(solver, "solve_with_cross_validation", fake_cross_validation)

assert solver.fit(O=O, list_l=list_l) is result
assert calls["O"] is O
assert calls["K"] == 2
assert calls["list_l"] is list_l

calls.clear()
assert solver.fit(O=O, K=4, list_l=list_l) is result
assert calls["O"] is O
assert calls["K"] == 4
assert calls["list_l"] is list_l


def test_fit_requires_observation_matrix():
solver = make_solver()

with pytest.raises(ValueError, match="O must be provided"):
solver.fit()
1,266 changes: 639 additions & 627 deletions tutorials/Panel Data Example.ipynb

Large diffs are not rendered by default.

1,300 changes: 650 additions & 650 deletions tutorials/Panel_Data_Example.ipynb

Large diffs are not rendered by default.

Loading