Skip to content

Add diffusion-gemma block-diffusion support#1

Open
lnigam wants to merge 27 commits into
masterfrom
diffusion-gemma-review
Open

Add diffusion-gemma block-diffusion support#1
lnigam wants to merge 27 commits into
masterfrom
diffusion-gemma-review

Conversation

@lnigam

@lnigam lnigam commented Jun 9, 2026

Copy link
Copy Markdown
Owner

Overview

This PR adds initial diffusion-gemma support for Gemma 4 based block-diffusion checkpoints. This is just a draft PR to get feedback on multiple design aspects of diffusion model like block diffusion, approximation of soft embeddings, separate vs single encoder-decoder, prefix KV-cache reuse, diffusion server utility etc.

Additional information

  • Add GGUF conversion support for diffusion Gemma checkpoints, including self-conditioning tensors and multimodal Gemma 4 vision/mmproj export.
  • Register the diffusion-gemma architecture and model implementation.
  • Implement the diffusion Gemma graph using Gemma 4 decoder blocks with bidirectional canvas attention, prompt-prefix conditioning, KV-cache reuse, and self-conditioning.
  • Add sparse top-k self-conditioning through on-device embedding gather.
  • Add llama-diffusion-gemma-cli for block-diffusion generation.
  • Add llama-diffusion-gemma-server, an HTTP server with /v1/completions, /v1/chat/completions, /health, /props, /metrics, and /slots.
  • Add CUDA backend support for diffusion top-k sampling, entropy/stability decisions, self-conditioning buffers, device-side canvas updates, and device-loop early stopping.
  • Enable CUDA graph friendly execution by keeping persistent diffusion inputs/output state on device and avoiding inter-step host sampling copies.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: yes, for generating the initial architecture support related changes and some bug fixes, merge conflicts and code review

lnigam added 27 commits June 9, 2026 23:05
…pport

Adds a new DIFFUSION_GEMMA4 architecture and a DiffusionGemma4Model converter
that reuses the existing gemma4 conversion path. The block-diffusion checkpoint
nests its language model under model.decoder.* and adds a self_conditioning MLP;
its text encoder shares all weights with the decoder except a per-layer
layer_scalar (dropped here, since the single-stack graph uses the decoder set).

- gguf-py: MODEL_ARCH.DIFFUSION_GEMMA4 + SELF_COND_* tensors and HF name maps
- conversion: DiffusionGemma4Model strips the model.decoder. prefix, drops
  encoder-only tensors, and otherwise inherits gemma4 hparam/tensor handling
- relax Gemma4Model num_kv_shared_layers lookup (key absent in diffusion config)

Dry-run validated: architecture and hparams recognized (5:1 sliding/global KV
pattern, dual head dims, softcapping, 128/8 experts), all 691 tensors map.
Adds LLM_ARCH_DIFFUSION_GEMMA4, reusing the gemma4 decoder block (hparams and
graph) by subclassing llama_model_gemma4. Adds the top-level self_conditioning
MLP tensors (SELF_COND_*); load_arch_tensors loads them alongside the inherited
gemma4 tensors. The block-diffusion sampling loop and the self_conditioning /
bidirectional / encoder-KV graph wiring are layered on top in a later step.

- llama-arch: arch name, SELF_COND_* tensor names and tensor infos
- llama-model: create / rope-type / kv-reuse dispatch for the new arch
- models: llama_model_diffusion_gemma4 subclass loading self_conditioning weights

Verified: the 50.5 GB bf16 GGUF loads, all 691 tensors map, and a forward pass
runs end-to-end with correct shapes (dual head dims, 128-expert MoE routing,
tied output). Runs as a gemma4-style causal LM for now; diffusion semantics next.
Replaces the inherited gemma4 (causal, KV-cache) graph with a dedicated
bidirectional, no-KV-cache denoising graph:

- load_arch_hparams: reuse gemma4 hparams, then set causal_attn = false
- register the arch in llm_arch_is_diffusion and on the no-memory path
  (res = nullptr), so it runs like the other diffusion LMs (DREAM/LLADA)
