Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions cuda/src/kernels/cu/nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,136 @@ __device__ void scaled_masked_softmax(
out_stride_4); \
}

// Bool-mask variant: mask is char (0/1), substitutes -inf at masked positions
// before the softmax. When post_mask is non-zero, fully-masked rows (all
// inputs -inf) are written as 0 instead of the NaN the naive softmax would
// emit. Partially-masked rows are unaffected: exp(-inf) = 0 already zeros
// masked positions in the output.
template <typename T, int BLOCK_SIZE>
__device__ void scaled_bool_masked_softmax(
const T *x, const char *mask, const float scale, T *dst,
const int32_t post_mask, const int32_t shape_0, const int32_t shape_1,
const int32_t shape_2, const int32_t shape_3, const int32_t shape_4,
const int32_t stride_0, const int32_t stride_1, const int32_t stride_2,
const int32_t stride_3, const int32_t stride_4,
const int32_t mask_stride_0, const int32_t mask_stride_1,
const int32_t mask_stride_2, const int32_t mask_stride_3,
const int32_t mask_stride_4, const int32_t out_stride_0,
const int32_t out_stride_1, const int32_t out_stride_2,
const int32_t out_stride_3, const int32_t out_stride_4) {
int32_t z0 = blockIdx.z / shape_1;
int32_t z1 = blockIdx.z % shape_1;
x += blockIdx.x * stride_3 + blockIdx.y * stride_2 + z1 * stride_1 +
z0 * stride_0;
mask += blockIdx.x * mask_stride_3 + blockIdx.y * mask_stride_2 +
z1 * mask_stride_1 + z0 * mask_stride_0;
dst += blockIdx.x * out_stride_3 + blockIdx.y * out_stride_2 +
z1 * out_stride_1 + z0 * out_stride_0;

const int block_size = BLOCK_SIZE == 0 ? blockDim.x : BLOCK_SIZE;

const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;

extern __shared__ float data_soft_max_f32[];
float *buf_iw = data_soft_max_f32;
float *vals = buf_iw + WARP_SIZE;

float max_val = -CUDART_INF_F;
_Pragma("unroll") for (int col0 = 0; col0 < shape_4; col0 += block_size) {
const int col = col0 + threadIdx.x;
if (col >= shape_4) {
break;
}

const bool m = mask[col * mask_stride_4] != 0;
const float val = m ? ((float)x[col * stride_4]) * scale : -CUDART_INF_F;
vals[col] = val;
max_val = max(max_val, val);
}

max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -CUDART_INF_F;
}
__syncthreads();

if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
__syncthreads();

max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}

float tmp = 0.0f;
_Pragma("unroll") for (int col0 = 0; col0 < shape_4; col0 += block_size) {
const int col = col0 + threadIdx.x;
if (col >= shape_4) {
break;
}

const float val = expf(vals[col] - max_val);
tmp += val;
vals[col] = val;
}

tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__syncthreads();
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
__syncthreads();

if (lane_id == 0) {
buf_iw[warp_id] = tmp;
}
__syncthreads();

tmp = buf_iw[lane_id];
tmp = warp_reduce_sum(tmp);
}

// Row-uniform: tmp <= 0 (or NaN) iff every position was masked. When
// post_mask is set we write 0 in that case to scrub the NaN; otherwise
// we fall through to the normal 1/sum path and let it propagate.
const bool zero_row = post_mask && !(tmp > 0.0f);
const float inv_sum = 1.0f / tmp;

_Pragma("unroll") for (int col0 = 0; col0 < shape_4; col0 += block_size) {
const int col = col0 + threadIdx.x;
if (col >= shape_4) {
return;
}
dst[col * out_stride_4] = zero_row ? (T)0.0f : (T)(vals[col] * inv_sum);
}
}

