Skip to content

Commit 14b90ee

Browse files
added a 100 example test for evaluation
1 parent 6653cef commit 14b90ee

37 files changed

Lines changed: 11345 additions & 5 deletions

.github/workflows/ci.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
name: CI
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches: [ main ]
7+
8+
jobs:
9+
tests:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v4
13+
- uses: actions/setup-python@v5
14+
with:
15+
python-version: '3.11'
16+
- name: Install dependencies
17+
run: |
18+
python -m pip install --upgrade pip
19+
pip install -r requirements.txt
20+
pip install -r dev-requirements.txt
21+
- name: Run tests
22+
run: pytest -q
23+

README.md

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,92 @@ Response example:
102102
}
103103
```
104104

105+
## MLOps Additions (Open Source + Free)
106+
107+
- Reproducible training: `src/train.py` now accepts `--seed` and fixes RNGs. After training it writes `metrics.json` and `run_info.json` into the output model folder with args, data hash, and git commit (if available) for traceability.
108+
- Health/CORS/metrics: FastAPI exposes `/healthz` and `/readyz`. CORS is enabled for React by default. Optional Prometheus metrics (set `ENABLE_METRICS=1`) if `prometheus-fastapi-instrumentator` is installed.
109+
- Tests: Added unit and API tests under `tests/`. Dev deps in `dev-requirements.txt`. CI via GitHub Actions runs tests on pushes/PRs in public repos for free.
110+
111+
### Configure CORS for React
112+
113+
- By default all origins are allowed. To restrict:
114+
115+
```bash
116+
set ALLOW_ORIGINS=http://localhost:5173,http://localhost:3000 # Windows PowerShell: $env:ALLOW_ORIGINS="http://..."
117+
uvicorn app.main:app --reload --port 8000
118+
```
119+
120+
### Enable Prometheus metrics (optional)
121+
122+
```bash
123+
export ENABLE_METRICS=1 # Windows PowerShell: $env:ENABLE_METRICS=1
124+
uvicorn app.main:app --reload --port 8000
125+
```
126+
127+
### Run Tests Locally
128+
129+
```bash
130+
pip install -r requirements.txt -r dev-requirements.txt
131+
pytest -q
132+
```
133+
134+
### Demo-Ready Metrics Report
135+
136+
Share visuals and plain-English talking points with stakeholders in one command:
137+
138+
```bash
139+
pip install -r requirements.txt -r dev-requirements.txt
140+
python reports/generate_report.py \
141+
--model_dir models/distilbert_component_classifier \
142+
--train_path data/train.csv \
143+
--report_dir reports/latest
144+
```
145+
146+
You will get:
147+
148+
- `reports/latest/report_summary.md`: non-technical explanation of precision, recall, F1, exact-match accuracy, and loss trends.
149+
- `reports/latest/report_data.json`: structured metrics and per-label stats ready for slide tables or dashboards.
150+
- `reports/latest/figures/*.png`: validation F1/loss curves plus a top-component coverage bar chart for quick storytelling.
151+
152+
Run `python reports/generate_report.py -h` to tweak thresholds, validation split, or the destination folder.
153+
154+
### Manual Scenario Audit (Edge Cases)
155+
156+
When you need to show strengths *and* improvement areas, run the curated scenario harness:
157+
158+
```bash
159+
python reports/run_manual_eval.py \
160+
--cases_path reports/manual_eval_cases_100.jsonl \
161+
--model_dir models/distilbert_component_classifier \
162+
--output_dir reports/manual_eval \
163+
--top_k_fallback 3
164+
```
165+
166+
Artifacts:
167+
168+
- `reports/manual_eval/manual_eval_results.{json,csv}`: per-scenario expectations vs. predictions, true/false positives, misses.
169+
- `reports/manual_eval/manual_eval_summary.md`: bullet-point narrative calling out gaps for non-technical leads.
170+
- `reports/manual_eval/figures/*.png`: outcome distribution, per-case recall, and top missed components for slide-ready visuals.
171+
172+
Use `--top_k_fallback` (default 0) to add the best-scoring labels even when the sigmoid score is below the threshold—handy for exploratory edge-case analysis. Edit `reports/manual_eval_cases_100.jsonl` directly or regenerate it with:
173+
174+
```bash
175+
python reports/build_manual_cases.py \
176+
--limit 100 \
177+
--output_path reports/manual_eval_cases_100.jsonl
178+
```
179+
180+
The generator spans authentication, lending, collections, KYC, payments, reporting, disputes, core integration, omni-channel comms, and regulatory scenarios so each component in the taxonomy appears multiple times.
181+
182+
### Lightweight Run Tracking Artifacts
183+
184+
- After training, check your output dir (e.g., `models/distilbert_component_classifier/`) for:
185+
- `metrics.json`: evaluation metrics from the Trainer
186+
- `run_info.json`: training args, data SHA256, git commit
187+
- `label2id.json`: label mapping used at inference
188+
189+
These files are simple, portable, and versionable in git or any storage.
190+
105191
## Optional: Docker
106192

107193
```dockerfile
@@ -125,3 +211,22 @@ docker run -p 8000:8000 component-identifier
125211

126212
- Training defaults target CPU-friendly settings (batch size 4, max length 256, 3–5 epochs). Adjust `--num_epochs`, `--learning_rate`, and other CLI flags as needed.
127213
- The provided synthetic dataset is only for demonstration. Replace it with real, labeled production data for meaningful predictions.
214+
215+
## Frontend Integration (React)
216+
217+
- Point your React app to the FastAPI endpoint:
218+
219+
```ts
220+
// Example using fetch
221+
async function predictComponents(text: string, threshold = 0.5) {
222+
const resp = await fetch("http://localhost:8000/predict", {
223+
method: "POST",
224+
headers: { "Content-Type": "application/json" },
225+
body: JSON.stringify({ text, threshold }),
226+
});
227+
if (!resp.ok) throw new Error("Prediction failed");
228+
return await resp.json();
229+
}
230+
```
231+
232+
- Ensure the backend has CORS configured (default is permissive) or set `ALLOW_ORIGINS` accordingly in production.

app/main.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from pathlib import Path
2+
import os
23
import sys
34
from typing import List
45

56
from fastapi import FastAPI, HTTPException
7+
from fastapi.middleware.cors import CORSMiddleware
68
from pydantic import BaseModel, Field
79

810
PROJECT_ROOT = Path(__file__).resolve().parents[1]
@@ -34,6 +36,16 @@ class PredictResponse(BaseModel):
3436
summary="Predict impacted components from requirement statements.",
3537
)
3638

39+
# CORS for React UI
40+
ALLOW_ORIGINS = os.getenv("ALLOW_ORIGINS", "*").split(",") if os.getenv("ALLOW_ORIGINS") else ["*"]
41+
app.add_middleware(
42+
CORSMiddleware,
43+
allow_origins=ALLOW_ORIGINS,
44+
allow_credentials=True,
45+
allow_methods=["*"],
46+
allow_headers=["*"],
47+
)
48+
3749
MODEL_DIR = PROJECT_ROOT / "models" / "distilbert_component_classifier"
3850
MODEL = None
3951
TOKENIZER = None
@@ -43,11 +55,21 @@ class PredictResponse(BaseModel):
4355
@app.on_event("startup")
4456
async def _load_model() -> None:
4557
global MODEL, TOKENIZER, LABELS # pylint: disable=global-statement
46-
if not MODEL_DIR.exists():
47-
raise RuntimeError(
48-
f"Model directory '{MODEL_DIR}' not found. Train the model before starting the API."
49-
)
50-
MODEL, TOKENIZER, LABELS = load_assets(str(MODEL_DIR))
58+
if MODEL_DIR.exists():
59+
MODEL, TOKENIZER, LABELS = load_assets(str(MODEL_DIR))
60+
else:
61+
# Defer loading; endpoint guard will respond 503 until trained
62+
MODEL, TOKENIZER, LABELS = None, None, None
63+
64+
65+
@app.get("/healthz")
66+
async def healthz() -> dict:
67+
return {"status": "ok", "version": app.version}
68+
69+
70+
@app.get("/readyz")
71+
async def readyz() -> dict:
72+
return {"ready": MODEL is not None}
5173

5274

5375
@app.post("/predict", response_model=PredictResponse)
@@ -67,3 +89,13 @@ async def predict_components(payload: PredictRequest) -> PredictResponse:
6789
predictions.sort(key=lambda item: item["score"], reverse=True)
6890
scores = [ComponentScore(**item) for item in predictions]
6991
return PredictResponse(components=scores, threshold=payload.threshold)
92+
93+
94+
# Optional Prometheus metrics if installed and enabled
95+
if os.getenv("ENABLE_METRICS", "0") == "1":
96+
try:
97+
from prometheus_fastapi_instrumentator import Instrumentator # type: ignore
98+
99+
Instrumentator().instrument(app).expose(app)
100+
except Exception: # pragma: no cover - optional dependency
101+
pass

dev-requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
pytest>=7.4
2+
requests>=2.31
3+
matplotlib>=3.8
4+
seaborn>=0.13

0 commit comments

Comments
 (0)