Skip to content

Add Prompt Lookup Decoding (ngram-simple) and Rolling-Hash Speculative Memory (ngram-mod)#1297

Open
mayank2130 wants to merge 3 commits into
ml-explore:mainfrom
mayank2130:pld-ngram-simple
Open

Add Prompt Lookup Decoding (ngram-simple) and Rolling-Hash Speculative Memory (ngram-mod)#1297
mayank2130 wants to merge 3 commits into
ml-explore:mainfrom
mayank2130:pld-ngram-simple

Conversation

@mayank2130
Copy link
Copy Markdown

@mayank2130 mayank2130 commented May 22, 2026

Closes #851

Summary

Adds Prompt Lookup Decoding (PLD) and rolling-hash speculative decoding to mlx_lm via a generalized DraftStrategy abstraction.

Instead of generating speculative drafts with a smaller neural model, the new strategies reuse previously observed token trajectories:

  • ngram-simple performs exact prompt-history lookup
  • ngram-mod implements a rolling-hash associative memory ported from llama.cpp PR #19164

Both strategies preserve output correctness because speculative tokens are only accepted if verified by the target model under the same sampling configuration.

This PR adds:

  • DraftStrategy interface for pluggable speculative drafters
  • ModelDraftStrategy for existing neural drafting
  • NgramSimpleStrategy for prompt lookup decoding
  • NgramModStrategy + NgramModTable for rolling-hash speculative memory
  • optional adaptive repetition gating
  • CLI, Python, and server integration
  • process-shared speculative memory for cross-request reuse

Usage

from mlx_lm import load, stream_generate
from mlx_lm.sample_utils import make_sampler

model, tokenizer = load("Qwen/Qwen3-8B-MLX-4bit")

prompt = tokenizer.apply_chat_template(
    [{"role": "user", "content": "Write a Python function `add(a, b)`."}],
    add_generation_prompt=True,
    enable_thinking=False,
)

for response in stream_generate(
    model,
    tokenizer,
    prompt,
    max_tokens=256,
    sampler=make_sampler(temp=0.0),
    draft_type="ngram-simple",      # or "ngram-mod"
    num_draft_tokens=4,
    ngram_size=3,                   # use 16 for ngram-mod
    disable_adaptive_gate=True,
):
    print(response.text, end="", flush=True)

For ngram-mod, reuse a table across related generations to preserve learned n-gram memory:

from mlx_lm import load, stream_generate
from mlx_lm.generate import NgramModTable
from mlx_lm.sample_utils import make_sampler

model, tokenizer = load("Qwen/Qwen3-8B-MLX-4bit")
table = NgramModTable(n=16)

for prompt_text in prompts:
    prompt = tokenizer.apply_chat_template(
        [{"role": "user", "content": prompt_text}],
        add_generation_prompt=True,
        enable_thinking=False,
    )

    for response in stream_generate(
        model,
        tokenizer,
        prompt,
        max_tokens=256,
        sampler=make_sampler(temp=0.0),
        draft_type="ngram-mod",
        num_draft_tokens=6,
        ngram_size=16,
        ngram_mod_table=table,
        disable_adaptive_gate=True,
    ):
        print(response.text, end="", flush=True)

CLI: multi-turn ngram-simple

printf '%s\n%s\n%s\nq\n' \
'Write a Python function summarize_orders(orders) where each order has id, customer, total, and status. Return only the code.' \
'Now add a currency="$" parameter and use it when formatting money values. Keep the cancelled-order behavior unchanged. Return the full updated function only.' \
'Update summarize_orders so it skips orders whose status is cancelled. Keep the same structure and return the full updated function only.' \
| python -m mlx_lm chat \
  --model Qwen/Qwen3-8B-MLX-4bit \
  --max-tokens 500 \
  --temp 0 \
  --chat-template-config '{"enable_thinking": false}' \
  --draft-type ngram-simple \
  --num-draft-tokens 4 \
  --ngram-size 3 \
  --disable-adaptive-gate

CLI: multi-turn ngram-mod