#define INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, bname, \
block_size_template) \
extern "C" __global__ void scaled_bool_masked_softmax_##bname##name( \
const T *x, const char *mask, const float scale, T *dst, \
const int32_t post_mask, const int32_t shape_0, \
const int32_t shape_1, const int32_t shape_2, const int32_t shape_3, \
const int32_t shape_4, const int32_t stride_0, \
const int32_t stride_1, const int32_t stride_2, \
const int32_t stride_3, const int32_t stride_4, \
const int32_t mask_stride_0, const int32_t mask_stride_1, \
const int32_t mask_stride_2, const int32_t mask_stride_3, \
const int32_t mask_stride_4, const int32_t out_stride_0, \
const int32_t out_stride_1, const int32_t out_stride_2, \
const int32_t out_stride_3, const int32_t out_stride_4) { \
scaled_bool_masked_softmax<T, block_size_template>( \
x, mask, scale, dst, post_mask, shape_0, shape_1, shape_2, \
shape_3, shape_4, stride_0, stride_1, stride_2, stride_3, \
stride_4, mask_stride_0, mask_stride_1, mask_stride_2, \
mask_stride_3, mask_stride_4, out_stride_0, out_stride_1, \
out_stride_2, out_stride_3, out_stride_4); \
}

#define INSTANTIATE_RMS_NORM(name, T, bname, block_size) \
extern "C" __global__ void rms_norm_##bname##name( \
const T *x, T *dst, const int32_t shape_0, const int32_t shape_1, \
Expand Down Expand Up @@ -652,8 +782,24 @@ INSTANTIATE_SOFTMAX(f16, __half, , 1024)
INSTANTIATE_SCALED_MASKED_SOFTMAX(name, T, 32768_, 1024) \
INSTANTIATE_SCALED_MASKED_SOFTMAX(name, T, 0_, 0)

#define INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX_FOR_T(name, T) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 32_, 32) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 64_, 64) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 128_, 126) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 256_, 256) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 512_, 512) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 1024_, 1024) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 2048_, 1024) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 4096_, 1024) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 8192_, 1024) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 16384_, 1024) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 32768_, 1024) \
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX(name, T, 0_, 0)

INSTANTIATE_SCALED_MASKED_SOFTMAX_FOR_T(f32, float)
INSTANTIATE_SCALED_MASKED_SOFTMAX_FOR_T(f16, __half)
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX_FOR_T(f32, float)
INSTANTIATE_SCALED_BOOL_MASKED_SOFTMAX_FOR_T(f16, __half)

INSTANTIATE_REDUCE(f32, float, small_, 32)
INSTANTIATE_REDUCE(f32, float, , 1024)
Expand Down
6 changes: 5 additions & 1 deletion cuda/src/kernels/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ pub fn all_functions() -> Vec<String> {
tract_gpu::tensor::DeviceTensor::SUPPORTED_DT
.into_iter()
.flat_map(|dt| sms_block_sizes().into_iter().map(move |bs| (dt, bs as usize)))
.flat_map(|(dt, bs)| ScaledMaskedSoftmax.kernel_name(dt, bs).into_iter()),
.flat_map(|(dt, bs)| {
[false, true]
.into_iter()
.flat_map(move |mb| ScaledMaskedSoftmax.kernel_name(dt, mb, bs).into_iter())
}),
);