- graph: gemma4 per-layer block (QK-norm, scale-less V-norm, dual head dims,
  proportional rope on full layers, dense+MoE dual FFN, layer_scalar) but with
  build_attn_inp_no_cache (bidirectional attention)
- input: scaled embeddings -> scale-less RMS norm, which is exactly the
  self-conditioning transform on the first denoising step (soft-cond = 0)

Scope: a single bidirectional denoising pass over the canvas with no prompt
context. Valid while canvas_length <= sliding_window (sliding == full attn).
The soft-conditioning input path (later steps) and encoder-KV cross-attention
(prompted generation) come next.

Verified: builds clean; loads the 50.5 GB GGUF with no KV cache and runs a full
forward (self_cond_input node + bidirectional attention + dual-FFN MoE +
layer_scalar + softcapped logits).
Adds the self_conditioning gated MLP to the decoder input:
  sc_input = post_norm(inputs_embeds + down(gelu(gate(pre_norm(soft))) * up(pre_norm(soft))))
using the self_cond_norm/gate/up/down weights. The soft-embeddings input is a zero
placeholder for now (the block-diffusion sampler will feed softmax(prev_logits) @ embed
per denoising step); with soft = 0 this is numerically identical to the verified
first-step behaviour (rms_norm(0) = 0 -> sc_signal = 0 -> scale-less post-norm of the
scaled embeddings), so no regression, but the self_cond weights are now used in-graph.

First slice of the self-conditioning + block-diffusion-sampler unit; the runtime
soft-embeddings feed (a settable input channel) lands with the sampler.
Adds llama-diffusion-gemma4-cli implementing the reference block-diffusion loop:
random canvas init, per-step full-canvas decode, linear temperature schedule
(0.4 -> 0.8), entropy-bound token acceptance, renoise of non-accepted positions,
and stable-and-confident stopping. Drives the bidirectional no-KV-cache graph.

Scope: unconditioned generation with self-conditioning = 0 (the prompt is not used
yet). The loop runs end-to-end and denoises as expected -- over the steps the mean
token entropy falls (3.12 -> 0.15) and the accepted-token count rises (1 -> 7/16),
i.e. the canvas converges. Self-conditioning feedback and prompt/encoder-KV
conditioning are layered on next.

DG4_CANVAS / DG4_STEPS env vars override canvas length / step count for testing.
Adds a per-decode soft-conditioning input so the block-diffusion sampler can feed
the previous denoising step's token probabilities back into the decoder.

- core: llama_diffusion_cond + llm_graph_input_diffusion_self_cond + build_inp_*,
  threaded through llm_graph_params/llm_graph_context exactly like llama_cross;
  public API llama_set_diffusion_self_cond(ctx, probs, n_vocab, n_tokens).
- graph: soft_embeddings = (probs @ token_embd) * sqrt(n_embd) -> self_cond MLP ->
  added to the input embeddings (replaces the zero placeholder). Empty buffer ->
  zeros == the verified first-step behaviour, so no regression.
- sampler: feeds softmax(processed_logits) each step for the next decode; cleared
  for step 0.

Runs end-to-end; step 0 is unchanged (zero self-cond), later steps use the feedback.
Full numerical verification vs PyTorch self_conditioning is pending. Perf TODO: the
in-graph embedding transpose for the probs@embed matmul is recomputed per decode.
Adds prompted block-diffusion generation as a single-pass [prompt ; canvas] forward.

- core: llm_graph_input_attn_no_cache_prefix + build_attn_inp_no_cache_prefix(n_prompt)
  build a prefix mask (prompt attends causally within the prompt; canvas attends to
  everything -> bidirectional + cross to the prompt). n_prompt is carried on
  llama_diffusion_cond and set via llama_set_diffusion_prompt_len().
- graph: prompt rows use the raw scaled embeddings (encoder; no self-conditioning /
  post-norm); canvas rows use the self-conditioned input. Uses the prefix attention
  when n_prompt > 0, else fully-bidirectional no-cache.
