Skip to content

deepseek v4 rebase track#1181

Draft
yueming-yuan wants to merge 55 commits into
radixark:mainfrom
yueming-yuan:deepseek-v4-main-rebase
Draft

deepseek v4 rebase track#1181
yueming-yuan wants to merge 55 commits into
radixark:mainfrom
yueming-yuan:deepseek-v4-main-rebase

Conversation

@yueming-yuan
Copy link
Copy Markdown
Collaborator

Tracking draft PR for the DeepSeek V4 rebase onto Miles main.\n\nThis branch currently contains the full rebase track; follow-up work will split general Miles fixes/features into separate PRs so this PR can converge to DeepSeek V4 support only.

yueming-yuan and others added 30 commits May 21, 2026 22:28
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Avoids _prepare_cp Ray dependency when only --model-dir is overridden:
both fields now resolve to the same path unless user explicitly fans out
to per-node local NVMe via --model-local-dir.
Pinaster/DeepSeek-V4-Flash-FP8-4layer ships with model_type=deepseek_v4 in
config.json, but SGLang's get_config fallback (with SGLANG_APPLY_CONFIG_BACKUP=none)
only fires on deepseek_ref. Rewrite the local config.json in-place so SGLang's
_load_deepseek_temp_model gets reached. Idempotent; no-op for non-4-layer models.
yueming-yuan and others added 24 commits May 21, 2026 22:32
…epseek-v4 @ 8e1ef3c)

Verified single-node 8xH200 training of DeepSeek-V4-Flash-FP8-4layer
(Pinaster/...) end-to-end through 12 GRPO rollouts (steps 0-11, all
loss=0 since 4-layer prune produces gibberish + reward=0, but pipeline
ran cleanly). Job raysubmit_WTh5sDDAaRcVrfTT.

Required environment additions on top of the upstream Dockerfile:
- sglang from sgl-project/sglang@deepseek_v4 (pip install -e python pulled
  sgl-kernel==0.3.21 and downgraded fastapi to 0.115.x)
- flashinfer-jit-cache==0.6.8+cu129 (matching flashinfer-python 0.6.8)
- tilelang==0.1.8 (PyPI release; 0.1.9 dropped wg_wait from T.gemm public API)
- flash_mla 1.0.0+71c7379 from deepseek-ai/FlashMLA (CUDA-built)
- fast_hadamard_transform 1.1.0 from Dao-AILab/fast-hadamard-transform
- Megatron-LM at radixark/Megatron-LM PR radixark#28 (mlm-pr28, commit 8455dbf,
  in both /workspace/Megatron-LM and /root/Megatron-LM editable install)

Run-time tweaks:
- bf16 ckpt config.json: model_type "deepseek_v4" -> "deepseek_v3" (to
  satisfy AutoConfig.from_pretrained on transformers 4.57.1)
- bf16 ckpt config.json: drop quantization_config (avoid sglang creating
  fp8 model params when our weights are bf16)