printf '%s\n%s\n%s\nq\n' \
'Write a Python function summarize_orders(orders) where each order has id, customer, total, and status. Return only the code.' \
'Now add a currency="$" parameter and use it when formatting money values. Keep the cancelled-order behavior unchanged. Return the full updated function only.' \
'Update summarize_orders so it skips orders whose status is cancelled. Keep the same structure and return the full updated function only.' \
| python -m mlx_lm chat \
  --model Qwen/Qwen3-8B-MLX-4bit \
  --max-tokens 500 \
  --temp 0 \
  --chat-template-config '{"enable_thinking": false}' \
  --draft-type ngram-mod \
  --num-draft-tokens 6 \
  --ngram-size 16 \
  --disable-adaptive-gate

The chat command keeps the conversation history and prompt cache alive across turns, so T2/T3 can reuse the generated structure from T1.

Server

Per-request JSON overrides: draft_type, ngram_size disable_adaptive_gate

Architecture

Speculative drafting is abstracted behind:

class DraftStrategy(Protocol):
    def propose(self, y, n_max, ctx) -> mx.array: ...
    def rewind(self, n: int) -> None: ...
    def observe(self, tokens) -> None: ...
    def accept(self, n_accepted: int, n_drafted: int) -> None: ...

NgramSimpleStrategy scans backward for matching n-grams and proposes the following continuation tokens directly from prior history.

  • Best suited for: short iterative edits, local repetition, single-user coding flows

NgramModStrategy ports llama.cpp's rolling-hash speculative memory.
Architecture mirrors llama.cpp's split between:

  • process-global speculative memory
  • per-request runtime state

The shared table stores:
hash(ngram) -> next_token
allowing speculative reuse across requests handled by the same running server process.

Implementation behavior intentionally matches llama.cpp:

  • fixed-size lossy hash table
  • silent overwrite collision policy
  • verifier-corrected speculative drafts
  • adaptive reset on repeated low acceptance

Adaptive Gate

An optional adaptive gate computes a 3-gram repetition score over the prompt. If repetition falls below:
NGRAM_GATE_THRESHOLD = 0.02
speculation is skipped automatically.

This is particularly important for ngram-mod, whose cold-start behavior can regress below baseline throughput on low-repetition prompts.

Benchmarks

All benchmarks used:

  • mlx-community/Llama-3.2-3B-Instruct-4bit
  • Apple Silicon

LONG MULTI-TURN EDITING (~280 TOK/TURN) — OVERALL

config tok/s acc% speedup
baseline 54.09 1.00×
ngram-simple nd=4 91.57 62.7% 1.69×
ngram-simple nd=6 89.01 67.8% 1.65×
ngram-mod nd=6 84.69 59.4% 1.57×
ngram-mod nd=8 82.69 61.7% 1.53×

ngram-mod nd=6 per-turn behavior

turn prompt baseline tok/s ngram-simple nd=6 ngram-mod nd=6
T1 write EmailValidator class 57.1 62.4 (1.09×) 54.4 (0.95×)
T2 add is_disposable method 57.9 118.3 (2.04×) 123.9 (2.14×)
T3 reject whitespace in domain 55.0 118.4 (2.15×) 122.9 (2.24×)
T4 log inside validate 49.9 118.0 (2.37×) 135.8 (2.72×)

mayank2130 and others added 2 commits May 19, 2026 15:35
@mayank2130 mayank2130 changed the title Add Prompt Lookup Decoding (ngram-simple) via DraftStrategy abstraction Add Prompt Lookup Decoding (ngram-simple & ngram-mod) via DraftStrategy abstraction May 22, 2026
@mayank2130 mayank2130 changed the title Add Prompt Lookup Decoding (ngram-simple & ngram-mod) via DraftStrategy abstraction Add Prompt Lookup Decoding (ngram-simple) and Rolling-Hash Speculative Memory (ngram-mod) May 22, 2026
@mayank2130
Copy link
Copy Markdown
Author

hey @angeloskath can this PLD/n-gram decoding be reviewed.

If you're not the one to reachout for mlx-lm PRs could you point me to someone else. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

n-gram hashing for speculative decoding

1 participant