vLLM fork for Tesla V100 (SM70) extending 1CatAI/1Cat-vLLM's AWQ support with compressed-tensors, MoE, and improved kernel accuracy.
1CatAI's fork provides AWQ 4-bit inference on V100 via hand-tuned TurboMind SM70 CUDA kernels. This fork extends that foundation with:
- Compressed-tensors W4A16 on V100 -- lowers
min_capabilityfrom 75 to 70 (from vLLM PR #32597) - TurboMindLinearKernel -- uses 1Cat's
awq_gemm_sm70for dense linear layers instead of the Triton GPTQ kernel, which has ~2% mean relative error per matmul on V100 (compounds to garbage across deep networks). TurboMind achieves <0.1% error. - MoE compressed-tensors fix --
CompressedTensorsSM70WNA16MoEMethodwas missing ~20 layer attributes needed by the AWQ apply path. Fixed by delegating toAWQSM70MoEMethodafter CT-to-AWQ weight conversion. _DEFAULT_MAX_TOKENSnaming fix -- alias for renamed constant that broke the CT MoE import chain- DeepSeek-V4-Flash on V100 -- runnable model class for Intel's W4A16 AutoRound quant of V4-Flash (290B / ~37B active, 256 experts, MLA + sparse attention + Hyper-Connections). Includes a V100 fp16 sparse-attention kernel port, a
_hc_postclamp that prevents fp16 residual overflow at pos 0, an Obstacle-1 CPU-mirrorstart_posin attention metadata that drops the per-forward host sync, and a paged main-window KV cache (single-request scope; multi-request via paged compressor/indexer caches is the natural Stage-2 follow-up). - Mistral-Small-4 119B GGUF on V100 -- runnable model class for Bartowski's Q4_K_M GGUF of Mistral-Small-4-119B-2603 (MoE, MLA, fused
ffn_gate_up_exps+ splitattn_k_b/attn_v_btensor layouts). Ships with three latent fixes that affect any GGUF or MLA user on V100: a 4-site fp16 overflow clamp in the GGUF csrc kernels (kept kernels' internal fp32 accumulator, clamped to ±65504 at the implicit fp32→fp16 write-back), an MMQ kernel alignment dispatch ingguf.py(small dense models like Qwen2.5-0.5B with hidden=896 now correctly fall back to dequantize instead of reading past the qweight buffer), and a manual fp32 LSE-returning fallback inmla_attention.pyso MLA models with prefix caching / chunked prefill no longer crashmerge_attn_stateson V100. - MiMo-V2.5 310B GGUF on V100 -- runnable model class for Bartowski's Q3_K_M GGUF of XiaomiMiMo/MiMo-V2.5 (310B / 15B active, hybrid SWA + full attention with asymmetric head dims Q/K=192 / V=128, fused
attn_qkv+ 3Dffn_*_expstensor layouts, MTP blocks that we skip). Ships with two additional fixes that affect any V100 user, not just MiMo: an HDIM=192 template instantiation in theflash_attn_v100kernels (was only 64/80/96/112/128/256 before -- any model with head_dim=192 hit the default-case TORCH_CHECK), and an MMQ alignment guard mirrored from the dense path into_fused_moe_gguf(MoE models with K-quant experts whose per-rankw2-inputisn't aligned silently IMA-crashed once batch crossed the MMVQ→MMQ threshold). Also pinstriton==3.5.1inrequirements/cuda.txtto match torch 2.9.1+cu128's wheel metadata, since triton 3.6.0's MLA decode codegen is ~3× slower on V100 sm_70 at long context (verified on Mistral4 T2 stress: 23.5 → 9.4 tok/s with 3.6.0, restored with the pin). - Qwen3.6-35B-A3B GGUF on V100 -- runnable model class for Bartowski's Q8_0 GGUF of Qwen3.6-35B-A3B (35B / 3B active, 256-expert MoE, hybrid Gated-DeltaNet + full-attention every 4th layer, interleaved M-RoPE; text backbone
Qwen3_5MoeForCausalLM). transformers/vLLM have no GGUF support for archqwen35moe, so the loader binds the text backbone via--hf-config-path+--hf-overrides(stripvision_config). Three GGUF-interpretation fixes were needed, each affecting any GGUF user of this arch: (a) Gemma-style RMSNorm double-+1-- Qwen3.5/3.6 usey=(1+w)*xand llama.cpp bakes the+1into the GGUF norm weights, so vLLM re-adding it doubled every norm (the loader subtracts 1 at load, excluding the gatedlinear_attn.norm); (b) Gated-DeltaNetA_logdouble-exponentiation -- GGUFssm_aalready stores the decayA=-exp(A_log), so the loader storeslog(-ssm_a)to keep-exp(A_log)==ssm_ainstead of re-applying-exp()(which collapsed the recurrence in all 30 GDN layers); (c) GDN value-head TILE vs repeat_interleave order -- llama.cpp pairs value-headiwith key-headi % num_k, vLLM's FLA kernel usesi // r, so the loader permutes the value heads pre-shard (head boundaries align to Q8_0 blocks, so the packed-byte permute stays byte-clean and TP-safe). A follow-on loads the model's native MTP (nextn) head from GGUF block 40 as a speculative-decode draft (~60% acceptance) -- it's loadable but not recommended on V100: spec-decode is net-negative single-stream here because the fastflash_attn_v100backend can't keep CUDA graphs under spec-decode (forced to PIECEWISE -> ~46 tok/s) and thetriton_attnbackend that can keep them still loses to no-spec (~77 vs ~100 tok/s) once the draft + 2-token verify overhead is counted. The same draft-loader path would pay off on Ampere/Hopper. - Qwen3.5-122B-A10B GGUF on V100 -- runnable model class for Bartowski's Q6_K_L GGUF of Qwen3.5-122B-A10B (122B / 10B active, 256-expert MoE, same hybrid Gated-DeltaNet + full-attention
qwen3_5_moearch as the 35B above, but 48 layers and wider/narrower heads: GDN 64 value / 16 key, full-attention 32 query / 2 KV heads / head_dim 256). Rides on the merged 35B loader fixes, but the bigger Q6_K_L quant + 8-GPU sharding exposed two more loader bugs, each affecting any GGUF user of this arch with a K-quant GDNout_projor at TP >num_kv_heads: (a) Q6_K GDNout_projvalue-head column permute corrupts super-blocks -- the value-head reorder acts on the packed input (column) dim, which is byte-clean for Q8_0 (head_dim = 4 whole 32-elem blocks) but Q6_K's 256-elem super-blocks span two value heads, so a per-head column permute splits super-blocks -> corrupt scales -> Inf weights -> NaN logits ("!!!!" garbage). The loader now dequantizesout_proj(reference-correct gguf-py), permutes the value-head columns in float, and emits it as an unquantized F16 weight (~1.8 GiB extra total, uniform across quants -- the previously-merged Q8_0 35B re-verified coherent on this path); (b) full-attention KV-head replication for TP >num_kv_heads-- with only 2 KV heads but TP=4/8,QKVParallelLinearreplicates each KV head acrosstp // nkvranks, but the GGUF weight-loader divides k/v rows naively by TP with no replication, so the qkv output width stops matching the forward split ([q, k, v]) and the model crashes. The loader pre-replicates the GGUFk_proj/v_projrows byrepeat_interleave(tp // nkv)so contiguous TP sharding reproduces the expected per-rank head layout (mirrors the MiMo fused-qkv KV replication). Both fixes are config-driven, so they no-op when not needed (e.g. the 35B at TP=2, wherenkv == tp).
| Model | Params | Quant | Architecture | TP | Status |
|---|---|---|---|---|---|
| cyankiwi/MiniMax-M2.7-AWQ-4bit | 240B (11B active) | compressed-tensors W4A16 | MoE (256 experts) | 8 | Working |
| cyankiwi/Qwen3.6-27B-AWQ-INT4 | 27B | compressed-tensors W4A16 (asymmetric) | Hybrid Gated DeltaNet | 4 | Working (greedy + tool-calling smoke) |
| cyankiwi/granite-4.1-8b-AWQ-INT4 | 8B | compressed-tensors W4A16 group_size=32 (asymmetric) | Dense (GraniteForCausalLM) | 2 | Working (cudagraph; ~127 tok/s single-stream, ~587 tok/s aggregate batch=8) |
| Intel/DeepSeek-V4-Flash-W4A16-AutoRound | 290B (37B active) | auto-round W4A16 | MoE (256 experts) + MLA + sparse-attn + Hyper-Connections | 8 | Working (single-request, ~5.66 tok/s decode-only) |
| bartowski/mistralai_Mistral-Small-4-119B-2603-GGUF (Q4_K_M) | 119B | GGUF Q4_K_M | MoE + MLA (Mistral4ForCausalLM) |
8 | Working (cudagraph; ~82 tok/s short prompt, ~24 tok/s @ 6k-tok prompt, ~26 tok/s prefix-cache replay) |
| bartowski/MiMo-V2.5-GGUF (Q3_K_M) | 310B (15B active) | GGUF Q3_K_M | MoE + hybrid SWA + asymmetric head_dim (MiMoV2FlashForCausalLM) |
8 | Working (cudagraph; ~42 tok/s single-stream, ~64 tok/s aggregate batch=8) |
| bartowski/Qwen_Qwen3.6-35B-A3B-GGUF (Q8_0) | 35B (3B active) | GGUF Q8_0 | MoE (256 experts) + hybrid Gated-DeltaNet (Qwen3_5MoeForCausalLM) |
2 | Working (cudagraph; ~100 tok/s single-stream, ~1900 tok/s aggregate 4×TP2) |
| bartowski/Qwen_Qwen3.5-122B-A10B-GGUF (Q6_K_L) | 122B (10B active) | GGUF Q6_K_L | MoE (256 experts) + hybrid Gated-DeltaNet (Qwen3_5MoeForCausalLM) |
8 (PP2×TP4) | Working (cudagraph; coherent greedy + chat -- throughput not yet benched) |
- 8x Tesla V100 SXM2 32GB (TP=8, or PP=2×TP=4 across the two 4-GPU NVLink islands; no expert parallel)
- Expert parallel corrupts MoE output for MiniMax M2.7 on this fork. Use tensor parallelism without
--enable-expert-parallel. Root cause is likely in the EP code path for 256-expert models. - V100 Triton JIT compilation takes 30-90 minutes on first request. Set
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=3000to avoid pod kills. - Do NOT use
--quantization gptq_marlinorCUDA_LAUNCH_BLOCKING=1on V100. - Custom all-reduce is unsupported on SM70. V100 (compute capability 7.0) has no symmetric-memory communicator (
SymmMemCommunicator: Device capability 7.0 not supported). For multi-GPU TP, pass--disable-custom-all-reduceso vLLM uses the NCCL all-reduce path; this also lets cudagraph (non-eager) capture succeed, so--enforce-eageris usually unnecessary. --kv-cache-auto-trim-ratiocollapses the cache for hybrid models. The default (1.05) sizes the cache toper_request_KV x max_num_seqs x 1.05. For hybrid Gated-DeltaNet / Mamba models the per-request KV estimate is dominated by tiny fixed-size recurrent state, so the cache is trimmed to a few hundred tokens (e.g. 21.8 GiB -> 0.05 GiB / 784 tokens on Qwen3.6-27B). Pass--kv-cache-auto-trim-ratio 0(or--kv-cache-memory-bytes) to disable trimming.- GGUF models with new-arch model classes (
mistral4,qwen3_5_moe) needtransformers >= 5. These GGUFs load via a path that builds a dummy HF model (AutoModelForCausalLM.from_config) to derive the tensor-name map, which needs the nativeMistral4Config/Qwen3_5MoeConfigclasses -- present only intransformers >= 5.x. Ontransformers < 5(4.57.x) loading fails (Unrecognized configuration class Qwen3_5MoeTextConfig/architecture mistral4 not supported). Resolved: the pin is nowtransformers >= 5.12.1, < 6(validated on 5.12.1). AWQ/compressed-tensors models are unaffected either way (own model class, no GGUF dummy-model path). Note: install this fork's prebuilt wheel directly (it pins transformers correctly) -- copying only the.pyoverlays onto a base wheel leaves the compiled_Cwithout the SM70 GGUF/MLA csrc patches and breaks long-context MLA (gather_and_maybe_dequant_cache only support head_dim 576).
Build the image from the included Dockerfile:
docker build -f docker/Dockerfile.v100 -t vllm-v100:latest .Dockerfile.v100 builds this fork's own wheels from source for SM70
(TORCH_CUDA_ARCH_LIST=7.0) in a builder stage, then installs them into a
slim runtime image: PyTorch (cu128), the vllm wheel (vendored TurboMind
SM70 AWQ GEMM included), and the flash_attn_v100 wheel. Building from
source keeps the image in lockstep with this repo's Python and CUDA
patches -- the GGUF csrc fp16 clamps and the flash_attn_v100 HDIM
templates live in compiled code and cannot be delivered by copying .py
files over a prebuilt wheel. The FA-V100 wheel unlocks
--attention-backend FLASH_ATTN_V100 (the SM70 FlashAttention-2 path);
without it the registered backend silently falls back to Triton.
Note: the older
docker/Dockerfile.sm70-wheelinstalled 1Cat's prebuiltv0.0.2wheel and overlaid only a handful of patched.pyfiles. That overlay list drifted out of date (it was missing ~35 of the 44 changed Python files, includingturbomind_asym.py) and could not ship the fork'scsrc/kernel fixes at all. PreferDockerfile.v100.
If you need a different Python or CUDA combo than the published wheel
(cp312-cp312-linux_x86_64, cu128), build the extension from the
vendored source under flash-attention-v100/:
# Requires nvcc on PATH and the same torch already in the venv.
PATH=/usr/local/cuda-12.8/bin:$PATH \
TORCH_CUDA_ARCH_LIST="7.0" \
pip install -e flash-attention-v100/ --no-build-isolation--no-build-isolation is important: it ensures the build picks up the
torch you already have installed instead of pulling a different version.
docker run --rm --gpus all --ipc=host \
-v /path/to/models:/models:ro \
-e VLLM_MODEL=/models/cyankiwi/MiniMax-M2.7-AWQ-4bit \
-e VLLM_SERVED_MODEL_NAME=MiniMax-M2.7 \
-e VLLM_QUANTIZATION=compressed-tensors \
-e VLLM_DTYPE=float16 \
-e VLLM_TENSOR_PARALLEL_SIZE=8 \
-e VLLM_GPU_MEMORY_UTILIZATION=0.90 \
-e VLLM_MAX_MODEL_LEN=32768 \
-e VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=3000 \
-p 8000:8000 \
vllm-v100:latestHybrid Gated DeltaNet, asymmetric compressed-tensors W4A16. Requires the
TurboMindAsymLinearKernel for dense Linear (already in this fork).
Three V100-specific flags matter here (all verified on 4x V100 SXM2):
--disable-custom-all-reduce(not--enforce-eager) -- cudagraphs capture fine on this model; the real blocker is the custom/symmetric-memory all-reduce, which is unsupported on SM70 (SymmMemCommunicator: Device capability 7.0 not supported). Disable it and the non-eager path runs with theFLASH_ATTN_V100decode kernel (CUDA-graph safe). Eager is ~3x slower and is not needed -- the previously documentedcausal_conv1dcuda-graph assertion does not reproduce once custom all-reduce is off.--kv-cache-auto-trim-ratio 0-- the hybrid Gated-DeltaNet per-request KV estimate is tiny (most layers are fixed-size recurrent state), so the default auto-trim (1.05) collapses the cache to ~784 tokens. Disabling the trim restores the full cache (~356k tokens here atgpu-mem 0.92).--reasoning-parser deepseek_r1-- routes<think>...</think>reasoning intoreasoning_content. The Qwen3.5/3.6 chat templates inject the opening<think>into the prompt, so the model emits only the closing</think>.deepseek_r1handles this in both streaming and non-streaming. This fork'sqwen3parser is also fixed to key off the closing</think>(issue #16) and works for non-streaming; usedeepseek_r1if you stream responses.
Tool-calling uses the qwen3_coder parser.
docker run --rm --gpus '"device=0,1,2,3"' --ipc=host \
-v /path/to/models:/models:ro \
-e VLLM_MODEL=/models/cyankiwi/Qwen3.6-27B-AWQ-INT4 \
-e VLLM_SERVED_MODEL_NAME=Qwen3.6-27B-AWQ-INT4 \
-e VLLM_QUANTIZATION=compressed-tensors \
-e VLLM_DTYPE=float16 \
-e VLLM_TENSOR_PARALLEL_SIZE=4 \
-e VLLM_GPU_MEMORY_UTILIZATION=0.92 \
-e VLLM_MAX_MODEL_LEN=262144 \
-e VLLM_MAX_NUM_SEQS=4 \
-e VLLM_MAX_NUM_BATCHED_TOKENS=4096 \
-e VLLM_ATTENTION_BACKEND=FLASH_ATTN_V100 \
-e VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=3000 \
-p 8000:8000 \
vllm-v100:latest \
--disable-custom-all-reduce \
--kv-cache-auto-trim-ratio 0 \
--enable-auto-tool-choice \
--tool-call-parser qwen3_coder \
--reasoning-parser deepseek_r1 \
--default-chat-template-kwargs '{"enable_thinking":true}'The equivalent native (non-Docker) invocation, for running from a source checkout or a pip-installed wheel:
python -m vllm.entrypoints.openai.api_server \
--model cyankiwi/Qwen3.6-27B-AWQ-INT4 \
--served-model-name Qwen3.6-27B-AWQ-INT4 \
--host 0.0.0.0 --port 8000 \
--tensor-parallel-size 4 \
--disable-custom-all-reduce \
--gpu-memory-utilization 0.92 \
--max-model-len 262144 \
--max-num-seqs 4 \
--max-num-batched-tokens 4096 \
--attention-backend FLASH_ATTN_V100 \
--enable-auto-tool-choice \
--tool-call-parser qwen3_coder \
--reasoning-parser deepseek_r1 \
--default-chat-template-kwargs '{"enable_thinking":true}' \
--kv-cache-auto-trim-ratio 0Pure dense GraniteForCausalLM, asymmetric compressed-tensors W4A16 with
group_size=32. Uses the existing TurboMindAsymLinearKernel (already in
this fork; no new code path needed). The
compile_ranges_split_points:[] setting disables the chunked-prefill
split that otherwise triggers a silent FLASH_ATTN_V100 fallback path
producing all-token-id-0 ("!") garbage. Cudagraph capture engages
cleanly -- do not add --enforce-eager (eager mode is ~3x slower
on this model). Local bench (TP=2, dual V100 32GB SXM2, 32-prompt ->
128-gen): 126.6 tok/s decode at batch=1; 586.8 tok/s aggregate / 73.3
per-seq at batch=8.
docker run --rm --gpus '"device=0,1"' --ipc=host \
-v /path/to/models:/models:ro \
-e VLLM_MODEL=/models/cyankiwi/granite-4.1-8b-AWQ-INT4 \
-e VLLM_SERVED_MODEL_NAME=granite-4.1-8b-AWQ-INT4 \
-e VLLM_QUANTIZATION=compressed-tensors \
-e VLLM_DTYPE=float16 \
-e VLLM_TENSOR_PARALLEL_SIZE=2 \
-e VLLM_GPU_MEMORY_UTILIZATION=0.85 \
-e VLLM_MAX_MODEL_LEN=8192 \
-e VLLM_MAX_NUM_SEQS=16 \
-e VLLM_MAX_NUM_BATCHED_TOKENS=4096 \
-e VLLM_COMPILATION_CONFIG='{"compile_ranges_split_points":[]}' \
-e VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=3000 \
-p 8000:8000 \
vllm-v100:latest \
--attention-backend FLASH_ATTN_V100Single-request only (bsz==1); compressor and indexer KV are kept on
module buffers rather than the paged cache for now. --enforce-eager
is required (cudagraph engagement is blocked by three uncaptureable
paths in the model -- TileLang JIT, TileLang deprecation warn, and a
Hash-MoE Python-state contract; the realistic post-cudagraph speedup
ceiling is also bounded by TP all-reduce dominating ~38% of decode-time
GPU work, so eager is the practical ship target on V100 SXM2).
--max-num-seqs=4 is the sampler warmup OOM headroom; block_size=64
matches the V100 sparse-attn kernel's BLOCK_N. Decode-only throughput
in this configuration is ~5.66 tok/s warm (median ~5.27 across 4
fresh-process runs at TP=8, 4096-token context).
docker run --rm --gpus all --ipc=host \
-v /path/to/models:/models:ro \
-e VLLM_MODEL=/models/Intel/DeepSeek-V4-Flash-W4A16-AutoRound \
-e VLLM_SERVED_MODEL_NAME=V4-Flash-W4A16 \
-e VLLM_QUANTIZATION=auto-round \
-e VLLM_DTYPE=float16 \
-e VLLM_TENSOR_PARALLEL_SIZE=8 \
-e VLLM_GPU_MEMORY_UTILIZATION=0.85 \
-e VLLM_MAX_MODEL_LEN=4096 \
-e VLLM_MAX_NUM_SEQS=4 \
-e VLLM_BLOCK_SIZE=64 \
-e VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=3000 \
-p 8000:8000 \
vllm-v100:latest \
--enforce-eager \
--no-enable-prefix-cachingBartowski's GGUF for mistralai/Mistral-Small-4-119B-2603. Requires
transformers >= 5.12.1 (ships Mistral4Config/Mistral4ForCausalLM; the
pin enforces it). Three flags make this launch non-trivial:
--hf-config-pathmust point at a FLAT mistral4 text config, not the HF repo directly. transformers has nomistral4in its GGUF arch allowlist, so config has to come from--hf-config-path-- but the HF repo'sconfig.jsonis the multimodalmistral3wrapper, which transformers 5.x loads as apixtralconfig whose nested text config mis-types todeepseek_v3. So build a flat config dir once from the repo's rawtext_config(wheremodel_typereally ismistral4):python - <<'PY' import json, os from huggingface_hub import hf_hub_download tc = json.load(open(hf_hub_download('mistralai/Mistral-Small-4-119B-2603','config.json')))['text_config'] tc['architectures'] = ['Mistral4ForCausalLM']; tc.pop('quantization_config', None) d = os.path.expanduser('~/models/mistral4-hf-config'); os.makedirs(d, exist_ok=True) json.dump(tc, open(f'{d}/config.json','w'), indent=2) PY
- Do NOT force
--attention-backend. Mistral4 is MLA; a forced backend (e.g. the entrypoint defaultTRITON_ATTN) errors withMLA not supported. Leave it unset so vLLM auto-selectsTRITON_MLA(with the SM70 SDPA prefill fallback).--tokenizer mistralai/...pulls the official chat template ([MODEL_SETTINGS]{...}[INST]...[/INST]); the repo's weights are gated but its tokenizer/config metadata is public (no token needed).
Cudagraph capture engages -- do not add --enforce-eager. Prefix caching is
supported (the LSE-SDPA fallback in mla_attention.py keeps merge_attn_states
happy on V100). Verified via vllm serve on 8x V100 (TP=8, max_model_len=16384):
~86 tok/s short-prompt decode, ~24 tok/s at 6k-token prompt + 512-token gen
(chunked prefill, max_num_batched_tokens=2048), ~26 tok/s prefix-cache replay.
vllm serve /models/.../mistralai_Mistral-Small-4-119B-2603-Q4_K_M-00001-of-00002.gguf \
--served-model-name Mistral-Small-4-119B-Q4_K_M \
--hf-config-path ~/models/mistral4-hf-config \
--tokenizer mistralai/Mistral-Small-4-119B-2603 \
--quantization gguf --dtype float16 \
--tensor-parallel-size 8 --gpu-memory-utilization 0.70 \
--max-model-len 16384 --max-num-seqs 1 --max-num-batched-tokens 2048 \
--disable-custom-all-reduce --enable-prefix-caching \
--enable-auto-tool-choice --tool-call-parser mistral \
--host 0.0.0.0 --port 8000
# (VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=3000 for the slow first-request Triton JIT)Bartowski's GGUF for XiaomiMiMo/MiMo-V2.5. Three things make this launch
non-trivial vs Mistral4: (a) --hf-config-path routes config through the
full HF repo because transformers' GGUF parser doesn't have mimo2 in
its arch allowlist, (b) --hf-overrides strips the fp8 native-quant
declaration plus the unused vision/audio/processor sub-configs, and
(c) --trust-remote-code is needed for the GGUF loader's dummy
meta-model build (transformers ships no native MiMoV2 class).
Cudagraph capture engages -- do not add --enforce-eager. Tool
calling works via the qwen3_coder parser (MiMo's
<tool_call><function=...><parameter=...></parameter></function></tool_call>
envelope is token-identical to qwen3-coder's). Local bench (TP=8,
max_model_len=4096, cudagraph + chunked-prefill + prefix-cache,
max_num_seqs=8): ~42 tok/s single-stream short-decode, ~64 tok/s
aggregate at batch=8.
docker run --rm --gpus all --ipc=host \
-v /path/to/models:/models:ro \
-e VLLM_MODEL=/models/bartowski/MiMo-V2.5-GGUF/MiMo-V2.5-Q3_K_M/MiMo-V2.5-Q3_K_M-00001-of-00004.gguf \
-e VLLM_SERVED_MODEL_NAME=MiMo-V2.5-Q3_K_M \
-e VLLM_QUANTIZATION=gguf \
-e VLLM_DTYPE=float16 \
-e VLLM_TENSOR_PARALLEL_SIZE=8 \
-e VLLM_GPU_MEMORY_UTILIZATION=0.92 \
-e VLLM_MAX_MODEL_LEN=4096 \
-e VLLM_MAX_NUM_SEQS=8 \
-e VLLM_MAX_NUM_BATCHED_TOKENS=2048 \
-e VLLM_TOKENIZER=XiaomiMiMo/MiMo-V2.5 \
-e VLLM_HF_CONFIG_PATH=XiaomiMiMo/MiMo-V2.5 \
-e VLLM_HF_OVERRIDES='{"quantization_config":null,"vision_config":null,"audio_config":null,"processor_config":null}' \
-e VLLM_TRUST_REMOTE_CODE=1 \
-e VLLM_ATTENTION_BACKEND=FLASH_ATTN_V100 \
-e VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=3000 \
-p 8000:8000 \
vllm-v100:latest \
--enable-chunked-prefill \
--enable-prefix-caching \
--enable-auto-tool-choice \
--tool-call-parser qwen3_coderBartowski's GGUF for Qwen/Qwen3.6-35B-A3B. The text GGUF (arch qwen35moe)
carries no vision tensors, so --hf-config-path routes config through the full
HF repo (transformers' GGUF parser has no qwen35moe) and --hf-overrides
binds the text backbone Qwen3_5MoeForCausalLM while nulling vision_config.
TP=2 is the minimum -- the 35 GiB Q8_0 weights don't fit one 32 GiB card.
--mamba-cache-mode align unifies the hybrid Gated-DeltaNet recurrent state
with the full-attention KV pages. Cudagraph capture engages -- do not add
--enforce-eager (it costs ~11× decode). Local bench (TP=2, max_model_len=8192,
cudagraph): ~100 tok/s single-stream; ~1900 tok/s aggregate across a 4×TP=2
replica fleet (8 GPUs). The native MTP speculative-decode head loads but is
net-negative on V100 (see notes above), so it is left disabled here.
docker run --rm --gpus '"device=0,1"' --ipc=host \
-v /path/to/models:/models:ro \
-e VLLM_MODEL=/models/Qwen3.6-35B-A3B-GGUF/Qwen_Qwen3.6-35B-A3B-Q8_0.gguf \
-e VLLM_SERVED_MODEL_NAME=Qwen3.6-35B-A3B \
-e VLLM_QUANTIZATION=gguf \
-e VLLM_DTYPE=float16 \
-e VLLM_TENSOR_PARALLEL_SIZE=2 \
-e VLLM_GPU_MEMORY_UTILIZATION=0.90 \
-e VLLM_MAX_MODEL_LEN=8192 \
-e VLLM_TOKENIZER=Qwen/Qwen3.6-35B-A3B \
-e VLLM_HF_CONFIG_PATH=Qwen/Qwen3.6-35B-A3B \
-e VLLM_HF_OVERRIDES='{"architectures":["Qwen3_5MoeForCausalLM"],"vision_config":null}' \
-e VLLM_TRUST_REMOTE_CODE=1 \
-e VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=3000 \
-p 8000:8000 \
vllm-v100:latest \
--mamba-cache-mode alignBartowski's Q6_K_L GGUF for Qwen/Qwen3.5-122B-A10B (3 split shards, ~102 GiB;
VLLM_MODEL points at the first, the loader globs the siblings). Same
qwen3_5_moe hybrid Gated-DeltaNet arch as the 35B, so the launch flags match
(--hf-config-path + --hf-overrides bind the text backbone
Qwen3_5MoeForCausalLM and null vision_config; --mamba-cache-mode align).
Two differences from the 35B route through the loader automatically: the Q6_K_L
quant uses the GDN out_proj dequant-to-F16 path, and the 2-KV-head
full-attention layers use the KV-head replication path (both described above).
On this 8x V100 SXM2 node NVLink forms two 4-GPU islands (0-3, 4-7), so PP=2 ×
TP=4 keeps every tensor-parallel all-reduce inside an island and runs faster
than TP=8 (which crosses the slower inter-island link on every all-reduce) --
pass --pipeline-parallel-size 2 as a trailing arg and set
VLLM_TENSOR_PARALLEL_SIZE=4. ~18 GiB/GPU weights, large KV headroom. Cudagraph
capture engages -- do not add --enforce-eager. (The env-var block below is
the Docker translation of the verified bare-metal vllm serve config; pass
--tensor-parallel-size 8 with no --pipeline-parallel-size for the single-TP
fallback.)
docker run --rm --gpus all --ipc=host \
-v /path/to/models:/models:ro \
-e VLLM_MODEL=/models/Qwen3.5-122B-A10B-GGUF/Qwen_Qwen3.5-122B-A10B-Q6_K_L/Qwen_Qwen3.5-122B-A10B-Q6_K_L-00001-of-00003.gguf \
-e VLLM_SERVED_MODEL_NAME=Qwen3.5-122B-A10B \
-e VLLM_QUANTIZATION=gguf \
-e VLLM_DTYPE=float16 \
-e VLLM_TENSOR_PARALLEL_SIZE=4 \
-e VLLM_GPU_MEMORY_UTILIZATION=0.90 \
-e VLLM_MAX_MODEL_LEN=8192 \
-e VLLM_TOKENIZER=Qwen/Qwen3.5-122B-A10B \
-e VLLM_HF_CONFIG_PATH=Qwen/Qwen3.5-122B-A10B \
-e VLLM_HF_OVERRIDES='{"architectures":["Qwen3_5MoeForCausalLM"],"vision_config":null}' \
-e VLLM_TRUST_REMOTE_CODE=1 \
-e VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=3000 \
-p 8000:8000 \
vllm-v100:latest \
--pipeline-parallel-size 2 \
--mamba-cache-mode aligndocker run --rm --gpus '"device=0,1"' --ipc=host \
-v /path/to/models:/models:ro \
-e VLLM_MODEL=/models/Qwen3.5-27B-AWQ \
-e VLLM_SERVED_MODEL_NAME=Qwen3.5-27B-AWQ \
-e VLLM_QUANTIZATION=awq \
-e VLLM_DTYPE=float16 \
-e VLLM_TENSOR_PARALLEL_SIZE=2 \
-e VLLM_GPU_MEMORY_UTILIZATION=0.90 \
-e VLLM_MAX_MODEL_LEN=262144 \
-e VLLM_MAX_NUM_SEQS=4 \
-e VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=3000 \
-p 8000:8000 \
vllm-v100:latest \
--attention-backend TRITON_ATTN \
--skip-mm-profiling \
--limit-mm-per-prompt '{"image":0,"video":0}'curl http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model": "MiniMax-M2.7",
"messages": [{"role": "user", "content": "What is 2+2?"}],
"max_tokens": 32
}'For compressed-tensors W4A16 on V100, the kernel selection order is:
- TurboMindLinearKernel (preferred) -- converts CT pack-quantized weights to AWQ format, uses
awq_gemm_sm70(<0.1% error) - TritonLinearKernel (fallback) -- Triton GPTQ kernel from PR #32597 (~2% error, unsuitable for deep networks)
- ExllamaLinearKernel (existing) -- standard Exllama path
The TurboMindLinearKernel handles weight format conversion at load time:
permute_param_layout_to get CT[K/8, N]with sequential packing- Unpack CT nibbles to
[K, N] - Repack as AWQ
[K, N/8]with interleaved order - Generate symmetric qzeros (
0x88888888) awq_sm70_preparefor TurboMind format
For MoE models using compressed-tensors quantization, CompressedTensorsSM70WNA16MoEMethod converts weights from CT to AWQ format, then delegates to AWQSM70MoEMethod for TurboMind setup (alignment, strided ptrs, buffer allocation).
- GPU: Tesla V100 SXM2 32GB
- CUDA: 12.8
- Python: 3.12
- PyTorch: 2.9.1+cu128
- Driver: 570.x
- 1CatAI/1Cat-vLLM -- TurboMind SM70 AWQ CUDA kernels and base V100 support
- vLLM -- upstream inference engine
- lmdeploy / TurboMind -- original SM70 WMMA kernels
Apache 2.0 -- same as upstream vLLM. See LICENSE.