Running with seqlen=8192
Input shapes: q=torch.Size([8192, 32, 128]), k=torch.Size([8192, 4, 128]), v=torch.Size([8192, 4, 128])
cu_seqlens: tensor([ 0, 8192], device='cuda:0', dtype=torch.int32)
block_size: 64, topk: 16
Computing topk_idx using compressed_attention...
topk_idx shape: torch.Size([4, 8192, 16]), dtype: torch.int32
topk_idx range: [-1, 127]
H, N, TopK: 4, 8192, 16
num_blocks: 128
causal: True
Warming up FSA optimized fwd kernel...
Benchmarking FSA optimized fwd kernel...
[topk_sparse_attention_fwd_opt] Time: 123.353 ms
Warming up reference fwd kernel...
Benchmarking reference fwd kernel...
[topk_sparse_attention_fwd_ref] Time: 31.077 ms
Warming up reference bwd kernel...
[--bwd dq] Time: 23.799 ms
Warming up FSA optimized bwd kernel...
Benchmarking FSA optimized bwd kernel...
[topk_sparse_attention_bwd_opt] Time: 110.593 ms
✅ [Fwd] Output accuracy test PASSED (atol=1e-2, rtol=1e-2)
✅ [Fwd] LSE accuracy test PASSED (atol=1e-5, rtol=1e-5)
✅ [Bwd] Output accuracy test PASSED (atol=9e-1, rtol=1e-3)
==================================================
BENCHMARK SUMMARY
==================================================
Configuration:
- Sequence length: 8192
- Num q heads: 32
- Num kv heads: 4
- Head dim: 128
- Block size: 64
- TopK: 16
- Kernel size: 32
- Kernel stride: 16
- Causal: True
Accuracy:
- [Fwd] Output diff: 1.56e-02 (relative: 5.04e-03)
- [Fwd] LSE diff: 5.72e-06
- [Bwd] dQ diff: 7.81e-03 (relative: 9.84e-04)
Performance:
- [Fwd] Reference: 3.108 ms
- [Fwd] Optimized: 12.335 ms
- [Fwd] Ratio (ref/opt): 0.252x
- [Bwd] Reference: 2.380 ms
- [Bwd] Optimized: 11.059 ms
- [Bwd] Ratio (ref/opt): 0.215x
🎉 Benchmark completed successfully!
Running with seqlen=16384
Input shapes: q=torch.Size([16384, 32, 128]), k=torch.Size([16384, 4, 128]), v=torch.Size([16384, 4, 128])
cu_seqlens: tensor([ 0, 16384], device='cuda:0', dtype=torch.int32)
block_size: 64, topk: 16
Computing topk_idx using compressed_attention...
topk_idx shape: torch.Size([4, 16384, 16]), dtype: torch.int32
topk_idx range: [-1, 255]
H, N, TopK: 4, 16384, 16
num_blocks: 256
causal: True
Warming up FSA optimized fwd kernel...
Benchmarking FSA optimized fwd kernel...
[topk_sparse_attention_fwd_opt] Time: 164.061 ms
Warming up reference fwd kernel...
Benchmarking reference fwd kernel...
[topk_sparse_attention_fwd_ref] Time: 63.484 ms
Warming up reference bwd kernel...
[--bwd dq] Time: 48.687 ms
Warming up FSA optimized bwd kernel...
Benchmarking FSA optimized bwd kernel...
[topk_sparse_attention_bwd_opt] Time: 132.298 ms
✅ [Fwd] Output accuracy test PASSED (atol=1e-2, rtol=1e-2)
✅ [Fwd] LSE accuracy test PASSED (atol=1e-5, rtol=1e-5)
✅ [Bwd] Output accuracy test PASSED (atol=9e-1, rtol=1e-3)
==================================================
BENCHMARK SUMMARY
==================================================
Configuration:
- Sequence length: 16384
- Num q heads: 32
- Num kv heads: 4
- Head dim: 128
- Block size: 64
- TopK: 16
- Kernel size: 32
- Kernel stride: 16
- Causal: True
Accuracy:
- [Fwd] Output diff: 1.56e-02 (relative: 5.04e-03)
- [Fwd] LSE diff: 5.72e-06
- [Bwd] dQ diff: 1.56e-02 (relative: 1.81e-03)
Performance:
- [Fwd] Reference: 6.348 ms
- [Fwd] Optimized: 16.406 ms
- [Fwd] Ratio (ref/opt): 0.387x
- [Bwd] Reference: 4.869 ms
- [Bwd] Optimized: 13.230 ms
- [Bwd] Ratio (ref/opt): 0.368x
🎉 Benchmark completed successfully!
Running with seqlen=32768
Input shapes: q=torch.Size([32768, 32, 128]), k=torch.Size([32768, 4, 128]), v=torch.Size([32768, 4, 128])
cu_seqlens: tensor([ 0, 32768], device='cuda:0', dtype=torch.int32)
block_size: 64, topk: 16
Computing topk_idx using compressed_attention...
topk_idx shape: torch.Size([4, 32768, 16]), dtype: torch.int32
topk_idx range: [-1, 16]
H, N, TopK: 4, 32768, 16
num_blocks: 17
causal: True
Warming up FSA optimized fwd kernel...
Benchmarking FSA optimized fwd kernel...
[topk_sparse_attention_fwd_opt] Time: 282.449 ms
Warming up reference fwd kernel...
Benchmarking reference fwd kernel...
[topk_sparse_attention_fwd_ref] Time: 128.365 ms
Warming up reference bwd kernel...
[--bwd dq] Time: 111.638 ms
Warming up FSA optimized bwd kernel...
Benchmarking FSA optimized bwd kernel...
[topk_sparse_attention_bwd_opt] Time: 195.669 ms
✅ [Fwd] Output accuracy test PASSED (atol=1e-2, rtol=1e-2)
✅ [Fwd] LSE accuracy test PASSED (atol=1e-5, rtol=1e-5)
✅ [Bwd] Output accuracy test PASSED (atol=9e-1, rtol=1e-3)
==================================================
BENCHMARK SUMMARY
==================================================
Configuration:
- Sequence length: 32768
- Num q heads: 32
- Num kv heads: 4
- Head dim: 128
- Block size: 64
- TopK: 16
- Kernel size: 32
- Kernel stride: 16
- Causal: True
Accuracy:
- [Fwd] Output diff: 1.56e-02 (relative: 4.39e-03)
- [Fwd] LSE diff: 6.68e-06
- [Bwd] dQ diff: 7.81e-03 (relative: 1.14e-03)
Performance:
- [Fwd] Reference: 12.837 ms
- [Fwd] Optimized: 28.245 ms
- [Fwd] Ratio (ref/opt): 0.454x
- [Bwd] Reference: 11.164 ms
- [Bwd] Optimized: 19.567 ms
- [Bwd] Ratio (ref/opt): 0.571x
🎉 Benchmark completed successfully!
Running with seqlen=65536
Input shapes: q=torch.Size([65536, 32, 128]), k=torch.Size([65536, 4, 128]), v=torch.Size([65536, 4, 128])
cu_seqlens: tensor([ 0, 65536], device='cuda:0', dtype=torch.int32)
block_size: 64, topk: 16
Computing topk_idx using compressed_attention...
topk_idx shape: torch.Size([4, 65536, 16]), dtype: torch.int32
topk_idx range: [-1, 1023]
H, N, TopK: 4, 65536, 16
num_blocks: 1024
causal: True
Warming up FSA optimized fwd kernel...
Benchmarking FSA optimized fwd kernel...
[topk_sparse_attention_fwd_opt] Time: 564.152 ms
Warming up reference fwd kernel...
Benchmarking reference fwd kernel...
[topk_sparse_attention_fwd_ref] Time: 256.047 ms
Warming up reference bwd kernel...
[--bwd dq] Time: 223.466 ms
Warming up FSA optimized bwd kernel...
Benchmarking FSA optimized bwd kernel...
[topk_sparse_attention_bwd_opt] Time: 353.789 ms
✅ [Fwd] Output accuracy test PASSED (atol=1e-2, rtol=1e-2)
✅ [Fwd] LSE accuracy test PASSED (atol=1e-5, rtol=1e-5)
✅ [Bwd] Output accuracy test PASSED (atol=9e-1, rtol=1e-3)
==================================================
BENCHMARK SUMMARY
==================================================
Configuration:
- Sequence length: 65536
- Num q heads: 32
- Num kv heads: 4
- Head dim: 128
- Block size: 64
- TopK: 16
- Kernel size: 32
- Kernel stride: 16
- Causal: True
Accuracy:
- [Fwd] Output diff: 1.56e-02 (relative: 4.67e-03)
- [Fwd] LSE diff: 7.63e-06
- [Bwd] dQ diff: 7.81e-03 (relative: 1.13e-03)
Performance:
- [Fwd] Reference: 25.605 ms
- [Fwd] Optimized: 56.415 ms
- [Fwd] Ratio (ref/opt): 0.454x
- [Bwd] Reference: 22.347 ms
- [Bwd] Optimized: 35.379 ms
- [Bwd] Ratio (ref/opt): 0.632x
Summary
While reproducing the H200 results, I’m seeing the opposite of the advertised speedups: the optimized FSA kernels run slower than the reference kernels across 8k–64k sequence lengths. In addition, at seqlen=32768 the runtime prints an anomalous
num_blocks: 17and the reportedtopk_idxrange is[-1, 16], unlike the other sequence lengths.This conflicts with the figure for
GQA=8, (64, 16), H200(screenshot attached), which suggests the optimized path should be significantly faster (e.g., ~3–5× at larger sequence lengths).What I ran
Hardware: H200 (Hopper)
Config used by the benchmark (from the log):
num_q_heads=32,num_kv_heads=4,head_dim=128block_size=64,TopK=16kernel_size=32,kernel_stride=16causal=TrueInput shapes (example):
q=[N, 32, 128], k=[N, 4, 128], v=[N, 4, 128]withcu_seqlens=[0, N]on CUDAObserved behavior (timings)
All timings are what the benchmark printed.
seqlen=8192
seqlen=16384
seqlen=32768⚠️ see anomaly below
seqlen=65536
All accuracy checks passed:
Anomaly at seqlen=32768
Only for N=32768 the log shows:
For the other lengths it prints:
[-1, 127],num_blocks: 128[-1, 255],num_blocks: 256[-1, 1023],num_blocks: 1024So 32k appears to be collapsing to 17 blocks (range 0..16), which seems wrong given
block_size=64(expected 512 blocks). This might indicate an indexing/typing or compression issue specific to that length boundary.Expected behavior
Based on the repo’s H200 figure for
GQA=8, (64, 16)(attached), I expected the optimized path to outperform the reference path—especially at longer sequences.Environment
Repo lists:
My key versions (from
pip list):Full benchmark logs
Click to expand
Questions / hypotheses
flash-attn==2.6.3, but my environment hasflash_attn==2.7.3(andflash_attn_3==3.0.0b1). Could newer FA versions route a different kernel path or change layouts in a way that hurts FSA perf?sm_90a) or Triton settings to hit the intended kernels? If so, could you share the exact environment used for the H200 figure?num_blocks: 17at N=32768? Thetopk_idxdtype prints asint32, so it doesn’t look like an int16 overflow, but it does look like a discretization/wrapping.Thanks
Happy to run any additional diagnostics (env vars, debug prints, Nsight Systems/Compute, etc.) and share more logs to help narrow this down.