Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ name = "bench_decode_ab"
path = "src/bin/bench_decode_ab.rs"
required-features = ["f16", "metal-gpu"]

[[bin]]
name = "ppl_metal"
path = "src/bin/ppl_metal.rs"
required-features = ["f16", "metal-gpu"]

[[bin]]
name = "bench_logit_dump"
path = "src/bin/bench_logit_dump.rs"
Expand Down
44 changes: 44 additions & 0 deletions crates/inference/src/bin/ppl_metal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use lattice_inference::forward::metal_qwen35::MetalQwen35State;
use lattice_inference::model::qwen35::{PerplexityConfig, Qwen35Model};
use lattice_inference::tokenizer::{BpeTokenizer, Tokenizer};

fn main() {
let home = std::env::var("HOME").unwrap();
let model_dir = std::env::var("LATTICE_MODEL_DIR")
.unwrap_or_else(|_| format!("{home}/.lattice/models/qwen3.5-0.8b"));
let dir = std::path::Path::new(&model_dir);
let n_tokens: usize = std::env::var("PPL_TOKENS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(2048);

eprintln!("[ppl_metal] loading {model_dir}");
let model = Qwen35Model::from_safetensors(dir).expect("load model");
let cfg = model.config().clone();
let mut metal = MetalQwen35State::new(model.weights(), &cfg, 4096).expect("init metal");
let tokenizer = BpeTokenizer::from_tokenizer_json(&dir.join("tokenizer.json")).unwrap();

let corpus_path =
std::env::var("CORPUS").unwrap_or_else(|_| "/tmp/wikitext2_test.txt".to_string());
let corpus = std::fs::read_to_string(&corpus_path).expect("read corpus");
let input = tokenizer.tokenize(&corpus);
let all_tokens: Vec<u32> = input.input_ids[..input.real_length].to_vec();
let tokens = &all_tokens[..all_tokens.len().min(n_tokens)];
eprintln!(
"[ppl_metal] scoring {} tokens (Metal GPU, Q8 + f16 lm_head)",
tokens.len()
);

let t = std::time::Instant::now();
let ppl_cfg = PerplexityConfig {
window: 512,
stride: 256,
};
let report = metal.compute_perplexity(tokens, &ppl_cfg).expect("ppl");
let elapsed = t.elapsed();

println!("PPL: {:.4}", report.ppl);
println!("NLL: {:.6}", report.mean_nll);
println!("Tokens: {}", report.num_tokens_scored);
println!("Time: {:.1}s", elapsed.as_secs_f64());
}
122 changes: 87 additions & 35 deletions crates/inference/src/forward/metal_qwen35.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,65 @@ kernel void gemv_q8_decode_wide(
}
}

// ===== FP16-weight GEMV Decode (wide): NR=4 rows per threadgroup =====
// Same scheduling as gemv_q8_decode_wide but reads half-precision weights
// directly (no Q8 block structure). Dispatched for large-N matmuls (lm_head)
// where the NR=1 gemv_decode_m1 kernel creates excessive threadgroup count.
// Dispatch: threadgroups=(ceil(N/4), 1, 1), threads=(32, 4, 1)
kernel void gemv_decode_wide_f16(
device const float* x [[buffer(0)]],
device const half* W [[buffer(1)]],
device float* y [[buffer(2)]],
constant uint& N [[buffer(3)]],
constant uint& K [[buffer(4)]],
uint3 tgpig [[threadgroup_position_in_grid]],
uint tiisg [[thread_index_in_simdgroup]],
uint sgitg [[simdgroup_index_in_threadgroup]])
{
const uint NR = 4;
const uint NSG = 4;
const uint nb = K / 32;
const uint first_row = tgpig.x * NR;
const uint ix = tiisg / 4;
const uint il = tiisg % 4;

float sumf[NR] = {0.0f};
const uint ib_start = sgitg * 8 + ix;
const uint ib_stride = NSG * 8;
device const float* xb = x + ib_start * 32 + il * 8;

for (uint ib = ib_start; ib < nb; ib += ib_stride) {
float xl[8];
for (uint i = 0; i < 8; i++) xl[i] = xb[i];
xb += ib_stride * 32;

for (uint row = 0; row < NR; row++) {
uint r = first_row + row;
if (r >= N) continue;
device const half* wrow = W + r * K + ib * 32 + il * 8;
float dot = 0.0f;
for (uint i = 0; i < 8; i++) dot += xl[i] * float(wrow[i]);
sumf[row] += dot;
}
}

for (uint row = 0; row < NR; row++) sumf[row] = simd_sum(sumf[row]);

threadgroup float shared[NR][4];
if (tiisg == 0) {
for (uint row = 0; row < NR; row++) shared[row][sgitg] = sumf[row];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0 && tiisg == 0) {
for (uint row = 0; row < NR; row++) {
uint r = first_row + row;
if (r < N) {
y[r] = shared[row][0] + shared[row][1] + shared[row][2] + shared[row][3];
}
}
}
}

