Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7817,7 +7817,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Gemma4AssistantForCausalLM")
@ModelBase.register("Gemma4AssistantForCausalLM", "Gemma4UnifiedAssistantForCausalLM")
class Gemma4AssistantModel(Gemma4Model):
model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT

Expand Down
10 changes: 6 additions & 4 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
}
break;
// MMA template instances for DKQ=512 only exist for ncols2 in {4,8},
// requiring gqa_ratio < 3. Gemma-4 (12B/26B/31B) has gqa_ratio=8 and
// aborts in switch_ncols2<512,512>. fattn-tile supports DKQ=512 with
// ncols2 fallback to {2,1}. Route ALL DKQ=512 cases to TILE here so
// the early-return path and the no-mask MTP cross-attn path are both covered.
return BEST_FATTN_KERNEL_TILE;
case 576:
case 640:
if (V->ne[0] != 512) {
Expand Down
15 changes: 12 additions & 3 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1105,10 +1105,14 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void
void llama_context::set_embeddings(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);

const bool changed = (cparams.embeddings != value);
cparams.embeddings = value;

// TODO: not sure yet if we want to reserve here
//sched_need_reserve = true;
// Changing embeddings mode changes the graph topology (adds/removes the backbone
// hidden-state output node). Re-reserve so compute buffers are sized correctly.
// Only trigger re-reserve on actual change — not on redundant same-value calls.
if (changed) {
sched_need_reserve = true;
}
}

void llama_context::set_embeddings_pre_norm(bool value) {
Expand Down Expand Up @@ -2692,6 +2696,11 @@ int32_t llama_context::decode_mtp_async(
return -8;
}

// Flush main scheduler so KV cache writes from the preceding llama_decode are
// visible to the MTP worker's CUDA reads before it starts. Without this sync
// the worker races the still-in-flight CUDA async writes and reads stale K/V.
ggml_backend_sched_synchronize(sched.get());

{
std::unique_lock<std::mutex> lk(mtp_mu);
if (mtp_pending.has_value() || mtp_in_flight || mtp_completed.has_value()) {
Expand Down
8 changes: 7 additions & 1 deletion src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
}

if (ubatch->embd) {
// Guard on `embd` being non-null (mirrors the can_reuse() check below): gemma4's
// per-layer-token input reuses this same input class but only allocates `tokens`
// (never `embd`). When the batch carries raw embeddings (ubatch->embd != null) --
// e.g. the Gemma-4 MTP / embeddings verify path -- that per-layer input would
// otherwise dereference a null `embd` here. Only the input that actually built an
// `embd` tensor should consume ubatch->embd.
if (ubatch->embd && embd) {
GGML_ASSERT(n_embd == embd->ne[0]);

const int64_t n_tokens = ubatch->n_tokens;
Expand Down
27 changes: 22 additions & 5 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1650,6 +1650,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t) n_kv_shared_layers;
hparams.f_attention_scale = 1.0f;

if (hparams.n_layer > 0 && hparams.n_layer_kv_from_start <= 0) {
LLAMA_LOG_WARN("%s: gemma4_assistant KV sharing metadata leaves no dedicated KV layers "
"(n_layer=%u, shared_kv_layers=%u); disabling reuse\n",
__func__, hparams.n_layer, n_kv_shared_layers);
hparams.n_layer_kv_from_start = (int32_t) hparams.n_layer;
}

ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
Expand Down Expand Up @@ -9573,11 +9580,21 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
GGML_ABORT("fatal error");
}

// add on pooling layer
llm->build_pooling(cls, cls_b, cls_out, cls_out_b, cls_norm);

// add backend sampling layers (if any)
llm->build_sampling();
// The Gemma-4 MTP graph is self-contained: it produces its own logits / h_post /
// on-device argmax and is driven on a dedicated scheduler. The shared embedding
// pooling and backend-sampling epilogues are for the main decode path only. They
// must be skipped for MTP graphs: the speculative path forces cparams.embeddings=true
// (so the main decode emits backbone hidden states), and that flag would otherwise
// make build_pooling run against the MTP graph's t_embd (h_post) -- which has no
// pooling inputs and crashes -- and build_sampling attach backend samplers to the
// MTP logits, which MTP never consumes.
if (params.gtype != LLM_GRAPH_TYPE_MTP) {
// add on pooling layer
llm->build_pooling(cls, cls_b, cls_out, cls_out_b, cls_norm);

// add backend sampling layers (if any)
llm->build_sampling();
}

// if the gguf model was converted with --sentence-transformers-dense-modules
// there will be two additional dense projection layers
Expand Down