Inference-time, training-free structured pruning for LLMs via probe-based representation–parameter coupling.
- Paper: OpenReview (ICML 2026)
- Project Page: https://gongzhiren.github.io/SubspacePathPruner-website/
SubspacePath Pruner compiles a scenario-specific pruned subnetwork of a frozen large language model at inference time — no fine-tuning, no gradient updates, no scenario training data. It exploits a simple observation: representation subspaces that align with semantic domains in embedding space are coupled with sparse, reusable attention-head pathways in parameter space. A handful of lightweight linear probes, trained once offline, are enough to map a scenario's domain mixture onto the heads that matter, and to prune the rest under a budget.
The released code lets you:
- Run pruned inference out of the box using the pretrained probe/importance/whitelist artifacts shipped for four models (Qwen2.5-7B, Qwen2.5-14B, Llama-3.1-8B, Llama-2-13B-chat).
- Reproduce the offline stage (probe training → calibration → head importance → whitelist) on your own model with the included input pools.
The framework has two components, both run offline once; the online stage is purely a compilation step.
A compact set of quasi-orthogonal domain axes is constructed in embedding space
to serve as a stable coordinate system. In this release the domains are
pre-selected, so the axes reduce to a one-hot basis over the selected domains
(src/preorientation/domain_axes.py). You do not need to re-run domain
selection to use the code.
Offline (training-free for the model):
-
Layer-wise linear probes — for each layer ℓ and domain k, a 1-vs-rest probe is trained on the post-attention residual stream to score domain relevance. Only the probes are trained; the base model is frozen. (
src/preorientation/linear_probe.py) -
Temperature calibration — probe logits are calibrated with temperature scaling for single-domain, OOD, and cross-domain regimes. (
src/preorientation/probe_calibration.py) -
Axis-aligned head importance
I_{ℓ,h,k}— the expected squared projection of each head's residual write onto domain axis k:I_{ℓ,h,k} = E_{x ~ P_k} [ (u_k^T w̃_{ℓ,h}(x))^2 / (||w̃_{ℓ,h}(x)||^2 + ε) ](
src/probe/head_importance.py) -
Whitelist — domain-invariant "backbone" heads (low importance variance across domains, high mean importance, confirmed by a statistical test) that are always kept. (
src/probe/whitelist_identification.py)
Online (per scenario, training-free):
- Diagnose the scenario's domain mixture by evaluating the probes on the first
turn(s); compute a normalized-entropy scenario breadth
c(s) ∈ [0,1]. (src/probe/domain_inference.py) - Score each head by combining its calibrated domain relevance with the cached
importance:
score_{ℓ,h}(s) = Σ_k s_k^eff · I_{ℓ,h,k}. - Compile a binary head-pruning mask under a budget governed by
pruning_strength(higher prunes more), always keeping the whitelist, and reuse the same mask for every turn of the scenario — zero per-turn overhead. (src/probe/session_pruning.py)
Because nothing is optimized online, scenario compilation takes tens of milliseconds and the pruned model is just a structured head mask over the original frozen weights.
subspacepath-pruner/
├── src/
│ ├── model/ # BaseModel wrapper (load, quantize, head masking)
│ ├── preorientation/ # DBS axes, linear probes, temperature calibration, foundation layers
│ ├── probe/ # head importance, whitelist, domain inference, session pruning
│ ├── reasoning/ # single-step inference + official chat templates
│ ├── evaluation/ # exact-match answer scoring + result logging
│ └── utils/ # logging, config, device helpers
├── scripts/
│ ├── run_inference.py # ONLINE: prune + infer + evaluate (history OFF by default)
│ └── train_offline.py # OFFLINE: train probes → calibrate → importance → whitelist
├── data/
│ ├── train/ val/ # per-domain input pools (input-only, no labels)
│ └── test/ # evaluation scenarios (see "Data" below)
├── outputs/<model>/ppd_pipeline/ # pretrained artifacts for 4 models (see below)
├── requirements.txt
└── README.md
git clone <your-fork-url> subspacepath-pruner
cd subspacepath-pruner
python -m venv .venv && source .venv/bin/activate
pip install -r requirements.txtTested with Python 3.10+, PyTorch ≥ 2.0, Transformers ≥ 4.40, on a single CUDA GPU.
Models are not bundled. Download the HuggingFace weights you want and place them
under models/, e.g.:
models/Qwen2.5-7B-Instruct/ # standard HF model directory
models/llama-3.1-8b/
models/Qwen14B/ # Qwen2.5-14B-Instruct
models/Llama-2-13b-chat-hf/
The offline artifacts for the four models are already in
outputs/<model>/ppd_pipeline/, so you can run pruned inference immediately:
python scripts/run_inference.py \
--model Qwen2.5-7B-Instruct \
--gpu 0 \
--pruning_strengths 0.2 0.4 0.6 \
--num_samples 20This loads the probes, head importance and whitelist, sweeps the requested pruning
strengths over a 20-scenario sample from each test category, and prints a summary
of accuracy vs. average pruned-head fraction. Per-run results are written to
outputs/<model>/pruning_strength/.
Useful flags:
| Flag | Default | Meaning |
|---|---|---|
--model |
Qwen2.5-7B-Instruct |
directory name under models/ |
--pruning_strengths |
0.2 0.4 0.6 |
strengths to sweep (higher → more pruning) |
--datasets |
selected_domain out_of_domain cross_domain |
which data/test/* categories to run |
--num_samples |
20 |
scenarios sampled per category (-1 = all) |
--use_history |
false |
include previous turns as context |
--use_calibration |
false |
use the temperature-calibrated probe system |
To also evaluate the cross-dataset splits, add e.g.
--datasets selected_domain out_of_domain cross_domain commonsenseqa natural_questions arc.
The defaults above (--num_samples 20, three strengths) are tuned for a quick
first run. To reproduce the full experiment, evaluate the entire test set over
the full pruning-strength sweep:
python scripts/run_inference.py \
--model Qwen2.5-7B-Instruct --gpu 0 \
--datasets selected_domain out_of_domain cross_domain \
--num_samples -1 \
--pruning_strengths 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8--num_samples -1 uses every scenario in each category (no sub-sampling). Repeat
with the cross-dataset categories (commonsenseqa natural_questions arc) for the
cross-dataset table. Swap --model for Qwen14B, llama-3.1-8b, or
Llama-2-13b-chat-hf to cover the other models. Accuracy is reported as the
exact-match score described under Evaluation protocol.
To regenerate the artifacts for a model from the included input pools:
python scripts/train_offline.py \
--model Qwen2.5-7B-Instruct \
--gpu 0 \
--selected_domains chemistry finance history math philosophy technology \
--final_probe_epochs 20Steps performed (base model stays frozen throughout):
- train the layer-wise linear probes,
- temperature-calibrate them (4-probe system),
- compute axis-aligned head importance
I_{ℓ,h,k}, - identify the domain-invariant head whitelist.
Outputs are written to outputs/<model>/ppd_pipeline/:
probe1_base.pt, probe_temperatures.json, calibration/, head_importance.pt,
whitelist.json — exactly the files run_inference.py consumes.
--head_importance single (default) computes one importance set from the base
probe and is fast; --head_importance multi computes scenario-specific
(single/OOD/cross-domain) sets and uses the cross-domain one.
The default
--selected_domainsorder matches the per-domain files indata/train/. The online pruner is index-based and internally consistent with whatever order you train on, so domain names never affect inference — they only label the probe indices.
All data ships with the repository.
data/train/,data/val/— input-only pools, one JSON per selected domain (chemistry, finance, history, math, philosophy, technology). Each file is a list of raw input strings; no labels are needed to train the probes.data/test/— multi-turn evaluation scenarios. Each scenario has atopic_descriptionand a list ofturns, where every turn hasprompt,answer, andtask_type(multiple_choice/factual/code/reasoning). Categories:selected_domain/— in-distribution domains,out_of_domain/— held-out domains (OOD),cross_domain/— multi-domain mixtures,commonsenseqa/,natural_questions/,arc/— cross-dataset splits.
outputs/<model>/ppd_pipeline/ ships, for each of the four models:
| File | Contents |
|---|---|
probe1_base.pt |
trained layer-wise linear probes (the base probe) |
final_probes.pt |
probes from the final training pass |
probe_temperatures.json |
per-regime / per-layer temperature-scaling parameters |
calibration/ |
cross-dataset calibration temperatures |
head_importance.pt |
axis-aligned head importance I_{ℓ,h,k} |
whitelist.json |
list of always-kept (layer, head) pairs |
These are the only learned parameters in the method (a few MB per model). The base model weights themselves are never modified.
Answer accuracy is exact-match (EM) keyword overlap: after stopword removal,
the fraction of the expected answer's keywords matched by the prediction (for
multiple-choice, the extracted option letter must match exactly). There is no
semantic-similarity scoring and no LLM judge — see src/evaluation/answer_evaluator.py.
- Single-GPU by design; select the device with
--gpu. For large models, load with--quantization int4intrain_offline.py. - History (multi-turn conversation context) is off by default; enable with
--use_history true. peftis only needed if your base model path is a LoRA/PEFT adapter; it is an optional dependency.
Released under the MIT License — see LICENSE.