Skip to content

Eval bug: Gemma 4 MTP Performance Regression Post-merge #24596

@Schrikvis

Description

@Schrikvis

Name and Version

The fast PR branch build says it's:

version: 9560 (efd651a8e)
built with GNU 16.1.1 for Linux x86_64

Operating systems

Linux

GGML backends

HIP

Hardware

RX 7800 XT

Models

gemma-4-12B-it-qat-UD-Q4_K_XL.gguf
gemma-4-12B-it-Q4_0-MTP.gguf (saved to disk as gemma-4-12B-it-qat-Q4_0-MTP.gguf)
https://huggingface.co/unsloth/gemma-4-12B-it-qat-GGUF
https://huggingface.co/unsloth/gemma-4-12B-it-qat-GGUF/tree/main/MTP

Problem description & steps to reproduce

MTP performance appears to have degraded compared to the PR branch.

I'm using the ROCM build and HIP to get these numbers.

I built #23398 ( https://github.com/am17an/llama.cpp/tree/gemma4-mtp ) like this:

cmake -B build -S . -DGGML_HIP=ON -DAMDGPU_TARGETS="gfx1101" -DBUILD_SHARED_LIBS=OFF && time cmake --build build --config Release -j$(nproc) --clean-first --target llama-cli llama-server

And I run it like this:

./llama-server -m "gemma-4-12B-it-qat-UD-Q4_K_XL.gguf" -md "gemma-4-12B-it-qat-Q4_0-MTP.gguf" -c 8192 -fa on -np 1 --fit-target 4096 --spec-type draft-mtp --spec-draft-n-max 2 -lv 4 --no-log-timestamps

(This is also reproducible with -fit off.)

Then, I give it a short prompt like: "Explain multi-token prediction."

And run it with normal sampling parameters like: Top-K 64, top-P 0.95, temp 1.

I get around 84 t/s:

I slot print_timing: id  0 | task 0 | n_decoded =    100, tg =  90.54 t/s
I slot print_timing: id  0 | task 0 | n_decoded =    374, tg =  90.88 t/s
I slot print_timing: id  0 | task 0 | n_decoded =    603, tg =  84.65 t/s
I slot print_timing: id  0 | task 0 | n_decoded =    853, tg =  84.25 t/s
I slot print_timing: id  0 | task 0 | prompt eval time =      71.27 ms /    15 tokens (    4.75 ms per token,   210.46 tokens per second)
I slot print_timing: id  0 | task 0 |        eval time =   10281.13 ms /   864 tokens (   11.90 ms per token,    84.04 tokens per second)
I slot print_timing: id  0 | task 0 |       total time =   10352.41 ms /   879 tokens
I slot print_timing: id  0 | task 0 |    graphs reused =          0
I slot print_timing: id  0 | task 0 | draft acceptance = 0.57463 (  462 accepted /   804 generated)
I statistics        draft-mtp: #calls(b,g,a) =    1    402    402, #gen drafts =    402, #acc drafts =   275, #gen tokens =    804, #acc tokens =   462, dur(b,g,a) = 0.000, 1689.840, 0.226 ms
I slot      release: id  0 | task 0 | stop processing: n_tokens = 879, truncated = 0

This is fast and roughly 60% faster than running without MTP.

Running without MTP looks like this, for reference.
0.10.011.121 I slot print_timing: id  0 | task 0 | n_decoded =    100, tg =  53.63 t/s
0.13.016.784 I slot print_timing: id  0 | task 0 | n_decoded =    261, tg =  53.59 t/s
0.16.033.790 I slot print_timing: id  0 | task 0 | n_decoded =    419, tg =  53.12 t/s
0.19.034.558 I slot print_timing: id  0 | task 0 | n_decoded =    576, tg =  52.90 t/s
0.22.036.239 I slot print_timing: id  0 | task 0 | n_decoded =    731, tg =  52.63 t/s
0.25.050.493 I slot print_timing: id  0 | task 0 | n_decoded =    886, tg =  52.41 t/s
0.27.361.938 I slot print_timing: id  0 | task 0 | prompt eval time =      66.97 ms /    14 tokens (    4.78 ms per token,   209.06 tokens per second)
0.27.361.940 I slot print_timing: id  0 | task 0 |        eval time =   19215.30 ms /  1005 tokens (   19.12 ms per token,    52.30 tokens per second)
0.27.361.941 I slot print_timing: id  0 | task 0 |       total time =   19282.26 ms /  1019 tokens
0.27.361.941 I slot print_timing: id  0 | task 0 |    graphs reused =          0
0.27.361.947 I slot      release: id  0 | task 0 | stop processing: n_tokens = 1018, truncated = 0

