Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import functools
import json
import logging
import sys
import time
from collections import deque
Expand Down Expand Up @@ -40,8 +41,9 @@
)
from .sample_utils import make_sampler
from .tokenizer_utils import TokenizerWrapper
from .utils import does_model_support_input_embeddings, load
from .utils import does_model_support_input_embeddings, get_total_parameters, load

logger = logging.getLogger(__name__)
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.0
Expand Down Expand Up @@ -304,6 +306,40 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)


def _warn_speculative_moe(model: nn.Module, draft_model: nn.Module) -> None:
"""Warn if speculative decoding is unlikely to help for MoE models."""
model_config = getattr(model, "args", None)
if model_config is None:
return

num_experts = getattr(model_config, "num_experts", None) or getattr(
model_config, "num_local_experts", None
)
num_experts_per_tok = getattr(model_config, "num_experts_per_tok", None) or getattr(
model_config, "num_experts_per_token", None
)

if num_experts == 0:
return

draft_params = get_total_parameters(draft_model)
target_params = get_total_parameters(model)
active_params = target_params * (num_experts_per_tok / num_experts)
ratio = active_params / draft_params

if (
ratio < 4.0
): # warn if draft model is >25% of active params (empirically shows slowdown)
active_b = active_params / 1e9
draft_b = draft_params / 1e9
logger.warning(
f"Target model active parameters ({active_b:.1f}B) are close to "
f"draft model size ({draft_b:.1f}B). Speculative decoding may hurt "
f"throughput for MoE architectures. Consider benchmarking with and "
f"without --draft-model."
)


def generate_step(
prompt: mx.array,
model: nn.Module,
Expand Down Expand Up @@ -2057,6 +2093,7 @@ def main():
draft_model, draft_tokenizer = load(args.draft_model)
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
raise ValueError("Draft model tokenizer does not match model tokenizer.")
_warn_speculative_moe(model, draft_model)
else:
draft_model = None
sampler = make_sampler(
Expand Down