Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8,
per_token_quant_int8,
)
from sglang.srt.utils import (
direct_register_custom_op,
get_device_name,
Expand Down Expand Up @@ -116,6 +119,7 @@ def fused_moe_kernel(
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.

This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
Expand Down Expand Up @@ -166,17 +170,38 @@ def fused_moe_kernel(
)
b_scale = tl.load(b_scale_ptrs)

if use_fp8_w8a8 or use_int8_w8a8:
if use_fp8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)

if use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
# channel-wise
else:
# Load per-column scale for weights
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
)
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
Expand Down Expand Up @@ -495,9 +520,11 @@ def invoke_fused_moe_kernel(
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
# activation tensor-wise fp8 quantization
padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
if _is_cuda:
Expand All @@ -510,9 +537,10 @@ def invoke_fused_moe_kernel(
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
padded_size = padding_size
A, A_scale = ops.scaled_int8_quant(A, A_scale)
# activation channel-wise int8 quantization
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
Expand Down Expand Up @@ -1040,7 +1068,6 @@ def fused_experts_impl(
use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape,
)

if activation == "silu":
if _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
Expand Down
153 changes: 150 additions & 3 deletions python/sglang/srt/layers/quantization/w8a8_int8.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import torch

from sglang.srt.utils import is_cuda_available
from sglang.srt.utils import is_cuda_available, set_weight_attrs

is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import int8_scaled_mm

from torch.nn.parameter import Parameter

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import (
Expand Down Expand Up @@ -55,9 +56,12 @@ def get_quant_method(
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE

if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
Expand All @@ -81,7 +85,7 @@ def create_weights(
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs
**extra_weight_attrs,
):

weight_loader = extra_weight_attrs.get("weight_loader")
Expand Down Expand Up @@ -115,3 +119,146 @@ def apply(
return int8_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
)


class W8A8Int8MoEMethod:
"""MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""

def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase

if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)

def __init__(self, quant_config):
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported

tp_size = get_tensor_model_parallel_world_size()

# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)

extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)

set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)

w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts

# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
)

return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
# per_channel=True,
)
28 changes: 16 additions & 12 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,18 +1155,22 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
if (
hasattr(self.quant_config, "weight_block_size")
and w.dtype == torch.int8
):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w = int8_block_dequant(
weight, weight_scale, weight_block_size
).to(torch.bfloat16)
if w.dtype == torch.int8:
if hasattr(self.quant_config, "weight_block_size"):
# block-wise int8 need it
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w = int8_block_dequant(
weight, weight_scale, weight_block_size
).to(torch.bfloat16)
else:
# channel-wise int8 need it
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
torch.bfloat16
)
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"test_w8a8_quantization.py",
"test_fp8_kernel.py",
"test_block_int8.py",
"test_int8_kernel.py",
],
"nightly": [
"test_nightly_gsm8k_eval.py",
Expand Down
Loading