- sampler: tokenizes the prompt, builds [prompt ; canvas] each step, requests logits
  for the canvas positions only, and feeds self-cond probs over the full sequence
  (prompt columns zero).

Result: prompted generation runs end-to-end. "The capital of France is" denoises the
canvas toward relevant tokens (including "Paris"). Output quality is still rough
(repetition/filler); refinement is ongoing (self-conditioning numerical verification,
sampler tuning, full canvas/steps, reconvert from the current v5 source).

Notes: encoder and decoder layer_scalars are identical in this checkpoint, so no
separate encoder scalars are needed. The prefix mask currently assumes
n_tokens <= sliding_window (sliding == full); long prompts need a windowed prefix mask.
This is a chat-trained model (turn/channel special tokens); feeding raw text gives
poor results. Format the user prompt with the model's chat template
(common_chat_templates) before tokenizing, parsing special tokens.

HF reference (chat-formatted) answers cleanly ("The capital of France is **Paris**."
then <pad> filler). With chat formatting the canvas now denoises toward prompt-relevant
tokens, though full convergence still needs work (canvas size / sampler dynamics).
In the no-KV-cache path llama.cpp prunes tokens that are not requested as outputs.
The sampler marked only the canvas tokens as outputs (logits=1), so the prompt
tokens were pruned from the attention and the canvas never attended to the prompt
-- generation was effectively unconditioned (identical output for different prompts).

Fix: request logits for all [prompt; canvas] tokens and read only the canvas rows
(offset n_prompt). The prompted forward now matches the HF reference (pos-0 top-5
logits agree to bf16 precision; output is prompt-dependent), and generation produces
the correct structure and reasoning (thought channel -> "The capital of France is ...").

Residual: filler positions don't fully converge to <pad> (sampler-dynamics polish).
After the denoising loop, decode the converged canvas once more and emit the greedy
argmax per canvas position instead of the accumulated `accepted` array. Positions that
were never accepted by the entropy-bound sampler carried stale/renoised tokens, which
showed up as garbage in the output; reading the model's argmax given the settled answer
cleans them up (filler -> <pad>).

With this, prompted generation produces correct, coherent answers, e.g.:
  "What is the capital of France?" ->
  <|channel>thought
  The user is asking for the capital of France.
      * Country: France
      * Capital: Paris
  The capital of France is Paris.
The model answers inside a "<|channel>thought ... <channel|>" block followed by the
response, and the fixed 256-canvas tail repeats the answer. Print the full canvas for
reference, then extract the final response (tokens after the last <channel|>, truncated
at the first end-of-generation token) and drop a trailing exact-duplicate.

"What is the capital of France?" now yields a clean:
  === answer ===
  The capital of France is Paris.
The CLI passed common_params.n_gpu_layers (-1) straight to the model loader without
resolving it, so it ran on CPU even in a CUDA build. Default to offloading all layers
(999) when -ngl is not given; -ngl N limits offload and -ngl 0 forces CPU. No effect in
a CPU-only build. With -DGGML_CUDA=ON the model now runs on the GPU by default.
Encode the prompt (and previously-finalized canvases) once into the unified
sliding-window KV cache and reuse it, instead of re-encoding [prompt; canvas]
on every denoising step.

- Encoder phase (causal, no self-conditioning): prefill the prompt / commit a
  finalized canvas into the cache; its K/V becomes a read-only prefix.
- Decoder phase (bidirectional, self-conditioned): each denoising step decodes
  only the canvas against the cached prefix, then rolls back its own K/V.
- Multi-block autoregressive loop: commit each finalized canvas and advance the
  cache pointer by canvas_length.

Two graph variants share the gemma4 transformer body: a single phase-branching
graph (default) and a separate encoder/decoder pair (DG4_SEPARATE_ENC_DEC).

Also:
- enable the iswa KV cache for the arch (was res=nullptr), reusing the gemma4
  layer-reuse / has_kv handling;
- precompute a transposed F32 token embedding once at load (load_arch_post) for
  the self-conditioning soft-embedding matmul, avoiding a per-decode dequantize
  + transpose of the whole embedding;