- SGLANG_APPLY_CONFIG_BACKUP=none (otherwise sglang substitutes the
  packaged 43-layer config and breaks miles' hf_validate_args)

dsv4_flash_to_bf16.py was written for the cluster's MXFP4 4-layer prune
(DeepSeek-V4-Flash-4layer); the actual 12-step run used the standard FP8
prune (Pinaster/DeepSeek-V4-Flash-FP8-4layer) so the tool was not
exercised. Keeping it for the MXFP4 path.
Drop-in replacement of the in-tree no-grad ``hc_split_sinkhorn`` /
``hc_pre_raw`` / ``hc_post_raw`` / ``hc_head_raw`` paths with calls
into ``tile_kernels.modeling.mhc.ops`` (sinkhorn, pre_norm_fn,
pre_split_mixes, pre_apply_mix, post, head_compute_mix, plus the
fused inference ``pre_big_fuse`` for ``no_grad`` paths).

Public class API (``DeepSeekV4HyperConnectionUtil``, ``HCHeadParams``)
unchanged so the radixark/Megatron-LM PR radixark#28 call sites in
``transformer_layer.py`` / ``transformer_block.py`` keep working.

Net effect:
- forward path now matches the canonical TileKernels kernels (which
  fuse RMS-norm + GEMM split-K, sinkhorn, and the pre-apply mix)
- backward path is enabled (the original code asserted
  ``_HYPER_CONNECTION_MIXER_NO_GRAD = True`` and wrapped everything
  in ``torch.no_grad``)
- ``post = 2 * sigmoid(...)`` reproduced via ``post_mult_value=2.0``;
  PR's single ``hc_eps`` reused for both ``pre_eps`` and
  ``sinkhorn_eps``

Tweaks vs upstream TileKernels' high-level wrappers:
- inline ``mhc_pre`` body so we can pass ``fuse_grad_acc=False`` to
  ``mhc_pre_norm_fn``. The default (``True``) requires ``mhc_post``
  to have written ``grad_from_mhc_post`` onto the same residual
  storage during backward, but Megatron's call sites use independent
  ``s b hc d`` -> ``b s hc d`` einops.rearrange'd tensors for
  ``layer_pre`` and ``layer_post`` so the storage objects don't match.
- inline ``mhc_head`` body to ``.contiguous()`` the
  ``mixes[..., :mhc_mult]`` slice before feeding it to
  ``mhc_head_compute_mix_fwd_kernel`` (the kernel asserts
  ``strides[0] == mhc_mult`` but the slice keeps the
  ``mhc_mult * (mhc_mult + 2)`` parent stride).
- ``.contiguous()`` everywhere we hand bf16 tensors to TileLang
  kernels so the no-grad fused ``pre_big_fuse`` path doesn't trip
  the ``view`` stride check.

Verified end-to-end: 3-rollout GRPO smoke test on
DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, GRPO with rollout
batch 8, 4 samples/prompt, 64 tok cap) reaches steps 0..2 with
entropy_loss / logprob_abs_diff in the same band as the original
in-tree implementation (rl-smoke-pass tag) — see job
raysubmit_B3DgGam9vCYuDAVU.
- ``kernel/act_quant.py`` is now a thin wrapper around
  ``tile_kernels.quant.per_token_cast(fmt='e4m3', round_sf=True)``;
  shape/dtype contract preserved (``(y_fp8, s_fp32)`` with
  ``s.shape == (*x.shape[:-1], N // block_size)``) so callers in
  ``qat.py`` / ``compressor.py`` / ``v4_indexer.py`` don't change.
- ``kernel/sinkhorn.py`` is removed: it was only consumed by the
  legacy ``hyper_connection.py`` path which now calls
  ``tile_kernels.modeling.mhc.ops.sinkhorn_normalize`` instead.

End-to-end audit of ``miles_plugins/models/deepseek_v4/ops/kernel/``:
- act_quant.py             -> tile_kernels.quant.per_token_cast (this commit)
- sinkhorn.py              -> tile_kernels.modeling.mhc.ops (Batch 1, removed)
- tilelang_indexer*.py     -> no TileKernels equivalent (DSV4 DSA-specific)
- tilelang_sparse_mla*.py  -> no TileKernels equivalent (DSV4 sparse-MLA)

So every ``kernel/`` file that has a TileKernels analogue now routes
through TileKernels; the remaining files are V4-specific and have no
upstream replacement.

Verified end-to-end with the 12-rollout GRPO smoke harness on
DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, run id
260425-054701-579 -> raysubmit_a4YH6wKJ963VrTg9). All 12 steps
(0..11) completed cleanly; entropy_loss / logprob_abs_diff land in
the same band as the rl-smoke-pass baseline (12-step run on the
unmodified PR @ a72ed84):

  step | baseline ent | TK ent
  -----+--------------+--------
   0   | 1.6858       | 1.7286
   1   | 1.7854       | 1.7321
   2   | 1.7360       | 1.7803
   3   | 1.7696       | 1.6750
   4   | 1.6815       | 1.7532
   5   | 1.7172       | 1.6865
   6   | 1.7394       | 1.6949
   7   | 1.7634       | 1.7217
   8   | 1.6674       | 1.6871
   9   | 1.7330       | 1.6943
  10   | 1.6319       | 1.6998
  11   | 1.6962       | 1.7321
- ops/qat.py:fp8_simulate now uses tile_kernels.quant.per_token_cast_back
  for the FP8→BF16 dequant step (was a manual unflatten/multiply).
- tools/fp8_cast_bf16.py:weight_dequant replaces the in-tree Triton kernel
  with tile_kernels.quant.cast_back (128x128 block FP8 dequant).

Both substitutions are bit-exact against the prior implementations on
GPU (max_diff = 0.0 across 2D/3D/4D inputs and 256x768 fp8 weights).
…cision

TileKernels' MHC fp32 GEMM uses TF32 tensor cores on H100/H200 (H100+
fp32 GEMM has no full-precision tensor-core path; TF32 is the fastest
fp32 mode).  PyTorch's default ``torch.backends.cuda.matmul.allow_tf32 =
False`` forces fp32 F.linear onto the SIMT path, which introduced a
~1e-4 mean-abs gap vs the TileKernels-backed HC mixer.