But when I run a post-merge build, for example b9549 or b9616, I get around 56 t/s:

I slot print_timing: id  0 | task 0 | n_decoded =    102, tg =  68.74 t/s
I slot print_timing: id  0 | task 0 | n_decoded =    294, tg =  65.16 t/s
I slot print_timing: id  0 | task 0 | n_decoded =    468, tg =  62.23 t/s
I slot print_timing: id  0 | task 0 | n_decoded =    641, tg =  60.77 t/s
I slot print_timing: id  0 | task 0 | n_decoded =    787, tg =  58.06 t/s
I slot print_timing: id  0 | task 0 | prompt eval time =     223.22 ms /    15 tokens (   14.88 ms per token,    67.20 tokens per second)
I slot print_timing: id  0 | task 0 |        eval time =   15900.22 ms /   900 tokens (   17.67 ms per token,    56.60 tokens per second)
I slot print_timing: id  0 | task 0 |       total time =   16123.44 ms /   915 tokens
I slot print_timing: id  0 | task 0 |    graphs reused =        388
I slot print_timing: id  0 | task 0 | draft acceptance = 0.65217 (  510 accepted /   782 generated)
I statistics        draft-mtp: #calls(b,g,a) =    1    391    391, #gen drafts =    391, #acc drafts =   305, #gen tokens =    782, #acc tokens =   510, dur(b,g,a) = 0.002, 1615.848, 0.232 ms
I slot      release: id  0 | task 0 | stop processing: n_tokens = 916, truncated = 0

This is slow and similar to running without MTP.

This relative speed loss is consistent across runs, and other gemma 4 models (qat and non-qat, 26B MoE and 31B dense). I'm demonstrating using 12B qat for convenience.

Both builds output coherent text. I cherry-picked these two runs for this section to show that runs with lower draft acceptance rates can be faster. The speed loss seems unrelated to draft acceptance rates etc.

Performance without MTP is equivalent between these builds.

Brief MoE MTP digression: MTP seems equally effective for sparse models.

Gemma 4 26B runs at around 112 t/s with MTP enabled on the PR branch:

0.54.132.669 I slot print_timing: id  0 | task 0 | n_decoded =    102, tg = 113.50 t/s
0.57.151.141 I slot print_timing: id  0 | task 0 | n_decoded =    447, tg = 114.11 t/s
1.00.151.475 I slot print_timing: id  0 | task 0 | n_decoded =    777, tg = 112.32 t/s
1.00.794.891 I slot print_timing: id  0 | task 0 | prompt eval time =     116.25 ms /    14 tokens (    8.30 ms per token,   120.43 tokens per second)
1.00.794.894 I slot print_timing: id  0 | task 0 |        eval time =    7560.90 ms /   848 tokens (    8.92 ms per token,   112.16 tokens per second)
1.00.794.894 I slot print_timing: id  0 | task 0 |       total time =    7677.15 ms /   862 tokens
1.00.794.894 I slot print_timing: id  0 | task 0 |    graphs reused =        365
1.00.794.895 I slot print_timing: id  0 | task 0 | draft acceptance = 0.65082 (  479 accepted /   736 generated)
1.00.796.259 I statistics        draft-mtp: #calls(b,g,a) =    1    368    368, #gen drafts =    368, #acc drafts =   275, #gen tokens =    736, #acc tokens =   479, dur(b,g,a) = 0.002, 1726.222, 0.208 ms