- add a per-decode phase selector API: llama_set_diffusion_decoder_phase.

Verified on CPU and CUDA (partial and full GPU offload): correct prompted
generation, multi-block coherence, and clean entropy-bound convergence.
Update the block-diffusion Gemma port to the v7 checkpoint, which renames the
text architecture diffusion_gemma4 -> diffusion_gemma and adds a gemma4 vision
tower (model_type gemma4_vision) + projector for image input.

Rename:
- DIFFUSION_GEMMA4 -> DIFFUSION_GEMMA; arch string "diffusion-gemma4" ->
  "diffusion-gemma"; model class llama_model_diffusion_gemma; converter class
  DiffusionGemmaModel registered for "DiffusionGemmaForBlockDiffusion".
- renamed files src/models/diffusion-gemma.cpp, examples/diffusion-gemma/, and
  the CLI env knobs (DG_*).

Multimodal (vision-only; the v7 diffusion checkpoint has no audio tower):
- new DiffusionGemmaVisionModel mmproj converter reusing the existing GEMMA4V
  vision export: strips the v7 model.encoder.* nesting, skips audio, registered
  in MMPROJ_MODEL_MAP. The clip.cpp GEMMA4V encoder is reused unchanged.
- diffusion-gemma CLI gains --mmproj/--image (enabled for LLAMA_EXAMPLE_DIFFUSION):
  the image marker is tokenized via libmtmd and the 280 GEMMA4V vision embeddings
  are fed into the diffusion encoder-phase prefill (mtmd_helper_eval_chunks); the
  canvas is then denoised against the cached prefix.

Verified on CUDA (RTX 5090): text Q4_K_M answers correctly, and image+text
produces an accurate, OCR-level description via the GEMMA4V vision tower.
…ne-argmax output

- -n / --n-predict now sets the number of 256-token canvas blocks
  (max_canvases = ceil(n_predict / canvas_length)); canvas_length is fixed at the
  trained block size (256). DG_MAX_CANVASES still overrides; n_ctx auto-sizes.
- print a generation timing summary at the end (blocks, denoising steps, canvas
  tokens, wall-clock, canvas tok/s, s/step, answer tokens).
- emit each block via the inline argmax of the last (stable) denoising step,
  matching the reference DiffusionGemma _denoising_step, instead of a separate
  read-out over a stale `accepted` scratch buffer. Fixes the unconverged canvas
  tail: never-accepted high-entropy positions no longer carry stale-random tokens
  into the output / committed prefix. Commit also uses the inline argmax.

Verified on CUDA: multi-block generation (e.g. -n 1536 -> 6 blocks, with
entropy-bound early-stopping) produces coherent, accurate prose and reports timing.
…ded backend

The self-conditioning soft-embedding does a full-vocabulary matmul
(softmax(prev_logits) @ token_embd) with the precomputed F32 tok_embd_t every
decode. token_embd is normally host-resident (it is only used for a cheap
get_rows lookup in AR models), and tok_embd_t inherited that host buffer, so the
scheduler copied the whole ~2.75 GiB tensor across PCIe on every forward,
dominating per-step time.

Allocate tok_embd_t on a non-host (offloaded) backend taken from a layer weight
when one is offloaded, so the matmul runs on-device with no per-decode copy; fall
back to token_embd's buffer for CPU-only runs.

Measured on RTX 5090 (Q4_K_M, -ngl 99, pp512/tg256 @ d512):
  pp512  1339 -> 3530 t/s (2.6x)
  tg256  4.84 -> 148.8 t/s (30.7x)
real generation ~0.43 s/step (was ~0.6-0.8); output unchanged.
Time the prompt prefill (encoder phase: causal, no self-conditioning) separately
from the denoising loop and print tokens / seconds / tok-per-second. Makes the
real prefill cost visible (it skips the self-cond matmul, unlike a decoder-phase
forward), so it is not conflated with the per-step denoising cost.
…) with k-annealing

The decoder/denoise phase computed softmax + entropy + self-conditioning over the
full 262k vocabulary on the host every step (~67M exp/log per step), which
dominated the per-step wall-clock. Add an opt-in top-k path: per canvas position,
select the top-k logits (size-k min-heap in one scan), then do softmax / entropy /
multinomial-sample / sparse self-conditioning over just those k. The self-cond
buffer is sparse (top-k filled), so the existing in-graph soft-embedding matmul
blends only the top-k embeddings (the dropped tail carries negligible embedding
weight; the following RMS norm absorbs the scale).

Knobs (CLI flags; default = full softmax, behaviour unchanged):
- --top-k N                 : fixed top-k per position (0 = full softmax)
- --top-k-start/--top-k-end : anneal k from high (first/high-entropy step) to low
                              (last step) -- early canvases are flat (need many
                              tokens), late ones are peaked (a few suffice)
- --top-k-tail-correction   : use the exact full-vocab entropy (logsumexp) for the
                              accept/stop signal instead of the under-estimating
                              top-k entropy (top-k truncation deflates entropy)

Also report encoder-phase prefill timing separately, and add the generation
timing summary. RTX 5090 (Q4_K_M, -ngl 99): ~2.6x faster per denoising step with
top-k (no correction); output preserved on the tested prompts.

This is host-side only -- it reuses the existing self-conditioning channel and
graph. The fuller graph-side variant (gather top-k embedding rows + a dedicated
top-k self-cond input, dropping the full-vocab matmul and the F32 transposed
embedding) is left as a follow-up.
Replace the dense full-vocab self-conditioning soft-embedding
(probs @ token_embd over all 262144 rows) with a sparse gather of only
the previous step's top-k token embeddings, blended by their
probabilities. The CLI feeds the top-256 (id, prob) per canvas position
each denoising step via a new API; the decoder graph gathers those rows
(ggml_get_rows), prob-weights and sums them (x sqrt(n_embd)).

The gather embedding (tok_embd_gpu) is stored as F16 on the offloaded
backend. F16 (not the native Q4_K) is required because CUDA get_rows has
no Q4_K/Q6_K kernel -- a quantized gather falls back to CPU every step
and is a large regression. F16 keeps the gather on-device and halves its
VRAM vs the F32 dense transpose (1.47 GB vs 2.75 GB).

Measured (RTX 5090, Q4_K_M, France prompt): --top-k 64 ~0.13 s/step vs
0.158 dense (~1.2x faster), ~1.3 GB less VRAM; k=0 at parity (host
softmax dominated). Output unchanged.

The dense matmul path is retained as a fallback when the F16 gather copy
cannot be allocated.

API:
  - llama_set_diffusion_self_cond_topk(ctx, ids, probs, k, n_tokens)
  - llm_graph_input_diffusion_self_cond_topk + build_inp_diffusion_self_cond_topk
Add examples/diffusion-gemma/diffusion-gemma-server.cpp (target
llama-diffusion-gemma-server): loads a block-diffusion model once and serves
the same denoising generation loop as the CLI over HTTP, with the llama-server
observability surface re-mapped to diffusion semantics.

Endpoints: GET /health, /v1/health, /props, /v1/models, /models,
/slots (--slots), /metrics (--metrics); POST /v1/chat/completions (+ SSE
stream) and /v1/completions.

Responses carry the llama-server `timings` object (prompt_n/prompt_ms,
predicted_n/predicted_ms, ...) extended with a `diffusion` sub-object
(n_blocks, n_steps, canvas_tokens, ms_per_step, steps_per_second,
canvas_tokens_per_second, n_decode), plus the OpenAI `usage` object. Per-request
llama_perf-style timing logs, an access log, and a llama-server-style startup
banner are emitted. Prometheus /metrics exposes prompt/predicted token counters
and gauges alongside diffusion_blocks/steps/canvas_tokens totals and rate gauges.

Generation is serialized behind a mutex (one slot; the context is single-
threaded). Server-only flags (--host, --port, --api-key, --metrics, --slots)
are stripped before common_params_parse so all diffusion flags still parse.
Request fields: max_tokens -> canvas blocks, top_k, seed, ignore_eos (run the
full block count for sustained long-generation benchmarking).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant