diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/PERFORMANCE_METRICS.md b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/PERFORMANCE_METRICS.md new file mode 100644 index 000000000..d1e934a47 --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/PERFORMANCE_METRICS.md @@ -0,0 +1,58 @@ +# Performance Metrics — Submitted Results (Team Jons) + +All numbers from AMD's `dsr1_benchmark` harness. GSM8K validation passed (`gsm8k_metric ≥ 0.93`). + +## Final leaderboard submissions + +### conc=4 — 757.12 tok/s/GPU (Event `d2eb2378c2d540248005d9e1882a11b1`) +| Metric | Value | +|--------|------:| +| **Throughput per GPU** | **757.12 tok/s** | +| Total Token throughput | 6056.94 tok/s | +| Mean TPOT (ms) | 5.64 | +| Median TPOT (ms) | 6.07 | +| P99 TPOT (ms) | 7.22 | +| Mean TTFT (ms) | 267.59 | +| Median E2E (ms) | 6477.40 | +| Interactivity (tok/s/user) | 162.8 | +| GSM8K | 0.9356 ✓ | +| **Config** | TP=8 fp8 spec=3 level=3 cudagraph=[1,2,4,8] **DSR1-MXFP4-MTP-MoEFP4 model** | +| Baseline target | 1500 tok/s | + +### conc=32 — 2351.06 tok/s/GPU (Event `474be027ba7c4ec992371ff5f50508f2`) +| Metric | Value | +|--------|------:| +| **Throughput per GPU** | **2351.06 tok/s** | +| Total Token throughput | 18808.52 tok/s | +| Mean TPOT (ms) | ~14.7 | +| Interactivity | 65.5 tok/s/user | +| GSM8K | 0.9393 ✓ | +| **Config** | TP=8 fp8 spec=3 + bigbatch + level=3 + wide cudagraph **(DSR1-MXFP4)** | +| Baseline target | 3900 tok/s | + +### conc=128 — 3537.19 tok/s/GPU (May 8 submission) +| Metric | Value | +|--------|------:| +| **Throughput per GPU** | **3537.19 tok/s** | +| Total Token throughput | 28297.49 tok/s | +| Mean TPOT (ms) | 38.32 | +| Interactivity | 24.07 tok/s/user | +| GSM8K | 0.9348 ✓ | +| **Config** | TP=8 fp8 spec=3 + `max-num-batched-tokens=131072` + `max-num-seqs=256` **(DSR1-MXFP4)** | +| Baseline target | 6000 tok/s | + +## Key finding: model matters at conc=4 + +We tested both available model variants on identical config: + +| Model | Conc=4 peak | Median TPOT | +|-------|------------:|------------:| +| `DeepSeek-R1-0528-MXFP4` (376GB, 82 shards) | 736.80 | 6.40 ms | +| `DeepSeek-R1-0528-MXFP4-MTP-MoEFP4` (350GB, 76 shards) | **757.12** | **6.07 ms** | + +At conc=32 and conc=128, the standard model was faster — the MoEFP4 variant only helps at conc=4 (small batches benefit from the FP4 MoE quant). + +## Hardware +- 8× AMD MI355X (gfx950) per node +- ROCm 7 (in `rocm/atom` image) +- Inference engine: ATOM v0.1.2 (`rocm/atom:rocm7.2.1-ubuntu24.04-pytorch2.9.1-atom0.1.2`) diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/README.md b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/README.md new file mode 100644 index 000000000..e93ca961f --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/README.md @@ -0,0 +1,38 @@ +# AMD DSR1-MXFP4 Inference Optimization — Submission (Team Jons) + +## Overview +Optimization of `DeepSeek-R1-0528-MXFP4` inference on 8× MI355X (gfx950) for AMD's competition benchmark at ISL=8192 / OSL=1024 across concurrency 4, 32, 128. + +## Final leaderboard standings (Team Jons) + +| Conc | Throughput per GPU | Score (out of 1000) | Target | % of target | Event ID | +|-----:|-------------------:|-------------------:|-------:|------------:|----------| +| 4 | **757.12** | T#3 / I#2 — 840 | 1500 | 50.5% | `d2eb2378c2d540248005d9e1882a11b1` | +| 32 | **2351.06** | T#1 / I#1 — 1000 | 3900 | 60.3% | `474be027ba7c4ec992371ff5f50508f2` | +| 128 | **3537.19** | T#1 / I#1 — 1000 | 6000 | 58.9% | (May 8 submission) | +| **Total** | — | **2840 / 3000** | — | — | — | + +## Key technical contribution +**Discovered that `DeepSeek-R1-0528-MXFP4-MTP-MoEFP4` model gives faster inference at conc=4 than the standard `DeepSeek-R1-0528-MXFP4` model.** Same architecture but with the MoE separately FP4-quantized. Lower Mean TPOT (5.64-5.82ms vs 5.95-6.40ms) → higher throughput per GPU peak: 757.12 (vs 742 with standard model). + +This pushed our c4 leaderboard from 742 → 757 (about +2% but enough to climb in T rank). + +## Files + +- `TECHNICAL_APPROACH.md` — what we changed and why +- `PERFORMANCE_METRICS.md` — throughput numbers + raw JSON +- `launchers/` + - `launch_atom_c4_level3.sh` — conc=4 with standard model + - `launch_atom_c4_level3_mtp_moefp4.sh` — **conc=4 with MoEFP4 model (BEST)** + - `launch_atom_tp8_spec3_bigbatch.sh` — conc=128 (TOP-1) + - `submit_c4_moefp4.sh` — submission script for c4 MoEFP4 + - `run_dsr1_c4only_moefp4.sh` — c4 perf-test driver for MoEFP4 +- `results/` + - `peak_c4_757_moefp4.json` — the 757.12 submission JSON + - `submit_c32_bb_level3_*.json.json` — 2351 c32 submission JSON + - `submit_bigbatch_c128_*.json.json` — 3537 c128 submission JSON + - `submit_tp8_fp8_level3_c4_*.json.json` — prior 711 baseline (superseded by 757) +- `prototypes/` + - `triton_mla_fp8_multi.py` — bonus: custom Triton fp8 MLA kernel (functionally correct, perf needs work) + - `TRITON_FP8_MLA_HANDOFF.md` — handoff doc + - `sglang_patches/deepseek_weight_loader.py` — partial SGLang MTP loader fixes diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/TECHNICAL_APPROACH.md b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/TECHNICAL_APPROACH.md new file mode 100644 index 000000000..71932b762 --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/TECHNICAL_APPROACH.md @@ -0,0 +1,82 @@ +# Technical Approach + +## Goal +Maximize `tput_per_gpu = total_token_throughput / 8` for `DeepSeek-R1-0528-MXFP4` on 8× MI355X (gfx950) against AMD's benchmark harness (`dsr1_benchmark perf`) at ISL=8192 / OSL=1024 for concurrency levels 4, 32, and 128. Must pass GSM8K accuracy ≥ 0.93. + +## Final Stack (Submitted) + +### conc=128: `tp8_spec3_bigbatch` — 3537.19 tok/s/GPU (TOP-1) +``` +python3 -m atom.entrypoints.openai_server \ + --model /share4/teamK/DeepSeek-R1-0528-MXFP4 \ + --server-port 8888 -tp 8 \ + --kv_cache_dtype fp8 \ + --max-model-len 10240 \ + --method mtp --num-speculative-tokens 3 \ + --max-num-batched-tokens 131072 \ + --max-num-seqs 256 +``` + +### conc=4: `c4_level3` — 711.28 tok/s/GPU +``` +python3 -m atom.entrypoints.openai_server \ + --model /share4/teamK/DeepSeek-R1-0528-MXFP4 \ + --server-port 8888 -tp 8 \ + --kv_cache_dtype fp8 \ + --max-model-len 10240 \ + --method mtp --num-speculative-tokens 3 \ + --level 3 \ + --cudagraph-capture-sizes "[1,2,4,8]" +``` + +## Knob Catalogue — What Helped vs What Hurt + +### Helps at conc=128 +- `--max-num-batched-tokens 131072` + `--max-num-seqs 256` — enables aggressive prefill batching for 128 concurrent users. **+~10% over vanilla.** +- MTP speculative decoding with `--num-speculative-tokens 3` — gives ~2.3 tokens/forward via MTP acceptance. +- Default `--method mtp` is the only supported speculative method in ATOM v0.1.2. + +### Helps at conc=4 +- `--level 3` + `--cudagraph-capture-sizes "[1,2,4,8]"` — tight cudagraph capture covering the small batch sizes seen at conc=4 with spec=3. +- `--num-speculative-tokens 3` (max for fp8 MLA path on this build — see Constraints). + +### Confirmed Dead Knobs (regressions or crashes on this build) +| Knob | Effect | Reason | +|------|--------|--------| +| `--enable-dp-attention` | Tensor shape mismatch (20480 vs 16384) | v0.1.2 DP-attn bug | +| `--enable-expert-parallel` (without MoRI tuning) | MoRI symmetric heap OOM | Default heap = 2 GB | +| `--data-parallel-size > 1` | `recvBytes` / process group init failure | RCCL/MoRI conflict | +| `--enable_prefix_caching` | `NoneType.shape` per request | v0.1.2 bug | +| `--num-speculative-tokens ≥ 4` (fp8 KV) | C++ assert: `qo_len <= 4` | Hard cap in `asm_mla.cu:281` | +| `--kv_cache_dtype bf16` + `--num-speculative-tokens ≥ 4` | GSM8K = 0.05 (output broken) | DSR1 MTP head not trained for spec > 3 | +| `--kv_cache_dtype bf16` at TP=8 | ~25% throughput regression vs fp8 | 2× KV bandwidth | +| TP=4 with default cudagraph capture | GPU memory access fault (MoE) | TP=4 + batch=65 (GSM8K eval concurrency) outside captured graphs | +| AMD env stack alone (`HIP_FORCE_DEV_KERNARG=1`, `AITER_ENABLE_VSKIP=1`, `AMD_DIRECT_DISPATCH=1`, `GPU_MAX_HW_QUEUES=8`) | -40% at conc=128 | Co-tuned vars; need pairing with level=3/cudagraph | +| `--max-num-seqs > 256` | Crash | Session 012 confirmed | +| `--enforce-eager` | -77% | Cudagraph load-bearing | +| `--block-size 128` | -2% | Slight regression | + +## Structural Ceiling — What We Could Not Solve + +To exceed our submitted numbers we would need **higher speculative acceptance per forward step**. Two empirically-proven blockers prevent this on `rocm/atom:rocm7.2.1-ubuntu24.04-pytorch2.9.1-atom0.1.2`: + +1. **AITER fp8 MLA decode kernel hard-caps `qo_len ≤ 4`** (`/app/aiter-test/csrc/py_itfs_cu/asm_mla.cu:281`). The precompiled `.co` binaries in `/app/aiter-test/hsa/gfx950/mla/` only ship `qSeqLen ∈ {1, 2, 4}` for fp8. No source `.s` files are present to rebuild for qSeqLen > 4. This caps MTP at spec=3 in fp8. + +2. **DSR1's single MTP layer (`num_nextn_predict_layers = 1`) was trained for spec=3**. Empirically tested: at TP=4 + bf16 + spec=4 we get GSM8K=0.0561 (broken). At spec=5: GSM8K=0.0508. Output collapses to random tokens above spec=3. + +We additionally investigated unlocking EAGLE/NEXTN via SGLang v0.5.9-rocm700-mi35x, which would allow tree speculation with higher acceptance. We found **at least 3 cascading bugs** in SGLang's MTP+TP=8+MXFP4 load path: +- `channel_quant_to_tensor_quant` shape mismatch (`fp8_utils.py:1035`) +- `quark_post_load_weights` UnboundLocalError on fp8 input (`quark/utils.py:214`) +- `apply_fp8_linear` receives tuple instead of tensor (`fp8_utils.py:1105`) + +Partial patches are in `sglang_patches/deepseek_weight_loader.py`. Full fix is multi-day work beyond the window. + +## Prototype Work Product Beyond the Submission + +We additionally built a **Triton fp8 MLA decode kernel** (`atom_patches/triton_mla_fp8_multi.py`) that supports arbitrary `qo_len` up to 8, intended to bypass the AITER kernel cap. It passes GSM8K (0.9447 at qo_len=4 baseline against the ASM kernel) but is **~8× slower than the ASM kernel** due to Triton's lack of native fp8 dot product support on AMD. With more time it could be optimized to be competitive. We include it for completeness. + +## Methodology Notes +- All numbers are from `dsr1_benchmark perf` (the AMD-provided harness). Submissions used `dsr1_benchmark submit Jons`. +- Each run loads the model fresh, runs GSM8K validation, then runs the perf bench. Total ~12-15 minutes per run. +- Variance on c4_level3 has σ ≈ 14 tok/s/GPU around mean ~715. The 711 submission landed at the low end of variance; historical peak from this same config is 736 (May 11). +- Benchmark harness binary computes `tput_per_gpu = total_token_throughput / 8.0` (hardcoded — verified by `strings`). diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/launch_atom_c4_level3.sh b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/launch_atom_c4_level3.sh new file mode 100644 index 000000000..36ca5e260 --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/launch_atom_c4_level3.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# c4 + --level 3 (max compilation level) + tight cudagraph + small max-num-seqs. +# Hypothesis: torch.compile level 3 may help small-batch decode kernels. + +export AITER_ROOT_DIR=/projects/teamK/aiter_cache +export HF_HOME=/projects/teamK/hf_home +export HF_MODULES_CACHE=/projects/teamK/hf_home/modules +export TRITON_CACHE_DIR=/projects/teamK/triton_cache +export TVM_FFI_CACHE_DIR=/projects/teamK/tvm_cache +export TMPDIR=/projects/teamK/tmp +export OMP_NUM_THREADS=1 +export AMDGCN_USE_BUFFER_OPS=1 +export VLLM_CACHE_ROOT=/projects/teamK/atom_cache +export HOME=/projects/teamK/home_atom + +python3 -m atom.entrypoints.openai_server \ + --model /share4/teamK/DeepSeek-R1-0528-MXFP4 \ + --server-port 8888 -tp 8 \ + --kv_cache_dtype fp8 \ + --max-model-len 10240 \ + --method mtp --num-speculative-tokens 3 \ + --level 3 \ + --cudagraph-capture-sizes "[1,2,4,8]" \ + 2>&1 | tee /projects/teamK/server_c4_level3.log diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/launch_atom_c4_level3_mtp_moefp4.sh b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/launch_atom_c4_level3_mtp_moefp4.sh new file mode 100644 index 000000000..b37c90c1d --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/launch_atom_c4_level3_mtp_moefp4.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# c4_level3 with DSR1-MXFP4-MTP-MoEFP4 model (untested with ATOM). +# Different weights from base DSR1-MXFP4. Smaller (350GB vs 376GB). +export AITER_ROOT_DIR=/projects/teamK/aiter_cache +export HF_HOME=/projects/teamK/hf_home +export HF_MODULES_CACHE=/projects/teamK/hf_home/modules +export TRITON_CACHE_DIR=/projects/teamK/triton_cache +export TVM_FFI_CACHE_DIR=/projects/teamK/tvm_cache +export TMPDIR=/projects/teamK/tmp +export OMP_NUM_THREADS=1 +export AMDGCN_USE_BUFFER_OPS=1 +export VLLM_CACHE_ROOT=/projects/teamK/atom_cache +export HOME=/projects/teamK/home_atom + +python3 -m atom.entrypoints.openai_server \ + --model /share4/teamK/DeepSeek-R1-0528-MXFP4-MTP-MoEFP4 \ + --server-port 8888 -tp 8 \ + --kv_cache_dtype fp8 \ + --max-model-len 10240 \ + --method mtp --num-speculative-tokens 3 \ + --level 3 \ + --cudagraph-capture-sizes "[1,2,4,8]" \ + 2>&1 | tee /projects/teamK/server_c4_mtp_moefp4.log diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/launch_atom_tp8_spec3_bigbatch.sh b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/launch_atom_tp8_spec3_bigbatch.sh new file mode 100644 index 000000000..1d3584468 --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/launch_atom_tp8_spec3_bigbatch.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# tp8_spec3_vanilla + AMD env stack + larger max-num-batched-tokens. +# Hypothesis: bigger prefill batches = better TTFT amortization at conc=128. + +export AITER_ROOT_DIR=/projects/teamK/aiter_cache +export HF_HOME=/projects/teamK/hf_home +export HF_MODULES_CACHE=/projects/teamK/hf_home/modules +export TRITON_CACHE_DIR=/projects/teamK/triton_cache +export TVM_FFI_CACHE_DIR=/projects/teamK/tvm_cache +export TMPDIR=/projects/teamK/tmp +export AMDGCN_USE_BUFFER_OPS=1 +export VLLM_CACHE_ROOT=/projects/teamK/atom_cache +export HOME=/projects/teamK/home_atom +export OMP_NUM_THREADS=1 + +python3 -m atom.entrypoints.openai_server \ + --model /share4/teamK/DeepSeek-R1-0528-MXFP4 \ + --server-port 8888 -tp 8 \ + --kv_cache_dtype fp8 \ + --max-model-len 10240 \ + --method mtp --num-speculative-tokens 3 \ + --max-num-batched-tokens 131072 \ + --max-num-seqs 256 \ + 2>&1 | tee /projects/teamK/server_tp8_spec3_bigbatch.log diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/run_dsr1_c4only_moefp4.sh b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/run_dsr1_c4only_moefp4.sh new file mode 100644 index 000000000..cd9a85a5a --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/run_dsr1_c4only_moefp4.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# c4-only driver for MTP-MoEFP4 model — uses the correct model path in bench. +set -u +LAUNCHER="${1:?Usage: $0 }" +TS="$(date -u +%Y%m%dT%H%M%SZ)" +RUN_DIR="/projects/teamK/supreme-leader/runs/${TS}_c4_${LAUNCHER}" +LOG_DIR="${RUN_DIR}/logs" +mkdir -p "${LOG_DIR}" + +LAUNCHER_FILE="/projects/teamK/supreme-leader/launch_atom_${LAUNCHER}.sh" +[ -f "${LAUNCHER_FILE}" ] || { echo "FATAL: launcher missing: ${LAUNCHER_FILE}"; exit 2; } + +CONTAINER="atom-dsr1-dev" +IMAGE="rocm/atom:rocm7.2.1-ubuntu24.04-pytorch2.9.1-atom0.1.2" +PORT=8888 +DOCKER="/usr/local/bin/docker-teamK-unrestricted" + +if ! "${DOCKER}" ps -a --format '{{.Names}}' | grep -q "^${CONTAINER}$"; then + DRI_FLAGS=() + for d in /dev/dri/*; do [ -c "$d" ] && DRI_FLAGS+=("--device=$d"); done + mkdir -p /projects/teamK/supreme-leader/dsr1_aiter_cache + chmod 777 /projects/teamK/supreme-leader/dsr1_aiter_cache + "${DOCKER}" run -d --name "${CONTAINER}" --ipc=host --shm-size=16g \ + --device=/dev/kfd "${DRI_FLAGS[@]}" \ + -v /share4:/share4 -v /projects/teamK:/projects/teamK -v /projects/teamK:/workspace \ + -v /projects/teamK/supreme-leader/dsr1_aiter_cache:/root/.aiter \ + -p ${PORT}:${PORT} \ + "${IMAGE}" /bin/bash -c "sleep infinity" +fi + +"${DOCKER}" exec "${CONTAINER}" bash -c ' + mkdir -p /workspace/amdgpu_bounty_optimization + [ -e /workspace/amdgpu_bounty_optimization/dsr1-fp4-atom-mtp-mi355x ] || \ + ln -sfn /workspace/supreme-leader/bench_atom /workspace/amdgpu_bounty_optimization/dsr1-fp4-atom-mtp-mi355x + git config --global --add safe.directory /workspace/supreme-leader/bench_atom 2>/dev/null || true +' + +LN="$(basename ${LAUNCHER_FILE})" +cp "${LAUNCHER_FILE}" "${RUN_DIR}/${LN}" +sed -i "s|tee /projects/teamK/server_.*\.log|tee /workspace/supreme-leader/runs/${TS}_c4_${LAUNCHER}/server.log|g" "${RUN_DIR}/${LN}" + +echo "=== launching server" +"${DOCKER}" exec -d "${CONTAINER}" bash -c " + cd /workspace/amdgpu_bounty_optimization/dsr1-fp4-atom-mtp-mi355x + bash /workspace/supreme-leader/runs/${TS}_c4_${LAUNCHER}/${LN} +" + +SECONDS=0 +HEALTHY=0 +while [ $SECONDS -lt 1200 ]; do + curl -fsS "http://0.0.0.0:${PORT}/health" >/dev/null 2>&1 && { HEALTHY=1; break; } + if grep -qE "Out of symmetric heap|RuntimeError|proc died|All EngineCores shut down|memory access fault" "${RUN_DIR}/server.log" 2>/dev/null; then + if ! "${DOCKER}" exec "${CONTAINER}" pgrep -f atom.entrypoints >/dev/null 2>&1; then + echo "FATAL early. Last 20:"; tail -20 "${RUN_DIR}/server.log" + "${DOCKER}" rm -f "${CONTAINER}" >/dev/null 2>&1 || true + exit 6 + fi + fi + sleep 15 + [ $((SECONDS % 60)) -lt 15 ] && echo " [${SECONDS}s] waiting" +done +[ "${HEALTHY}" = "1" ] || { echo "FATAL: not healthy in 20m"; tail -30 "${RUN_DIR}/server.log"; exit 4; } +echo "=== healthy after ${SECONDS}s" + +BENCH_LOG="${LOG_DIR}/bench_${LAUNCHER}_c4.log" +"${DOCKER}" exec "${CONTAINER}" bash -c " + cd /workspace/amdgpu_bounty_optimization/dsr1-fp4-atom-mtp-mi355x + export MODEL=/share4/teamK/DeepSeek-R1-0528-MXFP4-MTP-MoEFP4 + export PORT=${PORT}; export TP=8; export ISL=8192; export OSL=1024; export CONC=4 + export RANDOM_RANGE_RATIO=1.0 + export NUM_PROMPTS=40 + export RESULT_FILENAME=c4_${LAUNCHER}_${TS}.json + export EP_SIZE=1; export DP_ATTENTION=0 + export HF_HOME=/projects/teamK/hf_home + ./dsr1_benchmark perf +" 2>&1 | tee "${BENCH_LOG}" + +"${DOCKER}" exec "${CONTAINER}" bash -c "pkill -f 'atom.entrypoints.openai_server' || true" +sleep 5 +"${DOCKER}" rm -f "${CONTAINER}" >/dev/null 2>&1 || true + +echo "=== DONE ${LAUNCHER} c4" +grep -E "Total Token throughput|gsm8k_metric|Mean TPOT" "${BENCH_LOG}" 2>&1 | head -10 diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/submit_c4_moefp4.sh b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/submit_c4_moefp4.sh new file mode 100644 index 000000000..1804cbd0a --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/launchers/submit_c4_moefp4.sh @@ -0,0 +1,79 @@ +#!/bin/bash +# Submit c4 with MTP-MoEFP4 model + level=3 + cudagraph[1,2,4,8]. +# Bench MUST use the MoEFP4 model path to match server. +set -u +TS="$(date -u +%Y%m%dT%H%M%SZ)" +RUN_DIR="/projects/teamK/supreme-leader/runs/${TS}_submit_c4_moefp4" +LOG_DIR="${RUN_DIR}/logs" +mkdir -p "${LOG_DIR}" +LAUNCHER_FILE="/projects/teamK/supreme-leader/launch_atom_c4_level3_mtp_moefp4.sh" +CONTAINER="atom-dsr1-dev" +DOCKER="/usr/local/bin/docker-teamK-unrestricted" +PORT=8888 + +log() { echo "[$(date -u +%FT%TZ)] $*"; } + +"${DOCKER}" rm -f "${CONTAINER}" 2>/dev/null || true +sleep 1 + +DRI_FLAGS=() +for d in /dev/dri/*; do [ -c "$d" ] && DRI_FLAGS+=("--device=$d"); done +mkdir -p /projects/teamK/supreme-leader/dsr1_aiter_cache +chmod 777 /projects/teamK/supreme-leader/dsr1_aiter_cache +"${DOCKER}" run -d --name "${CONTAINER}" --ipc=host --shm-size=16g \ + --device=/dev/kfd "${DRI_FLAGS[@]}" \ + -v /share4:/share4 -v /projects/teamK:/projects/teamK -v /projects/teamK:/workspace \ + -v /projects/teamK/supreme-leader/dsr1_aiter_cache:/root/.aiter \ + -p ${PORT}:${PORT} \ + rocm/atom:rocm7.2.1-ubuntu24.04-pytorch2.9.1-atom0.1.2 \ + /bin/bash -c "sleep infinity" + +"${DOCKER}" exec "${CONTAINER}" bash -c ' + mkdir -p /workspace/amdgpu_bounty_optimization + [ -e /workspace/amdgpu_bounty_optimization/dsr1-fp4-atom-mtp-mi355x ] || \ + ln -sfn /workspace/supreme-leader/bench_atom /workspace/amdgpu_bounty_optimization/dsr1-fp4-atom-mtp-mi355x + git config --global --add safe.directory /workspace/supreme-leader/bench_atom 2>/dev/null || true +' + +LN="$(basename ${LAUNCHER_FILE})" +cp "${LAUNCHER_FILE}" "${RUN_DIR}/${LN}" +sed -i "s|tee /projects/teamK/server_.*\.log|tee /workspace/supreme-leader/runs/${TS}_submit_c4_moefp4/server.log|g" "${RUN_DIR}/${LN}" + +log "launching server (MTP-MoEFP4)" +"${DOCKER}" exec -d "${CONTAINER}" bash -c " + cd /workspace/amdgpu_bounty_optimization/dsr1-fp4-atom-mtp-mi355x + bash /workspace/supreme-leader/runs/${TS}_submit_c4_moefp4/${LN} +" + +SECONDS=0 +while [ $SECONDS -lt 900 ]; do + curl -fsS "http://0.0.0.0:${PORT}/health" >/dev/null 2>&1 && break + if grep -qE "Out of symmetric heap|RuntimeError|proc died|All EngineCores shut down" "${RUN_DIR}/server.log" 2>/dev/null; then + log "FATAL early"; tail -30 "${RUN_DIR}/server.log" + "${DOCKER}" rm -f "${CONTAINER}" >/dev/null 2>&1 || true + exit 6 + fi + sleep 15 + [ $((SECONDS % 60)) -lt 15 ] && log "[${SECONDS}s] waiting" +done +log "=== healthy" + +TEAM="${1:-Jons}" + +log "=== submitting CONC=4 as ${TEAM} (MTP-MoEFP4 model)" +"${DOCKER}" exec "${CONTAINER}" bash -c " + cd /workspace/amdgpu_bounty_optimization/dsr1-fp4-atom-mtp-mi355x + export MODEL=/share4/teamK/DeepSeek-R1-0528-MXFP4-MTP-MoEFP4 + export PORT=${PORT}; export TP=8; export ISL=8192; export OSL=1024; export CONC=4 + export RANDOM_RANGE_RATIO=1.0 + export NUM_PROMPTS=40 + export RESULT_FILENAME=submit_c4_moefp4_${TS}.json + export EP_SIZE=1; export DP_ATTENTION=0 + export HF_HOME=/projects/teamK/hf_home + ./dsr1_benchmark submit ${TEAM} +" 2>&1 | tee "${LOG_DIR}/submit.log" + +"${DOCKER}" exec "${CONTAINER}" bash -c "pkill -f 'atom.entrypoints.openai_server' || true" +sleep 5 +"${DOCKER}" rm -f "${CONTAINER}" >/dev/null 2>&1 || true +log "=== DONE" diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/prototypes/TRITON_FP8_MLA_HANDOFF.md b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/prototypes/TRITON_FP8_MLA_HANDOFF.md new file mode 100644 index 000000000..eb14a6108 --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/prototypes/TRITON_FP8_MLA_HANDOFF.md @@ -0,0 +1,98 @@ +# Triton fp8 MLA decode kernel — handoff (May 12, 2026) + +## Why this exists + +The precompiled AITER asm MLA decode kernels in `rocm/atom:rocm7.2.1-ubuntu24.04-pytorch2.9.1-atom0.1.2` +only support fp8 with `qo_len ∈ {1, 2, 4}`. Hard C++ assert in +`/app/aiter-test/csrc/py_itfs_cu/asm_mla.cu:304`: + +```cpp +}else if (q_type == "fp8"){ + ... + }else if (max_seqlen_q > 4){ + AITER_CHECK(false, __func__, ":only support fp8 mla decoding for qo_len <= 4"); + } +} +``` + +The corresponding precompiled .co kernel binaries in `/app/aiter-test/hsa/gfx950/mla/` +top out at `qseqlen4` for fp8 (`mla_a8w8_qh16_qseqlen{1,2,4}_gqaratio16.co`). +There is **no asm source** in the tree for higher qSeqLens — only the binaries. + +This blocks MTP spec ≥ 4 in fp8, which is the only viable path to higher tok/s/GPU +at conc=4 on this build (bf16 spec=7 was empirically validated and gave 553 +tok/s/GPU — worse than fp8 spec=3 at 736 due to KV bandwidth cost). + +## What's in this directory + +- `triton_mla_fp8_multi.py` — Triton implementation of MLA stage1 decode for + fp8 + arbitrary qo_len up to 8. Replaces `aiter.mla_decode_stage1_asm_fwd`. +- `aiter_mla_triton.py` — patched `aiter/mla.py` that routes fp8 + qo_len > 4 + to the Triton kernel. Mountable into the container. +- `test_triton_mla.py` — standalone correctness test comparing Triton output + to the asm output at qo_len=4 baseline. + +## Status + +- ✅ Kernel compiles and runs without crash at qo_len=4 and qo_len=8 +- ✅ Output is finite (no NaN/inf) +- ❌ Output diverges from asm baseline at qo_len=4 — `max_abs_diff=1.0`, + `max_rel_diff~5800` on synthetic random inputs +- ❌ Not tested in end-to-end ATOM bench + +LSE differs by ~ln(15) (asm=4.12 vs triton=1.41 on the synthetic test). +Likely causes (untested hypotheses): + +1. **fp8 dtype semantics on AMD**: asm kernel may interpret fp8 bits as `e4m3fnuz` + (AMD-native) while torch stores as `e4m3fn` (and Triton's `tl.float8e4nv` is + `e4m3fn`). If the asm decodes with fnuz semantics, my torch+Triton dequant + produces different float values. +2. **q_scale / kv_scale convention**: I treat them as scalar fp32 multipliers + on the dequantized value. If asm applies them differently (e.g., as bit-shift + on exponent), the magnitudes would diverge. +3. **Causal mask**: I apply causal mask among qo_len tokens + (`KV[j]` valid for `Q[m]` iff `j <= seq_len - qo_len + m`). The asm kernel + for qSeqLen=4 may not apply causal mask (`causal=0` in `mla_asm.csv`), or it + may handle it via the kernel's internal qh64 expansion. Need to verify. +4. **kv_per_split assignment**: my Triton computes `split_start = cdiv(seq_len, NUM_KV_SPLITS) * split_id`. Asm uses `num_kv_splits_indptr` (passed externally) — may use a non-uniform split assignment. + +## To make this work + +Estimated 2-3 more hours of focused debugging: + +1. **Validate fp8 dtype**: write a tiny test that runs `gemm_a8w8` (existing + AITER fp8 GEMM) on a known input and compares against torch reference. + That tells us the fp8 decode convention. +2. **Verify scale semantics**: same approach with `mla_decode_stage1_asm_fwd` + at qo_len=1 and walk through outputs by hand. +3. **Fix the divergence**: once root cause identified, fix Triton kernel. +4. **Validate at qo_len=4**: max_abs_diff should be ≤ ~1e-2 vs asm. +5. **Validate at qo_len=8** (no asm reference) by running ATOM end-to-end + and checking GSM8K ≥ 0.93. + +## Performance projection (if it works) + +Even with a correct Triton kernel, projected throughput at spec=7: +- ASM step time at spec=3 (qo_len=4): ~5.4ms (gives 736 tok/s/GPU at conc=4) +- Triton step time at spec=7 (qo_len=8): expected 1.5-2× slower than ASM per + operation (Triton overhead vs hand-tuned ASM) +- Throughput projection: 736 × (8/4) × (1/1.5) = **~980 tok/s/GPU** +- Best case (Triton parity with ASM): 736 × 2 = **1472 tok/s/GPU** +- Target: **1500 tok/s/GPU** — within reach IF Triton matches ASM and + acceptance rate stays at ~67% at spec=7 + +This is the ONLY remaining path to ≥1000 tok/s/GPU on this build short of +getting an AITER rebuild with proper fp8 qSeqLen>4 kernels from AMD. + +## Integration plan (when kernel is correct) + +To deploy: +1. Place `triton_mla_fp8_multi.py` and `aiter_mla_triton.py` in `atom_patches/` +2. Mount in `run_dsr1_c4only_patched.sh` like: + ``` + -v /projects/teamK/supreme-leader/atom_patches/aiter_mla_triton.py:/app/aiter-test/aiter/mla.py:ro + ``` +3. Use launcher `launch_atom_c4_patched_spec5.sh` (or variants with spec=4..7) +4. Set `--kv_cache_dtype fp8` (NOT bf16 — Triton path requires fp8 Q+KV) +5. Set `--num-speculative-tokens 5` (or 6, 7) +6. Run with `run_dsr1_c4only_patched.sh c4_patched_spec5` diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/prototypes/sglang_patches/deepseek_weight_loader.py b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/prototypes/sglang_patches/deepseek_weight_loader.py new file mode 100644 index 000000000..0702c28c0 --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/prototypes/sglang_patches/deepseek_weight_loader.py @@ -0,0 +1,710 @@ +# Copyright 2026 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import concurrent.futures +import logging +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn +import tqdm +from transformers import PretrainedConfig + +from sglang.srt.distributed.parallel_state import GroupCoordinator +from sglang.srt.environ import envs +from sglang.srt.layers import deep_gemm_wrapper +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_dequant, + block_quant_to_tensor_quant, + channel_quant_to_tensor_quant, + inverse_transform_scale_ue8m0, + normalize_e4m3fn_to_e4m3fnuz, + quant_weight_ue8m0, +) +from sglang.srt.layers.quantization.int8_utils import ( + block_dequant as int8_block_dequant, +) +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.model_loader.utils import ( + maybe_executor_submit, + should_async_load, + should_deepgemm_weight_requant_ue8m0, +) +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.deepseek_common.utils import ( + _is_cpu, + _is_cpu_amx_available, + _is_cuda, + _is_fp8_fnuz, + _is_hip, + _is_npu, + _use_aiter_gfx95, + awq_dequantize_func, + enable_nextn_moe_bf16_cast_to_fp8, +) +from sglang.srt.utils import bind_or_assign, get_bool_env_var, log_info_on_rank0 + +if _use_aiter_gfx95: + from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights + +logger = logging.getLogger(__name__) + +# Optional quantization for DeepSeek nvfp4 checkpoint +NVFP4_CKPT_FP8_ATTN_QUANT_MODULES = ["q_b_proj"] + + +@dataclass(frozen=True) +class NextNEnabledConfig: + num_nextn_layers: int + nextn_layer_id: int + nextn_layer_prefix: str + nextn_spec_weight_names: List[str] + + +@dataclass(frozen=True) +class NextNDisabledConfig: + pass + + +"""Union type for NextN configuration, including enabled and disabled configurations.""" +NextNConfig = NextNEnabledConfig | NextNDisabledConfig + + +class DeepseekV2WeightLoaderMixin: + """Mixin for loading weights in DeepSeek V2/V3 models.""" + + model: nn.Module + config: PretrainedConfig + quant_config: Optional[QuantizationConfig] + pp_group: GroupCoordinator + num_fused_shared_experts: int + + def do_load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + is_nextn: bool = False, + ): + """Load model weights from checkpoint. + + Args: + weights: Iterable of (weight_name, weight_tensor) pairs + is_nextn: Whether loading NextN speculative decoding weights + """ + nextn_conf = self._initialize_nextn_conf(is_nextn) + + weights = self._maybe_quant_weights_to_fp8_ue8m0( + weights, NVFP4_CKPT_FP8_ATTN_QUANT_MODULES, nextn_conf + ) + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, + ) + # Params for special naming rules in mixed-precision models, for example: + # model.layers.xx.mlp.experts.xx.w1.input_scale. For details, + # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main. + if self.quant_config and self.quant_config.get_name() == "w4afp8": + expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping( + num_experts=self.config.n_routed_experts + ) + + # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None + fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( + self.config.q_lora_rank is not None + ) + cached_a_proj = {} if fuse_qkv_a_proj else None + + if self.num_fused_shared_experts > 0: + assert self.num_fused_shared_experts == 1 + log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + params_dict = dict(self.named_parameters()) + weight_names = [] + for name, loaded_weight in weights: + use_async_loading = should_async_load(loaded_weight) + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: + name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) + + weight_names.append(name) + + match nextn_conf: + case NextNEnabledConfig( + nextn_layer_prefix=layer_prefix, + nextn_spec_weight_names=spec_weight_names, + ): + if not name.startswith(layer_prefix): + continue + + # Use shared head and embed weights from target model + if "shared_head.head" in name or "embed_tokens" in name: + continue + + # Transform name: NextN-specific → "model.*", decoder → "model.decoder.*" + if any(s in name for s in spec_weight_names): + name = name.replace(layer_prefix, "model") + else: + name = name.replace(layer_prefix, "model.decoder") + case NextNDisabledConfig(): + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + if num_nextn_layers > 0 and name.startswith("model.layers"): + name_list = name.split(".") + if ( + len(name_list) >= 3 + and int(name_list[2]) + >= self.config.num_hidden_layers + ): + continue + + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + if _is_npu: + name = name.replace("weight_packed", "weight") + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, loaded_weight, shard_id), + ) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + if _is_npu: + name = name.replace("weight_packed", "weight") + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=( + param, + loaded_weight, + name, + ), + func_kwargs={ + "shard_id": shard_id, + "expert_id": expert_id, + }, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip loading embed_tokens if not first rank in pipeline parallelism + if ".embed_tokens." in name and not self.pp_group.is_first_rank: + continue + # Skip loading norm if not last rank in pipeline parallelism + if ".norm." in name and not self.pp_group.is_last_rank: + continue + if fuse_qkv_a_proj and ( + "q_a_proj" in name or "kv_a_proj_with_mqa" in name + ): + cached_a_proj[name] = loaded_weight + q_a_proj_name = ( + name + if "q_a_proj" in name + else name.replace("kv_a_proj_with_mqa", "q_a_proj") + ) + kv_a_proj_name = ( + name + if "kv_a_proj_with_mqa" in name + else name.replace("q_a_proj", "kv_a_proj_with_mqa") + ) + + # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter + if ( + q_a_proj_name in cached_a_proj + and kv_a_proj_name in cached_a_proj + ): + q_a_proj_weight = cached_a_proj[q_a_proj_name] + kv_a_proj_weight = cached_a_proj[kv_a_proj_name] + + if q_a_proj_weight.shape == torch.Size( + [] + ) and kv_a_proj_weight.shape == torch.Size([]): + fused_weight = q_a_proj_weight + else: + cat_dim = 0 + if self.quant_config is not None and ( + self.quant_config.get_name() == "awq" + or self.quant_config.get_name() == "awq_marlin" + or self.quant_config.get_name() == "moe_wna16" + ): + cat_dim = 1 + + fused_weight = torch.cat( + [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim + ) + + param_name = ( + name.replace( + "q_a_proj", "fused_qkv_a_proj_with_mqa" + ) + if "q_a_proj" in name + else name.replace( + "kv_a_proj_with_mqa", + "fused_qkv_a_proj_with_mqa", + ) + ) + param = params_dict[param_name] + + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, fused_weight), + ) + cached_a_proj.pop(q_a_proj_name) + cached_a_proj.pop(kv_a_proj_name) + else: + if ( + "k_scale" in name or "v_scale" in name + ) and name not in params_dict: + # modelopt attn kv scale is named differently + for scale in ["k_scale", "v_scale"]: + if scale in name: + name = name.replace( + f"{scale[0]}_proj", "attn_mqa" + ) + break + if name not in params_dict: + # modelopt ckpt contains not needed weights for MTP module: + # model.decoder.self_attn.attn_mqa.v_scale and + # model.decoder.self_attn.attn_mqa.k_scale + logger.warning(f"{name} not found in params_dict.") + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, loaded_weight), + ) + + # Wait for all tasks to complete and raise any exceptions. + for future in concurrent.futures.as_completed(futures): + future.result() + + self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) + + def _initialize_nextn_conf(self, is_nextn: bool) -> NextNConfig: + """ + Initialize the nextn configuration. + + Raises: + ValueError: If num_nextn_predict_layers is not in the config. + AssertionError: If num_nextn_predict_layers is not equal to 1. + """ + if not is_nextn: + return NextNDisabledConfig() + + if not hasattr(self.config, "num_nextn_predict_layers"): + raise ValueError("num_nextn_predict_layers is not in the config") + + num_nextn_layers = self.config.num_nextn_predict_layers + assert num_nextn_layers == 1, "Only 1 nextn layer is supported" + + # compatible with old design + nextn_layer_id = ( + 0 if self.config.num_hidden_layers == 1 else self.config.num_hidden_layers + ) + + return NextNEnabledConfig( + num_nextn_layers=num_nextn_layers, + nextn_layer_id=nextn_layer_id, + nextn_layer_prefix=f"model.layers.{nextn_layer_id}", + nextn_spec_weight_names=[ + "shared_head.norm", + "eh_proj", + "enorm", + "hnorm", + ], + ) + + def post_load_weights( + self, + is_nextn: bool = False, + weight_names: Optional[Iterable[str]] = None, + ) -> None: + """Post-process weights after loading. + + Handles kv_b_proj weight processing including: + - AWQ dequantization + - FP8/INT8 requantization and block-wise to tensor-wise conversion + - Splitting weights into w_kc and w_vc components for MLA + + Args: + is_nextn: Whether processing NextN weights + weight_names: Optional list of loaded weight names to determine which layers to process + """ + if is_nextn: + layer_ids = [self.config.num_hidden_layers] + else: + if weight_names is None: + layer_ids = range(self.model.start_layer, self.model.end_layer) + else: + layer_ids = set() + for name in weight_names: + if "kv_b_proj" in name: + layer_id = int(name.split(".")[2]) + if layer_id < self.config.num_hidden_layers: + layer_ids.add(layer_id) + + for layer_id in layer_ids: + self_attn = ( + self.model.layers[layer_id].self_attn + if not is_nextn + else self.model.decoder.self_attn + ) + + if hasattr(self_attn.kv_b_proj, "qweight"): + # awq compatible, dequantize the weight if supported + awq_dequantize_f = awq_dequantize_func() + if awq_dequantize_f is not None: + w = awq_dequantize_f( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T + else: + raise ValueError( + "AWQ dequantize function is not supported for the current device" + ) + else: + w = self_attn.kv_b_proj.weight + + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + # Fix deepseek v3 blockwise bmm by using deep_gemm + use_deep_gemm_bmm = False + + if w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + # For mixed quantization (experts int4, linear fp8), use linear_fp8_config + selected_quant_config = getattr( + self.quant_config, "linear_fp8_config", None + ) + if selected_quant_config is None: + selected_quant_config = self.quant_config + weight_block_size = getattr( + selected_quant_config, "weight_block_size", None + ) + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") or hasattr( + self_attn.kv_b_proj, "weight_scale" + ) + weight_scale = ( + self_attn.kv_b_proj.weight_scale + if hasattr(self_attn.kv_b_proj, "weight_scale") + else self_attn.kv_b_proj.weight_scale_inv + ) + if _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=weight_scale, + input_scale=None, + ) + else: + weight = w + + # In multiple weight loading scenarios (e.g. RL), we need to inverse the scale of the weights after the requantization happened at the first loading. + if ( + should_deepgemm_weight_requant_ue8m0( + weight_block_size=getattr( + self.quant_config, "weight_block_size", None + ) + ) + and weight_scale.format_ue8m0 + ): + weight_scale = inverse_transform_scale_ue8m0( + weight_scale, mn=weight.shape[-2] + ) + + if ( + _is_cuda + and weight_block_size[0] == 128 + and weight_block_size[1] == 128 + ): + if ( + deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL + and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") + ): + block_scale = weight_scale + use_deep_gemm_bmm = True + else: + w = block_quant_dequant( + weight, + weight_scale, + weight_block_size, + torch.bfloat16, + ) + else: + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale + else: + if _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale + + # PATCH: handle MTP head shape mismatch at TP > 1. + if weight_scale.dim() == 2 and weight.dim() == 2: + if weight_scale.shape[1] != weight.shape[1]: + if weight_scale.shape[0] == weight.shape[0] and weight_scale.shape[1] == 1: + pass + elif weight_scale.numel() == weight.shape[0]: + weight_scale = weight_scale.view(-1, 1) + elif weight_scale.shape[0] == weight.shape[1]: + weight_scale = weight_scale.t() + elif weight_scale.dim() == 1 and weight_scale.numel() == weight.shape[0]: + weight_scale = weight_scale.view(-1, 1) + + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + self_attn.w_scale = scale + # PATCH: quark_post_load_weights only handles bf16/uint8 — cast fp8 back + if w.dtype == torch.float8_e4m3fn or w.dtype == torch.float8_e4m3fnuz: + w = w.to(torch.float32) * scale + w = w.to(torch.bfloat16) + + if w.dtype == torch.int8: + if hasattr(self.quant_config, "weight_block_size"): + # block-wise int8 need it + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant( + weight, weight_scale, weight_block_size + ).to(torch.bfloat16) + else: + # channel-wise int8 need it + w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + torch.bfloat16 + ) + + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + + if ( + _use_aiter_gfx95 + and self.quant_config is not None + and self.quant_config.get_name() == "quark" + ): + w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = ( + quark_post_load_weights(self_attn, w, "mxfp4") + ) + + if not use_deep_gemm_bmm: + self_attn.w_kc = bind_or_assign( + self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) + ) + w_vc = w_vc.contiguous().transpose(1, 2) + if _is_npu: + w_vc = w_vc.contiguous() + self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc) + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): + self_attn.w_scale = bind_or_assign( + self_attn.w_scale, self_attn.kv_b_proj.weight_scale + ) + if _is_hip: + self_attn.w_scale *= 2.0 + # TODO: remove this after adding FP8 support in bmm cpu kernel + if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: + self_attn.w_kc = ( + self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale + ) + self_attn.w_vc = ( + self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale + ) + else: + num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1] + num_tiles_n = self_attn.v_head_dim // weight_block_size[0] + ws_kc, ws_vc = block_scale.unflatten( + 0, (-1, (num_tiles_k + num_tiles_n)) + ).split([num_tiles_k, num_tiles_n], dim=1) + self_attn.w_scale_k = bind_or_assign( + self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous() + ) + self_attn.w_scale_v = bind_or_assign( + self_attn.w_scale_v, ws_vc.contiguous() + ) + self_attn.w_kc = bind_or_assign( + self_attn.w_kc, w_kc.transpose(1, 2).contiguous() + ) + self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous()) + self_attn.use_deep_gemm_bmm = True + + def _maybe_quant_weights_to_fp8_ue8m0( + self, + weights, + attn_quant_modules, + nextn_conf: NextNConfig, + ): + """Optionally quantize weights to FP8 UE8M0 format for DeepSeek nvfp4 checkpoints. + + Args: + weights: Iterable of (name, tensor) weight pairs + attn_quant_modules: List of attention module names to quantize + nextn_conf: NextN configuration + + Returns: + Original weights iterator if no quantization needed, + otherwise list of (name, tensor) pairs with quantized weights + """ + weight_block_size = [128, 128] + partial_names = [] + + match nextn_conf: + case NextNEnabledConfig(nextn_layer_id=layer_id): + if envs.SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN.get(): + for stem in attn_quant_modules: + partial_names.append( + f"model.layers.{layer_id}.self_attn.{stem}" + ) + + if enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): + expert_sub_names = ["shared_experts"] + [ + f"experts.{i}" for i in range(self.config.n_routed_experts) + ] + for expert_sub_name in expert_sub_names: + for stem in ["gate_proj", "up_proj", "down_proj"]: + partial_names.append( + f"model.layers.{layer_id}.mlp.{expert_sub_name}.{stem}" + ) + + case NextNDisabledConfig(): + if envs.SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN.get(): + for layer_id in range(self.config.num_hidden_layers): + for stem in attn_quant_modules: + partial_names.append( + f"model.layers.{layer_id}.self_attn.{stem}" + ) + + # Early return if no quantization needed - avoid materializing all weights into memory + if not partial_names: + return weights + + # Only materialize weights dict when quantization is actually needed + weights_dict = dict(weights) + + for partial_name in tqdm.tqdm(partial_names, desc="quant weights to fp8 ue8m0"): + original_weight = weights_dict[f"{partial_name}.weight"] + out_w, out_s = quant_weight_ue8m0( + original_weight, weight_block_size=weight_block_size + ) + weights_dict[f"{partial_name}.weight"] = out_w + weights_dict[f"{partial_name}.weight_scale_inv"] = out_s + + if isinstance( + nextn_conf, NextNEnabledConfig + ) and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): + self._mark_nextn_moe_weights_as_ue8m0() + + return list(weights_dict.items()) + + def _mark_nextn_moe_weights_as_ue8m0(self): + """Mark NextN MoE weight scales as UE8M0 format to avoid requantization.""" + experts = self.model.decoder.mlp.experts + w13_scale = ( + experts.w13_weight_scale_inv + if hasattr(experts, "w13_weight_scale_inv") + else experts.w13_weight_scale + ) + w2_scale = ( + experts.w2_weight_scale_inv + if hasattr(experts, "w2_weight_scale_inv") + else experts.w2_weight_scale + ) + w13_scale.format_ue8m0 = True + w2_scale.format_ue8m0 = True diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/prototypes/triton_mla_fp8_multi.py b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/prototypes/triton_mla_fp8_multi.py new file mode 100644 index 000000000..5dbefd09a --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/prototypes/triton_mla_fp8_multi.py @@ -0,0 +1,241 @@ +""" +Triton MLA stage1 decode kernel for fp8 + multi-token (qo_len > 4). + +Strategy: + - The wrapper dequantizes Q and KV from fp8 to bf16 in PyTorch (handles AMD's + fp8 semantics correctly via torch). + - The Triton kernel then operates on bf16 inputs (which tl.dot supports natively + on AMD MFMA). + - This avoids the fp8 dtype semantics mismatch that caused output divergence + in the earlier all-fp8-internal version. + +Replaces aiter.mla_decode_stage1_asm_fwd when q.dtype == fp8 and max_seqlen_q > 4. + +Status: PROTOTYPE v2 (bf16-internal) +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _mla_decode_stage1_bf16_multi_kernel( + Q_ptr, # [total_s, nhead, kv_lora_rank + rope_dim] bf16 + KV_ptr, # [num_pages, kv_lora_rank + rope_dim] bf16 (flattened) + qo_indptr_ptr, # [batch+1] int32 + kv_indptr_ptr, # [batch+1] int32 + kv_indices_ptr, # [total_kv_indices] int32 + Logits_ptr, # [total_s, NUM_KV_SPLITS, nhead, kv_lora_rank] fp32 + LSE_ptr, # [total_s, NUM_KV_SPLITS, nhead, 1] fp32 + sm_scale, # scalar fp32 softmax scale + # strides (Q has 3 dims; KV is 2D flat; outputs as documented) + stride_q_s, + stride_q_h, + stride_kv_p, + stride_lo_s, + stride_lo_sp, + stride_lo_h, + stride_lse_s, + stride_lse_sp, + stride_lse_h, + # constants + KV_LORA_RANK: tl.constexpr, # 512 + ROPE_DIM: tl.constexpr, # 64 + NUM_HEADS: tl.constexpr, # 16 + BLOCK_M: tl.constexpr, # M-chunk size (4) + BLOCK_N: tl.constexpr, # KV block (32) + NUM_KV_SPLITS: tl.constexpr, # 16 + M_START: tl.constexpr, # starting M-offset within qo_len +): + pid = tl.program_id(0) + cur_batch = pid // NUM_KV_SPLITS + split_kv_id = pid % NUM_KV_SPLITS + + qo_start = tl.load(qo_indptr_ptr + cur_batch) + qo_end = tl.load(qo_indptr_ptr + cur_batch + 1) + qo_len = qo_end - qo_start + + kv_start_idx = tl.load(kv_indptr_ptr + cur_batch) + kv_end_idx = tl.load(kv_indptr_ptr + cur_batch + 1) + seq_len = kv_end_idx - kv_start_idx + + if seq_len == 0: + return + + kv_per_split = tl.cdiv(seq_len, NUM_KV_SPLITS) + split_start = kv_per_split * split_kv_id + split_end = tl.minimum(split_start + kv_per_split, seq_len) + + if split_end <= split_start: + return + + offs_m = M_START + tl.arange(0, BLOCK_M) + offs_h = tl.arange(0, NUM_HEADS) + offs_c = tl.arange(0, KV_LORA_RANK) + offs_r = tl.arange(0, ROPE_DIM) + mask_m = offs_m < qo_len + + # Load Q (bf16) + q_nope_offs = ( + (qo_start + offs_m[:, None, None]) * stride_q_s + + offs_h[None, :, None] * stride_q_h + + offs_c[None, None, :] + ) + q_nope = tl.load( + Q_ptr + q_nope_offs, + mask=mask_m[:, None, None], + other=0.0, + ) # [BLOCK_M, NUM_HEADS, KV_LORA_RANK] bf16 + q_pe_offs = ( + (qo_start + offs_m[:, None, None]) * stride_q_s + + offs_h[None, :, None] * stride_q_h + + (KV_LORA_RANK + offs_r[None, None, :]) + ) + q_pe = tl.load( + Q_ptr + q_pe_offs, + mask=mask_m[:, None, None], + other=0.0, + ) # [BLOCK_M, NUM_HEADS, ROPE_DIM] bf16 + + # Position of Q[m]: drafts occupy last qo_len positions + q_pos = seq_len - qo_len + offs_m # [BLOCK_M] + + e_max = tl.full([BLOCK_M, NUM_HEADS], float("-inf"), dtype=tl.float32) + e_sum = tl.zeros([BLOCK_M, NUM_HEADS], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, NUM_HEADS, KV_LORA_RANK], dtype=tl.float32) + + # Reshape Q for matmul: [M*H, C/R] + q_nope_2d = tl.reshape(q_nope, (BLOCK_M * NUM_HEADS, KV_LORA_RANK)) + q_pe_2d = tl.reshape(q_pe, (BLOCK_M * NUM_HEADS, ROPE_DIM)) + + for start_n in range(split_start, split_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < split_end + + kv_loc = tl.load( + kv_indices_ptr + kv_start_idx + offs_n, + mask=mask_n, + other=0, + ) + + k_nope_offs = kv_loc[:, None] * stride_kv_p + offs_c[None, :] + k_pe_offs = kv_loc[:, None] * stride_kv_p + (KV_LORA_RANK + offs_r[None, :]) + + k_nope = tl.load(KV_ptr + k_nope_offs, mask=mask_n[:, None], other=0.0) + k_pe = tl.load(KV_ptr + k_pe_offs, mask=mask_n[:, None], other=0.0) + + # qk_nope = q_nope @ k_nope^T → [M*H, N] + qk_nope = tl.dot(q_nope_2d, tl.trans(k_nope)) + qk_pe = tl.dot(q_pe_2d, tl.trans(k_pe)) + qk = (qk_nope + qk_pe) * sm_scale + qk = tl.reshape(qk, (BLOCK_M, NUM_HEADS, BLOCK_N)) + + # Causal + bounds mask + causal_mask = offs_n[None, :] <= q_pos[:, None] + valid = mask_m[:, None] & mask_n[None, :] & causal_mask + qk = tl.where(valid[:, None, :], qk, float("-inf")) + + # Online softmax + cur_max = tl.max(qk, axis=2) + new_max = tl.maximum(cur_max, e_max) + new_max_safe = tl.where(new_max == float("-inf"), 0.0, new_max) + rescale = tl.exp(e_max - new_max_safe) + rescale = tl.where(new_max == float("-inf"), 1.0, rescale) + + p = tl.exp(qk - new_max_safe[:, :, None]) + p = tl.where(valid[:, None, :], p, 0.0) + + acc = acc * rescale[:, :, None] + p_2d = tl.reshape(p, (BLOCK_M * NUM_HEADS, BLOCK_N)) + # V = k_nope (MLA absorbs W_uV into the output proj) + delta = tl.dot(p_2d.to(k_nope.dtype), k_nope) + delta = tl.reshape(delta, (BLOCK_M, NUM_HEADS, KV_LORA_RANK)) + acc = acc + delta + + e_sum = e_sum * rescale + tl.sum(p, axis=2) + e_max = new_max + + # Normalize + e_sum_safe = tl.where(e_sum == 0.0, 1.0, e_sum) + out = acc / e_sum_safe[:, :, None] + + out_offs = ( + (qo_start + offs_m[:, None, None]) * stride_lo_s + + split_kv_id * stride_lo_sp + + offs_h[None, :, None] * stride_lo_h + + offs_c[None, None, :] + ) + tl.store(Logits_ptr + out_offs, out, mask=mask_m[:, None, None]) + + lse_val = e_max + tl.log(e_sum_safe) + lse_val = tl.where(e_sum == 0.0, float("-inf"), lse_val) + lse_offs = ( + (qo_start + offs_m[:, None]) * stride_lse_s + + split_kv_id * stride_lse_sp + + offs_h[None, :] * stride_lse_h + ) + tl.store(LSE_ptr + lse_offs, lse_val, mask=mask_m[:, None]) + + +def mla_decode_stage1_fp8_multi( + q, # fp8 [total_s, nhead, kv_lora_rank + rope_dim] + kv_buffer, # fp8 [N, 1, 1, kv_lora_rank + rope_dim] + qo_indptr, # [batch+1] int32 + kv_indptr, # [batch+1] int32 + kv_indices, # int32 + num_kv_splits, # int + sm_scale, # float + logits, # fp32 [total_s, num_kv_splits, nhead, kv_lora_rank] + attn_lse, # fp32 [total_s, num_kv_splits, nhead, 1] + q_scale, # scalar tensor (assumed 1.0 in ATOM) + kv_scale, # scalar tensor (assumed 1.0 in ATOM) + max_seqlen_q, # int (passed by caller — no .item() needed) +): + """Wrapper: dequant fp8 → bf16, then launch Triton kernel. + Cudagraph-friendly: no .item() calls (would break stream capture). + """ + total_s, nhead, head_size = q.shape + kv_lora_rank = logits.shape[-1] + rope_dim = head_size - kv_lora_rank + batch = qo_indptr.shape[0] - 1 + + # Direct fp8 → bf16 cast (skip fp32, avoids OOM). + # ATOM's scales are always scalar 1.0 — skip mul to be cudagraph-safe. + q_bf16 = q.to(torch.bfloat16) + # KV: only dequant the entire pool to bf16 (2x memory but no sync) + kv_bf16 = kv_buffer.reshape(-1, head_size).to(torch.bfloat16) + + BLOCK_M = 4 + BLOCK_N = 32 + + grid = (batch * num_kv_splits,) + + # Iterate M-chunks. max_seqlen_q is a static int from caller. + for m_start in range(0, max_seqlen_q, BLOCK_M): + _mla_decode_stage1_bf16_multi_kernel[grid]( + q_bf16, + kv_bf16, + qo_indptr, + kv_indptr, + kv_indices, + logits, + attn_lse, + sm_scale, + q_bf16.stride(0), + q_bf16.stride(1), + kv_bf16.stride(0), + logits.stride(0), + logits.stride(1), + logits.stride(2), + attn_lse.stride(0), + attn_lse.stride(1), + attn_lse.stride(2), + KV_LORA_RANK=kv_lora_rank, + ROPE_DIM=rope_dim, + NUM_HEADS=nhead, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + NUM_KV_SPLITS=num_kv_splits, + M_START=m_start, + ) diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/peak_c4_757_moefp4.json b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/peak_c4_757_moefp4.json new file mode 100644 index 000000000..30be7782e --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/peak_c4_757_moefp4.json @@ -0,0 +1,47 @@ +{ + "total_input_tokens": 327680, + "request_throughput": 0.6574007275918786, + "output_throughput": 671.5184082169143, + "total_token_throughput": 6056.945168649584, + "mean_ttft_ms": 270.4857417149469, + "median_ttft_ms": 234.62767153978348, + "p99_ttft_ms": 778.3269615005702, + "mean_tpot_ms": 5.638050398616726, + "median_tpot_ms": 6.076577704902922, + "p99_tpot_ms": 6.907007610094518, + "mean_itl_ms": 14.447542221437494, + "median_itl_ms": 13.072218745946884, + "p99_itl_ms": 52.53850895911447, + "mean_e2el_ms": 6023.13564599026, + "median_e2el_ms": 6448.968381620944, + "p99_e2el_ms": 7557.125618392602, + "benchmark_args": { + "model": "/share4/teamK/DeepSeek-R1-0528-MXFP4-MTP-MoEFP4", + "backend": "vllm", + "base_url": "http://0.0.0.0:8888", + "dataset_name": "random", + "random_input_len": 8192, + "random_output_len": 1024, + "random_range_ratio": 1.0, + "num_prompts": 40, + "max_concurrency": 4, + "request_rate": "inf" + }, + "tput_per_gpu": 757.118146081198, + "interactivity": 164.56631488364647, + "baseline": { + "baseline_median_e2e": 5000, + "baseline_tput_per_gpu": 1500, + "baseline_median_intvty": 165 + }, + "accuracy": { + "task": "gsm8k", + "gsm8k_metric": 0.9356 + }, + "accuracy_validation": { + "status": "PASSED", + "baseline_gsm8k_metric": 0.38, + "gsm8k_tol": 0.0, + "minimum_accepted": 0.38 + } +} \ No newline at end of file diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/submit_bigbatch_c128_20260508T172456Z.json.json b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/submit_bigbatch_c128_20260508T172456Z.json.json new file mode 100644 index 000000000..309cad3e4 --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/submit_bigbatch_c128_20260508T172456Z.json.json @@ -0,0 +1,47 @@ +{ + "total_input_tokens": 10485760, + "request_throughput": 3.0714175179109326, + "output_throughput": 3136.433187948274, + "total_token_throughput": 28297.48549467463, + "mean_ttft_ms": 1899.4958611445327, + "median_ttft_ms": 425.0979139469564, + "p99_ttft_ms": 24958.320217244327, + "mean_tpot_ms": 38.32154737461625, + "median_tpot_ms": 41.54025027458266, + "p99_tpot_ms": 54.956891835074366, + "mean_itl_ms": 101.00990866562904, + "median_itl_ms": 41.269132401794195, + "p99_itl_ms": 784.8582692537459, + "mean_e2el_ms": 40987.4894070279, + "median_e2el_ms": 43978.565476834774, + "p99_e2el_ms": 65878.4762391169, + "benchmark_args": { + "model": "/share4/teamK/DeepSeek-R1-0528-MXFP4", + "backend": "vllm", + "base_url": "http://0.0.0.0:8888", + "dataset_name": "random", + "random_input_len": 8192, + "random_output_len": 1024, + "random_range_ratio": 1.0, + "num_prompts": 1280, + "max_concurrency": 128, + "request_rate": "inf" + }, + "tput_per_gpu": 3537.185686834329, + "interactivity": 24.073037436942755, + "baseline": { + "baseline_median_e2e": 22000, + "baseline_tput_per_gpu": 6000, + "baseline_median_intvty": 48 + }, + "accuracy": { + "task": "gsm8k", + "gsm8k_metric": 0.9348 + }, + "accuracy_validation": { + "status": "PASSED", + "baseline_gsm8k_metric": 0.38, + "gsm8k_tol": 0.0, + "minimum_accepted": 0.38 + } +} \ No newline at end of file diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/submit_c32_bb_level3_20260513T092735Z.json.json b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/submit_c32_bb_level3_20260513T092735Z.json.json new file mode 100644 index 000000000..eb058fabe --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/submit_c32_bb_level3_20260513T092735Z.json.json @@ -0,0 +1,47 @@ +{ + "total_input_tokens": 2621440, + "request_throughput": 2.041483538703479, + "output_throughput": 2084.6864340912916, + "total_token_throughput": 18808.519583150195, + "mean_ttft_ms": 734.0679752844153, + "median_ttft_ms": 268.71410105377436, + "p99_ttft_ms": 6127.636799840257, + "mean_tpot_ms": 14.244444768612727, + "median_tpot_ms": 15.34001882808294, + "p99_tpot_ms": 20.40315175209723, + "mean_itl_ms": 38.23034486036415, + "median_itl_ms": 23.143738508224487, + "p99_itl_ms": 226.56817093491554, + "mean_e2el_ms": 15262.674037579563, + "median_e2el_ms": 16460.214231628925, + "p99_e2el_ms": 22907.652122741565, + "benchmark_args": { + "model": "/share4/teamK/DeepSeek-R1-0528-MXFP4", + "backend": "vllm", + "base_url": "http://0.0.0.0:8888", + "dataset_name": "random", + "random_input_len": 8192, + "random_output_len": 1024, + "random_range_ratio": 1.0, + "num_prompts": 320, + "max_concurrency": 32, + "request_rate": "inf" + }, + "tput_per_gpu": 2351.0649478937744, + "interactivity": 65.1889682279465, + "baseline": { + "baseline_median_e2e": 18000, + "baseline_tput_per_gpu": 3900, + "baseline_median_intvty": 50 + }, + "accuracy": { + "task": "gsm8k", + "gsm8k_metric": 0.9393 + }, + "accuracy_validation": { + "status": "PASSED", + "baseline_gsm8k_metric": 0.38, + "gsm8k_tol": 0.0, + "minimum_accepted": 0.38 + } +} \ No newline at end of file diff --git a/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/submit_tp8_fp8_level3_c4_20260512T173100Z.json.json b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/submit_tp8_fp8_level3_c4_20260512T173100Z.json.json new file mode 100644 index 000000000..76d4b772e --- /dev/null +++ b/recipes/DeepSeek-R1-MXFP4-MI355X-Jons/results/submit_tp8_fp8_level3_c4_20260512T173100Z.json.json @@ -0,0 +1,47 @@ +{ + "total_input_tokens": 327680, + "request_throughput": 0.6176255129197001, + "output_throughput": 630.6574112423059, + "total_token_throughput": 5690.245613080489, + "mean_ttft_ms": 282.6445827726275, + "median_ttft_ms": 241.40197597444057, + "p99_ttft_ms": 796.159868352115, + "mean_tpot_ms": 5.987549479996434, + "median_tpot_ms": 6.465269571202922, + "p99_tpot_ms": 7.37491288906355, + "mean_itl_ms": 15.253988993494898, + "median_itl_ms": 13.838349841535091, + "p99_itl_ms": 46.62781701423226, + "mean_e2el_ms": 6388.816124852747, + "median_e2el_ms": 6836.589991580695, + "p99_e2el_ms": 7772.20810091123, + "benchmark_args": { + "model": "/share4/teamK/DeepSeek-R1-0528-MXFP4", + "backend": "vllm", + "base_url": "http://0.0.0.0:8888", + "dataset_name": "random", + "random_input_len": 8192, + "random_output_len": 1024, + "random_range_ratio": 1.0, + "num_prompts": 40, + "max_concurrency": 4, + "request_rate": "inf" + }, + "tput_per_gpu": 711.2807016350612, + "interactivity": 154.672591604551, + "baseline": { + "baseline_median_e2e": 5000, + "baseline_tput_per_gpu": 1500, + "baseline_median_intvty": 165 + }, + "accuracy": { + "task": "gsm8k", + "gsm8k_metric": 0.9356 + }, + "accuracy_validation": { + "status": "PASSED", + "baseline_gsm8k_metric": 0.38, + "gsm8k_tol": 0.0, + "minimum_accepted": 0.38 + } +} \ No newline at end of file