diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..35111623c 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -5,6 +5,7 @@ import copy import functools import json +import logging import sys import time from collections import deque @@ -40,8 +41,9 @@ ) from .sample_utils import make_sampler from .tokenizer_utils import TokenizerWrapper -from .utils import does_model_support_input_embeddings, load +from .utils import does_model_support_input_embeddings, get_total_parameters, load +logger = logging.getLogger(__name__) DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.0 @@ -304,6 +306,40 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) +def _warn_speculative_moe(model: nn.Module, draft_model: nn.Module) -> None: + """Warn if speculative decoding is unlikely to help for MoE models.""" + model_config = getattr(model, "args", None) + if model_config is None: + return + + num_experts = getattr(model_config, "num_experts", None) or getattr( + model_config, "num_local_experts", None + ) + num_experts_per_tok = getattr(model_config, "num_experts_per_tok", None) or getattr( + model_config, "num_experts_per_token", None + ) + + if num_experts == 0: + return + + draft_params = get_total_parameters(draft_model) + target_params = get_total_parameters(model) + active_params = target_params * (num_experts_per_tok / num_experts) + ratio = active_params / draft_params + + if ( + ratio < 4.0 + ): # warn if draft model is >25% of active params (empirically shows slowdown) + active_b = active_params / 1e9 + draft_b = draft_params / 1e9 + logger.warning( + f"Target model active parameters ({active_b:.1f}B) are close to " + f"draft model size ({draft_b:.1f}B). Speculative decoding may hurt " + f"throughput for MoE architectures. Consider benchmarking with and " + f"without --draft-model." + ) + + def generate_step( prompt: mx.array, model: nn.Module, @@ -2057,6 +2093,7 @@ def main(): draft_model, draft_tokenizer = load(args.draft_model) if draft_tokenizer.vocab_size != tokenizer.vocab_size: raise ValueError("Draft model tokenizer does not match model tokenizer.") + _warn_speculative_moe(model, draft_model) else: draft_model = None sampler = make_sampler(