-
Notifications
You must be signed in to change notification settings - Fork 19.6k
[Speculative decoding] feat: add EAGLE3 speculative decoding support #18039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0806a96
800494f
16e6555
752bf23
b32d9eb
7c5f428
91b9cfc
4ca8087
413c16d
6c21222
ac7e2b2
544aaa2
5738c9a
b9f41d1
f3fbbed
8002c4c
33b02df
7857221
9b2543d
1d55316
2de116b
f408879
d373233
5caedbc
0274f0f
9baa68b
0bd5449
7c42aff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,12 +5,13 @@ | |
|
|
||
| from typing import Callable, Iterable, TYPE_CHECKING | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| if TYPE_CHECKING: | ||
| from torch import Tensor | ||
|
|
||
| from .base import ModelBase, TextModel, gguf | ||
| from .base import ModelBase, TextModel, gguf, logger | ||
|
|
||
|
|
||
| @ModelBase.register( | ||
|
|
@@ -21,6 +22,9 @@ | |
| "VLlama3ForCausalLM", | ||
| "LlavaForConditionalGeneration", | ||
| "VoxtralForConditionalGeneration", | ||
| "LlamaForCausalLMEagle3", | ||
| "Eagle3Speculator", | ||
| "Eagle3DraftModel", | ||
| "IQuestCoderForCausalLM", | ||
| "LlamaModel") | ||
| class LlamaModel(TextModel): | ||
|
|
@@ -39,7 +43,61 @@ def __init__(self, *args, **kwargs): | |
| hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False) | ||
| self.origin_hf_arch = hparams.get('architectures', [None])[0] | ||
|
|
||
| # Detect eagle3 draft checkpoint by hparams (some models don't use a distinct HF arch name) | ||
| if "draft_vocab_size" in self.hparams and self.hparams["num_hidden_layers"] == 1: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not important right now, but I'm guessing all this will basically be duplicated for every arch supported with very little if any differences? Would be nice if it can be refactored in a reusable way.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! I kept this local for now because all Eagle3 checkpoints I have encountered so far are based on Llama decoder (no matter where they come from RedHat, LMSYS, NVIDIA, etc), and this PR only targets that path unless we find an Eagle3 checkpoint based on a different architecture. (potentially this #18039 (comment) but not sure). If another architecture needs Eagle3 conversion later, this should be the first piece to factor out. |
||
| self.is_eagle3 = True | ||
| self.model_arch = gguf.MODEL_ARCH.EAGLE3 | ||
| logger.info("Detected EAGLE-3 draft model, switching to EAGLE3 architecture") | ||
| # Re-initialize tensor_map with eagle3 architecture | ||
| self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) | ||
| # Update gguf_writer architecture | ||
| self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] | ||
| self.gguf_writer.add_architecture() | ||
| if self.target_model_dir is None: | ||
| raise ValueError( | ||
| "EAGLE-3 model requires --target-model-dir to be specified. " | ||
| "Please provide the path to the target model directory to read config.json" | ||
| ) | ||
| # Read both eagle3 raw config and target model config | ||
| with open(self.dir_model / "config.json", 'r', encoding='utf-8') as f: | ||
| eagle3_raw_config = json.load(f) | ||
| with open(self.target_model_dir / "config.json", 'r', encoding='utf-8') as f: | ||
| target_config = json.load(f) | ||
|
|
||
| if "text_config" in target_config: | ||
| target_config = {**target_config, **target_config["text_config"]} | ||
| self.target_vocab_size = target_config["vocab_size"] | ||
|
|
||
| # target_layers: derived from target model layer count (low/mid/high) | ||
| target_num_layers = target_config["num_hidden_layers"] | ||
| target_layers = [2, target_num_layers // 2, target_num_layers - 3] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also prefer the eagle3 config when Same question for vocab size when
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good question. First, many Eagle3 checkpoints do not include
Both
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yikes, guess the EAGLE3 rollout has not been smooth 😅 thanks for the clarity! unfortunate but logical :)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah it is. Thanks for the review! |
||
| logger.info(f"EAGLE-3: target_layers = {target_layers} (target model has {target_num_layers} layers)") | ||
| self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", target_layers) | ||
|
|
||
| # target_hidden_size: prefer eagle3 config, fallback to target config | ||
| if eagle3_raw_config.get("target_hidden_size") is not None: | ||
| target_hidden_size = eagle3_raw_config["target_hidden_size"] | ||
| src = "EAGLE-3 config" | ||
| else: | ||
| target_hidden_size = target_config["hidden_size"] | ||
| src = "target model config" | ||
| logger.info(f"EAGLE-3: target_hidden_size = {target_hidden_size} (from {src})") | ||
| self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size) | ||
|
|
||
| # norm_before_residual (RedHat-style eagle3 specific) | ||
| norm_before_residual = eagle3_raw_config.get("norm_before_residual", False) | ||
| logger.info(f"EAGLE-3: norm_before_residual = {norm_before_residual}") | ||
| self.gguf_writer.add_bool(f"{self.gguf_writer.arch}.norm_before_residual", norm_before_residual) | ||
|
|
||
| def set_vocab(self): | ||
| # eagle3: use tokenizer from target model if provided | ||
| original_dir_model = None | ||
| if getattr(self, 'is_eagle3', False): | ||
| assert self.target_model_dir is not None | ||
| logger.info(f"EAGLE-3: Using tokenizer from target model: {self.target_model_dir}") | ||
| original_dir_model = self.dir_model | ||
| self.dir_model = self.target_model_dir | ||
|
|
||
| if self.origin_hf_arch == "GlmasrModel": | ||
| return self._set_vocab_glmedge() | ||
|
|
||
|
|
@@ -85,6 +143,10 @@ def set_vocab(self): | |
| if self.hparams.get("vocab_size", 32000) == 49152: | ||
| self.gguf_writer.add_add_bos_token(False) | ||
|
|
||
| # eagle3: Restore original dir_model | ||
| if original_dir_model is not None: | ||
| self.dir_model = original_dir_model | ||
|
|
||
| def set_gguf_parameters(self): | ||
| super().set_gguf_parameters() | ||
| hparams = self.hparams | ||
|
|
@@ -129,7 +191,49 @@ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Ca | |
|
|
||
| return super().filter_tensors((name, gen)) | ||
|
|
||
| def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]: | ||
| tensors = super().index_tensors(remote_hf_model_id) | ||
|
|
||
| # Handle Eagle3Speculator nested config | ||
| if "transformer_layer_config" in self.hparams: | ||
| self.hparams = {**self.hparams, **self.hparams["transformer_layer_config"]} | ||
|
|
||
| # eagle3 detection | ||
| if "draft_vocab_size" in self.hparams and self.hparams["num_hidden_layers"] == 1: | ||
| logger.info("EAGLE-3: renaming midlayer.* / layers.0.* to model.layers.0.*") | ||
| new_tensors = {} | ||
| for name, gen in tensors.items(): | ||
| if name.startswith("midlayer."): | ||
| new_name = "model.layers.0." + name[len("midlayer."):] | ||
| new_tensors[new_name] = gen | ||
| elif name.startswith("layers.0."): # Eagle3Speculator format | ||
| new_name = "model." + name | ||
| new_tensors[new_name] = gen | ||
| else: | ||
| new_tensors[name] = gen | ||
| return new_tensors | ||
|
|
||
| return tensors | ||
|
|
||
| def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||
| # eagle3: special tensors that bypass standard llama mapping | ||
| if getattr(self, 'is_eagle3', False): | ||
| if name == "fc.weight": | ||
| yield (name, data_torch) | ||
| return | ||
| if name == "d2t": | ||
| # store for manual int64 handling in prepare_tensors (avoid F32 conversion) | ||
| if not hasattr(self, '_eagle3_int_tensors'): | ||
| self._eagle3_int_tensors = {} | ||
| self._eagle3_int_tensors[name] = data_torch | ||
| return | ||
| if name == "t2d": | ||
| # not used at runtime, skip | ||
| return | ||
| if name.endswith(".hidden_norm.weight"): | ||
| yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_NORM_2, bid), data_torch) | ||
| return | ||
|
|
||
| n_head = self.find_hparam(["n_heads", "num_attention_heads"]) | ||
| n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"]) | ||
|
|
||
|
|
@@ -205,8 +309,33 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: | |
| yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) | ||
|
|
||
| def prepare_tensors(self): | ||
| # eagle3: collect d2t original dtype before parent converts tensors to F32 | ||
| eagle3_original_dtypes = {} | ||
| if getattr(self, 'is_eagle3', False): | ||
| for name, data_torch in self.get_tensors(): | ||
| if name == "d2t": | ||
| eagle3_original_dtypes[name] = data_torch.dtype | ||
|
|
||
| super().prepare_tensors() | ||
|
|
||
| # eagle3: write d2t as absolute target token ids | ||
| if getattr(self, 'is_eagle3', False) and hasattr(self, '_eagle3_int_tensors'): | ||
| for name, data_torch in self._eagle3_int_tensors.items(): | ||
| old_dtype = eagle3_original_dtypes.get(name, data_torch.dtype) | ||
| data = data_torch.to(torch.int64).cpu().numpy() | ||
| if name == "d2t": | ||
| data = data.reshape(-1) | ||
| data = data + np.arange(data.size, dtype=np.int64) | ||
| if np.any((data < 0) | (data >= self.target_vocab_size)): | ||
| raise ValueError(f"EAGLE-3 d2t target ids out of range for target vocab size {self.target_vocab_size}") | ||
| if np.unique(data).size != data.size: | ||
| raise ValueError("EAGLE-3 d2t contains duplicate target ids") | ||
| data_qtype = gguf.GGMLQuantizationType.I64 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Future-proofing is nice and all, but
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And we have a assert check here: https://github.com/ruixiang63/llama.cpp/blob/0bd54498f273bf290a8fd55152deedf8e7c878dc/src/models/eagle3.cpp#L309 |
||
|
|
||
| shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" | ||
| logger.info(f"{name + ',':<30} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") | ||
| self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype) | ||
|
|
||
| if self._experts is not None: | ||
| # flatten `list[dict[str, Tensor]]` into `list[str]` | ||
| experts = [k for d in self._experts for k in d.keys()] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.