From 7c80e09a5f8c62d630200a1dae27fd054560680a Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Tue, 26 May 2026 12:49:01 +0100 Subject: [PATCH] onnx: com.microsoft GroupQueryAttention (prefill), ORT-validated Prefill-only GroupQueryAttention lowered onto tract Sdpa: reshapes Q/K/V to 4D, applies an explicit lower-triangular causal mask, and returns present_key/present_value (the reshaped K/V). Sdpa handles the grouped-query head sharing (kv_num_heads < num_heads). Decode-step KV cache, internal rotary (do_rotary), local-window attention and softcap are rejected with clear errors. Validated against onnxruntime across head_size 8/16/64, several num_heads/kv_num_heads ratios (incl. multi-query kv=1) and batch>1: attention output matches to <=3.6e-7 and present_key/present_value are bit-exact. ORT's GroupQueryAttention prefill is standard causal grouped-query attention; the seqlens_k input is the 0-indexed position of the last token (total_sequence_length - 1), not the token count. Co-Authored-By: Claude Opus 4.7 (1M context) --- onnx/src/ops/nn/group_query_attention.rs | 169 +++++++++++++++++++++++ onnx/src/ops/nn/mod.rs | 2 + 2 files changed, 171 insertions(+) create mode 100644 onnx/src/ops/nn/group_query_attention.rs diff --git a/onnx/src/ops/nn/group_query_attention.rs b/onnx/src/ops/nn/group_query_attention.rs new file mode 100644 index 0000000000..4de7657d1c --- /dev/null +++ b/onnx/src/ops/nn/group_query_attention.rs @@ -0,0 +1,169 @@ +use crate::model::ParsingContext; +use crate::pb::NodeProto; +use tract_core::ops::change_axes::AxisOp; +use tract_hir::internal::*; +use tract_transformers::ops::sdpa::Sdpa; + +// com.microsoft GroupQueryAttention (prefill only). +// inputs: query(0), key(1), value(2), past_key(3), past_value(4), seqlens_k(5), total_seq(6) +// outputs: output(0), present_key(1), present_value(2) +// Scoped to the prefill case (no past KV cache): query/key/value are [B, S, heads*head_size], +// attention is causal with q_seq == kv_seq, and present_key/value are the reshaped K/V. +// Decode-step KV cache, internal rotary (do_rotary) and local-window attention are rejected. +pub fn group_query_attention( + _ctx: &ParsingContext, + node: &NodeProto, +) -> TractResult<(Box, Vec)> { + let num_heads: usize = node.get_attr("num_heads")?; + let kv_num_heads: usize = node.get_attr("kv_num_heads")?; + let scale = node.get_attr_opt::("scale")?; + ensure!( + node.get_attr_opt::("do_rotary")?.unwrap_or(0) == 0, + "GroupQueryAttention: internal rotary (do_rotary) is unsupported; apply RotaryEmbedding separately" + ); + ensure!( + node.get_attr_opt::("local_window_size")?.unwrap_or(-1) < 0, + "GroupQueryAttention: local_window_size is unsupported" + ); + ensure!( + node.get_attr_opt::("softcap")?.unwrap_or(0.0) == 0.0, + "GroupQueryAttention: softcap is unsupported" + ); + let have_past = (node.input.len() > 3 && !node.input[3].is_empty()) + || (node.input.len() > 4 && !node.input[4].is_empty()); + ensure!( + !have_past, + "GroupQueryAttention: past KV cache (decode step) is unsupported; only prefill is handled" + ); + Ok((expand(GroupQueryAttention { num_heads, kv_num_heads, scale }), vec![])) +} + +#[derive(Debug, Clone)] +struct GroupQueryAttention { + num_heads: usize, + kv_num_heads: usize, + scale: Option, +} + +// [B, S, heads*head_size] -> [B, heads, S, head_size] +fn to_4d( + model: &mut TypedModel, + prefix: &str, + x: OutletId, + total: TDim, + heads: usize, +) -> TractResult { + let head_dim = total.clone() / heads; + let reshaped = model.wire_node( + format!("{prefix}.reshape"), + AxisOp::Reshape(2, tvec![total], tvec![heads.to_dim(), head_dim]), + &[x], + )?[0]; + Ok(model.wire_node(format!("{prefix}.transpose"), AxisOp::Move(2, 1), &[reshaped])?[0]) +} + +impl Expansion for GroupQueryAttention { + fn name(&self) -> StaticName { + "GroupQueryAttention".into() + } + + fn nboutputs(&self) -> TractResult { + Ok(3) + } + + fn rules<'r, 'p: 'r, 's: 'r>( + &'s self, + s: &mut Solver<'r>, + inputs: &'p [TensorProxy], + outputs: &'p [TensorProxy], + ) -> InferenceResult { + check_output_arity(outputs, 3)?; + s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?; + s.equals(&inputs[0].shape, &outputs[0].shape)?; + s.equals(&inputs[0].datum_type, &outputs[1].datum_type)?; + s.equals(&inputs[0].datum_type, &outputs[2].datum_type)?; + // present_key / present_value = key/value reshaped to [B, kv_num_heads, S, head_dim]. + let kvh = self.kv_num_heads; + s.given(&inputs[1].shape, move |s, ks| { + s.equals( + &outputs[1].shape, + tvec![ks[0].clone(), kvh.to_dim(), ks[1].clone(), ks[2].clone() / kvh], + ) + })?; + s.given(&inputs[2].shape, move |s, vs| { + s.equals( + &outputs[2].shape, + tvec![vs[0].clone(), kvh.to_dim(), vs[1].clone(), vs[2].clone() / kvh], + ) + })?; + Ok(()) + } + + fn wire( + &self, + prefix: &str, + model: &mut TypedModel, + inputs: &[OutletId], + ) -> TractResult> { + let q_fact = model.outlet_fact(inputs[0])?.clone(); + let dt = q_fact.datum_type; + ensure!(q_fact.rank() == 3, "GroupQueryAttention: expected 3D query [B, S, hidden]"); + let q_hidden = q_fact.shape[2].clone(); + let k_hidden = model.outlet_fact(inputs[1])?.shape[2].clone(); + let v_hidden = model.outlet_fact(inputs[2])?.shape[2].clone(); + + let q4 = to_4d(model, &format!("{prefix}.q"), inputs[0], q_hidden.clone(), self.num_heads)?; + let k4 = to_4d(model, &format!("{prefix}.k"), inputs[1], k_hidden, self.kv_num_heads)?; + let v4 = to_4d(model, &format!("{prefix}.v"), inputs[2], v_hidden, self.kv_num_heads)?; + + // Causal mask: materialise an explicit additive lower-triangular mask for concrete + // shapes (exact ONNX semantics: query i attends to keys j <= i); fall back to Sdpa's + // own is_causal for symbolic shapes. Sdpa handles GQA head grouping (kv heads < q heads). + let q_seq = model.outlet_fact(q4)?.shape[2].to_usize().ok(); + let kv_seq = model.outlet_fact(k4)?.shape[2].to_usize().ok(); + let (mask, is_causal) = if let (Some(qs), Some(ks)) = (q_seq, kv_seq) { + let arr = tract_ndarray::Array2::::from_shape_fn((qs, ks), |(i, j)| { + if j <= i { 0.0f32 } else { f32::NEG_INFINITY } + }); + let mask_tensor: Tensor = arr.into(); + let mut m = model.add_const(format!("{prefix}.causal_mask"), mask_tensor)?; + for i in 0..2 { + m = model.wire_node( + format!("{prefix}.mask_unsqueeze_{i}"), + AxisOp::Add(0), + &[m], + )?[0]; + } + (Some(m), false) + } else { + (None, true) + }; + let mut sdpa_inputs = tvec![q4, k4, v4]; + if let Some(m) = mask { + sdpa_inputs.push(m); + } + let sdpa = Sdpa { + scale: self.scale.map(tensor0), + datum_type: dt, + acc_datum_type: DatumType::F32, + is_causal, + }; + let y4 = model.wire_node(format!("{prefix}.sdpa"), sdpa, &sdpa_inputs)?[0]; + + // [B, num_heads, S, head_dim] -> [B, S, num_heads, head_dim] -> [B, S, hidden] + let y_t = model.wire_node(format!("{prefix}.y_transpose"), AxisOp::Move(1, 2), &[y4])?[0]; + let yf = model.outlet_fact(y4)?.clone(); + let (heads_dim, head_dim) = (yf.shape[1].clone(), yf.shape[3].clone()); + let y = model.wire_node( + format!("{prefix}.y_reshape"), + AxisOp::Reshape( + 2, + tvec![heads_dim.clone(), head_dim.clone()], + tvec![heads_dim * head_dim], + ), + &[y_t], + )?[0]; + + Ok(tvec!(y, k4, v4)) + } +} diff --git a/onnx/src/ops/nn/mod.rs b/onnx/src/ops/nn/mod.rs index 1b4d3d0390..e440afd4b5 100644 --- a/onnx/src/ops/nn/mod.rs +++ b/onnx/src/ops/nn/mod.rs @@ -14,6 +14,7 @@ mod dropout; mod gelu; mod gelu_contrib; mod group_norm; +mod group_query_attention; mod instance_norm; mod layer_norm; mod lp_norm; @@ -91,6 +92,7 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) { reg.insert("BiasGelu", gelu_contrib::bias_gelu); reg.insert("FastGelu", gelu_contrib::fast_gelu); reg.insert("QuickGelu", gelu_contrib::quick_gelu); + reg.insert("GroupQueryAttention", group_query_attention::group_query_attention); reg.insert("HardSwish", |_, _| Ok((ops::nn::hard_swish().into_hir(), vec![]))); reg.insert("Mish", |_, _| Ok((expand(mish::Mish), vec![]))); reg.insert("RMSNormalization", rms_norm::rms_normalization);