// ===== Qwen3.5 RMS Norm: x = x * (1 + gamma) / sqrt(mean(x^2) + eps) =====
// Shifted norm: (1 + gamma) instead of plain gamma.
// One threadgroup per row.
Expand Down Expand Up @@ -2812,6 +2871,7 @@ kernel void moe_shared_gate_add(
/// Compiled MSL pipeline state objects.
struct MetalQwen35Pipelines {
gemv_decode: ComputePipelineState,
gemv_decode_wide: ComputePipelineState,
gemv_q8: ComputePipelineState,
gemv_q8_wide: ComputePipelineState,
rms_norm: ComputePipelineState,
Expand Down Expand Up @@ -3632,6 +3692,7 @@ kernel void moe_shared_gate_add(

let pipelines = MetalQwen35Pipelines {
gemv_decode: make_pipeline("gemv_decode_m1")?,
gemv_decode_wide: make_pipeline("gemv_decode_wide_f16")?,
gemv_q8: make_pipeline("gemv_q8_decode")?,
gemv_q8_wide: make_pipeline("gemv_q8_decode_wide")?,
gemv_q4: make_pipeline("gemv_q4_decode")?,
Expand Down Expand Up @@ -7013,45 +7074,33 @@ kernel void moe_shared_gate_add(
// matmul kernel is faster + handles QuaRot rotations correctly.
let use_fp16_lm_head = matches!(self.engine.quant_format, QuantFormat::Q8_0);
if use_fp16_lm_head {
let gemm_params = GemmParams {
m: 1,
n: cfg.vocab_size as u32,
k: hidden as u32,
lda: hidden as u32,
ldb: hidden as u32,
ldc: cfg.vocab_size as u32,
};
let vocab_n = cfg.vocab_size as u32;
let hidden_k = hidden as u32;
if let Some(ref pb) = ppl_buf {
for p in 0..n {
let hidden_off = (p * hidden) as u64 * 4;
let logits_off = (p * cfg.vocab_size) as u64 * 4;
enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode);
enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode_wide);
enc.set_buffer(0, Some(&self.session.activations.hidden), hidden_off);
enc.set_buffer(1, Some(&self.engine.embed_tokens), 0);
enc.set_buffer(2, Some(pb), logits_off);
enc.set_bytes(
3,
std::mem::size_of::<GemmParams>() as u64,
&gemm_params as *const GemmParams as *const _,
);
enc.set_bytes(3, 4, &vocab_n as *const u32 as *const _);
enc.set_bytes(4, 4, &hidden_k as *const u32 as *const _);
enc.dispatch_thread_groups(
MTLSize::new(cfg.vocab_size as u64, 1, 1),
MTLSize::new(256, 1, 1),
MTLSize::new(vocab_n.div_ceil(4) as u64, 1, 1),
MTLSize::new(32, 4, 1),
);
}
} else {
enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode);
enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode_wide);
enc.set_buffer(0, Some(&self.session.activations.hidden), last_off);
enc.set_buffer(1, Some(&self.engine.embed_tokens), 0);
enc.set_buffer(2, Some(&self.session.activations.logits), 0);
enc.set_bytes(
3,
std::mem::size_of::<GemmParams>() as u64,
&gemm_params as *const GemmParams as *const _,
);
enc.set_bytes(3, 4, &vocab_n as *const u32 as *const _);
enc.set_bytes(4, 4, &hidden_k as *const u32 as *const _);
enc.dispatch_thread_groups(
MTLSize::new(cfg.vocab_size as u64, 1, 1),
MTLSize::new(256, 1, 1),
MTLSize::new(vocab_n.div_ceil(4) as u64, 1, 1),
MTLSize::new(32, 4, 1),
);
}
} else if let Some(ref pb) = ppl_buf {
Expand Down Expand Up @@ -9670,19 +9719,21 @@ kernel void moe_shared_gate_add(
1,
cfg.rms_norm_eps,
);
// lm_head: for Q8 format use FP16 (matches MLX), for Q4 use Q4 path.
// Q8 per-row symmetric quantization of embeddings causes ~0.79 PPL
// gap from per-row scale not capturing outlier embeddings cleanly.
// lm_head: for Q8 format use FP16 wide kernel (NR=4, fewer TGs than
// gemv_decode_m1 which dispatches one TG per vocab row). For Q4 use Q4 path.
match self.engine.quant_format {
QuantFormat::Q8_0 => {
self.dispatch_matmul_half(
enc,
&self.session.activations.hidden,
&self.engine.embed_tokens,
&self.session.activations.logits,
1,
cfg.vocab_size as u32,
hidden as u32,
let vocab_n = cfg.vocab_size as u32;
let hidden_k = hidden as u32;
enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode_wide);
enc.set_buffer(0, Some(&self.session.activations.hidden), 0);
enc.set_buffer(1, Some(&self.engine.embed_tokens), 0);
enc.set_buffer(2, Some(&self.session.activations.logits), 0);
enc.set_bytes(3, 4, &vocab_n as *const u32 as *const _);
enc.set_bytes(4, 4, &hidden_k as *const u32 as *const _);
enc.dispatch_thread_groups(
MTLSize::new(vocab_n.div_ceil(4) as u64, 1, 1),
MTLSize::new(32, 4, 1),
);
}
QuantFormat::Q4_0 => {
Expand Down Expand Up @@ -11439,6 +11490,7 @@ kernel void moe_shared_gate_add(

let pipelines = MetalQwen35Pipelines {
gemv_decode: make_pipeline("gemv_decode_m1")?,
gemv_decode_wide: make_pipeline("gemv_decode_wide_f16")?,
gemv_q8: make_pipeline("gemv_q8_decode")?,
gemv_q8_wide: make_pipeline("gemv_q8_decode_wide")?,
gemv_q4: make_pipeline("gemv_q4_decode")?,
Expand Down
20 changes: 20 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[project]
name = "lattice-dev"
version = "0.1.0"
description = "Dev scripts for lattice inference engine"
requires-python = ">=3.11"

dependencies = [
"pyarrow>=17.0",
"datasets>=3.0",
"numpy>=1.26",
"mlx>=0.22",
"mlx-lm>=0.21",
"transformers>=4.46",
"matplotlib>=3.9",
]

[tool.uv]
dev-dependencies = [
"pytest>=8.0",
]
Loading