From e4a4ae31b35428548074c1a699861b48f1a37cb8 Mon Sep 17 00:00:00 2001 From: Shylin Date: Tue, 26 May 2026 21:48:33 +0530 Subject: [PATCH 1/2] warn when speculative decoding may hurt MoE model throughput --- mlx_lm/generate.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..b0211c3dd 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -40,7 +40,7 @@ ) from .sample_utils import make_sampler from .tokenizer_utils import TokenizerWrapper -from .utils import does_model_support_input_embeddings, load +from .utils import does_model_support_input_embeddings, get_total_parameters, load DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 @@ -304,6 +304,39 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) +def _warn_speculative_moe(model: nn.Module, draft_model: nn.Module) -> None: + """Warn if speculative decoding is unlikely to help for MoE models.""" + model_config = getattr(model, "args", None) + if model_config is None: + return + + num_experts = getattr(model_config, "num_experts", None) or getattr( + model_config, "num_local_experts", None + ) + num_experts_per_tok = getattr(model_config, "num_experts_per_tok", None) or getattr( + model_config, "num_experts_per_token", None + ) + + if num_experts is None or num_experts_per_tok is None: + return + + # Calculate active parameters as a fraction of total + draft_params = get_total_parameters(draft_model) + target_params = get_total_parameters(model) + active_params = target_params * (num_experts_per_tok / num_experts) + ratio = active_params / draft_params + + if ratio < 4.0: + active_b = active_params / 1e9 + draft_b = draft_params / 1e9 + print( + f"[WARNING] Target model active parameters ({active_b:.1f}B) are close to " + f"draft model size ({draft_b:.1f}B). Speculative decoding may hurt " + f"throughput for MoE architectures. Consider benchmarking with and without " + f"--draft-model." + ) + + def generate_step( prompt: mx.array, model: nn.Module, @@ -2057,6 +2090,7 @@ def main(): draft_model, draft_tokenizer = load(args.draft_model) if draft_tokenizer.vocab_size != tokenizer.vocab_size: raise ValueError("Draft model tokenizer does not match model tokenizer.") + _warn_speculative_moe(model, draft_model) else: draft_model = None sampler = make_sampler( From 07b01742006c791c078e44b9d458ece18f0c5c08 Mon Sep 17 00:00:00 2001 From: Shylin Date: Wed, 27 May 2026 03:49:33 +0530 Subject: [PATCH 2/2] use logging.warning instead of print, add division by zero guard --- mlx_lm/generate.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index b0211c3dd..35111623c 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -5,6 +5,7 @@ import copy import functools import json +import logging import sys import time from collections import deque @@ -42,6 +43,7 @@ from .tokenizer_utils import TokenizerWrapper from .utils import does_model_support_input_embeddings, get_total_parameters, load +logger = logging.getLogger(__name__) DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.0 @@ -317,23 +319,24 @@ def _warn_speculative_moe(model: nn.Module, draft_model: nn.Module) -> None: model_config, "num_experts_per_token", None ) - if num_experts is None or num_experts_per_tok is None: + if num_experts == 0: return - # Calculate active parameters as a fraction of total draft_params = get_total_parameters(draft_model) target_params = get_total_parameters(model) active_params = target_params * (num_experts_per_tok / num_experts) ratio = active_params / draft_params - if ratio < 4.0: + if ( + ratio < 4.0 + ): # warn if draft model is >25% of active params (empirically shows slowdown) active_b = active_params / 1e9 draft_b = draft_params / 1e9 - print( - f"[WARNING] Target model active parameters ({active_b:.1f}B) are close to " + logger.warning( + f"Target model active parameters ({active_b:.1f}B) are close to " f"draft model size ({draft_b:.1f}B). Speculative decoding may hurt " - f"throughput for MoE architectures. Consider benchmarking with and without " - f"--draft-model." + f"throughput for MoE architectures. Consider benchmarking with and " + f"without --draft-model." )