Group 1 - Hardware-Aware Transformer Optimisation: Integrating Programmable Attention, Triton Kernel Fusion, and Multi-Objective NAS#315
Open
aahaidar01 wants to merge 83 commits intoDeepWok:mainfrom
Open
Conversation
- Patch to support 2D inputs in Binary quantizers (previously hardcoded for 4D). - Fix signatures in and to match PyTorch autograd requirements.
…n add our plots with relative paths to that folder and prepare it for .zip submission. Added further results to .md file including lab 1.
…nalities) into mase /src files. Added pytest test scripts to test the implemented functionalities.
…ch.compile compatibility, return tuple fixes, and add training/bf16/seq512 tests
…s flex throughput
…edup, fix GQA via enable_gqa=True, add block_mask caching, and expand test suite to 39 tests
…as it was. failing silently and falling back to eager causing OOM.
…stematic ablation results
…Pytorch2.x documented issue
…xperiments. Add alibi score mod and compound alibi and sw into score_mods.py. Uploaded results temporarily to share with collaborators.
…ice instead of CPU.
…computations. Added 3 more experiments: decode generation, throughput (tokens/s), gqa isolation.
…_ import - Checkout fused_rmsnorm Triton kernel files from feature/fused-rmsnorm-residual branch (triton_fused_add_rmsnorm.py was missing, causing ImportError at chop startup and crashing all search jobs) - Wrap fused_ops import in transforms/__init__.py with try/except so a missing Triton kernel never cascades to crash the entire chop framework Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Both module_fusion.py and benchmark_seqlen.py imported rmsnorm_residual_fusion_pass which does not exist. The actual function is fused_rmsnorm_residual_transform_pass. In the search space the error was silently swallowed by except ImportError, meaning fused_rmsnorm was never applied in any trial. In the benchmark it caused [ERROR] for int8_rmsnorm and int8_both strategies. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Without explicit cleanup, each trial's deep-copied model accumulates on GPU memory across 100 trials. Move to CPU, delete, gc.collect(), and empty_cache() after metrics are computed each trial. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…f flex_attention. Removed redundant .pbs job scripts not required for PR.
FP32 shows no benefit from the fused RMSNorm kernel (per Section III Fig 5). Loading in float16 matches production inference dtype and surfaces the ~1.03x model-level latency improvement.
Anchors reference lines to baseline latency at seq=1024, making the sub-quadratic FlexAttention-SWA scaling claim visually explicit.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Authors:
Ali Haidar, Dorijan Donaj Magasic, Yash Agarwal, Mahmoud El Etreby
Summary
FlexAttention integration pass — module-level transform replacing SDPA with
torch.nn.attention.flex_attention, supporting causal, sliding-window, ALiBi, and document masking patterns with block-sparse acceleration. Up to 1.72x training speedup for sliding-window attention at seq=4096.Fused Add+RMSNorm Triton kernel — custom forward/backward Triton kernel fusing the residual-add → RMSNorm pattern in transformer decoder layers. 2.98x faster than unfused PyTorch, 1.42x faster than Liger-Kernel, with 60% peak memory reduction per fusion site. Includes both FX graph-level and module-level MASE passes.
Automated multi-objective search pipeline — fills MASE's
LatencyRunnerstub with GPU timing, adds aModuleSearchSpaceQuantizationFusionsearch space covering bit-width × fusion strategy, and wires everything into Optuna NSGA-II search. Produces Pareto frontiers over accuracy/perplexity, latency, and average bitwidth across BERT, TinyLlama, and Mistral.Key results
Files changed
New passes
src/chop/passes/module/transforms/attention/flex_attention_transform.py— FlexAttention passsrc/chop/passes/module/transforms/attention/score_mods.py— score/mask modification librarysrc/chop/passes/graph/transforms/fused_rmsnorm/— Triton kernel + FX graph passsrc/chop/passes/module/transforms/fused_ops/rmsnorm_residual_fusion.py— module-level fusion passSearch infrastructure
src/chop/actions/search/strategies/runners/hardware/latency.py— GPU latency runnersrc/chop/actions/search/search_space/quantization/module_fusion.py— fusion search spacesrc/chop/pipelines/optimization.py— pass-chain wrapperconfigs/search/quantization_fusion_{bert,llama,mistral}.toml— search configsExperiments & benchmarks
experiments/flex_attention/— 12 experiments with JSON results and figuresscripts/— search runners, benchmarks, kernel profilingtest/— FlexAttention tests (40 tests), LatencyRunner testsDocumentations
Models tested
Hardware
NVIDIA L40S (48GB), PyTorch 2.6.0, Triton 3.3.1