From fa98104f5639df6af279e3ed29faff92b510f88d Mon Sep 17 00:00:00 2001 From: Julian Oppermann Date: Thu, 21 May 2026 05:35:03 -0700 Subject: [PATCH 1/3] Add LLM-generated level 1 kernels (1-50) for Triton CPU Co-authored-by: Marcin Spoczynski --- .../10_3D_tensor_matrix_multiplication.py | 139 ++++++++++ .../11_4D_tensor_matrix_multiplication.py | 158 ++++++++++++ .../12_Matmul_with_diagonal_matrices_.py | 85 +++++++ .../13_Matmul_for_symmetric_matrices.py | 156 ++++++++++++ ...14_Matmul_for_upper_triangular_matrices.py | 160 ++++++++++++ ...15_Matmul_for_lower_triangular_matrices.py | 151 +++++++++++ .../level1/16_Matmul_with_transposed_A.py | 137 ++++++++++ .../level1/17_Matmul_with_transposed_B.py | 142 +++++++++++ .../level1/18_Matmul_with_transposed_both.py | 138 ++++++++++ .../triton/cpu/KernelBench/level1/19_ReLU.py | 62 +++++ .../level1/1_Square_matrix_multiplication_.py | 155 ++++++----- .../cpu/KernelBench/level1/20_LeakyReLU.py | 60 +++++ .../cpu/KernelBench/level1/21_Sigmoid.py | 55 ++++ .../triton/cpu/KernelBench/level1/22_Tanh.py | 62 +++++ .../cpu/KernelBench/level1/23_Softmax.py | 92 +++++++ .../cpu/KernelBench/level1/24_LogSoftmax.py | 85 +++++++ .../triton/cpu/KernelBench/level1/25_Swish.py | 56 ++++ .../triton/cpu/KernelBench/level1/26_GELU_.py | 63 +++++ .../triton/cpu/KernelBench/level1/27_SELU_.py | 55 ++++ .../cpu/KernelBench/level1/28_HardSigmoid.py | 53 ++++ .../cpu/KernelBench/level1/29_Softplus.py | 52 ++++ .../2_Standard_matrix_multiplication_.py | 129 ++++++++++ .../cpu/KernelBench/level1/30_Softsign.py | 49 ++++ .../triton/cpu/KernelBench/level1/31_ELU.py | 60 +++++ .../cpu/KernelBench/level1/32_HardTanh.py | 50 ++++ .../cpu/KernelBench/level1/33_BatchNorm.py | 191 ++++++++++++++ .../cpu/KernelBench/level1/34_InstanceNorm.py | 82 ++++++ .../cpu/KernelBench/level1/35_GroupNorm_.py | 175 +++++++++++++ .../cpu/KernelBench/level1/36_RMSNorm_.py | 98 +++++++ .../KernelBench/level1/37_FrobeniusNorm_.py | 99 ++++++++ .../cpu/KernelBench/level1/38_L1Norm_.py | 79 ++++++ .../cpu/KernelBench/level1/39_L2Norm_.py | 75 ++++++ .../level1/3_Batched_matrix_multiplication.py | 151 +++++++++++ .../cpu/KernelBench/level1/40_LayerNorm.py | 108 ++++++++ .../KernelBench/level1/41_Max_Pooling_1D.py | 104 ++++++++ .../KernelBench/level1/42_Max_Pooling_2D.py | 112 ++++++++ .../KernelBench/level1/43_Max_Pooling_3D.py | 156 ++++++++++++ .../level1/44_Average_Pooling_1D.py | 112 ++++++++ .../level1/45_Average_Pooling_2D.py | 103 ++++++++ .../level1/46_Average_Pooling_3D.py | 142 +++++++++++ .../47_Sum_reduction_over_a_dimension.py | 92 +++++++ .../48_Mean_reduction_over_a_dimension.py | 96 +++++++ .../49_Max_reduction_over_a_dimension.py | 95 +++++++ .../level1/4_Matrix_vector_multiplication_.py | 66 +++++ ...tandard_2D__square_input__square_kernel.py | 240 ++++++++++++++++++ .../level1/5_Matrix_scalar_multiplication.py | 52 ++++ .../6_Matmul_with_large_K_dimension_.py | 138 ++++++++++ .../7_Matmul_with_small_K_dimension_.py | 141 ++++++++++ .../level1/8_Matmul_with_irregular_shapes_.py | 133 ++++++++++ .../9_Tall_skinny_matrix_multiplication_.py | 147 +++++++++++ .../10_3D_tensor_matrix_multiplication.yaml | 10 + .../11_4D_tensor_matrix_multiplication.yaml | 11 + .../12_Matmul_with_diagonal_matrices_.yaml | 8 + .../13_Matmul_for_symmetric_matrices.yaml | 7 + ..._Matmul_for_upper_triangular_matrices.yaml | 7 + ..._Matmul_for_lower_triangular_matrices.yaml | 7 + .../level1/16_Matmul_with_transposed_A.yaml | 9 + .../level1/17_Matmul_with_transposed_B.yaml | 9 + .../18_Matmul_with_transposed_both.yaml | 9 + .../1_Square_matrix_multiplication_.yaml | 4 +- .../KernelBench/level1/20_LeakyReLU.yaml | 8 + .../specs/KernelBench/level1/21_Sigmoid.yaml | 8 + .../specs/KernelBench/level1/22_Tanh.yaml | 8 + .../specs/KernelBench/level1/23_Softmax.yaml | 8 + .../KernelBench/level1/24_LogSoftmax.yaml | 8 + .../specs/KernelBench/level1/25_Swish.yaml | 8 + .../specs/KernelBench/level1/26_GELU_.yaml | 8 + .../specs/KernelBench/level1/27_SELU_.yaml | 8 + .../KernelBench/level1/28_HardSigmoid.yaml | 8 + .../specs/KernelBench/level1/29_Softplus.yaml | 8 + .../2_Standard_matrix_multiplication_.yaml | 8 +- .../specs/KernelBench/level1/30_Softsign.yaml | 8 + problems/specs/KernelBench/level1/31_ELU.yaml | 9 + .../specs/KernelBench/level1/32_HardTanh.yaml | 8 + .../KernelBench/level1/33_BatchNorm.yaml | 10 + .../KernelBench/level1/34_InstanceNorm.yaml | 10 + .../KernelBench/level1/35_GroupNorm_.yaml | 11 + .../specs/KernelBench/level1/36_RMSNorm_.yaml | 11 + .../KernelBench/level1/37_FrobeniusNorm_.yaml | 10 + .../specs/KernelBench/level1/38_L1Norm_.yaml | 8 + .../specs/KernelBench/level1/39_L2Norm_.yaml | 8 + .../3_Batched_matrix_multiplication.yaml | 10 + .../KernelBench/level1/40_LayerNorm.yaml | 11 + .../KernelBench/level1/41_Max_Pooling_1D.yaml | 14 + .../KernelBench/level1/42_Max_Pooling_2D.yaml | 14 + .../KernelBench/level1/43_Max_Pooling_3D.yaml | 14 +- .../level1/44_Average_Pooling_1D.yaml | 12 + .../level1/45_Average_Pooling_2D.yaml | 11 + .../level1/46_Average_Pooling_3D.yaml | 14 + .../47_Sum_reduction_over_a_dimension.yaml | 10 + .../48_Mean_reduction_over_a_dimension.yaml | 10 + .../49_Max_reduction_over_a_dimension.yaml | 10 + .../4_Matrix_vector_multiplication_.yaml | 9 + ...ndard_2D__square_input__square_kernel.yaml | 11 + .../5_Matrix_scalar_multiplication.yaml | 9 + .../6_Matmul_with_large_K_dimension_.yaml | 13 +- .../7_Matmul_with_small_K_dimension_.yaml | 9 + .../8_Matmul_with_irregular_shapes_.yaml | 9 + .../9_Tall_skinny_matrix_multiplication_.yaml | 8 + pyproject.toml | 2 +- 100 files changed, 5786 insertions(+), 67 deletions(-) create mode 100644 backends/triton/cpu/KernelBench/level1/10_3D_tensor_matrix_multiplication.py create mode 100644 backends/triton/cpu/KernelBench/level1/11_4D_tensor_matrix_multiplication.py create mode 100644 backends/triton/cpu/KernelBench/level1/12_Matmul_with_diagonal_matrices_.py create mode 100644 backends/triton/cpu/KernelBench/level1/13_Matmul_for_symmetric_matrices.py create mode 100644 backends/triton/cpu/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.py create mode 100644 backends/triton/cpu/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.py create mode 100644 backends/triton/cpu/KernelBench/level1/16_Matmul_with_transposed_A.py create mode 100644 backends/triton/cpu/KernelBench/level1/17_Matmul_with_transposed_B.py create mode 100644 backends/triton/cpu/KernelBench/level1/18_Matmul_with_transposed_both.py create mode 100644 backends/triton/cpu/KernelBench/level1/19_ReLU.py create mode 100644 backends/triton/cpu/KernelBench/level1/20_LeakyReLU.py create mode 100644 backends/triton/cpu/KernelBench/level1/21_Sigmoid.py create mode 100644 backends/triton/cpu/KernelBench/level1/22_Tanh.py create mode 100644 backends/triton/cpu/KernelBench/level1/23_Softmax.py create mode 100644 backends/triton/cpu/KernelBench/level1/24_LogSoftmax.py create mode 100644 backends/triton/cpu/KernelBench/level1/25_Swish.py create mode 100644 backends/triton/cpu/KernelBench/level1/26_GELU_.py create mode 100644 backends/triton/cpu/KernelBench/level1/27_SELU_.py create mode 100644 backends/triton/cpu/KernelBench/level1/28_HardSigmoid.py create mode 100644 backends/triton/cpu/KernelBench/level1/29_Softplus.py create mode 100644 backends/triton/cpu/KernelBench/level1/2_Standard_matrix_multiplication_.py create mode 100644 backends/triton/cpu/KernelBench/level1/30_Softsign.py create mode 100644 backends/triton/cpu/KernelBench/level1/31_ELU.py create mode 100644 backends/triton/cpu/KernelBench/level1/32_HardTanh.py create mode 100644 backends/triton/cpu/KernelBench/level1/33_BatchNorm.py create mode 100644 backends/triton/cpu/KernelBench/level1/34_InstanceNorm.py create mode 100644 backends/triton/cpu/KernelBench/level1/35_GroupNorm_.py create mode 100644 backends/triton/cpu/KernelBench/level1/36_RMSNorm_.py create mode 100644 backends/triton/cpu/KernelBench/level1/37_FrobeniusNorm_.py create mode 100644 backends/triton/cpu/KernelBench/level1/38_L1Norm_.py create mode 100644 backends/triton/cpu/KernelBench/level1/39_L2Norm_.py create mode 100644 backends/triton/cpu/KernelBench/level1/3_Batched_matrix_multiplication.py create mode 100644 backends/triton/cpu/KernelBench/level1/40_LayerNorm.py create mode 100644 backends/triton/cpu/KernelBench/level1/41_Max_Pooling_1D.py create mode 100644 backends/triton/cpu/KernelBench/level1/42_Max_Pooling_2D.py create mode 100644 backends/triton/cpu/KernelBench/level1/43_Max_Pooling_3D.py create mode 100644 backends/triton/cpu/KernelBench/level1/44_Average_Pooling_1D.py create mode 100644 backends/triton/cpu/KernelBench/level1/45_Average_Pooling_2D.py create mode 100644 backends/triton/cpu/KernelBench/level1/46_Average_Pooling_3D.py create mode 100644 backends/triton/cpu/KernelBench/level1/47_Sum_reduction_over_a_dimension.py create mode 100644 backends/triton/cpu/KernelBench/level1/48_Mean_reduction_over_a_dimension.py create mode 100644 backends/triton/cpu/KernelBench/level1/49_Max_reduction_over_a_dimension.py create mode 100644 backends/triton/cpu/KernelBench/level1/4_Matrix_vector_multiplication_.py create mode 100644 backends/triton/cpu/KernelBench/level1/50_conv_standard_2D__square_input__square_kernel.py create mode 100644 backends/triton/cpu/KernelBench/level1/5_Matrix_scalar_multiplication.py create mode 100644 backends/triton/cpu/KernelBench/level1/6_Matmul_with_large_K_dimension_.py create mode 100644 backends/triton/cpu/KernelBench/level1/7_Matmul_with_small_K_dimension_.py create mode 100644 backends/triton/cpu/KernelBench/level1/8_Matmul_with_irregular_shapes_.py create mode 100644 backends/triton/cpu/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.py diff --git a/backends/triton/cpu/KernelBench/level1/10_3D_tensor_matrix_multiplication.py b/backends/triton/cpu/KernelBench/level1/10_3D_tensor_matrix_multiplication.py new file mode 100644 index 0000000..c50d8a3 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/10_3D_tensor_matrix_multiplication.py @@ -0,0 +1,139 @@ +# ruff: noqa: E731, E741 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _configs(): + return [ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ] + + +@triton.autotune(configs=_configs(), key=["M", "N", "K"]) +@triton.jit +def _matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a_tile = a_desc.load([pid_m * BLOCK_M, off_k]) + b_tile = b_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a_tile, b_tile) + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(c_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A, B): + batch, m, k = A.shape + _, l = B.shape + + a = A.to(torch.bfloat16).contiguous() + b = B.to(torch.bfloat16).contiguous() + + a_flat = a.reshape(batch * m, k) + total_m = batch * m + + c_flat = torch.empty((total_m, l), device=a.device, dtype=torch.bfloat16) + + def grid(META): + return ( + triton.cdiv(total_m, META["BLOCK_M"]) * triton.cdiv(l, META["BLOCK_N"]), + ) + + _matmul_kernel[grid]( + a_flat, + b, + c_flat, + total_m, + l, + k, + a_flat.stride(0), + a_flat.stride(1), + b.stride(0), + b.stride(1), + c_flat.stride(0), + c_flat.stride(1), + ) + + return c_flat.reshape(batch, m, l) diff --git a/backends/triton/cpu/KernelBench/level1/11_4D_tensor_matrix_multiplication.py b/backends/triton/cpu/KernelBench/level1/11_4D_tensor_matrix_multiplication.py new file mode 100644 index 0000000..0ea0b6c --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/11_4D_tensor_matrix_multiplication.py @@ -0,0 +1,158 @@ +# ruff: noqa: E731, E741 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def swizzle_tile( + tile_id, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + width = GROUP_SIZE_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(GROUP_SIZE_M, grid_m - group_id * GROUP_SIZE_M) + pid_m = group_id * GROUP_SIZE_M + ((tile_id % width) % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +def get_autotune_configs(): + return [ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ] + + +@triton.autotune( + configs=get_autotune_configs(), + key=["M", "N", "K"], +) +@triton.jit +def _gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_M, BLOCK_N, GROUP_SIZE_M) + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a_block = a_desc.load([pid_m * BLOCK_M, off_k]) + b_block = b_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a_block, b_block) + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(c_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A, B): + b_dim, i_dim, j_dim, l_dim = A.shape + k_dim = B.shape[1] + + A_flat = A.contiguous().view(-1, l_dim) + if A_flat.dtype != torch.bfloat16: + A_flat = A_flat.to(torch.bfloat16) + B_fp16 = B.contiguous() + if B_fp16.dtype != torch.bfloat16: + B_fp16 = B_fp16.to(torch.bfloat16) + + M = A_flat.shape[0] + N = k_dim + K = l_dim + + C_2d = torch.empty((M, N), device=A.device, dtype=torch.bfloat16) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + _gemm_kernel[grid]( + A_flat, + B_fp16, + C_2d, + M, + N, + K, + A_flat.stride(0), + A_flat.stride(1), + B_fp16.stride(0), + B_fp16.stride(1), + C_2d.stride(0), + C_2d.stride(1), + ) + + result = C_2d.view(b_dim, i_dim, j_dim, k_dim) + if A.dtype != torch.bfloat16: + result = result.to(A.dtype) + return result diff --git a/backends/triton/cpu/KernelBench/level1/12_Matmul_with_diagonal_matrices_.py b/backends/triton/cpu/KernelBench/level1/12_Matmul_with_diagonal_matrices_.py new file mode 100644 index 0000000..c14ea7b --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/12_Matmul_with_diagonal_matrices_.py @@ -0,0 +1,85 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=32, num_stages=2), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=32, num_stages=3), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=32, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=32, num_stages=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=32, num_stages=2), + ], + key=["N", "M"], +) +@triton.jit +def _diag_matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + N, + M, + stride_bn, + stride_bm, + stride_cn, + stride_cm, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_m = tl.program_id(1) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + mask_n = offs_n < N + mask_m = offs_m < M + + a_vals = tl.load(a_ptr + offs_n, mask=mask_n, other=0.0) + + b_ptrs = b_ptr + offs_n[:, None] * stride_bn + offs_m[None, :] * stride_bm + mask = mask_n[:, None] & mask_m[None, :] + b_vals = tl.load(b_ptrs, mask=mask, other=0.0) + + c_vals = a_vals[:, None].to(tl.float32) * b_vals.to(tl.float32) + + c_ptrs = c_ptr + offs_n[:, None] * stride_cn + offs_m[None, :] * stride_cm + tl.store(c_ptrs, c_vals.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A, B): + N = A.shape[0] + M = B.shape[1] + + C = torch.empty((N, M), device=A.device, dtype=A.dtype) + + grid = lambda META: ( + triton.cdiv(N, META["BLOCK_N"]), + triton.cdiv(M, META["BLOCK_M"]), + ) + + _diag_matmul_kernel[grid]( + A, + B, + C, + N, + M, + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + ) + + return C diff --git a/backends/triton/cpu/KernelBench/level1/13_Matmul_for_symmetric_matrices.py b/backends/triton/cpu/KernelBench/level1/13_Matmul_for_symmetric_matrices.py new file mode 100644 index 0000000..36ce840 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/13_Matmul_for_symmetric_matrices.py @@ -0,0 +1,156 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def swizzle_tile( + tile_id, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + width = GROUP_SIZE_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(GROUP_SIZE_M, grid_m - group_id * GROUP_SIZE_M) + pid_m = group_id * GROUP_SIZE_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +def _configs(): + return [ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ] + + +@triton.autotune(configs=_configs(), key=["M", "N", "K"]) +@triton.jit +def _matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + DIVISIBLE_M: tl.constexpr, + DIVISIBLE_N: tl.constexpr, + DIVISIBLE_K: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_M, BLOCK_N, GROUP_SIZE_M) + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + off_m = pid_m * BLOCK_M + off_n = pid_n * BLOCK_N + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for off_k in range(0, K, BLOCK_K): + a_tile = a_desc.load([off_m, off_k]) + b_tile = b_desc.load([off_k, off_n]) + acc += tl.dot(a_tile, b_tile) + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + c_desc.store([off_m, off_n], acc.to(c_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A, B): + if A.dtype != torch.bfloat16: + A = A.to(torch.bfloat16) + if B.dtype != torch.bfloat16: + B = B.to(torch.bfloat16) + if not A.is_contiguous(): + A = A.contiguous() + if not B.is_contiguous(): + B = B.contiguous() + + M, K = A.shape + N = B.shape[1] + C = torch.empty((M, N), device=A.device, dtype=torch.bfloat16) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + _matmul_kernel[grid]( + A, + B, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + DIVISIBLE_M=(M % 256 == 0), + DIVISIBLE_N=(N % 128 == 0), + DIVISIBLE_K=(K % 32 == 0), + ) + return C diff --git a/backends/triton/cpu/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.py b/backends/triton/cpu/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.py new file mode 100644 index 0000000..97669fb --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.py @@ -0,0 +1,160 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def swizzle_tile( + tile_id, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + width = GROUP_SIZE_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(GROUP_SIZE_M, grid_m - group_id * GROUP_SIZE_M) + pid_m = group_id * GROUP_SIZE_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def _triu_matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + pid_m, pid_n = swizzle_tile(pid, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_SIZE_M) + + off_m = pid_m * BLOCK_M + off_n = pid_n * BLOCK_N + + # Skip tiles entirely below the diagonal (output is upper triangular) + if off_m >= off_n + BLOCK_N: + return + + # K-loop trimming: A is upper tri so A[i,k]=0 for kj + # Effective K range: [off_m, min(off_n + BLOCK_N, K)) + k_start = (off_m // BLOCK_K) * BLOCK_K + k_end_raw = off_n + BLOCK_N + k_end = k_end_raw if k_end_raw < K else K + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k_offset in range(k_start, k_end, BLOCK_K): + a_block = a_desc.load([off_m, k_offset]) + b_block = b_desc.load([k_offset, off_n]) + acc += tl.dot(a_block, b_block) + # Apply triu mask + row_idx = off_m + tl.arange(0, BLOCK_M) + col_idx = off_n + tl.arange(0, BLOCK_N) + triu_mask = row_idx[:, None] <= col_idx[None, :] + acc = tl.where(triu_mask, acc, 0.0) + + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + c_desc.store([off_m, off_n], acc.to(c_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, A, B): + M, K = A.shape + N = B.shape[1] + C = torch.zeros((M, N), device=A.device, dtype=A.dtype) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + _triu_matmul_kernel[grid]( + A, + B, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + ) + return C diff --git a/backends/triton/cpu/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.py b/backends/triton/cpu/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.py new file mode 100644 index 0000000..7f6b7ff --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.py @@ -0,0 +1,151 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def swizzle_tile( + tile_id, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + width = GROUP_SIZE_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(GROUP_SIZE_M, grid_m - group_id * GROUP_SIZE_M) + pid_m = group_id * GROUP_SIZE_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ], + key=["M"], +) +@triton.jit +def tril_matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + pid_m, pid_n = swizzle_tile(pid, M, M, BLOCK_M, BLOCK_N, GROUP_SIZE_M) + + off_m = pid_m * BLOCK_M + off_n = pid_n * BLOCK_N + + # Skip tiles entirely in the upper triangle + if off_n > off_m + BLOCK_M - 1: + return + + # K-range optimization for triangular matrices: + # A is lower triangular: A[i,k]=0 for k>i, so max useful K = off_m + BLOCK_M + # B is lower triangular: B[k,j]=0 for j>k, so min useful K = off_n + # Align to BLOCK_K boundaries + k_start = (off_n // BLOCK_K) * BLOCK_K + k_end_raw = off_m + BLOCK_M + k_end = tl.minimum(k_end_raw, M) + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(M, M), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=(M, M), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(k_start, k_end, BLOCK_K): + a_tile = a_desc.load([off_m, k]) + b_tile = b_desc.load([k, off_n]) + acc += tl.dot(a_tile, b_tile) + # Apply tril mask + row_idx = off_m + tl.arange(0, BLOCK_M) + col_idx = off_n + tl.arange(0, BLOCK_N) + tril_mask = row_idx[:, None] >= col_idx[None, :] + acc = tl.where(tril_mask, acc, 0.0) + + # Store using raw pointers + bounds_mask = (row_idx[:, None] < M) & (col_idx[None, :] < M) + c_ptrs = c_ptr + row_idx[:, None] * stride_cm + col_idx[None, :] * stride_cn + tl.store(c_ptrs, acc.to(c_ptr.type.element_ty), mask=bounds_mask & tril_mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A, B): + M = A.shape[0] + C = torch.zeros(M, M, device=A.device, dtype=A.dtype) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(M, META["BLOCK_N"]), + ) + tril_matmul_kernel[grid]( + A, + B, + C, + M, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + ) + return C diff --git a/backends/triton/cpu/KernelBench/level1/16_Matmul_with_transposed_A.py b/backends/triton/cpu/KernelBench/level1/16_Matmul_with_transposed_A.py new file mode 100644 index 0000000..2b7f9ea --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/16_Matmul_with_transposed_A.py @@ -0,0 +1,137 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def get_autotune_configs(): + return [ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ] + + +@triton.autotune( + configs=get_autotune_configs(), + key=["M", "N", "K"], +) +@triton.jit +def _matmul_at_kernel( + A_ptr, + B_ptr, + C_ptr, + M, + N, + K, + stride_ak: tl.constexpr, + stride_am: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # A is [K, M] — load [BLOCK_K, BLOCK_M] tiles, transpose in register + A_desc = tl.make_tensor_descriptor( + base=A_ptr, + shape=(K, M), + strides=(stride_ak, stride_am), + block_shape=(BLOCK_K, BLOCK_M), + ) + + # B is [K, N] — load [BLOCK_K, BLOCK_N] tiles + B_desc = tl.make_tensor_descriptor( + base=B_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + A_tile = A_desc.load([off_k, pid_m * BLOCK_M]) + B_tile = B_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(A_tile.T, B_tile) + C_desc = tl.make_tensor_descriptor( + base=C_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + C_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(C_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + K, M = A.shape + _, N = B.shape + A = A.contiguous() + B = B.contiguous() + C = torch.empty((M, N), device=A.device, dtype=A.dtype) + + def grid(META): + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _matmul_at_kernel[grid]( + A, + B, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + ) + return C diff --git a/backends/triton/cpu/KernelBench/level1/17_Matmul_with_transposed_B.py b/backends/triton/cpu/KernelBench/level1/17_Matmul_with_transposed_B.py new file mode 100644 index 0000000..75e8da9 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/17_Matmul_with_transposed_B.py @@ -0,0 +1,142 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def get_autotune_configs(): + return [ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ] + + +@triton.autotune( + configs=get_autotune_configs(), + key=["M", "N", "K"], +) +@triton.jit +def _matmul_bt_kernel( + A_ptr, + B_ptr, + C_ptr, + M, + N, + K, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bn: tl.constexpr, + stride_bk: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # A is [M, K] — load [BLOCK_M, BLOCK_K] tiles in natural layout + a_desc = tl.make_tensor_descriptor( + A_ptr, + shape=[M, K], + strides=[stride_am, stride_ak], + block_shape=[BLOCK_M, BLOCK_K], + ) + + # B is [N, K] — load [BLOCK_N, BLOCK_K] tiles in natural layout, then transpose + b_desc = tl.make_tensor_descriptor( + B_ptr, + shape=[N, K], + strides=[stride_bn, stride_bk], + block_shape=[BLOCK_N, BLOCK_K], + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + A_tile = a_desc.load([pid_m * BLOCK_M, off_k]) + B_tile = b_desc.load([pid_n * BLOCK_N, off_k]) + + # Transpose B in-register: [BLOCK_N, BLOCK_K] -> [BLOCK_K, BLOCK_N] + acc += tl.dot(A_tile, B_tile.T) + c_desc = tl.make_tensor_descriptor( + C_ptr, + shape=[M, N], + strides=[stride_cm, stride_cn], + block_shape=[BLOCK_M, BLOCK_N], + ) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(C_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + M, K = A.shape + N, _ = B.shape + + A = A.contiguous() + B = B.contiguous() + + C = torch.empty((M, N), device=A.device, dtype=A.dtype) + + def grid(META): + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _matmul_bt_kernel[grid]( + A, + B, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + ) + + return C diff --git a/backends/triton/cpu/KernelBench/level1/18_Matmul_with_transposed_both.py b/backends/triton/cpu/KernelBench/level1/18_Matmul_with_transposed_both.py new file mode 100644 index 0000000..5849576 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/18_Matmul_with_transposed_both.py @@ -0,0 +1,138 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _configs(): + return [ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ] + + +@triton.autotune(configs=_configs(), key=["M", "N", "K"]) +@triton.jit +def _matmul_tt_kernel( + A_ptr, + B_ptr, + C_ptr, + M, + N, + K, + stride_ak: tl.constexpr, + stride_am: tl.constexpr, + stride_bn: tl.constexpr, + stride_bk: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # A is (K, M): load (BLOCK_K, BLOCK_M), transpose in-register + a_desc = tl.make_tensor_descriptor( + A_ptr, + shape=[K, M], + strides=[stride_ak, stride_am], + block_shape=[BLOCK_K, BLOCK_M], + ) + + # B is (N, K): load (BLOCK_N, BLOCK_K), transpose in-register + b_desc = tl.make_tensor_descriptor( + B_ptr, + shape=[N, K], + strides=[stride_bn, stride_bk], + block_shape=[BLOCK_N, BLOCK_K], + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = a_desc.load([off_k, pid_m * BLOCK_M]) + b = b_desc.load([pid_n * BLOCK_N, off_k]) + + acc += tl.dot(a.T, b.T) + c_desc = tl.make_tensor_descriptor( + C_ptr, + shape=[M, N], + strides=[stride_cm, stride_cn], + block_shape=[BLOCK_M, BLOCK_N], + ) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(C_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + K, M = A.shape + N, _ = B.shape + + A = A.contiguous() + B = B.contiguous() + + C = torch.empty((M, N), device=A.device, dtype=A.dtype) + + def grid(META): + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _matmul_tt_kernel[grid]( + A, + B, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + ) + + return C diff --git a/backends/triton/cpu/KernelBench/level1/19_ReLU.py b/backends/triton/cpu/KernelBench/level1/19_ReLU.py new file mode 100644 index 0000000..f76813e --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/19_ReLU.py @@ -0,0 +1,62 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE": 4096, "NUM_PROGRAMS": 256}, num_warps=8, num_stages=3 + ), + triton.Config( + {"BLOCK_SIZE": 8192, "NUM_PROGRAMS": 256}, num_warps=16, num_stages=3 + ), + triton.Config( + {"BLOCK_SIZE": 8192, "NUM_PROGRAMS": 512}, num_warps=16, num_stages=3 + ), + triton.Config( + {"BLOCK_SIZE": 4096, "NUM_PROGRAMS": 512}, num_warps=8, num_stages=2 + ), + triton.Config( + {"BLOCK_SIZE": 16384, "NUM_PROGRAMS": 256}, num_warps=16, num_stages=3 + ), + ], + key=["n_elements"], +) +@triton.jit +def relu_kernel_persistent( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + NUM_PROGRAMS: tl.constexpr, +): + pid = tl.program_id(0) + num_blocks = tl.cdiv(n_elements, BLOCK_SIZE) + + for block_id in tl.range(pid, num_blocks, NUM_PROGRAMS): + block_start = block_id * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + output = tl.maximum(x, 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = torch.empty_like(x) + n_elements = x.numel() + grid = lambda META: (META["NUM_PROGRAMS"],) + relu_kernel_persistent[grid](x, output, n_elements) + return output diff --git a/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py b/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py index 1582894..e873a3a 100644 --- a/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py +++ b/backends/triton/cpu/KernelBench/level1/1_Square_matrix_multiplication_.py @@ -1,5 +1,6 @@ # ruff: noqa: E731 -# Example Triton CPU kernel +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation # Status: Experimental / uncurated # Expectation: Correctness-first, performance not representative @@ -9,13 +10,37 @@ import triton.language as tl -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 32}), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}), - ], - key=["M", "N", "K"], # autotune per problem size -) +def _configs(): + return [ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ] + + +@triton.autotune(configs=_configs(), key=["M", "N", "K"]) @triton.jit def _matmul_kernel( a_ptr, @@ -24,70 +49,82 @@ def _matmul_kernel( M, N, K, + stride_am, + stride_ak: tl.constexpr, + stride_bk, + stride_bn: tl.constexpr, + stride_cm, + stride_cn: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, ): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + a_desc = tl.make_tensor_descriptor( - base=a_ptr, shape=(M, K), strides=(K, 1), block_shape=(BLOCK_M, BLOCK_K) + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), ) b_desc = tl.make_tensor_descriptor( - base=b_ptr, shape=(K, N), strides=(N, 1), block_shape=(BLOCK_K, BLOCK_N) - ) - c_desc = tl.make_tensor_descriptor( - base=c_ptr, shape=(M, N), strides=(N, 1), block_shape=(BLOCK_M, BLOCK_N) + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), ) - m = tl.program_id(0) * BLOCK_M - n = tl.program_id(1) * BLOCK_N acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(0, K, BLOCK_K): - a = a_desc.load((m, k)) - b = b_desc.load((k, n)) - acc = tl.dot(a, b, acc) - - c_desc.store((m, n), acc) - - -def _kernel_function_cpu(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: - assert isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor) - assert A.device.type == "cpu" and B.device.type == "cpu", "A and B must be on CPU" - assert A.is_floating_point() and B.is_floating_point(), ( - "A and B must be floating point tensors" - ) - assert A.dtype == B.dtype, f"dtype mismatch: {A.dtype} vs {B.dtype}" - - orig_dtype = A.dtype - - M, K = A.shape - K2, N = B.shape - assert K == K2, f"Incompatible K dimensions: {K} vs {K2}" - - C32 = torch.empty((M, N), device=A.device, dtype=torch.float32) - - # Autotuned grid: depends on BLOCK_M/BLOCK_N chosen by config - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_M"]), - triton.cdiv(N, META["BLOCK_N"]), - ) - - _matmul_kernel[grid]( - A, - B, - C32, - M, - N, - K, + for off_k in range(0, K, BLOCK_K): + a_tile = a_desc.load([pid_m * BLOCK_M, off_k]) + b_tile = b_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a_tile, b_tile) + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), ) - - return C32.to(orig_dtype) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(c_ptr.type.element_ty)) class Model(nn.Module): - """KernelBench-compatible wrapper""" - - def __init__(self, *args, **kwargs): - super(Model, self).__init__() + def __init__(self): + super().__init__() def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: - return _kernel_function_cpu(A, B) + M, K = A.shape + K2, N = B.shape + + A = A.contiguous() + B = B.contiguous() + + C = torch.empty((M, N), device=A.device, dtype=A.dtype) + + def grid(META): + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _matmul_kernel[grid]( + A, + B, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + ) + return C diff --git a/backends/triton/cpu/KernelBench/level1/20_LeakyReLU.py b/backends/triton/cpu/KernelBench/level1/20_LeakyReLU.py new file mode 100644 index 0000000..cb900a0 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/20_LeakyReLU.py @@ -0,0 +1,60 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + ], + key=["n_elements"], +) +@triton.jit +def leaky_relu_kernel( + x_ptr, + output_ptr, + neg_slope_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + neg_slope = tl.load(neg_slope_ptr) + + x = tl.load(x_ptr + offsets, mask=mask) + out = tl.where(x >= 0, x, x * neg_slope) + tl.store(output_ptr + offsets, out, mask=mask) + + +class Model(nn.Module): + def __init__(self, negative_slope: float = 0.01): + super(Model, self).__init__() + self.negative_slope = negative_slope + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = torch.empty_like(x) + n_elements = x.numel() + neg_slope_t = torch.tensor( + [self.negative_slope], dtype=x.dtype, device=x.device + ) + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + leaky_relu_kernel[grid]( + x, + output, + neg_slope_t, + n_elements, + ) + return output diff --git a/backends/triton/cpu/KernelBench/level1/21_Sigmoid.py b/backends/triton/cpu/KernelBench/level1/21_Sigmoid.py new file mode 100644 index 0000000..b4b21f6 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/21_Sigmoid.py @@ -0,0 +1,55 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + ], + key=["N"], +) +@triton.jit +def _sigmoid_kernel( + x_ptr, + out_ptr, + N, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + + inv_ln2 = 1.4426950408889634 + e = tl.math.exp2((-x) * inv_ln2) + y = 1.0 / (1.0 + e) + + tl.store(out_ptr + offsets, y.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_flat = x.contiguous().view(-1) + N = x_flat.numel() + out_flat = torch.empty_like(x_flat) + + grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),) + _sigmoid_kernel[grid](x_flat, out_flat, N) + + return out_flat.view_as(x) diff --git a/backends/triton/cpu/KernelBench/level1/22_Tanh.py b/backends/triton/cpu/KernelBench/level1/22_Tanh.py new file mode 100644 index 0000000..c07a6f1 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/22_Tanh.py @@ -0,0 +1,62 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +batch_size = 4096 +dim = 393216 + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + ], + key=["n_elements"], +) +@triton.jit +def _tanh_kernel( + x_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + + # tanh(x) = 2*sigmoid(2x) - 1 + # sigmoid(z) = 1/(1 + exp2(-z * log2(e))) + inv_ln2: tl.constexpr = 1.4426950408889634 + z = 2.0 * x + e = tl.math.exp2((-z) * inv_ln2) + sig = 1.0 / (1.0 + e) + result = 2.0 * sig - 1.0 + + tl.store(out_ptr + offsets, result.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_flat = x.contiguous().view(-1) + n_elements = x_flat.numel() + output_flat = torch.empty_like(x_flat) + + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + _tanh_kernel[grid](x_flat, output_flat, n_elements) + + return output_flat.view(x.shape) diff --git a/backends/triton/cpu/KernelBench/level1/23_Softmax.py b/backends/triton/cpu/KernelBench/level1/23_Softmax.py new file mode 100644 index 0000000..26ce427 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/23_Softmax.py @@ -0,0 +1,92 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _softmax_configs(): + return [ + triton.Config({"BLOCK_N": 2048}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_N": 4096}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_N": 4096}, num_warps=16, num_stages=3), + triton.Config({"BLOCK_N": 8192}, num_warps=16, num_stages=3), + triton.Config({"BLOCK_N": 16384}, num_warps=16, num_stages=2), + ] + + +@triton.autotune(configs=_softmax_configs(), key=["N"]) +@triton.jit +def _softmax_kernel( + inp_ptr, + out_ptr, + M, + N, + stride_im, + stride_in, + stride_om, + stride_on, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + row_inp = inp_ptr + pid_m * stride_im + row_out = out_ptr + pid_m * stride_om + + LOG2E: tl.constexpr = 1.4426950408889634 + + # Pass 1: Online max + sum_exp + row_max = -float("inf") + row_sum = 0.0 + for start in range(0, N, BLOCK_N): + offs = start + tl.arange(0, BLOCK_N) + mask = offs < N + x = tl.load(row_inp + offs * stride_in, mask=mask, other=-float("inf")).to( + tl.float32 + ) + block_max = tl.max(x, axis=0) + new_max = tl.maximum(row_max, block_max) + row_sum = row_sum * tl.math.exp2((row_max - new_max) * LOG2E) + tl.sum( + tl.math.exp2((x - new_max) * LOG2E), axis=0 + ) + row_max = new_max + + inv_sum = 1.0 / row_sum + + # Pass 2: normalize and store + for start in range(0, N, BLOCK_N): + offs = start + tl.arange(0, BLOCK_N) + mask = offs < N + x = tl.load(row_inp + offs * stride_in, mask=mask, other=-float("inf")).to( + tl.float32 + ) + e = tl.math.exp2((x - row_max) * LOG2E) + y = (e * inv_sum).to(tl.bfloat16) + tl.store(row_out + offs * stride_on, y, mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.contiguous() + M, N = x.shape + out = torch.empty_like(x) + + grid = (M,) + _softmax_kernel[grid]( + x, + out, + M, + N, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + ) + return out diff --git a/backends/triton/cpu/KernelBench/level1/24_LogSoftmax.py b/backends/triton/cpu/KernelBench/level1/24_LogSoftmax.py new file mode 100644 index 0000000..ab7bdcd --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/24_LogSoftmax.py @@ -0,0 +1,85 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 2048, "warp_size": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 4096, "warp_size": 32}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_N": 4096, "warp_size": 16}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_N": 8192, "warp_size": 32}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_N": 8192, "warp_size": 16}, num_warps=16, num_stages=2), + ], + key=["N"], +) +@triton.jit +def _logsoftmax_kernel( + inp_ptr, + out_ptr, + M, + N, + stride_im, + stride_om, + BLOCK_N: tl.constexpr, + warp_size: tl.constexpr, +): + pid_m = tl.program_id(0) + row_inp = inp_ptr + pid_m.to(tl.int64) * stride_im + row_out = out_ptr + pid_m.to(tl.int64) * stride_om + + LOG2E = 1.4426950408889634 + LN2 = 0.6931471805599453 + + m = -float("inf") + s = 0.0 + + for start in range(0, N, BLOCK_N): + offs = start + tl.arange(0, BLOCK_N) + mask = offs < N + x = tl.load(row_inp + offs, mask=mask, other=-float("inf")).to(tl.float32) + block_max = tl.max(x, axis=0) + m_new = tl.maximum(m, block_max) + s = s * tl.math.exp2((m - m_new) * LOG2E) + tl.sum( + tl.math.exp2((x - m_new) * LOG2E), axis=0 + ) + m = m_new + + log_s = tl.math.log2(s) * LN2 + + for start in range(0, N, BLOCK_N): + offs = start + tl.arange(0, BLOCK_N) + mask = offs < N + x = tl.load(row_inp + offs, mask=mask, other=-float("inf")).to(tl.float32) + y = x - m - log_s + tl.store(row_out + offs, y.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self, dim: int = 1): + super(Model, self).__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.contiguous() + + M, N = x.shape + out = torch.empty_like(x) + + grid = (M,) + _logsoftmax_kernel[grid]( + x, + out, + M, + N, + x.stride(0), + out.stride(0), + ) + return out diff --git a/backends/triton/cpu/KernelBench/level1/25_Swish.py b/backends/triton/cpu/KernelBench/level1/25_Swish.py new file mode 100644 index 0000000..73b8228 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/25_Swish.py @@ -0,0 +1,56 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def _sigmoid_exp2(x): + inv_ln2 = 1.4426950408889634 + e = tl.math.exp2((-x) * inv_ln2) + return 1.0 / (1.0 + e) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=2), + ], + key=["n_elements"], +) +@triton.jit +def swish_kernel( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + x_f32 = x.to(tl.float32) + sig = _sigmoid_exp2(x_f32) + result = x_f32 * sig + tl.store(output_ptr + offsets, result.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = torch.empty_like(x) + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + swish_kernel[grid](x, output, n_elements) + return output diff --git a/backends/triton/cpu/KernelBench/level1/26_GELU_.py b/backends/triton/cpu/KernelBench/level1/26_GELU_.py new file mode 100644 index 0000000..eafefec --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/26_GELU_.py @@ -0,0 +1,63 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE": 4096, "NUM_PROGS": 160}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_SIZE": 4096, "NUM_PROGS": 640}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_SIZE": 8192, "NUM_PROGS": 160}, num_warps=8, num_stages=2 + ), + triton.Config( + {"BLOCK_SIZE": 8192, "NUM_PROGS": 320}, num_warps=8, num_stages=2 + ), + triton.Config( + {"BLOCK_SIZE": 8192, "NUM_PROGS": 640}, num_warps=8, num_stages=2 + ), + ], + key=["n_elements"], +) +@triton.jit +def gelu_persistent_kernel( + x_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + NUM_PROGS: tl.constexpr, +): + pid = tl.program_id(0) + num_tiles = tl.cdiv(n_elements, BLOCK_SIZE) + + for tile_id in range(pid, num_tiles, NUM_PROGS): + offsets = tile_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + out = 0.5 * x * (1.0 + tl.math.erf(x * 0.70710678118654752440)) + tl.store(out_ptr + offsets, out.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_flat = x.view(-1) + out_flat = torch.empty_like(x_flat) + n_elements = x_flat.numel() + grid = lambda META: (META["NUM_PROGS"],) + gelu_persistent_kernel[grid](x_flat, out_flat, n_elements) + return out_flat.view_as(x) diff --git a/backends/triton/cpu/KernelBench/level1/27_SELU_.py b/backends/triton/cpu/KernelBench/level1/27_SELU_.py new file mode 100644 index 0000000..7dd616c --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/27_SELU_.py @@ -0,0 +1,55 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + ], + key=["n_elements"], +) +@triton.jit +def selu_kernel( + x_ptr, + out_ptr, + n_elements, + alpha, + scale, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + x_f32 = x.to(tl.float32) + + result = tl.where(x_f32 > 0.0, scale * x_f32, scale * alpha * (tl.exp(x_f32) - 1.0)) + + tl.store(out_ptr + offsets, result.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.alpha = 1.6732632423543772848170429916717 + self.scale = 1.0507009873554804934193349852946 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + selu_kernel[grid](x, output, n_elements, self.alpha, self.scale) + return output diff --git a/backends/triton/cpu/KernelBench/level1/28_HardSigmoid.py b/backends/triton/cpu/KernelBench/level1/28_HardSigmoid.py new file mode 100644 index 0000000..75ad83f --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/28_HardSigmoid.py @@ -0,0 +1,53 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + ], + key=["N"], +) +@triton.jit +def hardsigmoid_kernel( + x_ptr, + out_ptr, + N, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + x_f32 = x.to(tl.float32) + + result = x_f32 * (1.0 / 6.0) + 0.5 + result = tl.maximum(result, 0.0) + result = tl.minimum(result, 1.0) + + tl.store(out_ptr + offsets, result.to(x.dtype), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + N = x.numel() + grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),) + hardsigmoid_kernel[grid](x, out, N) + return out diff --git a/backends/triton/cpu/KernelBench/level1/29_Softplus.py b/backends/triton/cpu/KernelBench/level1/29_Softplus.py new file mode 100644 index 0000000..059d054 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/29_Softplus.py @@ -0,0 +1,52 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + ], + key=["n_elements"], +) +@triton.jit +def softplus_kernel( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + + # softplus(x) = log(1 + exp(x)), with threshold for numerical stability + THRESHOLD: tl.constexpr = 20.0 + result = tl.where(x > THRESHOLD, x, tl.math.log(1.0 + tl.exp(x))) + + tl.store(output_ptr + offsets, result.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = torch.empty_like(x) + n_elements = x.numel() + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + softplus_kernel[grid](x, output, n_elements) + return output diff --git a/backends/triton/cpu/KernelBench/level1/2_Standard_matrix_multiplication_.py b/backends/triton/cpu/KernelBench/level1/2_Standard_matrix_multiplication_.py new file mode 100644 index 0000000..9ffd423 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/2_Standard_matrix_multiplication_.py @@ -0,0 +1,129 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _configs(): + return [ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ] + + +@triton.autotune(configs=_configs(), key=["M", "N", "K"]) +@triton.jit +def _matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a_tile = a_desc.load([pid_m * BLOCK_M, off_k]) + b_tile = b_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a_tile, b_tile) + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(c_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + A_fp16 = A.to(torch.bfloat16).contiguous() + B_fp16 = B.to(torch.bfloat16).contiguous() + M, K = A_fp16.shape + N = B_fp16.shape[1] + C = torch.empty((M, N), device=A.device, dtype=torch.bfloat16) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + _matmul_kernel[grid]( + A_fp16, + B_fp16, + C, + M, + N, + K, + A_fp16.stride(0), + A_fp16.stride(1), + B_fp16.stride(0), + B_fp16.stride(1), + C.stride(0), + C.stride(1), + ) + return C diff --git a/backends/triton/cpu/KernelBench/level1/30_Softsign.py b/backends/triton/cpu/KernelBench/level1/30_Softsign.py new file mode 100644 index 0000000..aa899b1 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/30_Softsign.py @@ -0,0 +1,49 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + ], + key=["n_elements"], +) +@triton.jit +def softsign_kernel( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x_f32 = x.to(tl.float32) + abs_x = tl.abs(x_f32) + result = x_f32 / (1.0 + abs_x) + tl.store(output_ptr + offsets, result.to(x.dtype), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = torch.empty_like(x) + n_elements = x.numel() + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + softsign_kernel[grid](x, output, n_elements) + return output diff --git a/backends/triton/cpu/KernelBench/level1/31_ELU.py b/backends/triton/cpu/KernelBench/level1/31_ELU.py new file mode 100644 index 0000000..595290b --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/31_ELU.py @@ -0,0 +1,60 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + ], + key=["n_elements"], +) +@triton.jit +def elu_kernel( + x_ptr, + out_ptr, + alpha, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + x_f32 = x.to(tl.float32) + + inv_ln2: tl.constexpr = 1.4426950408889634 + exp_x = tl.math.exp2(x_f32 * inv_ln2) + neg_branch = alpha * (exp_x - 1.0) + + result = tl.where(x_f32 > 0.0, x_f32, neg_branch) + + tl.store(out_ptr + offsets, result.to(x.dtype), mask=mask) + + +class Model(nn.Module): + def __init__(self, alpha=1.0): + super(Model, self).__init__() + try: + self.alpha = float(alpha) + except (ValueError, TypeError): + self.alpha = 1.0 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n_elements = x.numel() + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + elu_kernel[grid](x, out, self.alpha, n_elements) + return out diff --git a/backends/triton/cpu/KernelBench/level1/32_HardTanh.py b/backends/triton/cpu/KernelBench/level1/32_HardTanh.py new file mode 100644 index 0000000..2cab659 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/32_HardTanh.py @@ -0,0 +1,50 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + ], + key=["n_elements"], +) +@triton.jit +def hardtanh_kernel( + x_ptr, + out_ptr, + n_elements, + min_val: tl.constexpr, + max_val: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = tl.maximum(x, min_val) + x = tl.minimum(x, max_val) + tl.store(out_ptr + offsets, x, mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n_elements = x.numel() + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + hardtanh_kernel[grid](x, out, n_elements, -1.0, 1.0) + return out diff --git a/backends/triton/cpu/KernelBench/level1/33_BatchNorm.py b/backends/triton/cpu/KernelBench/level1/33_BatchNorm.py new file mode 100644 index 0000000..ed08f6f --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/33_BatchNorm.py @@ -0,0 +1,191 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def _bn_reduce_kernel( + x_ptr, + partial_sum_ptr, + partial_sq_ptr, + C, + HW, + stride_b, + stride_c, + B: tl.constexpr, + BLOCK_HW: tl.constexpr, +): + pid_c = tl.program_id(0) + pid_b = tl.program_id(1) + base = pid_b.to(tl.int64) * stride_b + pid_c.to(tl.int64) * stride_c + acc_sum = 0.0 + acc_sq = 0.0 + for hw_start in range(0, HW, BLOCK_HW): + offs = hw_start + tl.arange(0, BLOCK_HW) + mask = offs < HW + x = tl.load(x_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + acc_sum += tl.sum(x, axis=0) + acc_sq += tl.sum(x * x, axis=0) + out_idx = pid_c * B + pid_b + tl.store(partial_sum_ptr + out_idx, acc_sum) + tl.store(partial_sq_ptr + out_idx, acc_sq) + + +@triton.jit +def _bn_stats_kernel( + partial_sum_ptr, + partial_sq_ptr, + scale_ptr, + shift_ptr, + weight_ptr, + bias_ptr, + inv_count, + eps, + B: tl.constexpr, + BLOCK_B: tl.constexpr, +): + pid_c = tl.program_id(0) + offs_b = tl.arange(0, BLOCK_B) + mask = offs_b < B + s = tl.load(partial_sum_ptr + pid_c * B + offs_b, mask=mask, other=0.0) + sq = tl.load(partial_sq_ptr + pid_c * B + offs_b, mask=mask, other=0.0) + total_sum = tl.sum(s, axis=0) + total_sq = tl.sum(sq, axis=0) + mean_val = total_sum * inv_count + var_val = total_sq * inv_count - mean_val * mean_val + w = tl.load(weight_ptr + pid_c) + bi = tl.load(bias_ptr + pid_c) + inv_std = 1.0 / tl.sqrt(var_val + eps) + scale_val = w * inv_std + shift_val = bi - mean_val * scale_val + tl.store(scale_ptr + pid_c, scale_val) + tl.store(shift_ptr + pid_c, shift_val) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 4096, "warp_size": 32}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096, "warp_size": 16}, num_warps=16), + triton.Config({"BLOCK_SIZE": 8192, "warp_size": 32}, num_warps=8), + triton.Config({"BLOCK_SIZE": 8192, "warp_size": 16}, num_warps=16), + triton.Config({"BLOCK_SIZE": 8192, "warp_size": 32}, num_warps=16), + ], + key=["total_elements"], +) +@triton.jit +def _bn_normalize_flat_kernel( + x_ptr, + out_ptr, + scale_ptr, + shift_ptr, + C, + HW, + total_elements, + BLOCK_SIZE: tl.constexpr, + warp_size: tl.constexpr, +): + pid = tl.program_id(0) + start = pid * BLOCK_SIZE + offs = start + tl.arange(0, BLOCK_SIZE) + mask = offs < total_elements + c = (offs // HW) % C + x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32) + scale = tl.load(scale_ptr + c, mask=mask, other=0.0) + shift = tl.load(shift_ptr + c, mask=mask, other=0.0) + y = x * scale + shift + tl.store(out_ptr + offs, y.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self, num_features: int): + super().__init__() + self.num_features = num_features + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + self.eps = 1e-5 + self.momentum = 0.1 + self._moved = False + self._bufs_ready = False + + def _move_params(self, device): + self.weight.data = self.weight.data.to(device, dtype=torch.float32).contiguous() + self.bias.data = self.bias.data.to(device, dtype=torch.float32).contiguous() + self.running_mean = self.running_mean.to( + device, dtype=torch.float32 + ).contiguous() + self.running_var = self.running_var.to(device, dtype=torch.float32).contiguous() + self._moved = True + + def _alloc_bufs(self, B, C, device): + self._partial_sum = torch.empty((C, B), device=device, dtype=torch.float32) + self._partial_sq = torch.empty((C, B), device=device, dtype=torch.float32) + self._scale = torch.empty(C, device=device, dtype=torch.float32) + self._shift = torch.empty(C, device=device, dtype=torch.float32) + self._bufs_ready = True + + def forward(self, x): + device = x.device + if not self._moved: + self._move_params(device) + x = x.to(dtype=torch.bfloat16).contiguous() + B, C, H, W = x.shape + HW = H * W + stride_b = x.stride(0) + stride_c = x.stride(1) + total_elements = B * C * HW + + if not self._bufs_ready: + self._alloc_bufs(B, C, device) + + if self.training: + _bn_reduce_kernel[(C, B)]( + x, + self._partial_sum, + self._partial_sq, + C, + HW, + stride_b, + stride_c, + B=B, + BLOCK_HW=8192, + num_warps=8, + ) + _bn_stats_kernel[(C,)]( + self._partial_sum, + self._partial_sq, + self._scale, + self._shift, + self.weight, + self.bias, + 1.0 / (B * HW), + self.eps, + B=B, + BLOCK_B=triton.next_power_of_2(B), + num_warps=4, + ) + else: + inv_std = 1.0 / torch.sqrt(self.running_var + self.eps) + self._scale.copy_(self.weight * inv_std) + self._shift.copy_(self.bias - self.running_mean * self._scale) + + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(total_elements, META["BLOCK_SIZE"]),) + _bn_normalize_flat_kernel[grid]( + x, + out, + self._scale, + self._shift, + C, + HW, + total_elements, + ) + return out diff --git a/backends/triton/cpu/KernelBench/level1/34_InstanceNorm.py b/backends/triton/cpu/KernelBench/level1/34_InstanceNorm.py new file mode 100644 index 0000000..be17f6e --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/34_InstanceNorm.py @@ -0,0 +1,82 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 2048, "warp_size": 32}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096, "warp_size": 32}, num_warps=8, num_stages=2), + triton.Config( + {"BLOCK_SIZE": 8192, "warp_size": 32}, num_warps=16, num_stages=2 + ), + triton.Config( + {"BLOCK_SIZE": 4096, "warp_size": 16}, num_warps=16, num_stages=2 + ), + triton.Config( + {"BLOCK_SIZE": 8192, "warp_size": 16}, num_warps=16, num_stages=2 + ), + ], + key=["N"], +) +@triton.jit +def instance_norm_kernel( + x_ptr, + out_ptr, + N, + eps, + BLOCK_SIZE: tl.constexpr, + warp_size: tl.constexpr, +): + pid = tl.program_id(0) + base = pid.to(tl.int64) * N + + sum_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + sq_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + for start in range(0, N, BLOCK_SIZE): + offs = start + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(x_ptr + base + offs, mask=mask, other=0.0) + sum_acc += x + sq_acc += x * x + + total_sum = tl.sum(sum_acc, axis=0) + total_sq = tl.sum(sq_acc, axis=0) + mean = total_sum / N + var = total_sq / N - mean * mean + inv_std = 1.0 / tl.sqrt(var + eps) + + for start in range(0, N, BLOCK_SIZE): + offs = start + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(x_ptr + base + offs, mask=mask, other=0.0) + out = (x - mean) * inv_std + tl.store(out_ptr + base + offs, out, mask=mask) + + +class Model(nn.Module): + def __init__(self, num_features: int): + super(Model, self).__init__() + self.num_features = num_features + self.eps = 1e-5 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + N = H * W + x_flat = x.contiguous().view(B * C, N) + out = torch.empty_like(x_flat) + grid = (B * C,) + instance_norm_kernel[grid]( + x_flat, + out, + N, + self.eps, + ) + return out.view(B, C, H, W) diff --git a/backends/triton/cpu/KernelBench/level1/35_GroupNorm_.py b/backends/triton/cpu/KernelBench/level1/35_GroupNorm_.py new file mode 100644 index 0000000..4e5bd00 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/35_GroupNorm_.py @@ -0,0 +1,175 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_HW": 4096, "warp_size": 32}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_HW": 8192, "warp_size": 32}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_HW": 16384, "warp_size": 32}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_HW": 8192, "warp_size": 16}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_HW": 16384, "warp_size": 16}, num_warps=16, num_stages=3), + ], + key=["HW", "channels_per_group"], +) +@triton.jit +def group_norm_stats_kernel( + x_ptr, + mean_ptr, + invstd_ptr, + C: tl.constexpr, + HW: tl.constexpr, + num_groups: tl.constexpr, + eps: tl.constexpr, + channels_per_group: tl.constexpr, + BLOCK_HW: tl.constexpr, + warp_size: tl.constexpr, +): + pid = tl.program_id(0) + batch_idx = pid // num_groups + group_idx = pid % num_groups + + channel_start = group_idx * channels_per_group + batch_offset = batch_idx.to(tl.int64) * C * HW + group_elems = channels_per_group * HW + + sum_val = 0.0 + sq_val = 0.0 + + for c in range(channels_per_group): + c_offset = (channel_start + c).to(tl.int64) * HW + base = batch_offset + c_offset + for hw_start in range(0, HW, BLOCK_HW): + offs = hw_start + tl.arange(0, BLOCK_HW) + mask = offs < HW + x_val = tl.load(x_ptr + base + offs.to(tl.int64), mask=mask, other=0.0).to( + tl.float32 + ) + sum_val += tl.sum(x_val, axis=0) + sq_val += tl.sum(x_val * x_val, axis=0) + + mean = sum_val / group_elems + variance = sq_val / group_elems - mean * mean + inv_std = 1.0 / tl.sqrt(variance + eps) + + tl.store(mean_ptr + pid, mean) + tl.store(invstd_ptr + pid, inv_std) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_HW": 4096, "warp_size": 32}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_HW": 8192, "warp_size": 32}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_HW": 16384, "warp_size": 32}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_HW": 8192, "warp_size": 16}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_HW": 16384, "warp_size": 16}, num_warps=16, num_stages=3), + ], + key=["HW"], +) +@triton.jit +def group_norm_apply_kernel( + x_ptr, + y_ptr, + mean_ptr, + invstd_ptr, + weight_ptr, + bias_ptr, + C: tl.constexpr, + HW: tl.constexpr, + num_groups: tl.constexpr, + channels_per_group: tl.constexpr, + BLOCK_HW: tl.constexpr, + warp_size: tl.constexpr, +): + pid = tl.program_id(0) + batch_idx = pid // C + channel_idx = pid % C + group_idx = channel_idx // channels_per_group + + stats_idx = batch_idx * num_groups + group_idx + mean = tl.load(mean_ptr + stats_idx) + inv_std = tl.load(invstd_ptr + stats_idx) + w = tl.load(weight_ptr + channel_idx).to(tl.float32) + b = tl.load(bias_ptr + channel_idx).to(tl.float32) + + scale = inv_std * w + shift = b - mean * scale + + batch_offset = batch_idx.to(tl.int64) * C * HW + c_offset = channel_idx.to(tl.int64) * HW + base = batch_offset + c_offset + + for hw_start in range(0, HW, BLOCK_HW): + offs = hw_start + tl.arange(0, BLOCK_HW) + mask = offs < HW + x_val = tl.load(x_ptr + base + offs.to(tl.int64), mask=mask, other=0.0).to( + tl.float32 + ) + normed = x_val * scale + shift + tl.store(y_ptr + base + offs.to(tl.int64), normed.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self, num_features: int, num_groups: int): + super(Model, self).__init__() + self.gn = nn.GroupNorm(num_groups=num_groups, num_channels=num_features) + self.num_features = num_features + self.num_groups = num_groups + self._packed = False + + def _pack_weights(self, device): + self.weight_packed = self.gn.weight.data.to(device).contiguous() + self.bias_packed = self.gn.bias.data.to(device).contiguous() + self._packed = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + device = x.device + x = x.contiguous() + if not self._packed: + self._pack_weights(device) + + N, C, H, W_dim = x.shape + HW = H * W_dim + channels_per_group = C // self.num_groups + y = torch.empty_like(x) + eps = self.gn.eps + + mean_buf = torch.empty(N * self.num_groups, device=device, dtype=torch.float32) + invstd_buf = torch.empty( + N * self.num_groups, device=device, dtype=torch.float32 + ) + + stats_grid = (N * self.num_groups,) + group_norm_stats_kernel[stats_grid]( + x, + mean_buf, + invstd_buf, + C, + HW, + self.num_groups, + eps, + channels_per_group, + ) + + apply_grid = (N * C,) + group_norm_apply_kernel[apply_grid]( + x, + y, + mean_buf, + invstd_buf, + self.weight_packed, + self.bias_packed, + C, + HW, + self.num_groups, + channels_per_group, + ) + return y diff --git a/backends/triton/cpu/KernelBench/level1/36_RMSNorm_.py b/backends/triton/cpu/KernelBench/level1/36_RMSNorm_.py new file mode 100644 index 0000000..9a6570b --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/36_RMSNorm_.py @@ -0,0 +1,98 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_S": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_S": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_S": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_S": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_S": 1024}, num_warps=16, num_stages=2), + ], + key=["S", "F"], +) +@triton.jit +def rms_norm_kernel( + x_ptr, + out_ptr, + eps_ptr, + S: tl.constexpr, + F: tl.constexpr, + stride_batch, + stride_feat, + BLOCK_S: tl.constexpr, +): + pid = tl.program_id(0) + num_s_blocks = tl.cdiv(S, BLOCK_S) + batch_id = pid // num_s_blocks + spatial_block_id = pid % num_s_blocks + + s_start = spatial_block_id * BLOCK_S + s_offs = s_start + tl.arange(0, BLOCK_S) + s_mask = s_offs < S + + batch_offset = batch_id.to(tl.int64) * stride_batch + eps = tl.load(eps_ptr) + + sum_sq = tl.zeros((BLOCK_S,), dtype=tl.float32) + for f in tl.static_range(F): + x_offs = batch_offset + f * stride_feat + s_offs + x_val = tl.load(x_ptr + x_offs, mask=s_mask, other=0.0) + x_fp32 = x_val.to(tl.float32) + sum_sq += x_fp32 * x_fp32 + + mean_sq = sum_sq / F + rms = tl.sqrt(mean_sq + eps.to(tl.float32)) + + for f in tl.static_range(F): + x_offs = batch_offset + f * stride_feat + s_offs + x_val = tl.load(x_ptr + x_offs, mask=s_mask, other=0.0) + out_val = x_val.to(tl.float32) / rms + tl.store(out_ptr + x_offs, out_val.to(tl.bfloat16), mask=s_mask) + + +class Model(nn.Module): + def __init__(self, num_features: int, eps: float = 1e-5): + super(Model, self).__init__() + self.num_features = num_features + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dtype != torch.bfloat16: + x = x.to(torch.bfloat16) + x = x.contiguous() + + B = x.shape[0] + F = x.shape[1] + spatial = 1 + for i in range(2, x.dim()): + spatial *= x.shape[i] + + x_flat = x.view(B, F, spatial) + out_flat = torch.empty_like(x_flat) + + stride_batch = F * spatial + stride_feat = spatial + + eps_t = torch.tensor([self.eps], dtype=torch.float32, device=x.device) + grid = lambda META: (B * triton.cdiv(spatial, META["BLOCK_S"]),) + rms_norm_kernel[grid]( + x_flat, + out_flat, + eps_t, + spatial, + F, + stride_batch, + stride_feat, + ) + + return out_flat.view_as(x) diff --git a/backends/triton/cpu/KernelBench/level1/37_FrobeniusNorm_.py b/backends/triton/cpu/KernelBench/level1/37_FrobeniusNorm_.py new file mode 100644 index 0000000..a404c25 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/37_FrobeniusNorm_.py @@ -0,0 +1,99 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +REDUCE_BLOCK: tl.constexpr = 8192 + + +@triton.jit +def _partial_sum_sq_kernel( + x_ptr, + out_ptr, + N, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + x_fp32 = x.to(tl.float32) + partial_sum = tl.sum(x_fp32 * x_fp32) + tl.store(out_ptr + pid, partial_sum) + + +@triton.jit +def _reduce_kernel( + partial_sums_ptr, + inv_norm_ptr, + num_partial, + REDUCE_BLOCK: tl.constexpr, +): + acc = tl.zeros([REDUCE_BLOCK], dtype=tl.float32) + for start in range(0, num_partial, REDUCE_BLOCK): + offsets = start + tl.arange(0, REDUCE_BLOCK) + mask = offsets < num_partial + vals = tl.load(partial_sums_ptr + offsets, mask=mask, other=0.0) + acc += vals + total = tl.sum(acc) + inv_norm = tl.rsqrt(total) + tl.store(inv_norm_ptr, inv_norm) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=3), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=3), + ], + key=["N"], +) +@triton.jit +def _normalize_kernel( + x_ptr, + out_ptr, + inv_norm_ptr, + N, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + inv_norm = tl.load(inv_norm_ptr) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + result = x.to(tl.float32) * inv_norm + tl.store(out_ptr + offsets, result.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + original_shape = x.shape + x_flat = x.contiguous().view(-1) + N = x_flat.numel() + + num_blocks = triton.cdiv(N, 8192) + partial_sums = torch.empty(num_blocks, device=x.device, dtype=torch.float32) + + _partial_sum_sq_kernel[(num_blocks,)](x_flat, partial_sums, N, BLOCK_SIZE=8192) + + inv_norm = torch.empty(1, device=x.device, dtype=torch.float32) + _reduce_kernel[(1,)]( + partial_sums, inv_norm, num_blocks, REDUCE_BLOCK=REDUCE_BLOCK + ) + + output_flat = torch.empty_like(x_flat) + grid_norm = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),) + _normalize_kernel[grid_norm](x_flat, output_flat, inv_norm, N) + + return output_flat.view(original_shape) diff --git a/backends/triton/cpu/KernelBench/level1/38_L1Norm_.py b/backends/triton/cpu/KernelBench/level1/38_L1Norm_.py new file mode 100644 index 0000000..c3bb3d3 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/38_L1Norm_.py @@ -0,0 +1,79 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + ], + key=["N"], +) +@triton.jit +def l1_norm_kernel( + x_ptr, + out_ptr, + M, + N, + stride_xm, + stride_xn, + stride_om, + stride_on, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + + col_offsets = tl.arange(0, BLOCK_SIZE) + row_x = x_ptr + row.to(tl.int64) * stride_xm + row_o = out_ptr + row.to(tl.int64) * stride_om + + # Phase 1: compute sum(abs(x)) for this row + abs_sum = 0.0 + for col_start in range(0, N, BLOCK_SIZE): + cols = col_start + col_offsets + mask = cols < N + x = tl.load(row_x + cols * stride_xn, mask=mask, other=0.0).to(tl.float32) + abs_sum += tl.sum(tl.abs(x), axis=0) + + # mean = sum(abs(x)) / N + mean_val = abs_sum / N + + # Phase 2: normalize x / mean + for col_start in range(0, N, BLOCK_SIZE): + cols = col_start + col_offsets + mask = cols < N + x = tl.load(row_x + cols * stride_xn, mask=mask, other=0.0).to(tl.float32) + out = x / mean_val + tl.store(row_o + cols * stride_on, out.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + M, N = x.shape + out = torch.empty_like(x) + grid = (M,) + l1_norm_kernel[grid]( + x, + out, + M, + N, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + ) + return out diff --git a/backends/triton/cpu/KernelBench/level1/39_L2Norm_.py b/backends/triton/cpu/KernelBench/level1/39_L2Norm_.py new file mode 100644 index 0000000..af582c2 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/39_L2Norm_.py @@ -0,0 +1,75 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16, num_stages=2), + ], + key=["N"], +) +@triton.jit +def l2_norm_kernel( + x_ptr, + out_ptr, + M, + N, + stride_m, + stride_n, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + row_start = row * stride_m + + sum_sq = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(x_ptr + row_start + cols * stride_n, mask=mask, other=0.0).to( + tl.float32 + ) + sum_sq += x * x + + norm_sq = tl.sum(sum_sq, axis=0) + inv_norm = 1.0 / tl.sqrt(norm_sq + 1e-12) + + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(x_ptr + row_start + cols * stride_n, mask=mask, other=0.0).to( + tl.float32 + ) + out = x * inv_norm + tl.store(out_ptr + row_start + cols * stride_n, out.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.contiguous() + M, N = x.shape + out = torch.empty_like(x) + grid = (M,) + l2_norm_kernel[grid]( + x, + out, + M, + N, + x.stride(0), + x.stride(1), + ) + return out diff --git a/backends/triton/cpu/KernelBench/level1/3_Batched_matrix_multiplication.py b/backends/triton/cpu/KernelBench/level1/3_Batched_matrix_multiplication.py new file mode 100644 index 0000000..ec0de2d --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/3_Batched_matrix_multiplication.py @@ -0,0 +1,151 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def _batched_matmul_kernel( + A_ptr, + B_ptr, + C_ptr, + M, + N, + K, + stride_ab: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bb: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cb: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BATCH: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_tiles_mn = num_pid_m * num_pid_n + + batch_id = pid // num_tiles_mn + tile_id = pid % num_tiles_mn + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + batch_offset_a = batch_id.to(tl.int64) * stride_ab + batch_offset_b = batch_id.to(tl.int64) * stride_bb + batch_offset_c = batch_id.to(tl.int64) * stride_cb + + a_desc = tl.make_tensor_descriptor( + base=A_ptr + batch_offset_a, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + + b_desc = tl.make_tensor_descriptor( + base=B_ptr + batch_offset_b, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a = a_desc.load([pid_m * BLOCK_M, off_k]) + b = b_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, b) + c_desc = tl.make_tensor_descriptor( + base=C_ptr + batch_offset_c, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(C_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + BATCH, M, K = A.shape + _, _, N = B.shape + + A = A.contiguous() + B = B.contiguous() + C = torch.empty((BATCH, M, N), device=A.device, dtype=A.dtype) + + def grid(META): + return ( + BATCH + * triton.cdiv(M, META["BLOCK_M"]) + * triton.cdiv(N, META["BLOCK_N"]), + ) + + _batched_matmul_kernel[grid]( + A, + B, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + C.stride(0), + C.stride(1), + C.stride(2), + BATCH=BATCH, + ) + return C diff --git a/backends/triton/cpu/KernelBench/level1/40_LayerNorm.py b/backends/triton/cpu/KernelBench/level1/40_LayerNorm.py new file mode 100644 index 0000000..dfc4831 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/40_LayerNorm.py @@ -0,0 +1,108 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE": 4096, "warp_size": 32}, num_warps=32, num_stages=2 + ), + triton.Config( + {"BLOCK_SIZE": 2048, "warp_size": 32}, num_warps=32, num_stages=2 + ), + triton.Config( + {"BLOCK_SIZE": 1024, "warp_size": 32}, num_warps=32, num_stages=2 + ), + triton.Config( + {"BLOCK_SIZE": 2048, "warp_size": 32}, num_warps=16, num_stages=2 + ), + triton.Config( + {"BLOCK_SIZE": 1024, "warp_size": 32}, num_warps=16, num_stages=2 + ), + ], + key=["N"], +) +@triton.jit +def _layer_norm_best_kernel( + X_ptr, + Y_ptr, + W_ptr, + B_ptr, + M, + N, + eps, + BLOCK_SIZE: tl.constexpr, + warp_size: tl.constexpr, +): + row = tl.program_id(0) + row_start = row.to(tl.int64) * N + + _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + _sum_sq = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X_ptr + row_start + cols, mask=mask, other=0.0).to(tl.float32) + _sum += x + _sum_sq += x * x + mean = tl.sum(_sum, axis=0) / N + var = tl.sum(_sum_sq, axis=0) / N - mean * mean + rstd = 1.0 / tl.sqrt(var + eps) + + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X_ptr + row_start + cols, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W_ptr + cols, mask=mask, other=1.0).to(tl.float32) + b = tl.load(B_ptr + cols, mask=mask, other=0.0).to(tl.float32) + y = (x - mean) * rstd * w + b + tl.store(Y_ptr + row_start + cols, y.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self, normalized_shape: tuple): + super(Model, self).__init__() + self.ln = nn.LayerNorm(normalized_shape=normalized_shape) + self._moved = False + + def _move_params(self, device): + self.w_flat = ( + self.ln.weight.data.to(device, dtype=torch.bfloat16).contiguous().flatten() + ) + self.b_flat = ( + self.ln.bias.data.to(device, dtype=torch.bfloat16).contiguous().flatten() + ) + self._eps = self.ln.eps + self._norm_n = 1 + for s in self.ln.normalized_shape: + self._norm_n *= s + self._moved = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self._moved: + self._move_params(x.device) + x = x.contiguous() + orig_shape = x.shape + N = self._norm_n + M = x.numel() // N + x_flat = x.view(M, N) + y_flat = torch.empty_like(x_flat) + grid = (M,) + _layer_norm_best_kernel[grid]( + x_flat, + y_flat, + self.w_flat, + self.b_flat, + M, + N, + self._eps, + ) + return y_flat.view(orig_shape) diff --git a/backends/triton/cpu/KernelBench/level1/41_Max_Pooling_1D.py b/backends/triton/cpu/KernelBench/level1/41_Max_Pooling_1D.py new file mode 100644 index 0000000..75fb3ee --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/41_Max_Pooling_1D.py @@ -0,0 +1,104 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), + ], + key=["output_length"], +) +@triton.jit +def maxpool1d_kernel( + input_ptr, + output_ptr, + seq_length, + output_length, + num_channels, + stride_b, + stride_c, + stride_out_b, + stride_out_c, + KERNEL_SIZE: tl.constexpr, + STRIDE: tl.constexpr, + PADDING: tl.constexpr, + DILATION: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_bc = tl.program_id(0) + pid_o = tl.program_id(1) + + b = pid_bc // num_channels + c = pid_bc % num_channels + + o_start = pid_o * BLOCK_SIZE + o_offsets = o_start + tl.arange(0, BLOCK_SIZE) + o_mask = o_offsets < output_length + + base_in = input_ptr + b * stride_b + c * stride_c + running_max = tl.full([BLOCK_SIZE], value=-float("inf"), dtype=tl.float32) + + for k in range(KERNEL_SIZE): + inp_idx = o_offsets * STRIDE + k * DILATION - PADDING + valid = (inp_idx >= 0) & (inp_idx < seq_length) & o_mask + vals = tl.load(base_in + inp_idx, mask=valid, other=-float("inf")) + running_max = tl.maximum(running_max, vals.to(tl.float32)) + + base_out = output_ptr + b * stride_out_b + c * stride_out_c + tl.store(base_out + o_offsets, running_max.to(tl.bfloat16), mask=o_mask) + + +class Model(nn.Module): + def __init__( + self, + kernel_size: int, + stride: int = None, + padding: int = 0, + dilation: int = 1, + return_indices: bool = False, + ): + super(Model, self).__init__() + self.kernel_size = kernel_size + self.stride = stride if stride is not None else kernel_size + self.padding = padding + self.dilation = dilation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + device = x.device + B, C, L = x.shape + x = x.to(device).contiguous() + + output_length = ( + L + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1 + ) // self.stride + 1 + output = torch.empty((B, C, output_length), device=device, dtype=x.dtype) + + grid = lambda META: (B * C, triton.cdiv(output_length, META["BLOCK_SIZE"])) + maxpool1d_kernel[grid]( + x, + output, + L, + output_length, + C, + x.stride(0), + x.stride(1), + output.stride(0), + output.stride(1), + KERNEL_SIZE=self.kernel_size, + STRIDE=self.stride, + PADDING=self.padding, + DILATION=self.dilation, + ) + return output diff --git a/backends/triton/cpu/KernelBench/level1/42_Max_Pooling_2D.py b/backends/triton/cpu/KernelBench/level1/42_Max_Pooling_2D.py new file mode 100644 index 0000000..4e46f9a --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/42_Max_Pooling_2D.py @@ -0,0 +1,112 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_H": 1, "BLOCK_W": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_H": 2, "BLOCK_W": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_H": 2, "BLOCK_W": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_H": 4, "BLOCK_W": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_H": 8, "BLOCK_W": 128}, num_warps=16, num_stages=2), + ], + key=["OH", "OW"], +) +@triton.jit +def _maxpool2d_kernel( + x_ptr, + out_ptr, + H, + W, + OH, + OW, + BLOCK_H: tl.constexpr, + BLOCK_W: tl.constexpr, + KERNEL_SIZE: tl.constexpr, + STRIDE: tl.constexpr, + PADDING: tl.constexpr, + DILATION: tl.constexpr, + IS_BF16: tl.constexpr, +): + pid_bc = tl.program_id(0) + pid_oh = tl.program_id(1) + pid_ow = tl.program_id(2) + + x_base = x_ptr + pid_bc.to(tl.int64) * H * W + out_base = out_ptr + pid_bc.to(tl.int64) * OH * OW + + oh_offsets = pid_oh * BLOCK_H + tl.arange(0, BLOCK_H) + ow_offsets = pid_ow * BLOCK_W + tl.arange(0, BLOCK_W) + + max_vals = tl.full((BLOCK_H, BLOCK_W), float("-inf"), dtype=tl.float32) + + for kh in range(KERNEL_SIZE): + ih = oh_offsets * STRIDE - PADDING + kh * DILATION + valid_h = (ih >= 0) & (ih < H) & (oh_offsets < OH) + for kw in range(KERNEL_SIZE): + iw = ow_offsets * STRIDE - PADDING + kw * DILATION + valid_w = (iw >= 0) & (iw < W) & (ow_offsets < OW) + mask = valid_h[:, None] & valid_w[None, :] + ptrs = x_base + ih[:, None] * W + iw[None, :] + vals = tl.load(ptrs, mask=mask, other=float("-inf")).to(tl.float32) + max_vals = tl.maximum(max_vals, vals) + + out_mask = (oh_offsets < OH)[:, None] & (ow_offsets < OW)[None, :] + out_ptrs = out_base + oh_offsets[:, None] * OW + ow_offsets[None, :] + if IS_BF16: + tl.store(out_ptrs, max_vals.to(tl.bfloat16), mask=out_mask) + else: + tl.store(out_ptrs, max_vals, mask=out_mask) + + +def maxpool2d(x, kernel_size, stride, padding, dilation): + B, C, H, W = x.shape + OH = (H + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 + OW = (W + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 + + x = x.contiguous() + out = torch.empty((B, C, OH, OW), device=x.device, dtype=x.dtype) + + is_bf16 = x.dtype == torch.bfloat16 + + grid = lambda META: ( + B * C, + triton.cdiv(OH, META["BLOCK_H"]), + triton.cdiv(OW, META["BLOCK_W"]), + ) + + _maxpool2d_kernel[grid]( + x, + out, + H, + W, + OH, + OW, + KERNEL_SIZE=kernel_size, + STRIDE=stride, + PADDING=padding, + DILATION=dilation, + IS_BF16=is_bf16, + ) + + return out + + +class Model(nn.Module): + def __init__(self, kernel_size: int, stride: int, padding: int, dilation: int): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return maxpool2d(x, self.kernel_size, self.stride, self.padding, self.dilation) diff --git a/backends/triton/cpu/KernelBench/level1/43_Max_Pooling_3D.py b/backends/triton/cpu/KernelBench/level1/43_Max_Pooling_3D.py new file mode 100644 index 0000000..578bcb4 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/43_Max_Pooling_3D.py @@ -0,0 +1,156 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_OW": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_OW": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_OW": 64}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_OW": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_OW": 128}, num_warps=4, num_stages=3), + ], + key=["OW"], +) +@triton.jit +def _maxpool3d_kernel( + x_ptr, + out_ptr, + C, + D, + H, + W, + OD, + OH, + OW, + stride_xb, + stride_xc, + stride_xd, + stride_xh, + stride_xw, + stride_ob, + stride_oc, + stride_od, + stride_oh, + stride_ow, + POOL_STRIDE: tl.constexpr, + PAD: tl.constexpr, + DIL: tl.constexpr, + KS: tl.constexpr, + BLOCK_OW: tl.constexpr, +): + pid_ow = tl.program_id(0) + pid_dh = tl.program_id(1) + pid_bc = tl.program_id(2) + + b = pid_bc // C + c = pid_bc % C + od = pid_dh // OH + oh = pid_dh % OH + + ow_start = pid_ow * BLOCK_OW + ow_offs = ow_start + tl.arange(0, BLOCK_OW) + mask_ow = ow_offs < OW + + base = x_ptr + b.to(tl.int64) * stride_xb + c.to(tl.int64) * stride_xc + + d_base = od * POOL_STRIDE - PAD + h_base = oh * POOL_STRIDE - PAD + w_bases = ow_offs * POOL_STRIDE - PAD + + max_val = tl.full([BLOCK_OW], float("-inf"), dtype=tl.float32) + + for kd in range(KS): + d_in = d_base + kd * DIL + d_valid = (d_in >= 0) & (d_in < D) + if d_valid: + for kh in range(KS): + h_in = h_base + kh * DIL + h_valid = (h_in >= 0) & (h_in < H) + if h_valid: + dh_offset = d_in * stride_xd + h_in * stride_xh + for kw in range(KS): + w_in = w_bases + kw * DIL + w_valid = (w_in >= 0) & (w_in < W) + ptrs = base + dh_offset + w_in * stride_xw + valid_mask = mask_ow & w_valid + val = tl.load(ptrs, mask=valid_mask, other=float("-inf")) + max_val = tl.maximum(max_val, val.to(tl.float32)) + + out_base = out_ptr + b.to(tl.int64) * stride_ob + c.to(tl.int64) * stride_oc + out_base += od * stride_od + oh * stride_oh + out_ptrs = out_base + ow_offs * stride_ow + tl.store(out_ptrs, max_val.to(tl.bfloat16), mask=mask_ow) + + +def maxpool3d_triton(x, kernel_size, stride_pool, padding, dilation): + B, C, D, H, W = x.shape + OD = (D + 2 * padding - dilation * (kernel_size - 1) - 1) // stride_pool + 1 + OH = (H + 2 * padding - dilation * (kernel_size - 1) - 1) // stride_pool + 1 + OW = (W + 2 * padding - dilation * (kernel_size - 1) - 1) // stride_pool + 1 + + out = torch.empty(B, C, OD, OH, OW, device=x.device, dtype=x.dtype) + + grid = lambda META: ( + triton.cdiv(OW, META["BLOCK_OW"]), + OD * OH, + B * C, + ) + + _maxpool3d_kernel[grid]( + x, + out, + C, + D, + H, + W, + OD, + OH, + OW, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + x.stride(4), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + out.stride(4), + POOL_STRIDE=stride_pool, + PAD=padding, + DIL=dilation, + KS=kernel_size, + ) + + return out + + +class Model(nn.Module): + def __init__( + self, + kernel_size: int, + stride: int = None, + padding: int = 0, + dilation: int = 1, + return_indices: bool = False, + ceil_mode: bool = False, + ): + super(Model, self).__init__() + self.kernel_size = kernel_size + self.stride_pool = stride if stride is not None else kernel_size + self.padding = padding + self.dilation = dilation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return maxpool3d_triton( + x, self.kernel_size, self.stride_pool, self.padding, self.dilation + ) diff --git a/backends/triton/cpu/KernelBench/level1/44_Average_Pooling_1D.py b/backends/triton/cpu/KernelBench/level1/44_Average_Pooling_1D.py new file mode 100644 index 0000000..9bb61db --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/44_Average_Pooling_1D.py @@ -0,0 +1,112 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2), + ], + key=["output_length", "kernel_size"], +) +@triton.jit +def avg_pool1d_kernel( + input_ptr, + output_ptr, + input_length, + output_length, + kernel_size, + stride, + padding, + stride_b, + stride_c, + stride_l, + out_stride_b, + out_stride_c, + out_stride_l, + BLOCK_SIZE: tl.constexpr, +): + pid_bc = tl.program_id(0) + pid_l = tl.program_id(1) + + offs = pid_l * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < output_length + + acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + for k in range(kernel_size): + inp_pos = offs * stride - padding + k + inp_mask = mask & (inp_pos >= 0) & (inp_pos < input_length) + inp_ptr = input_ptr + pid_bc.to(tl.int64) * stride_c + inp_pos * stride_l + val = tl.load(inp_ptr, mask=inp_mask, other=0.0).to(tl.float32) + acc += val + + inv_kernel_size = 1.0 / kernel_size + acc = acc * inv_kernel_size + + out_ptr = output_ptr + pid_bc.to(tl.int64) * out_stride_c + offs * out_stride_l + tl.store(out_ptr, acc.to(tl.bfloat16), mask=mask) + + +def kernel_function( + x: torch.Tensor, kernel_size: int, stride: int, padding: int +) -> torch.Tensor: + batch_size, in_channels, input_length = x.shape + output_length = (input_length + 2 * padding - kernel_size) // stride + 1 + + if x.dtype != torch.bfloat16: + x = x.to(torch.bfloat16) + x = x.contiguous() + + output = torch.empty( + (batch_size, in_channels, output_length), + device=x.device, + dtype=torch.bfloat16, + ) + + total_bc = batch_size * in_channels + + x_flat = x.view(total_bc, input_length) + out_flat = output.view(total_bc, output_length) + + grid = lambda META: (total_bc, triton.cdiv(output_length, META["BLOCK_SIZE"])) + + avg_pool1d_kernel[grid]( + x_flat, + out_flat, + input_length, + output_length, + kernel_size, + stride, + padding, + x_flat.stride(0), + x_flat.stride(0), + x_flat.stride(1), + out_flat.stride(0), + out_flat.stride(0), + out_flat.stride(1), + ) + + return output + + +class Model(nn.Module): + def __init__(self, kernel_size: int, stride: int = 1, padding: int = 0): + super(Model, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel_function(x, self.kernel_size, self.stride, self.padding) diff --git a/backends/triton/cpu/KernelBench/level1/45_Average_Pooling_2D.py b/backends/triton/cpu/KernelBench/level1/45_Average_Pooling_2D.py new file mode 100644 index 0000000..876bb2a --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/45_Average_Pooling_2D.py @@ -0,0 +1,103 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_OW": 64}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_OW": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_OW": 128}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_OW": 64}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_OW": 16}, num_warps=4, num_stages=2), + ], + key=["OW"], +) +@triton.jit +def avg_pool2d_kernel( + input_ptr, + output_ptr, + W, + OW, + in_stride_nc, + out_stride_nc, + KERNEL_SIZE: tl.constexpr, + POOL_STRIDE: tl.constexpr, + BLOCK_OW: tl.constexpr, + OUTPUT_BF16: tl.constexpr, +): + pid_nc = tl.program_id(0) + pid_oh = tl.program_id(1) + pid_ow_tile = tl.program_id(2) + + ow_offsets = pid_ow_tile * BLOCK_OW + tl.arange(0, BLOCK_OW) + ow_mask = ow_offsets < OW + + in_base = input_ptr + pid_nc.to(tl.int64) * in_stride_nc + out_base = output_ptr + pid_nc.to(tl.int64) * out_stride_nc + + acc = tl.zeros((BLOCK_OW,), dtype=tl.float32) + + h_start = pid_oh * POOL_STRIDE + + for kh in range(KERNEL_SIZE): + h_in = h_start + kh + row_offset = h_in * W + for kw in range(KERNEL_SIZE): + w_in = ow_offsets * POOL_STRIDE + kw + vals = tl.load(in_base + row_offset + w_in, mask=ow_mask, other=0.0) + acc += vals.to(tl.float32) + + inv_area = 1.0 / (KERNEL_SIZE * KERNEL_SIZE) + result = acc * inv_area + + out_offset = pid_oh * OW + ow_offsets + if OUTPUT_BF16: + tl.store(out_base + out_offset, result.to(tl.bfloat16), mask=ow_mask) + else: + tl.store(out_base + out_offset, result, mask=ow_mask) + + +class Model(nn.Module): + def __init__(self, kernel_size: int, stride: int = None, padding: int = 0): + super(Model, self).__init__() + self.kernel_size = kernel_size + self.stride = stride if stride is not None else kernel_size + self.padding = padding + + def forward(self, x: torch.Tensor) -> torch.Tensor: + N, C, H, W = x.shape + x = x.contiguous() + + OH = (H + 2 * self.padding - self.kernel_size) // self.stride + 1 + OW = (W + 2 * self.padding - self.kernel_size) // self.stride + 1 + + output = torch.empty(N, C, OH, OW, device=x.device, dtype=x.dtype) + + NC = N * C + in_stride_nc = H * W + out_stride_nc = OH * OW + is_bf16 = x.dtype == torch.bfloat16 + + grid = lambda META: (NC, OH, triton.cdiv(OW, META["BLOCK_OW"])) + + avg_pool2d_kernel[grid]( + x, + output, + W, + OW, + in_stride_nc, + out_stride_nc, + KERNEL_SIZE=self.kernel_size, + POOL_STRIDE=self.stride, + OUTPUT_BF16=is_bf16, + ) + + return output diff --git a/backends/triton/cpu/KernelBench/level1/46_Average_Pooling_3D.py b/backends/triton/cpu/KernelBench/level1/46_Average_Pooling_3D.py new file mode 100644 index 0000000..f46434c --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/46_Average_Pooling_3D.py @@ -0,0 +1,142 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_OW": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_OW": 128}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_OW": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_OW": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_OW": 256}, num_warps=4, num_stages=3), + ], + key=["OW"], +) +@triton.jit +def avg_pool3d_kernel( + x_ptr, + out_ptr, + C, + D, + H, + W, + OD, + OH, + OW, + stride_xn, + stride_xc, + stride_xd, + stride_xh, + stride_on, + stride_oc, + stride_od, + stride_oh, + KERNEL_SIZE: tl.constexpr, + STRIDE: tl.constexpr, + PADDING: tl.constexpr, + BLOCK_OW: tl.constexpr, +): + pid_ow = tl.program_id(0) + pid_ncdoh = tl.program_id(1) + + oh = pid_ncdoh % OH + tmp = pid_ncdoh // OH + od = tmp % OD + tmp2 = tmp // OD + c = tmp2 % C + n = tmp2 // C + + ow_start = pid_ow * BLOCK_OW + ow_offsets = ow_start + tl.arange(0, BLOCK_OW) + mask_ow = ow_offsets < OW + + base_x = x_ptr + n.to(tl.int64) * stride_xn + c.to(tl.int64) * stride_xc + + acc = tl.zeros((BLOCK_OW,), dtype=tl.float32) + + for kd in range(KERNEL_SIZE): + d_in = od * STRIDE + kd - PADDING + d_valid = (d_in >= 0) & (d_in < D) + if d_valid: + for kh in range(KERNEL_SIZE): + h_in = oh * STRIDE + kh - PADDING + h_valid = (h_in >= 0) & (h_in < H) + if h_valid: + row_base = ( + base_x + + d_in.to(tl.int64) * stride_xd + + h_in.to(tl.int64) * stride_xh + ) + for kw in range(KERNEL_SIZE): + w_in = ow_offsets * STRIDE + kw - PADDING + w_valid = (w_in >= 0) & (w_in < W) + mask = mask_ow & w_valid + ptrs = row_base + w_in + vals = tl.load(ptrs, mask=mask, other=0.0) + acc += vals.to(tl.float32) + + inv_count = 1.0 / (KERNEL_SIZE * KERNEL_SIZE * KERNEL_SIZE) + acc = acc * inv_count + + out_base = ( + out_ptr + + n.to(tl.int64) * stride_on + + c.to(tl.int64) * stride_oc + + od.to(tl.int64) * stride_od + + oh.to(tl.int64) * stride_oh + ) + tl.store(out_base + ow_offsets, acc.to(tl.bfloat16), mask=mask_ow) + + +class Model(nn.Module): + def __init__(self, kernel_size: int, stride: int = None, padding: int = 0): + super(Model, self).__init__() + self.kernel_size = kernel_size + self.stride = stride if stride is not None else kernel_size + self.padding = padding + + def forward(self, x: torch.Tensor) -> torch.Tensor: + N, C, D, H, W = x.shape + OD = (D + 2 * self.padding - self.kernel_size) // self.stride + 1 + OH = (H + 2 * self.padding - self.kernel_size) // self.stride + 1 + OW = (W + 2 * self.padding - self.kernel_size) // self.stride + 1 + + output = torch.empty(N, C, OD, OH, OW, device=x.device, dtype=x.dtype) + + grid = lambda META: ( + triton.cdiv(OW, META["BLOCK_OW"]), + N * C * OD * OH, + ) + + avg_pool3d_kernel[grid]( + x, + output, + C, + D, + H, + W, + OD, + OH, + OW, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + KERNEL_SIZE=self.kernel_size, + STRIDE=self.stride, + PADDING=self.padding, + ) + + return output diff --git a/backends/triton/cpu/KernelBench/level1/47_Sum_reduction_over_a_dimension.py b/backends/triton/cpu/KernelBench/level1/47_Sum_reduction_over_a_dimension.py new file mode 100644 index 0000000..6360eea --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/47_Sum_reduction_over_a_dimension.py @@ -0,0 +1,92 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def get_reduction_configs(): + return [ + triton.Config({"BLOCK_R": 128, "BLOCK_N": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_R": 64, "BLOCK_N": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_R": 128, "BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_R": 64, "BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_R": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), + ] + + +@triton.autotune( + configs=get_reduction_configs(), + key=["R", "C"], +) +@triton.jit +def sum_reduce_kernel( + x_ptr, + out_ptr, + B, + R, + C, + stride_xb, + stride_xr, + stride_xc: tl.constexpr, + stride_ob, + stride_oc, + BLOCK_R: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_b = tl.program_id(1) + + x_batch_offset = pid_b.to(tl.int64) * stride_xb + col_start = pid_n * BLOCK_N + + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + x_desc = tl.make_tensor_descriptor( + base=x_ptr + x_batch_offset, + shape=(R, C), + strides=(stride_xr, stride_xc), + block_shape=(BLOCK_R, BLOCK_N), + ) + for off_r in range(0, R, BLOCK_R): + tile = x_desc.load([off_r, col_start]) + acc += tl.sum(tile.to(tl.float32), axis=0) + offs_c = col_start + tl.arange(0, BLOCK_N) + out_offset = pid_b.to(tl.int64) * stride_ob + tl.store( + out_ptr + out_offset + offs_c * stride_oc, acc.to(tl.bfloat16), mask=offs_c < C + ) + + +class Model(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.contiguous() + + B, R, C = x.shape + output = torch.empty((B, 1, C), device=x.device, dtype=x.dtype) + + grid = lambda META: (triton.cdiv(C, META["BLOCK_N"]), B) + + sum_reduce_kernel[grid]( + x, + output, + B, + R, + C, + x.stride(0), + x.stride(1), + x.stride(2), + output.stride(0), + output.stride(2), + ) + + return output diff --git a/backends/triton/cpu/KernelBench/level1/48_Mean_reduction_over_a_dimension.py b/backends/triton/cpu/KernelBench/level1/48_Mean_reduction_over_a_dimension.py new file mode 100644 index 0000000..9f0b3f1 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/48_Mean_reduction_over_a_dimension.py @@ -0,0 +1,96 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def get_reduction_configs(): + return [ + triton.Config({"BLOCK_R": 128, "BLOCK_N": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_R": 64, "BLOCK_N": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_R": 128, "BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_R": 64, "BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_R": 128, "BLOCK_N": 256}, num_warps=16, num_stages=2), + ] + + +@triton.autotune( + configs=get_reduction_configs(), + key=["R", "C"], +) +@triton.jit +def mean_reduce_kernel( + x_ptr, + out_ptr, + B, + R, + C, + stride_xb, + stride_xr, + stride_xc: tl.constexpr, + stride_ob, + stride_oc, + BLOCK_R: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_b = tl.program_id(1) + + x_batch_offset = pid_b.to(tl.int64) * stride_xb + col_start = pid_n * BLOCK_N + + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + x_desc = tl.make_tensor_descriptor( + base=x_ptr + x_batch_offset, + shape=(R, C), + strides=(stride_xr, stride_xc), + block_shape=(BLOCK_R, BLOCK_N), + ) + for off_r in range(0, R, BLOCK_R): + tile = x_desc.load([off_r, col_start]) + acc += tl.sum(tile.to(tl.float32), axis=0) + inv_dim = 1.0 / R + result = acc * inv_dim + + offs_c = col_start + tl.arange(0, BLOCK_N) + out_offset = pid_b.to(tl.int64) * stride_ob + tl.store( + out_ptr + out_offset + offs_c * stride_oc, + result.to(tl.bfloat16), + mask=offs_c < C, + ) + + +class Model(nn.Module): + def __init__(self, dim: int): + super(Model, self).__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert self.dim == 1, "This kernel only supports reduction over dim=1" + x = x.contiguous() + B, R, C = x.shape + out = torch.empty((B, C), device=x.device, dtype=x.dtype) + + grid = lambda META: (triton.cdiv(C, META["BLOCK_N"]), B) + + mean_reduce_kernel[grid]( + x, + out, + B, + R, + C, + x.stride(0), + x.stride(1), + x.stride(2), + out.stride(0), + out.stride(1), + ) + return out diff --git a/backends/triton/cpu/KernelBench/level1/49_Max_reduction_over_a_dimension.py b/backends/triton/cpu/KernelBench/level1/49_Max_reduction_over_a_dimension.py new file mode 100644 index 0000000..957b650 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/49_Max_reduction_over_a_dimension.py @@ -0,0 +1,95 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_N": 64, "BLOCK_K": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_N": 32, "BLOCK_K": 128}, num_warps=8, num_stages=2), + ], + key=["DIM1", "DIM2"], +) +@triton.jit +def max_reduce_dim1_kernel( + x_ptr, + out_ptr, + BATCH: tl.constexpr, + DIM1: tl.constexpr, + DIM2: tl.constexpr, + stride_xb: tl.constexpr, + stride_xd1: tl.constexpr, + stride_xd2: tl.constexpr, + stride_ob: tl.constexpr, + stride_od2: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_n = tl.program_id(1) + + n_start = pid_n * BLOCK_N + n_offs = n_start + tl.arange(0, BLOCK_N) + n_mask = n_offs < DIM2 + + batch_offset = pid_batch.to(tl.int64) * stride_xb + + acc = tl.full((BLOCK_N,), value=float("-inf"), dtype=tl.float32) + + for k_start in range(0, DIM1, BLOCK_K): + k_offs = k_start + tl.arange(0, BLOCK_K) + k_mask = k_offs < DIM1 + + ptrs = ( + x_ptr + + batch_offset + + k_offs[:, None] * stride_xd1 + + n_offs[None, :] * stride_xd2 + ) + mask = k_mask[:, None] & n_mask[None, :] + tile = tl.load(ptrs, mask=mask, other=float("-inf")) + + tile_max = tl.max(tile.to(tl.float32), axis=0) + acc = tl.maximum(acc, tile_max) + + out_ptrs = out_ptr + pid_batch.to(tl.int64) * stride_ob + n_offs * stride_od2 + tl.store(out_ptrs, acc, mask=n_mask) + + +class Model(nn.Module): + def __init__(self, dim: int): + super(Model, self).__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.contiguous() + batch_size, dim1, dim2 = x.shape + + output = torch.empty((batch_size, dim2), device=x.device, dtype=torch.float32) + + grid = lambda META: (batch_size, triton.cdiv(dim2, META["BLOCK_N"])) + + max_reduce_dim1_kernel[grid]( + x, + output, + batch_size, + dim1, + dim2, + x.stride(0), + x.stride(1), + x.stride(2), + output.stride(0), + output.stride(1), + ) + + return output.to(x.dtype) diff --git a/backends/triton/cpu/KernelBench/level1/4_Matrix_vector_multiplication_.py b/backends/triton/cpu/KernelBench/level1/4_Matrix_vector_multiplication_.py new file mode 100644 index 0000000..2d7ea2c --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/4_Matrix_vector_multiplication_.py @@ -0,0 +1,66 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_K": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_K": 1024}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_K": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_K": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_K": 8192}, num_warps=16, num_stages=2), + ], + key=["K"], +) +@triton.jit +def _gemv_kernel( + a_ptr, + b_ptr, + c_ptr, + K, + stride_am, + BLOCK_K: tl.constexpr, +): + row = tl.program_id(0) + a_row_ptr = a_ptr + row * stride_am + + acc = tl.zeros((BLOCK_K,), dtype=tl.float32) + + for k in range(0, K, BLOCK_K): + offs_k = k + tl.arange(0, BLOCK_K) + mask = offs_k < K + a_vals = tl.load(a_row_ptr + offs_k, mask=mask, other=0.0) + b_vals = tl.load(b_ptr + offs_k, mask=mask, other=0.0) + acc += a_vals.to(tl.float32) * b_vals.to(tl.float32) + + result = tl.sum(acc) + tl.store(c_ptr + row, result.to(tl.bfloat16)) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + M, K = A.shape + C = torch.empty(M, device=A.device, dtype=A.dtype) + + B_flat = B.view(-1).contiguous() + + grid = (M,) + _gemv_kernel[grid]( + A, + B_flat, + C, + K, + A.stride(0), + ) + return C.view(M, 1) diff --git a/backends/triton/cpu/KernelBench/level1/50_conv_standard_2D__square_input__square_kernel.py b/backends/triton/cpu/KernelBench/level1/50_conv_standard_2D__square_input__square_kernel.py new file mode 100644 index 0000000..74b60f2 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/50_conv_standard_2D__square_input__square_kernel.py @@ -0,0 +1,240 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def swizzle_tile( + tile_id, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + width = GROUP_SIZE_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(GROUP_SIZE_M, grid_m - group_id * GROUP_SIZE_M) + pid_m = group_id * GROUP_SIZE_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=16, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def conv2d_implicit_gemm_kernel( + x_ptr, + w_ptr, + bias_ptr, + out_ptr, + M, + N, + K, + OH, + OW, + H, + W, + stride_conv_h, + stride_conv_w, + pad_h, + pad_w, + stride_xn, + stride_xc, + stride_xh, + stride_xw, + stride_on, + stride_oc, + stride_oh, + stride_ow, + stride_wk, + stride_wn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + KH: tl.constexpr, + KW: tl.constexpr, + C_IN: tl.constexpr, +): + pid = tl.program_id(0) + pid_m, pid_n = swizzle_tile(pid, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_SIZE_M) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + ohow = OH * OW + n_idx = (offs_m // ohow).to(tl.int64) + rem = offs_m % ohow + oh_idx = rem // OW + ow_idx = rem % OW + + mask_m = offs_m < M + mask_n = offs_n < N + + w_desc = tl.make_tensor_descriptor( + base=w_ptr, + shape=(K, N), + strides=(stride_wk, stride_wn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k0 in range(0, K, BLOCK_K): + offs_k = k0 + tl.arange(0, BLOCK_K) + + kh = offs_k // (KW * C_IN) + kw = (offs_k // C_IN) % KW + cin = offs_k % C_IN + + ih = oh_idx[:, None] * stride_conv_h + kh[None, :] - pad_h + iw = ow_idx[:, None] * stride_conv_w + kw[None, :] - pad_w + + valid = (ih >= 0) & (ih < H) & (iw >= 0) & (iw < W) & (offs_k[None, :] < K) + valid = valid & mask_m[:, None] + + x_ptrs = ( + x_ptr + + n_idx[:, None] * stride_xn + + cin[None, :].to(tl.int64) * stride_xc + + ih.to(tl.int64) * stride_xh + + iw.to(tl.int64) * stride_xw + ) + + x_tile = tl.load(x_ptrs, mask=valid, other=0.0) + w_tile = w_desc.load([k0, pid_n * BLOCK_N]) + + acc = tl.dot(x_tile, w_tile, acc) + bias_vals = tl.load(bias_ptr + offs_n, mask=mask_n, other=0.0) + acc += bias_vals[None, :] + + out_ptrs = ( + out_ptr + + n_idx[:, None] * stride_on + + offs_n[None, :].to(tl.int64) * stride_oc + + oh_idx[:, None].to(tl.int64) * stride_oh + + ow_idx[:, None].to(tl.int64) * stride_ow + ) + tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_m[:, None] & mask_n[None, :]) + + +class Model(nn.Module): + def __init__(self, num_classes=1000): + super().__init__() + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=2 + ) + self._packed = False + + def _pack_weights(self): + w = self.conv1.weight.data + b = self.conv1.bias.data + device = w.device + self.w_packed = ( + w.permute(2, 3, 1, 0) + .contiguous() + .reshape(-1, w.shape[0]) + .to(device=device, dtype=torch.bfloat16) + .contiguous() + ) + self.bias_packed = b.to(device=device, dtype=torch.bfloat16).contiguous() + self._packed = True + + def forward(self, x): + if not self._packed: + self._pack_weights() + + x = x.to(dtype=torch.bfloat16).contiguous() + + N_batch, C_in, H, W_in = x.shape + C_out = self.w_packed.shape[1] + KH, KW = 11, 11 + stride_h, stride_w = 4, 4 + pad_h, pad_w = 2, 2 + OH = (H + 2 * pad_h - KH) // stride_h + 1 + OW = (W_in + 2 * pad_w - KW) // stride_w + 1 + + M = N_batch * OH * OW + N = C_out + K = C_in * KH * KW + + output = torch.empty( + (N_batch, C_out, OH, OW), device=x.device, dtype=torch.bfloat16 + ) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + conv2d_implicit_gemm_kernel[grid]( + x, + self.w_packed, + self.bias_packed, + output, + M, + N, + K, + OH, + OW, + H, + W_in, + stride_h, + stride_w, + pad_h, + pad_w, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + self.w_packed.stride(0), + self.w_packed.stride(1), + KH=KH, + KW=KW, + C_IN=C_in, + ) + + return output diff --git a/backends/triton/cpu/KernelBench/level1/5_Matrix_scalar_multiplication.py b/backends/triton/cpu/KernelBench/level1/5_Matrix_scalar_multiplication.py new file mode 100644 index 0000000..cc77220 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/5_Matrix_scalar_multiplication.py @@ -0,0 +1,52 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 1024 * 16}, num_warps=32, num_stages=2), + triton.Config({"BLOCK_SIZE": 1024 * 8}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_SIZE": 1024 * 4}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE": 1024 * 2}, num_warps=4, num_stages=2), + ], + key=["n_elements"], +) +@triton.jit +def scalar_mul_kernel( + input_ptr, + output_ptr, + scalar, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(input_ptr + offsets, mask=mask) + result = x * scalar + tl.store(output_ptr + offsets, result.to(tl.bfloat16), mask=mask) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A: torch.Tensor, s) -> torch.Tensor: + A = A.contiguous() + output = torch.empty_like(A) + n_elements = A.numel() + if isinstance(s, torch.Tensor): + scalar_val = s.item() + else: + scalar_val = float(s) + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + scalar_mul_kernel[grid](A, output, scalar_val, n_elements) + return output diff --git a/backends/triton/cpu/KernelBench/level1/6_Matmul_with_large_K_dimension_.py b/backends/triton/cpu/KernelBench/level1/6_Matmul_with_large_K_dimension_.py new file mode 100644 index 0000000..696f5fa --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/6_Matmul_with_large_K_dimension_.py @@ -0,0 +1,138 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def _configs(): + return [ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=16, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=8, + num_stages=3, + ), + ] + + +@triton.autotune(configs=_configs(), key=["M", "N", "K"]) +@triton.jit +def _matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak: tl.constexpr, + stride_bk, + stride_bn: tl.constexpr, + stride_cm, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a_tile = a_desc.load([pid_m * BLOCK_M, off_k]) + b_tile = b_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a_tile, b_tile) + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(c_ptr.type.element_ty)) + + +def _matmul_triton(A, B): + M, K = A.shape + K2, N = B.shape + + A = A.contiguous() + B = B.contiguous() + + C = torch.empty((M, N), device=A.device, dtype=A.dtype) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + _matmul_kernel[grid]( + A, + B, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + ) + return C + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + return _matmul_triton(A, B) diff --git a/backends/triton/cpu/KernelBench/level1/7_Matmul_with_small_K_dimension_.py b/backends/triton/cpu/KernelBench/level1/7_Matmul_with_small_K_dimension_.py new file mode 100644 index 0000000..da9852c --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/7_Matmul_with_small_K_dimension_.py @@ -0,0 +1,141 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def swizzle_tile( + tile_id, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + width = GROUP_SIZE_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(GROUP_SIZE_M, grid_m - group_id * GROUP_SIZE_M) + pid_m = group_id * GROUP_SIZE_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def _matmul_small_k_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak: tl.constexpr, + stride_bk, + stride_bn: tl.constexpr, + stride_cm, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_M, BLOCK_N, GROUP_SIZE_M) + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for off_k in range(0, K, BLOCK_K): + a_tile = a_desc.load([pid_m * BLOCK_M, off_k]) + b_tile = b_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a_tile, b_tile) + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(c_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + A = A.contiguous() + B = B.contiguous() + M, K = A.shape + _, N = B.shape + + C = torch.empty((M, N), device=A.device, dtype=A.dtype) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + _matmul_small_k_kernel[grid]( + A, + B, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + ) + return C diff --git a/backends/triton/cpu/KernelBench/level1/8_Matmul_with_irregular_shapes_.py b/backends/triton/cpu/KernelBench/level1/8_Matmul_with_irregular_shapes_.py new file mode 100644 index 0000000..faaa415 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/8_Matmul_with_irregular_shapes_.py @@ -0,0 +1,133 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +def get_autotune_configs(): + return [ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=3, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ] + + +@triton.autotune( + configs=get_autotune_configs(), + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + a_desc = tl.make_tensor_descriptor( + a_ptr, + shape=[M, K], + strides=[stride_am, stride_ak], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = tl.make_tensor_descriptor( + b_ptr, + shape=[K, N], + strides=[stride_bk, stride_bn], + block_shape=[BLOCK_K, BLOCK_N], + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for off_k in range(0, K, BLOCK_K): + a = a_desc.load([pid_m * BLOCK_M, off_k]) + b = b_desc.load([off_k, pid_n * BLOCK_N]) + acc += tl.dot(a, b) + + c_desc = tl.make_tensor_descriptor( + c_ptr, + shape=[M, N], + strides=[stride_cm, stride_cn], + block_shape=[BLOCK_M, BLOCK_N], + ) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(c_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + A = A.contiguous() + B = B.contiguous() + M, K = A.shape + _, N = B.shape + C = torch.empty((M, N), device=A.device, dtype=A.dtype) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + matmul_kernel[grid]( + A, + B, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + ) + return C diff --git a/backends/triton/cpu/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.py b/backends/triton/cpu/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.py new file mode 100644 index 0000000..9bc3040 --- /dev/null +++ b/backends/triton/cpu/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.py @@ -0,0 +1,147 @@ +# ruff: noqa: E731 +# AUTOGENERATED KERNEL (LLM) +# Source: LLM-generated candidate implementation +# Status: Experimental / uncurated +# Expectation: Correctness-first, performance not representative + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +@triton.jit +def swizzle_tile( + tile_id, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + width = GROUP_SIZE_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(GROUP_SIZE_M, grid_m - group_id * GROUP_SIZE_M) + pid_m = group_id * GROUP_SIZE_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +def _configs(): + return [ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=16, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=16, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=16, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=32, + num_stages=2, + ), + ] + + +@triton.autotune(configs=_configs(), key=["M", "N", "K"]) +@triton.jit +def _matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak: tl.constexpr, + stride_bk, + stride_bn: tl.constexpr, + stride_cm, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_M, BLOCK_N, GROUP_SIZE_M) + + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + block_shape=(BLOCK_M, BLOCK_K), + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N), + ) + + off_m = pid_m * BLOCK_M + off_n = pid_n * BLOCK_N + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for off_k in range(0, K, BLOCK_K): + a_tile = a_desc.load([off_m, off_k]) + b_tile = b_desc.load([off_k, off_n]) + acc += tl.dot(a_tile, b_tile) + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_M, BLOCK_N), + ) + c_desc.store([off_m, off_n], acc.to(c_ptr.type.element_ty)) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, A, B): + device = A.device + A = A.contiguous() + B = B.contiguous() + + M_out, K = A.shape + N_out = B.shape[1] + + C = torch.empty((M_out, N_out), device=device, dtype=A.dtype) + + grid = lambda META: ( + triton.cdiv(M_out, META["BLOCK_M"]) * triton.cdiv(N_out, META["BLOCK_N"]), + ) + + _matmul_kernel[grid]( + A, + B, + C, + M_out, + N_out, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + ) + return C diff --git a/problems/specs/KernelBench/level1/10_3D_tensor_matrix_multiplication.yaml b/problems/specs/KernelBench/level1/10_3D_tensor_matrix_multiplication.yaml index a72791f..2d51f4c 100644 --- a/problems/specs/KernelBench/level1/10_3D_tensor_matrix_multiplication.yaml +++ b/problems/specs/KernelBench/level1/10_3D_tensor_matrix_multiplication.yaml @@ -16,6 +16,16 @@ ci: L: 24 flop: "2*N*M*L*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + N: 16 + M: 1024 + K: 2048 + L: 768 + flop: "2*N*M*L*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/11_4D_tensor_matrix_multiplication.yaml b/problems/specs/KernelBench/level1/11_4D_tensor_matrix_multiplication.yaml index e82b215..7d7d2b0 100644 --- a/problems/specs/KernelBench/level1/11_4D_tensor_matrix_multiplication.yaml +++ b/problems/specs/KernelBench/level1/11_4D_tensor_matrix_multiplication.yaml @@ -17,6 +17,17 @@ ci: K: 24 flop: "2*B*I*J*K*L" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + B: 8 + I: 256 + J: 512 + L: 256 + K: 768 + flop: "2*B*I*J*K*L" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/12_Matmul_with_diagonal_matrices_.yaml b/problems/specs/KernelBench/level1/12_Matmul_with_diagonal_matrices_.yaml index d5f9f78..bc53fb5 100644 --- a/problems/specs/KernelBench/level1/12_Matmul_with_diagonal_matrices_.yaml +++ b/problems/specs/KernelBench/level1/12_Matmul_with_diagonal_matrices_.yaml @@ -14,6 +14,14 @@ ci: N: 128 flop: "2*M*N" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 4096 + N: 4096 + flop: "2*M*N" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/13_Matmul_for_symmetric_matrices.yaml b/problems/specs/KernelBench/level1/13_Matmul_for_symmetric_matrices.yaml index 66a59c3..be0dff1 100644 --- a/problems/specs/KernelBench/level1/13_Matmul_for_symmetric_matrices.yaml +++ b/problems/specs/KernelBench/level1/13_Matmul_for_symmetric_matrices.yaml @@ -15,3 +15,10 @@ ci: dtype: float32 dims: N: 64 + +bench-cpu: + - params: ['A', 'B'] + dtype: bfloat16 + dims: + N: 1024 + flop: "2*N*N*N" diff --git a/problems/specs/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.yaml b/problems/specs/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.yaml index 7369aff..ca00691 100644 --- a/problems/specs/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.yaml +++ b/problems/specs/KernelBench/level1/14_Matmul_for_upper_triangular_matrices.yaml @@ -15,3 +15,10 @@ ci: dtype: float32 dims: N: 64 + +bench-cpu: + - params: ['A', 'B'] + dtype: bfloat16 + dims: + N: 1024 + flop: "N*N*N" diff --git a/problems/specs/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.yaml b/problems/specs/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.yaml index cafb725..83013ea 100644 --- a/problems/specs/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.yaml +++ b/problems/specs/KernelBench/level1/15_Matmul_for_lower_triangular_matrices.yaml @@ -15,3 +15,10 @@ ci: dtype: float32 dims: M: 64 + +bench-cpu: + - params: ['A', 'B'] + dtype: bfloat16 + dims: + M: 1024 + flop: "M*M*M" diff --git a/problems/specs/KernelBench/level1/16_Matmul_with_transposed_A.yaml b/problems/specs/KernelBench/level1/16_Matmul_with_transposed_A.yaml index e665dcc..cc7720d 100644 --- a/problems/specs/KernelBench/level1/16_Matmul_with_transposed_A.yaml +++ b/problems/specs/KernelBench/level1/16_Matmul_with_transposed_A.yaml @@ -15,6 +15,15 @@ ci: K: 256 flop: "2*M*N*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 512 + N: 1024 + K: 2048 + flop: "2*M*N*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/17_Matmul_with_transposed_B.yaml b/problems/specs/KernelBench/level1/17_Matmul_with_transposed_B.yaml index 9d984be..97b7e6a 100644 --- a/problems/specs/KernelBench/level1/17_Matmul_with_transposed_B.yaml +++ b/problems/specs/KernelBench/level1/17_Matmul_with_transposed_B.yaml @@ -15,6 +15,15 @@ ci: K: 256 flop: "2*M*N*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 512 + N: 1024 + K: 2048 + flop: "2*M*N*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/18_Matmul_with_transposed_both.yaml b/problems/specs/KernelBench/level1/18_Matmul_with_transposed_both.yaml index 0ec07fb..92edbd9 100644 --- a/problems/specs/KernelBench/level1/18_Matmul_with_transposed_both.yaml +++ b/problems/specs/KernelBench/level1/18_Matmul_with_transposed_both.yaml @@ -15,6 +15,15 @@ ci: K: 256 flop: "2*N*M*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 512 + N: 1024 + K: 2048 + flop: "2*N*M*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml b/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml index 3a0fc9c..31b0e97 100644 --- a/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml +++ b/problems/specs/KernelBench/level1/1_Square_matrix_multiplication_.yaml @@ -16,13 +16,15 @@ bench-cpu: - params: [A, B] dtype: float32 dims: - N: 1024 + N: 512 flop: "2*N*N*N" + mem_bytes: "3 * N * N * 4" # f32 - params: [A, B] dtype: bfloat16 dims: N: 1024 flop: "2*N*N*N" + mem_bytes: "3 * N * N * 2" # bf16 bench-gpu: - params: [A, B] diff --git a/problems/specs/KernelBench/level1/20_LeakyReLU.yaml b/problems/specs/KernelBench/level1/20_LeakyReLU.yaml index 7071370..37d2df8 100644 --- a/problems/specs/KernelBench/level1/20_LeakyReLU.yaml +++ b/problems/specs/KernelBench/level1/20_LeakyReLU.yaml @@ -11,6 +11,14 @@ ci: DIM: 512 flop: "2*BATCH*DIM" +bench-cpu: + - params: [X] + dtype: bfloat16 + dims: + BATCH: 1024 + DIM: 16384 + flop: "2*BATCH*DIM" + bench-gpu: - params: [X] dtype: float16 diff --git a/problems/specs/KernelBench/level1/21_Sigmoid.yaml b/problems/specs/KernelBench/level1/21_Sigmoid.yaml index 0bbba8e..8935039 100644 --- a/problems/specs/KernelBench/level1/21_Sigmoid.yaml +++ b/problems/specs/KernelBench/level1/21_Sigmoid.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + flop: "4*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/22_Tanh.yaml b/problems/specs/KernelBench/level1/22_Tanh.yaml index 0bbba8e..8935039 100644 --- a/problems/specs/KernelBench/level1/22_Tanh.yaml +++ b/problems/specs/KernelBench/level1/22_Tanh.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + flop: "4*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/23_Softmax.yaml b/problems/specs/KernelBench/level1/23_Softmax.yaml index 0bbba8e..d7d1df4 100644 --- a/problems/specs/KernelBench/level1/23_Softmax.yaml +++ b/problems/specs/KernelBench/level1/23_Softmax.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + flop: "5*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/24_LogSoftmax.yaml b/problems/specs/KernelBench/level1/24_LogSoftmax.yaml index 0bbba8e..d7d1df4 100644 --- a/problems/specs/KernelBench/level1/24_LogSoftmax.yaml +++ b/problems/specs/KernelBench/level1/24_LogSoftmax.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + flop: "5*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/25_Swish.yaml b/problems/specs/KernelBench/level1/25_Swish.yaml index 0bbba8e..d7d1df4 100644 --- a/problems/specs/KernelBench/level1/25_Swish.yaml +++ b/problems/specs/KernelBench/level1/25_Swish.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + flop: "5*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/26_GELU_.yaml b/problems/specs/KernelBench/level1/26_GELU_.yaml index 0bbba8e..fb8d1ab 100644 --- a/problems/specs/KernelBench/level1/26_GELU_.yaml +++ b/problems/specs/KernelBench/level1/26_GELU_.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + flop: "8*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/27_SELU_.yaml b/problems/specs/KernelBench/level1/27_SELU_.yaml index 0bbba8e..ecb92b1 100644 --- a/problems/specs/KernelBench/level1/27_SELU_.yaml +++ b/problems/specs/KernelBench/level1/27_SELU_.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + flop: "2*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/28_HardSigmoid.yaml b/problems/specs/KernelBench/level1/28_HardSigmoid.yaml index 0bbba8e..ecb92b1 100644 --- a/problems/specs/KernelBench/level1/28_HardSigmoid.yaml +++ b/problems/specs/KernelBench/level1/28_HardSigmoid.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + flop: "2*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/29_Softplus.yaml b/problems/specs/KernelBench/level1/29_Softplus.yaml index 0bbba8e..8935039 100644 --- a/problems/specs/KernelBench/level1/29_Softplus.yaml +++ b/problems/specs/KernelBench/level1/29_Softplus.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + flop: "4*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml b/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml index 0ba0f91..f197200 100644 --- a/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml +++ b/problems/specs/KernelBench/level1/2_Standard_matrix_multiplication_.yaml @@ -18,11 +18,11 @@ bench-cpu: - params: [A, B] dtype: bfloat16 dims: - M: 128 - N: 256 - K: 512 + M: 768 + N: 768 + K: 1024 flop: "2*M*N*K" - mem_bytes: "(M*K + K*N + M*N) * 2" # f16 + mem_bytes: "(M*K + K*N + M*N) * 2" # bf16 bench-gpu: - params: [A, B] diff --git a/problems/specs/KernelBench/level1/30_Softsign.yaml b/problems/specs/KernelBench/level1/30_Softsign.yaml index 0bbba8e..ecb92b1 100644 --- a/problems/specs/KernelBench/level1/30_Softsign.yaml +++ b/problems/specs/KernelBench/level1/30_Softsign.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + flop: "2*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/31_ELU.yaml b/problems/specs/KernelBench/level1/31_ELU.yaml index 451772e..3679758 100644 --- a/problems/specs/KernelBench/level1/31_ELU.yaml +++ b/problems/specs/KernelBench/level1/31_ELU.yaml @@ -13,3 +13,12 @@ ci: BATCH_SIZE: 128 DIM: 512 ALPHA: 1.0 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + ALPHA: 1.0 + flop: "2*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/32_HardTanh.yaml b/problems/specs/KernelBench/level1/32_HardTanh.yaml index 0bbba8e..d66de33 100644 --- a/problems/specs/KernelBench/level1/32_HardTanh.yaml +++ b/problems/specs/KernelBench/level1/32_HardTanh.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 2048 + DIM: 12288 + flop: "BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/33_BatchNorm.yaml b/problems/specs/KernelBench/level1/33_BatchNorm.yaml index 2c34dfb..c8e180a 100644 --- a/problems/specs/KernelBench/level1/33_BatchNorm.yaml +++ b/problems/specs/KernelBench/level1/33_BatchNorm.yaml @@ -14,3 +14,13 @@ ci: FEATURES: 4 DIM1: 64 DIM2: 64 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 64 + FEATURES: 64 + DIM1: 512 + DIM2: 512 + flop: "5*BATCH_SIZE*FEATURES*DIM1*DIM2" diff --git a/problems/specs/KernelBench/level1/34_InstanceNorm.yaml b/problems/specs/KernelBench/level1/34_InstanceNorm.yaml index 2c34dfb..4c454cd 100644 --- a/problems/specs/KernelBench/level1/34_InstanceNorm.yaml +++ b/problems/specs/KernelBench/level1/34_InstanceNorm.yaml @@ -14,3 +14,13 @@ ci: FEATURES: 4 DIM1: 64 DIM2: 64 + +bench-cpu: + - params: ['x'] + dtype: float32 # TODO: unable to pass result check with bfloat16, investigate + dims: + BATCH_SIZE: 16 + FEATURES: 64 + DIM1: 256 + DIM2: 256 + flop: "5*BATCH_SIZE*FEATURES*DIM1*DIM2" diff --git a/problems/specs/KernelBench/level1/35_GroupNorm_.yaml b/problems/specs/KernelBench/level1/35_GroupNorm_.yaml index da07568..4cf0431 100644 --- a/problems/specs/KernelBench/level1/35_GroupNorm_.yaml +++ b/problems/specs/KernelBench/level1/35_GroupNorm_.yaml @@ -16,3 +16,14 @@ ci: NUM_GROUPS: 4 DIM1: 32 DIM2: 32 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 16 + FEATURES: 64 + NUM_GROUPS: 8 + DIM1: 256 + DIM2: 256 + flop: "5*BATCH_SIZE*FEATURES*DIM1*DIM2" diff --git a/problems/specs/KernelBench/level1/36_RMSNorm_.yaml b/problems/specs/KernelBench/level1/36_RMSNorm_.yaml index 2c34dfb..3864819 100644 --- a/problems/specs/KernelBench/level1/36_RMSNorm_.yaml +++ b/problems/specs/KernelBench/level1/36_RMSNorm_.yaml @@ -14,3 +14,14 @@ ci: FEATURES: 4 DIM1: 64 DIM2: 64 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + rtol: 0.02 + dims: + BATCH_SIZE: 16 + FEATURES: 64 + DIM1: 256 + DIM2: 256 + flop: "4*BATCH_SIZE*FEATURES*DIM1*DIM2" diff --git a/problems/specs/KernelBench/level1/37_FrobeniusNorm_.yaml b/problems/specs/KernelBench/level1/37_FrobeniusNorm_.yaml index 72d91ec..fc1d2e7 100644 --- a/problems/specs/KernelBench/level1/37_FrobeniusNorm_.yaml +++ b/problems/specs/KernelBench/level1/37_FrobeniusNorm_.yaml @@ -13,3 +13,13 @@ ci: FEATURES: 4 DIM1: 64 DIM2: 64 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 16 + FEATURES: 64 + DIM1: 256 + DIM2: 256 + flop: "2*BATCH_SIZE*FEATURES*DIM1*DIM2" diff --git a/problems/specs/KernelBench/level1/38_L1Norm_.yaml b/problems/specs/KernelBench/level1/38_L1Norm_.yaml index 0bbba8e..99224be 100644 --- a/problems/specs/KernelBench/level1/38_L1Norm_.yaml +++ b/problems/specs/KernelBench/level1/38_L1Norm_.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 4096 + DIM: 16384 + flop: "BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/39_L2Norm_.yaml b/problems/specs/KernelBench/level1/39_L2Norm_.yaml index 0bbba8e..6751e3c 100644 --- a/problems/specs/KernelBench/level1/39_L2Norm_.yaml +++ b/problems/specs/KernelBench/level1/39_L2Norm_.yaml @@ -11,3 +11,11 @@ ci: dims: BATCH_SIZE: 128 DIM: 512 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 4096 + DIM: 16384 + flop: "2*BATCH_SIZE*DIM" diff --git a/problems/specs/KernelBench/level1/3_Batched_matrix_multiplication.yaml b/problems/specs/KernelBench/level1/3_Batched_matrix_multiplication.yaml index 14c26dd..84797f0 100644 --- a/problems/specs/KernelBench/level1/3_Batched_matrix_multiplication.yaml +++ b/problems/specs/KernelBench/level1/3_Batched_matrix_multiplication.yaml @@ -15,6 +15,16 @@ ci: N: 128 K: 128 +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + BATCH: 128 + M: 256 + N: 512 + K: 1024 + flop: "2*BATCH*M*N*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/40_LayerNorm.yaml b/problems/specs/KernelBench/level1/40_LayerNorm.yaml index 2fbd50a..0eddf6c 100644 --- a/problems/specs/KernelBench/level1/40_LayerNorm.yaml +++ b/problems/specs/KernelBench/level1/40_LayerNorm.yaml @@ -15,3 +15,14 @@ ci: DIM1: 64 DIM2: 64 NORMALIZED_SHAPE: [4, 64, 64] # TODO: bind these to other dims + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 16 + FEATURES: 64 + DIM1: 256 + DIM2: 256 + NORMALIZED_SHAPE: [64, 256, 256] + flop: "5*BATCH_SIZE*FEATURES*DIM1*DIM2" diff --git a/problems/specs/KernelBench/level1/41_Max_Pooling_1D.yaml b/problems/specs/KernelBench/level1/41_Max_Pooling_1D.yaml index 49ef0fe..916c181 100644 --- a/problems/specs/KernelBench/level1/41_Max_Pooling_1D.yaml +++ b/problems/specs/KernelBench/level1/41_Max_Pooling_1D.yaml @@ -22,3 +22,17 @@ ci: PADDING: 1 DILATION: 3 RETURN_INDICES: false + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 64 + FEATURES: 192 + SEQUENCE_LENGTH: 4096 + KERNEL_SIZE: 8 + STRIDE: 1 + PADDING: 4 + DILATION: 3 + RETURN_INDICES: False + flop: "BATCH_SIZE*FEATURES*SEQUENCE_LENGTH" diff --git a/problems/specs/KernelBench/level1/42_Max_Pooling_2D.yaml b/problems/specs/KernelBench/level1/42_Max_Pooling_2D.yaml index 95debcf..c42df20 100644 --- a/problems/specs/KernelBench/level1/42_Max_Pooling_2D.yaml +++ b/problems/specs/KernelBench/level1/42_Max_Pooling_2D.yaml @@ -21,3 +21,17 @@ ci: STRIDE: 1 PADDING: 1 DILATION: 1 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 4 + CHANNELS: 32 + HEIGHT: 256 + WIDTH: 256 + KERNEL_SIZE: 4 + STRIDE: 1 + PADDING: 1 + DILATION: 1 + flop: "BATCH_SIZE*CHANNELS*HEIGHT*WIDTH" diff --git a/problems/specs/KernelBench/level1/43_Max_Pooling_3D.yaml b/problems/specs/KernelBench/level1/43_Max_Pooling_3D.yaml index 2c34dfb..6afa2d6 100644 --- a/problems/specs/KernelBench/level1/43_Max_Pooling_3D.yaml +++ b/problems/specs/KernelBench/level1/43_Max_Pooling_3D.yaml @@ -1,6 +1,6 @@ inputs: x: - shape: [BATCH_SIZE, FEATURES, DIM1, DIM2] + shape: [BATCH_SIZE, FEATURES, DIM1, DIM2, DIM3] dtype: inherit inits: @@ -12,5 +12,17 @@ ci: dims: BATCH_SIZE: 2 FEATURES: 4 + DIM1: 16 + DIM2: 16 + DIM3: 16 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 4 + FEATURES: 16 DIM1: 64 DIM2: 64 + DIM3: 64 + flop: "BATCH_SIZE*FEATURES*DIM1*DIM2*DIM3" diff --git a/problems/specs/KernelBench/level1/44_Average_Pooling_1D.yaml b/problems/specs/KernelBench/level1/44_Average_Pooling_1D.yaml index f4577e5..ce7c4c4 100644 --- a/problems/specs/KernelBench/level1/44_Average_Pooling_1D.yaml +++ b/problems/specs/KernelBench/level1/44_Average_Pooling_1D.yaml @@ -18,3 +18,15 @@ ci: KERNEL_SIZE: 2 STRIDE: 1 PADDING: 1 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 64 + IN_CHANNELS: 128 + INPUT_LENGTH: 4096 + KERNEL_SIZE: 8 + STRIDE: 1 + PADDING: 4 + flop: "BATCH_SIZE*IN_CHANNELS*INPUT_LENGTH" diff --git a/problems/specs/KernelBench/level1/45_Average_Pooling_2D.yaml b/problems/specs/KernelBench/level1/45_Average_Pooling_2D.yaml index 7ae2029..e754e7d 100644 --- a/problems/specs/KernelBench/level1/45_Average_Pooling_2D.yaml +++ b/problems/specs/KernelBench/level1/45_Average_Pooling_2D.yaml @@ -15,3 +15,14 @@ ci: HEIGHT: 128 WIDTH: 128 KERNEL_SIZE: 3 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 16 + CHANNELS: 64 + HEIGHT: 512 + WIDTH: 512 + KERNEL_SIZE: 11 + flop: "BATCH_SIZE*CHANNELS*HEIGHT*WIDTH" diff --git a/problems/specs/KernelBench/level1/46_Average_Pooling_3D.yaml b/problems/specs/KernelBench/level1/46_Average_Pooling_3D.yaml index 748cf0b..64fc446 100644 --- a/problems/specs/KernelBench/level1/46_Average_Pooling_3D.yaml +++ b/problems/specs/KernelBench/level1/46_Average_Pooling_3D.yaml @@ -20,3 +20,17 @@ ci: KERNEL_SIZE: 3 STRIDE: 2 PADDING: 1 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 # also not for bfloat16 + dims: + BATCH_SIZE: 4 + CHANNELS: 16 + DEPTH: 128 + HEIGHT: 128 + WIDTH: 256 + KERNEL_SIZE: 3 + STRIDE: 2 + PADDING: 1 + flop: "BATCH_SIZE*CHANNELS*DEPTH*HEIGHT*WIDTH" diff --git a/problems/specs/KernelBench/level1/47_Sum_reduction_over_a_dimension.yaml b/problems/specs/KernelBench/level1/47_Sum_reduction_over_a_dimension.yaml index bf5fbda..35f47dd 100644 --- a/problems/specs/KernelBench/level1/47_Sum_reduction_over_a_dimension.yaml +++ b/problems/specs/KernelBench/level1/47_Sum_reduction_over_a_dimension.yaml @@ -14,3 +14,13 @@ ci: DIM1: 64 DIM2: 63 REDUCE_DIM: 1 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 128 + DIM1: 1024 + DIM2: 1023 + REDUCE_DIM: 1 + flop: "BATCH_SIZE*DIM1*DIM2" diff --git a/problems/specs/KernelBench/level1/48_Mean_reduction_over_a_dimension.yaml b/problems/specs/KernelBench/level1/48_Mean_reduction_over_a_dimension.yaml index bf5fbda..35f47dd 100644 --- a/problems/specs/KernelBench/level1/48_Mean_reduction_over_a_dimension.yaml +++ b/problems/specs/KernelBench/level1/48_Mean_reduction_over_a_dimension.yaml @@ -14,3 +14,13 @@ ci: DIM1: 64 DIM2: 63 REDUCE_DIM: 1 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 128 + DIM1: 1024 + DIM2: 1023 + REDUCE_DIM: 1 + flop: "BATCH_SIZE*DIM1*DIM2" diff --git a/problems/specs/KernelBench/level1/49_Max_reduction_over_a_dimension.yaml b/problems/specs/KernelBench/level1/49_Max_reduction_over_a_dimension.yaml index bf5fbda..35f47dd 100644 --- a/problems/specs/KernelBench/level1/49_Max_reduction_over_a_dimension.yaml +++ b/problems/specs/KernelBench/level1/49_Max_reduction_over_a_dimension.yaml @@ -14,3 +14,13 @@ ci: DIM1: 64 DIM2: 63 REDUCE_DIM: 1 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 128 + DIM1: 1024 + DIM2: 1023 + REDUCE_DIM: 1 + flop: "BATCH_SIZE*DIM1*DIM2" diff --git a/problems/specs/KernelBench/level1/4_Matrix_vector_multiplication_.yaml b/problems/specs/KernelBench/level1/4_Matrix_vector_multiplication_.yaml index 1536270..ce5ae57 100644 --- a/problems/specs/KernelBench/level1/4_Matrix_vector_multiplication_.yaml +++ b/problems/specs/KernelBench/level1/4_Matrix_vector_multiplication_.yaml @@ -15,6 +15,15 @@ ci: K: 256 flop: "2*M*N*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 2048 + N: 1 + K: 16384 + flop: "2*M*N*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/50_conv_standard_2D__square_input__square_kernel.yaml b/problems/specs/KernelBench/level1/50_conv_standard_2D__square_input__square_kernel.yaml index 11d119c..a4acab2 100644 --- a/problems/specs/KernelBench/level1/50_conv_standard_2D__square_input__square_kernel.yaml +++ b/problems/specs/KernelBench/level1/50_conv_standard_2D__square_input__square_kernel.yaml @@ -15,3 +15,14 @@ ci: IN_CHANNELS: 3 HEIGHT: 64 WIDTH: 64 + +bench-cpu: + - params: ['x'] + dtype: bfloat16 + dims: + BATCH_SIZE: 128 + NUM_CLASSES: 1000 + IN_CHANNELS: 3 + HEIGHT: 224 + WIDTH: 224 + flop: "2*BATCH_SIZE*IN_CHANNELS*HEIGHT*WIDTH*NUM_CLASSES" diff --git a/problems/specs/KernelBench/level1/5_Matrix_scalar_multiplication.yaml b/problems/specs/KernelBench/level1/5_Matrix_scalar_multiplication.yaml index 782dc63..39ad343 100644 --- a/problems/specs/KernelBench/level1/5_Matrix_scalar_multiplication.yaml +++ b/problems/specs/KernelBench/level1/5_Matrix_scalar_multiplication.yaml @@ -15,6 +15,15 @@ ci: UNIT: 1 flop: "M*N" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 16384 + N: 4096 + UNIT: 1 + flop: "M*N" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/6_Matmul_with_large_K_dimension_.yaml b/problems/specs/KernelBench/level1/6_Matmul_with_large_K_dimension_.yaml index 09435cc..898d576 100644 --- a/problems/specs/KernelBench/level1/6_Matmul_with_large_K_dimension_.yaml +++ b/problems/specs/KernelBench/level1/6_Matmul_with_large_K_dimension_.yaml @@ -15,6 +15,17 @@ ci: K: 512 flop: "2*M*N*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 256 + N: 256 + K: 16384 + flop: "2*M*N*K" + rtol: .inf + atol: 1.0e-05 + bench-gpu: - params: [A, B] dtype: float16 @@ -23,5 +34,5 @@ bench-gpu: N: 256 K: 16384 flop: "2*M*N*K" - rtol: inf + rtol: .inf atol: 1.0e-05 diff --git a/problems/specs/KernelBench/level1/7_Matmul_with_small_K_dimension_.yaml b/problems/specs/KernelBench/level1/7_Matmul_with_small_K_dimension_.yaml index 2c55dd2..6bad9fe 100644 --- a/problems/specs/KernelBench/level1/7_Matmul_with_small_K_dimension_.yaml +++ b/problems/specs/KernelBench/level1/7_Matmul_with_small_K_dimension_.yaml @@ -15,6 +15,15 @@ ci: K: 16 flop: "2*M*N*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 2048 + N: 2048 + K: 64 + flop: "2*M*N*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/8_Matmul_with_irregular_shapes_.yaml b/problems/specs/KernelBench/level1/8_Matmul_with_irregular_shapes_.yaml index 4f0553d..fa3ec37 100644 --- a/problems/specs/KernelBench/level1/8_Matmul_with_irregular_shapes_.yaml +++ b/problems/specs/KernelBench/level1/8_Matmul_with_irregular_shapes_.yaml @@ -15,6 +15,15 @@ ci: K: 19 flop: "2*M*N*K" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 1243 + N: 5687 + K: 901 + flop: "2*M*N*K" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/problems/specs/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.yaml b/problems/specs/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.yaml index 94ced01..91360fd 100644 --- a/problems/specs/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.yaml +++ b/problems/specs/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.yaml @@ -14,6 +14,14 @@ ci: N: 16 flop: "2*M*M*N" +bench-cpu: + - params: [A, B] + dtype: bfloat16 + dims: + M: 8192 + N: 32 + flop: "2*M*M*N" + bench-gpu: - params: [A, B] dtype: float16 diff --git a/pyproject.toml b/pyproject.toml index d0c35b1..dfcf0d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ override-dependencies = [ torch = { index = "pytorch" } pytorch-triton-xpu = { index = "pytorch" } pytorch-triton = { index = "pytorch" } -triton = { git = "https://github.com/triton-lang/triton-cpu.git", rev = "270e696" } +triton = { git = "https://github.com/triton-lang/triton-cpu.git", rev = "028479f" } lighthouse = { git = "https://github.com/llvm/lighthouse", rev = "456475d" } mlir-python-bindings = { index = "eudsl" } From 5f671065cb307e87ec42dd8dc6cd80ae88bd0709 Mon Sep 17 00:00:00 2001 From: Julian Oppermann Date: Tue, 26 May 2026 04:20:04 -0700 Subject: [PATCH 2/3] Bump TritonCPU --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index dfcf0d0..ba285e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ override-dependencies = [ torch = { index = "pytorch" } pytorch-triton-xpu = { index = "pytorch" } pytorch-triton = { index = "pytorch" } -triton = { git = "https://github.com/triton-lang/triton-cpu.git", rev = "028479f" } +triton = { git = "https://github.com/triton-lang/triton-cpu.git", rev = "eece2e9" } lighthouse = { git = "https://github.com/llvm/lighthouse", rev = "456475d" } mlir-python-bindings = { index = "eudsl" } From e007ad6cac316f6e220cf2cb71ecef507345671d Mon Sep 17 00:00:00 2001 From: Julian Oppermann Date: Tue, 26 May 2026 04:54:01 -0700 Subject: [PATCH 3/3] 46: `avg_pool3d_out_frame` is not implement for bfloat16 --- problems/specs/KernelBench/level1/46_Average_Pooling_3D.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/problems/specs/KernelBench/level1/46_Average_Pooling_3D.yaml b/problems/specs/KernelBench/level1/46_Average_Pooling_3D.yaml index 64fc446..8123bf7 100644 --- a/problems/specs/KernelBench/level1/46_Average_Pooling_3D.yaml +++ b/problems/specs/KernelBench/level1/46_Average_Pooling_3D.yaml @@ -23,7 +23,7 @@ ci: bench-cpu: - params: ['x'] - dtype: bfloat16 # also not for bfloat16 + dtype: float32 # also not for bfloat16 dims: BATCH_SIZE: 4 CHANNELS: 16