And at around 68 t/s with MTP disabled:

0.41.157.144 I slot print_timing: id  0 | task 0 | n_decoded =    100, tg =  69.78 t/s
0.44.165.966 I slot print_timing: id  0 | task 0 | n_decoded =    308, tg =  69.34 t/s
0.47.180.024 I slot print_timing: id  0 | task 0 | n_decoded =    517, tg =  69.34 t/s
0.50.192.958 I slot print_timing: id  0 | task 0 | n_decoded =    725, tg =  69.25 t/s
0.53.195.533 I slot print_timing: id  0 | task 0 | n_decoded =    929, tg =  68.96 t/s
0.54.716.498 I slot print_timing: id  0 | task 0 | prompt eval time =      86.03 ms /    14 tokens (    6.14 ms per token,   162.74 tokens per second)
0.54.716.500 I slot print_timing: id  0 | task 0 |        eval time =   14992.44 ms /  1031 tokens (   14.54 ms per token,    68.77 tokens per second)
0.54.716.501 I slot print_timing: id  0 | task 0 |       total time =   15078.47 ms /  1045 tokens
0.54.716.501 I slot print_timing: id  0 | task 0 |    graphs reused =       1026
0.54.716.507 I slot      release: id  0 | task 0 | stop processing: n_tokens = 1044, truncated = 0

I've seen claims and proof that MTP is less effective for sparse models, but it seems proportional to me on my hardware etc. etc.

Lastly, b9616 with MTP enabled:

0.31.906.486 I slot print_timing: id  0 | task 0 | n_decoded =    102, tg =  86.95 t/s
0.34.909.373 I slot print_timing: id  0 | task 0 | n_decoded =    354, tg =  84.77 t/s
0.37.923.224 I slot print_timing: id  0 | task 0 | n_decoded =    591, tg =  82.20 t/s
0.40.956.385 I slot print_timing: id  0 | task 0 | n_decoded =    794, tg =  77.67 t/s
0.43.959.339 I slot print_timing: id  0 | task 0 | n_decoded =   1001, tg =  75.68 t/s
0.43.995.440 I slot print_timing: id  0 | task 0 | prompt eval time =     262.39 ms /    14 tokens (   18.74 ms per token,    53.36 tokens per second)
0.43.995.442 I slot print_timing: id  0 | task 0 |        eval time =   13262.04 ms /  1003 tokens (   13.22 ms per token,    75.63 tokens per second)
0.43.995.443 I slot print_timing: id  0 | task 0 |       total time =   13524.43 ms /  1017 tokens
0.43.995.448 I slot print_timing: id  0 | task 0 |    graphs reused =        422
0.43.995.449 I slot print_timing: id  0 | task 0 | draft acceptance = 0.68000 (  578 accepted /   850 generated)
0.43.995.459 I statistics        draft-mtp: #calls(b,g,a) =    1    425    425, #gen drafts =    425, #acc drafts =   330, #gen tokens =    850, #acc tokens =   578, dur(b,g,a) = 0.003, 1934.726, 0.259 ms

And disabled:

0.15.965.977 I slot print_timing: id  0 | task 0 | n_decoded =    100, tg =  67.32 t/s
0.18.971.792 I slot print_timing: id  0 | task 0 | n_decoded =    303, tg =  67.46 t/s
0.21.983.365 I slot print_timing: id  0 | task 0 | n_decoded =    507, tg =  67.57 t/s
0.24.985.958 I slot print_timing: id  0 | task 0 | n_decoded =    710, tg =  67.58 t/s
0.27.989.837 I slot print_timing: id  0 | task 0 | n_decoded =    910, tg =  67.36 t/s
0.28.698.980 I slot print_timing: id  0 | task 0 | prompt eval time =      97.24 ms /    14 tokens (    6.95 ms per token,   143.97 tokens per second)
0.28.698.983 I slot print_timing: id  0 | task 0 |        eval time =   14218.42 ms /   957 tokens (   14.86 ms per token,    67.31 tokens per second)
0.28.698.984 I slot print_timing: id  0 | task 0 |       total time =   14315.66 ms /   971 tokens
0.28.698.987 I slot print_timing: id  0 | task 0 |    graphs reused =        953
0.28.699.010 I slot      release: id  0 | task 0 | stop processing: n_tokens = 970, truncated = 0

