-
Notifications
You must be signed in to change notification settings - Fork 1
time shuffle traj ablation expms #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # CoT trace & pipeline ablations | ||
|
|
||
| Controlled experiments for reviewer-style confound checks on the CEBRA + SDS pipeline. | ||
|
|
||
| | Subfolder | Reviewer item | Description | | ||
| |-----------|---------------|-------------| | ||
| | [`time_shuffle/`](time_shuffle/) | Time-shuffled trajectory | Permute **sentence-aligned activation steps** before SDS/EM; CEBRA still trained on real CoT order | | ||
| | [`correct_only/`](correct_only/) | Correctness confound | **Paired-correct** (both base & RFT right on same problems) | | ||
|
|
||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| # Ablation experiments on CoT trace/activations. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| # Ablation: correct-trace subsets only | ||
|
|
||
| Tests whether the base vs RFT SDS gap survives when restricting to **correct** trajectories—especially **paired correct** (same problems where **both** base and RFT are graded correct). | ||
|
|
||
| No length matching or incorrect-trace controls. | ||
|
|
||
| ## Subsets | ||
|
|
||
| | Subset | Definition | | ||
| |--------|------------| | ||
| | `all` | All problems in the manifest (up to `--limit-problems`) | | ||
| | `within_correct` | Per-model: only trajectories with a correct final answer | | ||
| | `paired_correct` | `problem_id` where **base and RFT** are both correct | | ||
|
|
||
| CEBRA is retrained **within each subset** (only allowed `problem_id`s enter triplets and trajectories). | ||
|
|
||
| ## Manifest cache (correctness labels only) | ||
|
|
||
| Grading runs once per `(dataset, model folder, variant, layer, cot_data hash)` and is saved under **`manifest_cache/`** (default). No activations are copied—only JSON with `problem_id`, `is_correct`, answers, etc. | ||
|
|
||
| | Flag | Meaning | | ||
| |------|---------| | ||
| | `--manifest-cache-dir PATH` | Cache directory (default: `correct_only/manifest_cache`) | | ||
| | `--use-cached-manifest` / `--no-use-cached-manifest` | Read cache when valid (default: use) | | ||
| | `--rebuild-manifest` | Force re-grade and overwrite cache | | ||
|
|
||
| Paired-correct `problem_id` lists are cached under `manifest_cache/paired/`. Each run still copies manifests into its `results/...` folder for reproducibility. | ||
|
|
||
| ## Scripts | ||
|
|
||
| | Script | Role | | ||
| |--------|------| | ||
| | `run_single_model.py` | One variant (`base` or `rft`) + one dataset → manifest + SDS metrics | | ||
| | `run_compare_base_rft.py` | Base + RFT for one dataset → all subsets + `comparison_report.json` | | ||
| | `run_full_analysis.py` | Sweep combos → `gap_summary.csv`, `gap_by_subset.png`, `outcome_summary.csv` | | ||
|
|
||
| ## Quick start | ||
|
|
||
| ```bash | ||
| python ablate_cot_trace_expms/correct_only/run_compare_base_rft.py \ | ||
| --dataset gsm8k --model-family qwen14 \ | ||
| --hf-cache-dir ./hf_cache --k-values 5 --k-focus 5 | ||
|
|
||
| python ablate_cot_trace_expms/correct_only/run_full_analysis.py \ | ||
| --hf-cache-dir ./hf_cache --datasets gsm8k --model-families qwen14 | ||
| ``` | ||
|
|
||
| ## Desired outcome (pass/fail) | ||
|
|
||
| At `--k-focus` (default **K=5**), **primary pass** requires on **`paired_correct`**: | ||
|
|
||
| - RFT > Base for `persistence`, `mean_self_transition`, `delta_r2`, `K_eff` | ||
| - Paired RFT−Base gap retains ≥50% of the `all`-subset gap (persistence & ΔR²) | ||
|
|
||
| Exit code **0** when primary pass holds (`run_compare_base_rft.py`, `run_full_analysis.py`). | ||
|
|
||
| ## Outputs | ||
|
|
||
| ``` | ||
| results/compare/<run_name>/<dataset>_<model_family>/ | ||
| base_manifest.json | ||
| rft_manifest.json | ||
| base/{all,within_correct,paired_correct}/summary.json | ||
| rft/... | ||
| comparison_report.json | ||
| ``` | ||
|
|
||
| `gap_by_subset.png` compares RFT−Base gaps for `all`, `within_correct`, and `paired_correct`. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| # Correct-trace ablation: paired-correct subset for base vs RFT SDS metrics. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| """Paths and defaults for correct-trace ablations on SDS HF datasets.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
|
|
||
| # HF dataset repo id -> benchmark loader key (see grading.DATASET_CFG) | ||
| HF_REPOS: dict[str, str] = { | ||
| "gsm8k": "withmartian/SDS_train_gsm8k", | ||
| "svamp": "withmartian/SDS_train_svamp", | ||
| "mmlu-pro": "withmartian/SDS_train_mmlu-pro", | ||
| "math500": "withmartian/SDS_math500_test", | ||
| } | ||
|
|
||
| # Short model family key -> (base folder, rft folder) under train HF repos | ||
| MODEL_PAIRS: dict[str, tuple[str, str]] = { | ||
| "llama8": ("llama_8B_base", "llama_8b_reasoning"), | ||
| "qwen14": ("qwen_14B_base", "qwen_14b_reasoning"), | ||
| "qwen1.5": ("qwen_1.5B_base", "qwen1.5b_reasoning"), | ||
| } | ||
|
|
||
| # Per-repo folder names (train repos differ from SDS_math500_test layout) | ||
| MODEL_PAIRS_BY_REPO: dict[str, dict[str, tuple[str, str]]] = { | ||
| "withmartian/SDS_math500_test": { | ||
| "llama8": ("Llama_8B_base", "Llama_8B_reasoning"), | ||
| "qwen14": ("Qwen_14B_base", "Qwen_14B_reasoning"), | ||
| "qwen1.5": ("Qwen_1_5B_base", "Qwen_1_5B_reasoning"), | ||
| }, | ||
| "withmartian/SDS_train_svamp": { | ||
| "llama8": ("Llama_8B_base", "Llama_8B_reasoning"), | ||
| "qwen14": ("Qwen_14B_base", "Qwen_14B_reasoning"), | ||
| "qwen1.5": ("Qwen_1_5B_base", "Qwen_1_5B_reasoning"), | ||
| }, | ||
| "withmartian/SDS_train_mmlu-pro": { | ||
| "llama8": ("llama8b_base", "llama8b"), | ||
| "qwen14": ("qwen14b_base", "qwen14b"), | ||
| "qwen1.5": ("qwen1.5b_base", "qwen1.5b"), | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| def model_pair_folders( | ||
| model_family: str, | ||
| *, | ||
| hf_repo: str, | ||
| ) -> tuple[str, str]: | ||
| pairs = MODEL_PAIRS_BY_REPO.get(hf_repo, MODEL_PAIRS) | ||
| if model_family not in pairs: | ||
| raise KeyError(f"Unknown model_family {model_family!r} for repo {hf_repo!r}") | ||
| return pairs[model_family] | ||
|
|
||
| DEFAULT_LAYER: dict[str, int] = { | ||
| "llama8": 22, | ||
| "qwen14": 28, | ||
| "qwen1.5": 20, | ||
| } | ||
|
|
||
| # Default problem cap (match exploration scripts / paper subsets) | ||
| DEFAULT_LIMIT_PROBLEMS: dict[str, int] = { | ||
| "gsm8k": 2000, | ||
| "svamp": 800, | ||
| "mmlu-pro": 500, | ||
| "math500": 500, | ||
| } | ||
|
|
||
| SUBSET_NAMES = ( | ||
| "all", | ||
| "within_correct", | ||
| "paired_correct", | ||
| ) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class ModelDatasetPaths: | ||
| dataset_key: str | ||
| model_family: str | ||
| variant: str # "base" | "rft" | ||
| layer: int | ||
| hf_repo: str | ||
| folder: str | ||
| features_path: Path | ||
| cot_data_path: Path | ||
|
|
||
| @property | ||
| def model_tag(self) -> str: | ||
| return f"{self.model_family}_{self.variant}" | ||
|
|
||
|
|
||
| def resolve_paths( | ||
| dataset_key: str, | ||
| model_family: str, | ||
| variant: str, | ||
| *, | ||
| layer: int | None = None, | ||
| hf_root: Path | None = None, | ||
| hf_repo: str | None = None, | ||
| ) -> ModelDatasetPaths: | ||
| if dataset_key not in HF_REPOS: | ||
| raise KeyError(f"Unknown dataset_key {dataset_key!r}; choose from {list(HF_REPOS)}") | ||
| if model_family not in MODEL_PAIRS: | ||
| raise KeyError(f"Unknown model_family {model_family!r}; choose from {list(MODEL_PAIRS)}") | ||
| if variant not in ("base", "rft"): | ||
| raise ValueError("variant must be 'base' or 'rft'") | ||
|
|
||
| layer = layer if layer is not None else DEFAULT_LAYER[model_family] | ||
| repo = hf_repo or HF_REPOS[dataset_key] | ||
| base_folder, rft_folder = model_pair_folders(model_family, hf_repo=repo) | ||
| folder = base_folder if variant == "base" else rft_folder | ||
| layer_dir = f"layer_{layer}" | ||
| root = hf_root if hf_root is not None else Path(repo.replace("/", "__")) | ||
|
|
||
| base = root / folder / layer_dir | ||
| return ModelDatasetPaths( | ||
| dataset_key=dataset_key, | ||
| model_family=model_family, | ||
| variant=variant, | ||
| layer=layer, | ||
| hf_repo=repo, | ||
| folder=folder, | ||
| features_path=base / "all_sentences_features.pkl", | ||
| cot_data_path=base / "cot_data.pkl", | ||
| ) | ||
|
|
||
|
|
||
| def all_combo_keys() -> list[tuple[str, str]]: | ||
| return [(d, m) for d in HF_REPOS for m in MODEL_PAIRS] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| """Pass/fail criteria for correct-trace ablations.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from pipeline import metric_at_k | ||
|
|
||
| METRICS_RFT_SHOULD_WIN = ("persistence", "mean_self_transition", "delta_r2", "K_eff") | ||
|
|
||
|
|
||
| def evaluate_compare_summaries( | ||
| base_summaries: dict[str, dict], | ||
| rft_summaries: dict[str, dict], | ||
| k_focus: int, | ||
| *, | ||
| min_paired_gap_fraction: float = 0.5, | ||
| ) -> dict: | ||
| """ | ||
| Desired outcome: | ||
| - On paired_correct: RFT > base on key metrics | ||
| - paired_correct gap retains a substantial fraction of the all-trace gap | ||
| """ | ||
| checks: list[dict] = [] | ||
|
|
||
| def add(name: str, passed: bool, detail: str, values: dict | None = None) -> None: | ||
| checks.append({"name": name, "passed": passed, "detail": detail, "values": values or {}}) | ||
|
|
||
| for subset in ("paired_correct", "all"): | ||
| if subset not in base_summaries or subset not in rft_summaries: | ||
| continue | ||
| for metric in METRICS_RFT_SHOULD_WIN: | ||
| b = metric_at_k(base_summaries[subset], k_focus, metric) | ||
| r = metric_at_k(rft_summaries[subset], k_focus, metric) | ||
| if b is None or r is None: | ||
| add(f"rft_gt_base_{subset}_{metric}", False, f"missing K={k_focus}") | ||
| continue | ||
| add( | ||
| f"rft_gt_base_{subset}_{metric}", | ||
| r > b, | ||
| f"{subset}: RFT={r:.4f} vs Base={b:.4f}", | ||
| {"base": b, "rft": r, "subset": subset}, | ||
| ) | ||
|
|
||
| for metric in ("persistence", "delta_r2"): | ||
| if not all(s in base_summaries for s in ("all", "paired_correct")): | ||
| continue | ||
| if not all(s in rft_summaries for s in ("all", "paired_correct")): | ||
| continue | ||
| gap_all = metric_at_k(rft_summaries["all"], k_focus, metric) - metric_at_k( | ||
| base_summaries["all"], k_focus, metric | ||
| ) | ||
| gap_paired = metric_at_k(rft_summaries["paired_correct"], k_focus, metric) - metric_at_k( | ||
| base_summaries["paired_correct"], k_focus, metric | ||
|
Comment on lines
+48
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Finding type: Want Baz to fix this for you? Activate Fixer Other fix methodsPrompt for AI Agents |
||
| ) | ||
| if gap_all <= 0: | ||
| add( | ||
| f"paired_retains_gap_{metric}", | ||
| gap_paired > 0, | ||
| f"all gap non-positive ({gap_all:.4f}); require paired gap>0 ({gap_paired:.4f})", | ||
| ) | ||
| else: | ||
| frac = gap_paired / gap_all if gap_all else 0.0 | ||
| add( | ||
| f"paired_retains_gap_{metric}", | ||
| gap_paired >= min_paired_gap_fraction * gap_all, | ||
| f"paired_gap={gap_paired:.4f} all_gap={gap_all:.4f} frac={frac:.2f}", | ||
| {"gap_paired": gap_paired, "gap_all": gap_all}, | ||
| ) | ||
|
|
||
| primary = [ | ||
| c for c in checks | ||
| if "paired_correct" in c["name"] and c["name"].startswith("rft_gt_base") | ||
| ] | ||
| primary_pass = all(c["passed"] for c in primary) if primary else False | ||
| all_passed = all(c["passed"] for c in checks) | ||
|
|
||
| return { | ||
| "k_focus": k_focus, | ||
| "checks": checks, | ||
| "all_passed": all_passed, | ||
| "primary_pass": primary_pass, | ||
| "desired_outcome": ( | ||
| "RFT > Base on paired_correct; paired RFT−Base gap retains " | ||
| f"≥{min_paired_gap_fraction:.0%} of all-trace gap (persistence, ΔR²)" | ||
| ), | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MODEL_PAIRShardcodes one(base_folder, rft_folder)per family, but the HF artifact folders vary by dataset, soresolve_paths()/ensure_model_artifacts()can request nonexistent filenames for some dataset×family combos — should we make the mapping dataset-aware or add per-dataset overrides?Finding type:
Breaking Changes| Severity: 🔴 HighWant Baz to fix this for you? Activate Fixer
Other fix methods
Prompt for AI Agents