Skip to content

decode attn fp8#76

Draft
mycpuorg wants to merge 2 commits into
ROCm:developfrom
mycpuorg:manrao/decode_attn_fp8_develop
Draft

decode attn fp8#76
mycpuorg wants to merge 2 commits into
ROCm:developfrom
mycpuorg:manrao/decode_attn_fp8_develop

Conversation

@mycpuorg
Copy link
Copy Markdown

What does this PR do?

Fixes # (issue).

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

 - Baseline (MI350 FP8 split‑K forward)
      - Runtime ~125 µs (BF16 path was 110 µs).
      - rocprof showed the Triton JIT emitted a 512‑VGPR kernel with 56 spills, so each SIMD carried only a single wavefront.
      - The inner loop kept every K/V fragment plus dequant buffers resident, and the HIP autotuner could pick even heavier configs (≥150 VGPR), cementing the one‑wave bottleneck.
  - Streaming K and On‑Demand V ( xformers/ops/fmha/_triton/splitk_kernels.py )
      - Rewrote the FP8 branch so each quantization group loads K directly into the dot product and reloads V only when updating the accumulator.
      - Removed the persistent register lists for K/V, collapsing VGPR usage to 108 with zero spills; MI350 can now run two waves per SIMD.
      - Result: FP8 runtime dropped to ~105 µs, now beating BF16’s 110 µs.
  - HIP Autotune Guardrails (same file)
      - Constrained the HIP autotuner to tiles ≤64×64 and ≤4 warps, preventing Triton from revisiting the high‑VGPR plans.
      - Ensures every new launch stays in the low‑register regime uncovered by the streaming change.
  - Forced HIP FP8 Launch Parameters ( xformers/ops/fmha/triton_splitk.py )
      - Added FwOp.force_kernel_config and, by default, return the measured best tuple (BLOCK_M=16, BLOCK_N=64, num_stages=2, num_warps=1) whenever FP8 scale/shift tensors are present.
      - Eliminates heuristics drifting at runtime and locks in the ~105 µs profile.
- HIP Autotune Guardrails (same file)
- Constrained the HIP autotuner to tiles ≤64×64 and ≤4 warps, preventing Triton from revisiting the high‑VGPR plans.
- Ensures every new launch stays in the low‑register regime uncovered by the streaming change.
- Forced HIP FP8 Launch Parameters ( xformers/ops/fmha/triton_splitk.py )
- Added FwOp.force_kernel_config and, by default, return the measured best tuple (BLOCK_M=16, BLOCK_N=64, num_stages=2, num_warps=1) whenever FP8 scale/shift tensors are present.
- Eliminates heuristics drifting at runtime and locks in the ~105 µs profile.
@mycpuorg mycpuorg changed the title Manrao/decode attn fp8 develop decode attn fp8 Sep 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant