diff --git a/eval/README.md b/eval/README.md index d4e72cf..18d05da 100644 --- a/eval/README.md +++ b/eval/README.md @@ -79,7 +79,20 @@ Tier 1 benchmark entrypoints: The runner loads benchmark inputs from the HF dataset and writes benchmark outputs to the HF experiments bucket when `--local-only` is not set. DDInter data is not stored in HF: if the local SQLite file is absent, configure `INTERACTION_DB_REPO`, `INTERACTION_DB_TAG`, and optionally `INTERACTION_DB_SHA256` so the runner can fetch `ddinter.db` from the pinned GitHub release source. -`predictions.jsonl` includes `elapsed_ms` keys for `ocr_clean`, `ner`, `rxnorm`, `ddinter_rxcui`, `ddinter_fts`, `openfda`, `severity`, `analyze`, `interactions`, and `total`. These measurements intentionally overlap: `analyze` includes OCR, NER, and RxNorm work; `interactions` includes DDInter, OpenFDA, and severity work; `total` includes the top-level phases plus benchmark overhead. Do not sum the keys as a disjoint latency partition. +`predictions.jsonl` includes benchmark-only diagnostics for each record: + +1. `elapsed_ms` preserves the original 10 timing keys: `ocr_clean`, `ner`, `rxnorm`, `ddinter_rxcui`, `ddinter_fts`, `openfda`, `severity`, `analyze`, `interactions`, and `total`. +2. `component_timings_ms` repeats those keys and adds `critical_path` plus `slowest_component_ms`; `critical_path` is the sum of the non-aggregate component buckets, and `slowest_component` names the largest non-aggregate component bucket for that record. +3. `ner_diagnostics` includes predicted entities plus per-record strict and lenient TP/FP/FN counts when `expected_names` is present. +4. `rxnorm_attempts` records benchmark-stage, method, query, returned RxCUI, status, elapsed time, output summary, and error metadata for RxNorm calls. +5. `interaction_attempts` records one row per checked pair, including pair names, RxCUIs, DDInter RxCUI lookup, DDInter FTS lookup, OpenFDA fallback, final source, final severity, and miss reason. +6. `pipeline_errors` records timeout or component errors tied to the record without requiring the whole benchmark to fail. + +Timing measurements intentionally overlap: `analyze` includes OCR, NER, and RxNorm work; `interactions` includes DDInter, OpenFDA, and severity work; `total` includes the top-level phases plus benchmark overhead. Do not sum all timing keys as a disjoint latency partition. Starting with `metric_schema_version: "benchmark-diagnostics-v1"`, the `rxnorm` timing bucket covers all benchmark-wrapped RxNorm calls (`get_rxcui`, `approximate_term`, `search_by_name`, and `get_drug_details`), so compare it with earlier runs only as a changed-instrumentation metric. + +`results.json` groups rollups by `overall`, `timing`, `ner`, `linking`, `rxnorm`, `interactions`, `errors`, and `fp_taxonomy`. `linking` is kept for backward compatibility with the original link-coverage fields; `rxnorm` carries those core fields plus RxNorm attempt diagnostics such as method hit/miss/error counts, unresolved queries, and canonicalization collisions. Interaction diagnostics report DDInter RxCUI hit rate, DDInter FTS rescue rate, OpenFDA rescue rate, source counts, and common unknown pairs. These are routing/source-coverage diagnostics unless reviewed `expected_interactions` and `known_safe_pairs` are present. + +`manifest.json` includes `metric_schema_version`, dataset revision, run id, sample size, model IDs, concurrency, and DDInter release metadata. `summary.md` highlights top-line metrics, timing bottlenecks, unresolved RxNorm queries, unknown interaction pairs, and an explicit warning when outputs are not accuracy-certified. Use `--record-timeout-seconds` to bound each input record so a stuck RxNorm, OpenFDA, or model path records a `record_timeout` error instead of hanging the whole run. Use `--local-only` for development and smoke runs. Without `--local-only`, result artifacts upload to the experiments bucket under an immutable `benchmark-results///` prefix; do not commit generated candidate JSON or benchmark result directories to GitHub. diff --git a/eval/benchmark_results.schema.json b/eval/benchmark_results.schema.json index 86703c7..3fc6a26 100644 --- a/eval/benchmark_results.schema.json +++ b/eval/benchmark_results.schema.json @@ -4,8 +4,42 @@ "title": "PillChecker benchmark results", "type": "object", "additionalProperties": true, - "required": ["ner", "linking", "interactions", "fp_taxonomy"], + "required": ["overall", "timing", "ner", "linking", "rxnorm", "interactions", "errors", "fp_taxonomy"], "properties": { + "overall": { + "type": "object", + "additionalProperties": true, + "required": [ + "records_total", + "records_completed", + "records_errored", + "error_rate", + "timeout_count", + "concurrency", + "wall_time_seconds", + "records_per_second" + ], + "properties": { + "records_total": {"type": "integer", "minimum": 0}, + "records_completed": {"type": "integer", "minimum": 0}, + "records_errored": {"type": "integer", "minimum": 0}, + "error_rate": {"type": "number", "minimum": 0, "maximum": 1}, + "timeout_count": {"type": "integer", "minimum": 0}, + "concurrency": {"type": "integer", "minimum": 1}, + "wall_time_seconds": {"type": "number", "minimum": 0}, + "records_per_second": {"type": ["number", "null"], "minimum": 0} + } + }, + "timing": { + "type": "object", + "additionalProperties": true, + "required": ["components", "slowest_component", "slowest_component_counts"], + "properties": { + "components": {"type": "object"}, + "slowest_component": {"type": ["string", "null"]}, + "slowest_component_counts": {"type": "object"} + } + }, "ner": { "type": "object", "additionalProperties": true, @@ -50,6 +84,23 @@ "incorrect_link_rate": {"type": ["number", "null"], "minimum": 0, "maximum": 1} } }, + "rxnorm": { + "type": "object", + "additionalProperties": true, + "required": [ + "coverage", + "fallback_rate", + "nil_rate", + "n_link_attempts", + "n_drugs_total", + "acc_at_1", + "incorrect_link_rate", + "n_rxnorm_attempts", + "rxnorm_by_method", + "unresolved_queries", + "canonicalization_collisions" + ] + }, "interactions": { "type": "object", "additionalProperties": true, @@ -65,7 +116,12 @@ "unknown_rate", "severity_distribution", "uncertain_rate", - "records_with_any_interaction" + "records_with_any_interaction", + "ddinter_rxcui_hit_rate", + "ddinter_fts_rescue_rate", + "openfda_rescue_rate", + "source_counts", + "top_unknown_pairs" ], "properties": { "total_pairs_checked": {"type": "integer", "minimum": 0}, @@ -74,7 +130,12 @@ "unknown_rate": {"type": "number", "minimum": 0, "maximum": 1}, "severity_distribution": {"$ref": "#/$defs/severity_counts"}, "uncertain_rate": {"type": "number", "minimum": 0, "maximum": 1}, - "records_with_any_interaction": {"type": "integer", "minimum": 0} + "records_with_any_interaction": {"type": "integer", "minimum": 0}, + "ddinter_rxcui_hit_rate": {"type": "number", "minimum": 0, "maximum": 1}, + "ddinter_fts_rescue_rate": {"type": "number", "minimum": 0, "maximum": 1}, + "openfda_rescue_rate": {"type": "number", "minimum": 0, "maximum": 1}, + "source_counts": {"type": "object"}, + "top_unknown_pairs": {"type": "array"} } }, "accuracy": {"type": ["object", "null"]}, @@ -91,6 +152,17 @@ } } }, + "errors": { + "type": "object", + "additionalProperties": true, + "required": ["total", "by_stage", "by_class", "records"], + "properties": { + "total": {"type": "integer", "minimum": 0}, + "by_stage": {"type": "object"}, + "by_class": {"type": "object"}, + "records": {"type": "array"} + } + }, "fp_taxonomy": { "type": "object", "additionalProperties": true, diff --git a/eval/benchmark_run_manifest.schema.json b/eval/benchmark_run_manifest.schema.json index 0ddd68c..12615d8 100644 --- a/eval/benchmark_run_manifest.schema.json +++ b/eval/benchmark_run_manifest.schema.json @@ -12,6 +12,9 @@ "dataset_revision", "command", "model_ids", + "sample_size", + "concurrency", + "metric_schema_version", "metrics" ], "properties": { @@ -63,6 +66,14 @@ "type": "integer", "minimum": 1 }, + "concurrency": { + "type": ["integer", "null"], + "minimum": 1 + }, + "metric_schema_version": { + "type": "string", + "const": "benchmark-diagnostics-v1" + }, "random_seed": { "type": ["integer", "string", "null"] }, diff --git a/eval/metrics/interactions.py b/eval/metrics/interactions.py index 504c948..1e75b9c 100644 --- a/eval/metrics/interactions.py +++ b/eval/metrics/interactions.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections import defaultdict +from collections import Counter, defaultdict from typing import Any, Awaitable, Callable @@ -140,6 +140,7 @@ def compute( uncertain = 0 returned = 0 records_with_any = 0 + attempts = [attempt for prediction in predictions for attempt in prediction.get("interaction_attempts", [])] for prediction in predictions: interactions_response = prediction.get("interactions") or {} @@ -159,6 +160,7 @@ def compute( uncertain += 1 total_pairs = sum(coverage.values()) + attempt_diagnostics = _attempt_diagnostics(attempts) return { "descriptive": { "total_pairs_checked": total_pairs, @@ -168,7 +170,40 @@ def compute( "severity_distribution": severity_distribution, "uncertain_rate": _rate(uncertain, returned), "records_with_any_interaction": records_with_any, + **attempt_diagnostics, }, "accuracy": _accuracy(predictions, dataset), "seed_smoke": compute_seed_smoke(seed_cases, seed_results), } + + +def _status(attempt: dict, component: str) -> str: + block = attempt.get(component) or {} + return str(block.get("status") or "skipped") + + +def _attempt_diagnostics(attempts: list[dict]) -> dict: + total = len(attempts) + source_counts = Counter(str(attempt.get("final_source") or "unknown") for attempt in attempts) + ddinter_rxcui_hits = sum(1 for attempt in attempts if _status(attempt, "ddinter_rxcui") == "hit") + ddinter_fts_hits = sum(1 for attempt in attempts if _status(attempt, "ddinter_fts") == "hit") + openfda_hits = sum(1 for attempt in attempts if _status(attempt, "openfda") == "hit") + unknown_pairs = Counter( + _pair_key(str(attempt.get("drug_a", "")), str(attempt.get("drug_b", ""))) + for attempt in attempts + if attempt.get("final_source") == "unknown" + ) + return { + "ddinter_rxcui_hit_rate": _rate(ddinter_rxcui_hits, total), + "ddinter_fts_rescue_rate": _rate(ddinter_fts_hits, total), + "openfda_rescue_rate": _rate(openfda_hits, total), + "source_counts": { + "ddinter": int(source_counts.get("ddinter", 0)), + "openfda": int(source_counts.get("openfda", 0)), + "unknown": int(source_counts.get("unknown", 0)), + }, + "top_unknown_pairs": [ + {"drug_a": drug_a, "drug_b": drug_b, "count": count} + for (drug_a, drug_b), count in unknown_pairs.most_common(10) + ], + } diff --git a/eval/metrics/linking.py b/eval/metrics/linking.py index 75a3bae..043a193 100644 --- a/eval/metrics/linking.py +++ b/eval/metrics/linking.py @@ -2,6 +2,8 @@ from __future__ import annotations +from collections import defaultdict + def _rate(numerator: int, denominator: int) -> float: return numerator / denominator if denominator else 0.0 @@ -10,6 +12,7 @@ def _rate(numerator: int, denominator: int) -> float: def compute(predictions: list[dict], dataset: list[dict]) -> dict: drugs = [drug for pred in predictions for drug in pred.get("drugs", [])] attempts = [attempt for pred in predictions for attempt in pred.get("link_attempts", [])] + rxnorm_attempts = [attempt for pred in predictions for attempt in pred.get("rxnorm_attempts", [])] resolved = sum(1 for drug in drugs if drug.get("rxcui")) fallback = sum(1 for drug in drugs if drug.get("source") == "rxnorm_fallback") nil_count = sum(1 for attempt in attempts if attempt.get("rxcui") is None) @@ -39,6 +42,7 @@ def compute(predictions: list[dict], dataset: list[dict]) -> dict: acc_at_1 = sum(acc_values) / len(acc_values) incorrect_link_rate = _rate(incorrect, predicted_with_rxcui_total) + diagnostics = _rxnorm_diagnostics(rxnorm_attempts) return { "coverage": _rate(resolved, len(drugs)), "fallback_rate": _rate(fallback, len(drugs)), @@ -47,4 +51,45 @@ def compute(predictions: list[dict], dataset: list[dict]) -> dict: "n_drugs_total": len(drugs), "acc_at_1": acc_at_1, "incorrect_link_rate": incorrect_link_rate, + **diagnostics, + } + + +def _rxnorm_diagnostics(attempts: list[dict]) -> dict: + by_method: dict[str, dict[str, int]] = defaultdict(lambda: {"hit": 0, "miss": 0, "error": 0}) + unresolved = [] + queries_by_rxcui: dict[str, set[str]] = defaultdict(set) + + for attempt in attempts: + method = str(attempt.get("method") or "unknown") + status = str(attempt.get("status") or "unknown") + if status not in {"hit", "miss", "error"}: + status = "miss" if attempt.get("rxcui") is None else "hit" + by_method[method][status] += 1 + + query = str(attempt.get("query") or attempt.get("name") or "") + rxcui = attempt.get("rxcui") + if rxcui: + if method != "get_drug_details" and query: + queries_by_rxcui[str(rxcui)].add(query) + elif status in {"miss", "error"}: + unresolved.append({ + "query": query, + "stage": attempt.get("stage"), + "method": method, + }) + + collisions = [ + {"rxcui": rxcui, "queries": sorted(query for query in queries if query)} + for rxcui, queries in sorted(queries_by_rxcui.items()) + if len({query.casefold() for query in queries if query}) > 1 + ] + return { + "n_rxnorm_attempts": len(attempts), + "rxnorm_by_method": { + method: counts + for method, counts in sorted(by_method.items()) + }, + "unresolved_queries": unresolved[:20], + "canonicalization_collisions": collisions[:20], } diff --git a/eval/metrics/ner.py b/eval/metrics/ner.py index 55eb96d..95f2906 100644 --- a/eval/metrics/ner.py +++ b/eval/metrics/ner.py @@ -46,6 +46,11 @@ def _prf(tp: int, fp: int, fn: int) -> dict[str, float]: def _record_metrics(predicted: list[str], expected: list[str], matcher: Callable[[str, str], bool]) -> dict[str, float]: + counts = _record_counts(predicted, expected, matcher) + return _prf(counts["tp"], counts["fp"], counts["fn"]) + + +def _record_counts(predicted: list[str], expected: list[str], matcher: Callable[[str, str], bool]) -> dict[str, int]: matched_expected: set[int] = set() tp = 0 for pred in predicted: @@ -58,7 +63,22 @@ def _record_metrics(predicted: list[str], expected: list[str], matcher: Callable tp += 1 fp = len(predicted) - tp fn = len(expected) - tp - return _prf(tp, fp, fn) + return {"tp": tp, "fp": fp, "fn": fn} + + +def diagnostics_for_entities(entities: list[dict], expected_names: list[str]) -> dict: + predicted = [str(entity.get("text", "")) for entity in entities] + expected = [str(name) for name in expected_names] + strict = _record_counts(predicted, expected, _strict_match) + lenient = _record_counts(predicted, expected, _lenient_match) + return { + "entities": entities, + "strict": strict, + "lenient": lenient, + "expected_count": len(expected), + "predicted_count": len(predicted), + "low_confidence_count": sum(1 for entity in entities if float(entity.get("score", 1.0)) < 0.85), + } def _average(blocks: list[dict[str, float]]) -> dict[str, float]: diff --git a/eval/metrics/pipeline.py b/eval/metrics/pipeline.py new file mode 100644 index 0000000..6d74ffd --- /dev/null +++ b/eval/metrics/pipeline.py @@ -0,0 +1,105 @@ +"""Pipeline-level benchmark diagnostics.""" + +from __future__ import annotations + +import math +from collections import Counter + + +TIMING_COMPONENTS = ( + "ocr_clean", + "ner", + "rxnorm", + "ddinter_rxcui", + "ddinter_fts", + "openfda", + "severity", + "analyze", + "interactions", + "total", +) +def _rate(numerator: int, denominator: int) -> float: + return numerator / denominator if denominator else 0.0 + + +def _percentile(values: list[float], percentile: int) -> float | None: + if not values: + return None + ordered = sorted(values) + index = max(0, math.ceil((percentile / 100) * len(ordered)) - 1) + return round(ordered[index], 3) + + +def timing(predictions: list[dict]) -> dict: + components: dict[str, dict] = {} + for component in TIMING_COMPONENTS: + values = [ + float(_timing_source(prediction).get(component, 0.0)) + for prediction in predictions + ] + components[component] = { + "p50_ms": _percentile(values, 50), + "p95_ms": _percentile(values, 95), + "p99_ms": _percentile(values, 99), + "max_ms": round(max(values), 3) if values else None, + "mean_ms": round(sum(values) / len(values), 3) if values else None, + } + + slowest_counts = Counter( + prediction.get("slowest_component") + for prediction in predictions + if prediction.get("slowest_component") + ) + return { + "components": components, + "slowest_component": slowest_counts.most_common(1)[0][0] if slowest_counts else None, + "slowest_component_counts": dict(sorted(slowest_counts.items())), + } + + +def _timing_source(prediction: dict) -> dict: + timings = prediction.get("component_timings_ms") + if timings is None: + timings = prediction.get("elapsed_ms", {}) + return timings + + +def overall( + predictions: list[dict], + errors: list[dict], + *, + concurrency: int, + wall_time_seconds: float, +) -> dict: + records_total = len(predictions) + records_errored = len({str(error.get("record_id")) for error in errors}) + records_completed = records_total - records_errored + return { + "records_total": records_total, + "records_completed": records_completed, + "records_errored": records_errored, + "error_rate": _rate(records_errored, records_total), + "timeout_count": sum(1 for error in errors if error.get("stage") == "record_timeout"), + "concurrency": concurrency, + "wall_time_seconds": round(wall_time_seconds, 3), + "records_per_second": round(records_total / wall_time_seconds, 3) if wall_time_seconds > 0 else None, + } + + +def errors(errors_: list[dict]) -> dict: + by_stage = Counter(str(error.get("stage", "unknown")) for error in errors_) + by_class = Counter(str(error.get("error_class", "unknown")) for error in errors_) + return { + "total": len(errors_), + "by_stage": dict(sorted(by_stage.items())), + "by_class": dict(sorted(by_class.items())), + "records": [ + { + "record_id": str(error.get("record_id", "")), + "stage": error.get("stage"), + "error_class": error.get("error_class"), + "message": error.get("message"), + } + for error in errors_[:20] + ], + } diff --git a/eval/run_benchmark.py b/eval/run_benchmark.py index df6e075..509279c 100644 --- a/eval/run_benchmark.py +++ b/eval/run_benchmark.py @@ -32,6 +32,7 @@ from eval.metrics import interactions as interaction_metrics from eval.metrics import linking as linking_metrics from eval.metrics import ner as ner_metrics +from eval.metrics import pipeline as pipeline_metrics from scripts import download_interaction_db ELAPSED_KEYS = ( @@ -46,6 +47,15 @@ "interactions", "total", ) +LEAF_ELAPSED_KEYS = ( + "ocr_clean", + "ner", + "rxnorm", + "ddinter_rxcui", + "ddinter_fts", + "openfda", + "severity", +) DEFAULT_SEED_CASES = Path(__file__).with_name("interaction_seed_cases.json") active_benchmark_trace: ContextVar["BenchmarkTrace | None"] = ContextVar( @@ -58,11 +68,44 @@ class BenchmarkTrace: elapsed_ms: dict[str, float] = field(default_factory=lambda: {key: 0.0 for key in ELAPSED_KEYS}) link_attempts: list[dict] = field(default_factory=list) + rxnorm_attempts: list[dict] = field(default_factory=list) + interaction_attempts: list[dict] = field(default_factory=list) ner_entities: list[dict] = field(default_factory=list) + pipeline_errors: list[dict] = field(default_factory=list) + error_signatures: set[tuple[str, str]] = field(default_factory=set) + phase: str | None = None + active_interaction_attempt: dict | None = None def add_elapsed(self, key: str, seconds: float) -> None: self.elapsed_ms[key] = round(self.elapsed_ms.get(key, 0.0) + seconds * 1000, 3) + def record_error(self, stage: str, exc: Exception) -> None: + signature = (exc.__class__.__name__, str(exc)) + if signature in self.error_signatures: + for item in self.pipeline_errors: + if item.get("error_class") == signature[0] and item.get("message") == signature[1]: + stages = item.setdefault("stages", [item.get("stage")]) + if stage not in stages: + stages.append(stage) + break + return + self.error_signatures.add(signature) + self.pipeline_errors.append({ + "stage": stage, + "stages": [stage], + "error_class": exc.__class__.__name__, + "message": str(exc), + }) + + def component_timings(self) -> dict[str, float]: + timings = {key: self.elapsed_ms.get(key, 0.0) for key in ELAPSED_KEYS} + timings["critical_path"] = round(sum(timings.get(key, 0.0) for key in LEAF_ELAPSED_KEYS), 3) + timings["slowest_component_ms"] = round(max(timings.get(key, 0.0) for key in LEAF_ELAPSED_KEYS), 3) + return timings + + def slowest_component(self) -> str: + return max(LEAF_ELAPSED_KEYS, key=lambda key: self.elapsed_ms.get(key, 0.0)) + @contextlib.contextmanager def install_benchmark_instrumentation(): @@ -82,19 +125,41 @@ def patch(target, attr, value) -> None: patch(rxnorm_client, "get_rxcui", _async_wrapper( rxnorm_client.get_rxcui, "rxnorm", - _record_link_attempt, + _record_rxnorm_attempt("get_rxcui"), + )) + patch(rxnorm_client, "approximate_term", _async_wrapper( + rxnorm_client.approximate_term, + "rxnorm", + _record_rxnorm_attempt("approximate_term"), + )) + patch(rxnorm_client, "search_by_name", _async_wrapper( + rxnorm_client.search_by_name, + "rxnorm", + _record_rxnorm_attempt("search_by_name"), + )) + patch(rxnorm_client, "get_drug_details", _async_wrapper( + rxnorm_client.get_drug_details, + "rxnorm", + _record_rxnorm_attempt("get_drug_details"), )) patch(ddinter_db.client, "lookup_by_rxcui", _async_wrapper( ddinter_db.client.lookup_by_rxcui, "ddinter_rxcui", + _record_interaction_component("ddinter_rxcui"), )) patch(ddinter_db.client, "lookup_by_name_fts", _async_wrapper( ddinter_db.client.lookup_by_name_fts, "ddinter_fts", + _record_interaction_component("ddinter_fts"), + )) + patch(openfda_client, "check_pair", _async_wrapper( + openfda_client.check_pair, + "openfda", + _record_interaction_component("openfda"), )) - patch(openfda_client, "check_pair", _async_wrapper(openfda_client.check_pair, "openfda")) patch(severity_classifier, "classify", _sync_wrapper(severity_classifier.classify, "severity")) patch(interaction_checker, "_format_openfda", _async_wrapper(interaction_checker._format_openfda, "severity")) + patch(interaction_checker, "_resolve_pair", _resolve_pair_wrapper(interaction_checker._resolve_pair)) try: yield @@ -109,14 +174,21 @@ def wrapper(*args, **kwargs): trace = active_benchmark_trace.get() start = time.perf_counter() result = None + error = None try: result = original(*args, **kwargs) return result + except Exception as exc: + error = exc + raise finally: if trace is not None: - trace.add_elapsed(key, time.perf_counter() - start) - if recorder is not None and result is not None: - recorder(trace, args, result) + elapsed = time.perf_counter() - start + trace.add_elapsed(key, elapsed) + if error is not None: + trace.record_error(key, error) + if recorder is not None: + recorder(trace, args, kwargs, result, error, round(elapsed * 1000, 3)) return wrapper @@ -127,19 +199,35 @@ async def wrapper(*args, **kwargs): trace = active_benchmark_trace.get() start = time.perf_counter() result = None + error = None try: result = await original(*args, **kwargs) return result + except Exception as exc: + error = exc + raise finally: if trace is not None: - trace.add_elapsed(key, time.perf_counter() - start) + elapsed = time.perf_counter() - start + trace.add_elapsed(key, elapsed) + if error is not None: + trace.record_error(key, error) if recorder is not None: - recorder(trace, args, result) + recorder(trace, args, kwargs, result, error, round(elapsed * 1000, 3)) return wrapper -def _record_ner_entities(trace: BenchmarkTrace, _args: tuple, entities: list) -> None: +def _record_ner_entities( + trace: BenchmarkTrace, + _args: tuple, + _kwargs: dict, + entities: list | None, + _error: Exception | None, + _elapsed_ms: float, +) -> None: + if entities is None: + return trace.ner_entities = [ { "text": entity.text, @@ -152,12 +240,214 @@ def _record_ner_entities(trace: BenchmarkTrace, _args: tuple, entities: list) -> ] -def _record_link_attempt(trace: BenchmarkTrace, args: tuple, rxcui: str | None) -> None: - trace.link_attempts.append({ - "name": args[0] if args else None, - "rxcui": rxcui, - "method": "rxnorm_exact", - }) +def _record_rxnorm_attempt(method: str): + def recorder( + trace: BenchmarkTrace, + args: tuple, + _kwargs: dict, + result, + error: Exception | None, + elapsed_ms: float, + ) -> None: + query = None if method == "get_drug_details" else args[0] if args else None + rxcui = _rxnorm_rxcui(method, result) + status = "error" if error is not None else "hit" if rxcui or _rxnorm_has_result(method, result) else "miss" + attempt = { + "stage": trace.phase, + "method": method, + "query": query, + "input_rxcui": args[0] if method == "get_drug_details" and args else None, + "rxcui": rxcui, + "status": status, + "elapsed_ms": elapsed_ms, + "output": _rxnorm_output(method, result), + } + if error is not None: + attempt["error"] = { + "class": error.__class__.__name__, + "message": str(error), + } + trace.rxnorm_attempts.append(attempt) + if method == "get_rxcui": + trace.link_attempts.append({ + "name": query, + "rxcui": rxcui, + "method": "rxnorm_exact", + "stage": trace.phase, + "status": status, + }) + + return recorder + + +def _rxnorm_has_result(method: str, result) -> bool: + if method in {"approximate_term", "search_by_name"}: + return bool(result) + return result is not None + + +def _rxnorm_rxcui(method: str, result) -> str | None: + if result is None: + return None + if method == "get_rxcui": + return str(result) if result else None + if method in {"approximate_term", "search_by_name"} and result: + return str(getattr(result[0], "rxcui", "") or "") or None + if method == "get_drug_details": + return str(result.get("rxcui") or "") if isinstance(result, dict) and result.get("rxcui") else None + return None + + +def _rxnorm_output(method: str, result): + if result is None: + return None + if method in {"approximate_term", "search_by_name"}: + return [ + { + "rxcui": getattr(candidate, "rxcui", None), + "name": getattr(candidate, "name", None), + "score": getattr(candidate, "score", None), + "synonym": getattr(candidate, "synonym", None), + "tty": getattr(candidate, "tty", None), + } + for candidate in result + ] + if isinstance(result, dict): + return { + key: result.get(key) + for key in ("rxcui", "name", "tty", "synonym") + if key in result + } or result + return result + + +def _resolve_pair_wrapper(original): + @wraps(original) + async def wrapper(drug_a: str, drug_b: str, rxcui_by_name: dict[str, str | None]): + trace = active_benchmark_trace.get() + previous_attempt = trace.active_interaction_attempt if trace is not None else None + attempt = None + if trace is not None: + attempt = { + "drug_a": drug_a, + "drug_b": drug_b, + "rxcui_a": rxcui_by_name.get(drug_a), + "rxcui_b": rxcui_by_name.get(drug_b), + } + trace.interaction_attempts.append(attempt) + trace.active_interaction_attempt = attempt + try: + entry, bucket = await original(drug_a, drug_b, rxcui_by_name) + except Exception as exc: + if trace is not None and attempt is not None: + attempt["final_source"] = "error" + attempt["miss_reason"] = "exception" + _finalize_interaction_attempt(attempt) + raise + finally: + if trace is not None: + trace.active_interaction_attempt = previous_attempt + if attempt is not None: + attempt["final_source"] = bucket + attempt["final_severity"] = entry.get("severity") if entry else None + attempt["miss_reason"] = None if entry else "no_source_hit" + _finalize_interaction_attempt(attempt) + return entry, bucket + + return wrapper + + +def _record_interaction_component(component: str): + def recorder( + trace: BenchmarkTrace, + args: tuple, + _kwargs: dict, + result, + error: Exception | None, + elapsed_ms: float, + ) -> None: + attempt = trace.active_interaction_attempt + if attempt is None: + return + block = attempt.setdefault(component, {"calls": []} if component == "openfda" else {}) + call = { + "input": _interaction_input(component, args), + "status": "error" if error is not None else "hit" if result else "miss", + "elapsed_ms": elapsed_ms, + "output": _interaction_output(component, result), + } + if error is not None: + call["error"] = { + "class": error.__class__.__name__, + "message": str(error), + } + if component == "openfda": + block.setdefault("calls", []).append(call) + if block.get("status") != "hit": + block.update({key: value for key, value in call.items() if key != "input"}) + block["input"] = call["input"] + else: + block.update(call) + + return recorder + + +def _interaction_input(component: str, args: tuple) -> dict: + if component == "ddinter_rxcui": + return { + "rxcui_a": args[0] if len(args) > 0 else None, + "rxcui_b": args[1] if len(args) > 1 else None, + } + return { + "drug_a": args[0] if len(args) > 0 else None, + "drug_b": args[1] if len(args) > 1 else None, + } + + +def _interaction_output(component: str, result): + if result is None: + return None + if component.startswith("ddinter"): + return { + key: result.get(key) + for key in ("drug_a_id", "drug_b_id", "drug_a_name", "drug_b_name", "severity", "atc_category", "source") + if key in result + } + if component == "openfda": + return { + "description_present": bool(result.get("description")) if isinstance(result, dict) else bool(result), + } + return result + + +def _finalize_interaction_attempt(attempt: dict) -> None: + rxcui_missing = not attempt.get("rxcui_a") or not attempt.get("rxcui_b") + defaults = { + "ddinter_rxcui": { + "status": "skipped", + "reason": "missing_rxcui" if rxcui_missing else "not_called", + "input": {"rxcui_a": attempt.get("rxcui_a"), "rxcui_b": attempt.get("rxcui_b")}, + "elapsed_ms": 0.0, + "output": None, + }, + "ddinter_fts": { + "status": "skipped", + "reason": "ddinter_rxcui_hit" if attempt.get("ddinter_rxcui", {}).get("status") == "hit" else "not_called", + "input": {"drug_a": attempt.get("drug_a"), "drug_b": attempt.get("drug_b")}, + "elapsed_ms": 0.0, + "output": None, + }, + "openfda": { + "status": "skipped", + "reason": "ddinter_hit" if attempt.get("final_source") == "ddinter" else "not_called", + "input": {"drug_a": attempt.get("drug_a"), "drug_b": attempt.get("drug_b")}, + "elapsed_ms": 0.0, + "output": None, + "calls": [], + }, + } + for key, value in defaults.items(): + attempt.setdefault(key, value) async def run_benchmark( @@ -173,6 +463,7 @@ async def run_benchmark( if record_timeout_seconds is not None and record_timeout_seconds <= 0: raise ValueError("record_timeout_seconds must be positive") semaphore = asyncio.Semaphore(concurrency) + wall_start = time.perf_counter() with install_benchmark_instrumentation(): record_results = await asyncio.gather(*[ _run_one_with_timeout( @@ -187,16 +478,39 @@ async def run_benchmark( seed_results = None if seed_cases is not None: seed_results = await interaction_metrics.run_seed_smoke(seed_cases, interaction_checker.check) - + wall_time_seconds = time.perf_counter() - wall_start + + rxnorm_results = linking_metrics.compute(predictions, records) + linking_results = { + key: rxnorm_results.get(key) + for key in ( + "coverage", + "fallback_rate", + "nil_rate", + "n_link_attempts", + "n_drugs_total", + "acc_at_1", + "incorrect_link_rate", + ) + } results = { + "overall": pipeline_metrics.overall( + predictions, + errors, + concurrency=concurrency, + wall_time_seconds=wall_time_seconds, + ), + "timing": pipeline_metrics.timing(predictions), "ner": ner_metrics.compute(predictions, records), - "linking": linking_metrics.compute(predictions, records), + "linking": linking_results, + "rxnorm": rxnorm_results, "interactions": interaction_metrics.compute( predictions, records, seed_cases=seed_cases, seed_results=seed_results, ), + "errors": pipeline_metrics.errors(errors), "fp_taxonomy": await fp_taxonomy.compute(predictions, records), } return { @@ -237,23 +551,29 @@ async def _run_one(record: dict) -> tuple[dict, dict | None]: error = None try: analyze_start = time.perf_counter() + trace.phase = "analyze" drugs = await drug_analyzer.analyze(str(record["ocr_text"])) trace.add_elapsed("analyze", time.perf_counter() - analyze_start) except Exception as exc: trace.add_elapsed("analyze", time.perf_counter() - analyze_start) + trace.record_error("analyze", exc) error = _error_record(record, "analyze", exc) else: interaction_start = time.perf_counter() try: + trace.phase = "interactions" interactions_response = await interaction_checker.check([drug["name"] for drug in drugs]) trace.add_elapsed("interactions", time.perf_counter() - interaction_start) except Exception as exc: trace.add_elapsed("interactions", time.perf_counter() - interaction_start) + trace.record_error("interactions", exc) error = _error_record(record, "interactions", exc) finally: + trace.phase = None trace.add_elapsed("total", time.perf_counter() - total_start) active_benchmark_trace.reset(token) + elapsed_ms = {key: trace.elapsed_ms.get(key, 0.0) for key in ELAPSED_KEYS} return { "record_id": str(record.get("id", "")), "category": record.get("category"), @@ -261,14 +581,24 @@ async def _run_one(record: dict) -> tuple[dict, dict | None]: "drugs": drugs, "interactions": interactions_response, "ner_entities": trace.ner_entities, + "ner_diagnostics": ner_metrics.diagnostics_for_entities( + trace.ner_entities, + [str(name) for name in record.get("expected_names", [])], + ), "link_attempts": trace.link_attempts, - "elapsed_ms": {key: trace.elapsed_ms.get(key, 0.0) for key in ELAPSED_KEYS}, + "rxnorm_attempts": trace.rxnorm_attempts, + "interaction_attempts": trace.interaction_attempts, + "pipeline_errors": trace.pipeline_errors, + "elapsed_ms": elapsed_ms, + "component_timings_ms": trace.component_timings(), + "slowest_component": trace.slowest_component(), }, error def _timeout_prediction(record: dict, timeout_seconds: float) -> dict: elapsed_ms = {key: 0.0 for key in ELAPSED_KEYS} elapsed_ms["total"] = round(timeout_seconds * 1000, 3) + pipeline_error = _timeout_error_record(record, timeout_seconds) return { "record_id": str(record.get("id", "")), "category": record.get("category"), @@ -276,8 +606,21 @@ def _timeout_prediction(record: dict, timeout_seconds: float) -> dict: "drugs": [], "interactions": None, "ner_entities": [], + "ner_diagnostics": ner_metrics.diagnostics_for_entities( + [], + [str(name) for name in record.get("expected_names", [])], + ), "link_attempts": [], + "rxnorm_attempts": [], + "interaction_attempts": [], + "pipeline_errors": [pipeline_error], "elapsed_ms": elapsed_ms, + "component_timings_ms": { + **elapsed_ms, + "critical_path": 0.0, + "slowest_component_ms": elapsed_ms["total"], + }, + "slowest_component": "total", } @@ -341,12 +684,18 @@ def summary_metrics(results: dict) -> dict: return { "ner_strict_f1": results.get("ner", {}).get("strict", {}).get("f1"), "ner_lenient_f1": results.get("ner", {}).get("lenient", {}).get("f1"), - "linking_coverage": results.get("linking", {}).get("coverage"), - "linking_nil_rate": results.get("linking", {}).get("nil_rate"), + "linking_coverage": results.get("rxnorm", results.get("linking", {})).get("coverage"), + "linking_nil_rate": results.get("rxnorm", results.get("linking", {})).get("nil_rate"), "interactions_total_pairs_checked": results.get("interactions", {}).get("descriptive", {}).get("total_pairs_checked"), "interactions_ddinter_hit_rate": results.get("interactions", {}).get("descriptive", {}).get("ddinter_hit_rate"), + "interactions_ddinter_rxcui_hit_rate": results.get("interactions", {}).get("descriptive", {}).get("ddinter_rxcui_hit_rate"), "interactions_openfda_hit_rate": results.get("interactions", {}).get("descriptive", {}).get("openfda_hit_rate"), "interactions_unknown_rate": results.get("interactions", {}).get("descriptive", {}).get("unknown_rate"), + "records_completed": results.get("overall", {}).get("records_completed"), + "records_errored": results.get("overall", {}).get("records_errored"), + "timeout_count": results.get("overall", {}).get("timeout_count"), + "slowest_component": results.get("timing", {}).get("slowest_component"), + "total_p95_ms": results.get("timing", {}).get("components", {}).get("total", {}).get("p95_ms"), "seed_smoke_recall": results.get("interactions", {}).get("seed_smoke", {}).get("recall"), "seed_smoke_false_alarm_rate": results.get("interactions", {}).get("seed_smoke", {}).get("false_alarm_rate"), "fp_total": results.get("fp_taxonomy", {}).get("total_fp"), @@ -355,9 +704,30 @@ def summary_metrics(results: dict) -> dict: def summary_markdown(results: dict) -> str: metrics = summary_metrics(results) - lines = ["# PillChecker benchmark summary", ""] + lines = [ + "# PillChecker benchmark summary", + "", + "> Interaction and RxNorm metrics are not accuracy-certified unless reviewed labels are present.", + "", + "## Top-line metrics", + "", + ] for key, value in metrics.items(): lines.append(f"- `{key}`: {value}") + lines.extend(["", "## Component timing", ""]) + for component, values in results.get("timing", {}).get("components", {}).items(): + lines.append( + f"- `{component}`: p50={values.get('p50_ms')} ms, " + f"p95={values.get('p95_ms')} ms, p99={values.get('p99_ms')} ms" + ) + lines.extend(["", "## RxNorm diagnostics", ""]) + for item in results.get("rxnorm", results.get("linking", {})).get("unresolved_queries", [])[:10]: + lines.append(f"- unresolved `{item.get('query')}` via `{item.get('method')}` in `{item.get('stage')}`") + lines.extend(["", "## Interaction diagnostics", ""]) + descriptive = results.get("interactions", {}).get("descriptive", {}) + lines.append(f"- source counts: `{descriptive.get('source_counts')}`") + for item in descriptive.get("top_unknown_pairs", [])[:10]: + lines.append(f"- unknown pair `{item.get('drug_a')}` + `{item.get('drug_b')}`: {item.get('count')}") lines.append("") return "\n".join(lines) @@ -371,6 +741,7 @@ def build_manifest( sample_size: int, output_prefix: str, results: dict, + concurrency: int | None = None, random_seed: int | str | None = None, ddinter_db: dict | None = None, ) -> dict: @@ -396,6 +767,8 @@ def build_manifest( "rxnorm_fallback_min_score": 10.0, }, "sample_size": sample_size, + "concurrency": concurrency, + "metric_schema_version": "benchmark-diagnostics-v1", "random_seed": random_seed, "metrics": summary_metrics(results), } @@ -476,6 +849,7 @@ async def _main_async(args: argparse.Namespace) -> int: dataset_path=meta.path, command=" ".join(sys.argv), sample_size=len(records), + concurrency=args.concurrency, output_prefix=output_prefix, results=output["results"], random_seed=args.random_seed, diff --git a/tests/eval/test_interaction_metrics.py b/tests/eval/test_interaction_metrics.py index 782affe..81ce97a 100644 --- a/tests/eval/test_interaction_metrics.py +++ b/tests/eval/test_interaction_metrics.py @@ -23,6 +23,62 @@ def test_interaction_descriptive_counts_all_buckets(): assert result["descriptive"]["uncertain_rate"] == 0.5 +def test_interaction_descriptive_uses_attempt_diagnostics(): + predictions = [{ + "record_id": "1", + "interactions": { + "coverage_summary": {"ddinter": 1, "openfda": 1, "unknown": 1}, + "interactions": [], + }, + "interaction_attempts": [ + { + "drug_a": "warfarin", + "drug_b": "ibuprofen", + "ddinter_rxcui": {"status": "hit"}, + "ddinter_fts": {"status": "skipped"}, + "openfda": {"status": "skipped"}, + "final_source": "ddinter", + "final_severity": "major", + }, + { + "drug_a": "a", + "drug_b": "b", + "ddinter_rxcui": {"status": "miss"}, + "ddinter_fts": {"status": "hit"}, + "openfda": {"status": "skipped"}, + "final_source": "ddinter", + "final_severity": "moderate", + }, + { + "drug_a": "c", + "drug_b": "d", + "ddinter_rxcui": {"status": "miss"}, + "ddinter_fts": {"status": "miss"}, + "openfda": {"status": "hit"}, + "final_source": "openfda", + "final_severity": "minor", + }, + { + "drug_a": "e", + "drug_b": "f", + "ddinter_rxcui": {"status": "skipped"}, + "ddinter_fts": {"status": "miss"}, + "openfda": {"status": "miss"}, + "final_source": "unknown", + "miss_reason": "no_source_hit", + }, + ], + }] + + result = interactions.compute(predictions, [{"id": "1"}]) + + assert result["descriptive"]["ddinter_rxcui_hit_rate"] == 0.25 + assert result["descriptive"]["ddinter_fts_rescue_rate"] == 0.25 + assert result["descriptive"]["openfda_rescue_rate"] == 0.25 + assert result["descriptive"]["source_counts"] == {"ddinter": 2, "openfda": 1, "unknown": 1} + assert result["descriptive"]["top_unknown_pairs"] == [{"drug_a": "e", "drug_b": "f", "count": 1}] + + def test_interaction_accuracy_none_without_reviewed_labels(): dataset = [{"id": "1", "expected_interactions": []}] diff --git a/tests/eval/test_linking_metrics.py b/tests/eval/test_linking_metrics.py index 3c8248d..276b2ad 100644 --- a/tests/eval/test_linking_metrics.py +++ b/tests/eval/test_linking_metrics.py @@ -22,6 +22,84 @@ def test_linking_nil_rate_from_explicit_attempts(): assert result["n_link_attempts"] == 3 +def test_linking_reports_rxnorm_attempt_diagnostics(): + predictions = [{ + "record_id": "1", + "drugs": [ + {"name": "Advil", "rxcui": "5640", "source": "rxnorm_fallback"}, + {"name": "unknown", "rxcui": None, "source": "ner"}, + ], + "rxnorm_attempts": [ + { + "stage": "analyze", + "method": "get_rxcui", + "query": "unknown", + "rxcui": None, + "status": "miss", + "elapsed_ms": 1.0, + }, + { + "stage": "analyze", + "method": "approximate_term", + "query": "Advil", + "rxcui": "5640", + "status": "hit", + "elapsed_ms": 2.0, + }, + { + "stage": "interactions", + "method": "get_rxcui", + "query": "Ibuprofen", + "rxcui": "5640", + "status": "hit", + "elapsed_ms": 1.5, + }, + ], + }] + + result = linking.compute(predictions, [{"id": "1"}]) + + assert result["n_rxnorm_attempts"] == 3 + assert result["rxnorm_by_method"]["get_rxcui"]["hit"] == 1 + assert result["rxnorm_by_method"]["get_rxcui"]["miss"] == 1 + assert result["rxnorm_by_method"]["approximate_term"]["hit"] == 1 + assert result["unresolved_queries"] == [{"query": "unknown", "stage": "analyze", "method": "get_rxcui"}] + assert result["canonicalization_collisions"] == [{ + "rxcui": "5640", + "queries": ["Advil", "Ibuprofen"], + }] + + +def test_linking_ignores_get_drug_details_for_canonicalization_collisions(): + predictions = [{ + "record_id": "1", + "drugs": [{"name": "Advil", "rxcui": "5640", "source": "rxnorm_fallback"}], + "rxnorm_attempts": [ + { + "stage": "analyze", + "method": "approximate_term", + "query": "Advil", + "rxcui": "5640", + "status": "hit", + "elapsed_ms": 2.0, + }, + { + "stage": "analyze", + "method": "get_drug_details", + "query": "5640", + "input_rxcui": "5640", + "rxcui": "5640", + "status": "hit", + "elapsed_ms": 1.0, + }, + ], + }] + + result = linking.compute(predictions, [{"id": "1"}]) + + assert result["canonicalization_collisions"] == [] + + def test_linking_nil_rate_none_without_attempts(): result = linking.compute([{"record_id": "1", "drugs": [], "link_attempts": []}], [{"id": "1"}]) diff --git a/tests/eval/test_run_benchmark.py b/tests/eval/test_run_benchmark.py index ed5d702..44abaed 100644 --- a/tests/eval/test_run_benchmark.py +++ b/tests/eval/test_run_benchmark.py @@ -90,6 +90,36 @@ async def no_openfda(_drug_a, _drug_b): assert prediction["interactions"]["interactions"][0]["source"] == "ddinter" assert prediction["ner_entities"][0]["text"] == "warfarin" assert len(prediction["link_attempts"]) >= 2 + assert prediction["component_timings_ms"]["total"] >= 0.0 + assert prediction["component_timings_ms"]["critical_path"] >= prediction["component_timings_ms"]["slowest_component_ms"] + assert prediction["slowest_component"] in { + "ocr_clean", + "ner", + "rxnorm", + "ddinter_rxcui", + "ddinter_fts", + "openfda", + "severity", + } + assert prediction["ner_diagnostics"]["strict"]["tp"] == 2 + assert prediction["ner_diagnostics"]["strict"]["fp"] == 0 + assert prediction["ner_diagnostics"]["strict"]["fn"] == 0 + assert all(attempt["stage"] in {"analyze", "interactions"} for attempt in prediction["rxnorm_attempts"]) + assert { + attempt["query"] + for attempt in prediction["rxnorm_attempts"] + if attempt["method"] == "get_rxcui" + } >= {"warfarin", "ibuprofen"} + interaction_attempt = prediction["interaction_attempts"][0] + assert interaction_attempt["drug_a"] == "warfarin" + assert interaction_attempt["drug_b"] == "ibuprofen" + assert interaction_attempt["rxcui_a"] == "11289" + assert interaction_attempt["rxcui_b"] == "5640" + assert interaction_attempt["ddinter_rxcui"]["status"] == "hit" + assert interaction_attempt["ddinter_rxcui"]["output"]["drug_a_id"] == "DDI00001" + assert interaction_attempt["final_source"] == "ddinter" + assert interaction_attempt["final_severity"] == "major" + assert prediction["pipeline_errors"] == [] assert set(prediction["elapsed_ms"]) == { "ocr_clean", "ner", @@ -103,6 +133,11 @@ async def no_openfda(_drug_a, _drug_b): "total", } assert output["results"]["ner"]["strict"]["f1"] == 1.0 + assert output["results"]["overall"]["records_completed"] == 1 + assert output["results"]["timing"]["components"]["total"]["p50_ms"] >= 0.0 + assert output["results"]["rxnorm"]["n_rxnorm_attempts"] >= 2 + assert "n_rxnorm_attempts" not in output["results"]["linking"] + assert output["results"]["interactions"]["descriptive"]["ddinter_rxcui_hit_rate"] == 1.0 assert output["results"]["interactions"]["accuracy"] is None artifacts = run_benchmark.write_local_outputs( @@ -119,12 +154,23 @@ async def no_openfda(_drug_a, _drug_b): def test_manifest_summary_contains_top_line_metric_scalars(): results = { + "overall": {"records_completed": 25, "records_errored": 0, "timeout_count": 0}, + "timing": { + "slowest_component": "rxnorm", + "components": {"total": {"p95_ms": 123.0}}, + }, "ner": {"strict": {"f1": 0.8}, "lenient": {"f1": 0.9}}, "linking": {"coverage": 0.75, "nil_rate": 0.1}, + "rxnorm": {"coverage": 0.75, "nil_rate": 0.1}, "interactions": { - "descriptive": {"total_pairs_checked": 4, "ddinter_hit_rate": 0.5}, + "descriptive": { + "total_pairs_checked": 4, + "ddinter_hit_rate": 0.5, + "ddinter_rxcui_hit_rate": 0.25, + }, "seed_smoke": {"recall": 1.0, "false_alarm_rate": 0.0}, }, + "errors": {"total": 0}, "fp_taxonomy": {"total_fp": 3}, } @@ -137,14 +183,57 @@ def test_manifest_summary_contains_top_line_metric_scalars(): "linking_nil_rate": 0.1, "interactions_total_pairs_checked": 4, "interactions_ddinter_hit_rate": 0.5, + "interactions_ddinter_rxcui_hit_rate": 0.25, "interactions_openfda_hit_rate": None, "interactions_unknown_rate": None, + "records_completed": 25, + "records_errored": 0, + "timeout_count": 0, + "slowest_component": "rxnorm", + "total_p95_ms": 123.0, "seed_smoke_recall": 1.0, "seed_smoke_false_alarm_rate": 0.0, "fp_total": 3, } +def test_trace_component_timings_use_leaf_components_for_bottlenecks(): + trace = run_benchmark.BenchmarkTrace() + trace.elapsed_ms.update({ + "ocr_clean": 2.0, + "ner": 3.0, + "rxnorm": 10.0, + "ddinter_rxcui": 4.0, + "ddinter_fts": 5.0, + "openfda": 6.0, + "severity": 7.0, + "analyze": 100.0, + "interactions": 200.0, + "total": 300.0, + }) + + timings = trace.component_timings() + + assert trace.slowest_component() == "rxnorm" + assert timings["slowest_component_ms"] == 10.0 + assert timings["critical_path"] == 37.0 + + +def test_trace_pipeline_errors_dedupe_same_exception(): + trace = run_benchmark.BenchmarkTrace() + exc = RuntimeError("same failure") + + trace.record_error("openfda", exc) + trace.record_error("interactions", exc) + + assert trace.pipeline_errors == [{ + "stage": "openfda", + "stages": ["openfda", "interactions"], + "error_class": "RuntimeError", + "message": "same failure", + }] + + @pytest.mark.asyncio async def test_run_benchmark_records_per_record_errors(monkeypatch): async def broken_analyze(_text): @@ -166,6 +255,7 @@ async def broken_analyze(_text): assert output["predictions"][0]["record_id"] == "case-error" assert output["predictions"][0]["drugs"] == [] + assert output["predictions"][0]["pipeline_errors"][0]["stage"] == "analyze" assert output["errors"] == [{ "record_id": "case-error", "stage": "analyze", @@ -197,6 +287,7 @@ async def hanging_analyze(_text): assert output["predictions"][0]["record_id"] == "case-timeout" assert output["predictions"][0]["drugs"] == [] + assert output["predictions"][0]["pipeline_errors"][0]["stage"] == "record_timeout" assert output["errors"] == [{ "record_id": "case-timeout", "stage": "record_timeout", @@ -253,6 +344,7 @@ def test_manifest_includes_ddinter_release_metadata(): dataset_path="data/benchmark.json", command="python -m eval.run_benchmark", sample_size=1, + concurrency=8, output_prefix="benchmark-results/2026-05-21/tier1-test/", results={"ner": {}, "linking": {}, "interactions": {}, "fp_taxonomy": {}}, random_seed=None, @@ -270,6 +362,8 @@ def test_manifest_includes_ddinter_release_metadata(): "asset": "ddinter.db", "sha256": "abc123", } + assert manifest["metric_schema_version"] == "benchmark-diagnostics-v1" + assert manifest["concurrency"] == 8 @pytest.mark.asyncio diff --git a/tests/eval/test_schemas.py b/tests/eval/test_schemas.py index 32dffa0..1fbf5a4 100644 --- a/tests/eval/test_schemas.py +++ b/tests/eval/test_schemas.py @@ -47,6 +47,23 @@ def test_interaction_label_candidates_schema_rejects_ground_truth_true(): def test_benchmark_results_schema_contains_required_metric_blocks(): schema = _load_schema("benchmark_results.schema.json") result = { + "overall": { + "records_total": 1, + "records_completed": 1, + "records_errored": 0, + "error_rate": 0.0, + "timeout_count": 0, + "concurrency": 1, + "wall_time_seconds": 0.1, + "records_per_second": 10.0, + }, + "timing": { + "components": { + "total": {"p50_ms": 1.0, "p95_ms": 1.0, "p99_ms": 1.0, "max_ms": 1.0, "mean_ms": 1.0} + }, + "slowest_component": "total", + "slowest_component_counts": {"total": 1}, + }, "ner": { "strict": {"precision": 1.0, "recall": 1.0, "f1": 1.0}, "lenient": {"precision": 1.0, "recall": 1.0, "f1": 1.0}, @@ -65,6 +82,19 @@ def test_benchmark_results_schema_contains_required_metric_blocks(): "acc_at_1": None, "incorrect_link_rate": None, }, + "rxnorm": { + "coverage": 1.0, + "fallback_rate": 0.0, + "nil_rate": None, + "n_link_attempts": 0, + "n_drugs_total": 1, + "acc_at_1": None, + "incorrect_link_rate": None, + "n_rxnorm_attempts": 0, + "rxnorm_by_method": {}, + "unresolved_queries": [], + "canonicalization_collisions": [], + }, "interactions": { "descriptive": { "total_pairs_checked": 0, @@ -74,6 +104,11 @@ def test_benchmark_results_schema_contains_required_metric_blocks(): "severity_distribution": {"minor": 0, "moderate": 0, "major": 0, "unknown": 0}, "uncertain_rate": 0.0, "records_with_any_interaction": 0, + "ddinter_rxcui_hit_rate": 0.0, + "ddinter_fts_rescue_rate": 0.0, + "openfda_rescue_rate": 0.0, + "source_counts": {"ddinter": 0, "openfda": 0, "unknown": 0}, + "top_unknown_pairs": [], }, "accuracy": None, "seed_smoke": { @@ -83,6 +118,12 @@ def test_benchmark_results_schema_contains_required_metric_blocks(): "missed_pairs": [], }, }, + "errors": { + "total": 0, + "by_stage": {}, + "by_class": {}, + "records": [], + }, "fp_taxonomy": { "brand": {"count": 0, "examples": []}, "salt": {"count": 0, "examples": []},