functions.extend(
Expand Down
110 changes: 93 additions & 17 deletions cuda/src/kernels/nn/scaled_masked_softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,25 @@ impl ScaledMaskedSoftmax {
matches!(dt, DatumType::F32 | DatumType::F16)
}

pub fn kernel_name(&self, dt: DatumType, block_size: usize) -> TractResult<String> {
pub fn is_supported_mask_dt(input_dt: DatumType, mask_dt: DatumType) -> bool {
mask_dt == input_dt || mask_dt == bool::datum_type()
}

pub fn kernel_name(
&self,
input_dt: DatumType,
mask_is_bool: bool,
block_size: usize,
) -> TractResult<String> {
ensure!(
Self::is_supported_dt(dt),
"Unsupported dt {:?} for cuda scaled masked softmaxop",
dt
Self::is_supported_dt(input_dt),
"Unsupported dt {:?} for cuda scaled masked softmax op",
input_dt
);
let tname = DeviceTensor::tname(dt)?;
Ok(format!("scaled_masked_softmax_{block_size}_{tname}"))
let tname = DeviceTensor::tname(input_dt)?;
let stem =
if mask_is_bool { "scaled_bool_masked_softmax" } else { "scaled_masked_softmax" };
Ok(format!("{stem}_{block_size}_{tname}"))
}

pub fn eval(
Expand All @@ -34,9 +45,10 @@ impl ScaledMaskedSoftmax {
input: &DeviceTensor,
scale: &Tensor,
mask: &DeviceTensor,
post_softmax_mask: bool,
) -> TractResult<DeviceTensor> {
let output = unsafe { DeviceTensor::uninitialized_dt(input.datum_type(), input.shape())? };
self.dispatch_eval(stream, input, scale, mask, &output)?;
self.dispatch_eval(stream, input, scale, mask, post_softmax_mask, &output)?;
stream.synchronize()?;
Ok(output)
}
Expand All @@ -47,19 +59,22 @@ impl ScaledMaskedSoftmax {
input: &DeviceTensor,
scale: &Tensor,
mask: &DeviceTensor,
post_softmax_mask: bool,
output: &DeviceTensor,
) -> TractResult<()> {
ensure!(output.shape() == input.shape());
ensure!(input.rank() >= 2 && input.rank() <= 5);
ensure!(mask.rank() == input.rank());
ensure!(output.datum_type() == input.datum_type());
ensure!(mask.datum_type() == input.datum_type());
let mask_is_bool = mask.datum_type() == bool::datum_type();
ensure!(Self::is_supported_mask_dt(input.datum_type(), mask.datum_type()));
// post_softmax_mask is meaningful only with a bool mask (CPU contract).
ensure!(!post_softmax_mask || mask_is_bool);

let shape = pad(input.shape(), 1);
let strides = pad(input.strides(), 0);
let mask_strides = pad(&compute_broadcast_strides::<i32>(mask.shape(), mask.strides())?, 0);
let output_strides = pad(output.strides(), 0);
let inner_len = shape[4];

let i_view = get_cuda_view(input);
let mask_view = get_cuda_view(mask);
Expand All @@ -74,14 +89,19 @@ impl ScaledMaskedSoftmax {
let block_size =
if inner_len.is_power_of_two() && inner_len > 32 { inner_len.min(1024) } else { 0 };

let func = cuda_context()
.load_pipeline(LibraryName::NN, self.kernel_name(input.datum_type(), block_size)?)?;
let func = cuda_context().load_pipeline(
LibraryName::NN,
self.kernel_name(input.datum_type(), mask_is_bool, block_size)?,
)?;

let mut launch_args = TractLaunchArgs::new(stream, &func);
launch_args.push_view(&i_view);
launch_args.push_view(&mask_view);
launch_args.push::<f32>(scale.cast_to_scalar::<f32>()?);
launch_args.push_view(&o_view);
if mask_is_bool {
launch_args.push::<i32>(post_softmax_mask as i32);
}
launch_args.push_slice_i32(&shape);
launch_args.push_slice_i32(&strides);
launch_args.push_slice_i32(&mask_strides);
Expand Down Expand Up @@ -111,22 +131,28 @@ pub fn cuda_scaled_masked_softmax_dispatch(
input: &DeviceTensor,
scale: &Tensor,
mask: &DeviceTensor,
post_softmax_mask: bool,
output: &DeviceTensor,
) -> TractResult<()> {
crate::with_cuda_stream(|stream| {
ScaledMaskedSoftmax.dispatch_eval(stream, input, scale, mask, output)
ScaledMaskedSoftmax.dispatch_eval(stream, input, scale, mask, post_softmax_mask, output)
})
}

crate::register_cuda_op!(
tract_transformers::ops::scaled_masked_softmax::ScaledMaskedSoftmax,
|source, node, op| {
rule_if!(!op.post_softmax_mask);
rule_if!(ScaledMaskedSoftmax::is_supported_dt(
source.node_input_facts(node.id)?[0].datum_type
let facts = source.node_input_facts(node.id)?;
rule_if!(ScaledMaskedSoftmax::is_supported_dt(facts[0].datum_type));
rule_if!(ScaledMaskedSoftmax::is_supported_mask_dt(
facts[0].datum_type,
facts[1].datum_type,
));
// post_softmax_mask requires a bool mask (CPU contract).
rule_if!(!op.post_softmax_mask || facts[1].datum_type == bool::datum_type());
Ok(Some(Box::new(tract_gpu::ops::scaled_masked_softmax::GpuScaledMaskedSoftmax::new(
op.scale.clone(),
op.post_softmax_mask,
"Cuda",
cuda_scaled_masked_softmax_dispatch,
))))
Expand Down Expand Up @@ -170,13 +196,63 @@ mod tests {
.eval(tvec![a.to_host()?.into_tvalue(), mask.to_host()?.into_tvalue()])?[0]
.clone()
.into_tensor();
let cuda_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask)?;
let cuda_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask, false)?;
cpu_output
.close_enough(&cuda_output.to_host()?.into_tensor(), Approximation::Approximate)?;
Ok(())
})
}

/// Bool-mask path with a fully-masked row. Without post_softmax_mask
/// the output is NaN (matches CPU); with it on, the NaN is scrubbed to 0.
#[test]
fn test_scaled_bool_masked_softmax_post_mask_scrubs_nan() -> TractResult<()> {
crate::with_cuda_stream(|stream| {
let m = 3;
let n = 5;
let scale: Arc<_> = tensor0(0.125f32).into();
// Row 0: fully masked. Row 1: partially masked. Row 2: fully unmasked.
let mask_data: Vec<bool> = (0..m)
.flat_map(|r| {
(0..n).map(move |c| match r {
0 => false,
1 => c >= 2,
_ => true,
})
})
.collect();
let mask = Tensor::from_shape(&[1, 1, m, n], &mask_data)?.into_device()?;
let a = Tensor::from_shape(
&[1, 1, m, n],
&(0..m * n).map(|f| f as f32).collect::<Vec<_>>(),
)?
.into_device()?;

for post in [false, true] {
let cpu = scaled_masked_softmax::ScaledMaskedSoftmax {
scale: scale.clone(),
post_softmax_mask: post,
};
let cpu_out = cpu
.eval(tvec![a.to_host()?.into_tvalue(), mask.to_host()?.into_tvalue()])?[0]
.clone()
.into_tensor();
let cuda_out = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask, post)?;
let cuda_host = cuda_out.to_host()?.into_tensor();
let cpu_slice = cpu_out.view().as_slice::<f32>().unwrap();
let cuda_slice = cuda_host.view().as_slice::<f32>().unwrap();
for (i, (c, g)) in cpu_slice.iter().zip(cuda_slice.iter()).enumerate() {
if c.is_nan() {
assert!(g.is_nan(), "post={post} idx={i}: cpu NaN, cuda {g}");
} else {
assert!((c - g).abs() < 1e-5, "post={post} idx={i}: cpu {c} cuda {g}");
}
}
}
Ok(())
})
}

proptest::proptest! {
#[test]
fn scaled_masked_softmax_prop_f32(pb in any::<ScaledMaskedSoftmaxProblem<f32>>()) {
Expand Down Expand Up @@ -269,7 +345,7 @@ mod tests {
let mask =
Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?.into_device()?;
let scale: Arc<_> = tensor0::<F>(0.125f32.as_()).into();
let cuda_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask)?;
let cuda_output = ScaledMaskedSoftmax.eval(stream, &a, &scale, &mask, false)?;
Ok(cuda_output.to_host()?.into_tensor())
})
}
Expand Down
Loading
Loading