diff --git a/docs/attention_tiling_analysis.md b/docs/attention_tiling_analysis.md new file mode 100644 index 00000000..e9172a7a --- /dev/null +++ b/docs/attention_tiling_analysis.md @@ -0,0 +1,461 @@ +# Attention Kernel Transformation Analysis + +This document describes the step-by-step transformation of an attention kernel implementation through MLIR compiler passes. The attention mechanism computes: `attention(Q, K, V) = softmax(Q × K^T) × V` + +--- + +## Stage 1: Initial Attention Computation + +```mlir +func.func @attention(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + + // Allocate output tensors + %3 = tensor.empty() : tensor<64x4096xf16> + %4 = tensor.empty() : tensor<128x4096xf16> + %5 = tensor.empty() : tensor<128x4096xf16> + %6 = tensor.empty() : tensor<128x64xf16> + + // Transpose K: 4096x64 -> 64x4096 + %transposed = linalg.transpose ins(%1 : tensor<4096x64xf16>) + outs(%3 : tensor<64x4096xf16>) + permutation = [1, 0] + + // First matmul: Q × K^T + %cst = arith.constant 0.000000e+00 : f16 + %7 = linalg.fill ins(%cst : f16) outs(%4 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + %8 = linalg.matmul ins(%0, %transposed : tensor<128x64xf16>, tensor<64x4096xf16>) + outs(%7 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + + // Softmax operation (high-level) + %9 = linalg.softmax dimension(1) ins(%8 : tensor<128x4096xf16>) + outs(%5 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + + // Second matmul: softmax_out × V + %10 = linalg.fill ins(%cst : f16) outs(%6 : tensor<128x64xf16>) -> tensor<128x64xf16> + %11 = linalg.matmul ins(%9, %2 : tensor<128x4096xf16>, tensor<4096x64xf16>) + outs(%10 : tensor<128x64xf16>) -> tensor<128x64xf16> + + bufferization.materialize_in_destination %11 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` + +--- + +## Stage 2: After Softmax Decomposition + +Softmax decomposition into constituent operations + +This is the numerically stable softmax: `softmax(x) = exp(x - max(x)) / sum(exp(x - max(x)))` + +```mlir +func.func @attention(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %3 = tensor.empty() : tensor<64x4096xf16> + %4 = tensor.empty() : tensor<128x4096xf16> + %5 = tensor.empty() : tensor<128x4096xf16> + %6 = tensor.empty() : tensor<128x64xf16> + + %transposed = linalg.transpose ins(%1 : tensor<4096x64xf16>) + outs(%3 : tensor<64x4096xf16>) + permutation = [1, 0] + + // First matmul (unchanged) + %cst = arith.constant 0.000000e+00 : f16 + %7 = linalg.fill ins(%cst : f16) outs(%4 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + %8 = linalg.matmul ins(%0, %transposed : tensor<128x64xf16>, tensor<64x4096xf16>) + outs(%7 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + + // === SOFTMAX DECOMPOSITION BEGINS === + + // Step 1: Find max along dimension 1 (reduction over 4096 elements) + %9 = tensor.empty() : tensor<128xf16> + %cst_0 = arith.constant 0xFE00 : f16 // -inf in f16 + %10 = linalg.fill ins(%cst_0 : f16) outs(%9 : tensor<128xf16>) -> tensor<128xf16> + %11 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%8 : tensor<128x4096xf16>) outs(%10 : tensor<128xf16>) { + ^bb0(%in: f16, %out: f16): + %18 = arith.maxnumf %in, %out : f16 + linalg.yield %18 : f16 + } -> tensor<128xf16> + + // Step 2: Subtract max and compute exp + %12 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%8, %11 : tensor<128x4096xf16>, tensor<128xf16>) + outs(%5 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_2: f16, %out: f16): + %18 = arith.subf %in, %in_2 : f16 + %19 = math.exp %18 : f16 + linalg.yield %19 : f16 + } -> tensor<128x4096xf16> + + // Step 3: Sum exponentials (reduction) + %cst_1 = arith.constant 0.000000e+00 : f16 + %13 = linalg.fill ins(%cst_1 : f16) outs(%9 : tensor<128xf16>) -> tensor<128xf16> + %14 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%12 : tensor<128x4096xf16>) outs(%13 : tensor<128xf16>) { + ^bb0(%in: f16, %out: f16): + %18 = arith.addf %in, %out : f16 + linalg.yield %18 : f16 + } -> tensor<128xf16> + + // Step 4: Divide by sum for normalization + %15 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%12, %14 : tensor<128x4096xf16>, tensor<128xf16>) + outs(%5 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_2: f16, %out: f16): + %18 = arith.divf %in, %in_2 : f16 + linalg.yield %18 : f16 + } -> tensor<128x4096xf16> + + // === SOFTMAX DECOMPOSITION ENDS === + + // Second matmul (unchanged) + %16 = linalg.fill ins(%cst : f16) outs(%6 : tensor<128x64xf16>) -> tensor<128x64xf16> + %17 = linalg.matmul ins(%15, %2 : tensor<128x4096xf16>, tensor<4096x64xf16>) + outs(%16 : tensor<128x64xf16>) -> tensor<128x64xf16> + + bufferization.materialize_in_destination %17 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` + +--- + +## Stage 3: After Matmul and Transpose Generalization + +Convert named operations (`linalg.matmul`, `linalg.transpose`) to `linalg.generic` + +```mlir +func.func @attention(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %3 = tensor.empty() : tensor<64x4096xf16> + %4 = tensor.empty() : tensor<128x4096xf16> + %5 = tensor.empty() : tensor<128x4096xf16> + %6 = tensor.empty() : tensor<128x64xf16> + + // Transpose converted to generic form + %7 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, // Input: swapped (4096×64 accessed as d1,d0) + affine_map<(d0, d1) -> (d0, d1)>], // Output: 64×4096 (normal d0,d1) + iterator_types = ["parallel", "parallel"] + } ins(%1 : tensor<4096x64xf16>) outs(%3 : tensor<64x4096xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<64x4096xf16> + + // First matmul converted to generic form + %cst = arith.constant 0.000000e+00 : f16 + %8 = linalg.fill ins(%cst : f16) outs(%4 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + %9 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, // Q: 128×64 + affine_map<(d0, d1, d2) -> (d2, d1)>, // K^T: 64×4096 + affine_map<(d0, d1, d2) -> (d0, d1)>], // Out: 128×4096 + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%0, %7 : tensor<128x64xf16>, tensor<64x4096xf16>) + outs(%8 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_2: f16, %out: f16): + %19 = arith.mulf %in, %in_2 : f16 + %20 = arith.addf %out, %19 : f16 + linalg.yield %20 : f16 + } -> tensor<128x4096xf16> + + // Softmax decomposition (unchanged) + // .... + + // Second matmul converted to generic form + %17 = linalg.fill ins(%cst : f16) outs(%6 : tensor<128x64xf16>) -> tensor<128x64xf16> + %18 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, // Softmax: 128×4096 + affine_map<(d0, d1, d2) -> (d2, d1)>, // V: 4096×64 + affine_map<(d0, d1, d2) -> (d0, d1)>], // Out: 128×64 + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%16, %2 : tensor<128x4096xf16>, tensor<4096x64xf16>) + outs(%17 : tensor<128x64xf16>) { + ^bb0(%in: f16, %in_2: f16, %out: f16): + %19 = arith.mulf %in, %in_2 : f16 + %20 = arith.addf %out, %19 : f16 + linalg.yield %20 : f16 + } -> tensor<128x64xf16> + + bufferization.materialize_in_destination %18 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` + +--- + +## Stage 4: After Elementwise Fusion (Final) + +Fusion of generic ops. + +### Notes +This is the most optimized form. Multiple transformations occur: +1. **Transpose + Matmul fusion**: The transpose operation is fused into the first matmul, directly reading K with transposed indexing +2. **Softmax reduction fusion**: The exp computation and sum reduction are fused into a single operation +3. **Final matmul fusion**: The normalization division is fused with the final matmul, computing `(exp(x - max) / sum) × V` in one pass + +```mlir +func.func @attention(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %cst = arith.constant 0xFE00 : f16 + %cst_0 = arith.constant 0.000000e+00 : f16 + + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %3 = tensor.empty() : tensor<128x4096xf16> + %4 = tensor.empty() : tensor<128x64xf16> + + // === FUSED: Transpose + First Matmul === + // Previously: separate transpose (4096×64 → 64×4096) + matmul + // Now: matmul directly reads K with transposed indexing + %5 = linalg.fill ins(%cst_0 : f16) outs(%3 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + %6 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, // Q: 128×64 + affine_map<(d0, d1, d2) -> (d1, d2)>, // K: 4096×64 (transpose on-the-fly!) + affine_map<(d0, d1, d2) -> (d0, d1)>], // Out: 128×4096 + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%0, %1 : tensor<128x64xf16>, tensor<4096x64xf16>) + outs(%5 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_1: f16, %out: f16): + %14 = arith.mulf %in, %in_1 : f16 + %15 = arith.addf %out, %14 : f16 + linalg.yield %15 : f16 + } -> tensor<128x4096xf16> + + // Max reduction (unchanged) + %7 = tensor.empty() : tensor<128xf16> + %8 = linalg.fill ins(%cst : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %9 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%6 : tensor<128x4096xf16>) outs(%8 : tensor<128xf16>) { + ^bb0(%in: f16, %out: f16): + %14 = arith.maxnumf %in, %out : f16 + linalg.yield %14 : f16 + } -> tensor<128xf16> + + // === FUSED: exp and sum reduction === + // Previously: separate exp operation + sum reduction + // Now: compute exp(x - max) and accumulate sum in one pass + %10 = linalg.fill ins(%cst_0 : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %11 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, // Input: QK^T + affine_map<(d0, d1) -> (d0)>, // Max values + affine_map<(d0, d1) -> (d0)>], // Sum accumulator (output) + iterator_types = ["parallel", "reduction"] + } ins(%6, %9 : tensor<128x4096xf16>, tensor<128xf16>) + outs(%10 : tensor<128xf16>) { + ^bb0(%in: f16, %in_1: f16, %out: f16): + %14 = arith.subf %in, %in_1 : f16 // x - max + %15 = math.exp %14 : f16 // exp(x - max) + %16 = arith.addf %15, %out : f16 // accumulate sum + linalg.yield %16 : f16 + } -> tensor<128xf16> + + // === FUSED: softmax normalization + second matmul === + // Previously: separate divide operation + matmul + // Now: compute (exp(x - max) / sum) * V in one fused kernel + %12 = linalg.fill ins(%cst_0 : f16) outs(%4 : tensor<128x64xf16>) -> tensor<128x64xf16> + %13 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, // QK^T scores + affine_map<(d0, d1, d2) -> (d0)>, // Max values + affine_map<(d0, d1, d2) -> (d0)>, // Sum values + affine_map<(d0, d1, d2) -> (d2, d1)>, // V matrix + affine_map<(d0, d1, d2) -> (d0, d1)>], // Output + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%6, %9, %11, %2 : tensor<128x4096xf16>, tensor<128xf16>, + tensor<128xf16>, tensor<4096x64xf16>) + outs(%12 : tensor<128x64xf16>) { + ^bb0(%in: f16, %in_1: f16, %in_2: f16, %in_3: f16, %out: f16): + %14 = arith.subf %in, %in_1 : f16 // x - max + %15 = math.exp %14 : f16 // exp(x - max) + %16 = arith.divf %15, %in_2 : f16 // exp(x - max) / sum [softmax] + %17 = arith.mulf %16, %in_3 : f16 // softmax * V + %18 = arith.addf %out, %17 : f16 // accumulate result + linalg.yield %18 : f16 + } -> tensor<128x64xf16> + + bufferization.materialize_in_destination %13 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` + +--- + +## Stage 5: Online Softmax Optimization (Max-Sum Fusion) + +The previous stage still computes max and sum in two separate passes over the attention scores. We can fuse these into a single pass using the **online softmax algorithm**. + +### Chanages we need in upstream? + +This require improve the **GenericOp Fusion** to support complex reductions that are **Fusible** (reduction error can be corrected i.e. first reduction can be factored out from each term in the second reduction and corrected.) + + +```mlir +func.func @attention_max_sum_fused(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %cst = arith.constant 0xFE00 : f16 + %cst_0 = arith.constant 0.000000e+00 : f16 + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %3 = tensor.empty() : tensor<128x4096xf16> + %4 = tensor.empty() : tensor<128x64xf16> + %5 = linalg.fill ins(%cst_0 : f16) outs(%3 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + + // First generic: Q @ K^T + %6 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} + ins(%0, %1 : tensor<128x64xf16>, tensor<4096x64xf16>) outs(%5 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_1: f16, %out: f16): + %14 = arith.mulf %in, %in_1 : f16 + %15 = arith.addf %out, %14 : f16 + linalg.yield %15 : f16 + } -> tensor<128x4096xf16> + + // === FUSED: Online Softmax (max + sum in single pass) === + // Takes: QK^T scores + // Returns: (max, sum) computed simultaneously + %7 = tensor.empty() : tensor<128xf16> + %8 = linalg.fill ins(%cst : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %9 = linalg.fill ins(%cst_0 : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %10:2 = linalg.generic {indexing_maps = [#map3, #map4, #map4], iterator_types = ["parallel", "reduction"]} + ins(%6 : tensor<128x4096xf16>) outs(%8, %9 : tensor<128xf16>, tensor<128xf16>) { + ^bb0(%in: f16, %out_max: f16, %out_sum: f16): + %14 = arith.maxnumf %in, %out_max : f16 // new_max = max(in, old_max) + %15 = arith.subf %out_max, %14 : f16 // old_max - new_max + %16 = math.exp %15 : f16 // correction_factor = exp(old_max - new_max) + %17 = arith.mulf %out_sum, %16 : f16 // rescaled_sum = old_sum * correction_factor + %18 = arith.subf %in, %14 : f16 // in - new_max + %19 = math.exp %18 : f16 // exp(in - new_max) + %20 = arith.addf %17, %19 : f16 // new_sum = rescaled_sum + exp(in - new_max) + linalg.yield %14, %20 : f16, f16 + } -> (tensor<128xf16>, tensor<128xf16>) + + // Final generic: compute attention output using max and sum + %11 = linalg.fill ins(%cst_0 : f16) outs(%4 : tensor<128x64xf16>) -> tensor<128x64xf16> + %12 = linalg.generic {indexing_maps = [#map, #map5, #map5, #map6, #map2], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%6, %10#0, %10#1, %2 : tensor<128x4096xf16>, tensor<128xf16>, tensor<128xf16>, tensor<4096x64xf16>) + outs(%11 : tensor<128x64xf16>) { + ^bb0(%in: f16, %in_1: f16, %in_2: f16, %in_3: f16, %out: f16): + %14 = arith.subf %in, %in_1 : f16 + %15 = math.exp %14 : f16 + %16 = arith.divf %15, %in_2 : f16 + %17 = arith.mulf %16, %in_3 : f16 + %18 = arith.addf %out, %17 : f16 + linalg.yield %18 : f16 + } -> tensor<128x64xf16> + + bufferization.materialize_in_destination %12 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` + +--- + +## Stage 6: Fully Fused Online Attention (Max-Sum-Matmul Fusion) + +The ultimate optimization fuses **all three operations** (max, sum, and final matmul) into a single generic operation. This computes the attention output in one pass over the sequence dimension. + +### Key Insight +Not only does the sum need rescaling when max changes, but the **accumulated output** must also be rescaled by the same correction factor to maintain correctness. + +```mlir +func.func @attention_final(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %cst = arith.constant 0xFE00 : f16 + %cst_0 = arith.constant 0.000000e+00 : f16 + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %3 = tensor.empty() : tensor<128x4096xf16> + %4 = tensor.empty() : tensor<128x64xf16> + %5 = linalg.fill ins(%cst_0 : f16) outs(%3 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + + // First generic: Q @ K^T + %6 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} + ins(%0, %1 : tensor<128x64xf16>, tensor<4096x64xf16>) outs(%5 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_1: f16, %out: f16): + %14 = arith.mulf %in, %in_1 : f16 + %15 = arith.addf %out, %14 : f16 + linalg.yield %15 : f16 + } -> tensor<128x4096xf16> + + // === FULLY FUSED: Online attention (max + sum + matmul in single pass) === + // Takes: QK^T scores and V matrix + // Returns: (max, sum, attention_output) computed simultaneously + %7 = tensor.empty() : tensor<128xf16> + %8 = linalg.fill ins(%cst : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %9 = linalg.fill ins(%cst_0 : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %10 = linalg.fill ins(%cst_0 : f16) outs(%4 : tensor<128x64xf16>) -> tensor<128x64xf16> + %11:3 = linalg.generic {indexing_maps = [#map, #map6, #map5, #map5, #map2], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%6, %2 : tensor<128x4096xf16>, tensor<4096x64xf16>) + outs(%8, %9, %10 : tensor<128xf16>, tensor<128xf16>, tensor<128x64xf16>) { + ^bb0(%in_qk: f16, %in_v: f16, %out_max: f16, %out_sum: f16, %out_acc: f16): + // Compute new max + %14 = arith.maxnumf %in_qk, %out_max : f16 + // Compute correction factor: exp(old_max - new_max) + %15 = arith.subf %out_max, %14 : f16 + %16 = math.exp %15 : f16 + // Rescale old sum with correction factor + %17 = arith.mulf %out_sum, %16 : f16 + // Rescale old acc with correction factor (CRITICAL!) + %18 = arith.mulf %out_acc, %16 : f16 + // Compute exp(in_qk - new_max) + %19 = arith.subf %in_qk, %14 : f16 + %20 = math.exp %19 : f16 + // Update sum + %21 = arith.addf %17, %20 : f16 + // Compute weighted value + %22 = arith.mulf %20, %in_v : f16 + // Update acc + %23 = arith.addf %18, %22 : f16 + linalg.yield %14, %21, %23 : f16, f16, f16 + } -> (tensor<128xf16>, tensor<128xf16>, tensor<128x64xf16>) + + // Final normalization: divide acc by sum + %12 = linalg.generic {indexing_maps = [#map3, #map4, #map3], iterator_types = ["parallel", "parallel"]} + ins(%11#2, %11#1 : tensor<128x64xf16>, tensor<128xf16>) outs(%4 : tensor<128x64xf16>) { + ^bb0(%in_acc: f16, %in_sum: f16, %out: f16): + %14 = arith.divf %in_acc, %in_sum : f16 + linalg.yield %14 : f16 + } -> tensor<128x64xf16> + + bufferization.materialize_in_destination %12 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md new file mode 100644 index 00000000..d9d4f17f --- /dev/null +++ b/docs/fused_attention_lowering.md @@ -0,0 +1,720 @@ +# Fused Attention Kernel Lowering Flow + +This document describes the multi-stage lowering process for standard attention kernels in MLIR, showing how high-level operations are progressively transformed into hardware-specific GPU code. + +--- + +## Stage 1: Initial Standard Attention + +**Input shape**: `2x8x4096x64xf16` (batch × heads × sequence × head_dim) + +### Key Operations + +```mlir +// Reshape from 4D to 3D: collapse batch and head dimensions +%q_3d = memref.collapse_shape %arg_q [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> +%k_3d = memref.collapse_shape %arg_k [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> +%v_3d = memref.collapse_shape %arg_v [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> + +// Transpose K: [16, 4096, 64] -> [16, 64, 4096] +%k_transposed = linalg.transpose ins(%k_3d : tensor<16x4096x64xf16>) + outs(%empty_kt : tensor<16x64x4096xf16>) permutation = [0, 2, 1] + +// Q @ K^T: [16, 4096, 64] @ [16, 64, 4096] -> [16, 4096, 4096] +%qk_scores = linalg.batch_matmul ins(%q_3d, %k_transposed : ...) + outs(%qk_init : tensor<16x4096x4096xf16>) -> tensor<16x4096x4096xf16> + +// Scale by 1/sqrt(d_k) = 0.125 +%qk_scaled = linalg.mul ins(%qk_scores, %scale_factor : tensor<16x4096x4096xf16>, ...) + -> tensor<16x4096x4096xf16> + +// Softmax over last dimension +%attention_weights = linalg.softmax dimension(2) ins(%qk_scaled : ...) + -> tensor<16x4096x4096xf16> + +// Attention @ V: [16, 4096, 4096] @ [16, 4096, 64] -> [16, 4096, 64] +%output = linalg.batch_matmul ins(%attention_weights, %v_3d : ...) + outs(%output_init : tensor<16x4096x64xf16>) -> tensor<16x4096x64xf16> +``` + +--- + +## Stage 2: Tiling and Softmax Decomposition + + +### Softmax Decomposition +The atomic `linalg.softmax` is decomposed into explicit operations: + +#### **Why decompose so early? :** +Softmax operation did not work well with tile and fuse. +If last matmul is tiled and then if we try to fuse softmax into the `scf.forall` +It does not **bubble-up** the `tensor.extract_slice`. instead entire softmax is +done first (including parallel dims) and then extractions are done from the result to +feed into the last matmul. + +```mlir +// 1. Find max value per row (for numerical stability) +%max_per_row = linalg.generic { + iterator_types = ["parallel", "parallel", "reduction"] +} ins(%qk_scaled) outs(%max_init) { +^bb0(%score: f16, %current_max: f16): + %new_max = arith.maxnumf %score, %current_max : f16 + linalg.yield %new_max : f16 +} -> tensor<16x4096xf16> + +// 2. Compute exp(x - max) for each element +%exp_scores = linalg.generic { + iterator_types = ["parallel", "parallel", "parallel"] +} ins(%qk_scaled, %max_per_row) outs(%exp_init) { +^bb0(%score: f16, %max_val: f16, %out: f16): + %centered = arith.subf %score, %max_val : f16 + %exp_val = math.exp %centered : f16 + linalg.yield %exp_val : f16 +} -> tensor<16x4096x4096xf16> + +// 3. Sum exp values per row +%sum_per_row = linalg.generic { + iterator_types = ["parallel", "parallel", "reduction"] +} ins(%exp_scores) outs(%sum_init) { +^bb0(%exp_val: f16, %current_sum: f16): + %new_sum = arith.addf %exp_val, %current_sum : f16 + linalg.yield %new_sum : f16 +} -> tensor<16x4096xf16> + +// 4. Normalize: divide each element by sum +%attention_weights = linalg.generic { + iterator_types = ["parallel", "parallel", "parallel"] +} ins(%exp_scores, %sum_per_row) outs(%norm_init) { +^bb0(%exp_val: f16, %sum: f16, %out: f16): + %normalized = arith.divf %exp_val, %sum : f16 + linalg.yield %normalized : f16 +} -> tensor<16x4096x4096xf16> +``` + +### Tiling the Output Dimension + +The second matmul (attention @ V) is tiled into 32 tiles of size 128: + +```mlir +%output = scf.forall (%batch_head_idx, %tile_idx) in (16, 32) + shared_outs(%out_accumulator = %output_init) -> (tensor<16x4096x64xf16>) { + %row_offset = affine.apply affine_map<(d0) -> (d0 * 128)>(%tile_idx) + + // Extract 128 rows of attention weights: [1, 128, 4096] + %attention_tile = tensor.extract_slice %attention_weights[%batch_head_idx, %row_offset, 0] + [1, 128, 4096] [1, 1, 1] : tensor<16x4096x4096xf16> to tensor<1x128x4096xf16> + + // Extract all of V: [1, 4096, 64] + %v_tile = tensor.extract_slice %v_3d[%batch_head_idx, 0, 0] [1, 4096, 64] [1, 1, 1] + : tensor<16x4096x64xf16> to tensor<1x4096x64xf16> + + // Compute partial result: [1, 128, 4096] @ [1, 4096, 64] -> [1, 128, 64] + %partial_output = linalg.batch_matmul ins(%attention_tile, %v_tile) + outs(%partial_init : tensor<1x128x64xf16>) -> tensor<1x128x64xf16> + + scf.forall.in_parallel { + tensor.parallel_insert_slice %partial_output into %out_accumulator[%batch_head_idx, %row_offset, 0] + [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16> into tensor<16x4096x64xf16> + } +} +``` + +--- + +## Stage 3: Tiling Batch and Head Dimensions + +### Fusion of Operations + +The entire attention computation is now fused into a single parallel loop: + +```mlir +%output = scf.forall (%batch_head_idx, %tile_idx) in (16, 32) + shared_outs(%out_accumulator = %output_init) -> (tensor<16x4096x64xf16>) { + %row_offset = affine.apply affine_map<(d0) -> (d0 * 128)>(%tile_idx) + + // Extract Q tile: [1, 128, 64] + %q_tile = tensor.extract_slice %q_3d[%batch_head_idx, %row_offset, 0] [1, 128, 64] [1, 1, 1] + : tensor<16x4096x64xf16> to tensor<1x128x64xf16> + + // Extract full K: [1, 4096, 64] + %k_tile = tensor.extract_slice %k_3d[%batch_head_idx, 0, 0] [1, 4096, 64] [1, 1, 1] + : tensor<16x4096x64xf16> to tensor<1x4096x64xf16> + + // Transpose K within tile + %k_tile_transposed = linalg.transpose ins(%k_tile : tensor<1x4096x64xf16>) + outs(%kt_init : tensor<1x64x4096xf16>) permutation = [0, 2, 1] + + // Q @ K^T: [1, 128, 64] @ [1, 64, 4096] -> [1, 128, 4096] + %qk_scores = linalg.batch_matmul ins(%q_tile, %k_tile_transposed : ...) + outs(%qk_init : tensor<1x128x4096xf16>) -> tensor<1x128x4096xf16> + + // Scale + %qk_scaled = linalg.mul ins(%qk_scores, %scale_factor : ...) + -> tensor<1x128x4096xf16> + + // Softmax decomposition (max, exp, sum, normalize) + %max_per_row = linalg.generic { ... } // max reduction -> [1, 128] + %exp_scores = linalg.generic { ... } // exp(x - max) + %sum_per_row = linalg.generic { ... } // sum reduction + %attention_weights = linalg.generic { ... } // normalize + + // Extract V: [1, 4096, 64] + %v_tile = tensor.extract_slice %v_3d[%batch_head_idx, 0, 0] [1, 4096, 64] [1, 1, 1] + : tensor<16x4096x64xf16> to tensor<1x4096x64xf16> + + // Attention @ V: [1, 128, 4096] @ [1, 4096, 64] -> [1, 128, 64] + %partial_output = linalg.batch_matmul ins(%attention_weights, %v_tile : ...) + outs(%partial_init : tensor<1x128x64xf16>) -> tensor<1x128x64xf16> + + scf.forall.in_parallel { + tensor.parallel_insert_slice %partial_output into %out_accumulator[%batch_head_idx, %row_offset, 0] + [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16> into tensor<16x4096x64xf16> + } +} +``` + +Each workgroup now processes: +- 128 rows of Q +- All of K and V (still materializes `128×4096` attention matrix) +- Produces 128 rows of output + +16 × 32 = 512 independent workgroups + +#### Why not introduce the fused attention loop at this point (at linalg level)? + +Tried this approach. This materializes the partial max and sum vectors in the fused +attention algorithm as SLM memory buffers. There is no way to promote these to +registers (*using only upstream magic?*). At this point, vector to xegpu does not support 1D SLM access also (**WIP**). +Because of these resons delayed the introduction of fused loop until after bufferization. + +--- + +## Stage 4: Vectorization + +Linalg operations are converted to vector operations for SIMD execution. + +### First Matmul (Q @ K^T) + +```mlir +// Read K: [4096, 64] +%k_vec = vector.transfer_read %k_3d[%batch_head_idx, %c0, %c0], %poison {in_bounds = [true, true]} + : tensor<16x4096x64xf16>, vector<4096x64xf16> + +// Read Q tile: [128, 64] +%q_vec = vector.transfer_read %q_3d[%batch_head_idx, %row_offset, %c0], %poison {in_bounds = [true, true]} + : tensor<16x4096x64xf16>, vector<128x64xf16> + +// Contract (matmul): [128, 64] @ [4096, 64]^T -> [128, 4096] +%qk_scores_vec = vector.contract { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, // Q: reduce over d2 + affine_map<(d0, d1, d2) -> (d1, d2)>, // K^T: reduce over d2 + affine_map<(d0, d1, d2) -> (d0, d1)> // Output + ], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind +} %q_vec, %k_vec, %zero_init : vector<128x64xf16>, vector<4096x64xf16> into vector<128x4096xf16> +``` + +### Softmax with Vector Reductions + +```mlir +// Scale by 1/sqrt(d_k) +%qk_scaled_vec = arith.mulf %qk_scores_vec, %scale_factor_vec : vector<128x4096xf16> + +// Reshape for reduction: [128, 4096] -> [1, 128, 4096] +%qk_3d = vector.shape_cast %qk_scaled_vec : vector<128x4096xf16> to vector<1x128x4096xf16> + +// Max reduction along last dimension (across sequence length) +%max_per_row_vec = vector.multi_reduction , %qk_3d, %neg_inf_init [2] + : vector<1x128x4096xf16> to vector<1x128xf16> + +// Broadcast max back to full shape for subtraction +%max_broadcast_3d = vector.broadcast %max_per_row_vec : vector<1x128xf16> to vector<4096x1x128xf16> +%max_broadcast_2d = vector.shape_cast %max_broadcast_3d : vector<4096x1x128xf16> to vector<4096x128xf16> +%max_broadcast = vector.transpose %max_broadcast_2d, [1, 0] : vector<4096x128xf16> to vector<128x4096xf16> + +// Exp of (x - max) for numerical stability +%centered_scores = arith.subf %qk_scaled_vec, %max_broadcast : vector<128x4096xf16> +%exp_scores_vec = math.exp %centered_scores : vector<128x4096xf16> + +// Sum reduction to get denominator +%exp_3d = vector.shape_cast %exp_scores_vec : vector<128x4096xf16> to vector<1x128x4096xf16> +%sum_per_row_vec = vector.multi_reduction , %exp_3d, %zero_init [2] + : vector<1x128x4096xf16> to vector<1x128xf16> + +// Broadcast sum for normalization +%sum_broadcast_3d = vector.broadcast %sum_per_row_vec : vector<1x128xf16> to vector<4096x1x128xf16> +%sum_broadcast_2d = vector.shape_cast %sum_broadcast_3d : vector<4096x1x128xf16> to vector<4096x128xf16> +%sum_broadcast = vector.transpose %sum_broadcast_2d, [1, 0] : vector<4096x128xf16> to vector<128x4096xf16> + +// Normalize to get attention weights +%attention_weights_vec = arith.divf %exp_scores_vec, %sum_broadcast : vector<128x4096xf16> +``` + +### Second Matmul (Attention @ V) + +```mlir +// Read V: [4096, 64] +%v_vec = vector.transfer_read %v_3d[%batch_head_idx, %c0, %c0], %poison {in_bounds = [true, true]} + : tensor<16x4096x64xf16>, vector<4096x64xf16> + +// Contract: [128, 4096] @ [4096, 64] -> [128, 64] +%output_vec = vector.contract { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, // Attention weights: reduce over sequence + affine_map<(d0, d1, d2) -> (d2, d1)>, // V: reduce over sequence + affine_map<(d0, d1, d2) -> (d0, d1)> // Output + ], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind +} %attention_weights_vec, %v_vec, %zero_output_init : + vector<128x4096xf16>, vector<4096x64xf16> into vector<128x64xf16> +``` + +--- + +## Stage 5: Bufferization + +Tensors are converted to memrefs (in-place memory buffers): + +```mlir +func.func @payload(%arg_output: memref<2x8x4096x64xf16>, + %arg_q: memref<2x8x4096x64xf16>, + %arg_k: memref<2x8x4096x64xf16>, + %arg_v: memref<2x8x4096x64xf16>) { + // Constants remain as vectors + %zero_vec_128x64 = arith.constant dense<0.000000e+00> : vector<128x64xf16> + %poison = ub.poison : f16 + + // Collapse shapes on memrefs + %q_3d = memref.collapse_shape %arg_q [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> + %k_3d = memref.collapse_shape %arg_k [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> + %v_3d = memref.collapse_shape %arg_v [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> + %output_3d = memref.collapse_shape %arg_output [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> + + scf.forall (%batch_head_idx, %tile_idx) in (16, 32) { + %row_offset = affine.apply affine_map<(d0) -> (d0 * 128)>(%tile_idx) + + // Direct reads from memrefs + %k_vec = vector.transfer_read %k_3d[%batch_head_idx, %c0, %c0], %poison + : memref<16x4096x64xf16>, vector<4096x64xf16> + + // ... computation ... + + // Create subview for output + %output_subview = memref.subview %output_3d[%batch_head_idx, %row_offset, 0] [1, 128, 64] [1, 1, 1] + : memref<16x4096x64xf16> to memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>> + + // Direct write to memref + vector.transfer_write %output_vec, %output_subview[%c0, %c0, %c0] {in_bounds = [true, true]} + : vector<128x64xf16>, memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>> + } +} +``` + +--- + +## Stage 6: Inner Tiling for Fused Attention (Online Softmax) + +1. Implements "online" softmax to avoid materializing the full attention matrix. +2. Tile the K and V loads into `16 x d_head` and interleave with DPAS for lower register preassure. + +### The Online Softmax Algorithm + +Instead of computing the full `128×4096` attention matrix, we process K/V in chunks of 64 and incrementally update: +- Running maximum `m_i` +- Running sum of exponentials `l_i` +- Partial output `O_i` + +```mlir +%final_max:3 = scf.for %kv_chunk_idx = %c0 to %c4096 step %c64 + iter_args(%m_old = %neg_inf_init, %l_old = %zero_init, %O_old = %zero_output_init) + -> (vector<128xf16>, vector<128xf16>, vector<128x64xf16>) { + // m_old = running maximum across previous chunks + // l_old = running sum of exponentials across previous chunks + // O_old = running partial output across previous chunks +``` + +#### Process 64 columns of K at a time (4 chunks of 16) + +```mlir + // Chunk 0: columns [0:16] of K + %k_chunk_0 = vector.transfer_read %k_4d[%batch_idx, %head_idx, %kv_chunk_idx, %c0], %poison + : memref<2x8x4096x64xf16>, vector<16x64xf16> + %k_chunk_0_t = vector.transpose %k_chunk_0, [1, 0] : vector<16x64xf16> to vector<64x16xf16> + %qk_chunk_0 = vector.contract { ... } %q_vec, %k_chunk_0_t, %zero_init : + vector<128x64xf16>, vector<64x16xf16> -> vector<128x16xf16> + + // Chunk 1: columns [16:32] of K + %k_offset_16 = arith.addi %kv_chunk_idx, %c16 : index + %k_chunk_1 = vector.transfer_read %k_4d[%batch_idx, %head_idx, %k_offset_16, %c0], %poison + %k_chunk_1_t = vector.transpose %k_chunk_1, [1, 0] : vector<16x64xf16> to vector<64x16xf16> + %qk_chunk_1 = vector.contract { ... } %q_vec, %k_chunk_1_t, %zero_init : ... -> vector<128x16xf16> + + // Chunk 2: columns [32:48] of K + %k_offset_32 = arith.addi %kv_chunk_idx, %c32 : index + %k_chunk_2 = vector.transfer_read %k_4d[%batch_idx, %head_idx, %k_offset_32, %c0], %poison + %k_chunk_2_t = vector.transpose %k_chunk_2, [1, 0] : vector<16x64xf16> to vector<64x16xf16> + %qk_chunk_2 = vector.contract { ... } %q_vec, %k_chunk_2_t, %zero_init : ... -> vector<128x16xf16> + + // Chunk 3: columns [48:64] of K + %k_offset_48 = arith.addi %kv_chunk_idx, %c48 : index + %k_chunk_3 = vector.transfer_read %k_4d[%batch_idx, %head_idx, %k_offset_48, %c0], %poison + %k_chunk_3_t = vector.transpose %k_chunk_3, [1, 0] : vector<16x64xf16> to vector<64x16xf16> + %qk_chunk_3 = vector.contract { ... } %q_vec, %k_chunk_3_t, %zero_init : ... -> vector<128x16xf16> +``` + +#### Find new maximum across all 4 chunks + +```mlir + // Find max across all 4 chunks + %max_01 = arith.maximumf %qk_chunk_0, %qk_chunk_1 : vector<128x16xf16> + %max_012 = arith.maximumf %max_01, %qk_chunk_2 : vector<128x16xf16> + %max_0123 = arith.maximumf %max_012, %qk_chunk_3 : vector<128x16xf16> + %max_chunk_per_row = vector.multi_reduction , %max_0123, %neg_inf_init [1] : + vector<128x16xf16> -> vector<128xf16> + + // Scale and update running maximum + %max_chunk_scaled = arith.mulf %max_chunk_per_row, %scale_factor_vec : vector<128xf16> + %m_new = arith.maximumf %m_old, %max_chunk_scaled : vector<128xf16> +``` + +#### Compute exponentials for each chunk + +```mlir + // Broadcast m_new for subtraction from all chunks + %m_new_3d = vector.broadcast %m_new : vector<128xf16> to vector<16x128xf16> + %m_new_broadcast = vector.transpose %m_new_3d, [1, 0] : vector<16x128xf16> to vector<128x16xf16> + + // exp(chunk_0 * scale - m_new) + %qk_chunk_0_scaled = arith.mulf %qk_chunk_0, %scale_factor_2d : vector<128x16xf16> + %qk_chunk_0_centered = arith.subf %qk_chunk_0_scaled, %m_new_broadcast : vector<128x16xf16> + %exp_chunk_0 = math.exp %qk_chunk_0_centered : vector<128x16xf16> + + // exp(chunk_1 * scale - m_new) + %qk_chunk_1_scaled = arith.mulf %qk_chunk_1, %scale_factor_2d : vector<128x16xf16> + %qk_chunk_1_centered = arith.subf %qk_chunk_1_scaled, %m_new_broadcast : vector<128x16xf16> + %exp_chunk_1 = math.exp %qk_chunk_1_centered : vector<128x16xf16> + + // exp(chunk_2 * scale - m_new) + %qk_chunk_2_scaled = arith.mulf %qk_chunk_2, %scale_factor_2d : vector<128x16xf16> + %qk_chunk_2_centered = arith.subf %qk_chunk_2_scaled, %m_new_broadcast : vector<128x16xf16> + %exp_chunk_2 = math.exp %qk_chunk_2_centered : vector<128x16xf16> + + // exp(chunk_3 * scale - m_new) + %qk_chunk_3_scaled = arith.mulf %qk_chunk_3, %scale_factor_2d : vector<128x16xf16> + %qk_chunk_3_centered = arith.subf %qk_chunk_3_scaled, %m_new_broadcast : vector<128x16xf16> + %exp_chunk_3 = math.exp %qk_chunk_3_centered : vector<128x16xf16> +``` + +#### Update sum of exponentials + +```mlir + // Sum exponentials across the 4 chunks + %sum_01 = arith.addf %exp_chunk_0, %exp_chunk_1 : vector<128x16xf16> + %sum_012 = arith.addf %sum_01, %exp_chunk_2 : vector<128x16xf16> + %sum_0123 = arith.addf %sum_012, %exp_chunk_3 : vector<128x16xf16> + %l_chunk = vector.multi_reduction , %sum_0123, %zero_init [1] : + vector<128x16xf16> -> vector<128xf16> + + // Correction factor for previous chunks: exp(m_old - m_new) + %m_delta = arith.subf %m_old, %m_new : vector<128xf16> + %correction_factor = math.exp %m_delta : vector<128xf16> + + // Update running sum: l_new = l_old * correction + l_chunk + %l_old_corrected = arith.mulf %l_old, %correction_factor : vector<128xf16> + %l_new = arith.addf %l_old_corrected, %l_chunk : vector<128xf16> +``` + +#### Update partial output + +```mlir + // Rescale old output by correction factor + %correction_3d = vector.broadcast %correction_factor : vector<128xf16> to vector<64x128xf16> + %correction_broadcast = vector.transpose %correction_3d, [1, 0] : vector<64x128xf16> to vector<128x64xf16> + %O_old_corrected = arith.mulf %O_old, %correction_broadcast : vector<128x64xf16> + + // Load corresponding 64 rows of V (chunk 0: rows [0:16]) + %v_chunk_0 = vector.transfer_read %v_4d[%batch_idx, %head_idx, %kv_chunk_idx, %c0], %poison + : memref<2x8x4096x64xf16>, vector<16x64xf16> + + // Accumulate: O += exp_chunk_0 @ V[0:16, :] + %O_partial_0 = vector.contract { ... } %exp_chunk_0, %v_chunk_0, %O_old_corrected : + vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> + + // Accumulate: O += exp_chunk_1 @ V[16:32, :] + %v_chunk_1 = vector.transfer_read %v_4d[%batch_idx, %head_idx, %k_offset_16, %c0], %poison + %O_partial_1 = vector.contract { ... } %exp_chunk_1, %v_chunk_1, %O_partial_0 : + vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> + + // Accumulate: O += exp_chunk_2 @ V[32:48, :] + %v_chunk_2 = vector.transfer_read %v_4d[%batch_idx, %head_idx, %k_offset_32, %c0], %poison + %O_partial_2 = vector.contract { ... } %exp_chunk_2, %v_chunk_2, %O_partial_1 : + vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> + + // Accumulate: O += exp_chunk_3 @ V[48:64, :] + %v_chunk_3 = vector.transfer_read %v_4d[%batch_idx, %head_idx, %k_offset_48, %c0], %poison + %O_new = vector.contract { ... } %exp_chunk_3, %v_chunk_3, %O_partial_2 : + vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> + + scf.yield %m_new, %l_new, %O_new : vector<128xf16>, vector<128xf16>, vector<128x64xf16> +} +``` + +#### Final normalization + +```mlir +// Extract final values from loop +%m_final = %final_max#0 : vector<128xf16> +%l_final = %final_max#1 : vector<128xf16> +%O_accumulated = %final_max#2 : vector<128x64xf16> + +// Broadcast sum to full output shape for normalization +%l_final_3d = vector.broadcast %l_final : vector<128xf16> to vector<64x128xf16> +%l_final_broadcast = vector.transpose %l_final_3d, [1, 0] : vector<64x128xf16> to vector<128x64xf16> + +// Normalize: O_final = O_accumulated / l_final +%output_normalized = arith.divf %O_accumulated, %l_final_broadcast : vector<128x64xf16> + +// Write result back to output buffer +vector.transfer_write %output_normalized, %output_4d[%batch_idx, %head_idx, %row_offset, %c0] + {in_bounds = [true, true]} : vector<128x64xf16>, memref<2x8x4096x64xf16> +``` + +--- + +## Stage 7: GPU Outlining + +`scf.forall` loop is distirbuted to workgroups. + +```mlir +module attributes {gpu.container_module} { + func.func @payload(%arg_output: memref<2x8x4096x64xf16>, + %arg_q: memref<2x8x4096x64xf16>, + %arg_k: memref<2x8x4096x64xf16>, + %arg_v: memref<2x8x4096x64xf16>) { + %c128 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + + // Launch GPU kernel + gpu.launch_func @payload_kernel::@payload_kernel + blocks in (%c16, %c32, %c1) // Grid: 16 × 32 × 1 (batch×head, seq_tiles, 1) + threads in (%c128, %c1, %c1) // Block: 128 × 1 × 1 + args(%arg_q : memref<2x8x4096x64xf16>, + %arg_k : memref<2x8x4096x64xf16>, + %arg_v : memref<2x8x4096x64xf16>, + %arg_output : memref<2x8x4096x64xf16>) + return + } + + gpu.module @payload_kernel { + gpu.func @payload_kernel(%q: memref<2x8x4096x64xf16>, + %k: memref<2x8x4096x64xf16>, + %v: memref<2x8x4096x64xf16>, + %output: memref<2x8x4096x64xf16>) kernel + attributes { + known_block_size = array, + known_grid_size = array + } { + + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + + // Compute batch and head indices from block_id_x + %row_offset = arith.muli %block_id_y, %c128 : index + %batch_idx = arith.floordivsi %block_id_x, %c8 : index + %head_idx = arith.remsi %block_id_x, %c8 : index + + // ... computation from Stage 6 ... + + gpu.return + } + } +} +``` + +**Grid Configuration**: +- 16 blocks in X (2 batches × 8 heads) +- 32 blocks in Y (4096 / 128 = 32 tiles) +- Each block has 128 threads (8 subgroups) + +--- + +## Stage 8: Converting to XeGPU Operations + +Vector operations are converted to Intel XeGPU-specific operations using tensor descriptors. + +### Tensor Descriptor Creation + +```mlir +// Create descriptor for reading Q +%q_subview = memref.subview %q[%batch_idx, %head_idx, 0, 0] [1, 1, 4096, 64] [1, 1, 1, 1] + : memref<2x8x4096x64xf16> to memref<4096x64xf16, strided<[64, 1], offset: ?>> + +%q_base_buffer, %q_offset, %q_sizes:2, %q_strides:2 = + memref.extract_strided_metadata %q_subview : memref<4096x64xf16, strided<[64, 1], offset: ?>> + -> memref, index, index, index, index, index + +%q_intptr = memref.extract_aligned_pointer_as_index %q_base_buffer : memref -> index +%q_byte_offset = arith.muli %q_offset, %c2 : index // Multiply by sizeof(f16) +%q_addr = arith.addi %q_intptr, %q_byte_offset : index +%q_addr_i64 = arith.index_cast %q_addr : index to i64 + +// Create XeGPU tensor descriptor for Q +%q_tdesc = xegpu.create_nd_tdesc %q_addr_i64, shape : [4096, 64], strides : [64, 1] : i64 + -> !xegpu.tensor_desc<128x64xf16, #xegpu.block_tdesc_attr> +``` + +### XeGPU Load Operations + +```mlir +// Load Q tile using tensor descriptor at row_offset +%q_vec = xegpu.load_nd %q_tdesc[%row_offset, 0] + : !xegpu.tensor_desc<128x64xf16, #xegpu.block_tdesc_attr> + -> vector<128x64xf16> + +// Load K chunk at kv_chunk_idx +%k_chunk_vec = xegpu.load_nd %k_tdesc[%kv_chunk_idx, 0] + : !xegpu.tensor_desc<16x64xf16, #xegpu.block_tdesc_attr> + -> vector<16x64xf16> +``` + +### XeGPU DPAS (Dot Product Accumulate Systolic) + +The key hardware instruction for matrix multiplication: + +```mlir +// Transpose K chunk for matmul +%k_chunk_t = vector.transpose %k_chunk_vec, [1, 0] : vector<16x64xf16> to vector<64x16xf16> + +// DPAS: Specialized matmul instruction +// Q: [128, 64], K^T: [64, 16], accumulator: [128, 16] +%qk_chunk = xegpu.dpas %q_vec, %k_chunk_t, %zero_accumulator + : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf16> + -> vector<128x16xf16> +``` + +### XeGPU Store Operations + +```mlir +// Create output descriptor (similar to Q descriptor creation) +%output_addr_i64 = arith.index_cast %output_addr : index to i64 +%output_tdesc = xegpu.create_nd_tdesc %output_addr_i64, shape : [4096, 64], strides : [64, 1] : i64 + -> !xegpu.tensor_desc<128x64xf16, #xegpu.block_tdesc_attr> + +// Store result at row_offset +xegpu.store_nd %output_normalized, %output_tdesc[%row_offset, 0] + : vector<128x64xf16>, + !xegpu.tensor_desc<128x64xf16, #xegpu.block_tdesc_attr> +``` + +--- + +## Stage 9: Setting XeGPU Layouts + +Assign layouts for distributing to subgroups. + +### Load Layout for Q (128×64) + +```mlir +%q_vec = xegpu.load_nd %q_tdesc[%row_offset, 0] <{ + layout = #xegpu.layout< + sg_layout = [8, 1], // 8 subgroups in row, 1 in column + sg_data = [16, 64], // Each SG handles 16×64 data + inst_data = [8, 16] // Each instruction: 8×16 + > +}> : !xegpu.tensor_desc<128x64xf16, ...> -> vector<128x64xf16> +``` + +**Breakdown**: +- 8 subgroups handle rows: 8 × 16 = 128 rows ✓ +- 1 subgroup handles columns: 1 × 64 = 64 columns ✓ +- Each instruction processes 8×16 elements + +### DPAS Layout (Q @ K^T → Attention Scores) + +```mlir +%qk_chunk = xegpu.dpas %q_vec, %k_chunk_t, %zero_accumulator { + layout_a = #xegpu.layout< + sg_layout = [8, 1], // Q: 8 SGs vertically, 1 horizontally + sg_data = [16, 64], // Each SG: 16 rows × 64 cols + inst_data = [8, 16] // Each DPAS: 8 rows × 16 cols + >, + layout_b = #xegpu.layout< + sg_layout = [1, 8], // K^T: 1 SG vertically, 8 horizontally + sg_data = [64, 16], // Each SG: 64 rows × 16 cols + inst_data = [16, 16], // Each DPAS: 16 × 16 + order = [0, 1] + >, + layout_cd = #xegpu.layout< + sg_layout = [8, 1], // Output: 8 SGs vertically + sg_data = [16, 16], // Each SG: 16 rows × 16 cols + inst_data = [8, 16] // Each DPAS output: 8 × 16 + > +} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf16> + -> vector<128x16xf16> +``` + +**Matrix Multiplication Mapping**: +- Input A (Q): 8 subgroups × (16 rows × 64 cols) = 128×64 +- Input B (K^T): 8 subgroups × (64 rows × 16 cols) = 64×128 (transposed) +- Output C (Scores): 8 subgroups × (16 rows × 16 cols) = 128×16 + +### DPAS Layout (Attention @ V → Output) + +```mlir +%O_partial = xegpu.dpas %exp_chunk, %v_chunk, %O_old_corrected { + layout_a = #xegpu.layout< + sg_layout = [8, 1], // Exp attention: 8 SGs vertically + sg_data = [16, 16], // Each SG: 16 rows × 16 cols + inst_data = [8, 16] // Each DPAS: 8 × 16 + >, + layout_b = #xegpu.layout< + sg_layout = [8, 1], // V: 8 SGs vertically (not transposed) + sg_data = [16, 64], // Each SG: 16 rows × 64 cols + inst_data = [16, 16] // Each DPAS: 16 × 16 + >, + layout_cd = #xegpu.layout< + sg_layout = [8, 1], // Output: 8 SGs vertically + sg_data = [16, 64], // Each SG: 16 rows × 64 cols + inst_data = [8, 16] // Each DPAS output: 8 × 16 + > +} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf16> + -> vector<128x64xf16> +``` + +### Store Layout + +```mlir +xegpu.store_nd %output_normalized, %output_tdesc[%row_offset, 0] <{ + layout = #xegpu.layout< + sg_layout = [8, 1], // 8 SGs vertically, 1 horizontally + sg_data = [16, 64], // Each SG stores: 16 rows × 64 cols + inst_data = [8, 16] // Each store instruction: 8 × 16 + > +}> : vector<128x64xf16>, !xegpu.tensor_desc<128x64xf16, ...> +``` + +--- + +## Conclusion + +This lowering flow demonstrates a sophisticated compilation strategy: +1. Start with intuitive high-level operations +2. Progressively expose parallelism through tiling +3. Apply memory-saving algorithms (online softmax) +4. Lower to vector operations for SIMD +5. Map to hardware-specific instructions (DPAS) +6. Optimize data layout for hardware execution + +The result is a highly optimized fused attention kernel that maximizes throughput while minimizing memory footprint. diff --git a/docs/softmax_tiling.md b/docs/softmax_tiling.md new file mode 100644 index 00000000..fb20db54 --- /dev/null +++ b/docs/softmax_tiling.md @@ -0,0 +1,319 @@ +# Softmax Tiling and Fusion + +--- + +## Stage 0: Initial Softmax + +**What we have:** A single high-level `linalg.softmax` operation that computes softmax across dimension 1. + +**Softmax formula:** For each row $i$: + +$$\text{softmax}(x)_{i,j} = \frac{\exp(x_{i,j} - \max_k x_{i,k})}{\sum_k \exp(x_{i,k} - \max_k x_{i,k})}$$ + +```mlir +func.func @softmax_v0(%arg0: memref<64x1024xf32>, %arg1: memref<64x1024xf32>) { + %0 = bufferization.to_tensor %arg0 restrict writable : memref<64x1024xf32> to tensor<64x1024xf32> + %1 = tensor.empty() : tensor<64x1024xf32> + %2 = linalg.softmax dimension(1) ins(%0 : tensor<64x1024xf32>) outs(%1 : tensor<64x1024xf32>) -> tensor<64x1024xf32> + bufferization.materialize_in_destination %2 in restrict writable %arg1 : (tensor<64x1024xf32>, memref<64x1024xf32>) -> () + return +} +``` + +--- + +## Stage 1: Softmax Decomposition + +**Transform applied:** `transform.structured.decompose_interface` + +**What happens:** The high-level softmax operation is decomposed into 5 primitive `linalg.generic` operations: + +1. **Max reduction** (`%4`): Find the maximum value along each row +2. **Subtract and exp** (`%5`): Compute `exp(x - max)` elementwise +3. **Fill for sum** (`%6`): Initialize accumulator for sum reduction +4. **Sum reduction** (`%7`): Sum the exp values along each row +5. **Final division** (`%8`): Divide each exp value by the sum + +```mlir +func.func @softmax_v0(%arg0: memref<64x1024xf32>, %arg1: memref<64x1024xf32>) { + %0 = bufferization.to_tensor %arg0 restrict writable : memref<64x1024xf32> to tensor<64x1024xf32> + %1 = tensor.empty() : tensor<64x1024xf32> + %2 = tensor.empty() : tensor<64xf32> + %cst = arith.constant 0xFFC00000 : f32 + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + + // 1. Max reduction: compute max along each row + %4 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%0 : tensor<64x1024xf32>) outs(%3 : tensor<64xf32>) { + ^bb0(%in: f32, %out: f32): + %9 = arith.maxnumf %in, %out : f32 + linalg.yield %9 : f32 + } -> tensor<64xf32> + + // 2. Subtract max and apply exp: exp(x - max) + %5 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel"] + } ins(%0, %4 : tensor<64x1024xf32>, tensor<64xf32>) outs(%1 : tensor<64x1024xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %9 = arith.subf %in, %in_1 : f32 + %10 = math.exp %9 : f32 + linalg.yield %10 : f32 + } -> tensor<64x1024xf32> + + %cst_0 = arith.constant 0.000000e+00 : f32 + %6 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + + // 3. Sum reduction: sum of exp values + %7 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%5 : tensor<64x1024xf32>) outs(%6 : tensor<64xf32>) { + ^bb0(%in: f32, %out: f32): + %9 = arith.addf %in, %out : f32 + linalg.yield %9 : f32 + } -> tensor<64xf32> + + // 4. Final division: exp(x - max) / sum + %8 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel"] + } ins(%5, %7 : tensor<64x1024xf32>, tensor<64xf32>) outs(%1 : tensor<64x1024xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %9 = arith.divf %in, %in_1 : f32 + linalg.yield %9 : f32 + } -> tensor<64x1024xf32> + + bufferization.materialize_in_destination %8 in restrict writable %arg1 : (tensor<64x1024xf32>, memref<64x1024xf32>) -> () + return +} +``` + +--- + +## Stage 2: Elementwise Operation Fusion + +**Transform applied:** `transform.apply_registered_pass "linalg-fuse-elementwise-ops"` with some modifications. + +**What happens:** The compiler fuses operations that can be computed together to reduce memory traffic: + +1. The elementwise `exp(x - max)` operation (`%5`) is fused into the sum reduction (`%6`) +2. The elementwise `exp(x - max)` and division are fused into the final output (`%7`) + +```mlir +func.func @softmax_v0(%arg0: memref<64x1024xf32>, %arg1: memref<64x1024xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 0xFFC00000 : f32 + %0 = bufferization.to_tensor %arg0 restrict writable : memref<64x1024xf32> to tensor<64x1024xf32> + %1 = tensor.empty() : tensor<64x1024xf32> + %2 = tensor.empty() : tensor<64xf32> + %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + + // 1. Max reduction (unchanged) + %4 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%0 : tensor<64x1024xf32>) outs(%3 : tensor<64xf32>) { + ^bb0(%in: f32, %out: f32): + %8 = arith.maxnumf %in, %out : f32 + linalg.yield %8 : f32 + } -> tensor<64xf32> + + %5 = linalg.fill ins(%cst : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + + // 2. Fused: exp(x - max) and sum reduction + // This computes exp on-the-fly during the reduction + %6 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0)> + ], + iterator_types = ["parallel", "reduction"] + } ins(%0, %4 : tensor<64x1024xf32>, tensor<64xf32>) outs(%5 : tensor<64xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %8 = arith.subf %in, %in_1 : f32 + %9 = math.exp %8 : f32 + %10 = arith.addf %9, %out : f32 + linalg.yield %10 : f32 + } -> tensor<64xf32> + + // 3. Fused: exp(x - max) and division + // This computes exp again and immediately divides by sum + %7 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel"] + } ins(%0, %4, %6 : tensor<64x1024xf32>, tensor<64xf32>, tensor<64xf32>) outs(%1 : tensor<64x1024xf32>) { + ^bb0(%in: f32, %in_1: f32, %in_2: f32, %out: f32): + %8 = arith.subf %in, %in_1 : f32 + %9 = math.exp %8 : f32 + %10 = arith.divf %9, %in_2 : f32 + linalg.yield %10 : f32 + } -> tensor<64x1024xf32> + + bufferization.materialize_in_destination %7 in restrict writable %arg1 : (tensor<64x1024xf32>, memref<64x1024xf32>) -> () + return +} +``` + +--- + +## Stage 3: Max-Sum Reduction Fusion (Online Softmax) + + +**What happens:** The max reduction and sum reduction are fused into a single pass over the data using the "online softmax" algorithm. This algorithm maintains running max and running sum values, applying a correction factor when the max changes. + +**Why this works?** + +Consider the reduction loop at an arbitrary stage $i$. We have access to: + +$$m_{i-1} = \max(x_0, x_1, \ldots, x_{i-1})$$ + +$$s_{i-1} = \sum_{j=0}^{i-1} \exp(x_j - m_{i-1})$$ + +When we process element $x_i$, we first compute the updated max: + +$$m_i = \max(m_{i-1}, x_i)$$ + +Now we need to compute the updated sum. The challenge is that $s_{i-1}$ was computed with respect to $m_{i-1}$, but we need the sum with respect to the new max $m_i$. + +Using the property $\exp(a - b) = \frac{\exp(a)}{\exp(b)} = \exp(a) \cdot \exp(-b)$, we can rewrite each term: + +$$\exp(x_j - m_i) = \exp(x_j - m_{i-1} - (m_i - m_{i-1})) = \exp(x_j - m_{i-1}) \cdot \exp(m_{i-1} - m_i)$$ + +Since multiplication distributes over addition, we can factor out the correction term from the entire sum: + +$$\sum_{j=0}^{i-1} \exp(x_j - m_i) = \exp(m_{i-1} - m_i) \cdot \sum_{j=0}^{i-1} \exp(x_j - m_{i-1}) = \exp(m_{i-1} - m_i) \cdot s_{i-1}$$ + +Therefore, the corrected sum at stage $i$ is: + +$$s_i = \exp(m_{i-1} - m_i) \cdot s_{i-1} + \exp(x_i - m_i)$$ + +This shows that we can maintain running max and sum values in a single pass, applying a multiplicative correction factor $\exp(m_{i-1} - m_i)$ whenever the max changes. + +```mlir +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @softmax_v2(%arg0: memref<64x1024xf32>, %arg1: memref<64x1024xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 0xFFC00000 : f32 + %0 = bufferization.to_tensor %arg0 restrict writable : memref<64x1024xf32> to tensor<64x1024xf32> + %1 = tensor.empty() : tensor<64x1024xf32> + %2 = tensor.empty() : tensor<64xf32> + %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + %4 = linalg.fill ins(%cst : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + + // Fused max and sum reduction with online correction + %5:2 = linalg.generic { + indexing_maps = [#map, #map1, #map1], + iterator_types = ["parallel", "reduction"] + } ins(%0 : tensor<64x1024xf32>) outs(%3, %4 : tensor<64xf32>, tensor<64xf32>) { + ^bb0(%in: f32, %out_max: f32, %out_sum: f32): + // Update max + %new_max = arith.maxnumf %in, %out_max : f32 + + // Compute correction factor for sum when max changes + // correction = exp(old_max - new_max) + %max_diff = arith.subf %out_max, %new_max : f32 + %correction = math.exp %max_diff : f32 + %corrected_sum = arith.mulf %out_sum, %correction : f32 + + // Add new contribution: exp(value - new_max) + %val_diff = arith.subf %in, %new_max : f32 + %exp_val = math.exp %val_diff : f32 + %new_sum = arith.addf %corrected_sum, %exp_val : f32 + + linalg.yield %new_max, %new_sum : f32, f32 + } -> (tensor<64xf32>, tensor<64xf32>) + + // Final division to compute softmax + %6 = linalg.generic { + indexing_maps = [#map, #map1, #map1, #map], + iterator_types = ["parallel", "parallel"] + } ins(%0, %5#0, %5#1 : tensor<64x1024xf32>, tensor<64xf32>, tensor<64xf32>) outs(%1 : tensor<64x1024xf32>) { + ^bb0(%in: f32, %in_max: f32, %in_sum: f32, %out: f32): + %7 = arith.subf %in, %in_max : f32 + %8 = math.exp %7 : f32 + %9 = arith.divf %8, %in_sum : f32 + linalg.yield %9 : f32 + } -> tensor<64x1024xf32> + + bufferization.materialize_in_destination %6 in restrict writable %arg1 : (tensor<64x1024xf32>, memref<64x1024xf32>) -> () + return +} +``` + +--- + +## Does This Always Work? + +**Short answer:** No. The online correction technique only works when the correction can be applied distributively over the aggregation operation. + +### Why Softmax Works + +The key to the online softmax algorithm is that **multiplication distributes over addition**: + +$$c \cdot (a + b) = c \cdot a + c \cdot b$$ + +Combined with the exponential property $\exp(a - b) = \exp(a) \cdot \exp(-b)$, this allows us to apply a multiplicative correction factor to the entire accumulated sum. + +### A Counter-Example: What If There Was No Exp? + +Consider a hypothetical variant where we want to compute: + +$$\text{result}_j = \frac{x_j - m}{s}$$ + +where $m = \max_k x_k$ and $s = \sum_k (x_k - m)$. + +Can we fuse the max and sum reductions using online correction? Let's try: + +At stage $i$, we have: +- $m_{i-1} = \max(x_0, x_1, \ldots, x_{i-1})$ +- $s_{i-1} = \sum_{j=0}^{i-1} (x_j - m_{i-1})$ + +When processing $x_i$: +- $m_i = \max(m_{i-1}, x_i)$ + +To correct $s_{i-1}$ for the new max $m_i$, we need: + +$$s_i = \sum_{j=0}^{i-1} (x_j - m_i) + (x_i - m_i)$$ + +Expanding the first term: + +$$\sum_{j=0}^{i-1} (x_j - m_i) = \sum_{j=0}^{i-1} (x_j - m_{i-1} - (m_i - m_{i-1}))$$ + +$$= \sum_{j=0}^{i-1} (x_j - m_{i-1}) - \sum_{j=0}^{i-1} (m_i - m_{i-1})$$ + +$$= s_{i-1} - i \cdot (m_i - m_{i-1})$$ + +**The problem:** To correct $s_{i-1}$, we need to subtract $(m_i - m_{i-1})$ from each of the $i$ terms we've seen so far. This requires knowing the count $i$ and performing $i$ subtractions worth of correction. **This is not possibel at linalg level** (maybe its possible when we materialize the loops?) + +**Why subtraction doesn't work:** Subtraction is **not distributive** over addition in the way we need: + +$(a + b) - c \neq (a - c) + (b - c)$ when we want a single correction + +We would need to track both the sum and the count of elements, and the correction becomes additive rather than multiplicative: $s_i = s_{i-1} - i \cdot \Delta m$, where the correction depends on how many elements we've processed. + +### General Principle + +The online correction technique works when: + +1. **The operation allows factoring out corrections**: The transformation applied to each element (like $\exp(x - m)$) can be rewritten to separate the correction term +2. **The correction is distributive over the reduction**: The correction can be applied to the accumulated result as a whole, not element-by-element +3. **The correction is independent of cardinality**: We don't need to know how many elements contributed to the partial result diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py new file mode 100644 index 00000000..fa4efc57 --- /dev/null +++ b/examples/xegpu/fused_attention.py @@ -0,0 +1,387 @@ +# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# CHECK: module attributes {gpu.container_module} { + +""" +XeGPU fused attention benchmark. +""" + +import argparse +from typing import Optional +from functools import cached_property + +import numpy as np +from mlir import ir + +from lighthouse import dialects as lh_dialects +from lighthouse.execution.runner import Runner +from lighthouse.pipeline.driver import TransformDriver +from lighthouse.execution import GPUMemoryManager +from lighthouse.utils.numpy import mlir_to_numpy_dtype +from lighthouse.ingress.mlir_gen import get_mlir_elem_type +from lighthouse.ingress.mlir_gen.gpu_fused_attention_payload import ( + generate_gpu_fused_attention_payload, +) +from lighthouse.schedule.xegpu import fused_attention_schedule, xegpu_to_binary + + +def fused_attention_complexity(Z: int, H: int, n_ctx: int, n_head: int, nbytes: int): + """ + Complexity of fused attention operation. + + For each batch and head: + - Q @ K^T: O(n_ctx^2 * n_head) operations + - Softmax: O(n_ctx^2) operations + - Attention @ V: O(n_ctx^2 * n_head) operations + Total: approximately 2*n_ctx^2*n_head FLOPs per batch and head + """ + # Approximation: 2 * n_ctx^2 * n_head FLOPs per batch and head + flop_count = Z * H * 2 * n_ctx * n_ctx * n_head + # Memory: read Q, K, V and write output + memory_reads = 3 * Z * H * n_ctx * n_head * nbytes + memory_writes = Z * H * n_ctx * n_head * nbytes + return flop_count, memory_reads, memory_writes + + +def check_correctness( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + output_arr: np.ndarray, + verbose: int = 0, +) -> bool: + """ + Check correctness of fused attention output. + + Reference implementation: + - scores = Q @ K^T / sqrt(n_head) + - attention_weights = softmax(scores, dim=-1) + - output = attention_weights @ V + """ + # Use float32 for computation + Q_f32 = Q.astype(np.float32) + K_f32 = K.astype(np.float32) + V_f32 = V.astype(np.float32) + + Z, H, n_ctx, n_head = Q.shape + scale = 1.0 / np.sqrt(n_head) + + output_ref = np.zeros_like(Q_f32) + + # Compute reference for each batch and head + for z in range(Z): + for h in range(H): + # scores = Q @ K^T / sqrt(n_head) + scores = Q_f32[z, h] @ K_f32[z, h].T * scale + + # softmax along last dimension + max_vals = np.max(scores, axis=1, keepdims=True) + exp_vals = np.exp(scores - max_vals) + sum_vals = np.sum(exp_vals, axis=1, keepdims=True) + attention_weights = exp_vals / sum_vals + + # output = attention_weights @ V + output_ref[z, h] = attention_weights @ V_f32[z, h] + + output = output_arr.astype(np.float32) + + if verbose > 1: + print("Reference solution (first batch, first head, first 5 rows):") + print(output_ref[0, 0, :5]) + print("Computed solution (first batch, first head, first 5 rows):") + print(output[0, 0, :5]) + + # Check values match reference + values_ok = np.allclose(output, output_ref, rtol=1e-3, atol=1e-3) + success = values_ok + + if verbose: + if success: + print("PASSED") + else: + print("FAILED!") + if not values_ok: + max_diff = np.abs(output - output_ref).max() + print(f" Values mismatch. Max abs diff: {max_diff:.6e}") + return success + + +class XeGPUFusedAttention: + """ + Fused attention workload on XeGPU. + + Computes fused attention: + output = softmax(Q @ K^T / sqrt(n_head)) @ V + + All Q, K, V matrices have shape (Z, H, n_ctx, n_head) where: + - Z: batch size + - H: number of heads + - n_ctx: context length + - n_head: head dimension + """ + + def __init__( + self, + Z: int, + H: int, + n_ctx: int, + n_head: int, + dtype: str = "f16", + ): + self.Z = Z + self.H = H + self.n_ctx = n_ctx + self.n_head = n_head + self.shape = (Z, H, n_ctx, n_head) + assert dtype == "f16", "Only f16 type is supported for fused attention" + self.elem_type = get_mlir_elem_type(dtype) + self.dtype = mlir_to_numpy_dtype(self.elem_type) + self.memory_manager_class = GPUMemoryManager + self.payload_function_name = "payload" + + @cached_property + def _initial_host_arrays(self) -> tuple[np.ndarray]: + """Generate initial values on host with numpy.""" + np.random.seed(42) + # Initialize Q, K, V with small random values + Q = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + K = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + V = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + output_arr = np.zeros(self.shape, dtype=self.dtype) + return (output_arr, Q, K, V) + + def get_complexity(self) -> tuple[int, int, int]: + nbytes = np.dtype(self.dtype).itemsize + return fused_attention_complexity( + self.Z, self.H, self.n_ctx, self.n_head, nbytes + ) + + def payload_module(self) -> ir.Module: + """Generate MLIR module for fused attention payload.""" + return generate_gpu_fused_attention_payload( + func_name=self.payload_function_name, + Z=self.Z, + H=self.H, + n_ctx=self.n_ctx, + n_head=self.n_head, + dtype=self.elem_type, + ) + + def schedule_modules( + self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None + ) -> list[ir.Module]: + """Generate transform schedule for fused attention.""" + schedules = [] + schedules.append(Runner.get_bench_wrapper_schedule(self.payload_function_name)) + + schedules.append( + fused_attention_schedule( + stop_at_stage=stop_at_stage, + parameters=parameters, + ) + ) + + if stop_at_stage and stop_at_stage != "final": + return schedules + + schedules.append(xegpu_to_binary()) + + return schedules + + def shared_libs(self) -> list[str]: + return ["libmlir_levelzero_runtime.so"] + + +def parse_cli(): + parser = argparse.ArgumentParser( + description="Fused Attention using MLIR XeGPU", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--batch-size", + type=int, + default=2, + help="Batch size (Z)", + ) + parser.add_argument( + "--num-heads", + type=int, + default=8, + help="Number of attention heads (H)", + ) + parser.add_argument( + "--n-ctx", + type=int, + default=4096, + help="Context length (sequence length)", + ) + parser.add_argument( + "--n-head", + type=int, + default=64, + help="Head dimension", + ) + parser.add_argument( + "--wg-rows", + type=int, + default=128, + help="Number of Q*K^T*V rows computed by each work group", + ) + parser.add_argument( + "--sg-rows", + type=int, + default=16, + help="Number of Q*K^T*V rows computed by each subgroup", + ) + parser.add_argument( + "--subgroup-size", + type=int, + default=16, + help="Subgroup size", + ) + parser.add_argument( + "--inner-loop-tile-size", + type=int, + default=64, + help="Tile size for the inner reduction dimension (K/V sequence length)", + ) + parser.add_argument( + "--nruns", + type=int, + default=1000, + help="Number of runs to average the execution time.", + ) + parser.add_argument( + "--nwarmup", + type=int, + default=20, + help="Number of warm-up iterations before benchmarking.", + ) + parser.add_argument( + "--check-result", + action="store_true", + help="Check the result of the fused attention computation.", + ) + parser.add_argument( + "--dump-kernel", + type=str, + choices=[ + "initial", + "outer-tiled", + "inner-tiled", + "vectorized", + "bufferized", + "gpu-outlining", + "xegpu-initial", + "xegpu-wg", + "final", + ], + help="Dump kernel IR at different stages of lowering and exit without " + "executing the kernel.", + ) + parser.add_argument( + "--dump-schedule", + action="store_true", + help="Dump transform schedule.", + ) + parser.add_argument( + "--verbose", + "-v", + action="count", + default=0, + help="Increase output verbosity (e.g. print reference and computed solutions).", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_cli() + + params = { + "batch_size": args.batch_size, + "num_heads": args.num_heads, + "n_ctx": args.n_ctx, + "n_head": args.n_head, + "wg_rows": args.wg_rows, + "sg_rows": args.sg_rows, + "subgroup_size": args.subgroup_size, + "inner_loop_tile_size": args.inner_loop_tile_size, + } + + Z = args.batch_size + H = args.num_heads + n_ctx = args.n_ctx + n_head = args.n_head + dtype = "f16" + + with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUFusedAttention(Z=Z, H=H, n_ctx=n_ctx, n_head=n_head, dtype=dtype) + + if args.dump_kernel or args.dump_schedule: + pipeline = TransformDriver( + wload.schedule_modules( + stop_at_stage=args.dump_kernel, parameters=params + ) + ) + payload = pipeline.apply(wload.payload_module()) + if args.dump_kernel: + print(payload) + if args.dump_schedule: + for schedule_module in wload.schedule_modules(parameters=params): + print(schedule_module) + else: + pipeline = TransformDriver(wload.schedule_modules(parameters=params)) + payload = pipeline.apply(wload.payload_module()) + runner = Runner( + payload, + mem_manager_cls=wload.memory_manager_class, + shared_libs=wload.shared_libs(), + ) + if args.check_result: + # Setup callback function to copy result from device to host. + result_host_copy = np.zeros(wload.shape, dtype=wload.dtype) + argument_access_callback = Runner.get_gpu_argument_access_callback( + result_host_copy, arg_index=0 + ) + + # Execute kernel once. + runner.execute( + host_input_buffers=wload._initial_host_arrays, + payload_function_name=wload.payload_function_name, + argument_access_callback=argument_access_callback, + ) + + # Compute reference solution on host. + Q, K, V = wload._initial_host_arrays[1:4] + success = check_correctness( + Q, + K, + V, + result_host_copy, + verbose=args.verbose, + ) + if not success: + raise ValueError("Result mismatch!") + else: + print("Result is correct. Proceeding to benchmark...") + + times = runner.benchmark( + host_input_buffers=wload._initial_host_arrays, + nruns=args.nruns, + nwarmup=args.nwarmup, + ) + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + print( + f"batch-size={Z} " + f"num-heads={H} " + f"n-ctx={n_ctx} " + f"n-head={n_head} " + f"dt={dtype} " + f"time(us): {elapsed:.2f} " + f"GFLOPS: {gflops:.2f} " + ) diff --git a/lighthouse/dialects/transform/transform_ext/__init__.py b/lighthouse/dialects/transform/transform_ext/__init__.py index 997522a2..eec36b6e 100644 --- a/lighthouse/dialects/transform/transform_ext/__init__.py +++ b/lighthouse/dialects/transform/transform_ext/__init__.py @@ -10,11 +10,13 @@ from .ops.get_tileable_consumers import get_tileable_consumers from .ops.get_tiling_sizes import get_tiling_sizes from .ops.update_address_space import update_address_space +from .ops.generate_fused_attention import generate_fused_attention __all__ = [ "TransformExtensionDialect", "convert_func_results_to_args", "extract_handle", + "generate_fused_attention", "get_named_attribute", "get_named_attribute", "get_tileable_consumers", diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py new file mode 100644 index 00000000..747a3376 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -0,0 +1,518 @@ +"""Transform extension to generate fused attention computation.""" + +import numpy as np +from mlir import ir +from mlir.dialects import ext, transform, arith, scf, math, vector +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect + + +class GenerateFusedAttention( + TransformExtensionDialect.Operation, name="generate_fused_attention" +): + """Generate tiled fused attention computation (flash attention optimization). + + Takes Q, K, V loads and scale constant from bufferized IR, and generates an inner + tiled loop that computes fused attention with online softmax using running max and sum. + + This implements the flash attention algorithm where: + 1. The computation is tiled along the reduction dimension (K/V sequence length) + 2. Online max and sum are maintained across tiles + 3. Output is incrementally updated with rescaled contributions + + Args: + q_load: Handle to Q load operation (vector.transfer_read) + k_load: Handle to K load operation (vector.transfer_read) + v_load: Handle to V load operation (vector.transfer_read) + scale: Handle to scale constant operation (arith.constant) + output: Handle to the output operation to replace (vector.contract) + tile_size: Tile size for the reduction dimension tiling (K/V sequence length) + """ + + q_load: ext.Operand[transform.AnyOpType] + k_load: ext.Operand[transform.AnyOpType] + v_load: ext.Operand[transform.AnyOpType] + scale: ext.Operand[transform.AnyOpType] + output: ext.Operand[transform.AnyOpType] + tile_size: ir.IntegerAttr + new_output: ext.Result[transform.AnyOpType[()]] = ext.infer_result() + + @classmethod + def attach_interface_impls(cls, ctx=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "GenerateFusedAttention", + rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + # Get payload operations + q_load_ops = state.get_payload_ops(op.q_load) + k_load_ops = state.get_payload_ops(op.k_load) + v_load_ops = state.get_payload_ops(op.v_load) + scale_ops = state.get_payload_ops(op.scale) + output_ops = state.get_payload_ops(op.output) + + if ( + len(q_load_ops) != 1 + or len(k_load_ops) != 1 + or len(v_load_ops) != 1 + or len(scale_ops) != 1 + or len(output_ops) != 1 + ): + return DiagnosedSilenceableFailure.emit_silenceable_error( + "Expected exactly one operation for each operand" + ) + + q_load_op = q_load_ops[0] + k_load_op = k_load_ops[0] + v_load_op = v_load_ops[0] + scale_op = scale_ops[0] + output_op = output_ops[0] + + # Extract the scale scalar value from scale_op (arith.constant) + scale_attr = scale_op.attributes["value"] + scale_dense_attr = ir.DenseElementsAttr(scale_attr) + scale_np_array = np.array(scale_dense_attr) + scale_value = float(scale_np_array.flat[0]) + + # Extract wg_rows and d_head from q_load result type + q_load_result = q_load_op.results[0] + q_vector_type = ir.VectorType(q_load_result.type) + wg_rows = q_vector_type.shape[0] + d_head = q_vector_type.shape[1] + + # Get tile size + tile_size_value = ir.IntegerAttr(op.tile_size).value + + # Get element type from q_load result + element_type = q_vector_type.element_type + + # Build the fused attention computation + with ir.InsertionPoint(output_op): + # Define m_i_init: vector of shape [wg_rows] with neg_inf values + # NOTE: We use float32 for the initial neg_inf values and cast to the element type + # to avoid issues with representing -inf. + m_i_vector_type = ir.VectorType.get([wg_rows], element_type) + m_i_vector_type_f32 = ir.VectorType.get([wg_rows], ir.F32Type.get()) + neg_inf_value = float("-inf") + m_i_values = np.full( + wg_rows, + neg_inf_value, + dtype=np.float32, + ) + m_i_init_attr = ir.DenseElementsAttr.get( + m_i_values, type=m_i_vector_type_f32 + ) + m_i_init_f32 = arith.constant(m_i_vector_type_f32, m_i_init_attr) + m_i_init = arith.truncf(m_i_vector_type, m_i_init_f32) + + # Define l_i_init: vector of shape [wg_rows] with zero values + l_i_vector_type = ir.VectorType.get([wg_rows], element_type) + l_i_values = np.zeros( + wg_rows, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + l_i_init_attr = ir.DenseElementsAttr.get( + l_i_values, type=l_i_vector_type + ) + l_i_init = arith.constant(l_i_vector_type, l_i_init_attr) + + # Define acc_init: vector of shape [wg_rows, d_head] with zero values + acc_vector_type = ir.VectorType.get([wg_rows, d_head], element_type) + acc_values = np.zeros( + (wg_rows, d_head), + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + acc_init_attr = ir.DenseElementsAttr.get( + acc_values, type=acc_vector_type + ) + acc_init = arith.constant(acc_vector_type, acc_init_attr) + + # Get n_ctx from k_load result type (first dimension size) + k_load_result = k_load_op.results[0] + k_vector_type = ir.VectorType(k_load_result.type) + n_ctx = k_vector_type.shape[0] + # Define scale vector: vector of shape [wg_rows] with the scale value + scale_vector_type = ir.VectorType.get([wg_rows], element_type) + scale_values = np.full( + (wg_rows), + scale_value, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + scale_init_attr = ir.DenseElementsAttr.get( + scale_values, type=scale_vector_type + ) + scale_vector = arith.constant(scale_vector_type, scale_init_attr) + + # Create loop bounds + index_type = ir.IndexType.get() + c0 = arith.constant(index_type, 0) + c_n_ctx = arith.constant(index_type, n_ctx) + c_tile_size = arith.constant(index_type, tile_size_value) + + # Create scf.for loop that iterates from 0 to n_ctx in steps of tile_size + loop = scf.ForOp( + c0, c_n_ctx, c_tile_size, [m_i_init, l_i_init, acc_init] + ) + + with ir.InsertionPoint(loop.body): + # Get the loop induction variable and iter_args + loop_idx = loop.induction_variable + m_i = loop.inner_iter_args[0] + l_i = loop.inner_iter_args[1] + acc = loop.inner_iter_args[2] + + # Get common values for K/V tiling + k_memref = k_load_op.operands[0] + k_load_indices = list(k_load_op.operands[1:-1]) + padding = k_load_op.operands[-1] + in_bounds = k_load_op.attributes.get("in_bounds", None) + k_perm_map = k_load_op.attributes.get("permutation_map", None) + q_value = q_load_op.results[0] + + # Constants for K/V tiling (tile into chunks of 16) + k_subtile_size = 16 + num_k_tiles = tile_size_value // k_subtile_size + + # Create offset constants for each K tile + k_tile_offsets = [] + for i in range(num_k_tiles): + offset = arith.constant(index_type, i * k_subtile_size) + k_tile_offsets.append(offset) + + # Load and process K tiles (unrolled) + # Each K tile is [16, d_head], transposed to [d_head, 16], contracted to [wg_rows, 16] + qkt_chunks = [] + + # Create affine maps for Q@K contraction (used for all tiles) + affine_d0 = ir.AffineExpr.get_dim(0) + affine_d1 = ir.AffineExpr.get_dim(1) + affine_d2 = ir.AffineExpr.get_dim(2) + + q_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d2]) + k_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) + out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) + + indexing_maps = ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get(q_map), + ir.AffineMapAttr.get(k_map), + ir.AffineMapAttr.get(out_map), + ] + ) + + iterator_types = ir.ArrayAttr.get( + [ + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ] + ) + + # Accumulator for Q@K chunks + qkt_chunk_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + qkt_chunk_acc_values = np.zeros( + (wg_rows, k_subtile_size), + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + qkt_chunk_acc_attr = ir.DenseElementsAttr.get( + qkt_chunk_acc_values, type=qkt_chunk_type + ) + qkt_chunk_acc = arith.constant(qkt_chunk_type, qkt_chunk_acc_attr) + + for tile_idx in range(num_k_tiles): + # Compute the offset index for this tile + k_tile_offset = arith.addi(loop_idx, k_tile_offsets[tile_idx]) + + # Update indices for this K tile + k_tile_indices = k_load_indices.copy() + k_tile_indices[-2] = k_tile_offset + + # Load K tile: [16, d_head] + k_tile_type = ir.VectorType.get( + [k_subtile_size, d_head], element_type + ) + k_tile = vector.TransferReadOp( + k_tile_type, + k_memref, + k_tile_indices, + k_perm_map, + padding, + in_bounds=in_bounds, + ).result + + # Transpose K tile: [16, d_head] -> [d_head, 16] + k_transpose_type = ir.VectorType.get( + [d_head, k_subtile_size], element_type + ) + k_transpose = vector.transpose(k_transpose_type, k_tile, [1, 0]) + + # Contract Q @ K_transpose: [wg_rows, d_head] @ [d_head, 16] -> [wg_rows, 16] + qkt_chunk = vector.contract( + qkt_chunk_type, + q_value, + k_transpose, + qkt_chunk_acc, + indexing_maps=indexing_maps, + iterator_types=iterator_types, + ) + qkt_chunks.append(qkt_chunk) + + # Elementwise maximum across all Q@K chunks + # Build tree of maximumf operations + qkt_max_combined = qkt_chunks[0] + for i in range(1, num_k_tiles): + qkt_max_combined = arith.maximumf( + qkt_max_combined, qkt_chunks[i] + ) + + # Final multi_reduction to get row-wise max: [wg_rows, 16] -> [wg_rows] + qkt_max = vector.multi_reduction( + kind="maxnumf", + source=qkt_max_combined, + acc=m_i_init, + reduction_dims=[1], + ) + + # Scale the max: qkt_max_scaled = qkt_max * scale + # Both have shape [wg_rows] + qkt_max_scaled = arith.mulf(qkt_max, scale_vector) + + # Compute m_ij = max(m_i, qkt_max_scaled) + # Both have shape [wg_rows] + m_ij = arith.maximumf(m_i, qkt_max_scaled) + + # Apply softmax to each Q@K chunk + # Scale constant for chunks: [wg_rows, 16] + scale_chunk_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + scale_chunk_values = np.full( + (wg_rows, k_subtile_size), + scale_value, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + scale_chunk_attr = ir.DenseElementsAttr.get( + scale_chunk_values, type=scale_chunk_type + ) + scale_chunk = arith.constant(scale_chunk_type, scale_chunk_attr) + + # Broadcast m_ij from [wg_rows] to [wg_rows, 16] + m_ij_bcasted_type = ir.VectorType.get( + [k_subtile_size, wg_rows], element_type + ) + m_ij_bcasted = vector.broadcast(m_ij_bcasted_type, m_ij) + m_ij_transposed_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + m_ij_transposed = vector.transpose( + m_ij_transposed_type, m_ij_bcasted, [1, 0] + ) + + # Apply softmax to each chunk + qkt_exp_chunks = [] + for qkt_chunk in qkt_chunks: + # Scale: qkt_scaled = qkt_chunk * scale + qkt_scaled = arith.mulf(qkt_chunk, scale_chunk) + + # Center: qkt_centered = qkt_scaled - m_ij_transposed + qkt_centered = arith.subf(qkt_scaled, m_ij_transposed) + + # Exponential: qkt_exp = exp(qkt_centered) + qkt_exp = math.exp(qkt_centered) + qkt_exp_chunks.append(qkt_exp) + + # Elementwise sum across all exp chunks + qkt_exp_combined = qkt_exp_chunks[0] + for i in range(1, num_k_tiles): + qkt_exp_combined = arith.addf( + qkt_exp_combined, qkt_exp_chunks[i] + ) + + # Final multi_reduction to get row-wise sum: [wg_rows, 16] -> [wg_rows] + l_ij = vector.multi_reduction( + kind="add", + source=qkt_exp_combined, + acc=l_i_init, + reduction_dims=[1], + ) + + # Compute alpha = exp(m_i - m_ij) + m_diff = arith.subf(m_i, m_ij) + alpha = math.exp(m_diff) + + # Update l_i: l_i_updated = l_i * alpha + l_ij + l_i_scaled = arith.mulf(l_i, alpha) + l_i_updated = arith.addf(l_i_scaled, l_ij) + + # Broadcast alpha from [wg_rows] to [wg_rows, d_head] + alpha_bcasted_type = ir.VectorType.get( + [d_head, wg_rows], element_type + ) + alpha_bcasted = vector.broadcast(alpha_bcasted_type, alpha) + alpha_transposed_type = ir.VectorType.get( + [wg_rows, d_head], element_type + ) + alpha_transposed = vector.transpose( + alpha_transposed_type, alpha_bcasted, [1, 0] + ) + + # Update accumulator: acc_updated = acc * alpha_bcasted + acc_updated = arith.mulf(acc, alpha_transposed) + + # Load V tiles and compute attention-weighted values + # Get V load parameters + v_memref = v_load_op.operands[0] + v_load_indices = list(v_load_op.operands[1:-1]) + v_padding = v_load_op.operands[-1] + v_in_bounds = v_load_op.attributes.get("in_bounds", None) + v_perm_map = v_load_op.attributes.get("permutation_map", None) + + # Create affine maps for P@V contraction + qkt_exp_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d2]) + v_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) + pv_out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) + + indexing_maps_pv = ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get(qkt_exp_map), + ir.AffineMapAttr.get(v_map), + ir.AffineMapAttr.get(pv_out_map), + ] + ) + + iterator_types_pv = ir.ArrayAttr.get( + [ + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ] + ) + + # Load and process V tiles (unrolled), accumulating results + pv_out = acc_updated + for tile_idx in range(num_k_tiles): + # Compute the offset index for this V tile + v_tile_offset = arith.addi(loop_idx, k_tile_offsets[tile_idx]) + + # Update indices for this V tile + v_tile_indices = v_load_indices.copy() + v_tile_indices[-2] = v_tile_offset + + # Load V tile: [16, d_head] + v_tile_type = ir.VectorType.get( + [k_subtile_size, d_head], element_type + ) + v_tile = vector.TransferReadOp( + v_tile_type, + v_memref, + v_tile_indices, + v_perm_map, + v_padding, + in_bounds=v_in_bounds, + ).result + + # Contract qkt_exp_chunk @ v_tile: [wg_rows, 16] @ [16, d_head] -> [wg_rows, d_head] + # Accumulate into pv_out + pv_out = vector.contract( + acc_vector_type, + qkt_exp_chunks[tile_idx], + v_tile, + pv_out, + indexing_maps=indexing_maps_pv, + iterator_types=iterator_types_pv, + ) + + # Yield the updated iter args + scf.yield_([m_ij, l_i_updated, pv_out]) + + # Extract the final accumulator result (3rd output) from the loop + pv_out = loop.results[2] + l_i_out = loop.results[1] + with ir.InsertionPoint.after(loop): + # Normalize the output: output_final = pv_out / l_i_out + # Need to broadcast l_i_out from [wg_rows] to [wg_rows, d_head] + l_i_out_bcasted_type = ir.VectorType.get( + [d_head, wg_rows], element_type + ) + l_i_out_bcasted = vector.broadcast(l_i_out_bcasted_type, l_i_out) + l_i_out_transposed_type = ir.VectorType.get( + [wg_rows, d_head], element_type + ) + l_i_out_transposed = vector.transpose( + l_i_out_transposed_type, l_i_out_bcasted, [1, 0] + ) + output_final = arith.divf(pv_out, l_i_out_transposed) + + # Replace all uses of the original output operation with the final loop result + output_op.results[0].replace_all_uses_with(output_final) + + # Erase the original output operation + rewriter.erase_op(output_op) + + # Return the final output handle + results.set_ops(op.new_output, [output_final.owner]) + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "GenerateFusedAttention") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: ir.Operation, effects): + # Read Q, K, scale, V slices + transform.only_reads_handle(op.op_operands[:4], effects) + # Consume and replace output + transform.consumes_handle(op.op_operands[4:5], effects) + # Produce new output handle + transform.produces_handle(op.results, effects) + # Modify the payload + transform.modifies_payload(effects) + + +def generate_fused_attention( + q_load: ir.Value, + k_load: ir.Value, + v_load: ir.Value, + scale: ir.Value, + output: ir.Value, + tile_size: int | ir.IntegerAttr, +) -> ir.Value: + """Generate fused attention computation with inner tiling on bufferized IR. + + Args: + q_load: Handle to Q load operation (vector.transfer_read) + k_load: Handle to K load operation (vector.transfer_read) + v_load: Handle to V load operation (vector.transfer_read) + scale: Handle to scale constant operation (arith.constant) + output: Handle to output operation to replace (vector.contract) + tile_size: Tile size for the reduction dimension tiling (K/V sequence length) + + Returns: + Handle to the new output operation + """ + if not isinstance(tile_size, ir.IntegerAttr): + tile_size = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), tile_size) + + return GenerateFusedAttention( + q_load, k_load, v_load, scale, output, tile_size=tile_size + ).new_output diff --git a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py new file mode 100644 index 00000000..73873dc6 --- /dev/null +++ b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py @@ -0,0 +1,136 @@ +"""Generate MLIR payload for GPU fused attention operation.""" + +import math + +from mlir import ir +from mlir.dialects import arith, bufferization, linalg, memref, tensor + +from lighthouse.utils.mlir import func_cif +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs +from lighthouse.ingress.mlir_gen.utils import emit_buf_to_tensor + + +def generate_gpu_fused_attention_payload( + func_name: str, + Z: int, + H: int, + n_ctx: int, + n_head: int, + dtype: ir.Type, +) -> ir.Module: + """ + Generate MLIR module for fused attention payload. + + Computes fused attention: + output = softmax(Q @ K^T / sqrt(n_head)) @ V + + Args: + func_name: Name of the payload function + Z: Batch size + H: Number of attention heads + n_ctx: Context length (sequence length) + n_head: Head dimension + dtype: MLIR element type (e.g., F32Type) + + Returns: + MLIR module containing the fused attention payload function + """ + mod = ir.Module.create() + shape = (Z, H, n_ctx, n_head) + memref_t = ir.MemRefType.get(shape, dtype) + + with ir.InsertionPoint(mod.body): + # Collapse first 2 dimensions (Z, H) into a batch dimension + # From (Z, H, n_ctx, n_head) to (Z*H, n_ctx, n_head) + batch_dim = Z * H + collapsed_shape_3d = (batch_dim, n_ctx, n_head) + memref_3d_t = ir.MemRefType.get(collapsed_shape_3d, dtype) + + # Function signature: payload(output, Q, K, V) + @func_cif(memref_t, memref_t, memref_t, memref_t, name=func_name) + def payload(output, Q_arg, K_arg, V_arg): + # Collapse memrefs from 4D to 3D + Q_3d_memref = memref.collapse_shape( + memref_3d_t, + Q_arg, + reassociation=[[0, 1], [2], [3]], + ) + K_3d_memref = memref.collapse_shape( + memref_3d_t, + K_arg, + reassociation=[[0, 1], [2], [3]], + ) + V_3d_memref = memref.collapse_shape( + memref_3d_t, + V_arg, + reassociation=[[0, 1], [2], [3]], + ) + output_3d_memref = memref.collapse_shape( + memref_3d_t, + output, + reassociation=[[0, 1], [2], [3]], + ) + + # Convert 3D memrefs to tensors + Q_3d = emit_buf_to_tensor(Q_3d_memref, restrict=True) + K_3d = emit_buf_to_tensor(K_3d_memref, restrict=True) + V_3d = emit_buf_to_tensor(V_3d_memref, restrict=True) + + # Step 1: Transpose K to get K^T + # Permute from (batch_dim, n_ctx, n_head) to (batch_dim, n_head, n_ctx) + kt_shape_3d = (batch_dim, n_head, n_ctx) + kt_init = tensor.empty(kt_shape_3d, dtype) + K_transposed = linalg.transpose(K_3d, outs=[kt_init], permutation=[0, 2, 1]) + + # Step 2: Compute Q @ K^T using batch_matmul + # Q: (batch_dim, n_ctx, n_head) @ K^T: (batch_dim, n_head, n_ctx) + # Result: (batch_dim, n_ctx, n_ctx) + qkt_shape_3d = (batch_dim, n_ctx, n_ctx) + qkt_init = tensor.empty(qkt_shape_3d, dtype) + # Initialize with zeros for matmul accumulation + zero = arith.constant(dtype, 0.0) + qkt_init_filled = linalg.fill(zero, outs=[qkt_init]) + + # Batch matmul: Q @ K^T + qkt = linalg.batch_matmul(Q_3d, K_transposed, outs=[qkt_init_filled]) + + # Step 3: Scale by 1/sqrt(n_head) + scale_factor = 1.0 / math.sqrt(n_head) + scale_const = arith.constant(dtype, scale_factor) + + # Create a tensor filled with the scale factor + scale_tensor_init = tensor.empty(qkt_shape_3d, dtype) + scale_tensor = linalg.fill(scale_const, outs=[scale_tensor_init]) + + # Elementwise multiply qkt with scale tensor + scaled_qkt_init = tensor.empty(qkt_shape_3d, dtype) + scaled_qkt = linalg.mul(qkt, scale_tensor, outs=[scaled_qkt_init]) + + # Step 4: Apply softmax along the last dimension (dim=2 in 3D) + softmax_init = tensor.empty(qkt_shape_3d, dtype) + attention_weights = linalg.softmax( + result=[ir.RankedTensorType.get(qkt_shape_3d, dtype)], + input=scaled_qkt, + output=softmax_init, + dimension=2, + ) + + # Step 5: Multiply attention weights by V using batch_matmul + # attention_weights: (batch_dim, n_ctx, n_ctx) @ V: (batch_dim, n_ctx, n_head) + # Result: (batch_dim, n_ctx, n_head) + output_3d_init = tensor.empty(collapsed_shape_3d, dtype) + output_3d_init_filled = linalg.fill(zero, outs=[output_3d_init]) + + result_3d = linalg.batch_matmul( + attention_weights, V_3d, outs=[output_3d_init_filled] + ) + + # Materialize 3D result back to 3D output memref + bufferization.materialize_in_destination( + None, result_3d, output_3d_memref, restrict=True, writable=True + ) + + # Emit utility functions for GPU memory management + emit_gpu_util_funcs(dtype, rank=4) + + return mod diff --git a/lighthouse/schedule/xegpu/__init__.py b/lighthouse/schedule/xegpu/__init__.py index 23d9ef0c..76f4101c 100644 --- a/lighthouse/schedule/xegpu/__init__.py +++ b/lighthouse/schedule/xegpu/__init__.py @@ -1,8 +1,10 @@ from .xegpu_to_binary import xegpu_to_binary from .mlp_schedule import mlp_schedule from .softmax_schedule import softmax_schedule +from .fused_attention_schedule import fused_attention_schedule __all__ = [ + "fused_attention_schedule", "mlp_schedule", "softmax_schedule", "xegpu_to_binary", diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py new file mode 100644 index 00000000..e350e195 --- /dev/null +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -0,0 +1,499 @@ +"""Generate MLIR transform schedule for XeGPU fused attention operation.""" + +from typing import Optional + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured, loop, xegpu +from mlir.dialects.transform import bufferization as transform_bufferization +from mlir.dialects.bufferization import LayoutMapOption +from mlir.dialects.transform.vector import ( + apply_patterns_vector_cast_away_vector_leading_one_dim, + apply_patterns_vector_drop_unit_dims_with_shape_cast, +) + +from lighthouse.pipeline.helper import ( + canonicalize, + match, + match_and_split, + PipelineInterrupt, + apply_registered_pass, +) +from lighthouse.schedule import schedule_boilerplate +from lighthouse.dialects.transform.transform_ext import ( + generate_fused_attention, + update_address_space, +) + + +def fused_attention_schedule( + stop_at_stage: Optional[str] = None, + parameters: Optional[dict] = None, +) -> ir.Module: + """ + Generate transform schedule for fused attention operation. + + The schedule performs the following transformations: + 1. Tile the fused attention operation + 2. Vectorize operations + 3. Bufferize tensors + 4. Convert to GPU dialect + 5. Lower to XeGPU operations + + Args: + stop_at_stage: Optional stage name to stop early (for debugging) + parameters: Dictionary with scheduling parameters: + - batch_size: Batch size (Z) + - num_heads: Number of attention heads (H) + - n_ctx: Context length + - n_head: Head dimension + - wg_rows: Number of Q*K^T*V rows computed by each work group + - sg_rows: Number of Q*K^T*V rows computed by each subgroup + - subgroup_size: Size of subgroup + + Returns: + MLIR module containing the transform schedule + """ + assert parameters is not None, "Schedule parameters must be provided" + + with schedule_boilerplate() as (schedule, named_seq): + # match the payload module + anytype = transform.AnyOpType.get() + func = match(named_seq.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + + try: + bundle_xegpu_fused_attention_schedule( + payload_mod, + parameters=parameters, + stop_at_stage=stop_at_stage or "", + ) + except PipelineInterrupt: + pass + finally: + transform.yield_() + + return schedule + + +def bundle_xegpu_fused_attention_schedule( + mod: ir.Value[transform.AnyOpType], + parameters: dict, + stop_at_stage: str = "", +) -> ir.Value[transform.AnyOpType]: + """Schedule for lowering fused attention payload to xegpu wg level.""" + + transform.print_(target=mod, name="Initial standard attention:") + + if stop_at_stage == "initial": + raise PipelineInterrupt() + + anytype = transform.AnyOpType.get() + # Match all matmul operations - there should be 2: + # 1. Q @ K^T + # 2. attention_weights @ V + matmul_ops = match_and_split(mod, ops={"linalg.batch_matmul"}, nhandles=2) + + # Get the last matmul (attention_weights @ V) + last_matmul = matmul_ops[1] + func = transform.get_parent_op( + anytype, + last_matmul, + op_name="func.func", + deduplicate=True, + ) + + # Tile the last matmul in both batch and M dimensions. + wg_rows = parameters["wg_rows"] + + tiled_matmul, forall_loop = structured.structured_tile_using_forall( + anytype, + anytype, + last_matmul, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(1, wg_rows, 0, 0), + ) + # Fuse the zero initialization of the output of the last matmul (tensor.empty) into the forall loop. + tiled_matmul_init = transform.get_producer_of_operand( + anytype, forall_loop, operand_number=0 + ) + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=tiled_matmul_init, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + # Decompose softmax into generic ops + softmax_ops = match_and_split(func, ops={"linalg.softmax"}, nhandles=1) + softmax_op = softmax_ops[0] + structured.structured_decompose_interface(anytype, softmax_op) + transform.apply_cse(func) + canonicalize(func) + + transform.print_( + target=func, + name="After tiling and fusing batch dimension, and softmax decomposition:", + ) + + # Fuse all linalg.generic ops from softmax decomposition (4 ops: max, sub+exp, sum, div) + # Match and fuse in reverse order (from consumer to producer) + generic_ops = match_and_split(func, ops={"linalg.generic"}, nhandles=4) + for generic_op in reversed(generic_ops): + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=generic_op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + # Max and add reductions use linalg.fill to intialize the reduction output. Fuse these fill ops as well. + fill_ops = match_and_split(func, ops={"linalg.fill"}, nhandles=5) + # Max fill is the third fill op and add fill is the fourth fill op (based on the pattern of decomposition) + max_fill_op = fill_ops[2] + add_fill_op = fill_ops[3] + for fill_op in [max_fill_op, add_fill_op]: + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=fill_op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + linalg_mul_op = match_and_split(func, ops={"linalg.mul"}, nhandles=1)[0] + first_matmul = transform.get_producer_of_operand( + anytype, linalg_mul_op, operand_number=0 + ) + scale_fill_op = transform.get_producer_of_operand( + anytype, linalg_mul_op, operand_number=1 + ) + transpose_op = transform.get_producer_of_operand( + anytype, first_matmul, operand_number=1 + ) + matmul_fill_op = transform.get_producer_of_operand( + anytype, first_matmul, operand_number=2 + ) + for op in [ + linalg_mul_op, + scale_fill_op, + first_matmul, + matmul_fill_op, + transpose_op, + ]: + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + transform.print_( + target=func, name="After tiling and fustion of batch and head dimensions:" + ) + + if stop_at_stage == "outer-tiled": + raise PipelineInterrupt() + + # vectorize + func = structured.VectorizeChildrenAndApplyPatternsOp( + func, + fold_type_extensions_into_contract=True, + ).result + transform.apply_cse(func) + canonicalize(func) + # Try to remove any unit dimensions that may have been introduced due to tiling (e.g. batch dim of 1) + with ir.InsertionPoint(transform.apply_patterns(func).patterns): + apply_patterns_vector_cast_away_vector_leading_one_dim() + apply_patterns_vector_drop_unit_dims_with_shape_cast() + + transform.print_(target=func, name="After vectorization:") + + if stop_at_stage == "vectorized": + raise PipelineInterrupt() + + # bufferize + mod = apply_registered_pass(mod, "eliminate-empty-tensors") + identity_layout = LayoutMapOption.IdentityLayoutMap + mod = transform_bufferization.OneShotBufferizeOp( + mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=identity_layout, + ).result + transform.apply_cse(mod) + canonicalize(mod) + + transform.print_(target=mod, name="After bufferization:") + if stop_at_stage == "bufferized": + raise PipelineInterrupt() + # fold memref.subviews into vector.transfer_read/write ops + mod = apply_registered_pass(mod, "fold-memref-alias-ops") + transform.apply_cse(mod) + canonicalize(mod) + + # promote memref.alloc to memref.alloca in payload function + func = match(mod, ops={"func.func"}) + func = apply_registered_pass( + func, + "promote-buffers-to-stack", + options={ + "max-alloc-size-in-bytes": "8192", + "max-rank-of-allocated-memref": "2", + }, + ) + + # Extract q, k, v memrefs from the bufferized IR + # Match vector.contract ops to find the q, k, v loads + for_all = match(mod, ops={"scf.forall"}) + func = transform.get_parent_op(anytype, for_all, op_name="func.func") + contract_ops = match_and_split(func, ops={"vector.contract"}, nhandles=2) + + # First vector.contract is Q @ K^T + # Its first operand is the q load (vector.transfer_read) + # Its second operand is the k load (vector.transfer_read) + first_contract = contract_ops[0] + q_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=0 + ) + k_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=1 + ) + + # # Second vector.contract is attention_weights @ V + # # Its second operand is the v load (vector.transfer_read) + second_contract = contract_ops[1] + v_load = transform.get_producer_of_operand( + anytype, second_contract, operand_number=1 + ) + + # Match arith.mulf to get the scale parameter + # The scale is the second operand of arith.mulf (the constant) + mulf_op = match_and_split(func, ops={"arith.mulf"}, nhandles=1)[0] + scale = transform.get_producer_of_operand(anytype, mulf_op, operand_number=1) + + # Generate fused attention computation with inner tiling + # This replaces the second vector.contract (attention_weights @ V) with a tiled + # loop that implements online softmax for efficient memory usage + tile_size = parameters.get( + "inner_loop_tile_size", 64 + ) # Tile size for reduction dimension (K/V sequence length) + generate_fused_attention( + q_load=q_load, + k_load=k_load, + v_load=v_load, + scale=scale, + output=second_contract, + tile_size=tile_size, + ) + transform.apply_cse(func) + canonicalize(func) + transform.print_( + target=func, name="After generating fused attention with inner tiling:" + ) + + if stop_at_stage == "inner-tiled": + raise PipelineInterrupt() + + # convert forall to parallel + wg_loops = match_and_split(mod, ops={"scf.forall"}) + for wg_loop in wg_loops: + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) + + # convert scf.parallel to gpu.launch + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + + # set the number of threads for the gpu.launch operation + launch_op = match_and_split(func, ops={"gpu.launch"}) + wg_rows = parameters["wg_rows"] + sg_rows = parameters["sg_rows"] + subgroup_size = parameters["subgroup_size"] + num_subgroups = wg_rows // sg_rows + num_threads = num_subgroups * subgroup_size + xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) + + # outline gpu func + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + mod = apply_registered_pass(mod, "gpu-kernel-outlining") + transform.apply_cse(mod) + + transform.print_(target=mod, name="After GPU outlining:") + + if stop_at_stage == "gpu-outlining": + raise PipelineInterrupt() + + # set xevm target + mod = apply_registered_pass( + mod, + "xevm-attach-target", + options={"O": "3", "chip": "bmg"}, + ) + + # for each gpu function in the gpu module, change memref.alloca address + # space to 3 (SLM) and convert vector to xegpu. + gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}) + for gpu_mod in gpu_mod_ops: + gpu_func = match(gpu_mod, ops={"gpu.func"}) + allocas = match_and_split(gpu_func, ops={"memref.alloca"}, nhandles=3) + for alloca in allocas: + # print("Updating address space for alloca:") + update_address_space(alloca, address_space=3) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) + gpu_func = apply_registered_pass(gpu_func, "loop-invariant-code-motion") + + transform.print_(target=gpu_func, name="After converting vector to xegpu:") + + if stop_at_stage == "xegpu-initial": + raise PipelineInterrupt() + + # Define XeGPU layout parameters + n_head = parameters["n_head"] + q_sg_layout = [num_subgroups, 1] + q_sg_data = [16, n_head] + q_inst_data = [8, 16] + + k_sg_layout = [num_subgroups, 1] + k_sg_data = [16, n_head] + k_inst_data = [16, 16] + + v_sg_layout = k_sg_layout + v_sg_data = k_sg_data + v_inst_data = k_inst_data + + kt_sg_layout = [1, num_subgroups] + kt_sg_data = [n_head, 16] + kt_inst_data = [16, 16] + kt_order = [0, 1] + + out_sg_layout = q_sg_layout + out_sg_data = q_sg_data + out_inst_data = q_inst_data + + layout_128x16_sg_layout = [num_subgroups, 1] + layout_128x16_sg_data = [16, 16] + layout_128x16_inst_data = [8, 16] + + qk_sg_layout = layout_128x16_sg_layout + qk_sg_data = layout_128x16_sg_data + qk_inst_data = layout_128x16_inst_data + + # Set layout attributes for xegpu.store_nd ops. + store_nd_op = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1)[0] + xegpu.set_anchor_layout( + store_nd_op, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + ) + + # Set layout for xegpu.load_nd ops (9 total: 1 Q, 4 K, 4 V) + load_nd_ops = match_and_split(gpu_func, ops={"xegpu.load_nd"}, nhandles=9) + + # First load_nd: Q layout + xegpu.set_anchor_layout( + load_nd_ops[0], sg_layout=q_sg_layout, sg_data=q_sg_data, inst_data=q_inst_data + ) + + # Next 4 load_nd ops: K layout + for i in range(1, 5): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=k_sg_layout, + sg_data=k_sg_data, + inst_data=k_inst_data, + ) + + # Last 4 load_nd ops: V layout + for i in range(5, 9): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=v_sg_layout, + sg_data=v_sg_data, + inst_data=v_inst_data, + ) + + # Set layout for xegpu.dpas ops (8 total: 4 for Q@K, 4 for P@V) + dpas_ops = match_and_split(gpu_func, ops={"xegpu.dpas"}, nhandles=8) + + # Layouts for first 4 dpas ops (Q@K^T): + for i in range(4): + qk_dpas_op = dpas_ops[i] + # Index 0: Q layout + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=q_sg_layout, + sg_data=q_sg_data, + inst_data=q_inst_data, + index=0, + ) + # Index 1: K^T layout + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=kt_sg_layout, + sg_data=kt_sg_data, + inst_data=kt_inst_data, + order=kt_order, + index=1, + ) + # Index 2: QK output layout (128x16) + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=layout_128x16_sg_layout, + sg_data=layout_128x16_sg_data, + inst_data=layout_128x16_inst_data, + index=2, + ) + + # Layouts for second 4 dpas ops (P@V): + for i in range(4, 8): + pv_dpas_op = dpas_ops[i] + # Index 0: QK (attention weights) layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=qk_sg_layout, + sg_data=qk_sg_data, + inst_data=qk_inst_data, + index=0, + ) + # Index 1: V layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=v_sg_layout, + sg_data=v_sg_data, + inst_data=v_inst_data, + index=1, + ) + # Index 2: Output layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + index=2, + ) + transform.print_(target=gpu_func, name="After setting xegpu layouts:") + if stop_at_stage == "xegpu-wg": + raise PipelineInterrupt() + + return mod