Setting allow_tf32 = True at deepseek_v4 plugin import time:

  HC parity (TileKernels MHC fwd vs legacy in-tree fwd, no-grad):
                       mean_abs    max_abs    notes
    layer_input        1.05e-5     1.56e-2    bf16 LSB output
    pre.post           1.52e-5     7.65e-5
    pre.comb           4.61e-6     2.87e-5
    hc_post out        1.15e-8     9.77e-4    fp32-equivalent
    hc_head out        1.14e-5     1.56e-2    bf16 LSB output

All ops <= 1.5e-5 mean-abs (matches the attention/indexer 1e-5 bar).
The max-abs values are 1 ULP of bf16 at the output magnitude (~6.0,
1 ULP = 2^-5 = 3.13e-2).  Other fp32 matmuls in the plugin (compressor,
indexer projections) get a free TF32 speed-up as a side effect.
…ware-aware scale dtype

Replace the TileKernels per_token_cast wrapper with a verbatim port of
deepseek-ai/DeepSeek-V4-Pro/inference/kernel.py:act_quant so this code
path is bit-exact with the upstream inference kernel.

The ported act_quant exposes scale_dtype and inplace, matching official:
  - scale_dtype=None auto-selects via SGLang should_deepgemm_weight_requant_ue8m0
    (Blackwell + DeepGEMM JIT -> float8_e8m0fnu; Hopper -> float32).
  - explicit scale_dtype overrides the auto path either way.
  - inplace=True wires through to the fused quant+dequant kernel.

Drops the tile_kernels import for this file; the package was not listed
in any requirements manifest. Caller (qat.py:fp8_simulate) is unaffected:
on Hopper the auto path resolves to float32, preserving prior behavior.

Verified on H200:
  - 26 kernel cases bit-exact vs official act_quant (4 shapes x 5 modes + 3D).
  - 4-layer iterated fp8_simulate stable, ~2.25% mean rel noise as expected for E4M3.
  - Hardware auto-resolve matches sglang.deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces full support for DeepSeek-V4 models, including Flash and Pro versions, by implementing specialized weight conversion, indexer replay, and a collection of TileLang kernels for optimized attention and hyper-connection operations. It also updates SGLang dependencies to v0.5.12 and adds a comprehensive training pipeline script. Technical feedback identifies several critical issues in the newly added kernels, such as missing boundary checks in the indexer and sparse MLA kernels that could cause out-of-bounds memory access. Furthermore, a bug in the FP8 simulation casting logic and a potential NameError in the weight update function were found, along with a recommendation to refactor model-specific padding detection for improved robustness.

Comment on lines +546 to +547
if isinstance(num_new_engines, tuple):
num_new_engines = num_new_engines[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable num_new_engines is used here but it does not appear to be defined within the scope of the update_weights function in the current diff. This will result in a NameError at runtime. Please ensure it is correctly initialized, likely from the return value of a weight update or engine connection call.

Comment on lines +107 to +109
T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
))) * s_local[i],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When inplace=True, out_dtype is reassigned to in_dtype (e.g., BF16). The inner T.Cast(out_dtype, ...) call then casts to BF16 instead of the intended FP8 quantization format. This prevents the kernel from correctly simulating FP8 quantization noise. You should use a separate variable to hold the quantization target type (e.g., the original FP8 constant).

                        y_local[i, j] = T.Cast(
                            out_dtype,
                            T.Cast(compute_dtype, T.Cast(FP8, T.clamp(
                                x_local[i, j] / s_local[i], fp8_min, fp8_max
                            ))) * s_local[i],
                        )

T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)

for bq_i, bn_i in T.Parallel(block_Q, block_N):
Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, bq_i]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The kernel writes to Logits without checking if the computed indices seq_len_i + bq_i and cu_k_s_min + nbn_i * block_N + bn_i are within the bounds of seq_len and seq_len_kv. This can lead to out-of-bounds memory access if the sequence lengths are not multiples of the block sizes.

Suggested change
Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, bq_i]
for bq_i, bn_i in T.Parallel(block_Q, block_N):
if seq_len_i + bq_i < seq_len and cu_k_s_min + nbn_i * block_N + bn_i < seq_len_kv:
Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, bq_i]

Comment on lines +231 to +235
T.atomic_add(
dAttnSink[bz * block_H + h_i],
-Delta[by, s_i, bz * block_H + h_i]
* T.exp2(AttnSink[bz * block_H + h_i] * 1.44269504 - Lse[by, s_i, bz * block_H + h_i]),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The kernel performs atomic additions to dAttnSink using an index bz * block_H + h_i that can range up to padded_H - 1. If the number of heads H is not a power of 2 (e.g., H=43 or H=61), this will result in out-of-bounds memory access. A boundary check is required.

Suggested change
T.atomic_add(
dAttnSink[bz * block_H + h_i],
-Delta[by, s_i, bz * block_H + h_i]
* T.exp2(AttnSink[bz * block_H + h_i] * 1.44269504 - Lse[by, s_i, bz * block_H + h_i]),
)
for h_i in T.Parallel(block_H):
if bz * block_H + h_i < H:
T.atomic_add(
dAttnSink[bz * block_H + h_i],
-Delta[by, s_i, bz * block_H + h_i]
* T.exp2(AttnSink[bz * block_H + h_i] * 1.44269504 - Lse[by, s_i, bz * block_H + h_i]),
)

Comment on lines +136 to +137
for h_i in T.Parallel(H_per_block):
sumexp[h_i] += T.exp2(AttnSink[H0 + h_i] * 1.44269504 - m_i[h_i] * sm_scale)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The kernel accesses AttnSink using H0 + h_i, which can exceed the actual number of heads H because H_per_block is based on padded_H. This will cause out-of-bounds memory access for models where H is not a multiple of the block size. Please add a boundary check.

Suggested change
for h_i in T.Parallel(H_per_block):
sumexp[h_i] += T.exp2(AttnSink[H0 + h_i] * 1.44269504 - m_i[h_i] * sm_scale)
for h_i in T.Parallel(H_per_block):
if H0 + h_i < H:
sumexp[h_i] += T.exp2(AttnSink[H0 + h_i] * 1.44269504 - m_i[h_i] * sm_scale)


# pad to reduce memory fragmentation and maybe make the computation faster
pad_size = parallel_state.tp.size * args.data_pad_size_multiplier
if getattr(args, "dsv4_hc_mult", 0) != 0: # TODO improve the way to detect needing this
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for detecting when to adjust pad_size for DeepSeek-V4 models relies on checking dsv4_hc_mult. As noted in the TODO, this should be refactored to a more robust detection mechanism, such as checking the model type or a specific configuration flag, to avoid relying on a specific attribute that might be missing or zero in other valid configurations.

References
  1. Model parameters should be retrieved from the model configuration rather than being hardcoded or inferred from potentially missing attributes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants