diff --git a/crates/inference/src/forward/metal_qwen35.rs b/crates/inference/src/forward/metal_qwen35.rs index c92e40bb..206b637a 100644 --- a/crates/inference/src/forward/metal_qwen35.rs +++ b/crates/inference/src/forward/metal_qwen35.rs @@ -2358,6 +2358,290 @@ kernel void moe_shared_gate_add( if (gid >= hidden) return; scratch_out[gid] = fma(gate_val, expert_out[gid], scratch_out[gid]); } + +// ===== Batch scatter Q and gate from interleaved q_proj output (Win 1) ===== +// Processes num_tokens rows in one dispatch. +// Source: q[num_tokens, num_heads * 2 * head_dim] with (Q_h, gate_h) per head. +// Output: q_out[num_tokens, num_heads * head_dim], gate_out same shape. +kernel void scatter_q_gate_batch( + device const float* qg_interleaved [[buffer(0)]], + device float* q_out [[buffer(1)]], + device float* gate_out [[buffer(2)]], + constant uint& num_tokens [[buffer(3)]], + constant uint& num_heads [[buffer(4)]], + constant uint& head_dim [[buffer(5)]], + uint gid [[thread_position_in_grid]]) +{ + uint q_dim = num_heads * head_dim; + uint total = num_tokens * q_dim; + if (gid >= total) return; + + uint t = gid / q_dim; + uint rem = gid % q_dim; + uint head = rem / head_dim; + uint d = rem % head_dim; + + // Source row: token t, head layout (Q_h, gate_h) per head in q_proj output. + uint src_base = t * (2u * q_dim) + head * (2u * head_dim); + q_out[gid] = qg_interleaved[src_base + d]; + gate_out[gid] = qg_interleaved[src_base + head_dim + d]; +} + +// ===== Batch per-head RMS norm (Win 1) ===== +// One threadgroup per (token, head) pair; same math as per_head_rms_norm. +kernel void per_head_rms_norm_batch( + device float* x [[buffer(0)]], + device const float* gamma [[buffer(1)]], + constant uint& num_tokens [[buffer(2)]], + constant uint& num_heads [[buffer(3)]], + constant uint& head_dim [[buffer(4)]], + constant float& eps [[buffer(5)]], + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint tgs [[threads_per_threadgroup]]) +{ + if (gid >= num_tokens * num_heads) return; + + uint t = gid / num_heads; + uint head = gid % num_heads; + uint base = t * (num_heads * head_dim) + head * head_dim; + + constexpr uint NORM_WG = 256; + threadgroup float shared[NORM_WG]; + + float local_sum = 0.0f; + for (uint i = lid; i < head_dim; i += tgs) { + float v = x[base + i]; + local_sum += v * v; + } + shared[lid] = local_sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = tgs / 2; s > 0; s >>= 1) { + if (lid < s) shared[lid] += shared[lid + s]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + float rms = rsqrt(shared[0] / float(head_dim) + eps); + + // Qwen3.5 shifted RMSNorm: same (1 + gamma) convention as per_head_rms_norm. + for (uint i = lid; i < head_dim; i += tgs) { + x[base + i] = x[base + i] * rms * (1.0f + gamma[i]); + } +} + +// ===== Batch stride-half partial RoPE (Win 1) ===== +// Extends partial_rope_interleaved to num_tokens rows. +// RoPE absolute position for token t is base_pos + t, not chunk-local t. +kernel void partial_rope_batch( + device float* x [[buffer(0)]], + device const float* cos_tab [[buffer(1)]], + device const float* sin_tab [[buffer(2)]], + constant uint& num_tokens [[buffer(3)]], + constant uint& num_heads [[buffer(4)]], + constant uint& head_dim [[buffer(5)]], + constant uint& half_rope_dim [[buffer(6)]], + constant uint& base_pos [[buffer(7)]], + uint gid [[thread_position_in_grid]]) +{ + uint total_pairs = num_tokens * num_heads * half_rope_dim; + if (gid >= total_pairs) return; + + uint pair = gid % half_rope_dim; + uint head = (gid / half_rope_dim) % num_heads; + uint t = gid / (num_heads * half_rope_dim); + + uint base = t * (num_heads * head_dim) + head * head_dim; + // Absolute position keeps RoPE consistent across chunk boundaries. + uint cs_base = (base_pos + t) * half_rope_dim; + + float cos_val = cos_tab[cs_base + pair]; + float sin_val = sin_tab[cs_base + pair]; + + // Stride-half pairing: (pair, half_rope_dim + pair), matching HF rotate_half. + uint idx0 = base + pair; + uint idx1 = base + half_rope_dim + pair; + float x0 = x[idx0]; + float x1 = x[idx1]; + x[idx0] = x0 * cos_val - x1 * sin_val; + x[idx1] = x0 * sin_val + x1 * cos_val; +} + +// ===== Batch K/V cache store (Win 1) ===== +// Copies num_tokens rows of K and V into cache buffers in one dispatch. +// Cache layout is token-major: cache[row * kv_dim + d] where row = base_pos + t. +kernel void copy_kv_cache_batch( + device const float* k_src [[buffer(0)]], + device const float* v_src [[buffer(1)]], + device float* k_cache [[buffer(2)]], + device float* v_cache [[buffer(3)]], + constant uint& num_tokens [[buffer(4)]], + constant uint& kv_dim [[buffer(5)]], + constant uint& base_pos [[buffer(6)]], + uint gid [[thread_position_in_grid]]) +{ + uint total = num_tokens * kv_dim; + if (gid >= total) return; + + uint t = gid / kv_dim; + uint d = gid % kv_dim; + uint src = t * kv_dim + d; + uint dst = (base_pos + t) * kv_dim + d; + k_cache[dst] = k_src[src]; + v_cache[dst] = v_src[src]; +} + +// ===== Batched causal prefill attention (Win 2) ===== +// Replaces the per-token decode_attention loop for full-attention prefill chunks. +// Grid: [num_kv_heads, num_tokens, 1]. Threads: [256, 1, 1] — one thread per output dim. +// One threadgroup per (kv_head, query_token) processes all Q heads in the GQA group. +kernel void prefill_attention_batched_causal( + device const float* q [[buffer(0)]], + device const float* k_cache [[buffer(1)]], + device const float* v_cache [[buffer(2)]], + device float* out [[buffer(3)]], + constant uint& base_pos [[buffer(4)]], + constant uint& num_tokens [[buffer(5)]], + constant uint& cache_len_total [[buffer(6)]], + constant uint& head_dim [[buffer(7)]], + constant uint& num_q_heads [[buffer(8)]], + constant uint& num_kv_heads [[buffer(9)]], + constant uint& q_dim [[buffer(10)]], + constant uint& kv_dim [[buffer(11)]], + constant float& scale [[buffer(12)]], + uint3 gid3 [[threadgroup_position_in_grid]], + uint3 lid3 [[thread_position_in_threadgroup]], + uint3 tgs3 [[threads_per_threadgroup]]) +{ + constexpr uint HEAD_DIM = 256; + constexpr uint MAX_GRP = 8; + constexpr uint TILE_TOKENS = 256; + + // Extract scalar thread/threadgroup indices from the 3D vectors. + const uint lid = lid3.x; + const uint tgs = tgs3.x; + + if (head_dim != HEAD_DIM || num_kv_heads == 0) return; + const uint kvh = gid3.x; + const uint qt = gid3.y; + if (kvh >= num_kv_heads || qt >= num_tokens) return; + if ((num_q_heads % num_kv_heads) != 0) return; + const uint group_size = num_q_heads / num_kv_heads; + if (group_size == 0 || group_size > MAX_GRP) return; + const uint qh_base = kvh * group_size; + + // Causal bound: query token qt at absolute position base_pos+qt + // sees only cache rows [0 .. base_pos+qt+1] to exclude future tokens. + const uint causal_len = min(cache_len_total, base_pos + qt + 1u); + if (causal_len == 0) { + for (uint qi = 0; qi < group_size; qi++) { + out[qt * q_dim + (qh_base + qi) * HEAD_DIM + lid] = 0.0f; + } + return; + } + + // Threadgroup memory — same layout and size as decode_attention (~17 KB). + threadgroup float q_s [MAX_GRP * HEAD_DIM]; + threadgroup float score_s[MAX_GRP * TILE_TOKENS]; + threadgroup float reduce_s[TILE_TOKENS]; + threadgroup float m_s [MAX_GRP]; + threadgroup float l_s [MAX_GRP]; + threadgroup float alpha_s[MAX_GRP]; + + if (lid < group_size) { + m_s[lid] = -INFINITY; + l_s[lid] = 0.0f; + } + + // Load Q for all GQA query heads from this token's row in q_separated[N, q_dim]. + const uint q_row_base = qt * q_dim + qh_base * HEAD_DIM; + for (uint idx = lid; idx < group_size * HEAD_DIM; idx += tgs) { + q_s[idx] = q[q_row_base + idx]; + } + + float acc[MAX_GRP]; + for (uint qi = 0; qi < MAX_GRP; qi++) acc[qi] = 0.0f; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Tiled online-softmax — same structure as decode_attention. + for (uint tile_start = 0; tile_start < causal_len; tile_start += TILE_TOKENS) { + const uint tile_count = min(TILE_TOKENS, causal_len - tile_start); + + if (lid < tile_count) { + float dot[MAX_GRP]; + for (uint qi = 0; qi < MAX_GRP; qi++) dot[qi] = 0.0f; + const uint k_base = (tile_start + lid) * kv_dim + kvh * HEAD_DIM; + for (uint d = 0; d < HEAD_DIM; d++) { + const float kd = k_cache[k_base + d]; + for (uint qi = 0; qi < group_size; qi++) { + dot[qi] += q_s[qi * HEAD_DIM + d] * kd; + } + } + for (uint qi = 0; qi < group_size; qi++) { + score_s[qi * TILE_TOKENS + lid] = dot[qi] * scale; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Online-softmax rescale: when a later tile raises the max, rescale the prior + // accumulator by alpha=exp(m_old-m_new) to preserve the running weighted sum. + for (uint qi = 0; qi < group_size; qi++) { + reduce_s[lid] = (lid < tile_count) ? score_s[qi * TILE_TOKENS + lid] : -INFINITY; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint s = tgs >> 1; s > 0; s >>= 1) { + if (lid < s) reduce_s[lid] = max(reduce_s[lid], reduce_s[lid + s]); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float tile_max = reduce_s[0]; + + if (lid == 0) { + const float m_old = m_s[qi]; + const float m_new = max(m_old, tile_max); + alpha_s[qi] = isfinite(m_old) ? exp(m_old - m_new) : 0.0f; + m_s[qi] = m_new; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (lid < tile_count) { + score_s[qi * TILE_TOKENS + lid] = + exp(score_s[qi * TILE_TOKENS + lid] - m_s[qi]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + reduce_s[lid] = (lid < tile_count) ? score_s[qi * TILE_TOKENS + lid] : 0.0f; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint s = tgs >> 1; s > 0; s >>= 1) { + if (lid < s) reduce_s[lid] += reduce_s[lid + s]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + if (lid == 0) { + l_s[qi] = alpha_s[qi] * l_s[qi] + reduce_s[0]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (uint qi = 0; qi < group_size; qi++) { + acc[qi] *= alpha_s[qi]; + } + + const uint d = lid; + for (uint local_t = 0; local_t < tile_count; local_t++) { + const float v = v_cache[(tile_start + local_t) * kv_dim + kvh * HEAD_DIM + d]; + for (uint qi = 0; qi < group_size; qi++) { + acc[qi] += score_s[qi * TILE_TOKENS + local_t] * v; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write output at the query token's row: out[qt * q_dim + qh * HEAD_DIM + lid]. + for (uint qi = 0; qi < group_size; qi++) { + const uint qh = qh_base + qi; + const float dn = l_s[qi]; + out[qt * q_dim + qh * HEAD_DIM + lid] = dn > 0.0f ? acc[qi] / dn : 0.0f; + } +} "#; const MSL_Q4_TILED_SOURCE: &str = concat!( @@ -2999,6 +3283,13 @@ kernel void moe_shared_gate_add( moe_scale_add: ComputePipelineState, moe_shared_gate_add: ComputePipelineState, moe_zero_buf: ComputePipelineState, + // Win 1 batch prefill prep kernels (ADR-126A) + scatter_q_gate_batch: ComputePipelineState, + per_head_rms_norm_batch: ComputePipelineState, + partial_rope_batch: ComputePipelineState, + copy_kv_cache_batch: ComputePipelineState, + // Win 2 batched causal prefill attention (ADR-126B) + prefill_attention_batched: ComputePipelineState, } // ----------------------------------------------------------------------- @@ -3848,6 +4139,11 @@ kernel void moe_shared_gate_add( moe_scale_add: make_pipeline("moe_scale_add")?, moe_shared_gate_add: make_pipeline("moe_shared_gate_add")?, moe_zero_buf: make_pipeline("zero_buf")?, + scatter_q_gate_batch: make_pipeline("scatter_q_gate_batch")?, + per_head_rms_norm_batch: make_pipeline("per_head_rms_norm_batch")?, + partial_rope_batch: make_pipeline("partial_rope_batch")?, + copy_kv_cache_batch: make_pipeline("copy_kv_cache_batch")?, + prefill_attention_batched: make_pipeline("prefill_attention_batched_causal")?, }; // Upload per-layer weights @@ -6839,161 +7135,126 @@ kernel void moe_shared_gate_add( ); } - // Per-token: scatter, norm, RoPE, cache store, attention, gate - for t in 0..n { - // abs_pos is the token's position in the full sequence across all chunks. - let abs_pos = start_pos + t; - let q_off = (t * 2 * q_dim) as u64 * 4; - let qs_off = (t * q_dim) as u64 * 4; - let gz_off = (t * q_dim) as u64 * 4; - let k_off = (t * kv_dim) as u64 * 4; - let v_off = (t * kv_dim) as u64 * 4; - let ao_off = (t * q_dim) as u64 * 4; - - // Scatter Q + gate from interleaved q_proj - // SAFETY: Per-token q/gate offsets are within the batch - // activation buffers, and q_dim is num_q_heads * head_dim. - unsafe { - enc.set_compute_pipeline_state(&self.engine.pipelines.scatter_q_gate); - enc.set_buffer(0, Some(&self.session.activations.q), q_off); - enc.set_buffer(1, Some(&self.session.activations.q_separated), qs_off); - enc.set_buffer(2, Some(&self.session.activations.gate_z), gz_off); - let nh = num_q_heads as u32; - let hd = head_dim as u32; - enc.set_bytes(3, 4, &nh as *const u32 as *const _); - enc.set_bytes(4, 4, &hd as *const u32 as *const _); - let wg = 256u64; - enc.dispatch_threads( - MTLSize::new(div_ceil(q_dim as u64, wg) * wg, 1, 1), - MTLSize::new(wg, 1, 1), - ); - } - - // Per-head RMS norm Q and K - // SAFETY: Q/K norm buffers are live and per-token offsets - // are within activation buffers sized from the same config. - unsafe { - enc.set_compute_pipeline_state( - &self.engine.pipelines.per_head_rms_norm, - ); - enc.set_buffer(0, Some(&self.session.activations.q_separated), qs_off); - enc.set_buffer(1, Some(&*w_qn), 0); - let nh = num_q_heads as u32; - let hd = head_dim as u32; - enc.set_bytes(2, 4, &nh as *const u32 as *const _); - enc.set_bytes(3, 4, &hd as *const u32 as *const _); - enc.set_bytes(4, 4, &cfg.rms_norm_eps as *const f32 as *const _); - enc.dispatch_thread_groups( - MTLSize::new(nh as u64, 1, 1), - MTLSize::new(256, 1, 1), - ); - - enc.set_buffer(0, Some(&self.session.activations.k), k_off); - enc.set_buffer(1, Some(&*w_kn), 0); - let nkh = num_kv_heads as u32; - enc.set_bytes(2, 4, &nkh as *const u32 as *const _); - enc.dispatch_thread_groups( - MTLSize::new(nkh as u64, 1, 1), - MTLSize::new(256, 1, 1), - ); - } - - // Partial RoPE for Q and K - // SAFETY: RoPE table buffers cover max positions and - // per-token Q/K offsets are within activation buffers. - unsafe { - let pos = abs_pos as u32; - enc.set_compute_pipeline_state(&self.engine.pipelines.partial_rope); - enc.set_buffer(0, Some(&self.session.activations.q_separated), qs_off); - enc.set_buffer(1, Some(&self.engine.rope_cos), 0); - enc.set_buffer(2, Some(&self.engine.rope_sin), 0); - let nh = num_q_heads as u32; - let hd = head_dim as u32; - enc.set_bytes(3, 4, &nh as *const u32 as *const _); - enc.set_bytes(4, 4, &hd as *const u32 as *const _); - enc.set_bytes(5, 4, &half_rope_dim as *const u32 as *const _); - enc.set_bytes(6, 4, &pos as *const u32 as *const _); - let wg = 256u64; - enc.dispatch_threads( - MTLSize::new(div_ceil(q_dim as u64, wg) * wg, 1, 1), - MTLSize::new(wg, 1, 1), - ); - - enc.set_buffer(0, Some(&self.session.activations.k), k_off); - let nkh = num_kv_heads as u32; - enc.set_bytes(3, 4, &nkh as *const u32 as *const _); - enc.dispatch_threads( - MTLSize::new(div_ceil(kv_dim as u64, wg) * wg, 1, 1), - MTLSize::new(wg, 1, 1), - ); - } - - // Store K, V to cache; use absolute row so cross-chunk tokens - // write to the correct KV cache rows rather than overwriting chunk 0. - { - let dst_offset = (abs_pos * kv_dim) as u32; - enc.set_compute_pipeline_state(&self.engine.pipelines.copy_offset); - enc.set_buffer(0, Some(&self.session.activations.k), k_off); - enc.set_buffer(1, Some(&self.session.kv_cache.k_bufs[full_idx]), 0); - let cnt = kv_dim as u32; - enc.set_bytes(2, 4, &cnt as *const u32 as *const _); - enc.set_bytes(3, 4, &dst_offset as *const u32 as *const _); - let wg = 256u64; - enc.dispatch_threads( - MTLSize::new(div_ceil(kv_dim as u64, wg) * wg, 1, 1), - MTLSize::new(wg, 1, 1), - ); + // Win 1: batch scatter, norm, RoPE, and KV store — one dispatch each + // instead of n dispatches. Attention loop (Win 2) kept per-token. + { + let base_pos = start_pos as u32; + let num_tok = m; + let nqh = num_q_heads as u32; + let nkh = num_kv_heads as u32; + let hd = head_dim as u32; + let kvd = kv_dim as u32; + // SAFETY: Layer norm weight pointers are live for the command buffer duration. + let (qn_ref, kn_ref): (&Buffer, &Buffer) = unsafe { (&*w_qn, &*w_kn) }; - enc.set_buffer(0, Some(&self.session.activations.v), v_off); - enc.set_buffer(1, Some(&self.session.kv_cache.v_bufs[full_idx]), 0); - enc.dispatch_threads( - MTLSize::new(div_ceil(kv_dim as u64, wg) * wg, 1, 1), - MTLSize::new(wg, 1, 1), - ); - } + self.dispatch_scatter_q_gate_batch(enc, num_tok, nqh, hd); + self.dispatch_per_head_rms_norm_batch( + enc, + &self.session.activations.q_separated, + qn_ref, + num_tok, + nqh, + hd, + cfg.rms_norm_eps, + ); + self.dispatch_per_head_rms_norm_batch( + enc, + &self.session.activations.k, + kn_ref, + num_tok, + nkh, + hd, + cfg.rms_norm_eps, + ); + self.dispatch_partial_rope_batch( + enc, + &self.session.activations.q_separated, + num_tok, + nqh, + hd, + half_rope_dim, + base_pos, + ); + self.dispatch_partial_rope_batch( + enc, + &self.session.activations.k, + num_tok, + nkh, + hd, + half_rope_dim, + base_pos, + ); + self.dispatch_copy_kv_cache_batch( + enc, + &self.session.activations.k, + &self.session.activations.v, + &self.session.kv_cache.k_bufs[full_idx], + &self.session.kv_cache.v_bufs[full_idx], + num_tok, + kvd, + base_pos, + ); + } - // Causal attention: Q[t] attends against cache[0..abs_pos+1], - // covering all prior chunks when start_pos > 0. - { + // Win 2: batched causal prefill attention — one dispatch for all n tokens. + // Falls back to the per-token loop only if shape validation fails. + if self + .dispatch_prefill_attention_batched( + enc, + &self.session.kv_cache.k_bufs[full_idx], + &self.session.kv_cache.v_bufs[full_idx], + start_pos as u32, + m, + head_dim as u32, + num_q_heads as u32, + num_kv_heads as u32, + q_dim as u32, + kv_dim as u32, + scale, + ) + .is_err() + { + for t in 0..n { + let abs_pos = start_pos + t; + let qs_off = (t * q_dim) as u64 * 4; + let ao_off = (t * q_dim) as u64 * 4; let cache_len = (abs_pos + 1) as u32; - enc.set_compute_pipeline_state(&self.engine.pipelines.decode_attention); - enc.set_buffer(0, Some(&self.session.activations.q_separated), qs_off); - enc.set_buffer(1, Some(&self.session.kv_cache.k_bufs[full_idx]), 0); - enc.set_buffer(2, Some(&self.session.kv_cache.v_bufs[full_idx]), 0); - enc.set_buffer(3, Some(&self.session.activations.attn_out), ao_off); - enc.set_bytes(4, 4, &cache_len as *const u32 as *const _); let hd = head_dim as u32; let nqh = num_q_heads as u32; let nkh = num_kv_heads as u32; let qd = q_dim as u32; let kvd = kv_dim as u32; - enc.set_bytes(5, 4, &hd as *const u32 as *const _); - enc.set_bytes(6, 4, &nqh as *const u32 as *const _); - enc.set_bytes(7, 4, &nkh as *const u32 as *const _); - enc.set_bytes(8, 4, &qd as *const u32 as *const _); - enc.set_bytes(9, 4, &kvd as *const u32 as *const _); - enc.set_bytes(10, 4, &scale as *const f32 as *const _); - enc.dispatch_thread_groups( - MTLSize::new(nqh as u64, 1, 1), - MTLSize::new(256, 1, 1), - ); - } - - // Sigmoid gate - { - let cnt = q_dim as u32; - enc.set_compute_pipeline_state(&self.engine.pipelines.sigmoid_gate); - enc.set_buffer(0, Some(&self.session.activations.attn_out), ao_off); - enc.set_buffer(1, Some(&self.session.activations.gate_z), gz_off); - enc.set_bytes(2, 4, &cnt as *const u32 as *const _); - let wg = 256u64; - enc.dispatch_threads( - MTLSize::new(div_ceil(q_dim as u64, wg) * wg, 1, 1), - MTLSize::new(wg, 1, 1), - ); + // Per-token q/attn_out offsets are within batch activation buffers. + { + enc.set_compute_pipeline_state( + &self.engine.pipelines.decode_attention, + ); + enc.set_buffer( + 0, + Some(&self.session.activations.q_separated), + qs_off, + ); + enc.set_buffer(1, Some(&self.session.kv_cache.k_bufs[full_idx]), 0); + enc.set_buffer(2, Some(&self.session.kv_cache.v_bufs[full_idx]), 0); + enc.set_buffer(3, Some(&self.session.activations.attn_out), ao_off); + enc.set_bytes(4, 4, &cache_len as *const u32 as *const _); + enc.set_bytes(5, 4, &hd as *const u32 as *const _); + enc.set_bytes(6, 4, &nqh as *const u32 as *const _); + enc.set_bytes(7, 4, &nkh as *const u32 as *const _); + enc.set_bytes(8, 4, &qd as *const u32 as *const _); + enc.set_bytes(9, 4, &kvd as *const u32 as *const _); + enc.set_bytes(10, 4, &scale as *const f32 as *const _); + enc.dispatch_thread_groups( + MTLSize::new(nqh as u64, 1, 1), + MTLSize::new(256, 1, 1), + ); + } } } + // Sigmoid gate over all n*q_dim outputs in one dispatch. + self.dispatch_sigmoid_gate(enc, m * q_dim as u32); + // Batch: O projection (attn_out[N, q_dim] → ffn_out[N, hidden]) // SAFETY: O-projection pointer is live and dimensions match // [m, q_dim] by [hidden, q_dim]. @@ -10241,6 +10502,155 @@ kernel void moe_shared_gate_add( ); } + fn dispatch_scatter_q_gate_batch( + &self, + enc: &ComputeCommandEncoderRef, + num_tokens: u32, + num_heads: u32, + head_dim: u32, + ) { + let total = num_tokens * num_heads * head_dim; + enc.set_compute_pipeline_state(&self.engine.pipelines.scatter_q_gate_batch); + enc.set_buffer(0, Some(&self.session.activations.q), 0); + enc.set_buffer(1, Some(&self.session.activations.q_separated), 0); + enc.set_buffer(2, Some(&self.session.activations.gate_z), 0); + enc.set_bytes(3, 4, &num_tokens as *const u32 as *const _); + enc.set_bytes(4, 4, &num_heads as *const u32 as *const _); + enc.set_bytes(5, 4, &head_dim as *const u32 as *const _); + let wg = 256u64; + enc.dispatch_threads( + MTLSize::new(div_ceil(total as u64, wg) * wg, 1, 1), + MTLSize::new(wg, 1, 1), + ); + } + + fn dispatch_per_head_rms_norm_batch( + &self, + enc: &ComputeCommandEncoderRef, + x: &Buffer, + gamma: &Buffer, + num_tokens: u32, + num_heads: u32, + head_dim: u32, + eps: f32, + ) { + let total_groups = num_tokens * num_heads; + enc.set_compute_pipeline_state(&self.engine.pipelines.per_head_rms_norm_batch); + enc.set_buffer(0, Some(x), 0); + enc.set_buffer(1, Some(gamma), 0); + enc.set_bytes(2, 4, &num_tokens as *const u32 as *const _); + enc.set_bytes(3, 4, &num_heads as *const u32 as *const _); + enc.set_bytes(4, 4, &head_dim as *const u32 as *const _); + enc.set_bytes(5, 4, &eps as *const f32 as *const _); + let wg = 256u64; + enc.dispatch_thread_groups( + MTLSize::new(total_groups as u64, 1, 1), + MTLSize::new(wg, 1, 1), + ); + } + + #[allow(clippy::too_many_arguments)] + fn dispatch_partial_rope_batch( + &self, + enc: &ComputeCommandEncoderRef, + x: &Buffer, + num_tokens: u32, + num_heads: u32, + head_dim: u32, + half_rope_dim: u32, + base_pos: u32, + ) { + let total_pairs = num_tokens * num_heads * half_rope_dim; + enc.set_compute_pipeline_state(&self.engine.pipelines.partial_rope_batch); + enc.set_buffer(0, Some(x), 0); + enc.set_buffer(1, Some(&self.engine.rope_cos), 0); + enc.set_buffer(2, Some(&self.engine.rope_sin), 0); + enc.set_bytes(3, 4, &num_tokens as *const u32 as *const _); + enc.set_bytes(4, 4, &num_heads as *const u32 as *const _); + enc.set_bytes(5, 4, &head_dim as *const u32 as *const _); + enc.set_bytes(6, 4, &half_rope_dim as *const u32 as *const _); + enc.set_bytes(7, 4, &base_pos as *const u32 as *const _); + let wg = 256u64; + enc.dispatch_threads( + MTLSize::new(div_ceil(total_pairs as u64, wg) * wg, 1, 1), + MTLSize::new(wg, 1, 1), + ); + } + + #[allow(clippy::too_many_arguments)] + fn dispatch_copy_kv_cache_batch( + &self, + enc: &ComputeCommandEncoderRef, + k_src: &Buffer, + v_src: &Buffer, + k_cache: &Buffer, + v_cache: &Buffer, + num_tokens: u32, + kv_dim: u32, + base_pos: u32, + ) { + let total = num_tokens * kv_dim; + enc.set_compute_pipeline_state(&self.engine.pipelines.copy_kv_cache_batch); + enc.set_buffer(0, Some(k_src), 0); + enc.set_buffer(1, Some(v_src), 0); + enc.set_buffer(2, Some(k_cache), 0); + enc.set_buffer(3, Some(v_cache), 0); + enc.set_bytes(4, 4, &num_tokens as *const u32 as *const _); + enc.set_bytes(5, 4, &kv_dim as *const u32 as *const _); + enc.set_bytes(6, 4, &base_pos as *const u32 as *const _); + let wg = 256u64; + enc.dispatch_threads( + MTLSize::new(div_ceil(total as u64, wg) * wg, 1, 1), + MTLSize::new(wg, 1, 1), + ); + } + + #[allow(clippy::too_many_arguments)] + fn dispatch_prefill_attention_batched( + &self, + enc: &ComputeCommandEncoderRef, + k_cache: &Buffer, + v_cache: &Buffer, + base_pos: u32, + num_tokens: u32, + head_dim: u32, + num_q_heads: u32, + num_kv_heads: u32, + q_dim: u32, + kv_dim: u32, + scale: f32, + ) -> Result<(), String> { + validate_flash_decode_shape( + head_dim as usize, + num_q_heads as usize, + num_kv_heads as usize, + q_dim as usize, + kv_dim as usize, + )?; + let cache_len_total = base_pos.checked_add(num_tokens).ok_or_else(|| { + "prefill_attention_batched: base_pos + num_tokens overflow".to_string() + })?; + enc.set_compute_pipeline_state(&self.engine.pipelines.prefill_attention_batched); + enc.set_buffer(0, Some(&self.session.activations.q_separated), 0); + enc.set_buffer(1, Some(k_cache), 0); + enc.set_buffer(2, Some(v_cache), 0); + enc.set_buffer(3, Some(&self.session.activations.attn_out), 0); + enc.set_bytes(4, 4, &base_pos as *const u32 as *const _); + enc.set_bytes(5, 4, &num_tokens as *const u32 as *const _); + enc.set_bytes(6, 4, &cache_len_total as *const u32 as *const _); + enc.set_bytes(7, 4, &head_dim as *const u32 as *const _); + enc.set_bytes(8, 4, &num_q_heads as *const u32 as *const _); + enc.set_bytes(9, 4, &num_kv_heads as *const u32 as *const _); + enc.set_bytes(10, 4, &q_dim as *const u32 as *const _); + enc.set_bytes(11, 4, &kv_dim as *const u32 as *const _); + enc.set_bytes(12, 4, &scale as *const f32 as *const _); + enc.dispatch_thread_groups( + MTLSize::new(num_kv_heads as u64, num_tokens as u64, 1), + MTLSize::new(256, 1, 1), + ); + Ok(()) + } + fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, count: u32) { enc.set_compute_pipeline_state(&self.engine.pipelines.silu_mul); enc.set_buffer(0, Some(&self.session.activations.gate), 0); @@ -11675,6 +12085,11 @@ kernel void moe_shared_gate_add( moe_scale_add: make_pipeline("moe_scale_add")?, moe_shared_gate_add: make_pipeline("moe_shared_gate_add")?, moe_zero_buf: make_pipeline("zero_buf")?, + scatter_q_gate_batch: make_pipeline("scatter_q_gate_batch")?, + per_head_rms_norm_batch: make_pipeline("per_head_rms_norm_batch")?, + partial_rope_batch: make_pipeline("partial_rope_batch")?, + copy_kv_cache_batch: make_pipeline("copy_kv_cache_batch")?, + prefill_attention_batched: make_pipeline("prefill_attention_batched_causal")?, }; let hidden = cfg.hidden_size;