First Bad Commit

b9549 is slow, but the PR branch was fast.

So, the bug seems to have appeared sometime after the fix typo commit in the PR branch, and during or shortly after the merge I guess. I don't know what else happened in the mean time.

Relevant log output

Verbose logs:

b9549-verbose-slow.log

b9637-verbose-slow.log

gemma4-mtp-verbose-fast.log

Less verbose logs:

llamacpp gemma4-mtp fast.txt

llamacpp b9616 slow.txt

llamacpp b9549 slow.txt

The generated text is along the lines of:

Multi-Token Prediction explained with minor blemishes that seem to be typical for gemma 4 qat models in general. I'm only putting this here to show that's it's a somewhat open-ended answer, and not a quake 3 fast inverse square root nor a list of prime numbers, but it's also not a creative writing exercise where MTP is known to be less consistently effective.

Multi-token prediction (MTP) is an evolution of the training objective used in Large Language Models (LLMs).

To understand it, you first need to understand how standard LLMs work.


1. The Standard Approach: Next-Token Prediction (NTP)

Current LLMs (like GPT-4 or Llama) are trained using Next-Token Prediction.

When the model sees the sequence: "The capital of France is...", the model’s mathematical objective is to calculate a probability distribution for only the very next word ("Paris").

The limitation: The model is only ever penalized for being wrong about the immediate next step. It doesn't inherently "plan" or look ahead to the structure of the entire sentence; it just masters the art of predicting the single next increment.

2. What is Multi-Token Prediction?

In Multi-Token Prediction, instead of asking the model, "What is the next word?", we ask the model, "What are the next N words?"

During training, the model is forced to predict not just word $n+1$, but also $n+2, n+3,$ and so on, all in a single forward pass.

How it works technically:

In a standard model, there is one "prediction head" (the final layer) that sits on top of the transformer blocks. In an MTP model:

  • The transformer backbone processes the text.
  • Instead of one output head, the model has multiple output heads (e.g., one for the 1st next token, one for the 2nd, one for the 3rd).
  • Each head is trained to predict a different position in the future sequence.

3. Why does this matter? (The Benefits)

A. Better "Global" Planning

Because the model is forced to predict several tokens at once, it must learn the underlying structure of language, logic, and syntax more deeply. To predict the 4th word correctly, the model has to "understand" the relationship between the current word and the context, rather than just looking at the immediate neighbor. This encourages long-term dependency modeling.

B. Improved Efficiency (Speed)

One of the biggest bottlenecks in running LLMs (inference) is that they generate text one word at a time, which is slow.

  • Standard LLM: To get 4 words, you must run the entire massive model 4 separate times.
  • MTP LLM: If the model is trained to predict 4 tokens at once, it can potentially output multiple tokens in a single "thought," significantly increasing the speed of text generation (throughput).

C. Better Reasoning

Recent research (notably from Meta/FAIR) suggests that MTP helps with reasoning tasks. Because the model is forced to "see" further into the future, it is less likely to get stuck in local loops or lose the thread of a complex mathematical or logical argument.

4. Summary Comparison

Feature Next-Token Prediction (Standard) Multi-Token Prediction (MTP)
Goal Predict $t+1$ Predict $t+1, t+2, t+3...$
Training Focus Local accuracy (short-term) Structural/Global accuracy (long-term)
Inference Speed Slower (one word per pass) Faster (multiple words per pass)
Complexity Simpler to implement More complex (multiple output heads)

Conclusion

Multi-token prediction moves AI from being a "stochastic parrot" that just guesses the next word, toward a model that captures the intent and structure of a sequence. It is a key area of research for creating models that are not only smarter but much faster to use.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions