deepseek v4 rebase track#1181
Conversation
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Avoids _prepare_cp Ray dependency when only --model-dir is overridden: both fields now resolve to the same path unless user explicitly fans out to per-node local NVMe via --model-local-dir.
Pinaster/DeepSeek-V4-Flash-FP8-4layer ships with model_type=deepseek_v4 in config.json, but SGLang's get_config fallback (with SGLANG_APPLY_CONFIG_BACKUP=none) only fires on deepseek_ref. Rewrite the local config.json in-place so SGLang's _load_deepseek_temp_model gets reached. Idempotent; no-op for non-4-layer models.
This reverts commit fda4195.
…epseek-v4 @ 8e1ef3c) Verified single-node 8xH200 training of DeepSeek-V4-Flash-FP8-4layer (Pinaster/...) end-to-end through 12 GRPO rollouts (steps 0-11, all loss=0 since 4-layer prune produces gibberish + reward=0, but pipeline ran cleanly). Job raysubmit_WTh5sDDAaRcVrfTT. Required environment additions on top of the upstream Dockerfile: - sglang from sgl-project/sglang@deepseek_v4 (pip install -e python pulled sgl-kernel==0.3.21 and downgraded fastapi to 0.115.x) - flashinfer-jit-cache==0.6.8+cu129 (matching flashinfer-python 0.6.8) - tilelang==0.1.8 (PyPI release; 0.1.9 dropped wg_wait from T.gemm public API) - flash_mla 1.0.0+71c7379 from deepseek-ai/FlashMLA (CUDA-built) - fast_hadamard_transform 1.1.0 from Dao-AILab/fast-hadamard-transform - Megatron-LM at radixark/Megatron-LM PR radixark#28 (mlm-pr28, commit 8455dbf, in both /workspace/Megatron-LM and /root/Megatron-LM editable install) Run-time tweaks: - bf16 ckpt config.json: model_type "deepseek_v4" -> "deepseek_v3" (to satisfy AutoConfig.from_pretrained on transformers 4.57.1) - bf16 ckpt config.json: drop quantization_config (avoid sglang creating fp8 model params when our weights are bf16) - SGLANG_APPLY_CONFIG_BACKUP=none (otherwise sglang substitutes the packaged 43-layer config and breaks miles' hf_validate_args) dsv4_flash_to_bf16.py was written for the cluster's MXFP4 4-layer prune (DeepSeek-V4-Flash-4layer); the actual 12-step run used the standard FP8 prune (Pinaster/DeepSeek-V4-Flash-FP8-4layer) so the tool was not exercised. Keeping it for the MXFP4 path.
Drop-in replacement of the in-tree no-grad ``hc_split_sinkhorn`` / ``hc_pre_raw`` / ``hc_post_raw`` / ``hc_head_raw`` paths with calls into ``tile_kernels.modeling.mhc.ops`` (sinkhorn, pre_norm_fn, pre_split_mixes, pre_apply_mix, post, head_compute_mix, plus the fused inference ``pre_big_fuse`` for ``no_grad`` paths). Public class API (``DeepSeekV4HyperConnectionUtil``, ``HCHeadParams``) unchanged so the radixark/Megatron-LM PR radixark#28 call sites in ``transformer_layer.py`` / ``transformer_block.py`` keep working. Net effect: - forward path now matches the canonical TileKernels kernels (which fuse RMS-norm + GEMM split-K, sinkhorn, and the pre-apply mix) - backward path is enabled (the original code asserted ``_HYPER_CONNECTION_MIXER_NO_GRAD = True`` and wrapped everything in ``torch.no_grad``) - ``post = 2 * sigmoid(...)`` reproduced via ``post_mult_value=2.0``; PR's single ``hc_eps`` reused for both ``pre_eps`` and ``sinkhorn_eps`` Tweaks vs upstream TileKernels' high-level wrappers: - inline ``mhc_pre`` body so we can pass ``fuse_grad_acc=False`` to ``mhc_pre_norm_fn``. The default (``True``) requires ``mhc_post`` to have written ``grad_from_mhc_post`` onto the same residual storage during backward, but Megatron's call sites use independent ``s b hc d`` -> ``b s hc d`` einops.rearrange'd tensors for ``layer_pre`` and ``layer_post`` so the storage objects don't match. - inline ``mhc_head`` body to ``.contiguous()`` the ``mixes[..., :mhc_mult]`` slice before feeding it to ``mhc_head_compute_mix_fwd_kernel`` (the kernel asserts ``strides[0] == mhc_mult`` but the slice keeps the ``mhc_mult * (mhc_mult + 2)`` parent stride). - ``.contiguous()`` everywhere we hand bf16 tensors to TileLang kernels so the no-grad fused ``pre_big_fuse`` path doesn't trip the ``view`` stride check. Verified end-to-end: 3-rollout GRPO smoke test on DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, GRPO with rollout batch 8, 4 samples/prompt, 64 tok cap) reaches steps 0..2 with entropy_loss / logprob_abs_diff in the same band as the original in-tree implementation (rl-smoke-pass tag) — see job raysubmit_B3DgGam9vCYuDAVU.
- ``kernel/act_quant.py`` is now a thin wrapper around ``tile_kernels.quant.per_token_cast(fmt='e4m3', round_sf=True)``; shape/dtype contract preserved (``(y_fp8, s_fp32)`` with ``s.shape == (*x.shape[:-1], N // block_size)``) so callers in ``qat.py`` / ``compressor.py`` / ``v4_indexer.py`` don't change. - ``kernel/sinkhorn.py`` is removed: it was only consumed by the legacy ``hyper_connection.py`` path which now calls ``tile_kernels.modeling.mhc.ops.sinkhorn_normalize`` instead. End-to-end audit of ``miles_plugins/models/deepseek_v4/ops/kernel/``: - act_quant.py -> tile_kernels.quant.per_token_cast (this commit) - sinkhorn.py -> tile_kernels.modeling.mhc.ops (Batch 1, removed) - tilelang_indexer*.py -> no TileKernels equivalent (DSV4 DSA-specific) - tilelang_sparse_mla*.py -> no TileKernels equivalent (DSV4 sparse-MLA) So every ``kernel/`` file that has a TileKernels analogue now routes through TileKernels; the remaining files are V4-specific and have no upstream replacement. Verified end-to-end with the 12-rollout GRPO smoke harness on DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, run id 260425-054701-579 -> raysubmit_a4YH6wKJ963VrTg9). All 12 steps (0..11) completed cleanly; entropy_loss / logprob_abs_diff land in the same band as the rl-smoke-pass baseline (12-step run on the unmodified PR @ a72ed84): step | baseline ent | TK ent -----+--------------+-------- 0 | 1.6858 | 1.7286 1 | 1.7854 | 1.7321 2 | 1.7360 | 1.7803 3 | 1.7696 | 1.6750 4 | 1.6815 | 1.7532 5 | 1.7172 | 1.6865 6 | 1.7394 | 1.6949 7 | 1.7634 | 1.7217 8 | 1.6674 | 1.6871 9 | 1.7330 | 1.6943 10 | 1.6319 | 1.6998 11 | 1.6962 | 1.7321
- ops/qat.py:fp8_simulate now uses tile_kernels.quant.per_token_cast_back for the FP8→BF16 dequant step (was a manual unflatten/multiply). - tools/fp8_cast_bf16.py:weight_dequant replaces the in-tree Triton kernel with tile_kernels.quant.cast_back (128x128 block FP8 dequant). Both substitutions are bit-exact against the prior implementations on GPU (max_diff = 0.0 across 2D/3D/4D inputs and 256x768 fp8 weights).
…cision
TileKernels' MHC fp32 GEMM uses TF32 tensor cores on H100/H200 (H100+
fp32 GEMM has no full-precision tensor-core path; TF32 is the fastest
fp32 mode). PyTorch's default ``torch.backends.cuda.matmul.allow_tf32 =
False`` forces fp32 F.linear onto the SIMT path, which introduced a
~1e-4 mean-abs gap vs the TileKernels-backed HC mixer.
Setting allow_tf32 = True at deepseek_v4 plugin import time:
HC parity (TileKernels MHC fwd vs legacy in-tree fwd, no-grad):
mean_abs max_abs notes
layer_input 1.05e-5 1.56e-2 bf16 LSB output
pre.post 1.52e-5 7.65e-5
pre.comb 4.61e-6 2.87e-5
hc_post out 1.15e-8 9.77e-4 fp32-equivalent
hc_head out 1.14e-5 1.56e-2 bf16 LSB output
All ops <= 1.5e-5 mean-abs (matches the attention/indexer 1e-5 bar).
The max-abs values are 1 ULP of bf16 at the output magnitude (~6.0,
1 ULP = 2^-5 = 3.13e-2). Other fp32 matmuls in the plugin (compressor,
indexer projections) get a free TF32 speed-up as a side effect.
…ware-aware scale dtype
Replace the TileKernels per_token_cast wrapper with a verbatim port of
deepseek-ai/DeepSeek-V4-Pro/inference/kernel.py:act_quant so this code
path is bit-exact with the upstream inference kernel.
The ported act_quant exposes scale_dtype and inplace, matching official:
- scale_dtype=None auto-selects via SGLang should_deepgemm_weight_requant_ue8m0
(Blackwell + DeepGEMM JIT -> float8_e8m0fnu; Hopper -> float32).
- explicit scale_dtype overrides the auto path either way.
- inplace=True wires through to the fused quant+dequant kernel.
Drops the tile_kernels import for this file; the package was not listed
in any requirements manifest. Caller (qat.py:fp8_simulate) is unaffected:
on Hopper the auto path resolves to float32, preserving prior behavior.
Verified on H200:
- 26 kernel cases bit-exact vs official act_quant (4 shapes x 5 modes + 3D).
- 4-layer iterated fp8_simulate stable, ~2.25% mean rel noise as expected for E4M3.
- Hardware auto-resolve matches sglang.deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0.
There was a problem hiding this comment.
Code Review
This pull request introduces full support for DeepSeek-V4 models, including Flash and Pro versions, by implementing specialized weight conversion, indexer replay, and a collection of TileLang kernels for optimized attention and hyper-connection operations. It also updates SGLang dependencies to v0.5.12 and adds a comprehensive training pipeline script. Technical feedback identifies several critical issues in the newly added kernels, such as missing boundary checks in the indexer and sparse MLA kernels that could cause out-of-bounds memory access. Furthermore, a bug in the FP8 simulation casting logic and a potential NameError in the weight update function were found, along with a recommendation to refactor model-specific padding detection for improved robustness.
| if isinstance(num_new_engines, tuple): | ||
| num_new_engines = num_new_engines[0] |
There was a problem hiding this comment.
The variable num_new_engines is used here but it does not appear to be defined within the scope of the update_weights function in the current diff. This will result in a NameError at runtime. Please ensure it is correctly initialized, likely from the return value of a weight update or engine connection call.
| T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp( | ||
| x_local[i, j] / s_local[i], fp8_min, fp8_max | ||
| ))) * s_local[i], |
There was a problem hiding this comment.
When inplace=True, out_dtype is reassigned to in_dtype (e.g., BF16). The inner T.Cast(out_dtype, ...) call then casts to BF16 instead of the intended FP8 quantization format. This prevents the kernel from correctly simulating FP8 quantization noise. You should use a separate variable to hold the quantization target type (e.g., the original FP8 constant).
y_local[i, j] = T.Cast(
out_dtype,
T.Cast(compute_dtype, T.Cast(FP8, T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
))) * s_local[i],
)| T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) | ||
|
|
||
| for bq_i, bn_i in T.Parallel(block_Q, block_N): | ||
| Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] |
There was a problem hiding this comment.
The kernel writes to Logits without checking if the computed indices seq_len_i + bq_i and cu_k_s_min + nbn_i * block_N + bn_i are within the bounds of seq_len and seq_len_kv. This can lead to out-of-bounds memory access if the sequence lengths are not multiples of the block sizes.
| Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] | |
| for bq_i, bn_i in T.Parallel(block_Q, block_N): | |
| if seq_len_i + bq_i < seq_len and cu_k_s_min + nbn_i * block_N + bn_i < seq_len_kv: | |
| Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] |
| T.atomic_add( | ||
| dAttnSink[bz * block_H + h_i], | ||
| -Delta[by, s_i, bz * block_H + h_i] | ||
| * T.exp2(AttnSink[bz * block_H + h_i] * 1.44269504 - Lse[by, s_i, bz * block_H + h_i]), | ||
| ) |
There was a problem hiding this comment.
The kernel performs atomic additions to dAttnSink using an index bz * block_H + h_i that can range up to padded_H - 1. If the number of heads H is not a power of 2 (e.g., H=43 or H=61), this will result in out-of-bounds memory access. A boundary check is required.
| T.atomic_add( | |
| dAttnSink[bz * block_H + h_i], | |
| -Delta[by, s_i, bz * block_H + h_i] | |
| * T.exp2(AttnSink[bz * block_H + h_i] * 1.44269504 - Lse[by, s_i, bz * block_H + h_i]), | |
| ) | |
| for h_i in T.Parallel(block_H): | |
| if bz * block_H + h_i < H: | |
| T.atomic_add( | |
| dAttnSink[bz * block_H + h_i], | |
| -Delta[by, s_i, bz * block_H + h_i] | |
| * T.exp2(AttnSink[bz * block_H + h_i] * 1.44269504 - Lse[by, s_i, bz * block_H + h_i]), | |
| ) |
| for h_i in T.Parallel(H_per_block): | ||
| sumexp[h_i] += T.exp2(AttnSink[H0 + h_i] * 1.44269504 - m_i[h_i] * sm_scale) |
There was a problem hiding this comment.
The kernel accesses AttnSink using H0 + h_i, which can exceed the actual number of heads H because H_per_block is based on padded_H. This will cause out-of-bounds memory access for models where H is not a multiple of the block size. Please add a boundary check.
| for h_i in T.Parallel(H_per_block): | |
| sumexp[h_i] += T.exp2(AttnSink[H0 + h_i] * 1.44269504 - m_i[h_i] * sm_scale) | |
| for h_i in T.Parallel(H_per_block): | |
| if H0 + h_i < H: | |
| sumexp[h_i] += T.exp2(AttnSink[H0 + h_i] * 1.44269504 - m_i[h_i] * sm_scale) |
|
|
||
| # pad to reduce memory fragmentation and maybe make the computation faster | ||
| pad_size = parallel_state.tp.size * args.data_pad_size_multiplier | ||
| if getattr(args, "dsv4_hc_mult", 0) != 0: # TODO improve the way to detect needing this |
There was a problem hiding this comment.
The logic for detecting when to adjust pad_size for DeepSeek-V4 models relies on checking dsv4_hc_mult. As noted in the TODO, this should be refactored to a more robust detection mechanism, such as checking the model type or a specific configuration flag, to avoid relying on a specific attribute that might be missing or zero in other valid configurations.
References
- Model parameters should be retrieved from the model configuration rather than being hardcoded or inferred from potentially missing attributes.
Tracking draft PR for the DeepSeek V4 rebase onto Miles main.\n\nThis branch currently contains the full rebase track; follow-up work will split general Miles fixes/features into separate PRs so this PR can converge to DeepSeek V4 support only.