Skip to content

Group 1 - Hardware-Aware Transformer Optimisation: Integrating Programmable Attention, Triton Kernel Fusion, and Multi-Objective NAS#315

Open
aahaidar01 wants to merge 83 commits intoDeepWok:mainfrom
aahaidar01:main
Open

Conversation

@aahaidar01
Copy link
Copy Markdown

@aahaidar01 aahaidar01 commented Mar 27, 2026

Authors:

Ali Haidar, Dorijan Donaj Magasic, Yash Agarwal, Mahmoud El Etreby

Summary

  • FlexAttention integration pass — module-level transform replacing SDPA with torch.nn.attention.flex_attention, supporting causal, sliding-window, ALiBi, and document masking patterns with block-sparse acceleration. Up to 1.72x training speedup for sliding-window attention at seq=4096.

  • Fused Add+RMSNorm Triton kernel — custom forward/backward Triton kernel fusing the residual-add → RMSNorm pattern in transformer decoder layers. 2.98x faster than unfused PyTorch, 1.42x faster than Liger-Kernel, with 60% peak memory reduction per fusion site. Includes both FX graph-level and module-level MASE passes.

  • Automated multi-objective search pipeline — fills MASE's LatencyRunner stub with GPU timing, adds a ModuleSearchSpaceQuantizationFusion search space covering bit-width × fusion strategy, and wires everything into Optuna NSGA-II search. Produces Pareto frontiers over accuracy/perplexity, latency, and average bitwidth across BERT, TinyLlama, and Mistral.

Key results

Optimization Best speedup Condition
FlexAttention SWA (inference) 1.46x seq=4096, Llama
FlexAttention SWA (training throughput) 1.73x seq=4096, 25K vs 14.5K tok/s
FlexAttention document masking 2.25x seq=8192 vs SDPA mask
Fused RMSNorm kernel vs PyTorch 2.98x L40S, BF16
Fused RMSNorm kernel vs Liger-Kernel 1.42x L40S, BF16
Fused RMSNorm memory reduction 60% per fusion site (forward)
Mistral sliding-window (search pipeline) 7% consistent from seq≥512

Files changed

New passes

  • src/chop/passes/module/transforms/attention/flex_attention_transform.py — FlexAttention pass
  • src/chop/passes/module/transforms/attention/score_mods.py — score/mask modification library
  • src/chop/passes/graph/transforms/fused_rmsnorm/ — Triton kernel + FX graph pass
  • src/chop/passes/module/transforms/fused_ops/rmsnorm_residual_fusion.py — module-level fusion pass

Search infrastructure

  • src/chop/actions/search/strategies/runners/hardware/latency.py — GPU latency runner
  • src/chop/actions/search/search_space/quantization/module_fusion.py — fusion search space
  • src/chop/pipelines/optimization.py — pass-chain wrapper
  • configs/search/quantization_fusion_{bert,llama,mistral}.toml — search configs

Experiments & benchmarks

  • experiments/flex_attention/ — 12 experiments with JSON results and figures
  • scripts/ — search runners, benchmarks, kernel profiling
  • test/ — FlexAttention tests (40 tests), LatencyRunner tests

Documentations

Models tested

  • BERT-base (SST-2 classification) — FlexAttention + quantization search
  • TinyLlama-1.1B (WikiText-2 LM) — full pipeline with causal FlexAttention
  • Mistral-7B (WikiText-2 LM) — sliding-window FlexAttention in float16

Hardware

NVIDIA L40S (48GB), PyTorch 2.6.0, Triton 3.3.1

aahaidar01 and others added 30 commits February 5, 2026 23:23
- Patch  to support 2D inputs in Binary quantizers (previously hardcoded for 4D).
- Fix  signatures in  and  to match PyTorch autograd requirements.
…n add our plots with relative paths to that folder and prepare it for .zip submission. Added further results to .md file including lab 1.
…nalities) into mase /src files. Added pytest test scripts to test the implemented functionalities.
…ch.compile compatibility, return tuple fixes, and add training/bf16/seq512 tests
…edup, fix GQA via enable_gqa=True, add block_mask caching, and expand test suite to 39 tests
…as it was. failing silently and falling back to eager causing OOM.
…xperiments. Add alibi score mod and compound alibi and sw into score_mods.py. Uploaded results temporarily to share with collaborators.
…computations. Added 3 more experiments: decode generation, throughput (tokens/s), gqa isolation.
MahmoudEletreby and others added 30 commits March 25, 2026 16:01
…_ import

- Checkout fused_rmsnorm Triton kernel files from feature/fused-rmsnorm-residual
  branch (triton_fused_add_rmsnorm.py was missing, causing ImportError at chop
  startup and crashing all search jobs)
- Wrap fused_ops import in transforms/__init__.py with try/except so a missing
  Triton kernel never cascades to crash the entire chop framework

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Both module_fusion.py and benchmark_seqlen.py imported rmsnorm_residual_fusion_pass
which does not exist. The actual function is fused_rmsnorm_residual_transform_pass.
In the search space the error was silently swallowed by except ImportError, meaning
fused_rmsnorm was never applied in any trial. In the benchmark it caused [ERROR]
for int8_rmsnorm and int8_both strategies.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Without explicit cleanup, each trial's deep-copied model accumulates on
GPU memory across 100 trials. Move to CPU, delete, gc.collect(), and
empty_cache() after metrics are computed each trial.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…f flex_attention. Removed redundant .pbs job scripts not required for PR.
FP32 shows no benefit from the fused RMSNorm kernel (per Section III
Fig 5). Loading in float16 matches production inference dtype and
surfaces the ~1.03x model-level latency improvement.
Anchors reference lines to baseline latency at seq=1024, making the
sub-quadratic FlexAttention-SWA scaling claim visually explicit.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants