From 5b221800a4e58aca5283c155f6484619ed5d42e7 Mon Sep 17 00:00:00 2001 From: wlg1 Date: Sun, 24 May 2026 20:41:24 -0700 Subject: [PATCH 1/4] time shuffle traj ablation expms --- ablate_cot_length_expms/README.md | 9 + ablate_cot_length_expms/__init__.py | 1 + .../time_shuffle/README.md | 158 +++++++++ .../time_shuffle/__init__.py | 1 + .../time_shuffle/run_compare_base_rft.py | 273 ++++++++++++++++ .../time_shuffle/run_time_shuffle.py | 308 ++++++++++++++++++ .../time_shuffle/shuffle.py | 104 ++++++ 7 files changed, 854 insertions(+) create mode 100644 ablate_cot_length_expms/README.md create mode 100644 ablate_cot_length_expms/__init__.py create mode 100644 ablate_cot_length_expms/time_shuffle/README.md create mode 100644 ablate_cot_length_expms/time_shuffle/__init__.py create mode 100644 ablate_cot_length_expms/time_shuffle/run_compare_base_rft.py create mode 100644 ablate_cot_length_expms/time_shuffle/run_time_shuffle.py create mode 100644 ablate_cot_length_expms/time_shuffle/shuffle.py diff --git a/ablate_cot_length_expms/README.md b/ablate_cot_length_expms/README.md new file mode 100644 index 0000000..092fba8 --- /dev/null +++ b/ablate_cot_length_expms/README.md @@ -0,0 +1,9 @@ +# CoT trace & pipeline ablations + +Controlled experiments for reviewer-style confound checks on the CEBRA + SDS pipeline. + +| Subfolder | Reviewer item | Description | +|-----------|---------------|-------------| +| [`time_shuffle/`](time_shuffle/) | Time-shuffled trajectory | Permute **sentence-aligned activation steps** before SDS/EM; CEBRA still trained on real CoT order | + + diff --git a/ablate_cot_length_expms/__init__.py b/ablate_cot_length_expms/__init__.py new file mode 100644 index 0000000..c4b946a --- /dev/null +++ b/ablate_cot_length_expms/__init__.py @@ -0,0 +1 @@ +# Ablation experiments on CoT trace/activations. diff --git a/ablate_cot_length_expms/time_shuffle/README.md b/ablate_cot_length_expms/time_shuffle/README.md new file mode 100644 index 0000000..a8d2ba5 --- /dev/null +++ b/ablate_cot_length_expms/time_shuffle/README.md @@ -0,0 +1,158 @@ +# Ablation (2): Time-shuffled trajectories before SDS + +**Goal:** Sanity-check that SDS metrics depend on real temporal order. CEBRA is trained on **real** CoT order; only the **sequences passed to EM** are permuted. + +**Good outcome:** On real order, RFT beats base on persistence, self-transition, ΔR², and similar metrics). + +Under shuffle, both models’ metrics drop sharply, and the RFT − base gap is smaller than on real order. + +This means that "not any order gives the same sticky regimes and transitions". This requires the "logical order" to give the desired gap between base and RFT. + +## EXAMPLE: Non-shuffled vs shuffled trajectory (same problem) + +One GSM8K-style problem produces four sentence-level steps. CEBRA embeddings are written as `z_t`; SDS infers a discrete regime `s_t` at each timestep (example regimes shown for intuition only). + +### Non-shuffled (real CoT order) — input to SDS + +| Timestep | Sentence in original CoT | Embedding row | Example regime (SDS decode) | +|----------|--------------------------|---------------|-----------------------------| +| 0 | "Let me define the variables." | `z₀` | PLAN | +| 1 | "Substituting into the equation…" | `z₁` | COMPUTE | +| 2 | "Wait, let me check that step." | `z₂` | VERIFY | +| 3 | "So the final answer is 42." | `z₃` | ANSWER | + +**Trajectory passed to EM:** `[z₀, z₁, z₂, z₃]` +**Timeline matches how the model actually wrote the trace** — plan → compute → verify → answer. + +SDS can learn transitions like PLAN→COMPUTE→VERIFY→ANSWER and long runs in the same regime when the reasoning stays in one mode for several sentences. + +### Shuffled (same vectors, permuted order) — input to SDS + +Same four sentences and the **same** `z₀…z₃` (each row still belongs to its original sentence). Only the **order of rows** changes, e.g. `perm = [2, 0, 3, 1]`: + +| Timestep | Sentence (still tied to that row) | Embedding row | Example regime (SDS decode) | +|----------|-----------------------------------|---------------|-----------------------------| +| 0 | "Wait, let me check that step." | `z₂` | (varies) | +| 1 | "Let me define the variables." | `z₀` | (varies) | +| 2 | "So the final answer is 42." | `z₃` | (varies) | +| 3 | "Substituting into the equation…" | `z₁` | (varies) | + +**Trajectory passed to EM:** `[z₂, z₀, z₃, z₁]` +**Timeline does not match the model’s reasoning flow** — verify appears before setup; answer before substitution. + +CEBRA was **not** retrained on this order; only SDS is refit on the scrambled sequence. + +### Side-by-side summary + +``` +Non-shuffled (real): z₀ → z₁ → z₂ → z₃ (setup → compute → verify → answer) +Shuffled (example): z₂ → z₀ → z₃ → z₁ (verify → setup → answer → compute) + └── same four vectors, different positions ──┘ +``` + +## What gets shuffled (what happens in code) + +Each problem yields one trajectory from `all_sentences_features.pkl`. Each timestep is one **sentence-aligned step**: + +| Index | Sentence (from CoT) | Stored in pickle | In trajectory | +|-------|---------------------|------------------|---------------| +| 0 | "Let me set up the equation." | `hidden_state_last`₀, `stage`₀ | row 0 of `z_seq` | +| 1 | "We substitute x = 3." | `hidden_state_last`₁, `stage`₁ | row 1 | +| 2 | "Therefore the answer is 12." | `hidden_state_last`₂, `stage`₂ | row 2 | + +`shuffle_sequence_lists` applies one `perm` to `cebra_seqs`, `pca_seqs`, and `labels` in parallel: + +``` +Real: z = [z₀, z₁, z₂] labels = [L₀, L₁, L₂] +Shuffled: z = [z₂, z₀, z₁] labels = [L₂, L₀, L₁] (perm = [2, 0, 1]) +``` + +### Block shuffle (optional) + +With `block_size=2` on the four-step example, blocks are `[0,1]` and `[2,3]`. If block order is reversed: + +``` +Non-shuffled: [z₀, z₁, z₂, z₃] +Block-shuffled: [z₂, z₃, z₀, z₁] # (compute+verify) block before (setup) block +``` + +Local order inside each block is preserved; global story order is still wrong. + +## What this is meant to test + +| Question | If real ≫ shuffled | If real ≈ shuffled | +|----------|-------------------|-------------------| +| Does SDS need sequential structure? | Yes — readout uses time order | Suspect order-invariant artifacts | +| Is persistence just static clustering? | Less likely | Investigate further | + + +## Design (pipeline) + +1. Load `all_sentences_features.pkl` (from `generate_data/create_dataset.py`). +2. Train CEBRA → `cebra_seqs`, `pca_seqs`, `labels` (real order). +3. **Real order:** fit SDS (EM); report persistence, \(K_{\mathrm{eff}}\), self-transition, \(\Delta R^2\), BIC. +4. **Shuffled order:** `shuffle_sequence_lists(...)` then refit SDS (same \(K\), multiple shuffle seeds). +5. Compare real vs shuffled in `summary.json` / `summary.csv`. + +Run the script separately on base and reasoning feature pickles and compare gaps. + +## Usage + +### Single model (real + shuffled in one run) + +From repo root: + +```bash +python ablate_cot_length_expms/time_shuffle/run_time_shuffle.py \ + --features-path /path/to/all_sentences_features.pkl \ + --limit-problems 500 \ + --shuffle-seeds 0 1 2 3 4 \ + --out-dir ablate_cot_length_expms/time_shuffle/results +``` + +### Base + RFT: four conditions, auto pass/fail + +Runs base and RFT pickles (each: real-order SDS + shuffled-order SDS), then checks: + +- Real ≫ shuffled per model (persistence, self-transition, ΔR²) +- RFT_real > Base_real +- (RFT − Base)_real > (RFT − Base)_shuffled + +```bash +python ablate_cot_length_expms/time_shuffle/run_compare_base_rft.py \ + --base-features-path /path/to/base/all_sentences_features.pkl \ + --rft-features-path /path/to/rft/all_sentences_features.pkl \ + --limit-problems 500 \ + --k-values 4 5 6 \ + --k-focus 5 \ + --out-dir ablate_cot_length_expms/time_shuffle/results/compare_base_rft +``` + +Writes `results/compare_base_rft//base/summary.json`, `rft/summary.json`, and `comparison_report.json`. Exit code **0** if all checks pass at `--k-focus`, **1** otherwise. + +Re-evaluate without retraining: + +```bash +python ablate_cot_length_expms/time_shuffle/run_compare_base_rft.py \ + --base-features-path ... --rft-features-path ... \ + --skip-run --run-name \ + --out-dir ablate_cot_length_expms/time_shuffle/results/compare_base_rft +``` + +Optional: cache CEBRA embeddings to skip retraining: + +```bash +python ablate_cot_length_expms/time_shuffle/run_time_shuffle.py \ + --features-path /path/to/all_sentences_features.pkl \ + --cebra-cache ablate_cot_length_expms/time_shuffle/results/cebra_cache.pkl +``` + +## Outputs + +- `results//summary.json` — per-\(K\) metrics for `real` and `shuffled` (mean/std over shuffle seeds) +- `results//summary.csv` — flat table for plotting + +## Shuffle modes + +- `full` (default): uniform random permutation of all sentence indices in a trajectory. +- `block`: permute blocks of `block_size` consecutive sentences (preserves local adjacency). diff --git a/ablate_cot_length_expms/time_shuffle/__init__.py b/ablate_cot_length_expms/time_shuffle/__init__.py new file mode 100644 index 0000000..5e3a0e9 --- /dev/null +++ b/ablate_cot_length_expms/time_shuffle/__init__.py @@ -0,0 +1 @@ +# Time-shuffle ablation before SDS fitting. diff --git a/ablate_cot_length_expms/time_shuffle/run_compare_base_rft.py b/ablate_cot_length_expms/time_shuffle/run_compare_base_rft.py new file mode 100644 index 0000000..8a9d7ae --- /dev/null +++ b/ablate_cot_length_expms/time_shuffle/run_compare_base_rft.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +""" +Run time-shuffle ablation on base and RFT feature pickles, then check desired outcomes. + +Each pickle run produces real-order and shuffled-order SDS metrics (two conditions per +model). This script runs both models and evaluates: + + 1. Analyzer: real >> shuffled for each model (persistence, delta_r2, self-transition). + 2. Paper: RFT_real > Base_real on key metrics. + 3. Temporal gap: (RFT - Base)_real > (RFT - Base)_shuffled. + +Exit code 0 if all checks pass at --k-focus (default 5); 1 otherwise. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from datetime import datetime, timezone +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from run_time_shuffle import execute_ablation # noqa: E402 + +METRICS_HIGHER_IS_RFT_BETTER = ("persistence", "mean_self_transition", "delta_r2", "K_eff") +METRICS_COLLAPSE_UNDER_SHUFFLE = ("persistence", "mean_self_transition", "delta_r2") + + +def _metric_at_k(summary: dict, k: int, order: str, field: str) -> float | None: + """order: 'real' or shuffled mean field suffix _mean.""" + if order == "real": + for row in summary["per_run"]: + if row["order"] == "real" and row["K"] == k: + return float(row[field]) + return None + for comp in summary["comparison"]: + if comp["K"] == k: + return float(comp["shuffled"][f"{field}_mean"]) + return None + + +def evaluate_pair( + base_summary: dict, + rft_summary: dict, + k: int, + *, + min_real_vs_shuf_ratio: float = 1.05, + min_gap_shrink_ratio: float = 1.05, + epsilon: float = 1e-9, +) -> dict: + """ + Check desired outcomes at regime count K. + + min_real_vs_shuf_ratio: real metric must be >= this * shuffled (per model). + min_gap_shrink_ratio: gap_real must be >= this * gap_shuffled (strict > if ratio 1.0). + """ + checks: list[dict] = [] + + def add_check(name: str, passed: bool, detail: str, values: dict | None = None) -> None: + checks.append({"name": name, "passed": passed, "detail": detail, "values": values or {}}) + + for model_tag, summary in (("base", base_summary), ("rft", rft_summary)): + for metric in METRICS_COLLAPSE_UNDER_SHUFFLE: + real_v = _metric_at_k(summary, k, "real", metric) + shuf_v = _metric_at_k(summary, k, "shuffled", metric) + if real_v is None or shuf_v is None: + add_check( + f"{model_tag}_{metric}_collapse", + False, + f"missing data for K={k}", + {"real": real_v, "shuffled": shuf_v}, + ) + continue + # delta_r2 can be negative; use absolute drop or ratio only when shuf near zero + if metric == "delta_r2": + passed = real_v > shuf_v + 0.01 + else: + passed = real_v >= min_real_vs_shuf_ratio * (shuf_v + epsilon) + add_check( + f"{model_tag}_{metric}_real_gt_shuffled", + passed, + f"{model_tag} real={real_v:.4f} vs shuffled={shuf_v:.4f}", + {"real": real_v, "shuffled": shuf_v}, + ) + + for metric in METRICS_HIGHER_IS_RFT_BETTER: + b_real = _metric_at_k(base_summary, k, "real", metric) + r_real = _metric_at_k(rft_summary, k, "real", metric) + if b_real is None or r_real is None: + add_check(f"rft_gt_base_real_{metric}", False, f"missing real-order data K={k}") + continue + add_check( + f"rft_gt_base_real_{metric}", + r_real > b_real, + f"RFT_real={r_real:.4f} vs Base_real={b_real:.4f}", + {"rft": r_real, "base": b_real}, + ) + + for metric in ("persistence", "delta_r2", "mean_self_transition"): + gap_real = None + gap_shuf = None + b_r = _metric_at_k(base_summary, k, "real", metric) + r_r = _metric_at_k(rft_summary, k, "real", metric) + b_s = _metric_at_k(base_summary, k, "shuffled", metric) + r_s = _metric_at_k(rft_summary, k, "shuffled", metric) + if None not in (b_r, r_r, b_s, r_s): + gap_real = r_r - b_r + gap_shuf = r_s - b_s + if metric == "delta_r2": + passed = gap_real > gap_shuf + 0.01 + else: + passed = gap_real >= min_gap_shrink_ratio * (gap_shuf + epsilon) + add_check( + f"gap_real_gt_gap_shuffled_{metric}", + passed, + f"gap_real={gap_real:.4f} vs gap_shuffled={gap_shuf:.4f}", + {"gap_real": gap_real, "gap_shuffled": gap_shuf}, + ) + else: + add_check(f"gap_real_gt_gap_shuffled_{metric}", False, "missing metrics for gap comparison") + + all_passed = all(c["passed"] for c in checks) + return {"K": k, "all_passed": all_passed, "checks": checks} + + +def parse_args(): + p = argparse.ArgumentParser( + description="Run base+RFT time-shuffle ablations and evaluate desired outcomes.", + ) + p.add_argument("--base-features-path", type=Path, required=True) + p.add_argument("--rft-features-path", type=Path, required=True) + p.add_argument("--limit-problems", type=int, default=500) + p.add_argument("--shuffle-mode", default="full", choices=["full", "block"]) + p.add_argument("--block-size", type=int, default=3) + p.add_argument("--shuffle-seeds", type=int, nargs="+", default=[0, 1, 2, 3, 4]) + p.add_argument( + "--k-values", + type=int, + nargs="+", + default=[4, 5, 6], + help="K values to fit; checks run at --k-focus", + ) + p.add_argument("--k-focus", type=int, default=5, help="K used for pass/fail criteria") + p.add_argument( + "--out-dir", + type=Path, + default=Path(__file__).parent / "results" / "compare_base_rft", + ) + p.add_argument("--run-name", type=str, default=None) + p.add_argument( + "--skip-run", + action="store_true", + help="Only evaluate existing summaries under out-dir/run-name/{base,rft}/", + ) + p.add_argument( + "--min-real-vs-shuf-ratio", + type=float, + default=1.05, + help="Real metric must be at least this times shuffled (per model)", + ) + p.add_argument( + "--min-gap-shrink-ratio", + type=float, + default=1.0, + help="Require gap_real >= ratio * gap_shuffled (use 1.05 for strict)", + ) + return p.parse_args() + + +def main() -> int: + args = parse_args() + run_name = args.run_name or datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + run_dir = args.out_dir / run_name + base_out = run_dir / "base" + rft_out = run_dir / "rft" + + if not args.skip_run: + run_dir.mkdir(parents=True, exist_ok=True) + print("=== Base model (real + shuffled SDS) ===", flush=True) + base_summary = execute_ablation( + args.base_features_path, + limit_problems=args.limit_problems, + shuffle_mode=args.shuffle_mode, + block_size=args.block_size, + shuffle_seeds=args.shuffle_seeds, + k_values=args.k_values, + cebra_cache=run_dir / "cebra_cache_base.pkl", + out_dir=base_out, + model_tag="base", + ) + print("\n=== RFT / reasoning model (real + shuffled SDS) ===", flush=True) + rft_summary = execute_ablation( + args.rft_features_path, + limit_problems=args.limit_problems, + shuffle_mode=args.shuffle_mode, + block_size=args.block_size, + shuffle_seeds=args.shuffle_seeds, + k_values=args.k_values, + cebra_cache=run_dir / "cebra_cache_rft.pkl", + out_dir=rft_out, + model_tag="rft", + ) + else: + with open(base_out / "summary.json", encoding="utf-8") as f: + base_summary = json.load(f) + with open(rft_out / "summary.json", encoding="utf-8") as f: + rft_summary = json.load(f) + + if args.k_focus not in args.k_values: + print( + f"Warning: --k-focus {args.k_focus} not in --k-values {args.k_values}; " + "evaluation may use missing K.", + flush=True, + ) + + evaluation = evaluate_pair( + base_summary, + rft_summary, + args.k_focus, + min_real_vs_shuf_ratio=args.min_real_vs_shuf_ratio, + min_gap_shrink_ratio=args.min_gap_shrink_ratio, + ) + + # Optional: evaluate all K in k_values for reporting + by_k = { + str(k): evaluate_pair( + base_summary, + rft_summary, + k, + min_real_vs_shuf_ratio=args.min_real_vs_shuf_ratio, + min_gap_shrink_ratio=args.min_gap_shrink_ratio, + ) + for k in args.k_values + } + + report = { + "run_name": run_name, + "base_features_path": str(args.base_features_path), + "rft_features_path": str(args.rft_features_path), + "k_focus": args.k_focus, + "k_values": args.k_values, + "desired_outcome_summary": ( + "RFT_real > Base_real; both real >> shuffled; (RFT-Base)_real > (RFT-Base)_shuffled" + ), + "evaluation_at_k_focus": evaluation, + "evaluation_by_k": by_k, + "overall_pass": evaluation["all_passed"], + } + + report_path = run_dir / "comparison_report.json" + with open(report_path, "w", encoding="utf-8") as f: + json.dump(report, f, indent=2) + + print(f"\nWrote {report_path}", flush=True) + print(f"\n=== Evaluation at K={args.k_focus} ===", flush=True) + for c in evaluation["checks"]: + status = "PASS" if c["passed"] else "FAIL" + print(f" [{status}] {c['name']}: {c['detail']}", flush=True) + + print( + f"\nOverall: {'PASS' if evaluation['all_passed'] else 'FAIL'} " + f"(desired outcome {'met' if evaluation['all_passed'] else 'not met'})", + flush=True, + ) + return 0 if evaluation["all_passed"] else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/ablate_cot_length_expms/time_shuffle/run_time_shuffle.py b/ablate_cot_length_expms/time_shuffle/run_time_shuffle.py new file mode 100644 index 0000000..a88c8b1 --- /dev/null +++ b/ablate_cot_length_expms/time_shuffle/run_time_shuffle.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 +""" +Ablation (2): time-shuffle sentence-level trajectories before SDS/EM fitting. + +What gets shuffled +------------------ +After CEBRA, each problem has a trajectory of embedding rows in real CoT order:: + + cebra_seqs[p] = [z_0, z_1, z_2, ...] # z_t = embedding for sentence t + +shuffle_sequence_lists applies the same perm to cebra_seqs, pca_seqs, and labels. +Example:: + + Real: [z_0, z_1, z_2] # sentences 0 -> 1 -> 2 + Shuffled: [z_2, z_0, z_1] # same sentence-linked vectors, wrong timeline + +Activations stay tied to the sentence they were extracted from; we do not shuffle +text in the LM or decouple h_t from sentence t. + +What is held fixed vs changed +----------------------------- +- Fixed: CEBRA training (temporal positive pairs on real order), number of sentences T. +- Changed: only the order of rows fed to SDS/EM (forward-backward + m_step). + +What this tests +--------------- +Whether SDS metrics (persistence, self-transition, regime R^2, BIC) require real +temporal order. Expected if the readout is valid: real_order >> shuffled_order. +Does not test base-vs-RFT (run twice on different feature pickles). +""" + +from __future__ import annotations + +import argparse +import csv +import json +import pickle +import sys +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np + +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from exploration.cebra_EM import ( # noqa: E402 + K_SWEEP, + fit_and_evaluate, + linear_ar_r2, + load_and_prepare_cebra, + train_cebra_projection, +) + +from shuffle import shuffle_sequence_lists # noqa: E402 + + +def parse_args(): + p = argparse.ArgumentParser( + description=( + "Time-shuffle ablation: permute sentence-aligned embedding rows before SDS. " + "Example: real [z0,z1,z2] vs shuffled [z2,z0,z1]; CEBRA unchanged. " + "Tests whether SDS metrics need temporal order (sanity check, not length control)." + ), + ) + p.add_argument( + "--features-path", + type=Path, + required=True, + help="Pickle from create_dataset.py (all_sentences_features.pkl)", + ) + p.add_argument("--limit-problems", type=int, default=500) + p.add_argument("--cebra-mode", default="temporal", choices=["temporal"]) + p.add_argument( + "--shuffle-mode", + default="full", + choices=["full", "block"], + help="full: permute all sentences; block: permute blocks of sentences", + ) + p.add_argument("--block-size", type=int, default=3) + p.add_argument( + "--shuffle-seeds", + type=int, + nargs="+", + default=[0, 1, 2, 3, 4], + help="Independent shuffle seeds; metrics averaged over shuffled runs", + ) + p.add_argument( + "--k-values", + type=int, + nargs="+", + default=None, + help="Regime counts to evaluate (default: exploration.cebra_EM.K_SWEEP)", + ) + p.add_argument("--out-dir", type=Path, default=Path(__file__).parent / "results") + p.add_argument("--run-name", type=str, default=None) + p.add_argument( + "--cebra-cache", + type=Path, + default=None, + help="Save/load (cebra_seqs, pca_seqs, labels, ar_r2) to skip CEBRA retraining", + ) + return p.parse_args() + + +def _metrics_row(order: str, k: int, seed: int | None, fit_tuple) -> dict: + persist, mean_st, spec, _, _, score, k_eff, r2, bic = fit_tuple + return { + "order": order, + "K": k, + "shuffle_seed": seed, + "persistence": float(persist), + "mean_self_transition": float(mean_st), + "specialization": float(spec), + "sss": float(score), + "K_eff": int(k_eff), + "regime_r2": float(r2), + "bic": float(bic), + } + + +def _aggregate_shuffled(rows: list[dict], k: int) -> dict: + sub = [r for r in rows if r["order"] == "shuffled" and r["K"] == k] + if not sub: + return {} + keys = ["persistence", "mean_self_transition", "specialization", "sss", "K_eff", "regime_r2", "bic"] + out = {"K": k, "order": "shuffled_mean", "n_seeds": len(sub)} + for key in keys: + vals = [r[key] for r in sub] + out[f"{key}_mean"] = float(np.mean(vals)) + out[f"{key}_std"] = float(np.std(vals)) + return out + + +def execute_ablation( + features_path: Path, + *, + limit_problems: int = 500, + cebra_mode: str = "temporal", + shuffle_mode: str = "full", + block_size: int = 3, + shuffle_seeds: list[int] | None = None, + k_values: list[int] | None = None, + cebra_cache: Path | None = None, + out_dir: Path | None = None, + model_tag: str = "model", + verbose: bool = True, +) -> dict: + """Train CEBRA, fit SDS on real and shuffled trajectories; return summary dict.""" + if shuffle_seeds is None: + shuffle_seeds = [0, 1, 2, 3, 4] + k_values = list(k_values) if k_values is not None else list(K_SWEEP) + + def log(msg: str) -> None: + if verbose: + print(msg, flush=True) + + if cebra_cache and cebra_cache.exists(): + log(f"[{model_tag}] Loading CEBRA cache from {cebra_cache}") + with open(cebra_cache, "rb") as f: + cache = pickle.load(f) + cebra_seqs = cache["cebra_seqs"] + pca_seqs = cache["pca_seqs"] + labels = cache["labels"] + ar_r2 = cache["ar_r2"] + else: + log(f"[{model_tag}] Loading features: {features_path}") + all_f, triplets = load_and_prepare_cebra( + str(features_path), + mode=cebra_mode, + limit_problems=limit_problems, + ) + log(f"[{model_tag}] Training CEBRA (real temporal order)...") + cebra_seqs, pca_seqs, labels = train_cebra_projection(all_f, triplets) + ar_r2 = linear_ar_r2(pca_seqs) + log(f" Linear AR baseline R² (PCA): {ar_r2:.4f}") + log(f" Trajectories: {len(cebra_seqs)}") + lengths = [len(s) for s in cebra_seqs] + log(f" Seq lengths: mean={np.mean(lengths):.1f} min={min(lengths)} max={max(lengths)}") + + if cebra_cache: + cebra_cache.parent.mkdir(parents=True, exist_ok=True) + with open(cebra_cache, "wb") as f: + pickle.dump( + { + "cebra_seqs": cebra_seqs, + "pca_seqs": pca_seqs, + "labels": labels, + "ar_r2": ar_r2, + "features_path": str(features_path), + "limit_problems": limit_problems, + }, + f, + ) + log(f" Saved CEBRA cache to {cebra_cache}") + + detail_rows: list[dict] = [] + + log(f"\n[{model_tag}] === Real temporal order (SDS) ===") + for k in k_values: + fit = fit_and_evaluate(cebra_seqs, pca_seqs, labels, k) + row = _metrics_row("real", k, None, fit) + row["delta_r2"] = row["regime_r2"] - ar_r2 + detail_rows.append(row) + log( + f" K={k}: persist={row['persistence']:.2f} self_trans={row['mean_self_transition']:.3f} " + f"K_eff={row['K_eff']} R²={row['regime_r2']:.4f} ΔR²={row['delta_r2']:+.4f} BIC={row['bic']:.1f}" + ) + + log(f"\n[{model_tag}] === Shuffled order ({shuffle_mode}) ===") + for shuffle_seed in shuffle_seeds: + z_shuf, p_shuf, lab_shuf = shuffle_sequence_lists( + cebra_seqs, + pca_seqs, + labels, + seed=shuffle_seed, + mode=shuffle_mode, + block_size=block_size, + ) + for k in k_values: + fit = fit_and_evaluate(z_shuf, p_shuf, lab_shuf, k) + row = _metrics_row("shuffled", k, shuffle_seed, fit) + row["delta_r2"] = row["regime_r2"] - ar_r2 + row["shuffle_mode"] = shuffle_mode + detail_rows.append(row) + log(f" seed={shuffle_seed}: done K={k_values[0]}..{k_values[-1]}") + + summary = { + "model_tag": model_tag, + "features_path": str(features_path), + "limit_problems": limit_problems, + "shuffle_mode": shuffle_mode, + "block_size": block_size, + "shuffle_seeds": shuffle_seeds, + "k_values": k_values, + "ar_r2": float(ar_r2), + "n_trajectories": len(cebra_seqs), + "per_run": detail_rows, + "comparison": [], + } + + for k in k_values: + real = next(r for r in detail_rows if r["order"] == "real" and r["K"] == k) + agg = _aggregate_shuffled(detail_rows, k) + if not agg: + continue + comp = { + "K": k, + "real": {key: real[key] for key in real if key not in ("order", "shuffle_seed", "shuffle_mode")}, + "shuffled": agg, + "delta_real_minus_shuffled_mean": { + "persistence": real["persistence"] - agg["persistence_mean"], + "mean_self_transition": real["mean_self_transition"] - agg["mean_self_transition_mean"], + "regime_r2": real["regime_r2"] - agg["regime_r2_mean"], + "delta_r2": real["delta_r2"] - (agg["regime_r2_mean"] - ar_r2), + "bic": real["bic"] - agg["bic_mean"], + }, + } + summary["comparison"].append(comp) + + if out_dir is not None: + out_dir.mkdir(parents=True, exist_ok=True) + summary_path = out_dir / "summary.json" + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2) + csv_path = out_dir / "summary.csv" + if detail_rows: + fieldnames = list(detail_rows[0].keys()) + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") + w.writeheader() + w.writerows(detail_rows) + log(f"\n[{model_tag}] Wrote {summary_path}") + log(f"[{model_tag}] Wrote {csv_path}") + log(f"\n[{model_tag}] Real vs shuffled (mean over seeds):") + for comp in summary["comparison"]: + d = comp["delta_real_minus_shuffled_mean"] + log( + f" K={comp['K']}: Δpersist={d['persistence']:+.3f} " + f"Δself_trans={d['mean_self_transition']:+.3f} ΔR²={d['regime_r2']:+.4f}" + ) + + return summary + + +def main(): + args = parse_args() + k_values = list(args.k_values) if args.k_values is not None else list(K_SWEEP) + run_name = args.run_name or datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + out_dir = args.out_dir / run_name + execute_ablation( + args.features_path, + limit_problems=args.limit_problems, + cebra_mode=args.cebra_mode, + shuffle_mode=args.shuffle_mode, + block_size=args.block_size, + shuffle_seeds=args.shuffle_seeds, + k_values=k_values, + cebra_cache=args.cebra_cache, + out_dir=out_dir, + model_tag="single", + verbose=True, + ) + + +if __name__ == "__main__": + main() diff --git a/ablate_cot_length_expms/time_shuffle/shuffle.py b/ablate_cot_length_expms/time_shuffle/shuffle.py new file mode 100644 index 0000000..8e2c7e0 --- /dev/null +++ b/ablate_cot_length_expms/time_shuffle/shuffle.py @@ -0,0 +1,104 @@ +""" +Permute sentence-level trajectory order before SDS fitting. + +What gets shuffled +------------------ +Each trajectory is a list of rows (T timesteps), one row per CoT sentence from +``create_dataset.py``: + + timestep t -> (CEBRA embedding z_t, PCA row p_t, stage label L_t) + +All three are permuted with the **same** index permutation ``perm``. Row t always +belongs to the **same sentence** as in the original trace; only its **position** +in the sequence fed to SDS changes. + +Example (one problem, full shuffle):: + + Real order (sentence text -> row index): + 0: "Let me set up..." -> z_0 + 1: "We compute..." -> z_1 + 2: "Therefore answer=12" -> z_2 + + perm = [2, 0, 1] + Shuffled sequence passed to SDS: [z_2, z_0, z_1] + +What this tests +--------------- +Whether SDS metrics (persistence, transition structure, predictive R^2) depend on +**genuine temporal order** of reasoning steps. If metrics collapse under shuffle, +the readout is using sequential structure rather than order-invariant clusters alone. +""" + +from __future__ import annotations + +import numpy as np + + +def permute_trajectory(seq: np.ndarray, rng: np.random.Generator, mode: str = "full", block_size: int = 3) -> np.ndarray: + """Return a row-permuted copy of ``seq`` (shape [T, D]). + + Row ``i`` is the activation (or embedding) for sentence ``i`` in the original + CoT; after permutation it sits at a new timestep in the SDS input sequence. + """ + n = len(seq) + if n <= 1: + return seq.copy() + + if mode == "full": + perm = rng.permutation(n) + elif mode == "block": + if block_size < 1: + raise ValueError("block_size must be >= 1") + blocks = [np.arange(i, min(i + block_size, n)) for i in range(0, n, block_size)] + block_order = rng.permutation(len(blocks)) + perm = np.concatenate([blocks[i] for i in block_order]) + else: + raise ValueError(f"Unknown shuffle mode: {mode!r}") + + return seq[perm] + + +def shuffle_sequence_lists( + cebra_seqs: list[np.ndarray], + pca_seqs: list[np.ndarray], + labels: list[list], + seed: int, + mode: str = "full", + block_size: int = 3, +) -> tuple[list[np.ndarray], list[np.ndarray], list[list]]: + """Apply one permutation per trajectory, shared across CEBRA / PCA / labels. + + Parameters + ---------- + cebra_seqs, pca_seqs + One array per problem, shape ``[T, D]``, rows in real CoT sentence order. + labels + Per-timestep stage strings aligned with rows (e.g. PLAN_GENERATION). + + Returns + ------- + Same structure with rows reordered. Example: if problem p has T=4 and + ``perm = [3, 1, 0, 2]``, then ``cebra_out[p] = cebra_seqs[p][perm]`` and + ``labels_out[p][t]`` is the stage that belonged to original sentence ``perm[t]``. + """ + rng = np.random.default_rng(seed) + cebra_out, pca_out, labels_out = [], [], [] + for z_seq, p_seq, lab in zip(cebra_seqs, pca_seqs, labels): + n = len(z_seq) + if n <= 1: + cebra_out.append(z_seq.copy()) + pca_out.append(p_seq.copy()) + labels_out.append(list(lab)) + continue + if mode == "full": + perm = rng.permutation(n) + elif mode == "block": + blocks = [np.arange(i, min(i + block_size, n)) for i in range(0, n, block_size)] + block_order = rng.permutation(len(blocks)) + perm = np.concatenate([blocks[i] for i in block_order]) + else: + raise ValueError(f"Unknown shuffle mode: {mode!r}") + cebra_out.append(z_seq[perm]) + pca_out.append(p_seq[perm]) + labels_out.append([lab[i] for i in perm]) + return cebra_out, pca_out, labels_out From 308e4c54c43045cf632f3e085b63df2d557f8edd Mon Sep 17 00:00:00 2001 From: wlg1 Date: Sun, 24 May 2026 20:43:48 -0700 Subject: [PATCH 2/4] time shuffle traj ablation expms --- {ablate_cot_length_expms => ablate_cot_trace_expms}/README.md | 0 {ablate_cot_length_expms => ablate_cot_trace_expms}/__init__.py | 0 .../time_shuffle/README.md | 0 .../time_shuffle/__init__.py | 0 .../time_shuffle/run_compare_base_rft.py | 0 .../time_shuffle/run_time_shuffle.py | 0 .../time_shuffle/shuffle.py | 0 7 files changed, 0 insertions(+), 0 deletions(-) rename {ablate_cot_length_expms => ablate_cot_trace_expms}/README.md (100%) rename {ablate_cot_length_expms => ablate_cot_trace_expms}/__init__.py (100%) rename {ablate_cot_length_expms => ablate_cot_trace_expms}/time_shuffle/README.md (100%) rename {ablate_cot_length_expms => ablate_cot_trace_expms}/time_shuffle/__init__.py (100%) rename {ablate_cot_length_expms => ablate_cot_trace_expms}/time_shuffle/run_compare_base_rft.py (100%) rename {ablate_cot_length_expms => ablate_cot_trace_expms}/time_shuffle/run_time_shuffle.py (100%) rename {ablate_cot_length_expms => ablate_cot_trace_expms}/time_shuffle/shuffle.py (100%) diff --git a/ablate_cot_length_expms/README.md b/ablate_cot_trace_expms/README.md similarity index 100% rename from ablate_cot_length_expms/README.md rename to ablate_cot_trace_expms/README.md diff --git a/ablate_cot_length_expms/__init__.py b/ablate_cot_trace_expms/__init__.py similarity index 100% rename from ablate_cot_length_expms/__init__.py rename to ablate_cot_trace_expms/__init__.py diff --git a/ablate_cot_length_expms/time_shuffle/README.md b/ablate_cot_trace_expms/time_shuffle/README.md similarity index 100% rename from ablate_cot_length_expms/time_shuffle/README.md rename to ablate_cot_trace_expms/time_shuffle/README.md diff --git a/ablate_cot_length_expms/time_shuffle/__init__.py b/ablate_cot_trace_expms/time_shuffle/__init__.py similarity index 100% rename from ablate_cot_length_expms/time_shuffle/__init__.py rename to ablate_cot_trace_expms/time_shuffle/__init__.py diff --git a/ablate_cot_length_expms/time_shuffle/run_compare_base_rft.py b/ablate_cot_trace_expms/time_shuffle/run_compare_base_rft.py similarity index 100% rename from ablate_cot_length_expms/time_shuffle/run_compare_base_rft.py rename to ablate_cot_trace_expms/time_shuffle/run_compare_base_rft.py diff --git a/ablate_cot_length_expms/time_shuffle/run_time_shuffle.py b/ablate_cot_trace_expms/time_shuffle/run_time_shuffle.py similarity index 100% rename from ablate_cot_length_expms/time_shuffle/run_time_shuffle.py rename to ablate_cot_trace_expms/time_shuffle/run_time_shuffle.py diff --git a/ablate_cot_length_expms/time_shuffle/shuffle.py b/ablate_cot_trace_expms/time_shuffle/shuffle.py similarity index 100% rename from ablate_cot_length_expms/time_shuffle/shuffle.py rename to ablate_cot_trace_expms/time_shuffle/shuffle.py From 56cbd16723142b1df76132632a5a7d29c66830ee Mon Sep 17 00:00:00 2001 From: wlg1 Date: Mon, 25 May 2026 17:43:22 -0700 Subject: [PATCH 3/4] correct-only expms --- ablate_cot_trace_expms/README.md | 1 + ablate_cot_trace_expms/correct_only/README.md | 68 ++++ .../correct_only/__init__.py | 1 + ablate_cot_trace_expms/correct_only/config.py | 96 ++++++ .../correct_only/evaluate.py | 85 +++++ .../correct_only/grading.py | 133 ++++++++ .../correct_only/hf_utils.py | 44 +++ .../correct_only/manifest.py | 289 ++++++++++++++++ .../correct_only/manifest_cache/.gitignore | 2 + .../correct_only/manifest_cache/.gitkeep | 0 .../correct_only/pipeline.py | 155 +++++++++ .../correct_only/run_compare_base_rft.py | 286 ++++++++++++++++ .../correct_only/run_full_analysis.py | 321 ++++++++++++++++++ .../correct_only/run_single_model.py | 142 ++++++++ .../correct_only/subsets.py | 36 ++ .../time_shuffle/run_compare_base_rft.py | 12 +- .../time_shuffle/run_remaining_batch.sh | 111 ++++++ exploration/cebra_EM.py | 11 +- 18 files changed, 1790 insertions(+), 3 deletions(-) create mode 100644 ablate_cot_trace_expms/correct_only/README.md create mode 100644 ablate_cot_trace_expms/correct_only/__init__.py create mode 100644 ablate_cot_trace_expms/correct_only/config.py create mode 100644 ablate_cot_trace_expms/correct_only/evaluate.py create mode 100644 ablate_cot_trace_expms/correct_only/grading.py create mode 100644 ablate_cot_trace_expms/correct_only/hf_utils.py create mode 100644 ablate_cot_trace_expms/correct_only/manifest.py create mode 100644 ablate_cot_trace_expms/correct_only/manifest_cache/.gitignore create mode 100644 ablate_cot_trace_expms/correct_only/manifest_cache/.gitkeep create mode 100644 ablate_cot_trace_expms/correct_only/pipeline.py create mode 100644 ablate_cot_trace_expms/correct_only/run_compare_base_rft.py create mode 100644 ablate_cot_trace_expms/correct_only/run_full_analysis.py create mode 100644 ablate_cot_trace_expms/correct_only/run_single_model.py create mode 100644 ablate_cot_trace_expms/correct_only/subsets.py create mode 100644 ablate_cot_trace_expms/time_shuffle/run_remaining_batch.sh diff --git a/ablate_cot_trace_expms/README.md b/ablate_cot_trace_expms/README.md index 092fba8..8e2d374 100644 --- a/ablate_cot_trace_expms/README.md +++ b/ablate_cot_trace_expms/README.md @@ -5,5 +5,6 @@ Controlled experiments for reviewer-style confound checks on the CEBRA + SDS pip | Subfolder | Reviewer item | Description | |-----------|---------------|-------------| | [`time_shuffle/`](time_shuffle/) | Time-shuffled trajectory | Permute **sentence-aligned activation steps** before SDS/EM; CEBRA still trained on real CoT order | +| [`correct_only/`](correct_only/) | Correctness confound | **Paired-correct** (both base & RFT right on same problems); no length controls | diff --git a/ablate_cot_trace_expms/correct_only/README.md b/ablate_cot_trace_expms/correct_only/README.md new file mode 100644 index 0000000..3f8718a --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/README.md @@ -0,0 +1,68 @@ +# Ablation: correct-trace subsets only + +Tests whether the base vs RFT SDS gap survives when restricting to **correct** trajectories—especially **paired correct** (same problems where **both** base and RFT are graded correct). + +No length matching or incorrect-trace controls. + +## Subsets + +| Subset | Definition | +|--------|------------| +| `all` | All problems in the manifest (up to `--limit-problems`) | +| `within_correct` | Per-model: only trajectories with a correct final answer | +| `paired_correct` | `problem_id` where **base and RFT** are both correct | + +CEBRA is retrained **within each subset** (only allowed `problem_id`s enter triplets and trajectories). + +## Manifest cache (correctness labels only) + +Grading runs once per `(dataset, model folder, variant, layer, cot_data hash)` and is saved under **`manifest_cache/`** (default). No activations are copied—only JSON with `problem_id`, `is_correct`, answers, etc. + +| Flag | Meaning | +|------|---------| +| `--manifest-cache-dir PATH` | Cache directory (default: `correct_only/manifest_cache`) | +| `--use-cached-manifest` / `--no-use-cached-manifest` | Read cache when valid (default: use) | +| `--rebuild-manifest` | Force re-grade and overwrite cache | + +Paired-correct `problem_id` lists are cached under `manifest_cache/paired/`. Each run still copies manifests into its `results/...` folder for reproducibility. + +## Scripts + +| Script | Role | +|--------|------| +| `run_single_model.py` | One variant (`base` or `rft`) + one dataset → manifest + SDS metrics | +| `run_compare_base_rft.py` | Base + RFT for one dataset → all subsets + `comparison_report.json` | +| `run_full_analysis.py` | Sweep combos → `gap_summary.csv`, `gap_by_subset.png`, `outcome_summary.csv` | + +## Quick start + +```bash +python ablate_cot_trace_expms/correct_only/run_compare_base_rft.py \ + --dataset gsm8k --model-family qwen14 \ + --hf-cache-dir ./hf_cache --k-values 5 --k-focus 5 + +python ablate_cot_trace_expms/correct_only/run_full_analysis.py \ + --hf-cache-dir ./hf_cache --datasets gsm8k --model-families qwen14 +``` + +## Desired outcome (pass/fail) + +At `--k-focus` (default **K=5**), **primary pass** requires on **`paired_correct`**: + +- RFT > Base for `persistence`, `mean_self_transition`, `delta_r2`, `K_eff` +- Paired RFT−Base gap retains ≥50% of the `all`-subset gap (persistence & ΔR²) + +Exit code **0** when primary pass holds (`run_compare_base_rft.py`, `run_full_analysis.py`). + +## Outputs + +``` +results/compare//_/ + base_manifest.json + rft_manifest.json + base/{all,within_correct,paired_correct}/summary.json + rft/... + comparison_report.json +``` + +`gap_by_subset.png` compares RFT−Base gaps for `all`, `within_correct`, and `paired_correct`. diff --git a/ablate_cot_trace_expms/correct_only/__init__.py b/ablate_cot_trace_expms/correct_only/__init__.py new file mode 100644 index 0000000..f6b8fe3 --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/__init__.py @@ -0,0 +1 @@ +# Correct-trace ablation: paired-correct subset for base vs RFT SDS metrics. diff --git a/ablate_cot_trace_expms/correct_only/config.py b/ablate_cot_trace_expms/correct_only/config.py new file mode 100644 index 0000000..a0dc368 --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/config.py @@ -0,0 +1,96 @@ +"""Paths and defaults for correct-trace ablations on SDS HF datasets.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +# HF dataset repo id -> benchmark loader key (see grading.DATASET_CFG) +HF_REPOS: dict[str, str] = { + "gsm8k": "withmartian/SDS_train_gsm8k", + "svamp": "withmartian/SDS_train_svamp", + "mmlu-pro": "withmartian/SDS_train_mmlu-pro", + "math500": "withmartian/SDS_math500_test", +} + +# Short model family key -> (base folder, rft folder) under each HF repo +MODEL_PAIRS: dict[str, tuple[str, str]] = { + "llama8": ("llama_8B_base", "llama_8b_reasoning"), + "qwen14": ("qwen_14B_base", "qwen_14b_reasoning"), + "qwen1.5": ("qwen_1.5B_base", "qwen1.5b_reasoning"), +} + +DEFAULT_LAYER: dict[str, int] = { + "llama8": 22, + "qwen14": 28, + "qwen1.5": 20, +} + +# Default problem cap (match exploration scripts / paper subsets) +DEFAULT_LIMIT_PROBLEMS: dict[str, int] = { + "gsm8k": 2000, + "svamp": 800, + "mmlu-pro": 500, + "math500": 500, +} + +SUBSET_NAMES = ( + "all", + "within_correct", + "paired_correct", +) + + +@dataclass(frozen=True) +class ModelDatasetPaths: + dataset_key: str + model_family: str + variant: str # "base" | "rft" + layer: int + hf_repo: str + folder: str + features_path: Path + cot_data_path: Path + + @property + def model_tag(self) -> str: + return f"{self.model_family}_{self.variant}" + + +def resolve_paths( + dataset_key: str, + model_family: str, + variant: str, + *, + layer: int | None = None, + hf_root: Path | None = None, + hf_repo: str | None = None, +) -> ModelDatasetPaths: + if dataset_key not in HF_REPOS: + raise KeyError(f"Unknown dataset_key {dataset_key!r}; choose from {list(HF_REPOS)}") + if model_family not in MODEL_PAIRS: + raise KeyError(f"Unknown model_family {model_family!r}; choose from {list(MODEL_PAIRS)}") + if variant not in ("base", "rft"): + raise ValueError("variant must be 'base' or 'rft'") + + layer = layer if layer is not None else DEFAULT_LAYER[model_family] + repo = hf_repo or HF_REPOS[dataset_key] + folder = MODEL_PAIRS[model_family][0 if variant == "base" else 1] + layer_dir = f"layer_{layer}" + root = hf_root if hf_root is not None else Path(repo.replace("/", "__")) + + base = root / folder / layer_dir + return ModelDatasetPaths( + dataset_key=dataset_key, + model_family=model_family, + variant=variant, + layer=layer, + hf_repo=repo, + folder=folder, + features_path=base / "all_sentences_features.pkl", + cot_data_path=base / "cot_data.pkl", + ) + + +def all_combo_keys() -> list[tuple[str, str]]: + return [(d, m) for d in HF_REPOS for m in MODEL_PAIRS] diff --git a/ablate_cot_trace_expms/correct_only/evaluate.py b/ablate_cot_trace_expms/correct_only/evaluate.py new file mode 100644 index 0000000..657a358 --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/evaluate.py @@ -0,0 +1,85 @@ +"""Pass/fail criteria for correct-trace ablations.""" + +from __future__ import annotations + +from pipeline import metric_at_k + +METRICS_RFT_SHOULD_WIN = ("persistence", "mean_self_transition", "delta_r2", "K_eff") + + +def evaluate_compare_summaries( + base_summaries: dict[str, dict], + rft_summaries: dict[str, dict], + k_focus: int, + *, + min_paired_gap_fraction: float = 0.5, +) -> dict: + """ + Desired outcome: + - On paired_correct: RFT > base on key metrics + - paired_correct gap retains a substantial fraction of the all-trace gap + """ + checks: list[dict] = [] + + def add(name: str, passed: bool, detail: str, values: dict | None = None) -> None: + checks.append({"name": name, "passed": passed, "detail": detail, "values": values or {}}) + + for subset in ("paired_correct", "all"): + if subset not in base_summaries or subset not in rft_summaries: + continue + for metric in METRICS_RFT_SHOULD_WIN: + b = metric_at_k(base_summaries[subset], k_focus, metric) + r = metric_at_k(rft_summaries[subset], k_focus, metric) + if b is None or r is None: + add(f"rft_gt_base_{subset}_{metric}", False, f"missing K={k_focus}") + continue + add( + f"rft_gt_base_{subset}_{metric}", + r > b, + f"{subset}: RFT={r:.4f} vs Base={b:.4f}", + {"base": b, "rft": r, "subset": subset}, + ) + + for metric in ("persistence", "delta_r2"): + if not all(s in base_summaries for s in ("all", "paired_correct")): + continue + if not all(s in rft_summaries for s in ("all", "paired_correct")): + continue + gap_all = metric_at_k(rft_summaries["all"], k_focus, metric) - metric_at_k( + base_summaries["all"], k_focus, metric + ) + gap_paired = metric_at_k(rft_summaries["paired_correct"], k_focus, metric) - metric_at_k( + base_summaries["paired_correct"], k_focus, metric + ) + if gap_all <= 0: + add( + f"paired_retains_gap_{metric}", + gap_paired > 0, + f"all gap non-positive ({gap_all:.4f}); require paired gap>0 ({gap_paired:.4f})", + ) + else: + frac = gap_paired / gap_all if gap_all else 0.0 + add( + f"paired_retains_gap_{metric}", + gap_paired >= min_paired_gap_fraction * gap_all, + f"paired_gap={gap_paired:.4f} all_gap={gap_all:.4f} frac={frac:.2f}", + {"gap_paired": gap_paired, "gap_all": gap_all}, + ) + + primary = [ + c for c in checks + if "paired_correct" in c["name"] and c["name"].startswith("rft_gt_base") + ] + primary_pass = all(c["passed"] for c in primary) if primary else False + all_passed = all(c["passed"] for c in checks) + + return { + "k_focus": k_focus, + "checks": checks, + "all_passed": all_passed, + "primary_pass": primary_pass, + "desired_outcome": ( + "RFT > Base on paired_correct; paired RFT−Base gap retains " + f"≥{min_paired_gap_fraction:.0%} of all-trace gap (persistence, ΔR²)" + ), + } diff --git a/ablate_cot_trace_expms/correct_only/grading.py b/ablate_cot_trace_expms/correct_only/grading.py new file mode 100644 index 0000000..99427b0 --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/grading.py @@ -0,0 +1,133 @@ +"""Grade CoT traces against benchmark ground truth (self-contained, no external imports).""" + +from __future__ import annotations + +import re +from typing import Any + +DATASET_CFG: dict[str, dict[str, Any]] = { + "gsm8k": { + "hf_id": "openai/gsm8k", + "config": "main", + "split": "train", + "type": "gsm8k", + }, + "svamp": { + "hf_id": "garrethlee/svamp", + "config": "default", + "split": "train", + "type": "svamp", + }, + "mmlu-pro": { + "hf_id": "TIGER-Lab/MMLU-Pro", + "config": "default", + "split": "test", + "type": "mmlu_pro", + }, + "math500": { + "hf_id": "HuggingFaceH4/MATH-500", + "config": "default", + "split": "test", + "type": "math", + }, +} + + +def _normalize_number(s: str) -> str: + s = str(s).strip().replace(",", "") + s = re.sub(r"^\$|\$$", "", s) + try: + v = float(s) + if abs(v - round(v)) < 1e-9: + return str(int(round(v))) + return str(v) + except ValueError: + return s.strip().lower() + + +def _extract_answer_from_cot(cot: str, kind: str) -> str: + text = cot.strip() + if kind == "gsm8k": + matches = re.findall(r"####\s*([\-\d,\.]+)", text) + if matches: + return _normalize_number(matches[-1]) + nums = re.findall(r"[-+]?\d*\.?\d+", text) + return _normalize_number(nums[-1]) if nums else "" + if kind == "mmlu_pro": + for line in reversed(text.splitlines()): + line = line.strip() + m = re.search(r"\b([A-J])\b\s*$", line, re.I) + if m: + return m.group(1).upper() + m = re.search(r"(?:answer|option)\s*[:is]?\s*([A-J])\b", line, re.I) + if m: + return m.group(1).upper() + m = re.search(r"\b([A-J])\b", text[-200:], re.I) + return m.group(1).upper() if m else "" + if kind == "math": + boxed = re.findall(r"\\boxed\{([^}]*)\}", text) + if boxed: + inner = boxed[-1] + inner = re.sub(r"\\[a-zA-Z]+", "", inner) + inner = inner.replace("{", "").replace("}", "").strip() + return _normalize_number(inner) if inner else "" + nums = re.findall(r"[-+]?\d*\.?\d+", text[-400:]) + return _normalize_number(nums[-1]) if nums else "" + # svamp / generic numeric + nums = re.findall(r"[-+]?\d*\.?\d+", text) + return _normalize_number(nums[-1]) if nums else "" + + +def is_correct_answer(predicted: str, ground_truth: str, kind: str) -> bool: + pred = (predicted or "").strip() + gold = (ground_truth or "").strip() + if not pred or not gold: + return False + if kind == "mmlu_pro": + return pred.upper() == gold.upper() + if kind in ("gsm8k", "svamp", "math"): + return _normalize_number(pred) == _normalize_number(gold) + return pred.lower() == gold.lower() + + +def ground_truth_from_row(row: dict, kind: str) -> str: + if kind == "svamp": + return str(row.get("Answer", row.get("answer", ""))).strip() + if kind == "mmlu_pro": + return str(row.get("answer", "")).strip().upper() + if kind == "gsm8k": + ans = row.get("answer", "") + m = re.search(r"####\s*([\-\d,\.]+)", str(ans)) + if m: + return _normalize_number(m.group(1)) + return str(ans).strip() + return str(row.get("answer", row.get("solution", ""))).strip() + + +def load_benchmark_rows( + dataset_key: str, + *, + limit: int | None = None, + hf_cache_dir: str | None = None, +) -> dict[int, str]: + """problem_id -> ground_truth string (aligned with create_dataset row order).""" + from datasets import load_dataset + + cfg = DATASET_CFG[dataset_key] + base_split = cfg["split"].split("[")[0] + kwargs: dict[str, Any] = { + "split": f"{base_split}[:{limit}]" if limit is not None else base_split, + } + if hf_cache_dir: + kwargs["cache_dir"] = hf_cache_dir + + if cfg.get("config"): + ds = load_dataset(cfg["hf_id"], cfg["config"], **kwargs) + else: + ds = load_dataset(cfg["hf_id"], **kwargs) + + kind = cfg["type"] + out: dict[int, str] = {} + for i, row in enumerate(ds): + out[i] = ground_truth_from_row(row, kind) + return out diff --git a/ablate_cot_trace_expms/correct_only/hf_utils.py b/ablate_cot_trace_expms/correct_only/hf_utils.py new file mode 100644 index 0000000..22bb4c4 --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/hf_utils.py @@ -0,0 +1,44 @@ +"""Download SDS artifacts from Hugging Face dataset repos.""" + +from __future__ import annotations + +from pathlib import Path + + +def ensure_hf_file( + repo_id: str, + filename: str, + local_dir: Path | None = None, +) -> Path: + from huggingface_hub import hf_hub_download + + path = hf_hub_download( + repo_id=repo_id, + filename=filename, + repo_type="dataset", + local_dir=str(local_dir) if local_dir else None, + ) + return Path(path) + + +def ensure_model_artifacts( + paths, + *, + cache_dir: Path | None = None, +) -> tuple[Path, Path]: + """Return (features_path, cot_data_path), downloading from HF if missing locally.""" + features = paths.features_path + cot = paths.cot_data_path + if features.is_file() and cot.is_file(): + return features, cot + + layer_dir = f"layer_{paths.layer}" + prefix = f"{paths.folder}/{layer_dir}" + repo = paths.hf_repo + root = cache_dir or Path("hf_cache") / repo.replace("/", "__") + + if not cot.is_file(): + cot = ensure_hf_file(repo, f"{prefix}/cot_data.pkl", local_dir=root) + if not features.is_file(): + features = ensure_hf_file(repo, f"{prefix}/all_sentences_features.pkl", local_dir=root) + return Path(features), Path(cot) diff --git a/ablate_cot_trace_expms/correct_only/manifest.py b/ablate_cot_trace_expms/correct_only/manifest.py new file mode 100644 index 0000000..c195444 --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/manifest.py @@ -0,0 +1,289 @@ +"""Build, cache, and load per-model correctness manifests (no activation duplication).""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import pickle +import shutil +from pathlib import Path +from typing import Any + +from .grading import DATASET_CFG, _extract_answer_from_cot, is_correct_answer, load_benchmark_rows + +DEFAULT_MANIFEST_CACHE_DIR = Path(__file__).resolve().parent / "manifest_cache" + + +def add_manifest_cache_args(parser) -> None: + """Register CLI flags shared by run_single_model and run_compare_base_rft.""" + parser.add_argument( + "--manifest-cache-dir", + type=Path, + default=DEFAULT_MANIFEST_CACHE_DIR, + help="Directory for reusable correctness manifests (JSON, no activations).", + ) + parser.add_argument( + "--use-cached-manifest", + action=argparse.BooleanOptionalAction, + default=True, + help="Load manifest from cache when cot_data fingerprint matches (default: true).", + ) + parser.add_argument( + "--rebuild-manifest", + action="store_true", + help="Force re-grade cot_data and overwrite cached manifest.", + ) + + +def trajectory_length(cot_entry: dict) -> int: + sents = cot_entry.get("sentences") or [] + if sents: + return len(sents) + cot = cot_entry.get("cot") or "" + parts = [p for p in cot.replace("\n", " ").split(". ") if len(p.strip()) > 10] + return max(len(parts), 1) + + +def cot_data_fingerprint(cot_data_path: Path) -> str: + """SHA-256 of cot_data.pkl; cache invalidates when the file changes.""" + h = hashlib.sha256() + with open(cot_data_path, "rb") as f: + for chunk in iter(lambda: f.read(1 << 20), b""): + h.update(chunk) + return h.hexdigest() + + +def manifest_cache_path( + cache_dir: Path, + *, + dataset_key: str, + model_folder: str, + variant: str, + layer: int, + limit_problems: int, + cot_data_path: Path, +) -> Path: + fp = cot_data_fingerprint(cot_data_path) + name = ( + f"{dataset_key}__{model_folder}__{variant}__layer{layer}" + f"__n{limit_problems}__{fp[:16]}.json" + ) + return cache_dir / name + + +def _manifest_matches_cache( + manifest: dict[str, Any], + *, + dataset_key: str, + limit_problems: int, + cot_data_path: Path, + model_folder: str, + variant: str, + layer: int, +) -> bool: + fp = cot_data_fingerprint(cot_data_path) + return ( + manifest.get("dataset_key") == dataset_key + and manifest.get("limit_problems") == limit_problems + and manifest.get("cot_data_fingerprint") == fp + and manifest.get("cot_data_path") == str(cot_data_path.resolve()) + and manifest.get("model_folder") == model_folder + and manifest.get("variant") == variant + and manifest.get("layer") == layer + and "records" in manifest + ) + + +def build_manifest( + cot_data_path: Path, + dataset_key: str, + *, + limit_problems: int, + hf_cache_dir: str | None = None, + model_folder: str | None = None, + variant: str | None = None, + layer: int | None = None, +) -> dict[str, Any]: + cot_data_path = cot_data_path.resolve() + with open(cot_data_path, "rb") as f: + cot_data: dict = pickle.load(f) + + gold = load_benchmark_rows(dataset_key, limit=limit_problems, hf_cache_dir=hf_cache_dir) + kind = DATASET_CFG[dataset_key]["type"] + + records = [] + n_correct = 0 + for pid in sorted(cot_data.keys(), key=lambda x: int(x) if str(x).isdigit() else x): + pid_int = int(pid) + if pid_int >= limit_problems: + continue + entry = cot_data[pid] + cot = entry.get("cot", "") + pred = _extract_answer_from_cot(cot, kind) + gt = gold.get(pid_int, "") + ok = is_correct_answer(pred, gt, kind) + if ok: + n_correct += 1 + records.append( + { + "problem_id": pid_int, + "is_correct": ok, + "T": trajectory_length(entry), + "extracted_answer": pred, + "ground_truth": gt, + } + ) + + manifest: dict[str, Any] = { + "dataset_key": dataset_key, + "cot_data_path": str(cot_data_path), + "cot_data_fingerprint": cot_data_fingerprint(cot_data_path), + "limit_problems": limit_problems, + "n_problems": len(records), + "n_correct": n_correct, + "accuracy": n_correct / len(records) if records else 0.0, + "records": records, + } + if model_folder is not None: + manifest["model_folder"] = model_folder + if variant is not None: + manifest["variant"] = variant + if layer is not None: + manifest["layer"] = layer + return manifest + + +def save_manifest(manifest: dict[str, Any], path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(manifest, f, indent=2) + + +def load_manifest(path: Path) -> dict[str, Any]: + with open(path, encoding="utf-8") as f: + return json.load(f) + + +def get_or_build_manifest( + cot_data_path: Path, + dataset_key: str, + *, + limit_problems: int, + model_folder: str, + variant: str, + layer: int, + cache_dir: Path | None = None, + hf_cache_dir: str | None = None, + use_cached: bool = True, + rebuild: bool = False, +) -> tuple[dict[str, Any], Path, bool]: + """ + Load a cached correctness manifest or build and save one. + + Returns (manifest, cache_path, from_cache). + """ + cache_dir = cache_dir or DEFAULT_MANIFEST_CACHE_DIR + cache_dir.mkdir(parents=True, exist_ok=True) + cot_data_path = cot_data_path.resolve() + cache_path = manifest_cache_path( + cache_dir, + dataset_key=dataset_key, + model_folder=model_folder, + variant=variant, + layer=layer, + limit_problems=limit_problems, + cot_data_path=cot_data_path, + ) + + if use_cached and not rebuild and cache_path.is_file(): + manifest = load_manifest(cache_path) + if _manifest_matches_cache( + manifest, + dataset_key=dataset_key, + limit_problems=limit_problems, + cot_data_path=cot_data_path, + model_folder=model_folder, + variant=variant, + layer=layer, + ): + return manifest, cache_path, True + print( + f" Manifest cache stale or mismatched ({cache_path.name}); rebuilding.", + flush=True, + ) + + manifest = build_manifest( + cot_data_path, + dataset_key, + limit_problems=limit_problems, + hf_cache_dir=hf_cache_dir, + model_folder=model_folder, + variant=variant, + layer=layer, + ) + save_manifest(manifest, cache_path) + return manifest, cache_path, False + + +def copy_manifest_to_run_dir(cache_path: Path, run_path: Path) -> None: + """Copy cached manifest into a run folder for reproducibility.""" + run_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(cache_path, run_path) + + +def records_by_pid(manifest: dict[str, Any]) -> dict[int, dict]: + return {r["problem_id"]: r for r in manifest["records"]} + + +def get_or_build_paired_index( + base_manifest: dict[str, Any], + rft_manifest: dict[str, Any], + *, + dataset_key: str, + model_family: str, + layer: int, + limit_problems: int, + cache_dir: Path | None = None, + use_cached: bool = True, + rebuild: bool = False, +) -> tuple[list[int], Path | None, bool]: + """ + Cache the list of problem_ids where both models are correct. + + Returns (problem_ids, cache_path or None, from_cache). + """ + from .subsets import pids_paired_correct + + cache_dir = cache_dir or DEFAULT_MANIFEST_CACHE_DIR + paired_dir = cache_dir / "paired" + paired_dir.mkdir(parents=True, exist_ok=True) + + base_fp = base_manifest["cot_data_fingerprint"][:8] + rft_fp = rft_manifest["cot_data_fingerprint"][:8] + cache_path = paired_dir / ( + f"{dataset_key}__{model_family}__layer{layer}__n{limit_problems}" + f"__base_{base_fp}__rft_{rft_fp}.json" + ) + + if use_cached and not rebuild and cache_path.is_file(): + with open(cache_path, encoding="utf-8") as f: + data = json.load(f) + if data.get("base_fingerprint", "").startswith(base_fp) and data.get( + "rft_fingerprint", "" + ).startswith(rft_fp): + return data["problem_ids"], cache_path, True + + pids = sorted(pids_paired_correct(base_manifest, rft_manifest)) + payload = { + "dataset_key": dataset_key, + "model_family": model_family, + "layer": layer, + "limit_problems": limit_problems, + "base_fingerprint": base_manifest["cot_data_fingerprint"], + "rft_fingerprint": rft_manifest["cot_data_fingerprint"], + "n_paired": len(pids), + "problem_ids": pids, + } + save_manifest(payload, cache_path) + return pids, cache_path, False diff --git a/ablate_cot_trace_expms/correct_only/manifest_cache/.gitignore b/ablate_cot_trace_expms/correct_only/manifest_cache/.gitignore new file mode 100644 index 0000000..0827618 --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/manifest_cache/.gitignore @@ -0,0 +1,2 @@ +*.json +!.gitignore diff --git a/ablate_cot_trace_expms/correct_only/manifest_cache/.gitkeep b/ablate_cot_trace_expms/correct_only/manifest_cache/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/ablate_cot_trace_expms/correct_only/pipeline.py b/ablate_cot_trace_expms/correct_only/pipeline.py new file mode 100644 index 0000000..7a0f66d --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/pipeline.py @@ -0,0 +1,155 @@ +"""CEBRA + SDS on filtered problem subsets.""" + +from __future__ import annotations + +import csv +import json +import pickle +import sys +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np + +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) + +from exploration.cebra_EM import ( # noqa: E402 + K_SWEEP, + fit_and_evaluate, + linear_ar_r2, + load_and_prepare_cebra, + train_cebra_projection, +) + + +def _metrics_row(subset: str, k: int, fit_tuple) -> dict: + persist, mean_st, spec, _, _, score, k_eff, r2, bic = fit_tuple + return { + "subset": subset, + "K": k, + "persistence": float(persist), + "mean_self_transition": float(mean_st), + "specialization": float(spec), + "sss": float(score), + "K_eff": int(k_eff), + "regime_r2": float(r2), + "bic": float(bic), + } + + +def execute_correct_subset( + features_path: Path, + allowed_problem_ids: set[int], + *, + subset_name: str = "custom", + limit_problems: int = 500, + k_values: list[int] | None = None, + cebra_cache: Path | None = None, + out_dir: Path | None = None, + model_tag: str = "model", + verbose: bool = True, +) -> dict: + """Train CEBRA on allowed problems only; fit SDS; return summary dict.""" + if k_values is None: + k_values = list(K_SWEEP) + + def log(msg: str) -> None: + if verbose: + print(msg, flush=True) + + allowed_list = sorted(allowed_problem_ids) + cache_key = (subset_name, tuple(allowed_list), limit_problems) + + cache = None + if cebra_cache and cebra_cache.exists(): + log(f"[{model_tag}] Loading CEBRA cache from {cebra_cache}") + with open(cebra_cache, "rb") as f: + loaded = pickle.load(f) + if loaded.get("cache_key") != cache_key: + log(f"[{model_tag}] Cache key mismatch; retraining CEBRA.") + else: + cache = loaded + cebra_seqs = cache["cebra_seqs"] + pca_seqs = cache["pca_seqs"] + labels = cache["labels"] + ar_r2 = cache["ar_r2"] + + if cache is None: + log( + f"[{model_tag}] Loading features (subset={subset_name}, " + f"n_pids={len(allowed_list)})..." + ) + all_f, triplets = load_and_prepare_cebra( + str(features_path), + limit_problems=limit_problems, + allowed_problem_ids=allowed_list, + ) + if len(all_f) < 10: + raise ValueError( + f"Too few features ({len(all_f)}) for subset {subset_name}; " + f"check allowed_problem_ids and limit_problems." + ) + log(f"[{model_tag}] Training CEBRA ({len(all_f)} steps, {len(triplets)} triplets)...") + cebra_seqs, pca_seqs, labels = train_cebra_projection(all_f, triplets) + ar_r2 = linear_ar_r2(pca_seqs) + log(f" Trajectories: {len(cebra_seqs)} AR R²: {ar_r2:.4f}") + if cebra_cache: + cebra_cache.parent.mkdir(parents=True, exist_ok=True) + with open(cebra_cache, "wb") as f: + pickle.dump( + { + "cache_key": cache_key, + "cebra_seqs": cebra_seqs, + "pca_seqs": pca_seqs, + "labels": labels, + "ar_r2": ar_r2, + "features_path": str(features_path), + "subset_name": subset_name, + "n_allowed": len(allowed_list), + }, + f, + ) + + detail_rows: list[dict] = [] + log(f"\n[{model_tag}] === SDS on subset '{subset_name}' ===") + for k in k_values: + fit = fit_and_evaluate(cebra_seqs, pca_seqs, labels, k) + row = _metrics_row(subset_name, k, fit) + row["delta_r2"] = row["regime_r2"] - ar_r2 + detail_rows.append(row) + log( + f" K={k}: persist={row['persistence']:.2f} self_trans={row['mean_self_transition']:.3f} " + f"K_eff={row['K_eff']} ΔR²={row['delta_r2']:+.4f}" + ) + + summary = { + "model_tag": model_tag, + "subset_name": subset_name, + "features_path": str(features_path), + "limit_problems": limit_problems, + "n_allowed_pids": len(allowed_list), + "n_trajectories": len(cebra_seqs), + "k_values": k_values, + "ar_r2": float(ar_r2), + "per_run": detail_rows, + } + + if out_dir is not None: + out_dir.mkdir(parents=True, exist_ok=True) + with open(out_dir / "summary.json", "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2) + if detail_rows: + with open(out_dir / "summary.csv", "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=list(detail_rows[0].keys())) + w.writeheader() + w.writerows(detail_rows) + + return summary + + +def metric_at_k(summary: dict, k: int, field: str) -> float | None: + for row in summary["per_run"]: + if row["K"] == k: + return float(row[field]) + return None diff --git a/ablate_cot_trace_expms/correct_only/run_compare_base_rft.py b/ablate_cot_trace_expms/correct_only/run_compare_base_rft.py new file mode 100644 index 0000000..3ad16fd --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/run_compare_base_rft.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +""" +Run correct-trace ablations for base + RFT on one dataset; evaluate desired outcomes. + +Subsets per model: + - all + - within_correct + - paired_correct (same problem_ids where both base and RFT are correct) + +Example: + python run_compare_base_rft.py --dataset gsm8k --model-family qwen14 \\ + --hf-cache-dir ./hf_cache +""" + +from __future__ import annotations + +import argparse +import json +import sys +from datetime import datetime, timezone +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from config import DEFAULT_LIMIT_PROBLEMS, resolve_paths # noqa: E402 +from evaluate import evaluate_compare_summaries # noqa: E402 +from hf_utils import ensure_model_artifacts # noqa: E402 +from manifest import ( # noqa: E402 + DEFAULT_MANIFEST_CACHE_DIR, + add_manifest_cache_args, + copy_manifest_to_run_dir, + get_or_build_manifest, + get_or_build_paired_index, +) +from pipeline import execute_correct_subset # noqa: E402 +from subsets import resolve_subset_pids # noqa: E402 + +COMPARE_SUBSETS = ("all", "within_correct", "paired_correct") + + +def parse_args(): + p = argparse.ArgumentParser(description="Base vs RFT correct-trace ablation.") + p.add_argument("--dataset", required=True, choices=list(DEFAULT_LIMIT_PROBLEMS)) + p.add_argument("--model-family", required=True, choices=["llama8", "qwen14", "qwen1.5"]) + p.add_argument("--layer", type=int, default=None) + p.add_argument("--base-features-path", type=Path, default=None) + p.add_argument("--base-cot-data-path", type=Path, default=None) + p.add_argument("--rft-features-path", type=Path, default=None) + p.add_argument("--rft-cot-data-path", type=Path, default=None) + p.add_argument("--hf-cache-dir", type=Path, default=None) + p.add_argument("--limit-problems", type=int, default=None) + p.add_argument("--k-values", type=int, nargs="+", default=[4, 5, 6]) + p.add_argument("--k-focus", type=int, default=5) + p.add_argument("--out-dir", type=Path, default=Path(__file__).parent / "results" / "compare") + p.add_argument("--run-name", type=str, default=None) + p.add_argument("--skip-run", action="store_true", help="Load existing summaries from out-dir") + p.add_argument("--min-paired-gap-fraction", type=float, default=0.5) + add_manifest_cache_args(p) + return p.parse_args() + + +def _run_model_subsets( + variant: str, + features_path: Path, + cot_path: Path, + manifest: dict, + base_manifest: dict, + rft_manifest: dict, + *, + limit: int, + k_values: list[int], + run_dir: Path, +) -> dict[str, dict]: + summaries: dict[str, dict] = {} + for subset in COMPARE_SUBSETS: + pids = resolve_subset_pids( + subset, + manifest=manifest, + limit_problems=limit, + base_manifest=base_manifest, + rft_manifest=rft_manifest, + reference_manifest=manifest, + ) + print(f" [{variant}] subset={subset}: n_pids={len(pids)}", flush=True) + if len(pids) < 3: + print(f" SKIP (too few)", flush=True) + continue + sub_out = run_dir / variant / subset + summaries[subset] = execute_correct_subset( + features_path, + pids, + subset_name=subset, + limit_problems=limit, + k_values=k_values, + cebra_cache=sub_out / "cebra_cache.pkl", + out_dir=sub_out, + model_tag=variant, + ) + return summaries + + +def main() -> int: + args = parse_args() + limit = args.limit_problems or DEFAULT_LIMIT_PROBLEMS[args.dataset] + base_paths = resolve_paths(args.dataset, args.model_family, "base", layer=args.layer) + rft_paths = resolve_paths(args.dataset, args.model_family, "rft", layer=args.layer) + + run_name = args.run_name or datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + run_dir = args.out_dir / run_name / f"{args.dataset}_{args.model_family}" + + base_summaries: dict[str, dict] = {} + rft_summaries: dict[str, dict] = {} + cache_dir = args.manifest_cache_dir or DEFAULT_MANIFEST_CACHE_DIR + base_cache_path: str | None = None + rft_cache_path: str | None = None + paired_cache_path: str | None = None + + if not args.skip_run: + if args.base_features_path and args.base_cot_data_path: + base_feat, base_cot = args.base_features_path, args.base_cot_data_path + else: + base_feat, base_cot = ensure_model_artifacts(base_paths, cache_dir=args.hf_cache_dir) + + if args.rft_features_path and args.rft_cot_data_path: + rft_feat, rft_cot = args.rft_features_path, args.rft_cot_data_path + else: + rft_feat, rft_cot = ensure_model_artifacts(rft_paths, cache_dir=args.hf_cache_dir) + + run_dir.mkdir(parents=True, exist_ok=True) + mkwargs = dict( + cache_dir=cache_dir, + use_cached=args.use_cached_manifest, + rebuild=args.rebuild_manifest, + ) + + print("=== Base manifest ===", flush=True) + base_manifest, base_cache, base_cached = get_or_build_manifest( + base_cot, + args.dataset, + limit_problems=limit, + model_folder=base_paths.folder, + variant="base", + layer=base_paths.layer, + **mkwargs, + ) + base_cache_path = str(base_cache) + copy_manifest_to_run_dir(base_cache, run_dir / "base_manifest.json") + print( + f" base accuracy={base_manifest['accuracy']:.3f} " + f"({base_manifest['n_correct']}/{base_manifest['n_problems']}) " + f"[{'cache' if base_cached else 'built'}]", + flush=True, + ) + + print("=== RFT manifest ===", flush=True) + rft_manifest, rft_cache, rft_cached = get_or_build_manifest( + rft_cot, + args.dataset, + limit_problems=limit, + model_folder=rft_paths.folder, + variant="rft", + layer=rft_paths.layer, + **mkwargs, + ) + rft_cache_path = str(rft_cache) + copy_manifest_to_run_dir(rft_cache, run_dir / "rft_manifest.json") + print( + f" rft accuracy={rft_manifest['accuracy']:.3f} " + f"({rft_manifest['n_correct']}/{rft_manifest['n_problems']}) " + f"[{'cache' if rft_cached else 'built'}]", + flush=True, + ) + + paired_list, paired_cache, paired_cached = get_or_build_paired_index( + base_manifest, + rft_manifest, + dataset_key=args.dataset, + model_family=args.model_family, + layer=base_paths.layer, + limit_problems=limit, + cache_dir=cache_dir, + use_cached=args.use_cached_manifest, + rebuild=args.rebuild_manifest, + ) + paired_cache_path = str(paired_cache) if paired_cache else None + print( + f" paired_correct n={len(paired_list)} " + f"[{'cache' if paired_cached else 'built'} -> {paired_cache}]", + flush=True, + ) + + print("\n=== Base model SDS ===", flush=True) + base_summaries = _run_model_subsets( + "base", + base_feat, + base_cot, + base_manifest, + base_manifest, + rft_manifest, + limit=limit, + k_values=args.k_values, + run_dir=run_dir, + ) + + print("\n=== RFT model SDS ===", flush=True) + rft_summaries = _run_model_subsets( + "rft", + rft_feat, + rft_cot, + rft_manifest, + base_manifest, + rft_manifest, + limit=limit, + k_values=args.k_values, + run_dir=run_dir, + ) + else: + with open(run_dir / "base_manifest.json", encoding="utf-8") as f: + base_manifest = json.load(f) + with open(run_dir / "rft_manifest.json", encoding="utf-8") as f: + rft_manifest = json.load(f) + for variant in ("base", "rft"): + store = base_summaries if variant == "base" else rft_summaries + for subset in COMPARE_SUBSETS: + p = run_dir / variant / subset / "summary.json" + if p.is_file(): + with open(p, encoding="utf-8") as f: + store[subset] = json.load(f) + + evaluation = evaluate_compare_summaries( + base_summaries, + rft_summaries, + args.k_focus, + min_paired_gap_fraction=args.min_paired_gap_fraction, + ) + + report = { + "dataset": args.dataset, + "model_family": args.model_family, + "layer": base_paths.layer, + "limit_problems": limit, + "k_focus": args.k_focus, + "k_values": args.k_values, + "base_manifest_stats": { + "accuracy": base_manifest.get("accuracy"), + "n_correct": base_manifest.get("n_correct"), + }, + "rft_manifest_stats": { + "accuracy": rft_manifest.get("accuracy"), + "n_correct": rft_manifest.get("n_correct"), + }, + "evaluation": evaluation, + "overall_pass": evaluation["primary_pass"], + "all_checks_pass": evaluation["all_passed"], + "manifest_cache_dir": str(cache_dir), + "base_manifest_cache": base_cache_path, + "rft_manifest_cache": rft_cache_path, + "paired_index_cache": paired_cache_path, + } + + report_path = run_dir / "comparison_report.json" + with open(report_path, "w", encoding="utf-8") as f: + json.dump(report, f, indent=2) + + print(f"\nWrote {report_path}", flush=True) + print(f"\n=== Evaluation at K={args.k_focus} ===", flush=True) + for c in evaluation["checks"]: + status = "PASS" if c["passed"] else "FAIL" + print(f" [{status}] {c['name']}: {c['detail']}", flush=True) + + print( + f"\nPrimary outcome (paired_correct RFT>Base): " + f"{'PASS' if evaluation['primary_pass'] else 'FAIL'}", + flush=True, + ) + print( + f"All checks (incl. paired vs all gap): {'PASS' if evaluation['all_passed'] else 'FAIL'}", + flush=True, + ) + return 0 if evaluation["primary_pass"] else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/ablate_cot_trace_expms/correct_only/run_full_analysis.py b/ablate_cot_trace_expms/correct_only/run_full_analysis.py new file mode 100644 index 0000000..612fac2 --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/run_full_analysis.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +""" +Run base+RFT correct-trace ablations for all (or selected) model×dataset combos; +aggregate metrics into CSV + plot; report whether desired outcomes are met. + +Example: + python run_full_analysis.py --hf-cache-dir ./hf_cache --datasets gsm8k math500 \\ + --model-families qwen14 --k-values 5 --skip-run # evaluate existing only +""" + +from __future__ import annotations + +import argparse +import json +import subprocess +import sys +from datetime import datetime, timezone +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +REPO_ROOT = Path(__file__).resolve().parents[2] +SCRIPT_DIR = Path(__file__).resolve().parent +COMPARE_SCRIPT = SCRIPT_DIR / "run_compare_base_rft.py" + +sys.path.insert(0, str(REPO_ROOT)) +sys.path.insert(0, str(SCRIPT_DIR)) + +from config import DEFAULT_LIMIT_PROBLEMS, HF_REPOS, MODEL_PAIRS, all_combo_keys # noqa: E402 +from evaluate import evaluate_compare_summaries # noqa: E402 +from pipeline import metric_at_k # noqa: E402 + + +def parse_args(): + p = argparse.ArgumentParser(description="Full correct-trace ablation sweep + summary plot.") + p.add_argument("--datasets", nargs="+", default=None, choices=list(HF_REPOS)) + p.add_argument("--model-families", nargs="+", default=None, choices=list(MODEL_PAIRS)) + p.add_argument("--hf-cache-dir", type=Path, default=Path(__file__).parent / "hf_cache") + p.add_argument("--limit-problems", type=int, default=None) + p.add_argument("--k-values", type=int, nargs="+", default=[5]) + p.add_argument("--k-focus", type=int, default=5) + p.add_argument("--out-dir", type=Path, default=Path(__file__).parent / "results" / "full_analysis") + p.add_argument("--run-name", type=str, default=None) + p.add_argument("--skip-run", action="store_true", help="Only aggregate; expect compare runs under out-dir") + p.add_argument("--dry-run", action="store_true", help="Print commands only") + p.add_argument( + "--manifest-cache-dir", + type=Path, + default=None, + help="Passed to run_compare_base_rft (default: correct_only/manifest_cache)", + ) + p.add_argument( + "--rebuild-manifest", + action="store_true", + help="Force re-grade and refresh manifest cache for all compare runs.", + ) + p.add_argument( + "--no-use-cached-manifest", + action="store_true", + help="Disable manifest cache reads for compare subprocesses.", + ) + return p.parse_args() + + +def _gap_row( + dataset: str, + model_family: str, + subset: str, + metric: str, + base_summary: dict, + rft_summary: dict, + k: int, +) -> dict | None: + b = metric_at_k(base_summary, k, metric) + r = metric_at_k(rft_summary, k, metric) + if b is None or r is None: + return None + return { + "dataset": dataset, + "model_family": model_family, + "subset": subset, + "metric": metric, + "K": k, + "base": b, + "rft": r, + "gap_rft_minus_base": r - b, + } + + +def run_compare_subprocess( + dataset: str, + model_family: str, + *, + hf_cache_dir: Path, + limit: int | None, + k_values: list[int], + k_focus: int, + out_dir: Path, + run_name: str, + dry_run: bool, + manifest_cache_dir: Path | None = None, + rebuild_manifest: bool = False, + use_cached_manifest: bool = True, +) -> int: + cmd = [ + sys.executable, + str(COMPARE_SCRIPT), + "--dataset", + dataset, + "--model-family", + model_family, + "--hf-cache-dir", + str(hf_cache_dir), + "--k-focus", + str(k_focus), + "--k-values", + *[str(k) for k in k_values], + "--out-dir", + str(out_dir / "compare"), + "--run-name", + run_name, + ] + if manifest_cache_dir is not None: + cmd.extend(["--manifest-cache-dir", str(manifest_cache_dir)]) + if rebuild_manifest: + cmd.append("--rebuild-manifest") + if not use_cached_manifest: + cmd.append("--no-use-cached-manifest") + if limit is not None: + cmd.extend(["--limit-problems", str(limit)]) + print(" ".join(cmd), flush=True) + if dry_run: + return 0 + return subprocess.call(cmd) + + +def collect_rows(compare_root: Path, k_focus: int) -> tuple[pd.DataFrame, list[dict]]: + rows: list[dict] = [] + eval_reports: list[dict] = [] + + for report_path in sorted(compare_root.rglob("comparison_report.json")): + combo_dir = report_path.parent + with open(report_path, encoding="utf-8") as f: + report = json.load(f) + eval_reports.append(report) + + dataset = report["dataset"] + model_family = report["model_family"] + + for subset in ("all", "within_correct", "paired_correct"): + base_p = combo_dir / "base" / subset / "summary.json" + rft_p = combo_dir / "rft" / subset / "summary.json" + if not base_p.is_file() or not rft_p.is_file(): + continue + with open(base_p, encoding="utf-8") as f: + base_s = json.load(f) + with open(rft_p, encoding="utf-8") as f: + rft_s = json.load(f) + for metric in ("persistence", "mean_self_transition", "delta_r2", "K_eff"): + row = _gap_row(dataset, model_family, subset, metric, base_s, rft_s, k_focus) + if row: + rows.append(row) + + return pd.DataFrame(rows), eval_reports + + +def plot_gaps(df: pd.DataFrame, out_path: Path, k_focus: int) -> None: + if df.empty: + return + plot_metrics = ["persistence", "delta_r2"] + subdf = df[df["metric"].isin(plot_metrics) & (df["K"] == k_focus)] + if subdf.empty: + return + + combos = subdf.groupby(["dataset", "model_family"]).size().reset_index()[ + ["dataset", "model_family"] + ] + n = len(combos) + fig, axes = plt.subplots(1, n, figsize=(4 * n, 4), squeeze=False) + subsets = ["all", "within_correct", "paired_correct"] + colors = {"all": "#4C72B0", "within_correct": "#DD8452", "paired_correct": "#55A868"} + + for ax, (_, row) in zip(axes[0], combos.iterrows()): + d, m = row["dataset"], row["model_family"] + block = subdf[(subdf["dataset"] == d) & (subdf["model_family"] == m)] + x = list(range(len(subsets))) + for i, subset in enumerate(subsets): + for metric, offset in zip(plot_metrics, (-0.15, 0.15)): + val = block[(block["subset"] == subset) & (block["metric"] == metric)] + if val.empty: + continue + gap = float(val["gap_rft_minus_base"].iloc[0]) + ax.bar( + i + offset, + gap, + width=0.25, + color=colors.get(subset, "gray"), + alpha=0.85 if metric == "persistence" else 0.55, + label=f"{subset}" if metric == "persistence" else None, + ) + ax.axhline(0, color="black", linewidth=0.8) + ax.set_xticks(x) + ax.set_xticklabels([s.replace("_", "\n") for s in subsets], fontsize=8) + ax.set_title(f"{d}\n{m}") + ax.set_ylabel("RFT − Base") + + handles, labels = axes[0][0].get_legend_handles_labels() + if handles: + fig.legend(handles, labels, loc="upper center", ncol=3, bbox_to_anchor=(0.5, 1.02)) + fig.suptitle(f"Correct-trace ablation: RFT−Base gap (K={k_focus})", y=1.08) + fig.tight_layout() + out_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def main() -> int: + args = parse_args() + datasets = args.datasets or list(HF_REPOS) + families = args.model_families or list(MODEL_PAIRS) + run_name = args.run_name or datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + analysis_dir = args.out_dir / run_name + analysis_dir.mkdir(parents=True, exist_ok=True) + + if not args.skip_run: + failures = 0 + for dataset, model_family in all_combo_keys(): + if dataset not in datasets or model_family not in families: + continue + rc = run_compare_subprocess( + dataset, + model_family, + hf_cache_dir=args.hf_cache_dir, + limit=args.limit_problems, + k_values=args.k_values, + k_focus=args.k_focus, + out_dir=analysis_dir, + run_name=run_name, + dry_run=args.dry_run, + manifest_cache_dir=args.manifest_cache_dir, + rebuild_manifest=args.rebuild_manifest, + use_cached_manifest=not args.no_use_cached_manifest, + ) + if rc != 0: + failures += 1 + if args.dry_run: + return 0 + if failures: + print(f"Warning: {failures} compare run(s) returned non-zero exit.", flush=True) + + compare_root = analysis_dir / "compare" / run_name + df, eval_reports = collect_rows(compare_root, args.k_focus) + + csv_path = analysis_dir / "gap_summary.csv" + df.to_csv(csv_path, index=False) + print(f"Wrote {csv_path} ({len(df)} rows)", flush=True) + + plot_path = analysis_dir / "gap_by_subset.png" + plot_gaps(df, plot_path, args.k_focus) + if plot_path.is_file(): + print(f"Wrote {plot_path}", flush=True) + + summary_rows = [] + for report in eval_reports: + ev = report.get("evaluation", {}) + summary_rows.append( + { + "combo": f"{report['dataset']}_{report['model_family']}", + "primary_pass": report.get("overall_pass", False), + "all_checks_pass": report.get("all_checks_pass", False), + "base_acc": report.get("base_manifest_stats", {}).get("accuracy"), + "rft_acc": report.get("rft_manifest_stats", {}).get("accuracy"), + } + ) + + summary_df = pd.DataFrame(summary_rows) + summary_path = analysis_dir / "outcome_summary.csv" + summary_df.to_csv(summary_path, index=False) + + master = { + "run_name": run_name, + "k_focus": args.k_focus, + "n_combos": len(summary_rows), + "n_primary_pass": int(summary_df["primary_pass"].sum()) if not summary_df.empty else 0, + "desired_outcome": ( + "RFT > Base on paired_correct for persistence and delta_r2 at k_focus" + ), + "combos": summary_rows, + "eval_reports": [ + { + "combo": f"{r['dataset']}_{r['model_family']}", + "overall_pass": r.get("overall_pass"), + "checks": r.get("evaluation", {}).get("checks", []), + } + for r in eval_reports + ], + } + master_path = analysis_dir / "master_report.json" + with open(master_path, "w", encoding="utf-8") as f: + json.dump(master, f, indent=2) + + print(f"\nWrote {summary_path}", flush=True) + print(f"Wrote {master_path}", flush=True) + + if summary_df.empty: + print("No comparison reports found.", flush=True) + return 1 + + all_primary = bool(summary_df["primary_pass"].all()) + print( + f"\nPrimary outcome across combos: " + f"{int(summary_df['primary_pass'].sum())}/{len(summary_df)} passed", + flush=True, + ) + print(summary_df.to_string(index=False), flush=True) + return 0 if all_primary else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/ablate_cot_trace_expms/correct_only/run_single_model.py b/ablate_cot_trace_expms/correct_only/run_single_model.py new file mode 100644 index 0000000..d198f08 --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/run_single_model.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +""" +Run correct-trace ablation for one model variant on one dataset. + +Builds a correctness manifest from cot_data.pkl, then runs CEBRA+SDS on +subsets: all, within_correct (and optionally others passed via --subsets). + +Example: + python run_single_model.py --dataset gsm8k --model-family qwen14 --variant base \\ + --features-path ./hf_cache/.../all_sentences_features.pkl \\ + --cot-data-path ./hf_cache/.../cot_data.pkl +""" + +from __future__ import annotations + +import argparse +import json +import sys +from datetime import datetime, timezone +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from config import DEFAULT_LIMIT_PROBLEMS, resolve_paths # noqa: E402 +from hf_utils import ensure_model_artifacts # noqa: E402 +from manifest import ( # noqa: E402 + DEFAULT_MANIFEST_CACHE_DIR, + add_manifest_cache_args, + copy_manifest_to_run_dir, + get_or_build_manifest, +) +from pipeline import execute_correct_subset # noqa: E402 +from subsets import resolve_subset_pids # noqa: E402 + + +def parse_args(): + p = argparse.ArgumentParser(description="Correct-trace SDS ablation (single model).") + p.add_argument("--dataset", required=True, choices=list(DEFAULT_LIMIT_PROBLEMS)) + p.add_argument("--model-family", required=True, choices=["llama8", "qwen14", "qwen1.5"]) + p.add_argument("--variant", required=True, choices=["base", "rft"]) + p.add_argument("--layer", type=int, default=None) + p.add_argument("--features-path", type=Path, default=None) + p.add_argument("--cot-data-path", type=Path, default=None) + p.add_argument("--hf-cache-dir", type=Path, default=None, help="Download HF files here if paths missing") + p.add_argument("--limit-problems", type=int, default=None) + p.add_argument( + "--subsets", + nargs="+", + default=["all", "within_correct"], + choices=["all", "within_correct"], + ) + p.add_argument("--k-values", type=int, nargs="+", default=[4, 5, 6]) + p.add_argument("--out-dir", type=Path, default=Path(__file__).parent / "results" / "single_model") + p.add_argument("--run-name", type=str, default=None) + add_manifest_cache_args(p) + return p.parse_args() + + +def main() -> int: + args = parse_args() + limit = args.limit_problems or DEFAULT_LIMIT_PROBLEMS[args.dataset] + paths = resolve_paths(args.dataset, args.model_family, args.variant, layer=args.layer) + + if args.features_path and args.cot_data_path: + features_path = args.features_path + cot_path = args.cot_data_path + else: + features_path, cot_path = ensure_model_artifacts(paths, cache_dir=args.hf_cache_dir) + + run_name = args.run_name or datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + run_dir = args.out_dir / run_name / paths.model_tag + run_dir.mkdir(parents=True, exist_ok=True) + + cache_dir = args.manifest_cache_dir or DEFAULT_MANIFEST_CACHE_DIR + print(f"Loading/building manifest for {cot_path}...", flush=True) + manifest, cache_path, from_cache = get_or_build_manifest( + cot_path, + args.dataset, + limit_problems=limit, + model_folder=paths.folder, + variant=args.variant, + layer=paths.layer, + cache_dir=cache_dir, + use_cached=args.use_cached_manifest, + rebuild=args.rebuild_manifest, + ) + manifest_path = run_dir / "manifest.json" + copy_manifest_to_run_dir(cache_path, manifest_path) + print( + f" accuracy={manifest['accuracy']:.3f} " + f"({manifest['n_correct']}/{manifest['n_problems']}) " + f"[{'cache' if from_cache else 'built'} -> {cache_path}]", + flush=True, + ) + + summaries = {} + for subset in args.subsets: + pids = resolve_subset_pids( + subset, + manifest=manifest, + limit_problems=limit, + ) + print(f"\n=== Subset '{subset}': {len(pids)} problems ===", flush=True) + if len(pids) < 3: + print(f" SKIP: too few problems ({len(pids)})", flush=True) + continue + sub_out = run_dir / subset + summary = execute_correct_subset( + features_path, + pids, + subset_name=subset, + limit_problems=limit, + k_values=args.k_values, + cebra_cache=sub_out / "cebra_cache.pkl", + out_dir=sub_out, + model_tag=paths.model_tag, + ) + summaries[subset] = summary + + report = { + "dataset": args.dataset, + "model_family": args.model_family, + "variant": args.variant, + "layer": paths.layer, + "limit_problems": limit, + "manifest_path": str(manifest_path), + "manifest_cache_path": str(cache_path), + "manifest_from_cache": from_cache, + "features_path": str(features_path), + "subsets": summaries, + } + report_path = run_dir / "run_report.json" + with open(report_path, "w", encoding="utf-8") as f: + json.dump(report, f, indent=2) + print(f"\nWrote {report_path}", flush=True) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/ablate_cot_trace_expms/correct_only/subsets.py b/ablate_cot_trace_expms/correct_only/subsets.py new file mode 100644 index 0000000..17804cd --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/subsets.py @@ -0,0 +1,36 @@ +"""Select problem_id sets for correct-trace ablation subsets.""" + +from __future__ import annotations + + +def pids_within_correct(manifest: dict) -> set[int]: + return {r["problem_id"] for r in manifest["records"] if r["is_correct"]} + + +def pids_all(manifest: dict, limit_problems: int) -> set[int]: + return {r["problem_id"] for r in manifest["records"] if r["problem_id"] < limit_problems} + + +def pids_paired_correct(base_manifest: dict, rft_manifest: dict) -> set[int]: + b = pids_within_correct(base_manifest) + r = pids_within_correct(rft_manifest) + return b & r + + +def resolve_subset_pids( + subset: str, + *, + manifest: dict, + limit_problems: int, + base_manifest: dict | None = None, + rft_manifest: dict | None = None, +) -> set[int]: + if subset == "all": + return pids_all(manifest, limit_problems) + if subset == "within_correct": + return pids_within_correct(manifest) + if subset == "paired_correct": + if base_manifest is None or rft_manifest is None: + raise ValueError("paired_correct requires base_manifest and rft_manifest") + return pids_paired_correct(base_manifest, rft_manifest) + raise ValueError(f"Unknown subset {subset!r}; use all, within_correct, or paired_correct") diff --git a/ablate_cot_trace_expms/time_shuffle/run_compare_base_rft.py b/ablate_cot_trace_expms/time_shuffle/run_compare_base_rft.py index 8a9d7ae..15d09cd 100644 --- a/ablate_cot_trace_expms/time_shuffle/run_compare_base_rft.py +++ b/ablate_cot_trace_expms/time_shuffle/run_compare_base_rft.py @@ -38,8 +38,16 @@ def _metric_at_k(summary: dict, k: int, order: str, field: str) -> float | None: return float(row[field]) return None for comp in summary["comparison"]: - if comp["K"] == k: - return float(comp["shuffled"][f"{field}_mean"]) + if comp["K"] != k: + continue + shuf = comp["shuffled"] + if field == "delta_r2": + if "regime_r2_mean" in shuf: + return float(shuf["regime_r2_mean"]) - float(summary["ar_r2"]) + key = f"{field}_mean" + return float(shuf[key]) if key in shuf else None + key = f"{field}_mean" + return float(shuf[key]) if key in shuf else None return None diff --git a/ablate_cot_trace_expms/time_shuffle/run_remaining_batch.sh b/ablate_cot_trace_expms/time_shuffle/run_remaining_batch.sh new file mode 100644 index 0000000..d87a936 --- /dev/null +++ b/ablate_cot_trace_expms/time_shuffle/run_remaining_batch.sh @@ -0,0 +1,111 @@ +#!/usr/bin/env bash +# Run remaining model×dataset time-shuffle comparisons; cleanup pkls/caches after each. +set -euo pipefail +export HF_TOKEN="${HF_TOKEN:?Set HF_TOKEN}" +cd /workspace/mi-cot +PY=python3 +RUNNER=ablate_cot_trace_expms/time_shuffle/run_compare_base_rft.py +OUT_ROOT=ablate_cot_trace_expms/time_shuffle/results + +cleanup_data() { + rm -rf /workspace/data/SDS_train_gsm8k /workspace/data/SDS_train_svamp \ + /workspace/data/SDS_train_mmlu-pro /workspace/data/SDS_math500_test 2>/dev/null || true +} + +cleanup_run_artifacts() { + local run_dir="$1" + find "$run_dir" -name 'cebra_cache_*.pkl' -delete 2>/dev/null || true +} + +download_pair() { + local repo="$1" base_f="$2" rft_f="$3" local_name="$4" + $PY << PY +import os +from huggingface_hub import hf_hub_download +token = os.environ["HF_TOKEN"] +repo = "$repo" +local_dir = "/workspace/data/$local_name" +for f in ["$base_f", "$rft_f"]: + p = hf_hub_download(repo_id=repo, filename=f, repo_type="dataset", token=token, local_dir=local_dir) + print("downloaded", p) +PY +} + +run_job() { + local tag="$1" repo="$2" local_name="$3" base_f="$4" rft_f="$5" limit="$6" out_name="$7" + echo "" + echo "========== JOB: $tag ==========" + cleanup_data + download_pair "$repo" "$base_f" "$rft_f" "$local_name" + local base_p="/workspace/data/${local_name}/${base_f}" + local rft_p="/workspace/data/${local_name}/${rft_f}" + local out_dir="${OUT_ROOT}/${out_name}" + mkdir -p "$out_dir" + set +e + $PY "$RUNNER" \ + --base-features-path "$base_p" \ + --rft-features-path "$rft_p" \ + --limit-problems "$limit" \ + --k-values 4 5 6 \ + --k-focus 5 \ + --shuffle-seeds 0 1 2 3 4 \ + --out-dir "$out_dir" + local ec=$? + set -e + local run_dir + run_dir=$(ls -td "${out_dir}"/*/ 2>/dev/null | head -1) + if [[ -n "$run_dir" ]]; then + cleanup_run_artifacts "$run_dir" + fi + cleanup_data + echo "JOB $tag exit_code=$ec run_dir=$run_dir" + return $ec +} + +# --- Remaining jobs (skip completed: qwen15 math500+gsm8k, qwen14/llama8 math500) --- + +# GSM8K +run_job "qwen14_gsm8k" "withmartian/SDS_train_gsm8k" "SDS_train_gsm8k" \ + "qwen_14B_base/layer_28/all_sentences_features.pkl" \ + "qwen_14b_reasoning/layer_28/all_sentences_features.pkl" \ + 2000 "compare_qwen14_gsm8k" || true + +run_job "llama8_gsm8k" "withmartian/SDS_train_gsm8k" "SDS_train_gsm8k" \ + "llama_8B_base/layer_22/all_sentences_features.pkl" \ + "llama_8b_reasoning/layer_22/all_sentences_features.pkl" \ + 2000 "compare_llama8_gsm8k" || true + +# SVAMP +run_job "qwen15_svamp" "withmartian/SDS_train_svamp" "SDS_train_svamp" \ + "Qwen_1_5B_base/layer_20/all_sentences_features.pkl" \ + "Qwen_1_5B_reasoning/layer_20/all_sentences_features.pkl" \ + 800 "compare_qwen15_svamp" || true + +run_job "qwen14_svamp" "withmartian/SDS_train_svamp" "SDS_train_svamp" \ + "Qwen_14B_base/layer_28/all_sentences_features.pkl" \ + "Qwen_14B_reasoning/layer_28/all_sentences_features.pkl" \ + 800 "compare_qwen14_svamp" || true + +run_job "llama8_svamp" "withmartian/SDS_train_svamp" "SDS_train_svamp" \ + "Llama_8B_base/layer_22/all_sentences_features.pkl" \ + "Llama_8B_reasoning/layer_22/all_sentences_features.pkl" \ + 800 "compare_llama8_svamp" || true + +# MMLU-Pro +run_job "qwen15_mmlu" "withmartian/SDS_train_mmlu-pro" "SDS_train_mmlu-pro" \ + "qwen1.5b_base/layer_20/all_sentences_features.pkl" \ + "qwen1.5b/layer_20/all_sentences_features.pkl" \ + 500 "compare_qwen15_mmlu-pro" || true + +run_job "qwen14_mmlu" "withmartian/SDS_train_mmlu-pro" "SDS_train_mmlu-pro" \ + "qwen14b_base/layer_28/all_sentences_features.pkl" \ + "qwen14b/layer_28/all_sentences_features.pkl" \ + 500 "compare_qwen14_mmlu-pro" || true + +run_job "llama8_mmlu" "withmartian/SDS_train_mmlu-pro" "SDS_train_mmlu-pro" \ + "llama8b_base/layer_22/all_sentences_features.pkl" \ + "llama8b/layer_22/all_sentences_features.pkl" \ + 500 "compare_llama8_mmlu-pro" || true + +echo "" +echo "========== BATCH DONE ==========" diff --git a/exploration/cebra_EM.py b/exploration/cebra_EM.py index 4c9f60b..3f17334 100644 --- a/exploration/cebra_EM.py +++ b/exploration/cebra_EM.py @@ -343,13 +343,22 @@ "ACTIVE_COMPUTATION", "FINAL_ANSWER_EMISSION" ] -def load_and_prepare_cebra(path, mode='temporal', limit_problems=500, max_triplets=25): +def load_and_prepare_cebra( + path, + mode='temporal', + limit_problems=500, + max_triplets=25, + allowed_problem_ids=None, +): if not os.path.exists(path): raise FileNotFoundError(f"Could not find {path}.") print(f"Loading data for CEBRA-{mode}...", flush=True) with open(path, 'rb') as f: all_features = pickle.load(f) all_features = [f for f in all_features if f['problem_id'] < limit_problems] + if allowed_problem_ids is not None: + allowed = set(allowed_problem_ids) + all_features = [f for f in all_features if f['problem_id'] in allowed] p_map = defaultdict(list) for i, f in enumerate(all_features): p_map[f['problem_id']].append(i) From 6753aedfb8ba80a139006774a1ac530b32bb1816 Mon Sep 17 00:00:00 2001 From: wlg1 Date: Mon, 25 May 2026 19:48:46 -0700 Subject: [PATCH 4/4] correct-only expms update --- ablate_cot_trace_expms/README.md | 2 +- ablate_cot_trace_expms/correct_only/config.py | 35 +++++- .../correct_only/grading.py | 5 +- .../correct_only/manifest.py | 4 +- .../correct_only/pipeline.py | 16 ++- .../correct_only/run_compare_base_rft.py | 1 - .../correct_only/run_remaining_sweep.sh | 103 ++++++++++++++++++ .../correct_only/summarize_sweep.py | 83 ++++++++++++++ 8 files changed, 238 insertions(+), 11 deletions(-) create mode 100644 ablate_cot_trace_expms/correct_only/run_remaining_sweep.sh create mode 100644 ablate_cot_trace_expms/correct_only/summarize_sweep.py diff --git a/ablate_cot_trace_expms/README.md b/ablate_cot_trace_expms/README.md index 8e2d374..79fdf20 100644 --- a/ablate_cot_trace_expms/README.md +++ b/ablate_cot_trace_expms/README.md @@ -5,6 +5,6 @@ Controlled experiments for reviewer-style confound checks on the CEBRA + SDS pip | Subfolder | Reviewer item | Description | |-----------|---------------|-------------| | [`time_shuffle/`](time_shuffle/) | Time-shuffled trajectory | Permute **sentence-aligned activation steps** before SDS/EM; CEBRA still trained on real CoT order | -| [`correct_only/`](correct_only/) | Correctness confound | **Paired-correct** (both base & RFT right on same problems); no length controls | +| [`correct_only/`](correct_only/) | Correctness confound | **Paired-correct** (both base & RFT right on same problems) | diff --git a/ablate_cot_trace_expms/correct_only/config.py b/ablate_cot_trace_expms/correct_only/config.py index a0dc368..4994409 100644 --- a/ablate_cot_trace_expms/correct_only/config.py +++ b/ablate_cot_trace_expms/correct_only/config.py @@ -13,13 +13,43 @@ "math500": "withmartian/SDS_math500_test", } -# Short model family key -> (base folder, rft folder) under each HF repo +# Short model family key -> (base folder, rft folder) under train HF repos MODEL_PAIRS: dict[str, tuple[str, str]] = { "llama8": ("llama_8B_base", "llama_8b_reasoning"), "qwen14": ("qwen_14B_base", "qwen_14b_reasoning"), "qwen1.5": ("qwen_1.5B_base", "qwen1.5b_reasoning"), } +# Per-repo folder names (train repos differ from SDS_math500_test layout) +MODEL_PAIRS_BY_REPO: dict[str, dict[str, tuple[str, str]]] = { + "withmartian/SDS_math500_test": { + "llama8": ("Llama_8B_base", "Llama_8B_reasoning"), + "qwen14": ("Qwen_14B_base", "Qwen_14B_reasoning"), + "qwen1.5": ("Qwen_1_5B_base", "Qwen_1_5B_reasoning"), + }, + "withmartian/SDS_train_svamp": { + "llama8": ("Llama_8B_base", "Llama_8B_reasoning"), + "qwen14": ("Qwen_14B_base", "Qwen_14B_reasoning"), + "qwen1.5": ("Qwen_1_5B_base", "Qwen_1_5B_reasoning"), + }, + "withmartian/SDS_train_mmlu-pro": { + "llama8": ("llama8b_base", "llama8b"), + "qwen14": ("qwen14b_base", "qwen14b"), + "qwen1.5": ("qwen1.5b_base", "qwen1.5b"), + }, +} + + +def model_pair_folders( + model_family: str, + *, + hf_repo: str, +) -> tuple[str, str]: + pairs = MODEL_PAIRS_BY_REPO.get(hf_repo, MODEL_PAIRS) + if model_family not in pairs: + raise KeyError(f"Unknown model_family {model_family!r} for repo {hf_repo!r}") + return pairs[model_family] + DEFAULT_LAYER: dict[str, int] = { "llama8": 22, "qwen14": 28, @@ -75,7 +105,8 @@ def resolve_paths( layer = layer if layer is not None else DEFAULT_LAYER[model_family] repo = hf_repo or HF_REPOS[dataset_key] - folder = MODEL_PAIRS[model_family][0 if variant == "base" else 1] + base_folder, rft_folder = model_pair_folders(model_family, hf_repo=repo) + folder = base_folder if variant == "base" else rft_folder layer_dir = f"layer_{layer}" root = hf_root if hf_root is not None else Path(repo.replace("/", "__")) diff --git a/ablate_cot_trace_expms/correct_only/grading.py b/ablate_cot_trace_expms/correct_only/grading.py index 99427b0..e9e75c9 100644 --- a/ablate_cot_trace_expms/correct_only/grading.py +++ b/ablate_cot_trace_expms/correct_only/grading.py @@ -3,6 +3,7 @@ from __future__ import annotations import re +import math from typing import Any DATASET_CFG: dict[str, dict[str, Any]] = { @@ -38,10 +39,12 @@ def _normalize_number(s: str) -> str: s = re.sub(r"^\$|\$$", "", s) try: v = float(s) + if not math.isfinite(v): + return s.strip().lower() if abs(v - round(v)) < 1e-9: return str(int(round(v))) return str(v) - except ValueError: + except (ValueError, OverflowError): return s.strip().lower() diff --git a/ablate_cot_trace_expms/correct_only/manifest.py b/ablate_cot_trace_expms/correct_only/manifest.py index c195444..ba6e696 100644 --- a/ablate_cot_trace_expms/correct_only/manifest.py +++ b/ablate_cot_trace_expms/correct_only/manifest.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any -from .grading import DATASET_CFG, _extract_answer_from_cot, is_correct_answer, load_benchmark_rows +from grading import DATASET_CFG, _extract_answer_from_cot, is_correct_answer, load_benchmark_rows DEFAULT_MANIFEST_CACHE_DIR = Path(__file__).resolve().parent / "manifest_cache" @@ -253,7 +253,7 @@ def get_or_build_paired_index( Returns (problem_ids, cache_path or None, from_cache). """ - from .subsets import pids_paired_correct + from subsets import pids_paired_correct cache_dir = cache_dir or DEFAULT_MANIFEST_CACHE_DIR paired_dir = cache_dir / "paired" diff --git a/ablate_cot_trace_expms/correct_only/pipeline.py b/ablate_cot_trace_expms/correct_only/pipeline.py index 7a0f66d..b8393bb 100644 --- a/ablate_cot_trace_expms/correct_only/pipeline.py +++ b/ablate_cot_trace_expms/correct_only/pipeline.py @@ -16,6 +16,7 @@ from exploration.cebra_EM import ( # noqa: E402 K_SWEEP, + PCA_DIM, fit_and_evaluate, linear_ar_r2, load_and_prepare_cebra, @@ -85,11 +86,18 @@ def log(msg: str) -> None: limit_problems=limit_problems, allowed_problem_ids=allowed_list, ) - if len(all_f) < 10: - raise ValueError( - f"Too few features ({len(all_f)}) for subset {subset_name}; " - f"check allowed_problem_ids and limit_problems." + if len(all_f) < PCA_DIM: + log( + f"[{model_tag}] SKIP subset {subset_name}: " + f"{len(all_f)} sentence features < PCA_DIM={PCA_DIM}" ) + return { + "subset": subset_name, + "skipped": True, + "reason": f"insufficient_features ({len(all_f)} < {PCA_DIM})", + "n_allowed_pids": len(allowed_list), + "per_k": [], + } log(f"[{model_tag}] Training CEBRA ({len(all_f)} steps, {len(triplets)} triplets)...") cebra_seqs, pca_seqs, labels = train_cebra_projection(all_f, triplets) ar_r2 = linear_ar_r2(pca_seqs) diff --git a/ablate_cot_trace_expms/correct_only/run_compare_base_rft.py b/ablate_cot_trace_expms/correct_only/run_compare_base_rft.py index 3ad16fd..977567b 100644 --- a/ablate_cot_trace_expms/correct_only/run_compare_base_rft.py +++ b/ablate_cot_trace_expms/correct_only/run_compare_base_rft.py @@ -81,7 +81,6 @@ def _run_model_subsets( limit_problems=limit, base_manifest=base_manifest, rft_manifest=rft_manifest, - reference_manifest=manifest, ) print(f" [{variant}] subset={subset}: n_pids={len(pids)}", flush=True) if len(pids) < 3: diff --git a/ablate_cot_trace_expms/correct_only/run_remaining_sweep.sh b/ablate_cot_trace_expms/correct_only/run_remaining_sweep.sh new file mode 100644 index 0000000..aad37a8 --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/run_remaining_sweep.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash +# Run remaining model×dataset combos (middle layers); cleanup large pkls after each run. +set -euo pipefail + +MI_COT_ROOT="$(cd "$(dirname "$0")/../.." && pwd)" +CORRECT_ONLY="$(cd "$(dirname "$0")" && pwd)" +cd "$MI_COT_ROOT" + +RUN_NAME="${RUN_NAME:-middle_layer_sweep}" +HF_CACHE="${HF_CACHE:-$MI_COT_ROOT/hf_cache}" +RESULTS="$CORRECT_ONLY/results/compare/$RUN_NAME" + +# dataset model_family layer +COMBOS=( + "math500 qwen14 28" + "math500 llama8 22" + "gsm8k qwen14 28" + "gsm8k qwen1.5 20" + "gsm8k llama8 22" + "svamp qwen14 28" + "svamp qwen1.5 20" + "svamp llama8 22" + "mmlu-pro qwen14 28" + "mmlu-pro qwen1.5 20" + "mmlu-pro llama8 22" +) + +cleanup_combo() { + local dataset="$1" + local family="$2" + local layer="$3" + local run_subdir="$RESULTS/${dataset}_${family}" + + # CEBRA caches (reproducible from summaries) + find "$run_subdir" -name 'cebra_cache.pkl' -delete 2>/dev/null || true + + # HF activation pickles for this combo's folders + local -a folders=() + case "$dataset" in + math500|svamp) + case "$family" in + qwen14) folders=(Qwen_14B_base Qwen_14B_reasoning) ;; + qwen1.5) folders=(Qwen_1_5B_base Qwen_1_5B_reasoning) ;; + llama8) folders=(Llama_8B_base Llama_8B_reasoning) ;; + esac + ;; + gsm8k) + case "$family" in + qwen14) folders=(qwen_14B_base qwen_14b_reasoning) ;; + qwen1.5) folders=(qwen_1.5B_base qwen1.5b_reasoning) ;; + llama8) folders=(llama_8B_base llama_8b_reasoning) ;; + esac + ;; + mmlu-pro) + case "$family" in + qwen14) folders=(qwen14b_base qwen14b) ;; + qwen1.5) folders=(qwen1.5b_base qwen1.5b) ;; + llama8) folders=(llama8b_base llama8b) ;; + esac + ;; + esac + + for folder in "${folders[@]}"; do + rm -rf "$HF_CACHE/$folder/layer_${layer}" 2>/dev/null || true + rm -rf "$HF_CACHE/withmartian__"*/"$folder"/"layer_${layer}" 2>/dev/null || true + done + + # HuggingFace hub blob cache (safe to drop between combos) + rm -rf "$HF_CACHE/.cache" 2>/dev/null || true + + echo " [cleanup] freed pkls for ${dataset}/${family} layer_${layer}" +} + +for spec in "${COMBOS[@]}"; do + read -r dataset family layer <<< "$spec" + echo "" + echo "========== ${dataset} / ${family} (layer ${layer}) ==========" + if [[ -f "$RESULTS/${dataset}_${family}/comparison_report.json" ]]; then + echo " Skipping (comparison_report.json exists)" + cleanup_combo "$dataset" "$family" "$layer" + continue + fi + + python3 ablate_cot_trace_expms/correct_only/run_compare_base_rft.py \ + --dataset "$dataset" \ + --model-family "$family" \ + --layer "$layer" \ + --hf-cache-dir "$HF_CACHE" \ + --k-values 5 \ + --k-focus 5 \ + --run-name "$RUN_NAME" \ + || true + + if [[ ! -f "$RESULTS/${dataset}_${family}/comparison_report.json" ]]; then + echo "ERROR: no comparison_report.json for ${dataset} ${family}" + exit 1 + fi + + cleanup_combo "$dataset" "$family" "$layer" +done + +echo "" +echo "========== SWEEP DONE ==========" diff --git a/ablate_cot_trace_expms/correct_only/summarize_sweep.py b/ablate_cot_trace_expms/correct_only/summarize_sweep.py new file mode 100644 index 0000000..0d075e5 --- /dev/null +++ b/ablate_cot_trace_expms/correct_only/summarize_sweep.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +"""Aggregate comparison_report.json files into a summary table.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import pandas as pd + +SCRIPT_DIR = Path(__file__).resolve().parent + + +def load_report(path: Path) -> dict: + with open(path, encoding="utf-8") as f: + return json.load(f) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument( + "--compare-dirs", + nargs="+", + type=Path, + default=[ + SCRIPT_DIR / "results" / "compare" / "middle_layer_sweep", + SCRIPT_DIR / "results" / "compare" / "qwen15_math500_first_test", + ], + ) + p.add_argument( + "--out-csv", + type=Path, + default=SCRIPT_DIR / "results" / "compare" / "middle_layer_sweep_summary.csv", + ) + args = p.parse_args() + + rows = [] + for compare_dir in args.compare_dirs: + if not compare_dir.is_dir(): + continue + for report_path in sorted(compare_dir.glob("*/comparison_report.json")): + r = load_report(report_path) + ev = r.get("evaluation", {}) + checks = {c["name"]: c for c in ev.get("checks", [])} + + def metric_check(subset: str, metric: str) -> str | None: + key = f"rft_gt_base_{subset}_{metric}" + c = checks.get(key) + if not c: + return None + return "PASS" if c["passed"] else "FAIL" + + rows.append( + { + "dataset": r["dataset"], + "model_family": r["model_family"], + "layer": r["layer"], + "run_dir": str(report_path.parent.name), + "base_acc": r.get("base_manifest_stats", {}).get("accuracy"), + "rft_acc": r.get("rft_manifest_stats", {}).get("accuracy"), + "primary_pass": ev.get("primary_pass"), + "all_pass": ev.get("all_passed"), + "paired_persist": metric_check("paired_correct", "persistence"), + "paired_self_trans": metric_check("paired_correct", "mean_self_transition"), + "paired_delta_r2": metric_check("paired_correct", "delta_r2"), + "paired_k_eff": metric_check("paired_correct", "K_eff"), + "gap_persist": metric_check("paired_correct", "persistence") + and checks.get("paired_retains_gap_persistence", {}).get("passed"), + } + ) + + df = pd.DataFrame(rows) + df = df.sort_values(["dataset", "model_family"]).reset_index(drop=True) + args.out_csv.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(args.out_csv, index=False) + print(df.to_string(index=False)) + print(f"\nWrote {args.out_csv}") + print(f"Primary pass: {df['primary_pass'].sum()}/{len(df)}") + + +if __name__ == "__main__": + main()