From aeb0bb2d63b58beb5a4c7ab898b8729c994ad51c Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 26 May 2026 09:38:58 +0000 Subject: [PATCH 1/5] cuda+metal/scaled_masked_softmax: bool mask + post_softmax_mask Adds a sibling kernel that loads the mask as uchar/char, substitutes -inf at masked positions before softmax, and when post_softmax_mask is set scrubs fully-masked rows (sum == 0 / NaN) to 0 on write-back. Lifts the GpuScaledMaskedSoftmax guards so bool masks aren't rejected by output_facts, and drops the rule_if!(!post_softmax_mask) gate on both backends. For the nemotron-streaming encoder this moves all 24 SMS nodes off CPU (--cuda matches the recorded io bundle at --approx very). Metal mirror matches structurally; CI's macOS nemotron harness covers the numeric check there. --- cuda/src/kernels/cu/nn.cu | 146 ++++++++++++++++++ cuda/src/kernels/nn/mod.rs | 6 +- cuda/src/kernels/nn/scaled_masked_softmax.rs | 110 +++++++++++-- gpu/src/ops/scaled_masked_softmax.rs | 31 +++- metal/src/kernels/nn/mod.rs | 8 +- metal/src/kernels/nn/nn_ops.metal | 120 ++++++++++++++ metal/src/kernels/nn/scaled_masked_softmax.rs | 115 +++++++++++--- 7 files changed, 489 insertions(+), 47 deletions(-) diff --git a/cuda/src/kernels/cu/nn.cu b/cuda/src/kernels/cu/nn.cu index 7eb34dc4e9..fa3151acaa 100644 --- a/cuda/src/kernels/cu/nn.cu +++ b/cuda/src/kernels/cu/nn.cu @@ -588,6 +588,136 @@ __device__ void scaled_masked_softmax( out_stride_4); \ } +// Bool-mask variant: mask is char (0/1), substitutes -inf at masked positions +// before the softmax. When post_mask is non-zero, fully-masked rows (all +// inputs -inf) are written as 0 instead of the NaN the naive softmax would +// emit. Partially-masked rows are unaffected: exp(-inf) = 0 already zeros +// masked positions in the output. +template +__device__ void scaled_bool_masked_softmax( + const T *x, const char *mask, const float scale, T *dst, + const int32_t post_mask, const int32_t shape_0, const int32_t shape_1, + const int32_t shape_2, const int32_t shape_3, const int32_t shape_4, + const int32_t stride_0, const int32_t stride_1, const int32_t stride_2, + const int32_t stride_3, const int32_t stride_4, + const int32_t mask_stride_0, const int32_t mask_stride_1, + const int32_t mask_stride_2, const int32_t mask_stride_3, + const int32_t mask_stride_4, const int32_t out_stride_0, + const int32_t out_stride_1, const int32_t out_stride_2, + const int32_t out_stride_3, const int32_t out_stride_4) { + int32_t z0 = blockIdx.z / shape_1; + int32_t z1 = blockIdx.z % shape_1; + x += blockIdx.x * stride_3 + blockIdx.y * stride_2 + z1 * stride_1 + + z0 * stride_0; + mask += blockIdx.x * mask_stride_3 + blockIdx.y * mask_stride_2 + + z1 * mask_stride_1 + z0 * mask_stride_0; + dst += blockIdx.x * out_stride_3 + blockIdx.y * out_stride_2 + + z1 * out_stride_1 + z0 * out_stride_0; + + const int block_size = BLOCK_SIZE == 0 ? blockDim.x : BLOCK_SIZE; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + extern __shared__ float data_soft_max_f32[]; + float *buf_iw = data_soft_max_f32; + float *vals = buf_iw + WARP_SIZE; + + float max_val = -CUDART_INF_F; + _Pragma("unroll") for (int col0 = 0; col0 < shape_4; col0 += block_size) { + const int col = col0 + threadIdx.x; + if (col >= shape_4) { + break; + } + + const bool m = mask[col * mask_stride_4] != 0; + const float val = m ? ((float)x[col * stride_4]) * scale : -CUDART_INF_F; + vals[col] = val; + max_val = max(max_val, val); + } + + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf_iw[lane_id] = -CUDART_INF_F; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = max_val; + } + __syncthreads(); + + max_val = buf_iw[lane_id]; + max_val = warp_reduce_max(max_val); + } + + float tmp = 0.0f; + _Pragma("unroll") for (int col0 = 0; col0 < shape_4; col0 += block_size) { + const int col = col0 + threadIdx.x; + if (col >= shape_4) { + break; + } + + const float val = expf(vals[col] - max_val); + tmp += val; + vals[col] = val; + } + + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __syncthreads(); + if (warp_id == 0) { + buf_iw[lane_id] = 0.0f; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = tmp; + } + __syncthreads(); + + tmp = buf_iw[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + // Row-uniform: tmp <= 0 (or NaN) iff every position was masked. When + // post_mask is set we write 0 in that case to scrub the NaN; otherwise + // we fall through to the normal 1/sum path and let it propagate. + const bool zero_row = post_mask && !(tmp > 0.0f); + const float inv_sum = 1.0f / tmp; + + _Pragma("unroll") for (int col0 = 0; col0 < shape_4; col0 += block_size) { + const int col = col0 + threadIdx.x; + if (col >= shape_4) { + return; + } + dst[col * out_stride_4] = zero_row ? (T)0.0f : (T)(vals[col] * inv_sum); + } +} + +#define INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, bname, \ + block_size_template) \ + extern "C" __global__ void scaled_bool_masked_softmax_##bname##name( \ + const T *x, const char *mask, const float scale, T *dst, \ + const int32_t post_mask, const int32_t shape_0, \ + const int32_t shape_1, const int32_t shape_2, const int32_t shape_3, \ + const int32_t shape_4, const int32_t stride_0, \ + const int32_t stride_1, const int32_t stride_2, \ + const int32_t stride_3, const int32_t stride_4, \ + const int32_t mask_stride_0, const int32_t mask_stride_1, \ + const int32_t mask_stride_2, const int32_t mask_stride_3, \ + const int32_t mask_stride_4, const int32_t out_stride_0, \ + const int32_t out_stride_1, const int32_t out_stride_2, \ + const int32_t out_stride_3, const int32_t out_stride_4) { \ + scaled_bool_masked_softmax( \ + x, mask, scale, dst, post_mask, shape_0, shape_1, shape_2, \ + shape_3, shape_4, stride_0, stride_1, stride_2, stride_3, \ + stride_4, mask_stride_0, mask_stride_1, mask_stride_2, \ + mask_stride_3, mask_stride_4, out_stride_0, out_stride_1, \ + out_stride_2, out_stride_3, out_stride_4); \ + } + #define INSTANTIATE_RMS_NORM(name, T, bname, block_size) \ extern "C" __global__ void rms_norm_##bname##name( \ const T *x, T *dst, const int32_t shape_0, const int32_t shape_1, \ @@ -652,8 +782,24 @@ INSTANTIATE_SOFTMAX(f16, __half, , 1024) INSTANTIATE_SCALED_MASKED_SOFTMAX(name, T, 32768_, 1024) \ INSTANTIATE_SCALED_MASKED_SOFTMAX(name, T, 0_, 0) +#define INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX_FOR_T(name, T) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 32_, 32) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 64_, 64) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 128_, 126) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 256_, 256) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 512_, 512) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 1024_, 1024) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 2048_, 1024) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 4096_, 1024) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 8192_, 1024) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 16384_, 1024) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 32768_, 1024) \ + INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 0_, 0) + INSTANTIATE_SCALED_MASKED_SOFTMAX_FOR_T(f32, float) INSTANTIATE_SCALED_MASKED_SOFTMAX_FOR_T(f16, __half) +INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX_FOR_T(f32, float) +INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX_FOR_T(f16, __half) INSTANTIATE_REDUCE(f32, float, small_, 32) INSTANTIATE_REDUCE(f32, float, , 1024) diff --git a/cuda/src/kernels/nn/mod.rs b/cuda/src/kernels/nn/mod.rs index 668fad0763..a2bb0e78e1 100644 --- a/cuda/src/kernels/nn/mod.rs +++ b/cuda/src/kernels/nn/mod.rs @@ -54,7 +54,11 @@ pub fn all_functions() -> Vec { tract_gpu::tensor::DeviceTensor::SUPPORTED_DT .into_iter() .flat_map(|dt| sms_block_sizes().into_iter().map(move |bs| (dt, bs as usize))) - .flat_map(|(dt, bs)| ScaledMaskedSoftmax.kernel_name(dt, bs).into_iter()), + .flat_map(|(dt, bs)| { + [false, true] + .into_iter() + .flat_map(move |mb| ScaledMaskedSoftmax.kernel_name(dt, mb, bs).into_iter()) + }), ); functions.extend( diff --git a/cuda/src/kernels/nn/scaled_masked_softmax.rs b/cuda/src/kernels/nn/scaled_masked_softmax.rs index 8e5d80a0c8..3653a490af 100644 --- a/cuda/src/kernels/nn/scaled_masked_softmax.rs +++ b/cuda/src/kernels/nn/scaled_masked_softmax.rs @@ -18,14 +18,25 @@ impl ScaledMaskedSoftmax { matches!(dt, DatumType::F32 | DatumType::F16) } - pub fn kernel_name(&self, dt: DatumType, block_size: usize) -> TractResult { + pub fn is_supported_mask_dt(input_dt: DatumType, mask_dt: DatumType) -> bool { + mask_dt == input_dt || mask_dt == bool::datum_type() + } + + pub fn kernel_name( + &self, + input_dt: DatumType, + mask_is_bool: bool, + block_size: usize, + ) -> TractResult { ensure!( - Self::is_supported_dt(dt), - "Unsupported dt {:?} for cuda scaled masked softmaxop", - dt + Self::is_supported_dt(input_dt), + "Unsupported dt {:?} for cuda scaled masked softmax op", + input_dt ); - let tname = DeviceTensor::tname(dt)?; - Ok(format!("scaled_masked_softmax_{block_size}_{tname}")) + let tname = DeviceTensor::tname(input_dt)?; + let stem = + if mask_is_bool { "scaled_bool_masked_softmax" } else { "scaled_masked_softmax" }; + Ok(format!("{stem}_{block_size}_{tname}")) } pub fn eval( @@ -34,9 +45,10 @@ impl ScaledMaskedSoftmax { input: &DeviceTensor, scale: &Tensor, mask: &DeviceTensor, + post_softmax_mask: bool, ) -> TractResult { let output = unsafe { DeviceTensor::uninitialized_dt(input.datum_type(), input.shape())? }; - self.dispatch_eval(stream, input, scale, mask, &output)?; + self.dispatch_eval(stream, input, scale, mask, post_softmax_mask, &output)?; stream.synchronize()?; Ok(output) } @@ -47,19 +59,22 @@ impl ScaledMaskedSoftmax { input: &DeviceTensor, scale: &Tensor, mask: &DeviceTensor, + post_softmax_mask: bool, output: &DeviceTensor, ) -> TractResult<()> { ensure!(output.shape() == input.shape()); ensure!(input.rank() >= 2 && input.rank() <= 5); ensure!(mask.rank() == input.rank()); ensure!(output.datum_type() == input.datum_type()); - ensure!(mask.datum_type() == input.datum_type()); + let mask_is_bool = mask.datum_type() == bool::datum_type(); + ensure!(Self::is_supported_mask_dt(input.datum_type(), mask.datum_type())); + // post_softmax_mask is meaningful only with a bool mask (CPU contract). + ensure!(!post_softmax_mask || mask_is_bool); let shape = pad(input.shape(), 1); let strides = pad(input.strides(), 0); let mask_strides = pad(&compute_broadcast_strides::(mask.shape(), mask.strides())?, 0); let output_strides = pad(output.strides(), 0); - let inner_len = shape[4]; let i_view = get_cuda_view(input); let mask_view = get_cuda_view(mask); @@ -74,14 +89,19 @@ impl ScaledMaskedSoftmax { let block_size = if inner_len.is_power_of_two() && inner_len > 32 { inner_len.min(1024) } else { 0 }; - let func = cuda_context() - .load_pipeline(LibraryName::NN, self.kernel_name(input.datum_type(), block_size)?)?; + let func = cuda_context().load_pipeline( + LibraryName::NN, + self.kernel_name(input.datum_type(), mask_is_bool, block_size)?, + )?; let mut launch_args = TractLaunchArgs::new(stream, &func); launch_args.push_view(&i_view); launch_args.push_view(&mask_view); launch_args.push::(scale.cast_to_scalar::()?); launch_args.push_view(&o_view); + if mask_is_bool { + launch_args.push::(post_softmax_mask as i32); + } launch_args.push_slice_i32(&shape); launch_args.push_slice_i32(&strides); launch_args.push_slice_i32(&mask_strides); @@ -111,22 +131,28 @@ pub fn cuda_scaled_masked_softmax_dispatch( input: &DeviceTensor, scale: &Tensor, mask: &DeviceTensor, + post_softmax_mask: bool, output: &DeviceTensor, ) -> TractResult<()> { crate::with_cuda_stream(|stream| { - ScaledMaskedSoftmax.dispatch_eval(stream, input, scale, mask, output) + ScaledMaskedSoftmax.dispatch_eval(stream, input, scale, mask, post_softmax_mask, output) }) } crate::register_cuda_op!( tract_transformers::ops::scaled_masked_softmax::ScaledMaskedSoftmax, |source, node, op| { - rule_if!(!op.post_softmax_mask); - rule_if!(ScaledMaskedSoftmax::is_supported_dt( - source.node_input_facts(node.id)?[0].datum_type + let facts = source.node_input_facts(node.id)?; + rule_if!(ScaledMaskedSoftmax::is_supported_dt(facts[0].datum_type)); + rule_if!(ScaledMaskedSoftmax::is_supported_mask_dt( + facts[0].datum_type, + facts[1].datum_type, )); + // post_softmax_mask requires a bool mask (CPU contract). + rule_if!(!op.post_softmax_mask || facts[1].datum_type == bool::datum_type()); Ok(Some(Box::new(tract_gpu::ops::scaled_masked_softmax::GpuScaledMaskedSoftmax::new( op.scale.clone(), + op.post_softmax_mask, "Cuda", cuda_scaled_masked_softmax_dispatch, )))) @@ -170,13 +196,63 @@ mod tests { .eval(tvec![a.to_host()?.into_tvalue(), mask.to_host()?.into_tvalue()])?[0] .clone() .into_tensor(); - let cuda_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask)?; + let cuda_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask, false)?; cpu_output .close_enough(&cuda_output.to_host()?.into_tensor(), Approximation::Approximate)?; Ok(()) }) } + /// Bool-mask path with a fully-masked row. Without post_softmax_mask + /// the output is NaN (matches CPU); with it on, the NaN is scrubbed to 0. + #[test] + fn test_scaled_bool_masked_softmax_post_mask_scrubs_nan() -> TractResult<()> { + crate::with_cuda_stream(|stream| { + let m = 3; + let n = 5; + let scale: Arc<_> = tensor0(0.125f32).into(); + // Row 0: fully masked. Row 1: partially masked. Row 2: fully unmasked. + let mask_data: Vec = (0..m) + .flat_map(|r| { + (0..n).map(move |c| match r { + 0 => false, + 1 => c >= 2, + _ => true, + }) + }) + .collect(); + let mask = Tensor::from_shape(&[1, 1, m, n], &mask_data)?.into_device()?; + let a = Tensor::from_shape( + &[1, 1, m, n], + &(0..m * n).map(|f| f as f32).collect::>(), + )? + .into_device()?; + + for post in [false, true] { + let cpu = scaled_masked_softmax::ScaledMaskedSoftmax { + scale: scale.clone(), + post_softmax_mask: post, + }; + let cpu_out = cpu + .eval(tvec![a.to_host()?.into_tvalue(), mask.to_host()?.into_tvalue()])?[0] + .clone() + .into_tensor(); + let cuda_out = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask, post)?; + let cuda_host = cuda_out.to_host()?.into_tensor(); + let cpu_slice = cpu_out.view().as_slice::().unwrap(); + let cuda_slice = cuda_host.view().as_slice::().unwrap(); + for (i, (c, g)) in cpu_slice.iter().zip(cuda_slice.iter()).enumerate() { + if c.is_nan() { + assert!(g.is_nan(), "post={post} idx={i}: cpu NaN, cuda {g}"); + } else { + assert!((c - g).abs() < 1e-5, "post={post} idx={i}: cpu {c} cuda {g}"); + } + } + } + Ok(()) + }) + } + proptest::proptest! { #[test] fn scaled_masked_softmax_prop_f32(pb in any::>()) { @@ -269,7 +345,7 @@ mod tests { let mask = Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?.into_device()?; let scale: Arc<_> = tensor0::(0.125f32.as_()).into(); - let cuda_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask)?; + let cuda_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask, false)?; Ok(cuda_output.to_host()?.into_tensor()) }) } diff --git a/gpu/src/ops/scaled_masked_softmax.rs b/gpu/src/ops/scaled_masked_softmax.rs index 2e6faaefe1..8aae92b261 100644 --- a/gpu/src/ops/scaled_masked_softmax.rs +++ b/gpu/src/ops/scaled_masked_softmax.rs @@ -2,14 +2,25 @@ use crate::tensor::{DeviceTensor, DeviceTensorExt}; use derive_new::new; use tract_core::internal::*; -/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=2) -/// Only input of rank of 3 is supported -pub type DispatchScaledMaskedSoftmaxFn = - fn(&DeviceTensor, &Tensor, &DeviceTensor, &DeviceTensor) -> TractResult<()>; +/// Fused scale + mask + softmax over the last axis. When the mask is float +/// it is added in log-space (`out = softmax(x*scale + mask)`); when it is +/// bool, masked positions are substituted with `-inf` before softmax. +/// +/// If `post_softmax_mask` is true (bool mask only), fully-masked rows — whose +/// softmax would otherwise be NaN — are written as `0` instead. Partially- +/// masked rows are unaffected. +pub type DispatchScaledMaskedSoftmaxFn = fn( + input: &DeviceTensor, + scale: &Tensor, + mask: &DeviceTensor, + post_softmax_mask: bool, + output: &DeviceTensor, +) -> TractResult<()>; #[derive(Clone, new)] pub struct GpuScaledMaskedSoftmax { pub scale: Arc, + pub post_softmax_mask: bool, pub backend_name: &'static str, pub dispatch: DispatchScaledMaskedSoftmaxFn, } @@ -22,7 +33,9 @@ impl std::fmt::Debug for GpuScaledMaskedSoftmax { impl PartialEq for GpuScaledMaskedSoftmax { fn eq(&self, other: &Self) -> bool { - self.backend_name == other.backend_name && self.scale == other.scale + self.backend_name == other.backend_name + && self.scale == other.scale + && self.post_softmax_mask == other.post_softmax_mask } } impl Eq for GpuScaledMaskedSoftmax {} @@ -31,6 +44,7 @@ impl std::hash::Hash for GpuScaledMaskedSoftmax { fn hash(&self, state: &mut H) { self.backend_name.hash(state); self.scale.hash(state); + self.post_softmax_mask.hash(state); } } @@ -61,7 +75,7 @@ impl EvalOp for GpuScaledMaskedSoftmax { input.datum_type(), input.shape(), )?; - (self.dispatch)(input, &self.scale, mask, &output)?; + (self.dispatch)(input, &self.scale, mask, self.post_softmax_mask, &output)?; Ok(tvec!(output.into_tensor().into_tvalue())) } } @@ -71,7 +85,10 @@ impl TypedOp for GpuScaledMaskedSoftmax { crate::utils::facts_to_device_facts(inputs, |facts| { ensure!(facts.len() == 2); let dt = facts[0].datum_type; - ensure!(dt == facts[1].datum_type); + let mask_dt = facts[1].datum_type; + ensure!(mask_dt == dt || mask_dt == bool::datum_type()); + // post_softmax_mask is bool-mask-only per the CPU contract. + ensure!(!self.post_softmax_mask || mask_dt == bool::datum_type()); ensure!(facts[0].rank() <= 5); ensure!(facts[0].rank() >= 2); ensure!(facts[0].rank() == facts[1].rank()); diff --git a/metal/src/kernels/nn/mod.rs b/metal/src/kernels/nn/mod.rs index 270dd7a674..9ff4b13ea8 100644 --- a/metal/src/kernels/nn/mod.rs +++ b/metal/src/kernels/nn/mod.rs @@ -40,11 +40,11 @@ pub fn all_functions() -> Vec { .flat_map(|dt| Softmax.kernel_name(dt).into_iter()), ); - functions.extend( - tract_gpu::tensor::DeviceTensor::SUPPORTED_DT + functions.extend(tract_gpu::tensor::DeviceTensor::SUPPORTED_DT.into_iter().flat_map(|dt| { + [false, true] .into_iter() - .flat_map(|dt| ScaledMaskedSoftmax.kernel_name(dt).into_iter()), - ); + .flat_map(move |mb| ScaledMaskedSoftmax.kernel_name(dt, mb).into_iter()) + })); functions.extend( tract_gpu::tensor::DeviceTensor::SUPPORTED_DT diff --git a/metal/src/kernels/nn/nn_ops.metal b/metal/src/kernels/nn/nn_ops.metal index f67d08f201..5d27078345 100644 --- a/metal/src/kernels/nn/nn_ops.metal +++ b/metal/src/kernels/nn/nn_ops.metal @@ -495,6 +495,126 @@ template [[host_name("nn_ops::scaled_masked_softmax_nd5_" "f16")]] [[kernel]] scaled_masked_softmax_nd5_t scaled_masked_softmax_nd5; +// Bool-mask variant: mask is uchar (0/1). Masked positions are substituted +// with -inf before softmax (so exp(-inf)=0 naturally zeroes them in the +// output). When post_mask is non-zero, fully-masked rows — whose softmax +// would otherwise be NaN — are written as 0 instead. +template +[[kernel]] void scaled_bool_masked_softmax_nd5( + device const void *input_b, device const void *mask_b, + constant float *scale_b, device void *output_b, + constant uint *post_mask_b, constant const size_t shape[5], + constant const size_t strides[5], + constant const size_t mask_strides[5], + constant const size_t out_strides[5], + + uint3 tgpig [[threadgroup_position_in_grid]], + uint tiisg [[thread_index_in_simdgroup]], + uint tpsg [[threads_per_simdgroup]], + uint3 tptg [[thread_position_in_threadgroup]], + uint3 tptgN [[threads_per_threadgroup]], + + threadgroup float *tgmem [[threadgroup(0)]]) { + + const uint tid = tptg.x; + const uint tg_sz = tptgN.x; + const uint sg_id = tid / tpsg; + const uint lane = tiisg; + + const size_t row = (size_t)tgpig.x; + const size_t h = (size_t)tgpig.y; + const size_t z = (size_t)tgpig.z; + const size_t z0 = z / shape[1]; + const size_t z1 = z % shape[1]; + + device const F *x = (device const F *)input_b; + device const uchar *mask = (device const uchar *)mask_b; + device F *out = (device F *)output_b; + + const float scale = *scale_b; + const bool post_mask = *post_mask_b != 0; + + x += row * strides[3] + h * strides[2] + z1 * strides[1] + z0 * strides[0]; + mask += row * mask_strides[3] + h * mask_strides[2] + + z1 * mask_strides[1] + z0 * mask_strides[0]; + out += row * out_strides[3] + h * out_strides[2] + z1 * out_strides[1] + + z0 * out_strides[0]; + + threadgroup float *buf_iw = tgmem; + threadgroup float *vals = tgmem + 32; + + const uint simd_size = tpsg; + const uint num_sg = (tg_sz + simd_size - 1u) / simd_size; + const size_t cols = shape[4]; + + // 1) Substitute -inf at masked positions, then take row max + float max_val = -INFINITY; + for (size_t col = (size_t)tid; col < cols; col += (size_t)tg_sz) { + const bool m = mask[col * mask_strides[4]] != 0; + const float xv = (float)x[col * strides[4]] * scale; + const float v = m ? xv : -INFINITY; + vals[col] = v; + max_val = metal::max(max_val, v); + } + + float sg_max = simd_max(max_val); + if (lane == 0) { + buf_iw[sg_id] = sg_max; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (sg_id == 0) { + float x0 = (lane < num_sg) ? buf_iw[lane] : -INFINITY; + float block_max = simd_max(x0); + if (lane == 0) + buf_iw[0] = block_max; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_val = buf_iw[0]; + + // 2) exp(vals - max) and row sum + float sum = 0.0f; + for (size_t col = (size_t)tid; col < cols; col += (size_t)tg_sz) { + float e = exp(vals[col] - max_val); + vals[col] = e; + sum += e; + } + + float sg_sum = simd_sum(sum); + if (lane == 0) { + buf_iw[sg_id] = sg_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (sg_id == 0) { + float x0 = (lane < num_sg) ? buf_iw[lane] : 0.0f; + float block_sum = simd_sum(x0); + if (lane == 0) + buf_iw[0] = block_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sum = buf_iw[0]; + + // Row-uniform: sum <= 0 (or NaN) iff every position was masked. With + // post_mask we write 0 in that case to scrub the NaN; otherwise we let + // 1/sum propagate. + const bool zero_row = post_mask && !(sum > 0.0f); + const float inv_sum = 1.0f / sum; + + for (size_t col = (size_t)tid; col < cols; col += (size_t)tg_sz) { + float y = zero_row ? 0.0f : vals[col] * inv_sum; + out[col * out_strides[4]] = (F)y; + } +} + +typedef decltype(scaled_bool_masked_softmax_nd5) + scaled_bool_masked_softmax_nd5_t; + +template [[host_name("nn_ops::scaled_bool_masked_softmax_nd5_" + "f32")]] [[kernel]] scaled_bool_masked_softmax_nd5_t + scaled_bool_masked_softmax_nd5; +template [[host_name("nn_ops::scaled_bool_masked_softmax_nd5_" + "f16")]] [[kernel]] scaled_bool_masked_softmax_nd5_t + scaled_bool_masked_softmax_nd5; + constant float GELU_COEF_A = 0.044715f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; diff --git a/metal/src/kernels/nn/scaled_masked_softmax.rs b/metal/src/kernels/nn/scaled_masked_softmax.rs index 5ce0323d6d..89b7c9f263 100644 --- a/metal/src/kernels/nn/scaled_masked_softmax.rs +++ b/metal/src/kernels/nn/scaled_masked_softmax.rs @@ -14,14 +14,23 @@ impl ScaledMaskedSoftmax { matches!(dt, DatumType::F32 | DatumType::F16) } - pub fn kernel_name(&self, dt: DatumType) -> TractResult { + pub fn is_supported_mask_dt(input_dt: DatumType, mask_dt: DatumType) -> bool { + mask_dt == input_dt || mask_dt == bool::datum_type() + } + + pub fn kernel_name(&self, dt: DatumType, mask_is_bool: bool) -> TractResult { ensure!( Self::is_supported_dt(dt), - "Unsupported dt {:?} for metal scaled masked softmaxop", + "Unsupported dt {:?} for metal scaled masked softmax op", dt ); let tname = DeviceTensor::tname(dt)?; - Ok(format!("nn_ops::scaled_masked_softmax_nd5_{tname}")) + let stem = if mask_is_bool { + "scaled_bool_masked_softmax_nd5" + } else { + "scaled_masked_softmax_nd5" + }; + Ok(format!("nn_ops::{stem}_{tname}")) } pub fn eval( @@ -30,9 +39,10 @@ impl ScaledMaskedSoftmax { input: &DeviceTensor, scale: &Tensor, mask: &DeviceTensor, + post_softmax_mask: bool, ) -> TractResult { let output = unsafe { DeviceTensor::uninitialized_dt(input.datum_type(), input.shape())? }; - self.dispatch_eval(stream, input, scale, mask, &output)?; + self.dispatch_eval(stream, input, scale, mask, post_softmax_mask, &output)?; stream.wait_until_completed()?; Ok(output) } @@ -43,6 +53,7 @@ impl ScaledMaskedSoftmax { input: &DeviceTensor, scale: &Tensor, mask: &DeviceTensor, + post_softmax_mask: bool, output: &DeviceTensor, ) -> TractResult<()> { stream.retain_tensor(input); @@ -54,7 +65,10 @@ impl ScaledMaskedSoftmax { ensure!(input.rank() <= 5); ensure!(mask.rank() == input.rank()); ensure!(output.datum_type() == input.datum_type()); - ensure!(mask.datum_type() == input.datum_type()); + let mask_is_bool = mask.datum_type() == bool::datum_type(); + ensure!(Self::is_supported_mask_dt(input.datum_type(), mask.datum_type())); + // post_softmax_mask is meaningful only with a bool mask (CPU contract). + ensure!(!post_softmax_mask || mask_is_bool); let scale = scale.cast_to::()?; let shape = pad(input.shape(), 1); @@ -72,8 +86,10 @@ impl ScaledMaskedSoftmax { let tg_floats = 32 + inner_len; let tg_bytes = tg_floats * f32::datum_type().size_of(); - let pipeline = - stream.load_pipeline(LibraryName::NNOps, &self.kernel_name(input.datum_type())?)?; + let pipeline = stream.load_pipeline( + LibraryName::NNOps, + &self.kernel_name(input.datum_type(), mask_is_bool)?, + )?; let command_buffer = stream.command_buffer(); command_buffer.encode(|encoder| { @@ -82,10 +98,18 @@ impl ScaledMaskedSoftmax { encoder.set_metal_tensor(1, mask, metal::MTLResourceUsage::Read); encoder.set_tensor(2, &scale); encoder.set_metal_tensor(3, output, metal::MTLResourceUsage::Write); - encoder.set_slice(4, &shape); - encoder.set_slice(5, &strides); - encoder.set_slice(6, &mask_strides); - encoder.set_slice(7, &out_strides); + // Bool-mask kernel takes a `post_mask` flag at slot 4; the + // float-mask kernel doesn't, so slots shift down by one. + let next_slot = if mask_is_bool { + encoder.set_slice(4, &[post_softmax_mask as u32]); + 5 + } else { + 4 + }; + encoder.set_slice(next_slot, &shape); + encoder.set_slice(next_slot + 1, &strides); + encoder.set_slice(next_slot + 2, &mask_strides); + encoder.set_slice(next_slot + 3, &out_strides); encoder.set_threadgroup_memory_length(0, tg_bytes as _); let grid_size = MTLSize { width: shape[3] as _, @@ -111,22 +135,27 @@ pub fn metal_scaled_masked_softmax_dispatch( input: &DeviceTensor, scale: &Tensor, mask: &DeviceTensor, + post_softmax_mask: bool, output: &DeviceTensor, ) -> TractResult<()> { crate::with_metal_stream(|stream| { - ScaledMaskedSoftmax.dispatch_eval(stream, input, scale, mask, output) + ScaledMaskedSoftmax.dispatch_eval(stream, input, scale, mask, post_softmax_mask, output) }) } crate::register_metal_op!( tract_transformers::ops::scaled_masked_softmax::ScaledMaskedSoftmax, |source, node, op| { - rule_if!(!op.post_softmax_mask); - rule_if!(ScaledMaskedSoftmax::is_supported_dt( - source.node_input_facts(node.id)?[0].datum_type + let facts = source.node_input_facts(node.id)?; + rule_if!(ScaledMaskedSoftmax::is_supported_dt(facts[0].datum_type)); + rule_if!(ScaledMaskedSoftmax::is_supported_mask_dt( + facts[0].datum_type, + facts[1].datum_type, )); + rule_if!(!op.post_softmax_mask || facts[1].datum_type == bool::datum_type()); Ok(Some(Box::new(tract_gpu::ops::scaled_masked_softmax::GpuScaledMaskedSoftmax::new( op.scale.clone(), + op.post_softmax_mask, "Metal", metal_scaled_masked_softmax_dispatch, )))) @@ -171,7 +200,7 @@ mod tests { .eval(tvec![a.to_host()?.into_tvalue(), mask.to_host()?.into_tvalue()])?[0] .clone() .into_tensor(); - let metal_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask)?; + let metal_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask, false)?; cpu_output .close_enough(&metal_output.to_host()?.into_tensor(), Approximation::Approximate)?; Ok(()) @@ -201,13 +230,63 @@ mod tests { .eval(tvec![a.to_host()?.into_tvalue(), mask.to_host()?.into_tvalue()])?[0] .clone() .into_tensor(); - let metal_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask)?; + let metal_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask, false)?; cpu_output .close_enough(&metal_output.to_host()?.into_tensor(), Approximation::Approximate)?; Ok(()) }) } + /// Bool-mask path with a fully-masked row. Without post_softmax_mask + /// the output is NaN (matches CPU); with it on, the NaN is scrubbed to 0. + #[test] + fn test_scaled_bool_masked_softmax_post_mask_scrubs_nan() -> TractResult<()> { + with_borrowed_metal_stream(|stream| { + let m = 3; + let n = 5; + let scale: Arc<_> = tensor0(0.125f32).into(); + // Row 0: fully masked. Row 1: partially masked. Row 2: fully unmasked. + let mask_data: Vec = (0..m) + .flat_map(|r| { + (0..n).map(move |c| match r { + 0 => false, + 1 => c >= 2, + _ => true, + }) + }) + .collect(); + let mask = Tensor::from_shape(&[1, 1, m, n], &mask_data)?.into_device()?; + let a = Tensor::from_shape( + &[1, 1, m, n], + &(0..m * n).map(|f| f as f32).collect::>(), + )? + .into_device()?; + + for post in [false, true] { + let cpu = scaled_masked_softmax::ScaledMaskedSoftmax { + scale: scale.clone(), + post_softmax_mask: post, + }; + let cpu_out = cpu + .eval(tvec![a.to_host()?.into_tvalue(), mask.to_host()?.into_tvalue()])?[0] + .clone() + .into_tensor(); + let metal_out = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask, post)?; + let metal_host = metal_out.to_host()?.into_tensor(); + let cpu_slice = cpu_out.view().as_slice::().unwrap(); + let metal_slice = metal_host.view().as_slice::().unwrap(); + for (i, (c, g)) in cpu_slice.iter().zip(metal_slice.iter()).enumerate() { + if c.is_nan() { + assert!(g.is_nan(), "post={post} idx={i}: cpu NaN, metal {g}"); + } else { + assert!((c - g).abs() < 1e-5, "post={post} idx={i}: cpu {c} metal {g}"); + } + } + } + Ok(()) + }) + } + proptest::proptest! { #[test] fn scaled_masked_softmax_prop_f32(pb in any::>()) { @@ -300,7 +379,7 @@ mod tests { let mask = Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?.into_device()?; let scale: Arc<_> = tensor0::(0.125f32.as_()).into(); - let metal_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask)?; + let metal_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask, false)?; Ok(metal_output.to_host()?.into_tensor()) }) } From 4bba45411b12c147535e4ce4885af3c7199c4df3 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 27 May 2026 09:17:37 +0200 Subject: [PATCH 2/5] unstick ci From b1c18d91f3c3bbc3dfd543c4b347ec16555b0504 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 27 May 2026 07:46:47 +0000 Subject: [PATCH 3/5] harness/nemotron: tighten gpu allowlists (drop SMS, DiagGather, IsNan, Reduce) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Both SMS (this branch) and DiagGather (already on main) now have CUDA + Metal kernels. IsNan and Reduce don't appear in any of the 4 streaming models — IsNan never did, Reduce is always F32 which both backends handle. Audit on --cuda confirms zero CPU instances; --metal mirrors the same allowlist now that the Metal kernels exist. Tight placement check: any regression that puts one of these on CPU now fails CI. --- harness/nemotron-speech-streaming-en-0.6b/ci.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/harness/nemotron-speech-streaming-en-0.6b/ci.sh b/harness/nemotron-speech-streaming-en-0.6b/ci.sh index ee01c886c5..a07d990de5 100755 --- a/harness/nemotron-speech-streaming-en-0.6b/ci.sh +++ b/harness/nemotron-speech-streaming-en-0.6b/ci.sh @@ -12,8 +12,8 @@ for rt in $TRACT_RUNTIMES do gpu_assert="" case "$rt" in - --cuda) gpu_assert="--assert-op-only Cuda*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,IsNan,Add,Range,Cast,Eq,Div,Sub,Scan,Gather,DiagGather,ScaledMaskedSoftmax";; - --metal) gpu_assert="--assert-op-only Metal*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,IsNan,Add,Range,Cast,Eq,Div,Sub,Scan,Gather,Reduce*,DiagGather,ScaledMaskedSoftmax";; + --cuda) gpu_assert="--assert-op-only Cuda*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,Add,Range,Cast,Eq,Div,Sub,Scan,Gather";; + --metal) gpu_assert="--assert-op-only Metal*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,Add,Range,Cast,Eq,Div,Sub,Scan,Gather";; esac for m in preprocessor encoder decoder joint From d4a988291967be1a56969d69ebed360de9ed356c Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 27 May 2026 08:28:34 +0000 Subject: [PATCH 4/5] harness/nemotron: inline decoder Scan via force_scan_external_state MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The decoder is stepped one token at a time by the caller (external state plumbed through the outer graph), so iters resolves to 1 and the Scan body can be inlined. Apply the existing core force_scan_external_state transform on the decoder run; the two LSTM cells now land on GPU. Drop Scan from the gpu allowlists — no model in the harness keeps a Scan node on CPU after this. --- harness/nemotron-speech-streaming-en-0.6b/ci.sh | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/harness/nemotron-speech-streaming-en-0.6b/ci.sh b/harness/nemotron-speech-streaming-en-0.6b/ci.sh index a07d990de5..8076d30134 100755 --- a/harness/nemotron-speech-streaming-en-0.6b/ci.sh +++ b/harness/nemotron-speech-streaming-en-0.6b/ci.sh @@ -12,8 +12,8 @@ for rt in $TRACT_RUNTIMES do gpu_assert="" case "$rt" in - --cuda) gpu_assert="--assert-op-only Cuda*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,Add,Range,Cast,Eq,Div,Sub,Scan,Gather";; - --metal) gpu_assert="--assert-op-only Metal*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,Add,Range,Cast,Eq,Div,Sub,Scan,Gather";; + --cuda) gpu_assert="--assert-op-only Cuda*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,Add,Range,Cast,Eq,Div,Sub,Gather";; + --metal) gpu_assert="--assert-op-only Metal*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,Add,Range,Cast,Eq,Div,Sub,Gather";; esac for m in preprocessor encoder decoder joint @@ -24,11 +24,18 @@ do else nnef_file=$MODEL.$m.nnef.tgz fi + # Decoder is stepped one token per call by the caller (external state + # carry); force the Scan op into single-iter inlining so the LSTM body + # lands on the GPU instead of bouncing through CPU each step. + extra_transforms="" + if [ "$m" = "decoder" ]; then + extra_transforms="-t force_scan_external_state" + fi $CACHE_FILE \ $S3DIR/$nnef_file \ $S3DIR/$MODEL.$m.io.npz - $TRACT_RUN $MODELS/$S3DIR/$nnef_file $rt --nnef-tract-transformers -t transformers_detect_all run \ + $TRACT_RUN $MODELS/$S3DIR/$nnef_file $rt --nnef-tract-transformers -t transformers_detect_all $extra_transforms run \ --input-from-bundle $MODELS/$S3DIR/$MODEL.$m.io.npz --assert-output-bundle $MODELS/$S3DIR/$MODEL.$m.io.npz \ --approx very $gpu_assert done From 3800d8488b5eca148763eb20d3cec30bc03c3df4 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 27 May 2026 11:06:21 +0000 Subject: [PATCH 5/5] harness/nemotron: drop Gather from gpu allowlists Now that cuda + metal Gather kernels are on main, the decoder embedding lookup runs on GPU. Audit confirms zero CPU Gather across all 4 models (decoder is run with -t force_scan_external_state so its embedding input is fed directly to CudaGather/MetalGather). --- harness/nemotron-speech-streaming-en-0.6b/ci.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/harness/nemotron-speech-streaming-en-0.6b/ci.sh b/harness/nemotron-speech-streaming-en-0.6b/ci.sh index 8076d30134..352fa87757 100755 --- a/harness/nemotron-speech-streaming-en-0.6b/ci.sh +++ b/harness/nemotron-speech-streaming-en-0.6b/ci.sh @@ -12,8 +12,8 @@ for rt in $TRACT_RUNTIMES do gpu_assert="" case "$rt" in - --cuda) gpu_assert="--assert-op-only Cuda*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,Add,Range,Cast,Eq,Div,Sub,Gather";; - --metal) gpu_assert="--assert-op-only Metal*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,Add,Range,Cast,Eq,Div,Sub,Gather";; + --cuda) gpu_assert="--assert-op-only Cuda*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,Add,Range,Cast,Eq,Div,Sub";; + --metal) gpu_assert="--assert-op-only Metal*,Gpu*,DeviceSync*,Const,Source,STFT,Pad,Add,Range,Cast,Eq,Div,Sub";; esac for m in preprocessor encoder decoder joint