diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ea0ecb4f..4574c29e 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1228,6 +1228,8 @@ def handle_finish_chunk(self, result: dict, request_id: str, full_text: str): "finish_reason": finish_reason, "stop_str": stop_str, "full_text": full_text, + "accepted_draft_tokens": result.get("accepted_draft_tokens", 0), + "rejected_draft_tokens": result.get("rejected_draft_tokens", 0), } return finish_chunk diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index bac27adc..1d775235 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -154,24 +154,47 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs # Prepare the draft model config if necessary draft_args = unwrap(kwargs.get("draft_model"), {}) draft_model_name = draft_args.get("draft_model_name") - self.use_draft_model = draft_args and draft_model_name + draft_arch_override = draft_args.get("draft_arch_override") + self._draft_args = draft_args - # Always disable draft if params are incorrectly configured - if draft_args and draft_model_name is None: + # Two ways to enable a draft model: + # 1) Separate dir+name (regular draft, any arch). + # 2) MTP head loaded from the main model's checkpoint: set draft_arch_override + # (e.g. "Qwen3_5MTPDraftModel") and leave draft_model_name unset. + self.use_draft_model = bool(draft_args) and bool(draft_model_name or draft_arch_override) + + # Misconfiguration: draft section present but no way to locate weights + if draft_args and not draft_model_name and not draft_arch_override: xlogger.warning( - "Draft model is disabled because a model name " - "wasn't provided. Please check your config.yml!" + "Draft model section is set but neither draft_model_name nor " + "draft_arch_override is provided. Disabling draft model. " + "Set draft_model_name to load from a separate directory, or " + "draft_arch_override (e.g. Qwen3_5MTPDraftModel) to load the " + "MTP head from the main model directory." ) self.use_draft_model = False if self.use_draft_model: - draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models")) - draft_model_path = draft_model_path / draft_model_name + if draft_model_name: + # Separate draft model directory + draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models")) + draft_model_path = draft_model_path / draft_model_name + else: + # MTP from the same dir as the main model — checkpoint has both trunk and + # mtp.* tensors; arch_override picks just the MTP weights via Qwen3_5MTPDraftConfig. + draft_model_path = model_directory + xlogger.info("Loading draft model from main model directory (self-spec)") + self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), []) self.draft_model_dir = draft_model_path - self.draft_config = Config.from_directory(str(draft_model_path.resolve())) + self.draft_config = Config.from_directory( + str(draft_model_path.resolve()), + arch_override=draft_arch_override, + ) self.draft_model = Model.from_config(self.draft_config) xlogger.info(f"Using draft model: {str(draft_model_path.resolve())}") + if draft_arch_override: + xlogger.info(f"Draft arch override: {draft_arch_override}") else: self.draft_model = None self.draft_cache = None @@ -555,6 +578,9 @@ async def create_generator(self): await self.wait_for_jobs(skip_wait=True) # Create new generator + draft_args_runtime = unwrap(getattr(self, "_draft_args", None), {}) + num_draft_tokens = draft_args_runtime.get("num_draft_tokens") \ + if isinstance(draft_args_runtime, dict) else None self.generator = AsyncGenerator( model=self.model, cache=self.cache, @@ -563,6 +589,7 @@ async def create_generator(self): tokenizer=self.tokenizer, max_batch_size=self.max_batch_size, max_chunk_size=self.chunk_size, + num_draft_tokens=num_draft_tokens, recurrent_cache_size=config.memory.sysmem_recurrent_cache * 1024**2, ) @@ -910,6 +937,8 @@ def handle_finish_chunk(self, result: dict, request_id: str, full_text: str): "finish_reason": finish_reason, "stop_str": stop_str, "full_text": full_text, + "accepted_draft_tokens": result.get("accepted_draft_tokens", 0), + "rejected_draft_tokens": result.get("rejected_draft_tokens", 0), } return finish_chunk diff --git a/common/config_models.py b/common/config_models.py index 5be86af1..87640018 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -347,7 +347,9 @@ class DraftModelConfig(BaseConfigModel): draft_model_name: Optional[str] = Field( None, description=( - "An initial draft model to load.\nEnsure the model is in the model directory." + "An initial draft model to load.\nEnsure the model is in the model directory.\n" + "Leave blank when using draft_arch_override to load an MTP head from the\n" + "main model directory (self-spec)." ), ) draft_rope_scale: Optional[float] = Field( @@ -381,6 +383,25 @@ class DraftModelConfig(BaseConfigModel): "If this isn't filled in, the draft model is autosplit." ), ) + draft_arch_override: Optional[str] = Field( + None, + description=( + "Override the architecture string read from the draft model's config.json.\n" + "Use 'Qwen3_5MTPDraftModel' to load only the MTP head. Two ways:\n" + " - With draft_model_name: load the MTP head from a separate directory\n" + " (e.g. point draft_model_name at the original BF16 Qwen3.6 repo).\n" + " - Without draft_model_name: load the MTP head from the SAME directory\n" + " as the main model, when that checkpoint already contains the mtp.*\n" + " tensors alongside the regular trunk weights." + ), + ) + num_draft_tokens: Optional[int] = Field( + None, + description=( + "Number of draft tokens generated per spec-decode round (default: draft model\n" + "preference, else 4). For MTP-1 (Qwen3.6), 2-3 is typical." + ), + ) class SamplingConfig(BaseConfigModel): diff --git a/common/gen_logging.py b/common/gen_logging.py index 8aa54dc3..f0b00fc5 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -82,6 +82,16 @@ def log_metrics( itemization.append(f"Generate: {metrics.get('gen_tokens_per_sec')} T/s") + # Add draft token acceptance rate if available + accepted_draft = metrics.get("accepted_draft_tokens", 0) + rejected_draft = metrics.get("rejected_draft_tokens", 0) + total_draft = accepted_draft + rejected_draft + if total_draft > 0: + acceptance_rate = round(accepted_draft / total_draft * 100, 1) + itemization.append( + f"Draft: {accepted_draft}/{total_draft} accepted ({acceptance_rate}% acceptance)" + ) + # Add context (original token count) if context_len: itemization.append(f"Context: {context_len} tokens")