Skip to content

Commit 2ccb620

Browse files
committed
test(extraction): add unit tests for TrapPruner, MissingnessRecognizer, TargetLeakageAuditor
- 13 tests each (39 total) covering profiling, prompt construction, verdict parsing, verification, and mock-LLM integration
1 parent 8fd7dc6 commit 2ccb620

3 files changed

Lines changed: 670 additions & 0 deletions

File tree

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
"""Unit tests for the TargetLeakageAuditor module."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
from unittest.mock import MagicMock
7+
8+
import narwhals as nw
9+
import polars as pl
10+
import pytest
11+
12+
from loclean.extraction.leakage_auditor import TargetLeakageAuditor
13+
14+
# ------------------------------------------------------------------
15+
# Helpers
16+
# ------------------------------------------------------------------
17+
18+
19+
def _make_engine(response: str) -> MagicMock:
20+
engine = MagicMock()
21+
engine.generate.return_value = response
22+
return engine
23+
24+
25+
def _sample_df() -> pl.DataFrame:
26+
return pl.DataFrame(
27+
{
28+
"age": [25, 30, 45, 50, 35],
29+
"income": [50000, 60000, 80000, 90000, 55000],
30+
"approved_date": [
31+
"2024-01-15",
32+
"2024-01-20",
33+
"2024-02-01",
34+
"2024-02-10",
35+
"2024-01-25",
36+
],
37+
"feedback_score": [4, 5, 3, 5, 4],
38+
"approved": [True, True, False, True, True],
39+
}
40+
)
41+
42+
43+
# ------------------------------------------------------------------
44+
# _extract_state
45+
# ------------------------------------------------------------------
46+
47+
48+
class TestExtractState:
49+
def test_extracts_features_and_samples(self) -> None:
50+
df = _sample_df()
51+
df_nw = nw.from_native(df)
52+
features = ["age", "income", "approved_date", "feedback_score"]
53+
state = TargetLeakageAuditor._extract_state(df_nw, "approved", features)
54+
55+
assert state["target_col"] == "approved"
56+
assert state["features"] == features
57+
assert len(state["sample_rows"]) <= 10
58+
assert "age" in state["dtypes"]
59+
60+
def test_respects_sample_n(self) -> None:
61+
df = _sample_df()
62+
df_nw = nw.from_native(df)
63+
state = TargetLeakageAuditor._extract_state(
64+
df_nw, "approved", ["age"], sample_n=2
65+
)
66+
assert len(state["sample_rows"]) == 2
67+
68+
69+
# ------------------------------------------------------------------
70+
# _build_prompt
71+
# ------------------------------------------------------------------
72+
73+
74+
class TestBuildPrompt:
75+
def test_includes_domain_and_target(self) -> None:
76+
state = {
77+
"target_col": "approved",
78+
"features": ["age", "income"],
79+
"dtypes": {"age": "Int64", "income": "Int64"},
80+
"sample_rows": [{"age": 25, "income": 50000, "approved": True}],
81+
}
82+
prompt = TargetLeakageAuditor._build_prompt(state, "loan approval prediction")
83+
assert "loan approval prediction" in prompt
84+
assert "approved" in prompt
85+
assert "age" in prompt
86+
assert "is_leakage" in prompt
87+
88+
def test_no_domain(self) -> None:
89+
state = {
90+
"target_col": "y",
91+
"features": ["x"],
92+
"dtypes": {"x": "Float64"},
93+
"sample_rows": [{"x": 1.0, "y": 0}],
94+
}
95+
prompt = TargetLeakageAuditor._build_prompt(state, "")
96+
assert "Dataset domain:" not in prompt
97+
98+
99+
# ------------------------------------------------------------------
100+
# _parse_verdict
101+
# ------------------------------------------------------------------
102+
103+
104+
class TestParseVerdict:
105+
def test_parses_valid_json(self) -> None:
106+
response = json.dumps(
107+
[
108+
{"column": "approved_date", "is_leakage": True, "reason": "Post-event"},
109+
{"column": "age", "is_leakage": False, "reason": "Pre-event"},
110+
]
111+
)
112+
verdicts = TargetLeakageAuditor._parse_verdict(response)
113+
assert len(verdicts) == 2
114+
assert verdicts[0]["column"] == "approved_date"
115+
assert verdicts[0]["is_leakage"] is True
116+
assert verdicts[1]["is_leakage"] is False
117+
118+
def test_handles_extra_text(self) -> None:
119+
response = (
120+
'Analysis:\n[{"column": "x", "is_leakage": false, "reason": "ok"}]\nEnd.'
121+
)
122+
verdicts = TargetLeakageAuditor._parse_verdict(response)
123+
assert len(verdicts) == 1
124+
125+
def test_raises_on_no_json(self) -> None:
126+
with pytest.raises(ValueError, match="No JSON array"):
127+
TargetLeakageAuditor._parse_verdict("no json here")
128+
129+
130+
# ------------------------------------------------------------------
131+
# audit (integration with mock LLM)
132+
# ------------------------------------------------------------------
133+
134+
135+
class TestAudit:
136+
def test_drops_leaked_columns(self) -> None:
137+
df = _sample_df()
138+
response = json.dumps(
139+
[
140+
{"column": "age", "is_leakage": False, "reason": "ok"},
141+
{"column": "income", "is_leakage": False, "reason": "ok"},
142+
{"column": "approved_date", "is_leakage": True, "reason": "Post-event"},
143+
{
144+
"column": "feedback_score",
145+
"is_leakage": True,
146+
"reason": "Post-event",
147+
},
148+
]
149+
)
150+
engine = _make_engine(response)
151+
auditor = TargetLeakageAuditor(inference_engine=engine)
152+
153+
pruned, summary = auditor.audit(df, "approved", "loan approval")
154+
155+
assert "approved_date" not in pruned.columns
156+
assert "feedback_score" not in pruned.columns
157+
assert "age" in pruned.columns
158+
assert "income" in pruned.columns
159+
assert "approved" in pruned.columns
160+
assert "approved_date" in summary["dropped_columns"]
161+
assert "feedback_score" in summary["dropped_columns"]
162+
163+
def test_keeps_all_if_no_leakage(self) -> None:
164+
df = _sample_df()
165+
response = json.dumps(
166+
[
167+
{"column": "age", "is_leakage": False, "reason": "ok"},
168+
{"column": "income", "is_leakage": False, "reason": "ok"},
169+
{"column": "approved_date", "is_leakage": False, "reason": "ok"},
170+
{"column": "feedback_score", "is_leakage": False, "reason": "ok"},
171+
]
172+
)
173+
engine = _make_engine(response)
174+
auditor = TargetLeakageAuditor(inference_engine=engine)
175+
176+
pruned, summary = auditor.audit(df, "approved")
177+
178+
assert set(pruned.columns) == set(df.columns)
179+
assert summary["dropped_columns"] == []
180+
181+
def test_missing_target_raises(self) -> None:
182+
df = _sample_df()
183+
engine = _make_engine("[]")
184+
auditor = TargetLeakageAuditor(inference_engine=engine)
185+
186+
with pytest.raises(ValueError, match="not found"):
187+
auditor.audit(df, "nonexistent")
188+
189+
def test_no_feature_columns(self) -> None:
190+
df = pl.DataFrame({"target": [1, 2, 3]})
191+
engine = _make_engine("[]")
192+
auditor = TargetLeakageAuditor(inference_engine=engine)
193+
194+
pruned, summary = auditor.audit(df, "target")
195+
196+
assert pruned.columns == ["target"]
197+
assert summary["dropped_columns"] == []
198+
engine.generate.assert_not_called()
199+
200+
def test_summary_contains_verdicts(self) -> None:
201+
df = _sample_df()
202+
response = json.dumps(
203+
[
204+
{"column": "age", "is_leakage": False, "reason": "ok"},
205+
]
206+
)
207+
engine = _make_engine(response)
208+
auditor = TargetLeakageAuditor(inference_engine=engine)
209+
210+
_, summary = auditor.audit(df, "approved")
211+
212+
assert "verdicts" in summary
213+
assert isinstance(summary["verdicts"], list)
214+
215+
def test_domain_passed_to_prompt(self) -> None:
216+
df = _sample_df()
217+
response = json.dumps(
218+
[
219+
{"column": "age", "is_leakage": False, "reason": "ok"},
220+
]
221+
)
222+
engine = _make_engine(response)
223+
auditor = TargetLeakageAuditor(inference_engine=engine)
224+
225+
auditor.audit(df, "approved", domain="healthcare readmission")
226+
227+
call_args = engine.generate.call_args[0][0]
228+
assert "healthcare readmission" in call_args

0 commit comments

Comments
 (0)