feat(EP): decouple fp8_blockwise combine scale_dim from user config#315
Merged
Conversation
Introduce a dedicated `fp8BlockwiseCombineScaleDim` on EpDispatchCombineHandle, plumbed through EpDispatchCombineArgs<T> / ArgsRaw to the kernel. The value is driven by a new env var `MORI_FP8_COMBINE_SCALE_DIM` (default 56, matching block_elems = 7168/56 = 128 for the AccumNum=8 + VecBytes=8 dequant specialization). EpDispatchCombineConfig.scaleDim/scaleTypeSize remain reserved for caller-provided dispatch scales (e.g. FP4 input) and are no longer consulted by fp8_blockwise combine. shmemOutScalesMemObj is only allocated when the caller provides user scales; shmemInpScalesMemObj takes max(userScaleSize, fp8BlockwiseScaleSize). Verified on MI300X / EP8 / hidden_dim=7168: - pytest test_dispatch_combine_intranode fp8_blockwise: 16/16 pass - bench_precision_fp8_blockwise: SNR/cos_sim unchanged across normal / lognormal / two_bucket distributions - bench_dispatch_combine: combine ~910 us (~19.5% faster than bf16 no-quant baseline), matching PR #311
73f7834 to
eaf4958
Compare
Contributor
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
fp8BlockwiseCombineScaleDimonEpDispatchCombineHandle, plumbed throughEpDispatchCombineArgs<T>/ArgsRawto the kernel. Driven by a new env varMORI_FP8_COMBINE_SCALE_DIM(default56, matchingblock_elems = 7168 / 56 = 128for theAccumNum=8 + VecBytes=8dequant specialization).EpDispatchCombineConfig.scaleDim/scaleTypeSizekeep their original semantics (caller-provided dispatch scales, e.g. FP4 input). fp8_blockwise combine no longer consults them.shmemOutScalesMemObjis allocated only when the caller provides user scales;shmemInpScalesMemObjis sized asmax(userScaleSize, fp8BlockwiseScaleSize). Pybindget_dispatch_output_ptrsguards withIsValid().EpDispatchCombineOpreads the effective fp8 combinescale_dimfromhandle_infointo a private attribute;self.config.scale_dimis no longer mutated by op construction.Test plan
Verified on MI300X, EP=8,
hidden_dim=7168,max-tokens=4096,zero-copy=0:pytest tests/python/ops/test_dispatch_combine_intranode.py::test_dispatch_combine -k "fp8_blockwise and data_type0 and True"— 16/16 passMORI_FP8_COMBINE_SCALE_DIM=112correctly drives the internal combine scale_dim to 112 whileconfig.scaleDim(user value) stays untouched.bench_dispatch_combine.py --cmd bench --quant-type fp8_blockwise: combine ~910 us (~19.5% faster than bf16 no-quant baseline ~1131 us), matching PR feat(EP): FP8 blockwise quantization for IntraNode combine #311.--scale-dim∈ {0, 32, 56}.