Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ablate_cot_trace_expms/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# CoT trace & pipeline ablations

Controlled experiments for reviewer-style confound checks on the CEBRA + SDS pipeline.

| Subfolder | Reviewer item | Description |
|-----------|---------------|-------------|
| [`time_shuffle/`](time_shuffle/) | Time-shuffled trajectory | Permute **sentence-aligned activation steps** before SDS/EM; CEBRA still trained on real CoT order |
| [`correct_only/`](correct_only/) | Correctness confound | **Paired-correct** (both base & RFT right on same problems) |


1 change: 1 addition & 0 deletions ablate_cot_trace_expms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Ablation experiments on CoT trace/activations.
68 changes: 68 additions & 0 deletions ablate_cot_trace_expms/correct_only/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Ablation: correct-trace subsets only

Tests whether the base vs RFT SDS gap survives when restricting to **correct** trajectories—especially **paired correct** (same problems where **both** base and RFT are graded correct).

No length matching or incorrect-trace controls.

## Subsets

| Subset | Definition |
|--------|------------|
| `all` | All problems in the manifest (up to `--limit-problems`) |
| `within_correct` | Per-model: only trajectories with a correct final answer |
| `paired_correct` | `problem_id` where **base and RFT** are both correct |

CEBRA is retrained **within each subset** (only allowed `problem_id`s enter triplets and trajectories).

## Manifest cache (correctness labels only)

Grading runs once per `(dataset, model folder, variant, layer, cot_data hash)` and is saved under **`manifest_cache/`** (default). No activations are copied—only JSON with `problem_id`, `is_correct`, answers, etc.

| Flag | Meaning |
|------|---------|
| `--manifest-cache-dir PATH` | Cache directory (default: `correct_only/manifest_cache`) |
| `--use-cached-manifest` / `--no-use-cached-manifest` | Read cache when valid (default: use) |
| `--rebuild-manifest` | Force re-grade and overwrite cache |

Paired-correct `problem_id` lists are cached under `manifest_cache/paired/`. Each run still copies manifests into its `results/...` folder for reproducibility.

## Scripts

| Script | Role |
|--------|------|
| `run_single_model.py` | One variant (`base` or `rft`) + one dataset → manifest + SDS metrics |
| `run_compare_base_rft.py` | Base + RFT for one dataset → all subsets + `comparison_report.json` |
| `run_full_analysis.py` | Sweep combos → `gap_summary.csv`, `gap_by_subset.png`, `outcome_summary.csv` |

## Quick start

```bash
python ablate_cot_trace_expms/correct_only/run_compare_base_rft.py \
--dataset gsm8k --model-family qwen14 \
--hf-cache-dir ./hf_cache --k-values 5 --k-focus 5

python ablate_cot_trace_expms/correct_only/run_full_analysis.py \
--hf-cache-dir ./hf_cache --datasets gsm8k --model-families qwen14
```

## Desired outcome (pass/fail)

At `--k-focus` (default **K=5**), **primary pass** requires on **`paired_correct`**:

- RFT > Base for `persistence`, `mean_self_transition`, `delta_r2`, `K_eff`
- Paired RFT−Base gap retains ≥50% of the `all`-subset gap (persistence & ΔR²)

Exit code **0** when primary pass holds (`run_compare_base_rft.py`, `run_full_analysis.py`).

## Outputs

```
results/compare/<run_name>/<dataset>_<model_family>/
base_manifest.json
rft_manifest.json
base/{all,within_correct,paired_correct}/summary.json
rft/...
comparison_report.json
```

`gap_by_subset.png` compares RFT−Base gaps for `all`, `within_correct`, and `paired_correct`.
1 change: 1 addition & 0 deletions ablate_cot_trace_expms/correct_only/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Correct-trace ablation: paired-correct subset for base vs RFT SDS metrics.
127 changes: 127 additions & 0 deletions ablate_cot_trace_expms/correct_only/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Paths and defaults for correct-trace ablations on SDS HF datasets."""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

# HF dataset repo id -> benchmark loader key (see grading.DATASET_CFG)
HF_REPOS: dict[str, str] = {
"gsm8k": "withmartian/SDS_train_gsm8k",
"svamp": "withmartian/SDS_train_svamp",
"mmlu-pro": "withmartian/SDS_train_mmlu-pro",
"math500": "withmartian/SDS_math500_test",
}

# Short model family key -> (base folder, rft folder) under train HF repos
MODEL_PAIRS: dict[str, tuple[str, str]] = {
"llama8": ("llama_8B_base", "llama_8b_reasoning"),
"qwen14": ("qwen_14B_base", "qwen_14b_reasoning"),
"qwen1.5": ("qwen_1.5B_base", "qwen1.5b_reasoning"),
Comment on lines +17 to +20

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MODEL_PAIRS hardcodes one (base_folder, rft_folder) per family, but the HF artifact folders vary by dataset, so resolve_paths()/ensure_model_artifacts() can request nonexistent filenames for some dataset×family combos — should we make the mapping dataset-aware or add per-dataset overrides?

Finding type: Breaking Changes | Severity: 🔴 High


Want Baz to fix this for you? Activate Fixer

Other fix methods

Fix in Cursor

Prompt for AI Agents
Before applying, verify this suggestion against the current code. In
ablate_cot_trace_expms/correct_only/config.py around lines 17-21 where `MODEL_PAIRS` is
defined and in `resolve_paths()` around lines 60-82, the code assumes a single
(base_folder, rft_folder) per `model_family`, but HF artifact folder names vary by
`dataset_key`, causing nonexistent `<folder>/layer_<n>/cot_data.pkl` and
`all_sentences_features.pkl` requests. Refactor by making the folder mapping
dataset-aware (e.g., change `MODEL_PAIRS` to a nested mapping keyed by dataset_key then
model_family, or introduce a per-dataset override dict) and update `resolve_paths()` to
derive `folder` from both `dataset_key` and `model_family` for the requested `variant`
('base' vs 'rft'). Add explicit error messages when a dataset×family override is
missing so failures are actionable.

}

# Per-repo folder names (train repos differ from SDS_math500_test layout)
MODEL_PAIRS_BY_REPO: dict[str, dict[str, tuple[str, str]]] = {
"withmartian/SDS_math500_test": {
"llama8": ("Llama_8B_base", "Llama_8B_reasoning"),
"qwen14": ("Qwen_14B_base", "Qwen_14B_reasoning"),
"qwen1.5": ("Qwen_1_5B_base", "Qwen_1_5B_reasoning"),
},
"withmartian/SDS_train_svamp": {
"llama8": ("Llama_8B_base", "Llama_8B_reasoning"),
"qwen14": ("Qwen_14B_base", "Qwen_14B_reasoning"),
"qwen1.5": ("Qwen_1_5B_base", "Qwen_1_5B_reasoning"),
},
"withmartian/SDS_train_mmlu-pro": {
"llama8": ("llama8b_base", "llama8b"),
"qwen14": ("qwen14b_base", "qwen14b"),
"qwen1.5": ("qwen1.5b_base", "qwen1.5b"),
},
}


def model_pair_folders(
model_family: str,
*,
hf_repo: str,
) -> tuple[str, str]:
pairs = MODEL_PAIRS_BY_REPO.get(hf_repo, MODEL_PAIRS)
if model_family not in pairs:
raise KeyError(f"Unknown model_family {model_family!r} for repo {hf_repo!r}")
return pairs[model_family]

DEFAULT_LAYER: dict[str, int] = {
"llama8": 22,
"qwen14": 28,
"qwen1.5": 20,
}

# Default problem cap (match exploration scripts / paper subsets)
DEFAULT_LIMIT_PROBLEMS: dict[str, int] = {
"gsm8k": 2000,
"svamp": 800,
"mmlu-pro": 500,
"math500": 500,
}

SUBSET_NAMES = (
"all",
"within_correct",
"paired_correct",
)


@dataclass(frozen=True)
class ModelDatasetPaths:
dataset_key: str
model_family: str
variant: str # "base" | "rft"
layer: int
hf_repo: str
folder: str
features_path: Path
cot_data_path: Path

@property
def model_tag(self) -> str:
return f"{self.model_family}_{self.variant}"


def resolve_paths(
dataset_key: str,
model_family: str,
variant: str,
*,
layer: int | None = None,
hf_root: Path | None = None,
hf_repo: str | None = None,
) -> ModelDatasetPaths:
if dataset_key not in HF_REPOS:
raise KeyError(f"Unknown dataset_key {dataset_key!r}; choose from {list(HF_REPOS)}")
if model_family not in MODEL_PAIRS:
raise KeyError(f"Unknown model_family {model_family!r}; choose from {list(MODEL_PAIRS)}")
if variant not in ("base", "rft"):
raise ValueError("variant must be 'base' or 'rft'")

layer = layer if layer is not None else DEFAULT_LAYER[model_family]
repo = hf_repo or HF_REPOS[dataset_key]
base_folder, rft_folder = model_pair_folders(model_family, hf_repo=repo)
folder = base_folder if variant == "base" else rft_folder
layer_dir = f"layer_{layer}"
root = hf_root if hf_root is not None else Path(repo.replace("/", "__"))

base = root / folder / layer_dir
return ModelDatasetPaths(
dataset_key=dataset_key,
model_family=model_family,
variant=variant,
layer=layer,
hf_repo=repo,
folder=folder,
features_path=base / "all_sentences_features.pkl",
cot_data_path=base / "cot_data.pkl",
)


def all_combo_keys() -> list[tuple[str, str]]:
return [(d, m) for d in HF_REPOS for m in MODEL_PAIRS]
85 changes: 85 additions & 0 deletions ablate_cot_trace_expms/correct_only/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Pass/fail criteria for correct-trace ablations."""

from __future__ import annotations

from pipeline import metric_at_k

METRICS_RFT_SHOULD_WIN = ("persistence", "mean_self_transition", "delta_r2", "K_eff")


def evaluate_compare_summaries(
base_summaries: dict[str, dict],
rft_summaries: dict[str, dict],
k_focus: int,
*,
min_paired_gap_fraction: float = 0.5,
) -> dict:
"""
Desired outcome:
- On paired_correct: RFT > base on key metrics
- paired_correct gap retains a substantial fraction of the all-trace gap
"""
checks: list[dict] = []

def add(name: str, passed: bool, detail: str, values: dict | None = None) -> None:
checks.append({"name": name, "passed": passed, "detail": detail, "values": values or {}})

for subset in ("paired_correct", "all"):
if subset not in base_summaries or subset not in rft_summaries:
continue
for metric in METRICS_RFT_SHOULD_WIN:
b = metric_at_k(base_summaries[subset], k_focus, metric)
r = metric_at_k(rft_summaries[subset], k_focus, metric)
if b is None or r is None:
add(f"rft_gt_base_{subset}_{metric}", False, f"missing K={k_focus}")
continue
add(
f"rft_gt_base_{subset}_{metric}",
r > b,
f"{subset}: RFT={r:.4f} vs Base={b:.4f}",
{"base": b, "rft": r, "subset": subset},
)

for metric in ("persistence", "delta_r2"):
if not all(s in base_summaries for s in ("all", "paired_correct")):
continue
if not all(s in rft_summaries for s in ("all", "paired_correct")):
continue
gap_all = metric_at_k(rft_summaries["all"], k_focus, metric) - metric_at_k(
base_summaries["all"], k_focus, metric
)
gap_paired = metric_at_k(rft_summaries["paired_correct"], k_focus, metric) - metric_at_k(
base_summaries["paired_correct"], k_focus, metric
Comment on lines +48 to +52

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gap_all / gap_paired subtract metric_at_k() results unconditionally; should we mirror the earlier None check before the gap arithmetic, otherwise missing K can raise TypeError and abort evaluation?

Finding type: Logical Bugs | Severity: 🟢 Low


Want Baz to fix this for you? Activate Fixer

Other fix methods

Fix in Cursor

Prompt for AI Agents
Before applying, verify this suggestion against the current code. In
ablate_cot_trace_expms/correct_only/evaluate.py around lines 48-67 inside
evaluate_compare_summaries’ gap-retention logic for metrics (“persistence”,
“delta_r2”), mirror the defensive handling used above: after computing gap_all and
gap_paired, first check whether metric_at_k(...) for both base and rft returned None
(e.g., missing K=k_focus) and in that case add a failed check entry (include subset/all
vs paired_correct and the missing K) and skip the subtraction/arithmetic. Then only
compute gaps and the min_paired_gap_fraction comparison when all required values are
non-None, so evaluation never raises TypeError when K is absent.

)
if gap_all <= 0:
add(
f"paired_retains_gap_{metric}",
gap_paired > 0,
f"all gap non-positive ({gap_all:.4f}); require paired gap>0 ({gap_paired:.4f})",
)
else:
frac = gap_paired / gap_all if gap_all else 0.0
add(
f"paired_retains_gap_{metric}",
gap_paired >= min_paired_gap_fraction * gap_all,
f"paired_gap={gap_paired:.4f} all_gap={gap_all:.4f} frac={frac:.2f}",
{"gap_paired": gap_paired, "gap_all": gap_all},
)

primary = [
c for c in checks
if "paired_correct" in c["name"] and c["name"].startswith("rft_gt_base")
]
primary_pass = all(c["passed"] for c in primary) if primary else False
all_passed = all(c["passed"] for c in checks)

return {
"k_focus": k_focus,
"checks": checks,
"all_passed": all_passed,
"primary_pass": primary_pass,
"desired_outcome": (
"RFT > Base on paired_correct; paired RFT−Base gap retains "
f"≥{min_paired_gap_fraction:.0%} of all-trace gap (persistence, ΔR²)"
),
}
Loading