Skip to content

feat(EP): decouple fp8_blockwise combine scale_dim from user config#315

Merged
maning00 merged 3 commits into
mainfrom
feat/ep-fp8-combine-internal-scaledim
May 13, 2026
Merged

feat(EP): decouple fp8_blockwise combine scale_dim from user config#315
maning00 merged 3 commits into
mainfrom
feat/ep-fp8-combine-internal-scaledim

Conversation

@maning00
Copy link
Copy Markdown
Contributor

@maning00 maning00 commented May 12, 2026

Summary

  • Add a dedicated fp8BlockwiseCombineScaleDim on EpDispatchCombineHandle, plumbed through EpDispatchCombineArgs<T> / ArgsRaw to the kernel. Driven by a new env var MORI_FP8_COMBINE_SCALE_DIM (default 56, matching block_elems = 7168 / 56 = 128 for the AccumNum=8 + VecBytes=8 dequant specialization).
  • EpDispatchCombineConfig.scaleDim / scaleTypeSize keep their original semantics (caller-provided dispatch scales, e.g. FP4 input). fp8_blockwise combine no longer consults them.
  • shmemOutScalesMemObj is allocated only when the caller provides user scales; shmemInpScalesMemObj is sized as max(userScaleSize, fp8BlockwiseScaleSize). Pybind get_dispatch_output_ptrs guards with IsValid().
  • Python EpDispatchCombineOp reads the effective fp8 combine scale_dim from handle_info into a private attribute; self.config.scale_dim is no longer mutated by op construction.

Test plan

Verified on MI300X, EP=8, hidden_dim=7168, max-tokens=4096, zero-copy=0:

  • pytest tests/python/ops/test_dispatch_combine_intranode.py::test_dispatch_combine -k "fp8_blockwise and data_type0 and True"16/16 pass
  • Sanity: MORI_FP8_COMBINE_SCALE_DIM=112 correctly drives the internal combine scale_dim to 112 while config.scaleDim (user value) stays untouched.
  • bench_dispatch_combine.py --cmd bench --quant-type fp8_blockwise: combine ~910 us (~19.5% faster than bf16 no-quant baseline ~1131 us), matching PR feat(EP): FP8 blockwise quantization for IntraNode combine #311.
  • Same combine latency irrespective of caller-supplied --scale-dim ∈ {0, 32, 56}.

maning00 added 2 commits May 12, 2026 12:30
Introduce a dedicated `fp8BlockwiseCombineScaleDim` on
EpDispatchCombineHandle, plumbed through EpDispatchCombineArgs<T> /
ArgsRaw to the kernel. The value is driven by a new env var
`MORI_FP8_COMBINE_SCALE_DIM` (default 56, matching block_elems = 7168/56
= 128 for the AccumNum=8 + VecBytes=8 dequant specialization).

EpDispatchCombineConfig.scaleDim/scaleTypeSize remain reserved for
caller-provided dispatch scales (e.g. FP4 input) and are no longer
consulted by fp8_blockwise combine. shmemOutScalesMemObj is only
allocated when the caller provides user scales; shmemInpScalesMemObj
takes max(userScaleSize, fp8BlockwiseScaleSize).

Verified on MI300X / EP8 / hidden_dim=7168:
- pytest test_dispatch_combine_intranode fp8_blockwise: 16/16 pass
- bench_precision_fp8_blockwise: SNR/cos_sim unchanged across normal /
  lognormal / two_bucket distributions
- bench_dispatch_combine: combine ~910 us (~19.5% faster than bf16
  no-quant baseline), matching PR #311
@maning00 maning00 force-pushed the feat/ep-fp8-combine-internal-scaledim branch from 73f7834 to eaf4958 Compare May 12, 2026 12:53
@maning00 maning00 merged commit 96ffa16 into main May 13, 2026
39 of 43 checks passed
@maning00
Copy link
Copy Markdown
Contributor Author

cc @billishyahao

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant