diff --git a/atom/config.py b/atom/config.py index 9ecbdb365..a7864b63c 100644 --- a/atom/config.py +++ b/atom/config.py @@ -609,19 +609,35 @@ def __post_init__(self): self.kv_cache_block_size % 16 == 0 or self.kv_cache_block_size == 1 ), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1" assert 1 <= self.tensor_parallel_size <= 8 - self.hf_config = get_hf_config(self.model) + if is_plugin_mode(): + # plugin mode + assert ( + self.plugin_config is not None + ), "plugin_config is required in plugin mode" + self.hf_config = self.plugin_config.model_config.hf_config + else: + self.hf_config = get_hf_config(self.model) + + self.generation_config = get_generation_config(self.model) + if self.generation_config is not None: + if ( + eos_ids := getattr(self.generation_config, "eos_token_id", None) + ) is not None: + self.stop_token_ids = ( + [eos_ids] if isinstance(eos_ids, int) else eos_ids + ) if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 - rope_params = getattr(self.hf_config, "rope_scaling", {}) + rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} rope_params["rope_theta"] = self.hf_config.rope_theta + rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") self.hf_config.rope_parameters = rope_params - self.generation_config = get_generation_config(self.model) - if self.generation_config is not None: - if ( - eos_ids := getattr(self.generation_config, "eos_token_id", None) - ) is not None: - self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids + # if self.generation_config is not None: + # if ( + # eos_ids := getattr(self.generation_config, "eos_token_id", None) + # ) is not None: + # self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids self.quant_config = get_quant_config(self.hf_config) hf_config_max_position_embeddings = getattr( self.hf_config, "max_position_embeddings", 8192 diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index b25e1aaba..8547add53 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -11,6 +11,8 @@ from atom.models.utils import maybe_prefix from atom.utils import envs +from aiter.rotary_embedding import AiterFusedSetKVBufferArg + class RadixAttention(BaseAttention): """ @@ -50,14 +52,22 @@ def __init__( ) if is_sglang(): - from sglang.srt.layers.radix_attention import RadixAttention + self.rotary_emb = rotary_emb + self.layer_num = layer_num + + self.k_scale = torch.tensor([1.0], dtype=torch.float32) + self.v_scale = torch.tensor([1.0], dtype=torch.float32) + + # if True, save cache will be done in rope + self.use_rope_fused_qknorm = envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION + from sglang.srt.layers.radix_attention import RadixAttention self.attn = RadixAttention( num_heads=num_heads, head_dim=head_dim, scaling=scale, num_kv_heads=num_kv_heads, - layer_id=layer_num, + layer_id=self.layer_num, prefix=maybe_prefix(prefix, "attn"), ) else: @@ -65,7 +75,7 @@ def __init__( "RadixAttention is only supported for plugin mode for sglang for now" ) # if True, save cache will be done in rope - self.use_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE + # self.use_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE def forward_impl_plugin_mode( self, @@ -78,24 +88,64 @@ def forward_impl_plugin_mode( output_block_scale: torch.Tensor | None = None, positions: torch.Tensor = None, q_scale: torch.Tensor = None, + qkv: torch.Tensor = None, **kwargs, ): - if is_sglang(): - # for sglang, forward_batch is required - forward_batch = kwargs.get("forward_batch", None) - assert forward_batch is not None, "forward_batch is required for sglang" - save_kv_cache = not self.use_rope_fused_qknorm - return self.attn( - query, - key, - value, - forward_batch=forward_batch, - save_kv_cache=save_kv_cache, + # for sglang, forward_batch is required + forward_batch = kwargs.get("forward_batch", None) + assert forward_batch is not None, "forward_batch is required for sglang" + + if self.use_rope_fused_qknorm: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + self.layer_num ) - else: - raise NotImplementedError( - "RadixAttention is only supported for plugin mode for sglang for now" + block_size = 1024 # Default fallback + if hasattr(forward_batch, "attn_backend") and hasattr( + forward_batch.attn_backend, "page_size" + ): + block_size = forward_batch.attn_backend.page_size + elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr( + forward_batch.token_to_kv_pool.allocator, "page_size" + ): + block_size = forward_batch.token_to_kv_pool.allocator.page_size + elif hasattr(forward_batch.token_to_kv_pool, "page_size"): + block_size = forward_batch.token_to_kv_pool.page_size + x = 16 // k_buffer.element_size() + aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( + kv_cache=(k_buffer, v_buffer), + cache_loc=forward_batch.out_cache_loc, + k_scale=self.k_scale, + v_scale=self.v_scale, + return_kv=True, + use_shuffle_layout=True, + block_size=block_size, + x=x, ) + q, k, v = self.rotary_emb( + qkv, + self.q_norm.weight, + self.k_norm.weight, + positions, + self.num_heads, + self.num_kv_heads, + self.q_norm.eps, + fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg, + ) + else: + # calculate the q and k with rotary embedding + assert self.rotary_emb is not None, "rotary_emb is required" + q, k = self.rotary_emb(positions, query, key) + v = value + + save_kv_cache = not self.use_rope_fused_qknorm + return self.attn( + q, + k, + v, + forward_batch=forward_batch, + save_kv_cache=save_kv_cache, + ) + def forward( self, @@ -104,6 +154,7 @@ def forward( value: torch.Tensor, positions: torch.Tensor = None, q_scale: Optional[torch.Tensor] = None, + qkv: torch.Tensor = None, **kwargs, ): if is_plugin_mode(): @@ -113,6 +164,7 @@ def forward( value=value, positions=positions, q_scale=q_scale, + qkv=qkv, **kwargs, ) else: diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 51927fe2c..a5fa065a9 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -5,7 +5,7 @@ from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size # from atom.model_ops.rotary_embedding import get_rope -from aiter.rotary_embedding import get_rope, AiterFusedSetKVBufferArg +from aiter.rotary_embedding import get_rope from atom.config import Config, QuantizationConfig from atom.model_ops.activation import SiluAndMul @@ -30,7 +30,6 @@ from atom.utils.decorators import support_torch_compile from torch import nn from atom.model_loader.loader import load_model_in_plugin_mode -from atom.plugin.prepare import is_sglang # import torch.distributed as dist from transformers import PretrainedConfig @@ -39,7 +38,6 @@ ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION ) -ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE class Qwen3MoeMLP(nn.Module): @@ -226,65 +224,6 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.layer_num = layer_num - self.k_scale = torch.tensor([1.0], dtype=torch.float32) - self.v_scale = torch.tensor([1.0], dtype=torch.float32) - - def forward_sgl_plugin_mode( - self, - positions: torch.Tensor, - qkv: torch.Tensor, - **model_kwargs: dict[str, Any] | None, - ): - if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: - forward_batch = model_kwargs.get("forward_batch", None) - assert forward_batch is not None, "forward_batch is required for sglang" - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( - self.layer_num - ) - block_size = 1024 # Default fallback - if hasattr(forward_batch, "attn_backend") and hasattr( - forward_batch.attn_backend, "page_size" - ): - block_size = forward_batch.attn_backend.page_size - elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr( - forward_batch.token_to_kv_pool.allocator, "page_size" - ): - block_size = forward_batch.token_to_kv_pool.allocator.page_size - elif hasattr(forward_batch.token_to_kv_pool, "page_size"): - block_size = forward_batch.token_to_kv_pool.page_size - x = 16 // k_buffer.element_size() - aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( - kv_cache=(k_buffer, v_buffer), - cache_loc=forward_batch.out_cache_loc, - k_scale=self.k_scale, - v_scale=self.v_scale, - return_kv=True, - use_shuffle_layout=True, - block_size=block_size, - x=x, - ) - q, k, v = self.rotary_emb( - qkv, - self.q_norm.weight, - self.k_norm.weight, - positions, - self.num_heads, - self.num_kv_heads, - self.q_norm.eps, - fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg, - ) - else: - q, k, v = torch.split( - qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 - ) - # Add qk-norm - q = self.q_norm(q) - k = self.k_norm(k) - - q, k = self.rotary_emb(positions, q, k) - - attn_output = self.attn(q, k, v, positions=positions, **model_kwargs) - return attn_output def forward( self, @@ -295,25 +234,23 @@ def forward( qkv = self.qkv_proj(hidden_states) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: - q, k, v = torch.split( - qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 - ) - attn_output = self.attn( - query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv - ) + attn_output = self.attn(query=q, + key=k, + value=v, + positions=positions, + q_scale=None, + qkv=qkv, + **model_kwargs) else: - if is_sglang(): - attn_output = self.forward_sgl_plugin_mode( - positions, qkv, **model_kwargs - ) - else: - # Add qk-norm - q = self.q_norm(q) - k = self.k_norm(k) + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) - attn_output = self.attn( - query=q, key=k, value=v, positions=positions, **model_kwargs - ) + attn_output = self.attn(query=q, + key=k, + value=v, + positions=positions, + **model_kwargs) output = self.o_proj(attn_output) return output diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 0f6b77760..62ce11bb5 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -42,7 +42,6 @@ "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1" ) == "1", - "ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE": lambda: os.getenv("ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE", "0") == "1", } diff --git a/bench_qwen.sh b/bench_qwen.sh new file mode 100644 index 000000000..864b18773 --- /dev/null +++ b/bench_qwen.sh @@ -0,0 +1,53 @@ +#!bin/bash + +MODEL=/mnt/raid0/pretrained_model/Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 + +RANGE_RATIO=0.8 + +# 1K/1K +ISL=1024 +OSL=1024 +CON=128 +#CON=128 +NUM=$(( CON * 4 )) + +# 4K/1K +#ISL=4000 +#OSL=1000 +#CON=128 +#CON=64 +#NUM=$(( CON * 4 )) + +# 10K/1K +#ISL=10000 +#OSL=1000 +#CON=64 +#CON=32 +#NUM=$(( CON * 4 )) + +echo "ATOM Model=${MODEL}" +echo "ATOM ISL=${ISL}, OSL=${OSL}, NUM=${NUM}, CON=${CON} RANGE_RATIO=${RANGE_RATIO}" + +sleep 2 + +# git clone https://github.com/kimbochen/bench_serving.git +python bench_serving/benchmark_serving.py \ + --model=$MODEL \ + --backend=vllm \ + --base-url=http://localhost:8000 \ + --dataset-name=random \ + --random-input-len=$ISL \ + --random-output-len=$OSL \ + --random-range-ratio ${RANGE_RATIO} \ + --num-prompts=${NUM} \ + --max-concurrency=${CON} \ + --request-rate=inf \ + --ignore-eos \ + --save-result \ + --percentile-metrics="ttft,tpot,itl,e2el" \ + --result-dir=./ \ + 2>&1 | tee log.bench.log + +echo "ATOM Model=${MODEL}" +echo "ATOM ISL=${ISL}, OSL=${OSL}, NUM=${NUM}, CON=${CON}" +rm -rf ./*.json \ No newline at end of file diff --git a/launch_qwen_atom.sh b/launch_qwen_atom.sh new file mode 100644 index 000000000..418653d4f --- /dev/null +++ b/launch_qwen_atom.sh @@ -0,0 +1,24 @@ +set -x +# export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 +# export AITER_ROPE_FUSED_QKNORM=1 + +# quick allreduce +export AITER_QUICK_REDUCE_QUANTIZATION=INT4 +# model_path=/mnt/raid0/pretrained_model/Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 +# model_path=/mnt/raid0/pretrained_model/Qwen/Qwen3-VL-235B-A22B-Instruct-FP8 +model_path=/mnt/raid0/pretrained_model/Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 + +# export SGLANG_TORCH_PROFILER_DIR=./profile_qwen3_sglang_atom + +TORCHINDUCTOR_COMPILE_THREADS=128 CUDA_VISIBLE_DEVICES="4,5,6,7" python3 -m sglang.launch_server \ + --model-path $model_path \ + --host localhost \ + --port 8000 \ + --trust-remote-code \ + --tensor-parallel-size 4 \ + --expert-parallel-size 4 \ + --kv-cache-dtype fp8_e4m3 \ + --mem-fraction-static 0.7 \ + --model-impl atom \ + --page-size 1024 \ + 2>&1 | tee log.serve.log \ No newline at end of file diff --git a/launch_qwen_sglang.sh b/launch_qwen_sglang.sh new file mode 100644 index 000000000..8b6da9952 --- /dev/null +++ b/launch_qwen_sglang.sh @@ -0,0 +1,33 @@ +set -x +# export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 +# export AITER_ROPE_FUSED_QKNORM=1 + +# quick allreduce +export AITER_QUICK_REDUCE_QUANTIZATION=INT4 +# model_path=/mnt/raid0/pretrained_model/Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 +# model_path=/mnt/raid0/pretrained_model/Qwen/Qwen3-VL-235B-A22B-Instruct-FP8 +model_path=/mnt/raid0/pretrained_model/Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 + +# export SGLANG_TORCH_PROFILER_DIR=./profile_qwen3_sglang + +TORCHINDUCTOR_COMPILE_THREADS=128 CUDA_VISIBLE_DEVICES="4,5,6,7" python3 -m sglang.launch_server \ + --model-path $model_path \ + --host localhost \ + --port 8000 \ + --trust-remote-code \ + --tensor-parallel-size 4 \ + --expert-parallel-size 4 \ + --kv-cache-dtype fp8_e4m3 \ + --mem-fraction-static 0.7 \ + --page-size 1024 \ + 2>&1 | tee log.serve.log + + + +# curl -X POST "http://localhost:8000/v1/completions" \ +# -H "Content-Type: application/json" \ +# -d '{ +# "prompt": "The capital of China", "temperature": 0, "top_p": 1, +# "top_k": 0, "repetition_penalty": 1.0, "presence_penalty": 0, "frequency_penalty": 0, +# "stream": false, "ignore_eos": false, "n": 1, "seed": 123 +# }' \ No newline at end of file diff --git a/val_gsm8k.sh b/val_gsm8k.sh new file mode 100644 index 000000000..040f30458 --- /dev/null +++ b/val_gsm8k.sh @@ -0,0 +1,10 @@ +addr=localhost +port=8000 +url=http://${addr}:${port}/v1/completions +model=/mnt/raid0/pretrained_model/Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 +task=gsm8k +lm_eval --model local-completions \ + --model_args model=${model},base_url=${url},num_concurrent=65,max_retries=1,tokenized_requests=False \ + --tasks ${task} \ + --num_fewshot 3 \ + 2>&1 | tee log.lmeval.log \ No newline at end of file