Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
e341194
[Bug fix] Fix binary/block quantizers for Linear layers
aahaidar01 Feb 5, 2026
e428369
[Add] Combined current progress in lab practicals.
aahaidar01 Feb 5, 2026
b2c79ac
[Modify] Modified md file lab 3 section
aahaidar01 Feb 6, 2026
d108887
[Add] Added a plot folder in the same parent path of labs.md so we ca…
aahaidar01 Feb 7, 2026
263e78f
[Add] Additional edits and results for lab 3 added to the .md file
aahaidar01 Feb 7, 2026
7750d21
[Add] More modifications on lab 3 section in .md file
aahaidar01 Feb 7, 2026
24ea8a8
added ADLS lab 4 results and pruning
hadesyash Feb 7, 2026
d1bc45e
added lab 2 to the .md file
aahaidar01 Feb 7, 2026
f98c580
final images for lab 4
hadesyash Feb 7, 2026
b2e02d6
added tutorial 4 pruning to repo
aahaidar01 Feb 7, 2026
a8bdd91
Added flex attention transform pass and score mods (diff score funtio…
aahaidar01 Mar 22, 2026
4594327
Fix FlexAttention CUDA forward pass: position_embeddings support, tor…
aahaidar01 Mar 22, 2026
7c46457
Added sdpa vs flex experiments and latency benchmark to assess sdpa v…
aahaidar01 Mar 22, 2026
631e0f1
Add block_mask support with create_block_mask() for FlexAttention spe…
aahaidar01 Mar 23, 2026
4011789
removed as_tensor from sliding_window_score_mod and causal_score_mod …
aahaidar01 Mar 23, 2026
0be0812
added series of experiments to benchmark flexattention and prepare sy…
aahaidar01 Mar 23, 2026
b77150c
fix flex attention compile mode to max-autotune-no-cudagraphs as per …
aahaidar01 Mar 23, 2026
286c471
add kernel profiling, document masking, compound masking (alibi+SW) e…
aahaidar01 Mar 24, 2026
f7b728b
Add exeriment results to share with collaborators
aahaidar01 Mar 24, 2026
f5f7785
bug fix in alibi score mod when computing tensors by forcing CUDA dev…
aahaidar01 Mar 24, 2026
865e19d
bug fix in score_mods.py to hardcode device to CUDA for alibi tensor …
aahaidar01 Mar 24, 2026
98318fb
added experiments output .json files
aahaidar01 Mar 24, 2026
0dafa17
updated exp12 and uploaded exp12 .json results
aahaidar01 Mar 24, 2026
7a6d079
Added plot results script and generated plots for all conducted exper…
aahaidar01 Mar 25, 2026
42da521
Implemented Automated Search Pipeline: Ready to Test
MahmoudEletreby Mar 22, 2026
7ac7cf1
Fixed smoke test error
MahmoudEletreby Mar 23, 2026
adbf5f6
Fix PBS scripts: add module load and correct CLI entrypoint
MahmoudEletreby Mar 23, 2026
0b52579
Fix PBS: --save-dir -> --project-dir
MahmoudEletreby Mar 23, 2026
112cd09
Replace ch CLI with standalone search scripts for BERT and LLaMA
MahmoudEletreby Mar 23, 2026
e0282de
Fix module_fusion: remove 'default' key from quantize_by_type args
MahmoudEletreby Mar 23, 2026
03d59b0
Fix nlp_cls_forward: filter non-tensor batch keys before model(**batch)
MahmoudEletreby Mar 23, 2026
a11dead
Fix rebuild_model: move model to accelerator after quantization repla…
MahmoudEletreby Mar 23, 2026
7a4297c
Fix BertModel: restore embeddings, pooler, and tuple return
MahmoudEletreby Mar 23, 2026
3e39158
Fix BertForSequenceClassification: don't access .hidden_states/.atten…
MahmoudEletreby Mar 23, 2026
0504054
Fix nlp_cls_forward: squeeze labels to 1D for torchmetrics
MahmoudEletreby Mar 23, 2026
0916f2f
Fix runners: filter non-tensor batch keys in latency, reset metrics b…
MahmoudEletreby Mar 23, 2026
9abdb45
Fix optuna: guard visualizer.log_metrics call when visualizer is None
MahmoudEletreby Mar 23, 2026
f8a84af
Fix nlp_lm_forward: filter non-tensor batch keys before model(**model…
MahmoudEletreby Mar 23, 2026
7089ed9
Use fine-tuned SST-2 BERT checkpoint for meaningful accuracy results
MahmoudEletreby Mar 23, 2026
7583797
Update part3 implementation summary
MahmoudEletreby Mar 23, 2026
3fd0c78
Fix FlexAttention: remove none from mask_mod registry to skip block_m…
MahmoudEletreby Mar 23, 2026
78fda29
Fix modeling_bert: restore elementwise_affine=True on all LayerNorm i…
MahmoudEletreby Mar 23, 2026
83e2301
Fix BertModel: restore attention mask and position/token_type embeddi…
MahmoudEletreby Mar 23, 2026
3a8939f
Restrict BERT search to fusion_strategy=none — FlexAttention not appl…
MahmoudEletreby Mar 23, 2026
69926f9
Fix Mistral sliding window: reduce window_size 512→128 for real block…
MahmoudEletreby Mar 24, 2026
c468719
Fix Mistral PBS: add missing module load Python/3.12.3-GCCcore-13.3.0
MahmoudEletreby Mar 24, 2026
fdfe10b
Fix Mistral PBS: add search_mistral.py script, replace python -m chop…
MahmoudEletreby Mar 24, 2026
cca25b3
Fix Mistral: load in float16 so FlexAttention Triton tiles fit in L40…
MahmoudEletreby Mar 24, 2026
b0c8f78
Fix dtype mismatch: restore original model dtype after quantize pass …
MahmoudEletreby Mar 24, 2026
958e7d9
Add scaling benchmark script: seq_len, batch_size, peak memory (Exper…
MahmoudEletreby Mar 25, 2026
26d5e5f
Add scaling benchmark script: seq_len, batch_size, peak memory (Exper…
MahmoudEletreby Mar 25, 2026
e851d2e
Fix benchmark: guard del model in finally block against unbound variable
MahmoudEletreby Mar 25, 2026
f9ccbcd
Fix benchmark: deepcopy quant config to prevent mutation across calls
MahmoudEletreby Mar 25, 2026
bb331a2
Updated .md file with sequence length experiment results
MahmoudEletreby Mar 25, 2026
dd047de
Add int8_rmsnorm and int8_both configs to scaling benchmark (Part 2 i…
MahmoudEletreby Mar 26, 2026
32186ee
Add fused RMSNorm residual fusion pass (Part 2)
MahmoudEletreby Mar 26, 2026
65e8b20
Add fused RMSNorm residual fusion pass (Part 2)
MahmoudEletreby Mar 26, 2026
4bad2c5
Fix fused_rmsnorm import: add missing Triton kernel and guard __init_…
MahmoudEletreby Mar 26, 2026
4c563db
Fix wrong function name for fused_rmsnorm pass in search and benchmark
MahmoudEletreby Mar 26, 2026
cb6a982
Fix 2D hidden_states guard in fused_forward and FlexAttention variants
MahmoudEletreby Mar 27, 2026
ab4efc3
Free trial model from GPU after each Optuna trial to prevent OOM
MahmoudEletreby Mar 27, 2026
a4521a4
Added MD file as supplementary reference for all experiment results o…
aahaidar01 Mar 27, 2026
f4530e7
Changed experiment results MD file path
aahaidar01 Mar 27, 2026
c9d1dac
Load TinyLlama in float16 for search to enable fused_rmsnorm speedup
MahmoudEletreby Mar 27, 2026
5aa1a16
Add kernel launch profiling script (Part 3 experiment)
MahmoudEletreby Mar 27, 2026
c561f97
Fix profiler compat: use cuda_time_total fallback for newer PyTorch
MahmoudEletreby Mar 27, 2026
116129f
Fix profiler: add CPU activity for kernel attribution, fix div-by-zero
MahmoudEletreby Mar 27, 2026
bcd7b3f
Debug: print profiler event time attributes
MahmoudEletreby Mar 27, 2026
49ebbfd
Fix profiler: use device_time_total/self_device_time_total (PyTorch 2…
MahmoudEletreby Mar 27, 2026
dd43a3b
Fix Mistral OOM: reload model from disk per strategy instead of deepcopy
MahmoudEletreby Mar 27, 2026
413c086
Fix Mistral OOM: gc.collect() between strategies, free profiler buffers
MahmoudEletreby Mar 27, 2026
85059d3
Add plot_results.py: Pareto scatter, seq-len scaling, memory scaling
MahmoudEletreby Mar 27, 2026
052859f
Fix plots: Pareto labels, Mistral seqlen scaling, kernel dispatch bar…
MahmoudEletreby Mar 27, 2026
d31bb5c
Add O(n²) and O(n·w) reference slope lines to Fig 2 seqlen scaling plot
MahmoudEletreby Mar 27, 2026
6755e9c
Remove misleading reference slope lines; move Fig 2 legend to lower r…
MahmoudEletreby Mar 27, 2026
de8d09b
Fix Fig1/Fig2 aesthetics: smaller legend, annotation below title, P m…
MahmoudEletreby Mar 27, 2026
8ea5f60
Shrink Fig 2 legend further
MahmoudEletreby Mar 27, 2026
aa17278
Fig 3: shorten FusedRMSNorm label to RMSNorm, reduce x-tick fontsize
MahmoudEletreby Mar 27, 2026
d5f3ffb
Merge branch 'flexattention' into Automated_Search_Pipeline
aahaidar01 Mar 27, 2026
373bb78
Removed .pbs training scripts. Added RMS_NORM Kernel Fusion MD file
aahaidar01 Mar 27, 2026
f7a44e6
additional repo cleanup
aahaidar01 Mar 27, 2026
267e48a
Merge remote-tracking branch 'origin/Automated_Search_Pipeline'
aahaidar01 Mar 27, 2026
120f689
changed .md file name into Automated_Search_Pipeline for clear reference
aahaidar01 Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ vagrant/.vagrant/
/src/.machop_cache/
lightning_logs/

# /experiments/flex_attention/results
/experiments/flex_attention/results_old
*.pbs
# macOS
.DS_Store

Expand Down
382 changes: 382 additions & 0 deletions Automated_Search_Pipeline.md

Large diffs are not rendered by default.

289 changes: 289 additions & 0 deletions FLEXATTENTION_REPORT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
# FlexAttention Integration into MASE

> **Project:** Hardware-Aware Transformer Optimisation: Integrating Programmable Attention, Triton Kernel Fusion, and Multi-Objective NAS
> **Author:** Ali Haidar, Imperial College London, ali.haidar25@imperial.ac.uk
> **Component:** FlexAttention Integration Pass
> **Hardware:** NVIDIA L40S (48 GB), PyTorch 2.6, bfloat16
> **Date:** March 2026

---

## 1. Overview

FlexAttention (`torch.nn.attention.flex_attention`) is PyTorch's newest attention API (v2.5+). The idea is simple: you write a small Python function that describes your attention pattern, and `torch.compile` fuses it into a Triton CUDA kernel — with performance on par with FlashAttention-2, but for *any* pattern, not just causal.

The real win is **block-sparse masking**. FlexAttention divides the attention matrix into 128x128 blocks and skips any block that's entirely masked out. For patterns like sliding-window attention, that means most of the matrix gets skipped at long sequences.

We built a MASE transform pass that swaps out SDPA attention modules for FlexAttention across Llama, Mistral, and BERT. The pass includes a library of 5 composable score modifications (causal, sliding window, ALiBi, ALiBi+SWA, document masking) and handles the various edge cases that come with integrating a fairly new API into an existing compiler framework.

| Component | Specification |
|-----------|---------------|
| GPU | NVIDIA L40S (48 GB VRAM) |
| PyTorch | 2.6.0 |
| Precision | bfloat16 (except Exp 2: float32) |
| Model | Medium LLaMA (~1.3B params) |
| Config | hidden=2048, layers=16, heads=16, kv_heads=4 |

---

## 2. Implementation

The core implementation lives in three files:

- **`flex_attention_transform.py`** (~720 LOC) — the main pass. Walks the model, identifies attention modules, and replaces them with FlexAttention subclasses. Weights are transferred via `load_state_dict` so everything (projections, RoPE, KV cache) is preserved. Only the kernel dispatch changes.
- **`score_mods.py`** (~250 LOC) — the score/mask modification library.
- **`test_flex_attention.py`** (~800 LOC) — 40 tests (26 CPU, 14 CUDA) covering score_mod logic, module replacement, forward/backward correctness, and bfloat16 support.

Usage is straightforward:
```python
model, stats = flex_attention_transform_pass(model, {
"score_mod": "sliding_window",
"score_mod_kwargs": {"window_size": 256},
"use_block_mask": True,
})
```

The kernel is compiled once per process with `torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")`. We use `dynamic=False` because of a known symbolic-shape bug in PyTorch 2.6's `flex_decoding` kernel — this isn't a limitation of our work, it's a documented upstream issue. The trade-off is that the kernel specialises per sequence length, but block mask caching makes this a non-issue during fixed-length training.

### Score Modifications

Each score_mod follows the `(score, b, h, q_idx, kv_idx) -> score` signature. When `use_block_mask=True`, the pass auto-pairs each score_mod with a corresponding mask_mod for block-level sparsity.

| Name | Type | Description |
|------|------|-------------|
| `causal` | Direct | Standard autoregressive: `q_idx >= kv_idx` |
| `sliding_window` | Factory(`window_size`) | Causal + distance limit |
| `alibi` | Factory(`num_heads`) | Attention with Linear Biases |
| `alibi_sliding_window` | Factory(`num_heads`, `window_size`) | ALiBi + sliding window composed |
| `document_mask` | Factory(`doc_len`) | Attend only within same document (for sequence packing) |

---

## 3. Challenges

Getting FlexAttention to work reliably inside MASE involved several non-trivial issues:

- **ALiBi slopes on CUDA** — This was probably the most confusing bug. ALiBi computes per-head slope tensors, and if those live on CPU, TorchInductor crashes during Triton codegen with an unhelpful error. The fix is to materialise them on CUDA at construction time.

- **`dynamic=False` workaround** — PyTorch 2.6 has a known bug where `torch.compile` with `dynamic=True` fails on the `flex_decoding` kernel's `get_split_k` specialisation. We compile with `dynamic=False` to sidestep this. It's not ideal, but block mask caching means the recompilation cost is amortised in practice.

- **Block mask caching** — `create_block_mask()` is surprisingly expensive if called every forward pass. We cache by `(Q_LEN, KV_LEN, device)` and reuse across training steps with fixed sequence lengths.

- **Native GQA** — FlexAttention supports Grouped Query Attention directly via `enable_gqa=True`, so we don't need the `repeat_kv` head expansion that SDPA requires. This saves memory and bandwidth.

- **Short sequences** — Below 128 tokens there's only one block, so block masking can't help. We skip it and the kernel still runs fine.

- **Contiguity** — After the reshape and transpose for multi-head layout, Q/K/V aren't always contiguous. We add `.contiguous()` calls before the kernel.

- **Graceful fallbacks** — Some configs (e.g., `output_attentions=True`, BERT cross-attention) aren't compatible with FlexAttention. The subclasses detect these and fall back to the parent class's eager path.

---

## 4. Experiments

All experiments use the medium LLaMA config (hidden=2048, 16 layers, 16 heads, 4 KV heads, batch=2, bfloat16) on an NVIDIA L40S unless noted otherwise.

Experiment results are found in: **[experiments/flex_attention/results/](experiments/flex_attention/results/)**

---

### Exp 1a: Inference Latency

Baseline comparison — how does FlexAttention compare to SDPA for causal and sliding-window inference?

| Method | 256 | 512 | 1024 | 2048 | 4096 |
|--------|-----|-----|------|------|------|
| SDPA causal | 11.78 ms | 13.15 ms | 22.02 ms | 44.58 ms | 106.31 ms |
| SDPA SWA(256) | 11.63 ms | 13.51 ms | 24.01 ms | 52.39 ms | 137.77 ms |
| Flex causal | 12.30 ms | 13.24 ms | 22.11 ms | 44.49 ms | 104.45 ms |
| Flex SWA(256) | 12.24 ms | 13.21 ms | 21.61 ms | 42.15 ms | **94.65 ms** |

Flex matches SDPA on causal (1.02x) and is **1.46x faster** on SWA at seq=4096. Memory usage is essentially identical.

![Inference Latency](experiments/flex_attention/results/figures/fig1_inference_latency.png)

---

### Exp 1b: Training Latency

Same comparison but for full training steps (forward + backward).

| Method | 256 | 512 | 1024 | 2048 | 4096 |
|--------|-----|-----|------|------|------|
| SDPA causal | 37.96 ms | 42.55 ms | 75.08 ms | 156.08 ms | 370.57 ms |
| SDPA SWA(256) | 37.47 ms | 45.76 ms | 87.33 ms | 207.70 ms | 563.90 ms |
| Flex causal | 37.82 ms | 43.85 ms | 76.31 ms | 159.32 ms | 377.06 ms |
| Flex SWA(256) | 37.95 ms | 43.72 ms | 73.38 ms | 147.52 ms | **328.78 ms** |

The SWA speedup is even larger during training: **1.72x** at seq=4096. This makes sense — backprop through SDPA SWA still processes the full causal triangle.

![Training Latency](experiments/flex_attention/results/figures/figS1_training_latency.png)

---

### Exp 2: Training Equivalence

Before benchmarking performance, we needed to verify that FlexAttention actually produces the same results as SDPA. We ran 50 training steps on a tiny LLaMA (hidden=256, 4 layers, float32, seed=42) and compared the loss and gradient norms at every step.

- **Max loss difference:** 1.43 x 10^-6
- **Max gradient norm difference:** 2.96 x 10^-7

The two backends are numerically indistinguishable. The bottom panel shows the per-step absolute difference on a log scale — it's all noise-level.

![Training Equivalence](experiments/flex_attention/results/figures/fig2_training_equivalence.png)

---

### Exp 3: Block Mask Ablation

How much does block-level sparsity actually matter? We ran FlexAttention with and without `block_mask` for both causal and SWA.

| Method | 256 | 512 | 1024 | 2048 | 4096 |
|--------|-----|-----|------|------|------|
| Causal + block_mask | 12.47 ms | 13.24 ms | 22.05 ms | 44.24 ms | 104.22 ms |
| Causal (no block_mask) | 12.96 ms | 13.56 ms | 22.26 ms | 46.21 ms | 113.00 ms |
| SWA + block_mask | 12.58 ms | 13.23 ms | 21.44 ms | 42.11 ms | **94.01 ms** |
| SWA (no block_mask) | 13.02 ms | 13.50 ms | 22.17 ms | 46.41 ms | 113.05 ms |

Block masking gives a 1.08x speedup for causal and **1.20x for SWA** at seq=4096. Without it, SWA degrades to causal-level performance because the kernel evaluates every block regardless.

![Block Mask Ablation](experiments/flex_attention/results/figures/figS2_block_mask_ablation.png)

---

### Exp 4: Mistral Sliding Window

We also tested on Mistral to confirm the speedup isn't Llama-specific.

| Method | 256 | 512 | 1024 | 2048 | 4096 |
|--------|-----|-----|------|------|------|
| Native Mistral SWA (SDPA) | 11.96 ms | 14.03 ms | 24.35 ms | 54.02 ms | 141.98 ms |
| Flex Mistral SWA | 12.53 ms | 13.45 ms | 21.84 ms | 42.58 ms | **94.68 ms** |

**1.50x speedup** at seq=4096 — consistent with the Llama numbers. The benefit is architecture-independent.

![Mistral SWA](experiments/flex_attention/results/figures/figS3_mistral_swa.png)

---

### Exp 6: Compound Masks (ALiBi + Sliding Window)

Can we compose ALiBi bias with sliding-window sparsity without paying extra? The idea is that `torch.compile` should fuse both into one kernel.

| Method | 256 | 1024 | 4096 |
|--------|-----|------|------|
| Flex Causal | 38.72 ms | 74.46 ms | 370.92 ms |
| Flex Sliding Window | 38.65 ms | 71.96 ms | 324.57 ms |
| Flex ALiBi + SWA | 39.25 ms | 72.15 ms | **324.55 ms** |

ALiBi + SWA (324.55 ms) matches plain SWA (324.57 ms) within noise. Composition is genuinely free — the compiler fuses both operations into one Triton kernel.

![Compound Masks](experiments/flex_attention/results/figures/figS5_compound_masks.png)

---

### Exp 7: Kernel Profiling

We generated PyTorch trace files for Chrome DevTools / Perfetto analysis. Output in `results/traces/` — useful for inspecting kernel launch counts, memory traffic, and fusion boundaries, but not something that fits in a static figure.

---

### Exp 8: Document Masking (Sequence Packing)

This is arguably FlexAttention's best use case. When you pack multiple documents into one sequence for training, you need to prevent attention across document boundaries. SDPA can't express this natively — you have to materialise a full N x N mask. FlexAttention handles it with block sparsity.

| Method | 1024 | 4096 | 8192 |
|--------|------|------|------|
| SDPA Causal (FA2 Baseline) | 74.57 ms | 369.12 ms | 895.22 ms |
| SDPA Document Mask | 87.10 ms | 561.89 ms | 1687.26 ms |
| Flex Document Mask | 75.89 ms | 333.77 ms | **748.31 ms** |

At seq=8192, Flex is **2.25x faster** than SDPA's manual document mask, and actually **1.20x faster than SDPA's causal baseline** — because the block mask prunes cross-document blocks.

![Document Masking](experiments/flex_attention/results/figures/fig3_document_masking.png)

---

### Exp 9: Batch Sensitivity

A quick sanity check — does FlexAttention hold up across batch sizes?

| Batch Size | SDPA causal | Flex causal | Ratio |
|-----------|------------|------------|-------|
| 1 | 13.44 ms | 13.76 ms | 0.98x |
| 2 | 22.45 ms | 22.68 ms | 0.99x |
| 4 | 43.22 ms | 43.13 ms | 1.00x |
| 8 | 97.73 ms | 96.41 ms | 1.01x |

Parity across the board. No surprises here.

![Batch Sensitivity](experiments/flex_attention/results/figures/figS6_batch_sensitivity.png)

---

### Exp 10: Decode Generation

This is where FlexAttention struggles. During autoregressive generation, each decode step has Q_LEN=1 — a single token. There's no block sparsity to exploit, and `torch.compile` overhead dominates.

| Method | Prompt=128 | Prompt=512 | Prompt=1024 |
|--------|-----------|-----------|------------|
| SDPA causal (per-token ms) | 11.51 | 11.47 | 11.46 |
| Flex causal (per-token ms) | 13.65 | 49.85 | 49.91 |
| Flex SWA (per-token ms) | 50.75 | 50.53 | 50.61 |

Flex is **4.3-4.4x slower** per token at decode. SDPA's hand-optimised CUDA kernels are much better here. This only affects token-by-token generation — prefill and training are unaffected. Future PyTorch releases with optimised `flex_decoding` kernels should help.

![Decode Caveat](experiments/flex_attention/results/figures/fig5_decode_caveat.png)

---

### Exp 11: Throughput

Translating latency into tokens/sec makes the scaling story clearer.

#### Training Throughput (tok/sec)

| Method | 256 | 512 | 1024 | 2048 | 4096 |
|--------|-----|-----|------|------|------|
| SDPA causal | 13,386 | 24,353 | 27,246 | 26,346 | 22,127 |
| SDPA SWA(256) | 13,461 | 22,362 | 23,469 | 19,782 | 14,498 |
| Flex causal | 13,219 | 23,296 | 26,936 | 25,744 | 21,808 |
| Flex SWA(256) | 13,330 | 23,428 | 28,061 | 27,903 | **25,070** |

This is the headline number: Flex SWA achieves **1.73x higher training throughput** than SDPA SWA at seq=4096 (25,070 vs 14,498 tok/sec). The key insight is that SDPA SWA throughput *collapses* beyond seq=1024 because it falls back to the full attention matrix, while Flex maintains near-causal throughput thanks to block sparsity.

![Training Throughput](experiments/flex_attention/results/figures/fig6_training_throughput.png)

---

### Exp 12: GQA Isolation

Finally, we tested whether the FlexAttention advantage depends on the head configuration. We swept across MHA, GQA-4, and MQA with three attention patterns.

#### Speedup Matrix (SDPA / Flex) at seq_len=4096

| Head Config | Causal | SWA(256) | ALiBi+SWA(256) |
|------------|--------|----------|----------------|
| MHA (16/16) | 0.98x | **1.63x** | **1.63x** |
| GQA-4 (16/4) | 0.99x | **1.72x** | **1.72x** |
| MQA (16/1) | 0.95x | **1.74x** | **1.73x** |

Three things stand out: (1) causal shows no regression; (2) SWA gives 1.63-1.74x regardless of head grouping; (3) ALiBi+SWA matches plain SWA exactly, confirming the free composition result from Exp 6 holds across all configs.

![GQA Heatmap](experiments/flex_attention/results/figures/fig4_gqa_heatmap.png)

---

## 5. Summary

| Finding | Value |
|---------|-------|
| Causal attention parity | 1.02x |
| SWA inference speedup (seq=4096) | **1.46x** |
| SWA training speedup (seq=4096) | **1.72x** |
| SWA training throughput (seq=4096) | **1.73x** (25,070 vs 14,498 tok/sec) |
| Document masking speedup (seq=8192) | **2.25x** |
| ALiBi+SWA composition overhead | Zero |
| Decode generation slowdown | 4.3-4.4x |

**The bottom line:** FlexAttention is a safe drop-in for causal attention and a significant win for anything sparse. Sliding window gets 1.46-1.74x depending on the workload, document masking gets 2.25x, and composing patterns like ALiBi+SWA is genuinely free. The one caveat is decode generation, where `torch.compile` overhead makes it 4.3x slower — but that's a known PyTorch limitation that only affects token-by-token generation, not training or prefill.
Loading
Loading