Skip to content

Optimizations for Decode Attn fp8 kernel MI350#75

Closed
mycpuorg wants to merge 12 commits into
ROCm:developfrom
mycpuorg:manrao/decode_attn_fp8_sept_19
Closed

Optimizations for Decode Attn fp8 kernel MI350#75
mycpuorg wants to merge 12 commits into
ROCm:developfrom
mycpuorg:manrao/decode_attn_fp8_sept_19

Conversation

@mycpuorg
Copy link
Copy Markdown

Optimizations for Decode Attn fp8 kernel

 - Baseline (MI350 FP8 split‑K forward)
      - Runtime ~125 µs (BF16 path was 110 µs).
      - rocprof showed the Triton JIT emitted a 512‑VGPR kernel with 56 spills, so each SIMD carried only a single wavefront.
      - The inner loop kept every K/V fragment plus dequant buffers resident, and the HIP autotuner could pick even heavier configs (≥150 VGPR), cementing the one‑wave bottleneck.
  - Streaming K and On‑Demand V ( xformers/ops/fmha/_triton/splitk_kernels.py )
      - Rewrote the FP8 branch so each quantization group loads K directly into the dot product and reloads V only when updating the accumulator.
      - Removed the persistent register lists for K/V, collapsing VGPR usage to 108 with zero spills; MI350 can now run two waves per SIMD.
      - Result: FP8 runtime dropped to ~105 µs, now beating BF16’s 110 µs.
  - HIP Autotune Guardrails (same file)
      - Constrained the HIP autotuner to tiles ≤64×64 and ≤4 warps, preventing Triton from revisiting the high‑VGPR plans.
      - Ensures every new launch stays in the low‑register regime uncovered by the streaming change.
  - Forced HIP FP8 Launch Parameters ( xformers/ops/fmha/triton_splitk.py )
      - Added FwOp.force_kernel_config and, by default, return the measured best tuple (BLOCK_M=16, BLOCK_N=64, num_stages=2, num_warps=1) whenever FP8 scale/shift tensors are present.
      - Eliminates heuristics drifting at runtime and locks in the ~105 µs profile.

scxiao and others added 12 commits August 21, 2025 17:21
 - Baseline (MI350 FP8 split‑K forward)
      - Runtime ~125 µs (BF16 path was 110 µs).
      - rocprof showed the Triton JIT emitted a 512‑VGPR kernel with 56 spills, so each SIMD carried only a single wavefront.
      - The inner loop kept every K/V fragment plus dequant buffers resident, and the HIP autotuner could pick even heavier configs (≥150 VGPR), cementing the one‑wave bottleneck.
  - Streaming K and On‑Demand V ( xformers/ops/fmha/_triton/splitk_kernels.py )
      - Rewrote the FP8 branch so each quantization group loads K directly into the dot product and reloads V only when updating the accumulator.
      - Removed the persistent register lists for K/V, collapsing VGPR usage to 108 with zero spills; MI350 can now run two waves per SIMD.
      - Result: FP8 runtime dropped to ~105 µs, now beating BF16’s 110 µs.
  - HIP Autotune Guardrails (same file)
      - Constrained the HIP autotuner to tiles ≤64×64 and ≤4 warps, preventing Triton from revisiting the high‑VGPR plans.
      - Ensures every new launch stays in the low‑register regime uncovered by the streaming change.
  - Forced HIP FP8 Launch Parameters ( xformers/ops/fmha/triton_splitk.py )
      - Added FwOp.force_kernel_config and, by default, return the measured best tuple (BLOCK_M=16, BLOCK_N=64, num_stages=2, num_warps=1) whenever FP8 scale/shift tensors are present.
      - Eliminates heuristics drifting at runtime and locks in the ~105 µs profile.
- HIP Autotune Guardrails (same file)
- Constrained the HIP autotuner to tiles ≤64×64 and ≤4 warps, preventing Triton from revisiting the high‑VGPR plans.
- Ensures every new launch stays in the low‑register regime uncovered by the streaming change.
- Forced HIP FP8 Launch Parameters ( xformers/ops/fmha/triton_splitk.py )
- Added FwOp.force_kernel_config and, by default, return the measured best tuple (BLOCK_M=16, BLOCK_N=64, num_stages=2, num_warps=1) whenever FP8 scale/shift tensors are present.
- Eliminates heuristics drifting at runtime and locks in the ~105 µs profile.
@mycpuorg
Copy link
Copy Markdown
Author

dup of #74

@mycpuorg mycpuorg closed this Sep 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants