Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,8 @@ def handle_finish_chunk(self, result: dict, request_id: str, full_text: str):
"finish_reason": finish_reason,
"stop_str": stop_str,
"full_text": full_text,
"accepted_draft_tokens": result.get("accepted_draft_tokens", 0),
"rejected_draft_tokens": result.get("rejected_draft_tokens", 0),
}

return finish_chunk
Expand Down
45 changes: 37 additions & 8 deletions backends/exllamav3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,24 +154,47 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
self.use_draft_model = draft_args and draft_model_name
draft_arch_override = draft_args.get("draft_arch_override")
self._draft_args = draft_args

# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
# Two ways to enable a draft model:
# 1) Separate dir+name (regular draft, any arch).
# 2) MTP head loaded from the main model's checkpoint: set draft_arch_override
# (e.g. "Qwen3_5MTPDraftModel") and leave draft_model_name unset.
self.use_draft_model = bool(draft_args) and bool(draft_model_name or draft_arch_override)

# Misconfiguration: draft section present but no way to locate weights
if draft_args and not draft_model_name and not draft_arch_override:
xlogger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
"Draft model section is set but neither draft_model_name nor "
"draft_arch_override is provided. Disabling draft model. "
"Set draft_model_name to load from a separate directory, or "
"draft_arch_override (e.g. Qwen3_5MTPDraftModel) to load the "
"MTP head from the main model directory."
)
self.use_draft_model = False

if self.use_draft_model:
draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models"))
draft_model_path = draft_model_path / draft_model_name
if draft_model_name:
# Separate draft model directory
draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models"))
draft_model_path = draft_model_path / draft_model_name
else:
# MTP from the same dir as the main model — checkpoint has both trunk and
# mtp.* tensors; arch_override picks just the MTP weights via Qwen3_5MTPDraftConfig.
draft_model_path = model_directory
xlogger.info("Loading draft model from main model directory (self-spec)")

self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), [])
self.draft_model_dir = draft_model_path
self.draft_config = Config.from_directory(str(draft_model_path.resolve()))
self.draft_config = Config.from_directory(
str(draft_model_path.resolve()),
arch_override=draft_arch_override,
)
self.draft_model = Model.from_config(self.draft_config)
xlogger.info(f"Using draft model: {str(draft_model_path.resolve())}")
if draft_arch_override:
xlogger.info(f"Draft arch override: {draft_arch_override}")
else:
self.draft_model = None
self.draft_cache = None
Expand Down Expand Up @@ -555,6 +578,9 @@ async def create_generator(self):
await self.wait_for_jobs(skip_wait=True)

# Create new generator
draft_args_runtime = unwrap(getattr(self, "_draft_args", None), {})
num_draft_tokens = draft_args_runtime.get("num_draft_tokens") \
if isinstance(draft_args_runtime, dict) else None
self.generator = AsyncGenerator(
model=self.model,
cache=self.cache,
Expand All @@ -563,6 +589,7 @@ async def create_generator(self):
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
max_chunk_size=self.chunk_size,
num_draft_tokens=num_draft_tokens,
recurrent_cache_size=config.memory.sysmem_recurrent_cache * 1024**2,
)

Expand Down Expand Up @@ -910,6 +937,8 @@ def handle_finish_chunk(self, result: dict, request_id: str, full_text: str):
"finish_reason": finish_reason,
"stop_str": stop_str,
"full_text": full_text,
"accepted_draft_tokens": result.get("accepted_draft_tokens", 0),
"rejected_draft_tokens": result.get("rejected_draft_tokens", 0),
}

return finish_chunk
Expand Down
23 changes: 22 additions & 1 deletion common/config_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,9 @@ class DraftModelConfig(BaseConfigModel):
draft_model_name: Optional[str] = Field(
None,
description=(
"An initial draft model to load.\nEnsure the model is in the model directory."
"An initial draft model to load.\nEnsure the model is in the model directory.\n"
"Leave blank when using draft_arch_override to load an MTP head from the\n"
"main model directory (self-spec)."
),
)
draft_rope_scale: Optional[float] = Field(
Expand Down Expand Up @@ -381,6 +383,25 @@ class DraftModelConfig(BaseConfigModel):
"If this isn't filled in, the draft model is autosplit."
),
)
draft_arch_override: Optional[str] = Field(
None,
description=(
"Override the architecture string read from the draft model's config.json.\n"
"Use 'Qwen3_5MTPDraftModel' to load only the MTP head. Two ways:\n"
" - With draft_model_name: load the MTP head from a separate directory\n"
" (e.g. point draft_model_name at the original BF16 Qwen3.6 repo).\n"
" - Without draft_model_name: load the MTP head from the SAME directory\n"
" as the main model, when that checkpoint already contains the mtp.*\n"
" tensors alongside the regular trunk weights."
),
)
num_draft_tokens: Optional[int] = Field(
None,
description=(
"Number of draft tokens generated per spec-decode round (default: draft model\n"
"preference, else 4). For MTP-1 (Qwen3.6), 2-3 is typical."
),
)


class SamplingConfig(BaseConfigModel):
Expand Down
10 changes: 10 additions & 0 deletions common/gen_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def log_metrics(

itemization.append(f"Generate: {metrics.get('gen_tokens_per_sec')} T/s")

# Add draft token acceptance rate if available
accepted_draft = metrics.get("accepted_draft_tokens", 0)
rejected_draft = metrics.get("rejected_draft_tokens", 0)
total_draft = accepted_draft + rejected_draft
if total_draft > 0:
acceptance_rate = round(accepted_draft / total_draft * 100, 1)
itemization.append(
f"Draft: {accepted_draft}/{total_draft} accepted ({acceptance_rate}% acceptance)"
)

# Add context (original token count)
if context_len:
itemization.append(f"Context: {context_len} tokens")
Expand Down