Skip to content

Commit eb27e46

Browse files
fix: update tests to be isolated from LR model data
Test fixtures used real program IDs (cmu-mscf, baruch-mfe, rutgers-mqf) that now have trained LR models, causing classification to differ from the heuristic-based expectations. Replace with synthetic IDs (test-reach-*, test-target-*, test-safety-*) so tests always fall back to heuristics. Also update test_each_result_has_required_keys to use subset check instead of exact equality, accommodating the new admission_prob field in results. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent bd75bd8 commit eb27e46

2 files changed

Lines changed: 28 additions & 26 deletions

File tree

tests/test_list_builder.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ def _make_programs(count: int = 9) -> list[ProgramData]:
5151
5252
We construct programmes to exercise each bucket.
5353
"""
54+
# Use synthetic IDs so the LR model has no data and falls back to heuristics.
5455
# Reach programmes: low acceptance rate or high avg GPA.
5556
reach_1 = ProgramData(
56-
id="cmu-mscf",
57+
id="test-reach-1",
5758
name="CMU MSCF",
5859
university="Carnegie Mellon University",
5960
acceptance_rate=0.05,
@@ -67,7 +68,7 @@ def _make_programs(count: int = 9) -> list[ProgramData]:
6768
],
6869
)
6970
reach_2 = ProgramData(
70-
id="baruch-mfe",
71+
id="test-reach-2",
7172
name="Baruch MFE",
7273
university="Baruch College, CUNY",
7374
acceptance_rate=0.04,
@@ -79,7 +80,7 @@ def _make_programs(count: int = 9) -> list[ProgramData]:
7980
],
8081
)
8182
reach_3 = ProgramData(
82-
id="princeton-mfin",
83+
id="test-reach-3",
8384
name="Princeton MFin",
8485
university="Princeton University",
8586
acceptance_rate=0.03,
@@ -90,7 +91,7 @@ def _make_programs(count: int = 9) -> list[ProgramData]:
9091
],
9192
)
9293
reach_4 = ProgramData(
93-
id="mit-mfin",
94+
id="test-reach-4",
9495
name="MIT MFin",
9596
university="MIT",
9697
acceptance_rate=0.06,
@@ -104,7 +105,7 @@ def _make_programs(count: int = 9) -> list[ProgramData]:
104105

105106
# Target programmes: moderate acceptance, GPA near user's.
106107
target_1 = ProgramData(
107-
id="bu-msmf",
108+
id="test-target-1",
108109
name="BU MSMF",
109110
university="Boston University",
110111
acceptance_rate=0.12,
@@ -116,7 +117,7 @@ def _make_programs(count: int = 9) -> list[ProgramData]:
116117
],
117118
)
118119
target_2 = ProgramData(
119-
id="nyu-mfe",
120+
id="test-target-2",
120121
name="NYU MFE",
121122
university="New York University",
122123
acceptance_rate=0.10,
@@ -128,7 +129,7 @@ def _make_programs(count: int = 9) -> list[ProgramData]:
128129
],
129130
)
130131
target_3 = ProgramData(
131-
id="gatech-qcf",
132+
id="test-target-3",
132133
name="GaTech QCF",
133134
university="Georgia Tech",
134135
acceptance_rate=0.14,
@@ -141,7 +142,7 @@ def _make_programs(count: int = 9) -> list[ProgramData]:
141142

142143
# Safety programmes: high acceptance, avg GPA well below user's.
143144
safety_1 = ProgramData(
144-
id="rutgers-mqf",
145+
id="test-safety-1",
145146
name="Rutgers MQF",
146147
university="Rutgers University",
147148
acceptance_rate=0.25,
@@ -152,7 +153,7 @@ def _make_programs(count: int = 9) -> list[ProgramData]:
152153
],
153154
)
154155
safety_2 = ProgramData(
155-
id="uconn-msqf",
156+
id="test-safety-2",
156157
name="UConn MSQF",
157158
university="University of Connecticut",
158159
acceptance_rate=0.30,
@@ -248,8 +249,8 @@ def test_safety_programmes(self) -> None:
248249
def test_fewer_programmes_than_max(self) -> None:
249250
"""When only 1 safety programme exists, we should get 1."""
250251
progs = _make_programs()
251-
# Remove the second safety programme (uconn-msqf).
252-
progs = [p for p in progs if p.id != "uconn-msqf"]
252+
# Remove the second safety programme (test-safety-2).
253+
progs = [p for p in progs if p.id != "test-safety-2"]
253254
sl = build_school_list(
254255
_make_profile(), progs, _default_evaluation(), max_safety=2,
255256
)
@@ -313,7 +314,7 @@ def test_empty_program_list(self) -> None:
313314

314315
def test_single_programme(self) -> None:
315316
"""A single reach programme should appear in reach only."""
316-
progs = [_make_programs()[0]] # CMU = reach
317+
progs = [_make_programs()[0]] # test-reach-1 = reach
317318
sl = build_school_list(_make_profile(), progs, _default_evaluation())
318319
assert len(sl.reach) == 1
319320
assert len(sl.target) == 0

tests/test_school_ranker.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,10 @@ def _make_profile(self) -> UserProfile:
168168
)
169169

170170
def _make_programs(self) -> list[ProgramData]:
171+
# Use synthetic IDs so LR model fallback (heuristic) is always used.
171172
# Reach: low acceptance, high avg GPA
172173
reach = ProgramData(
173-
id="cmu-mscf",
174+
id="test-reach-prog",
174175
name="CMU MSCF",
175176
university="Carnegie Mellon",
176177
acceptance_rate=0.05,
@@ -184,7 +185,7 @@ def _make_programs(self) -> list[ProgramData]:
184185
)
185186
# Target: moderate acceptance, matched GPA
186187
target = ProgramData(
187-
id="bu-msmf",
188+
id="test-target-prog",
188189
name="BU MSMF",
189190
university="Boston University",
190191
acceptance_rate=0.12,
@@ -196,7 +197,7 @@ def _make_programs(self) -> list[ProgramData]:
196197
)
197198
# Safety: high acceptance, lower avg GPA
198199
safety = ProgramData(
199-
id="rutgers-mqf",
200+
id="test-safety-prog",
200201
name="Rutgers MQF",
201202
university="Rutgers",
202203
acceptance_rate=0.25,
@@ -237,12 +238,12 @@ def test_each_result_has_required_keys(self) -> None:
237238
programs = self._make_programs()
238239
evaluation = EvaluationResult(overall_score=7.0)
239240
result = rank_schools(profile, programs, evaluation)
240-
expected_keys = {
241+
required_keys = {
241242
"program_id", "name", "university", "category",
242243
"fit_score", "prereq_match_score", "acceptance_rate", "avg_gpa",
243244
}
244245
for entry in result["all"]:
245-
assert set(entry.keys()) == expected_keys
246+
assert required_keys.issubset(set(entry.keys()))
246247

247248
def test_classification_matches_category(self) -> None:
248249
"""Programs in each bucket should have matching category values."""
@@ -258,22 +259,22 @@ def test_classification_matches_category(self) -> None:
258259
assert entry["category"] == "safety"
259260

260261
def test_reach_program_classified_correctly(self) -> None:
261-
"""CMU with 5% acceptance should be reach."""
262+
"""Program with 5% acceptance and avg GPA > user GPA should be reach."""
262263
profile = self._make_profile()
263264
programs = self._make_programs()
264265
evaluation = EvaluationResult(overall_score=7.0)
265266
result = rank_schools(profile, programs, evaluation)
266-
cmu = next(r for r in result["all"] if r["program_id"] == "cmu-mscf")
267-
assert cmu["category"] == "reach"
267+
reach = next(r for r in result["all"] if r["program_id"] == "test-reach-prog")
268+
assert reach["category"] == "reach"
268269

269270
def test_safety_program_classified_correctly(self) -> None:
270-
"""Rutgers with 25% acceptance and user GPA above avg+0.1 -> safety."""
271+
"""Program with 25% acceptance and avg GPA below user -> safety."""
271272
profile = self._make_profile()
272273
programs = self._make_programs()
273274
evaluation = EvaluationResult(overall_score=7.0)
274275
result = rank_schools(profile, programs, evaluation)
275-
rutgers = next(r for r in result["all"] if r["program_id"] == "rutgers-mqf")
276-
assert rutgers["category"] == "safety"
276+
safety = next(r for r in result["all"] if r["program_id"] == "test-safety-prog")
277+
assert safety["category"] == "safety"
277278

278279
def test_empty_programs_list(self) -> None:
279280
profile = self._make_profile()
@@ -290,6 +291,6 @@ def test_prereq_match_score_reflected(self) -> None:
290291
programs = self._make_programs()
291292
evaluation = EvaluationResult(overall_score=7.0)
292293
result = rank_schools(profile, programs, evaluation)
293-
# Rutgers only requires calculus, which profile has -> match_score = 1.0
294-
rutgers = next(r for r in result["all"] if r["program_id"] == "rutgers-mqf")
295-
assert rutgers["prereq_match_score"] == 1.0
294+
# test-safety-prog only requires calculus, which profile has -> match_score = 1.0
295+
safety = next(r for r in result["all"] if r["program_id"] == "test-safety-prog")
296+
assert safety["prereq_match_score"] == 1.0

0 commit comments

Comments
 (0)