From e1c8124fa67b3c0dae15fd3093c960f3e6b51842 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 18 May 2026 23:01:20 +0000 Subject: [PATCH 01/20] add linalg to xegpu fused attention example --- examples/xegpu/fused_attention.py | 387 +++++++++++++ .../transform/transform_ext/__init__.py | 2 + .../ops/generate_fused_attention.py | 518 ++++++++++++++++++ lighthouse/execution/memory_manager.py | 2 +- .../mlir_gen/gpu_fused_attention_payload.py | 136 +++++ lighthouse/schedule/xegpu/__init__.py | 2 + .../xegpu/fused_attention_schedule.py | 477 ++++++++++++++++ 7 files changed, 1523 insertions(+), 1 deletion(-) create mode 100644 examples/xegpu/fused_attention.py create mode 100644 lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py create mode 100644 lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py create mode 100644 lighthouse/schedule/xegpu/fused_attention_schedule.py diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py new file mode 100644 index 00000000..fa4efc57 --- /dev/null +++ b/examples/xegpu/fused_attention.py @@ -0,0 +1,387 @@ +# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# CHECK: module attributes {gpu.container_module} { + +""" +XeGPU fused attention benchmark. +""" + +import argparse +from typing import Optional +from functools import cached_property + +import numpy as np +from mlir import ir + +from lighthouse import dialects as lh_dialects +from lighthouse.execution.runner import Runner +from lighthouse.pipeline.driver import TransformDriver +from lighthouse.execution import GPUMemoryManager +from lighthouse.utils.numpy import mlir_to_numpy_dtype +from lighthouse.ingress.mlir_gen import get_mlir_elem_type +from lighthouse.ingress.mlir_gen.gpu_fused_attention_payload import ( + generate_gpu_fused_attention_payload, +) +from lighthouse.schedule.xegpu import fused_attention_schedule, xegpu_to_binary + + +def fused_attention_complexity(Z: int, H: int, n_ctx: int, n_head: int, nbytes: int): + """ + Complexity of fused attention operation. + + For each batch and head: + - Q @ K^T: O(n_ctx^2 * n_head) operations + - Softmax: O(n_ctx^2) operations + - Attention @ V: O(n_ctx^2 * n_head) operations + Total: approximately 2*n_ctx^2*n_head FLOPs per batch and head + """ + # Approximation: 2 * n_ctx^2 * n_head FLOPs per batch and head + flop_count = Z * H * 2 * n_ctx * n_ctx * n_head + # Memory: read Q, K, V and write output + memory_reads = 3 * Z * H * n_ctx * n_head * nbytes + memory_writes = Z * H * n_ctx * n_head * nbytes + return flop_count, memory_reads, memory_writes + + +def check_correctness( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + output_arr: np.ndarray, + verbose: int = 0, +) -> bool: + """ + Check correctness of fused attention output. + + Reference implementation: + - scores = Q @ K^T / sqrt(n_head) + - attention_weights = softmax(scores, dim=-1) + - output = attention_weights @ V + """ + # Use float32 for computation + Q_f32 = Q.astype(np.float32) + K_f32 = K.astype(np.float32) + V_f32 = V.astype(np.float32) + + Z, H, n_ctx, n_head = Q.shape + scale = 1.0 / np.sqrt(n_head) + + output_ref = np.zeros_like(Q_f32) + + # Compute reference for each batch and head + for z in range(Z): + for h in range(H): + # scores = Q @ K^T / sqrt(n_head) + scores = Q_f32[z, h] @ K_f32[z, h].T * scale + + # softmax along last dimension + max_vals = np.max(scores, axis=1, keepdims=True) + exp_vals = np.exp(scores - max_vals) + sum_vals = np.sum(exp_vals, axis=1, keepdims=True) + attention_weights = exp_vals / sum_vals + + # output = attention_weights @ V + output_ref[z, h] = attention_weights @ V_f32[z, h] + + output = output_arr.astype(np.float32) + + if verbose > 1: + print("Reference solution (first batch, first head, first 5 rows):") + print(output_ref[0, 0, :5]) + print("Computed solution (first batch, first head, first 5 rows):") + print(output[0, 0, :5]) + + # Check values match reference + values_ok = np.allclose(output, output_ref, rtol=1e-3, atol=1e-3) + success = values_ok + + if verbose: + if success: + print("PASSED") + else: + print("FAILED!") + if not values_ok: + max_diff = np.abs(output - output_ref).max() + print(f" Values mismatch. Max abs diff: {max_diff:.6e}") + return success + + +class XeGPUFusedAttention: + """ + Fused attention workload on XeGPU. + + Computes fused attention: + output = softmax(Q @ K^T / sqrt(n_head)) @ V + + All Q, K, V matrices have shape (Z, H, n_ctx, n_head) where: + - Z: batch size + - H: number of heads + - n_ctx: context length + - n_head: head dimension + """ + + def __init__( + self, + Z: int, + H: int, + n_ctx: int, + n_head: int, + dtype: str = "f16", + ): + self.Z = Z + self.H = H + self.n_ctx = n_ctx + self.n_head = n_head + self.shape = (Z, H, n_ctx, n_head) + assert dtype == "f16", "Only f16 type is supported for fused attention" + self.elem_type = get_mlir_elem_type(dtype) + self.dtype = mlir_to_numpy_dtype(self.elem_type) + self.memory_manager_class = GPUMemoryManager + self.payload_function_name = "payload" + + @cached_property + def _initial_host_arrays(self) -> tuple[np.ndarray]: + """Generate initial values on host with numpy.""" + np.random.seed(42) + # Initialize Q, K, V with small random values + Q = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + K = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + V = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + output_arr = np.zeros(self.shape, dtype=self.dtype) + return (output_arr, Q, K, V) + + def get_complexity(self) -> tuple[int, int, int]: + nbytes = np.dtype(self.dtype).itemsize + return fused_attention_complexity( + self.Z, self.H, self.n_ctx, self.n_head, nbytes + ) + + def payload_module(self) -> ir.Module: + """Generate MLIR module for fused attention payload.""" + return generate_gpu_fused_attention_payload( + func_name=self.payload_function_name, + Z=self.Z, + H=self.H, + n_ctx=self.n_ctx, + n_head=self.n_head, + dtype=self.elem_type, + ) + + def schedule_modules( + self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None + ) -> list[ir.Module]: + """Generate transform schedule for fused attention.""" + schedules = [] + schedules.append(Runner.get_bench_wrapper_schedule(self.payload_function_name)) + + schedules.append( + fused_attention_schedule( + stop_at_stage=stop_at_stage, + parameters=parameters, + ) + ) + + if stop_at_stage and stop_at_stage != "final": + return schedules + + schedules.append(xegpu_to_binary()) + + return schedules + + def shared_libs(self) -> list[str]: + return ["libmlir_levelzero_runtime.so"] + + +def parse_cli(): + parser = argparse.ArgumentParser( + description="Fused Attention using MLIR XeGPU", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--batch-size", + type=int, + default=2, + help="Batch size (Z)", + ) + parser.add_argument( + "--num-heads", + type=int, + default=8, + help="Number of attention heads (H)", + ) + parser.add_argument( + "--n-ctx", + type=int, + default=4096, + help="Context length (sequence length)", + ) + parser.add_argument( + "--n-head", + type=int, + default=64, + help="Head dimension", + ) + parser.add_argument( + "--wg-rows", + type=int, + default=128, + help="Number of Q*K^T*V rows computed by each work group", + ) + parser.add_argument( + "--sg-rows", + type=int, + default=16, + help="Number of Q*K^T*V rows computed by each subgroup", + ) + parser.add_argument( + "--subgroup-size", + type=int, + default=16, + help="Subgroup size", + ) + parser.add_argument( + "--inner-loop-tile-size", + type=int, + default=64, + help="Tile size for the inner reduction dimension (K/V sequence length)", + ) + parser.add_argument( + "--nruns", + type=int, + default=1000, + help="Number of runs to average the execution time.", + ) + parser.add_argument( + "--nwarmup", + type=int, + default=20, + help="Number of warm-up iterations before benchmarking.", + ) + parser.add_argument( + "--check-result", + action="store_true", + help="Check the result of the fused attention computation.", + ) + parser.add_argument( + "--dump-kernel", + type=str, + choices=[ + "initial", + "outer-tiled", + "inner-tiled", + "vectorized", + "bufferized", + "gpu-outlining", + "xegpu-initial", + "xegpu-wg", + "final", + ], + help="Dump kernel IR at different stages of lowering and exit without " + "executing the kernel.", + ) + parser.add_argument( + "--dump-schedule", + action="store_true", + help="Dump transform schedule.", + ) + parser.add_argument( + "--verbose", + "-v", + action="count", + default=0, + help="Increase output verbosity (e.g. print reference and computed solutions).", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_cli() + + params = { + "batch_size": args.batch_size, + "num_heads": args.num_heads, + "n_ctx": args.n_ctx, + "n_head": args.n_head, + "wg_rows": args.wg_rows, + "sg_rows": args.sg_rows, + "subgroup_size": args.subgroup_size, + "inner_loop_tile_size": args.inner_loop_tile_size, + } + + Z = args.batch_size + H = args.num_heads + n_ctx = args.n_ctx + n_head = args.n_head + dtype = "f16" + + with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUFusedAttention(Z=Z, H=H, n_ctx=n_ctx, n_head=n_head, dtype=dtype) + + if args.dump_kernel or args.dump_schedule: + pipeline = TransformDriver( + wload.schedule_modules( + stop_at_stage=args.dump_kernel, parameters=params + ) + ) + payload = pipeline.apply(wload.payload_module()) + if args.dump_kernel: + print(payload) + if args.dump_schedule: + for schedule_module in wload.schedule_modules(parameters=params): + print(schedule_module) + else: + pipeline = TransformDriver(wload.schedule_modules(parameters=params)) + payload = pipeline.apply(wload.payload_module()) + runner = Runner( + payload, + mem_manager_cls=wload.memory_manager_class, + shared_libs=wload.shared_libs(), + ) + if args.check_result: + # Setup callback function to copy result from device to host. + result_host_copy = np.zeros(wload.shape, dtype=wload.dtype) + argument_access_callback = Runner.get_gpu_argument_access_callback( + result_host_copy, arg_index=0 + ) + + # Execute kernel once. + runner.execute( + host_input_buffers=wload._initial_host_arrays, + payload_function_name=wload.payload_function_name, + argument_access_callback=argument_access_callback, + ) + + # Compute reference solution on host. + Q, K, V = wload._initial_host_arrays[1:4] + success = check_correctness( + Q, + K, + V, + result_host_copy, + verbose=args.verbose, + ) + if not success: + raise ValueError("Result mismatch!") + else: + print("Result is correct. Proceeding to benchmark...") + + times = runner.benchmark( + host_input_buffers=wload._initial_host_arrays, + nruns=args.nruns, + nwarmup=args.nwarmup, + ) + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + print( + f"batch-size={Z} " + f"num-heads={H} " + f"n-ctx={n_ctx} " + f"n-head={n_head} " + f"dt={dtype} " + f"time(us): {elapsed:.2f} " + f"GFLOPS: {gflops:.2f} " + ) diff --git a/lighthouse/dialects/transform/transform_ext/__init__.py b/lighthouse/dialects/transform/transform_ext/__init__.py index 997522a2..eec36b6e 100644 --- a/lighthouse/dialects/transform/transform_ext/__init__.py +++ b/lighthouse/dialects/transform/transform_ext/__init__.py @@ -10,11 +10,13 @@ from .ops.get_tileable_consumers import get_tileable_consumers from .ops.get_tiling_sizes import get_tiling_sizes from .ops.update_address_space import update_address_space +from .ops.generate_fused_attention import generate_fused_attention __all__ = [ "TransformExtensionDialect", "convert_func_results_to_args", "extract_handle", + "generate_fused_attention", "get_named_attribute", "get_named_attribute", "get_tileable_consumers", diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py new file mode 100644 index 00000000..747a3376 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -0,0 +1,518 @@ +"""Transform extension to generate fused attention computation.""" + +import numpy as np +from mlir import ir +from mlir.dialects import ext, transform, arith, scf, math, vector +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect + + +class GenerateFusedAttention( + TransformExtensionDialect.Operation, name="generate_fused_attention" +): + """Generate tiled fused attention computation (flash attention optimization). + + Takes Q, K, V loads and scale constant from bufferized IR, and generates an inner + tiled loop that computes fused attention with online softmax using running max and sum. + + This implements the flash attention algorithm where: + 1. The computation is tiled along the reduction dimension (K/V sequence length) + 2. Online max and sum are maintained across tiles + 3. Output is incrementally updated with rescaled contributions + + Args: + q_load: Handle to Q load operation (vector.transfer_read) + k_load: Handle to K load operation (vector.transfer_read) + v_load: Handle to V load operation (vector.transfer_read) + scale: Handle to scale constant operation (arith.constant) + output: Handle to the output operation to replace (vector.contract) + tile_size: Tile size for the reduction dimension tiling (K/V sequence length) + """ + + q_load: ext.Operand[transform.AnyOpType] + k_load: ext.Operand[transform.AnyOpType] + v_load: ext.Operand[transform.AnyOpType] + scale: ext.Operand[transform.AnyOpType] + output: ext.Operand[transform.AnyOpType] + tile_size: ir.IntegerAttr + new_output: ext.Result[transform.AnyOpType[()]] = ext.infer_result() + + @classmethod + def attach_interface_impls(cls, ctx=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "GenerateFusedAttention", + rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + # Get payload operations + q_load_ops = state.get_payload_ops(op.q_load) + k_load_ops = state.get_payload_ops(op.k_load) + v_load_ops = state.get_payload_ops(op.v_load) + scale_ops = state.get_payload_ops(op.scale) + output_ops = state.get_payload_ops(op.output) + + if ( + len(q_load_ops) != 1 + or len(k_load_ops) != 1 + or len(v_load_ops) != 1 + or len(scale_ops) != 1 + or len(output_ops) != 1 + ): + return DiagnosedSilenceableFailure.emit_silenceable_error( + "Expected exactly one operation for each operand" + ) + + q_load_op = q_load_ops[0] + k_load_op = k_load_ops[0] + v_load_op = v_load_ops[0] + scale_op = scale_ops[0] + output_op = output_ops[0] + + # Extract the scale scalar value from scale_op (arith.constant) + scale_attr = scale_op.attributes["value"] + scale_dense_attr = ir.DenseElementsAttr(scale_attr) + scale_np_array = np.array(scale_dense_attr) + scale_value = float(scale_np_array.flat[0]) + + # Extract wg_rows and d_head from q_load result type + q_load_result = q_load_op.results[0] + q_vector_type = ir.VectorType(q_load_result.type) + wg_rows = q_vector_type.shape[0] + d_head = q_vector_type.shape[1] + + # Get tile size + tile_size_value = ir.IntegerAttr(op.tile_size).value + + # Get element type from q_load result + element_type = q_vector_type.element_type + + # Build the fused attention computation + with ir.InsertionPoint(output_op): + # Define m_i_init: vector of shape [wg_rows] with neg_inf values + # NOTE: We use float32 for the initial neg_inf values and cast to the element type + # to avoid issues with representing -inf. + m_i_vector_type = ir.VectorType.get([wg_rows], element_type) + m_i_vector_type_f32 = ir.VectorType.get([wg_rows], ir.F32Type.get()) + neg_inf_value = float("-inf") + m_i_values = np.full( + wg_rows, + neg_inf_value, + dtype=np.float32, + ) + m_i_init_attr = ir.DenseElementsAttr.get( + m_i_values, type=m_i_vector_type_f32 + ) + m_i_init_f32 = arith.constant(m_i_vector_type_f32, m_i_init_attr) + m_i_init = arith.truncf(m_i_vector_type, m_i_init_f32) + + # Define l_i_init: vector of shape [wg_rows] with zero values + l_i_vector_type = ir.VectorType.get([wg_rows], element_type) + l_i_values = np.zeros( + wg_rows, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + l_i_init_attr = ir.DenseElementsAttr.get( + l_i_values, type=l_i_vector_type + ) + l_i_init = arith.constant(l_i_vector_type, l_i_init_attr) + + # Define acc_init: vector of shape [wg_rows, d_head] with zero values + acc_vector_type = ir.VectorType.get([wg_rows, d_head], element_type) + acc_values = np.zeros( + (wg_rows, d_head), + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + acc_init_attr = ir.DenseElementsAttr.get( + acc_values, type=acc_vector_type + ) + acc_init = arith.constant(acc_vector_type, acc_init_attr) + + # Get n_ctx from k_load result type (first dimension size) + k_load_result = k_load_op.results[0] + k_vector_type = ir.VectorType(k_load_result.type) + n_ctx = k_vector_type.shape[0] + # Define scale vector: vector of shape [wg_rows] with the scale value + scale_vector_type = ir.VectorType.get([wg_rows], element_type) + scale_values = np.full( + (wg_rows), + scale_value, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + scale_init_attr = ir.DenseElementsAttr.get( + scale_values, type=scale_vector_type + ) + scale_vector = arith.constant(scale_vector_type, scale_init_attr) + + # Create loop bounds + index_type = ir.IndexType.get() + c0 = arith.constant(index_type, 0) + c_n_ctx = arith.constant(index_type, n_ctx) + c_tile_size = arith.constant(index_type, tile_size_value) + + # Create scf.for loop that iterates from 0 to n_ctx in steps of tile_size + loop = scf.ForOp( + c0, c_n_ctx, c_tile_size, [m_i_init, l_i_init, acc_init] + ) + + with ir.InsertionPoint(loop.body): + # Get the loop induction variable and iter_args + loop_idx = loop.induction_variable + m_i = loop.inner_iter_args[0] + l_i = loop.inner_iter_args[1] + acc = loop.inner_iter_args[2] + + # Get common values for K/V tiling + k_memref = k_load_op.operands[0] + k_load_indices = list(k_load_op.operands[1:-1]) + padding = k_load_op.operands[-1] + in_bounds = k_load_op.attributes.get("in_bounds", None) + k_perm_map = k_load_op.attributes.get("permutation_map", None) + q_value = q_load_op.results[0] + + # Constants for K/V tiling (tile into chunks of 16) + k_subtile_size = 16 + num_k_tiles = tile_size_value // k_subtile_size + + # Create offset constants for each K tile + k_tile_offsets = [] + for i in range(num_k_tiles): + offset = arith.constant(index_type, i * k_subtile_size) + k_tile_offsets.append(offset) + + # Load and process K tiles (unrolled) + # Each K tile is [16, d_head], transposed to [d_head, 16], contracted to [wg_rows, 16] + qkt_chunks = [] + + # Create affine maps for Q@K contraction (used for all tiles) + affine_d0 = ir.AffineExpr.get_dim(0) + affine_d1 = ir.AffineExpr.get_dim(1) + affine_d2 = ir.AffineExpr.get_dim(2) + + q_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d2]) + k_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) + out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) + + indexing_maps = ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get(q_map), + ir.AffineMapAttr.get(k_map), + ir.AffineMapAttr.get(out_map), + ] + ) + + iterator_types = ir.ArrayAttr.get( + [ + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ] + ) + + # Accumulator for Q@K chunks + qkt_chunk_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + qkt_chunk_acc_values = np.zeros( + (wg_rows, k_subtile_size), + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + qkt_chunk_acc_attr = ir.DenseElementsAttr.get( + qkt_chunk_acc_values, type=qkt_chunk_type + ) + qkt_chunk_acc = arith.constant(qkt_chunk_type, qkt_chunk_acc_attr) + + for tile_idx in range(num_k_tiles): + # Compute the offset index for this tile + k_tile_offset = arith.addi(loop_idx, k_tile_offsets[tile_idx]) + + # Update indices for this K tile + k_tile_indices = k_load_indices.copy() + k_tile_indices[-2] = k_tile_offset + + # Load K tile: [16, d_head] + k_tile_type = ir.VectorType.get( + [k_subtile_size, d_head], element_type + ) + k_tile = vector.TransferReadOp( + k_tile_type, + k_memref, + k_tile_indices, + k_perm_map, + padding, + in_bounds=in_bounds, + ).result + + # Transpose K tile: [16, d_head] -> [d_head, 16] + k_transpose_type = ir.VectorType.get( + [d_head, k_subtile_size], element_type + ) + k_transpose = vector.transpose(k_transpose_type, k_tile, [1, 0]) + + # Contract Q @ K_transpose: [wg_rows, d_head] @ [d_head, 16] -> [wg_rows, 16] + qkt_chunk = vector.contract( + qkt_chunk_type, + q_value, + k_transpose, + qkt_chunk_acc, + indexing_maps=indexing_maps, + iterator_types=iterator_types, + ) + qkt_chunks.append(qkt_chunk) + + # Elementwise maximum across all Q@K chunks + # Build tree of maximumf operations + qkt_max_combined = qkt_chunks[0] + for i in range(1, num_k_tiles): + qkt_max_combined = arith.maximumf( + qkt_max_combined, qkt_chunks[i] + ) + + # Final multi_reduction to get row-wise max: [wg_rows, 16] -> [wg_rows] + qkt_max = vector.multi_reduction( + kind="maxnumf", + source=qkt_max_combined, + acc=m_i_init, + reduction_dims=[1], + ) + + # Scale the max: qkt_max_scaled = qkt_max * scale + # Both have shape [wg_rows] + qkt_max_scaled = arith.mulf(qkt_max, scale_vector) + + # Compute m_ij = max(m_i, qkt_max_scaled) + # Both have shape [wg_rows] + m_ij = arith.maximumf(m_i, qkt_max_scaled) + + # Apply softmax to each Q@K chunk + # Scale constant for chunks: [wg_rows, 16] + scale_chunk_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + scale_chunk_values = np.full( + (wg_rows, k_subtile_size), + scale_value, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + scale_chunk_attr = ir.DenseElementsAttr.get( + scale_chunk_values, type=scale_chunk_type + ) + scale_chunk = arith.constant(scale_chunk_type, scale_chunk_attr) + + # Broadcast m_ij from [wg_rows] to [wg_rows, 16] + m_ij_bcasted_type = ir.VectorType.get( + [k_subtile_size, wg_rows], element_type + ) + m_ij_bcasted = vector.broadcast(m_ij_bcasted_type, m_ij) + m_ij_transposed_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + m_ij_transposed = vector.transpose( + m_ij_transposed_type, m_ij_bcasted, [1, 0] + ) + + # Apply softmax to each chunk + qkt_exp_chunks = [] + for qkt_chunk in qkt_chunks: + # Scale: qkt_scaled = qkt_chunk * scale + qkt_scaled = arith.mulf(qkt_chunk, scale_chunk) + + # Center: qkt_centered = qkt_scaled - m_ij_transposed + qkt_centered = arith.subf(qkt_scaled, m_ij_transposed) + + # Exponential: qkt_exp = exp(qkt_centered) + qkt_exp = math.exp(qkt_centered) + qkt_exp_chunks.append(qkt_exp) + + # Elementwise sum across all exp chunks + qkt_exp_combined = qkt_exp_chunks[0] + for i in range(1, num_k_tiles): + qkt_exp_combined = arith.addf( + qkt_exp_combined, qkt_exp_chunks[i] + ) + + # Final multi_reduction to get row-wise sum: [wg_rows, 16] -> [wg_rows] + l_ij = vector.multi_reduction( + kind="add", + source=qkt_exp_combined, + acc=l_i_init, + reduction_dims=[1], + ) + + # Compute alpha = exp(m_i - m_ij) + m_diff = arith.subf(m_i, m_ij) + alpha = math.exp(m_diff) + + # Update l_i: l_i_updated = l_i * alpha + l_ij + l_i_scaled = arith.mulf(l_i, alpha) + l_i_updated = arith.addf(l_i_scaled, l_ij) + + # Broadcast alpha from [wg_rows] to [wg_rows, d_head] + alpha_bcasted_type = ir.VectorType.get( + [d_head, wg_rows], element_type + ) + alpha_bcasted = vector.broadcast(alpha_bcasted_type, alpha) + alpha_transposed_type = ir.VectorType.get( + [wg_rows, d_head], element_type + ) + alpha_transposed = vector.transpose( + alpha_transposed_type, alpha_bcasted, [1, 0] + ) + + # Update accumulator: acc_updated = acc * alpha_bcasted + acc_updated = arith.mulf(acc, alpha_transposed) + + # Load V tiles and compute attention-weighted values + # Get V load parameters + v_memref = v_load_op.operands[0] + v_load_indices = list(v_load_op.operands[1:-1]) + v_padding = v_load_op.operands[-1] + v_in_bounds = v_load_op.attributes.get("in_bounds", None) + v_perm_map = v_load_op.attributes.get("permutation_map", None) + + # Create affine maps for P@V contraction + qkt_exp_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d2]) + v_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) + pv_out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) + + indexing_maps_pv = ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get(qkt_exp_map), + ir.AffineMapAttr.get(v_map), + ir.AffineMapAttr.get(pv_out_map), + ] + ) + + iterator_types_pv = ir.ArrayAttr.get( + [ + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ] + ) + + # Load and process V tiles (unrolled), accumulating results + pv_out = acc_updated + for tile_idx in range(num_k_tiles): + # Compute the offset index for this V tile + v_tile_offset = arith.addi(loop_idx, k_tile_offsets[tile_idx]) + + # Update indices for this V tile + v_tile_indices = v_load_indices.copy() + v_tile_indices[-2] = v_tile_offset + + # Load V tile: [16, d_head] + v_tile_type = ir.VectorType.get( + [k_subtile_size, d_head], element_type + ) + v_tile = vector.TransferReadOp( + v_tile_type, + v_memref, + v_tile_indices, + v_perm_map, + v_padding, + in_bounds=v_in_bounds, + ).result + + # Contract qkt_exp_chunk @ v_tile: [wg_rows, 16] @ [16, d_head] -> [wg_rows, d_head] + # Accumulate into pv_out + pv_out = vector.contract( + acc_vector_type, + qkt_exp_chunks[tile_idx], + v_tile, + pv_out, + indexing_maps=indexing_maps_pv, + iterator_types=iterator_types_pv, + ) + + # Yield the updated iter args + scf.yield_([m_ij, l_i_updated, pv_out]) + + # Extract the final accumulator result (3rd output) from the loop + pv_out = loop.results[2] + l_i_out = loop.results[1] + with ir.InsertionPoint.after(loop): + # Normalize the output: output_final = pv_out / l_i_out + # Need to broadcast l_i_out from [wg_rows] to [wg_rows, d_head] + l_i_out_bcasted_type = ir.VectorType.get( + [d_head, wg_rows], element_type + ) + l_i_out_bcasted = vector.broadcast(l_i_out_bcasted_type, l_i_out) + l_i_out_transposed_type = ir.VectorType.get( + [wg_rows, d_head], element_type + ) + l_i_out_transposed = vector.transpose( + l_i_out_transposed_type, l_i_out_bcasted, [1, 0] + ) + output_final = arith.divf(pv_out, l_i_out_transposed) + + # Replace all uses of the original output operation with the final loop result + output_op.results[0].replace_all_uses_with(output_final) + + # Erase the original output operation + rewriter.erase_op(output_op) + + # Return the final output handle + results.set_ops(op.new_output, [output_final.owner]) + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "GenerateFusedAttention") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: ir.Operation, effects): + # Read Q, K, scale, V slices + transform.only_reads_handle(op.op_operands[:4], effects) + # Consume and replace output + transform.consumes_handle(op.op_operands[4:5], effects) + # Produce new output handle + transform.produces_handle(op.results, effects) + # Modify the payload + transform.modifies_payload(effects) + + +def generate_fused_attention( + q_load: ir.Value, + k_load: ir.Value, + v_load: ir.Value, + scale: ir.Value, + output: ir.Value, + tile_size: int | ir.IntegerAttr, +) -> ir.Value: + """Generate fused attention computation with inner tiling on bufferized IR. + + Args: + q_load: Handle to Q load operation (vector.transfer_read) + k_load: Handle to K load operation (vector.transfer_read) + v_load: Handle to V load operation (vector.transfer_read) + scale: Handle to scale constant operation (arith.constant) + output: Handle to output operation to replace (vector.contract) + tile_size: Tile size for the reduction dimension tiling (K/V sequence length) + + Returns: + Handle to the new output operation + """ + if not isinstance(tile_size, ir.IntegerAttr): + tile_size = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), tile_size) + + return GenerateFusedAttention( + q_load, k_load, v_load, scale, output, tile_size=tile_size + ).new_output diff --git a/lighthouse/execution/memory_manager.py b/lighthouse/execution/memory_manager.py index 7b241a8b..5fd1443b 100644 --- a/lighthouse/execution/memory_manager.py +++ b/lighthouse/execution/memory_manager.py @@ -85,7 +85,7 @@ def alloc( ptr_mref = memref_to_ctype(mref) ptr_dims = [ctypes.pointer(ctypes.c_int32(d)) for d in shape] rank = len(shape) - assert rank in (1, 2), "Only 1d or 2d arrays are supported." + assert rank >= 1 and rank <= 5, "Only 1d to 5d arrays are supported." suffix = f"{rank}d_{str(elem_type)}" self.execution_engine.invoke("gpu_alloc_" + suffix, ptr_mref, *ptr_dims) diff --git a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py new file mode 100644 index 00000000..73873dc6 --- /dev/null +++ b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py @@ -0,0 +1,136 @@ +"""Generate MLIR payload for GPU fused attention operation.""" + +import math + +from mlir import ir +from mlir.dialects import arith, bufferization, linalg, memref, tensor + +from lighthouse.utils.mlir import func_cif +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs +from lighthouse.ingress.mlir_gen.utils import emit_buf_to_tensor + + +def generate_gpu_fused_attention_payload( + func_name: str, + Z: int, + H: int, + n_ctx: int, + n_head: int, + dtype: ir.Type, +) -> ir.Module: + """ + Generate MLIR module for fused attention payload. + + Computes fused attention: + output = softmax(Q @ K^T / sqrt(n_head)) @ V + + Args: + func_name: Name of the payload function + Z: Batch size + H: Number of attention heads + n_ctx: Context length (sequence length) + n_head: Head dimension + dtype: MLIR element type (e.g., F32Type) + + Returns: + MLIR module containing the fused attention payload function + """ + mod = ir.Module.create() + shape = (Z, H, n_ctx, n_head) + memref_t = ir.MemRefType.get(shape, dtype) + + with ir.InsertionPoint(mod.body): + # Collapse first 2 dimensions (Z, H) into a batch dimension + # From (Z, H, n_ctx, n_head) to (Z*H, n_ctx, n_head) + batch_dim = Z * H + collapsed_shape_3d = (batch_dim, n_ctx, n_head) + memref_3d_t = ir.MemRefType.get(collapsed_shape_3d, dtype) + + # Function signature: payload(output, Q, K, V) + @func_cif(memref_t, memref_t, memref_t, memref_t, name=func_name) + def payload(output, Q_arg, K_arg, V_arg): + # Collapse memrefs from 4D to 3D + Q_3d_memref = memref.collapse_shape( + memref_3d_t, + Q_arg, + reassociation=[[0, 1], [2], [3]], + ) + K_3d_memref = memref.collapse_shape( + memref_3d_t, + K_arg, + reassociation=[[0, 1], [2], [3]], + ) + V_3d_memref = memref.collapse_shape( + memref_3d_t, + V_arg, + reassociation=[[0, 1], [2], [3]], + ) + output_3d_memref = memref.collapse_shape( + memref_3d_t, + output, + reassociation=[[0, 1], [2], [3]], + ) + + # Convert 3D memrefs to tensors + Q_3d = emit_buf_to_tensor(Q_3d_memref, restrict=True) + K_3d = emit_buf_to_tensor(K_3d_memref, restrict=True) + V_3d = emit_buf_to_tensor(V_3d_memref, restrict=True) + + # Step 1: Transpose K to get K^T + # Permute from (batch_dim, n_ctx, n_head) to (batch_dim, n_head, n_ctx) + kt_shape_3d = (batch_dim, n_head, n_ctx) + kt_init = tensor.empty(kt_shape_3d, dtype) + K_transposed = linalg.transpose(K_3d, outs=[kt_init], permutation=[0, 2, 1]) + + # Step 2: Compute Q @ K^T using batch_matmul + # Q: (batch_dim, n_ctx, n_head) @ K^T: (batch_dim, n_head, n_ctx) + # Result: (batch_dim, n_ctx, n_ctx) + qkt_shape_3d = (batch_dim, n_ctx, n_ctx) + qkt_init = tensor.empty(qkt_shape_3d, dtype) + # Initialize with zeros for matmul accumulation + zero = arith.constant(dtype, 0.0) + qkt_init_filled = linalg.fill(zero, outs=[qkt_init]) + + # Batch matmul: Q @ K^T + qkt = linalg.batch_matmul(Q_3d, K_transposed, outs=[qkt_init_filled]) + + # Step 3: Scale by 1/sqrt(n_head) + scale_factor = 1.0 / math.sqrt(n_head) + scale_const = arith.constant(dtype, scale_factor) + + # Create a tensor filled with the scale factor + scale_tensor_init = tensor.empty(qkt_shape_3d, dtype) + scale_tensor = linalg.fill(scale_const, outs=[scale_tensor_init]) + + # Elementwise multiply qkt with scale tensor + scaled_qkt_init = tensor.empty(qkt_shape_3d, dtype) + scaled_qkt = linalg.mul(qkt, scale_tensor, outs=[scaled_qkt_init]) + + # Step 4: Apply softmax along the last dimension (dim=2 in 3D) + softmax_init = tensor.empty(qkt_shape_3d, dtype) + attention_weights = linalg.softmax( + result=[ir.RankedTensorType.get(qkt_shape_3d, dtype)], + input=scaled_qkt, + output=softmax_init, + dimension=2, + ) + + # Step 5: Multiply attention weights by V using batch_matmul + # attention_weights: (batch_dim, n_ctx, n_ctx) @ V: (batch_dim, n_ctx, n_head) + # Result: (batch_dim, n_ctx, n_head) + output_3d_init = tensor.empty(collapsed_shape_3d, dtype) + output_3d_init_filled = linalg.fill(zero, outs=[output_3d_init]) + + result_3d = linalg.batch_matmul( + attention_weights, V_3d, outs=[output_3d_init_filled] + ) + + # Materialize 3D result back to 3D output memref + bufferization.materialize_in_destination( + None, result_3d, output_3d_memref, restrict=True, writable=True + ) + + # Emit utility functions for GPU memory management + emit_gpu_util_funcs(dtype, rank=4) + + return mod diff --git a/lighthouse/schedule/xegpu/__init__.py b/lighthouse/schedule/xegpu/__init__.py index 23d9ef0c..76f4101c 100644 --- a/lighthouse/schedule/xegpu/__init__.py +++ b/lighthouse/schedule/xegpu/__init__.py @@ -1,8 +1,10 @@ from .xegpu_to_binary import xegpu_to_binary from .mlp_schedule import mlp_schedule from .softmax_schedule import softmax_schedule +from .fused_attention_schedule import fused_attention_schedule __all__ = [ + "fused_attention_schedule", "mlp_schedule", "softmax_schedule", "xegpu_to_binary", diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py new file mode 100644 index 00000000..cb99e0e9 --- /dev/null +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -0,0 +1,477 @@ +"""Generate MLIR transform schedule for XeGPU fused attention operation.""" + +from typing import Optional + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured, loop, xegpu +from mlir.dialects.transform import bufferization as transform_bufferization +from mlir.dialects.bufferization import LayoutMapOption +from mlir.dialects.transform.vector import ( + apply_patterns_vector_cast_away_vector_leading_one_dim, + apply_patterns_vector_drop_unit_dims_with_shape_cast, +) + +from lighthouse.pipeline.helper import ( + canonicalize, + match, + match_and_split, + PipelineInterrupt, + apply_registered_pass, +) +from lighthouse.schedule import schedule_boilerplate +from lighthouse.dialects.transform.transform_ext import ( + generate_fused_attention, + update_address_space, +) + + +def fused_attention_schedule( + stop_at_stage: Optional[str] = None, + parameters: Optional[dict] = None, +) -> ir.Module: + """ + Generate transform schedule for fused attention operation. + + The schedule performs the following transformations: + 1. Tile the fused attention operation + 2. Vectorize operations + 3. Bufferize tensors + 4. Convert to GPU dialect + 5. Lower to XeGPU operations + + Args: + stop_at_stage: Optional stage name to stop early (for debugging) + parameters: Dictionary with scheduling parameters: + - batch_size: Batch size (Z) + - num_heads: Number of attention heads (H) + - n_ctx: Context length + - n_head: Head dimension + - wg_rows: Number of Q*K^T*V rows computed by each work group + - sg_rows: Number of Q*K^T*V rows computed by each subgroup + - subgroup_size: Size of subgroup + + Returns: + MLIR module containing the transform schedule + """ + assert parameters is not None, "Schedule parameters must be provided" + + with schedule_boilerplate() as (schedule, named_seq): + # match the payload module + anytype = transform.AnyOpType.get() + func = match(named_seq.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + + try: + bundle_xegpu_fused_attention_schedule( + payload_mod, + parameters=parameters, + stop_at_stage=stop_at_stage or "", + ) + except PipelineInterrupt: + pass + finally: + transform.yield_() + + return schedule + + +def bundle_xegpu_fused_attention_schedule( + mod: ir.Value[transform.AnyOpType], + parameters: dict, + stop_at_stage: str = "", +) -> ir.Value[transform.AnyOpType]: + """Schedule for lowering fused attention payload to xegpu wg level.""" + + if stop_at_stage == "initial": + raise PipelineInterrupt() + + anytype = transform.AnyOpType.get() + # Match all matmul operations - there should be 2: + # 1. Q @ K^T + # 2. attention_weights @ V + matmul_ops = match_and_split(mod, ops={"linalg.batch_matmul"}, nhandles=2) + + # Get the last matmul (attention_weights @ V) + last_matmul = matmul_ops[1] + func = transform.get_parent_op( + anytype, + last_matmul, + op_name="func.func", + deduplicate=True, + ) + + # Tile the last matmul in both batch and M dimensions. + wg_rows = parameters["wg_rows"] + + tiled_matmul, forall_loop = structured.structured_tile_using_forall( + anytype, + anytype, + last_matmul, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(1, wg_rows, 0, 0), + ) + # Fuse the zero initialization of the output of the last matmul (tensor.empty) into the forall loop. + tiled_matmul_init = transform.get_producer_of_operand( + anytype, forall_loop, operand_number=0 + ) + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=tiled_matmul_init, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + # Decompose softmax into generic ops + softmax_ops = match_and_split(func, ops={"linalg.softmax"}, nhandles=1) + softmax_op = softmax_ops[0] + structured.structured_decompose_interface(anytype, softmax_op) + transform.apply_cse(func) + canonicalize(func) + + # Fuse all linalg.generic ops from softmax decomposition (4 ops: max, sub+exp, sum, div) + # Match and fuse in reverse order (from consumer to producer) + generic_ops = match_and_split(func, ops={"linalg.generic"}, nhandles=4) + for generic_op in reversed(generic_ops): + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=generic_op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + # Max and add reductions use linalg.fill to intialize the reduction output. Fuse these fill ops as well. + fill_ops = match_and_split(func, ops={"linalg.fill"}, nhandles=5) + # Max fill is the third fill op and add fill is the fourth fill op (based on the pattern of decomposition) + max_fill_op = fill_ops[2] + add_fill_op = fill_ops[3] + for fill_op in [max_fill_op, add_fill_op]: + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=fill_op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + linalg_mul_op = match_and_split(func, ops={"linalg.mul"}, nhandles=1)[0] + first_matmul = transform.get_producer_of_operand( + anytype, linalg_mul_op, operand_number=0 + ) + scale_fill_op = transform.get_producer_of_operand( + anytype, linalg_mul_op, operand_number=1 + ) + transpose_op = transform.get_producer_of_operand( + anytype, first_matmul, operand_number=1 + ) + matmul_fill_op = transform.get_producer_of_operand( + anytype, first_matmul, operand_number=2 + ) + for op in [ + linalg_mul_op, + scale_fill_op, + first_matmul, + matmul_fill_op, + transpose_op, + ]: + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "outer-tiled": + raise PipelineInterrupt() + + # vectorize + func = structured.VectorizeChildrenAndApplyPatternsOp( + func, + fold_type_extensions_into_contract=True, + ).result + transform.apply_cse(func) + canonicalize(func) + # Try to remove any unit dimensions that may have been introduced due to tiling (e.g. batch dim of 1) + with ir.InsertionPoint(transform.apply_patterns(func).patterns): + apply_patterns_vector_cast_away_vector_leading_one_dim() + apply_patterns_vector_drop_unit_dims_with_shape_cast() + + if stop_at_stage == "vectorized": + raise PipelineInterrupt() + + # bufferize + mod = apply_registered_pass(mod, "eliminate-empty-tensors") + identity_layout = LayoutMapOption.IdentityLayoutMap + mod = transform_bufferization.OneShotBufferizeOp( + mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=identity_layout, + ).result + transform.apply_cse(mod) + canonicalize(mod) + # fold memref.subviews into vector.transfer_read/write ops + mod = apply_registered_pass(mod, "fold-memref-alias-ops") + transform.apply_cse(mod) + canonicalize(mod) + + # promote memref.alloc to memref.alloca in payload function + func = match(mod, ops={"func.func"}) + func = apply_registered_pass( + func, + "promote-buffers-to-stack", + options={ + "max-alloc-size-in-bytes": "8192", + "max-rank-of-allocated-memref": "2", + }, + ) + + # Extract q, k, v memrefs from the bufferized IR + # Match vector.contract ops to find the q, k, v loads + for_all = match(mod, ops={"scf.forall"}) + func = transform.get_parent_op(anytype, for_all, op_name="func.func") + contract_ops = match_and_split(func, ops={"vector.contract"}, nhandles=2) + + # First vector.contract is Q @ K^T + # Its first operand is the q load (vector.transfer_read) + # Its second operand is the k load (vector.transfer_read) + first_contract = contract_ops[0] + q_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=0 + ) + k_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=1 + ) + + # # Second vector.contract is attention_weights @ V + # # Its second operand is the v load (vector.transfer_read) + second_contract = contract_ops[1] + v_load = transform.get_producer_of_operand( + anytype, second_contract, operand_number=1 + ) + + # Match arith.mulf to get the scale parameter + # The scale is the second operand of arith.mulf (the constant) + mulf_op = match_and_split(func, ops={"arith.mulf"}, nhandles=1)[0] + scale = transform.get_producer_of_operand(anytype, mulf_op, operand_number=1) + + if stop_at_stage == "bufferized": + raise PipelineInterrupt() + + # Generate fused attention computation with inner tiling + # This replaces the second vector.contract (attention_weights @ V) with a tiled + # loop that implements online softmax for efficient memory usage + tile_size = parameters.get( + "inner_loop_tile_size", 64 + ) # Tile size for reduction dimension (K/V sequence length) + generate_fused_attention( + q_load=q_load, + k_load=k_load, + v_load=v_load, + scale=scale, + output=second_contract, + tile_size=tile_size, + ) + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "inner-tiled": + raise PipelineInterrupt() + + # convert forall to parallel + wg_loops = match_and_split(mod, ops={"scf.forall"}) + for wg_loop in wg_loops: + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) + + # convert scf.parallel to gpu.launch + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + + # set the number of threads for the gpu.launch operation + launch_op = match_and_split(func, ops={"gpu.launch"}) + wg_rows = parameters["wg_rows"] + sg_rows = parameters["sg_rows"] + subgroup_size = parameters["subgroup_size"] + num_subgroups = wg_rows // sg_rows + num_threads = num_subgroups * subgroup_size + xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) + + # outline gpu func + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + mod = apply_registered_pass(mod, "gpu-kernel-outlining") + transform.apply_cse(mod) + + if stop_at_stage == "gpu-outlining": + raise PipelineInterrupt() + + # set xevm target + mod = apply_registered_pass( + mod, + "xevm-attach-target", + options={"O": "3", "chip": "bmg"}, + ) + + # for each gpu function in the gpu module, change memref.alloca address + # space to 3 (SLM) and convert vector to xegpu. + gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}) + for gpu_mod in gpu_mod_ops: + gpu_func = match(gpu_mod, ops={"gpu.func"}) + allocas = match_and_split(gpu_func, ops={"memref.alloca"}, nhandles=3) + for alloca in allocas: + # print("Updating address space for alloca:") + update_address_space(alloca, address_space=3) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) + gpu_func = apply_registered_pass(gpu_func, "loop-invariant-code-motion") + + if stop_at_stage == "xegpu-initial": + raise PipelineInterrupt() + + # Define XeGPU layout parameters + q_sg_layout = [8, 1] + q_sg_data = [16, 64] + q_inst_data = [8, 16] + + k_sg_layout = [8, 1] + k_sg_data = [16, 64] + k_inst_data = [16, 16] + + v_sg_layout = k_sg_layout + v_sg_data = k_sg_data + v_inst_data = k_inst_data + + kt_sg_layout = [1, 8] + kt_sg_data = [64, 16] + kt_inst_data = [16, 16] + kt_order = [0, 1] + + out_sg_layout = q_sg_layout + out_sg_data = q_sg_data + out_inst_data = q_inst_data + + layout_128x16_sg_layout = [8, 1] + layout_128x16_sg_data = [16, 16] + layout_128x16_inst_data = [8, 16] + + qk_sg_layout = layout_128x16_sg_layout + qk_sg_data = layout_128x16_sg_data + qk_inst_data = layout_128x16_inst_data + + # Set layout attributes for xegpu.store_nd ops. + store_nd_op = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1)[0] + xegpu.set_anchor_layout( + store_nd_op, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + ) + + # Set layout for xegpu.load_nd ops (9 total: 1 Q, 4 K, 4 V) + load_nd_ops = match_and_split(gpu_func, ops={"xegpu.load_nd"}, nhandles=9) + + # First load_nd: Q layout + xegpu.set_anchor_layout( + load_nd_ops[0], sg_layout=q_sg_layout, sg_data=q_sg_data, inst_data=q_inst_data + ) + + # Next 4 load_nd ops: K layout + for i in range(1, 5): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=k_sg_layout, + sg_data=k_sg_data, + inst_data=k_inst_data, + ) + + # Last 4 load_nd ops: V layout + for i in range(5, 9): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=v_sg_layout, + sg_data=v_sg_data, + inst_data=v_inst_data, + ) + + # Set layout for xegpu.dpas ops (8 total: 4 for Q@K, 4 for P@V) + dpas_ops = match_and_split(gpu_func, ops={"xegpu.dpas"}, nhandles=8) + + # Layouts for first 4 dpas ops (Q@K^T): + for i in range(4): + qk_dpas_op = dpas_ops[i] + # Index 0: Q layout + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=q_sg_layout, + sg_data=q_sg_data, + inst_data=q_inst_data, + index=0, + ) + # Index 1: K^T layout + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=kt_sg_layout, + sg_data=kt_sg_data, + inst_data=kt_inst_data, + order=kt_order, + index=1, + ) + # Index 2: QK output layout (128x16) + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=layout_128x16_sg_layout, + sg_data=layout_128x16_sg_data, + inst_data=layout_128x16_inst_data, + index=2, + ) + + # Layouts for second 4 dpas ops (P@V): + for i in range(4, 8): + pv_dpas_op = dpas_ops[i] + # Index 0: QK (attention weights) layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=qk_sg_layout, + sg_data=qk_sg_data, + inst_data=qk_inst_data, + index=0, + ) + # Index 1: V layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=v_sg_layout, + sg_data=v_sg_data, + inst_data=v_inst_data, + index=1, + ) + # Index 2: Output layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + index=2, + ) + + if stop_at_stage == "xegpu-wg": + raise PipelineInterrupt() + + return mod From 589b01466e35e5b37e695ff88602a4b5b514ad52 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 18 May 2026 23:10:18 +0000 Subject: [PATCH 02/20] remove hard coded layout params --- .../schedule/xegpu/fused_attention_schedule.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index cb99e0e9..6524fba6 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -347,20 +347,21 @@ def bundle_xegpu_fused_attention_schedule( raise PipelineInterrupt() # Define XeGPU layout parameters - q_sg_layout = [8, 1] - q_sg_data = [16, 64] + n_head = parameters["n_head"] + q_sg_layout = [num_subgroups, 1] + q_sg_data = [16, n_head] q_inst_data = [8, 16] - k_sg_layout = [8, 1] - k_sg_data = [16, 64] + k_sg_layout = [num_subgroups, 1] + k_sg_data = [16, n_head] k_inst_data = [16, 16] v_sg_layout = k_sg_layout v_sg_data = k_sg_data v_inst_data = k_inst_data - kt_sg_layout = [1, 8] - kt_sg_data = [64, 16] + kt_sg_layout = [1, num_subgroups] + kt_sg_data = [n_head, 16] kt_inst_data = [16, 16] kt_order = [0, 1] @@ -368,7 +369,7 @@ def bundle_xegpu_fused_attention_schedule( out_sg_data = q_sg_data out_inst_data = q_inst_data - layout_128x16_sg_layout = [8, 1] + layout_128x16_sg_layout = [num_subgroups, 1] layout_128x16_sg_data = [16, 16] layout_128x16_inst_data = [8, 16] From 54a2b4aa58d2f3746c1284d42b4000300c4323af Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 17:59:26 +0000 Subject: [PATCH 03/20] add print msgs --- .../xegpu/fused_attention_schedule.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 6524fba6..e350e195 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -88,6 +88,8 @@ def bundle_xegpu_fused_attention_schedule( ) -> ir.Value[transform.AnyOpType]: """Schedule for lowering fused attention payload to xegpu wg level.""" + transform.print_(target=mod, name="Initial standard attention:") + if stop_at_stage == "initial": raise PipelineInterrupt() @@ -137,6 +139,11 @@ def bundle_xegpu_fused_attention_schedule( transform.apply_cse(func) canonicalize(func) + transform.print_( + target=func, + name="After tiling and fusing batch dimension, and softmax decomposition:", + ) + # Fuse all linalg.generic ops from softmax decomposition (4 ops: max, sub+exp, sum, div) # Match and fuse in reverse order (from consumer to producer) generic_ops = match_and_split(func, ops={"linalg.generic"}, nhandles=4) @@ -194,6 +201,10 @@ def bundle_xegpu_fused_attention_schedule( transform.apply_cse(func) canonicalize(func) + transform.print_( + target=func, name="After tiling and fustion of batch and head dimensions:" + ) + if stop_at_stage == "outer-tiled": raise PipelineInterrupt() @@ -209,6 +220,8 @@ def bundle_xegpu_fused_attention_schedule( apply_patterns_vector_cast_away_vector_leading_one_dim() apply_patterns_vector_drop_unit_dims_with_shape_cast() + transform.print_(target=func, name="After vectorization:") + if stop_at_stage == "vectorized": raise PipelineInterrupt() @@ -223,6 +236,10 @@ def bundle_xegpu_fused_attention_schedule( ).result transform.apply_cse(mod) canonicalize(mod) + + transform.print_(target=mod, name="After bufferization:") + if stop_at_stage == "bufferized": + raise PipelineInterrupt() # fold memref.subviews into vector.transfer_read/write ops mod = apply_registered_pass(mod, "fold-memref-alias-ops") transform.apply_cse(mod) @@ -268,9 +285,6 @@ def bundle_xegpu_fused_attention_schedule( mulf_op = match_and_split(func, ops={"arith.mulf"}, nhandles=1)[0] scale = transform.get_producer_of_operand(anytype, mulf_op, operand_number=1) - if stop_at_stage == "bufferized": - raise PipelineInterrupt() - # Generate fused attention computation with inner tiling # This replaces the second vector.contract (attention_weights @ V) with a tiled # loop that implements online softmax for efficient memory usage @@ -287,6 +301,9 @@ def bundle_xegpu_fused_attention_schedule( ) transform.apply_cse(func) canonicalize(func) + transform.print_( + target=func, name="After generating fused attention with inner tiling:" + ) if stop_at_stage == "inner-tiled": raise PipelineInterrupt() @@ -320,6 +337,8 @@ def bundle_xegpu_fused_attention_schedule( mod = apply_registered_pass(mod, "gpu-kernel-outlining") transform.apply_cse(mod) + transform.print_(target=mod, name="After GPU outlining:") + if stop_at_stage == "gpu-outlining": raise PipelineInterrupt() @@ -343,6 +362,8 @@ def bundle_xegpu_fused_attention_schedule( transform.apply_cse(gpu_func) gpu_func = apply_registered_pass(gpu_func, "loop-invariant-code-motion") + transform.print_(target=gpu_func, name="After converting vector to xegpu:") + if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() @@ -471,7 +492,7 @@ def bundle_xegpu_fused_attention_schedule( inst_data=out_inst_data, index=2, ) - + transform.print_(target=gpu_func, name="After setting xegpu layouts:") if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() From 74fdaa31a1a0a2946af9d9288cdae212605725e9 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 18:19:18 +0000 Subject: [PATCH 04/20] save initial doc --- docs/fused_attention_lowering.md | 798 +++++++++++++++++++++++++++++++ 1 file changed, 798 insertions(+) create mode 100644 docs/fused_attention_lowering.md diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md new file mode 100644 index 00000000..ba974329 --- /dev/null +++ b/docs/fused_attention_lowering.md @@ -0,0 +1,798 @@ +# Fused Attention Kernel Lowering Flow + +This document describes the multi-stage lowering process for standard attention kernels in MLIR, showing how high-level operations are progressively transformed into hardware-specific GPU code. + +--- + +## Stage 1: Initial Standard Attention + +**Input shape**: `2x8x4096x64xf16` (batch × heads × sequence × head_dim) + +### Key Operations + +```mlir +// Reshape from 4D to 3D: collapse batch and head dimensions +%q_3d = memref.collapse_shape %arg_q [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> +%k_3d = memref.collapse_shape %arg_k [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> +%v_3d = memref.collapse_shape %arg_v [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> + +// Transpose K: [16, 4096, 64] -> [16, 64, 4096] +%k_transposed = linalg.transpose ins(%k_3d : tensor<16x4096x64xf16>) + outs(%empty_kt : tensor<16x64x4096xf16>) permutation = [0, 2, 1] + +// Q @ K^T: [16, 4096, 64] @ [16, 64, 4096] -> [16, 4096, 4096] +%qk_scores = linalg.batch_matmul ins(%q_3d, %k_transposed : ...) + outs(%qk_init : tensor<16x4096x4096xf16>) -> tensor<16x4096x4096xf16> + +// Scale by 1/sqrt(d_k) = 0.125 +%qk_scaled = linalg.mul ins(%qk_scores, %scale_factor : tensor<16x4096x4096xf16>, ...) + -> tensor<16x4096x4096xf16> + +// Softmax over last dimension +%attention_weights = linalg.softmax dimension(2) ins(%qk_scaled : ...) + -> tensor<16x4096x4096xf16> + +// Attention @ V: [16, 4096, 4096] @ [16, 4096, 64] -> [16, 4096, 64] +%output = linalg.batch_matmul ins(%attention_weights, %v_3d : ...) + outs(%output_init : tensor<16x4096x64xf16>) -> tensor<16x4096x64xf16> +``` + +**Characteristics**: +- Materializes full attention matrix `[16, 4096, 4096]` in memory +- Sequential operations with clear dependencies +- Single monolithic softmax operation + +--- + +## Stage 2: Tiling and Softmax Decomposition + +### Major Changes + +#### Softmax Decomposition +The atomic `linalg.softmax` is decomposed into explicit operations: + +```mlir +// 1. Find max value per row (for numerical stability) +%max_per_row = linalg.generic { + iterator_types = ["parallel", "parallel", "reduction"] +} ins(%qk_scaled) outs(%max_init) { +^bb0(%score: f16, %current_max: f16): + %new_max = arith.maxnumf %score, %current_max : f16 + linalg.yield %new_max : f16 +} -> tensor<16x4096xf16> + +// 2. Compute exp(x - max) for each element +%exp_scores = linalg.generic { + iterator_types = ["parallel", "parallel", "parallel"] +} ins(%qk_scaled, %max_per_row) outs(%exp_init) { +^bb0(%score: f16, %max_val: f16, %out: f16): + %centered = arith.subf %score, %max_val : f16 + %exp_val = math.exp %centered : f16 + linalg.yield %exp_val : f16 +} -> tensor<16x4096x4096xf16> + +// 3. Sum exp values per row +%sum_per_row = linalg.generic { + iterator_types = ["parallel", "parallel", "reduction"] +} ins(%exp_scores) outs(%sum_init) { +^bb0(%exp_val: f16, %current_sum: f16): + %new_sum = arith.addf %exp_val, %current_sum : f16 + linalg.yield %new_sum : f16 +} -> tensor<16x4096xf16> + +// 4. Normalize: divide each element by sum +%attention_weights = linalg.generic { + iterator_types = ["parallel", "parallel", "parallel"] +} ins(%exp_scores, %sum_per_row) outs(%norm_init) { +^bb0(%exp_val: f16, %sum: f16, %out: f16): + %normalized = arith.divf %exp_val, %sum : f16 + linalg.yield %normalized : f16 +} -> tensor<16x4096x4096xf16> +``` + +#### Tiling the Output Dimension + +The second matmul (attention @ V) is tiled into 32 tiles of size 128: + +```mlir +%output = scf.forall (%batch_head_idx, %tile_idx) in (16, 32) + shared_outs(%out_accumulator = %output_init) -> (tensor<16x4096x64xf16>) { + %row_offset = affine.apply affine_map<(d0) -> (d0 * 128)>(%tile_idx) + + // Extract 128 rows of attention weights: [1, 128, 4096] + %attention_tile = tensor.extract_slice %attention_weights[%batch_head_idx, %row_offset, 0] + [1, 128, 4096] [1, 1, 1] : tensor<16x4096x4096xf16> to tensor<1x128x4096xf16> + + // Extract all of V: [1, 4096, 64] + %v_tile = tensor.extract_slice %v_3d[%batch_head_idx, 0, 0] [1, 4096, 64] [1, 1, 1] + : tensor<16x4096x64xf16> to tensor<1x4096x64xf16> + + // Compute partial result: [1, 128, 4096] @ [1, 4096, 64] -> [1, 128, 64] + %partial_output = linalg.batch_matmul ins(%attention_tile, %v_tile) + outs(%partial_init : tensor<1x128x64xf16>) -> tensor<1x128x64xf16> + + scf.forall.in_parallel { + tensor.parallel_insert_slice %partial_output into %out_accumulator[%batch_head_idx, %row_offset, 0] + [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16> into tensor<16x4096x64xf16> + } +} +``` + +**Key Insight**: Still computes the full `4096×4096` attention matrix before the tiled second matmul. + +--- + +## Stage 3: Tiling Batch and Head Dimensions + +### Fusion of Operations + +The entire attention computation is now fused into a single parallel loop: + +```mlir +%output = scf.forall (%batch_head_idx, %tile_idx) in (16, 32) + shared_outs(%out_accumulator = %output_init) -> (tensor<16x4096x64xf16>) { + %row_offset = affine.apply affine_map<(d0) -> (d0 * 128)>(%tile_idx) + + // Extract Q tile: [1, 128, 64] + %q_tile = tensor.extract_slice %q_3d[%batch_head_idx, %row_offset, 0] [1, 128, 64] [1, 1, 1] + : tensor<16x4096x64xf16> to tensor<1x128x64xf16> + + // Extract full K: [1, 4096, 64] + %k_tile = tensor.extract_slice %k_3d[%batch_head_idx, 0, 0] [1, 4096, 64] [1, 1, 1] + : tensor<16x4096x64xf16> to tensor<1x4096x64xf16> + + // Transpose K within tile + %k_tile_transposed = linalg.transpose ins(%k_tile : tensor<1x4096x64xf16>) + outs(%kt_init : tensor<1x64x4096xf16>) permutation = [0, 2, 1] + + // Q @ K^T: [1, 128, 64] @ [1, 64, 4096] -> [1, 128, 4096] + %qk_scores = linalg.batch_matmul ins(%q_tile, %k_tile_transposed : ...) + outs(%qk_init : tensor<1x128x4096xf16>) -> tensor<1x128x4096xf16> + + // Scale + %qk_scaled = linalg.mul ins(%qk_scores, %scale_factor : ...) + -> tensor<1x128x4096xf16> + + // Softmax decomposition (max, exp, sum, normalize) + %max_per_row = linalg.generic { ... } // max reduction -> [1, 128] + %exp_scores = linalg.generic { ... } // exp(x - max) + %sum_per_row = linalg.generic { ... } // sum reduction + %attention_weights = linalg.generic { ... } // normalize + + // Extract V: [1, 4096, 64] + %v_tile = tensor.extract_slice %v_3d[%batch_head_idx, 0, 0] [1, 4096, 64] [1, 1, 1] + : tensor<16x4096x64xf16> to tensor<1x4096x64xf16> + + // Attention @ V: [1, 128, 4096] @ [1, 4096, 64] -> [1, 128, 64] + %partial_output = linalg.batch_matmul ins(%attention_weights, %v_tile : ...) + outs(%partial_init : tensor<1x128x64xf16>) -> tensor<1x128x64xf16> + + scf.forall.in_parallel { + tensor.parallel_insert_slice %partial_output into %out_accumulator[%batch_head_idx, %row_offset, 0] + [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16> into tensor<16x4096x64xf16> + } +} +``` + +**Key Change**: Each workgroup now processes: +- 128 rows of Q +- All of K and V (still materializes `128×4096` attention matrix) +- Produces 128 rows of output + +**Parallelism**: 16 × 32 = 512 independent workgroups + +--- + +## Stage 4: Vectorization + +Linalg operations are converted to vector operations for SIMD execution. + +### First Matmul (Q @ K^T) + +```mlir +// Read K: [4096, 64] +%k_vec = vector.transfer_read %k_3d[%batch_head_idx, %c0, %c0], %poison {in_bounds = [true, true]} + : tensor<16x4096x64xf16>, vector<4096x64xf16> + +// Read Q tile: [128, 64] +%q_vec = vector.transfer_read %q_3d[%batch_head_idx, %row_offset, %c0], %poison {in_bounds = [true, true]} + : tensor<16x4096x64xf16>, vector<128x64xf16> + +// Contract (matmul): [128, 64] @ [4096, 64]^T -> [128, 4096] +%qk_scores_vec = vector.contract { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, // Q: reduce over d2 + affine_map<(d0, d1, d2) -> (d1, d2)>, // K^T: reduce over d2 + affine_map<(d0, d1, d2) -> (d0, d1)> // Output + ], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind +} %q_vec, %k_vec, %zero_init : vector<128x64xf16>, vector<4096x64xf16> into vector<128x4096xf16> +``` + +### Softmax with Vector Reductions + +```mlir +// Scale by 1/sqrt(d_k) +%qk_scaled_vec = arith.mulf %qk_scores_vec, %scale_factor_vec : vector<128x4096xf16> + +// Reshape for reduction: [128, 4096] -> [1, 128, 4096] +%qk_3d = vector.shape_cast %qk_scaled_vec : vector<128x4096xf16> to vector<1x128x4096xf16> + +// Max reduction along last dimension (across sequence length) +%max_per_row_vec = vector.multi_reduction , %qk_3d, %neg_inf_init [2] + : vector<1x128x4096xf16> to vector<1x128xf16> + +// Broadcast max back to full shape for subtraction +%max_broadcast_3d = vector.broadcast %max_per_row_vec : vector<1x128xf16> to vector<4096x1x128xf16> +%max_broadcast_2d = vector.shape_cast %max_broadcast_3d : vector<4096x1x128xf16> to vector<4096x128xf16> +%max_broadcast = vector.transpose %max_broadcast_2d, [1, 0] : vector<4096x128xf16> to vector<128x4096xf16> + +// Exp of (x - max) for numerical stability +%centered_scores = arith.subf %qk_scaled_vec, %max_broadcast : vector<128x4096xf16> +%exp_scores_vec = math.exp %centered_scores : vector<128x4096xf16> + +// Sum reduction to get denominator +%exp_3d = vector.shape_cast %exp_scores_vec : vector<128x4096xf16> to vector<1x128x4096xf16> +%sum_per_row_vec = vector.multi_reduction , %exp_3d, %zero_init [2] + : vector<1x128x4096xf16> to vector<1x128xf16> + +// Broadcast sum for normalization +%sum_broadcast_3d = vector.broadcast %sum_per_row_vec : vector<1x128xf16> to vector<4096x1x128xf16> +%sum_broadcast_2d = vector.shape_cast %sum_broadcast_3d : vector<4096x1x128xf16> to vector<4096x128xf16> +%sum_broadcast = vector.transpose %sum_broadcast_2d, [1, 0] : vector<4096x128xf16> to vector<128x4096xf16> + +// Normalize to get attention weights +%attention_weights_vec = arith.divf %exp_scores_vec, %sum_broadcast : vector<128x4096xf16> +``` + +### Second Matmul (Attention @ V) + +```mlir +// Read V: [4096, 64] +%v_vec = vector.transfer_read %v_3d[%batch_head_idx, %c0, %c0], %poison {in_bounds = [true, true]} + : tensor<16x4096x64xf16>, vector<4096x64xf16> + +// Contract: [128, 4096] @ [4096, 64] -> [128, 64] +%output_vec = vector.contract { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, // Attention weights: reduce over sequence + affine_map<(d0, d1, d2) -> (d2, d1)>, // V: reduce over sequence + affine_map<(d0, d1, d2) -> (d0, d1)> // Output + ], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind +} %attention_weights_vec, %v_vec, %zero_output_init : + vector<128x4096xf16>, vector<4096x64xf16> into vector<128x64xf16> +``` + +**Result**: All operations now use vector types, ready for SIMD hardware. + +--- + +## Stage 5: Bufferization + +Tensors are converted to memrefs (in-place memory buffers): + +```mlir +func.func @payload(%arg_output: memref<2x8x4096x64xf16>, + %arg_q: memref<2x8x4096x64xf16>, + %arg_k: memref<2x8x4096x64xf16>, + %arg_v: memref<2x8x4096x64xf16>) { + // Constants remain as vectors + %zero_vec_128x64 = arith.constant dense<0.000000e+00> : vector<128x64xf16> + %poison = ub.poison : f16 + + // Collapse shapes on memrefs + %q_3d = memref.collapse_shape %arg_q [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> + %k_3d = memref.collapse_shape %arg_k [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> + %v_3d = memref.collapse_shape %arg_v [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> + %output_3d = memref.collapse_shape %arg_output [[0, 1], [2], [3]] + : memref<2x8x4096x64xf16> into memref<16x4096x64xf16> + + scf.forall (%batch_head_idx, %tile_idx) in (16, 32) { + %row_offset = affine.apply affine_map<(d0) -> (d0 * 128)>(%tile_idx) + + // Direct reads from memrefs + %k_vec = vector.transfer_read %k_3d[%batch_head_idx, %c0, %c0], %poison + : memref<16x4096x64xf16>, vector<4096x64xf16> + + // ... computation ... + + // Create subview for output + %output_subview = memref.subview %output_3d[%batch_head_idx, %row_offset, 0] [1, 128, 64] [1, 1, 1] + : memref<16x4096x64xf16> to memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>> + + // Direct write to memref + vector.transfer_write %output_vec, %output_subview[%c0, %c0, %c0] {in_bounds = [true, true]} + : vector<128x64xf16>, memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>> + } +} +``` + +**Key Change**: No more tensor abstractions; direct memory operations. + +--- + +## Stage 6: Inner Tiling for Fused Attention (Online Softmax) + +**This is the critical optimization** - implements "online" softmax to avoid materializing the full attention matrix. + +### The Online Softmax Algorithm + +Instead of computing the full `128×4096` attention matrix, we process K/V in chunks of 64 and incrementally update: +- Running maximum `m_i` +- Running sum of exponentials `l_i` +- Partial output `O_i` + +```mlir +%final_max:3 = scf.for %kv_chunk_idx = %c0 to %c4096 step %c64 + iter_args(%m_old = %neg_inf_init, %l_old = %zero_init, %O_old = %zero_output_init) + -> (vector<128xf16>, vector<128xf16>, vector<128x64xf16>) { + // m_old = running maximum across previous chunks + // l_old = running sum of exponentials across previous chunks + // O_old = running partial output across previous chunks +``` + +#### Process 64 columns of K at a time (4 chunks of 16) + +```mlir + // Chunk 0: columns [0:16] of K + %k_chunk_0 = vector.transfer_read %k_4d[%batch_idx, %head_idx, %kv_chunk_idx, %c0], %poison + : memref<2x8x4096x64xf16>, vector<16x64xf16> + %k_chunk_0_t = vector.transpose %k_chunk_0, [1, 0] : vector<16x64xf16> to vector<64x16xf16> + %qk_chunk_0 = vector.contract { ... } %q_vec, %k_chunk_0_t, %zero_init : + vector<128x64xf16>, vector<64x16xf16> -> vector<128x16xf16> + + // Chunk 1: columns [16:32] of K + %k_offset_16 = arith.addi %kv_chunk_idx, %c16 : index + %k_chunk_1 = vector.transfer_read %k_4d[%batch_idx, %head_idx, %k_offset_16, %c0], %poison + %k_chunk_1_t = vector.transpose %k_chunk_1, [1, 0] : vector<16x64xf16> to vector<64x16xf16> + %qk_chunk_1 = vector.contract { ... } %q_vec, %k_chunk_1_t, %zero_init : ... -> vector<128x16xf16> + + // Chunk 2: columns [32:48] of K + %k_offset_32 = arith.addi %kv_chunk_idx, %c32 : index + %k_chunk_2 = vector.transfer_read %k_4d[%batch_idx, %head_idx, %k_offset_32, %c0], %poison + %k_chunk_2_t = vector.transpose %k_chunk_2, [1, 0] : vector<16x64xf16> to vector<64x16xf16> + %qk_chunk_2 = vector.contract { ... } %q_vec, %k_chunk_2_t, %zero_init : ... -> vector<128x16xf16> + + // Chunk 3: columns [48:64] of K + %k_offset_48 = arith.addi %kv_chunk_idx, %c48 : index + %k_chunk_3 = vector.transfer_read %k_4d[%batch_idx, %head_idx, %k_offset_48, %c0], %poison + %k_chunk_3_t = vector.transpose %k_chunk_3, [1, 0] : vector<16x64xf16> to vector<64x16xf16> + %qk_chunk_3 = vector.contract { ... } %q_vec, %k_chunk_3_t, %zero_init : ... -> vector<128x16xf16> +``` + +#### Find new maximum across all 4 chunks + +```mlir + // Find max across all 4 chunks + %max_01 = arith.maximumf %qk_chunk_0, %qk_chunk_1 : vector<128x16xf16> + %max_012 = arith.maximumf %max_01, %qk_chunk_2 : vector<128x16xf16> + %max_0123 = arith.maximumf %max_012, %qk_chunk_3 : vector<128x16xf16> + %max_chunk_per_row = vector.multi_reduction , %max_0123, %neg_inf_init [1] : + vector<128x16xf16> -> vector<128xf16> + + // Scale and update running maximum + %max_chunk_scaled = arith.mulf %max_chunk_per_row, %scale_factor_vec : vector<128xf16> + %m_new = arith.maximumf %m_old, %max_chunk_scaled : vector<128xf16> +``` + +#### Compute exponentials for each chunk + +```mlir + // Broadcast m_new for subtraction from all chunks + %m_new_3d = vector.broadcast %m_new : vector<128xf16> to vector<16x128xf16> + %m_new_broadcast = vector.transpose %m_new_3d, [1, 0] : vector<16x128xf16> to vector<128x16xf16> + + // exp(chunk_0 * scale - m_new) + %qk_chunk_0_scaled = arith.mulf %qk_chunk_0, %scale_factor_2d : vector<128x16xf16> + %qk_chunk_0_centered = arith.subf %qk_chunk_0_scaled, %m_new_broadcast : vector<128x16xf16> + %exp_chunk_0 = math.exp %qk_chunk_0_centered : vector<128x16xf16> + + // exp(chunk_1 * scale - m_new) + %qk_chunk_1_scaled = arith.mulf %qk_chunk_1, %scale_factor_2d : vector<128x16xf16> + %qk_chunk_1_centered = arith.subf %qk_chunk_1_scaled, %m_new_broadcast : vector<128x16xf16> + %exp_chunk_1 = math.exp %qk_chunk_1_centered : vector<128x16xf16> + + // exp(chunk_2 * scale - m_new) + %qk_chunk_2_scaled = arith.mulf %qk_chunk_2, %scale_factor_2d : vector<128x16xf16> + %qk_chunk_2_centered = arith.subf %qk_chunk_2_scaled, %m_new_broadcast : vector<128x16xf16> + %exp_chunk_2 = math.exp %qk_chunk_2_centered : vector<128x16xf16> + + // exp(chunk_3 * scale - m_new) + %qk_chunk_3_scaled = arith.mulf %qk_chunk_3, %scale_factor_2d : vector<128x16xf16> + %qk_chunk_3_centered = arith.subf %qk_chunk_3_scaled, %m_new_broadcast : vector<128x16xf16> + %exp_chunk_3 = math.exp %qk_chunk_3_centered : vector<128x16xf16> +``` + +#### Update sum of exponentials + +```mlir + // Sum exponentials across the 4 chunks + %sum_01 = arith.addf %exp_chunk_0, %exp_chunk_1 : vector<128x16xf16> + %sum_012 = arith.addf %sum_01, %exp_chunk_2 : vector<128x16xf16> + %sum_0123 = arith.addf %sum_012, %exp_chunk_3 : vector<128x16xf16> + %l_chunk = vector.multi_reduction , %sum_0123, %zero_init [1] : + vector<128x16xf16> -> vector<128xf16> + + // Correction factor for previous chunks: exp(m_old - m_new) + %m_delta = arith.subf %m_old, %m_new : vector<128xf16> + %correction_factor = math.exp %m_delta : vector<128xf16> + + // Update running sum: l_new = l_old * correction + l_chunk + %l_old_corrected = arith.mulf %l_old, %correction_factor : vector<128xf16> + %l_new = arith.addf %l_old_corrected, %l_chunk : vector<128xf16> +``` + +#### Update partial output + +```mlir + // Rescale old output by correction factor + %correction_3d = vector.broadcast %correction_factor : vector<128xf16> to vector<64x128xf16> + %correction_broadcast = vector.transpose %correction_3d, [1, 0] : vector<64x128xf16> to vector<128x64xf16> + %O_old_corrected = arith.mulf %O_old, %correction_broadcast : vector<128x64xf16> + + // Load corresponding 64 rows of V (chunk 0: rows [0:16]) + %v_chunk_0 = vector.transfer_read %v_4d[%batch_idx, %head_idx, %kv_chunk_idx, %c0], %poison + : memref<2x8x4096x64xf16>, vector<16x64xf16> + + // Accumulate: O += exp_chunk_0 @ V[0:16, :] + %O_partial_0 = vector.contract { ... } %exp_chunk_0, %v_chunk_0, %O_old_corrected : + vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> + + // Accumulate: O += exp_chunk_1 @ V[16:32, :] + %v_chunk_1 = vector.transfer_read %v_4d[%batch_idx, %head_idx, %k_offset_16, %c0], %poison + %O_partial_1 = vector.contract { ... } %exp_chunk_1, %v_chunk_1, %O_partial_0 : + vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> + + // Accumulate: O += exp_chunk_2 @ V[32:48, :] + %v_chunk_2 = vector.transfer_read %v_4d[%batch_idx, %head_idx, %k_offset_32, %c0], %poison + %O_partial_2 = vector.contract { ... } %exp_chunk_2, %v_chunk_2, %O_partial_1 : + vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> + + // Accumulate: O += exp_chunk_3 @ V[48:64, :] + %v_chunk_3 = vector.transfer_read %v_4d[%batch_idx, %head_idx, %k_offset_48, %c0], %poison + %O_new = vector.contract { ... } %exp_chunk_3, %v_chunk_3, %O_partial_2 : + vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> + + scf.yield %m_new, %l_new, %O_new : vector<128xf16>, vector<128xf16>, vector<128x64xf16> +} +``` + +#### Final normalization + +```mlir +// Extract final values from loop +%m_final = %final_max#0 : vector<128xf16> +%l_final = %final_max#1 : vector<128xf16> +%O_accumulated = %final_max#2 : vector<128x64xf16> + +// Broadcast sum to full output shape for normalization +%l_final_3d = vector.broadcast %l_final : vector<128xf16> to vector<64x128xf16> +%l_final_broadcast = vector.transpose %l_final_3d, [1, 0] : vector<64x128xf16> to vector<128x64xf16> + +// Normalize: O_final = O_accumulated / l_final +%output_normalized = arith.divf %O_accumulated, %l_final_broadcast : vector<128x64xf16> + +// Write result back to output buffer +vector.transfer_write %output_normalized, %output_4d[%batch_idx, %head_idx, %row_offset, %c0] + {in_bounds = [true, true]} : vector<128x64xf16>, memref<2x8x4096x64xf16> +``` + +### Memory Savings + +**Before**: `128 × 4096 × 2 bytes = 1 MB` per workgroup + +**After**: +- `128 × 64 × 2 bytes = 16 KB` for partial QK^T (8 chunks of 128×16) +- `128 × 64 × 2 bytes = 16 KB` for partial output +- **Total: 32 KB** per workgroup + +**Reduction: 96.875%** — this enables processing much longer sequences! + +--- + +## Stage 7: GPU Outlining + +The computation is extracted into a separate GPU kernel module: + +```mlir +module attributes {gpu.container_module} { + func.func @payload(%arg_output: memref<2x8x4096x64xf16>, + %arg_q: memref<2x8x4096x64xf16>, + %arg_k: memref<2x8x4096x64xf16>, + %arg_v: memref<2x8x4096x64xf16>) { + %c128 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + + // Launch GPU kernel + gpu.launch_func @payload_kernel::@payload_kernel + blocks in (%c16, %c32, %c1) // Grid: 16 × 32 × 1 (batch×head, seq_tiles, 1) + threads in (%c128, %c1, %c1) // Block: 128 × 1 × 1 + args(%arg_q : memref<2x8x4096x64xf16>, + %arg_k : memref<2x8x4096x64xf16>, + %arg_v : memref<2x8x4096x64xf16>, + %arg_output : memref<2x8x4096x64xf16>) + return + } + + gpu.module @payload_kernel { + gpu.func @payload_kernel(%q: memref<2x8x4096x64xf16>, + %k: memref<2x8x4096x64xf16>, + %v: memref<2x8x4096x64xf16>, + %output: memref<2x8x4096x64xf16>) kernel + attributes { + known_block_size = array, + known_grid_size = array + } { + + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + + // Compute batch and head indices from block_id_x + %row_offset = arith.muli %block_id_y, %c128 : index + %batch_idx = arith.floordivsi %block_id_x, %c8 : index + %head_idx = arith.remsi %block_id_x, %c8 : index + + // ... computation from Stage 6 ... + + gpu.return + } + } +} +``` + +**Grid Configuration**: +- 16 blocks in X (2 batches × 8 heads) +- 32 blocks in Y (4096 / 128 = 32 tiles) +- Each block has 128 threads + +--- + +## Stage 8: Converting to XeGPU Operations + +Vector operations are converted to Intel XeGPU-specific operations using tensor descriptors. + +### Tensor Descriptor Creation + +```mlir +// Create descriptor for reading Q +%q_subview = memref.subview %q[%batch_idx, %head_idx, 0, 0] [1, 1, 4096, 64] [1, 1, 1, 1] + : memref<2x8x4096x64xf16> to memref<4096x64xf16, strided<[64, 1], offset: ?>> + +%q_base_buffer, %q_offset, %q_sizes:2, %q_strides:2 = + memref.extract_strided_metadata %q_subview : memref<4096x64xf16, strided<[64, 1], offset: ?>> + -> memref, index, index, index, index, index + +%q_intptr = memref.extract_aligned_pointer_as_index %q_base_buffer : memref -> index +%q_byte_offset = arith.muli %q_offset, %c2 : index // Multiply by sizeof(f16) +%q_addr = arith.addi %q_intptr, %q_byte_offset : index +%q_addr_i64 = arith.index_cast %q_addr : index to i64 + +// Create XeGPU tensor descriptor for Q +%q_tdesc = xegpu.create_nd_tdesc %q_addr_i64, shape : [4096, 64], strides : [64, 1] : i64 + -> !xegpu.tensor_desc<128x64xf16, #xegpu.block_tdesc_attr> +``` + +### XeGPU Load Operations + +```mlir +// Load Q tile using tensor descriptor at row_offset +%q_vec = xegpu.load_nd %q_tdesc[%row_offset, 0] + : !xegpu.tensor_desc<128x64xf16, #xegpu.block_tdesc_attr> + -> vector<128x64xf16> + +// Load K chunk at kv_chunk_idx +%k_chunk_vec = xegpu.load_nd %k_tdesc[%kv_chunk_idx, 0] + : !xegpu.tensor_desc<16x64xf16, #xegpu.block_tdesc_attr> + -> vector<16x64xf16> +``` + +### XeGPU DPAS (Dot Product Accumulate Systolic) + +The key hardware instruction for matrix multiplication: + +```mlir +// Transpose K chunk for matmul +%k_chunk_t = vector.transpose %k_chunk_vec, [1, 0] : vector<16x64xf16> to vector<64x16xf16> + +// DPAS: Specialized matmul instruction +// Q: [128, 64], K^T: [64, 16], accumulator: [128, 16] +%qk_chunk = xegpu.dpas %q_vec, %k_chunk_t, %zero_accumulator + : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf16> + -> vector<128x16xf16> +``` + +**DPAS Operation**: +- Performs `C = A @ B + C` in hardware +- Optimized for f16 on Intel GPUs +- Systolic array execution + +### XeGPU Store Operations + +```mlir +// Create output descriptor (similar to Q descriptor creation) +%output_addr_i64 = arith.index_cast %output_addr : index to i64 +%output_tdesc = xegpu.create_nd_tdesc %output_addr_i64, shape : [4096, 64], strides : [64, 1] : i64 + -> !xegpu.tensor_desc<128x64xf16, #xegpu.block_tdesc_attr> + +// Store result at row_offset +xegpu.store_nd %output_normalized, %output_tdesc[%row_offset, 0] + : vector<128x64xf16>, + !xegpu.tensor_desc<128x64xf16, #xegpu.block_tdesc_attr> +``` + +--- + +## Stage 9: Setting XeGPU Layouts + +The final stage adds hardware-specific layout information for optimal data distribution across subgroups. + +### Layout Specification + +Layouts describe how data is distributed across: +- **Subgroups (SG)**: Groups of threads that execute together +- **SG Data**: Data assigned to each subgroup +- **Inst Data**: Data processed by a single instruction + +### Load Layout for Q (128×64) + +```mlir +%q_vec = xegpu.load_nd %q_tdesc[%row_offset, 0] <{ + layout = #xegpu.layout< + sg_layout = [8, 1], // 8 subgroups in row, 1 in column + sg_data = [16, 64], // Each SG handles 16×64 data + inst_data = [8, 16] // Each instruction: 8×16 + > +}> : !xegpu.tensor_desc<128x64xf16, ...> -> vector<128x64xf16> +``` + +**Breakdown**: +- 8 subgroups handle rows: 8 × 16 = 128 rows ✓ +- 1 subgroup handles columns: 1 × 64 = 64 columns ✓ +- Each instruction processes 8×16 elements + +### DPAS Layout (Q @ K^T → Attention Scores) + +```mlir +%qk_chunk = xegpu.dpas %q_vec, %k_chunk_t, %zero_accumulator { + layout_a = #xegpu.layout< + sg_layout = [8, 1], // Q: 8 SGs vertically, 1 horizontally + sg_data = [16, 64], // Each SG: 16 rows × 64 cols + inst_data = [8, 16] // Each DPAS: 8 rows × 16 cols + >, + layout_b = #xegpu.layout< + sg_layout = [1, 8], // K^T: 1 SG vertically, 8 horizontally + sg_data = [64, 16], // Each SG: 64 rows × 16 cols + inst_data = [16, 16], // Each DPAS: 16 × 16 + order = [0, 1] + >, + layout_cd = #xegpu.layout< + sg_layout = [8, 1], // Output: 8 SGs vertically + sg_data = [16, 16], // Each SG: 16 rows × 16 cols + inst_data = [8, 16] // Each DPAS output: 8 × 16 + > +} : vector<128x64xf16>, vector<64x16xf16>, vector<128x16xf16> + -> vector<128x16xf16> +``` + +**Matrix Multiplication Mapping**: +- Input A (Q): 8 subgroups × (16 rows × 64 cols) = 128×64 +- Input B (K^T): 8 subgroups × (64 rows × 16 cols) = 64×128 (transposed) +- Output C (Scores): 8 subgroups × (16 rows × 16 cols) = 128×16 + +### DPAS Layout (Attention @ V → Output) + +```mlir +%O_partial = xegpu.dpas %exp_chunk, %v_chunk, %O_old_corrected { + layout_a = #xegpu.layout< + sg_layout = [8, 1], // Exp attention: 8 SGs vertically + sg_data = [16, 16], // Each SG: 16 rows × 16 cols + inst_data = [8, 16] // Each DPAS: 8 × 16 + >, + layout_b = #xegpu.layout< + sg_layout = [8, 1], // V: 8 SGs vertically (not transposed) + sg_data = [16, 64], // Each SG: 16 rows × 64 cols + inst_data = [16, 16] // Each DPAS: 16 × 16 + >, + layout_cd = #xegpu.layout< + sg_layout = [8, 1], // Output: 8 SGs vertically + sg_data = [16, 64], // Each SG: 16 rows × 64 cols + inst_data = [8, 16] // Each DPAS output: 8 × 16 + > +} : vector<128x16xf16>, vector<16x64xf16>, vector<128x64xf16> + -> vector<128x64xf16> +``` + +### Store Layout + +```mlir +xegpu.store_nd %output_normalized, %output_tdesc[%row_offset, 0] <{ + layout = #xegpu.layout< + sg_layout = [8, 1], // 8 SGs vertically, 1 horizontally + sg_data = [16, 64], // Each SG stores: 16 rows × 64 cols + inst_data = [8, 16] // Each store instruction: 8 × 16 + > +}> : vector<128x64xf16>, !xegpu.tensor_desc<128x64xf16, ...> +``` + +### GPU Target Attribute + +```mlir +gpu.module @payload_kernel [#xevm.target] { + // O = 3 specifies optimization level 3 + gpu.func @payload_kernel(...) kernel { ... } +} +``` + +--- + +## Summary of Optimizations + +### Memory Efficiency +1. **Dimension collapse**: `2×8×4096×64` → `16×4096×64` (eliminate indexing overhead) +2. **Online softmax**: Avoid materializing `128×4096` attention matrix (96.9% memory reduction) +3. **Tiled computation**: Process K/V in chunks of 64 + +### Parallelism +1. **Batch/head parallelism**: 16 independent blocks (2 batches × 8 heads) +2. **Sequence parallelism**: 32 blocks for 4096 / 128 +3. **Thread parallelism**: 128 threads per block +4. **Total**: 16 × 32 × 128 = 65,536 concurrent threads + +### Hardware Utilization +1. **Vectorization**: SIMD operations on vectors +2. **XeGPU DPAS**: Hardware-accelerated matrix multiply +3. **Tensor descriptors**: Efficient 2D block loads/stores +4. **Layout optimization**: Data distribution matches hardware subgroup structure + +### Numerical Stability +1. **Max subtraction**: Prevent overflow in exp() +2. **Online updates**: Incrementally correct max/sum as new data arrives +3. **Final normalization**: Divide by accumulated sum + +--- + +## Performance Considerations + +### What Fits in Register File +Per workgroup (128 threads, 8 subgroups): +- Q tile: `128×64×2B = 16 KB` ✓ +- Partial attention scores: `128×16×2B = 4 KB` (4 chunks) ✓ +- Partial output: `128×64×2B = 16 KB` ✓ +- Statistics: `128×2×2B = 512 B` (max + sum) ✓ +- **Total: ~36 KB** — fits in GPU register file + +### Memory Access Pattern +1. **Q**: Load once, reuse for all 64 chunks of K +2. **K**: Stream through in chunks of 64×64 +3. **V**: Stream through in chunks of 64×64 +4. **Output**: Accumulate in registers, write once + +### Compute Intensity +- **FLOPs**: ~2 × 128 × 64 × 4096 × 2 (two matmuls) ≈ 134 MFLOP per workgroup +- **Memory**: ~128 × 64 × 2 (Q) + 4096 × 64 × 2 × 2 (K+V) ≈ 1 MB +- **Arithmetic intensity**: 134 MFLOP / 1 MB ≈ **134 FLOP/byte** — excellent! + +--- + +## Conclusion + +This lowering flow demonstrates a sophisticated compilation strategy: +1. Start with intuitive high-level operations +2. Progressively expose parallelism through tiling +3. Apply memory-saving algorithms (online softmax) +4. Lower to vector operations for SIMD +5. Map to hardware-specific instructions (DPAS) +6. Optimize data layout for hardware execution + +The result is a highly optimized fused attention kernel that maximizes throughput while minimizing memory footprint. From 1bedc5b84b96c247a13a56a967aced7a3402c9df Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 20:10:56 +0000 Subject: [PATCH 05/20] save refined version of doc --- docs/fused_attention_lowering.md | 147 +++++++------------------------ 1 file changed, 33 insertions(+), 114 deletions(-) diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md index ba974329..3e8b8700 100644 --- a/docs/fused_attention_lowering.md +++ b/docs/fused_attention_lowering.md @@ -24,7 +24,7 @@ This document describes the multi-stage lowering process for standard attention outs(%empty_kt : tensor<16x64x4096xf16>) permutation = [0, 2, 1] // Q @ K^T: [16, 4096, 64] @ [16, 64, 4096] -> [16, 4096, 4096] -%qk_scores = linalg.batch_matmul ins(%q_3d, %k_transposed : ...) +%qk_scores = linalg.batch_matmul ins(%q_3d, %k_transposed : ...) outs(%qk_init : tensor<16x4096x4096xf16>) -> tensor<16x4096x4096xf16> // Scale by 1/sqrt(d_k) = 0.125 @@ -36,22 +36,16 @@ This document describes the multi-stage lowering process for standard attention -> tensor<16x4096x4096xf16> // Attention @ V: [16, 4096, 4096] @ [16, 4096, 64] -> [16, 4096, 64] -%output = linalg.batch_matmul ins(%attention_weights, %v_3d : ...) +%output = linalg.batch_matmul ins(%attention_weights, %v_3d : ...) outs(%output_init : tensor<16x4096x64xf16>) -> tensor<16x4096x64xf16> ``` -**Characteristics**: -- Materializes full attention matrix `[16, 4096, 4096]` in memory -- Sequential operations with clear dependencies -- Single monolithic softmax operation - --- ## Stage 2: Tiling and Softmax Decomposition -### Major Changes -#### Softmax Decomposition +### Softmax Decomposition The atomic `linalg.softmax` is decomposed into explicit operations: ```mlir @@ -93,17 +87,17 @@ The atomic `linalg.softmax` is decomposed into explicit operations: } -> tensor<16x4096x4096xf16> ``` -#### Tiling the Output Dimension +### Tiling the Output Dimension The second matmul (attention @ V) is tiled into 32 tiles of size 128: ```mlir -%output = scf.forall (%batch_head_idx, %tile_idx) in (16, 32) +%output = scf.forall (%batch_head_idx, %tile_idx) in (16, 32) shared_outs(%out_accumulator = %output_init) -> (tensor<16x4096x64xf16>) { %row_offset = affine.apply affine_map<(d0) -> (d0 * 128)>(%tile_idx) // Extract 128 rows of attention weights: [1, 128, 4096] - %attention_tile = tensor.extract_slice %attention_weights[%batch_head_idx, %row_offset, 0] + %attention_tile = tensor.extract_slice %attention_weights[%batch_head_idx, %row_offset, 0] [1, 128, 4096] [1, 1, 1] : tensor<16x4096x4096xf16> to tensor<1x128x4096xf16> // Extract all of V: [1, 4096, 64] @@ -115,14 +109,12 @@ The second matmul (attention @ V) is tiled into 32 tiles of size 128: outs(%partial_init : tensor<1x128x64xf16>) -> tensor<1x128x64xf16> scf.forall.in_parallel { - tensor.parallel_insert_slice %partial_output into %out_accumulator[%batch_head_idx, %row_offset, 0] + tensor.parallel_insert_slice %partial_output into %out_accumulator[%batch_head_idx, %row_offset, 0] [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16> into tensor<16x4096x64xf16> } } ``` -**Key Insight**: Still computes the full `4096×4096` attention matrix before the tiled second matmul. - --- ## Stage 3: Tiling Batch and Head Dimensions @@ -132,7 +124,7 @@ The second matmul (attention @ V) is tiled into 32 tiles of size 128: The entire attention computation is now fused into a single parallel loop: ```mlir -%output = scf.forall (%batch_head_idx, %tile_idx) in (16, 32) +%output = scf.forall (%batch_head_idx, %tile_idx) in (16, 32) shared_outs(%out_accumulator = %output_init) -> (tensor<16x4096x64xf16>) { %row_offset = affine.apply affine_map<(d0) -> (d0 * 128)>(%tile_idx) @@ -145,7 +137,7 @@ The entire attention computation is now fused into a single parallel loop: : tensor<16x4096x64xf16> to tensor<1x4096x64xf16> // Transpose K within tile - %k_tile_transposed = linalg.transpose ins(%k_tile : tensor<1x4096x64xf16>) + %k_tile_transposed = linalg.transpose ins(%k_tile : tensor<1x4096x64xf16>) outs(%kt_init : tensor<1x64x4096xf16>) permutation = [0, 2, 1] // Q @ K^T: [1, 128, 64] @ [1, 64, 4096] -> [1, 128, 4096] @@ -171,18 +163,18 @@ The entire attention computation is now fused into a single parallel loop: outs(%partial_init : tensor<1x128x64xf16>) -> tensor<1x128x64xf16> scf.forall.in_parallel { - tensor.parallel_insert_slice %partial_output into %out_accumulator[%batch_head_idx, %row_offset, 0] + tensor.parallel_insert_slice %partial_output into %out_accumulator[%batch_head_idx, %row_offset, 0] [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16> into tensor<16x4096x64xf16> } } ``` -**Key Change**: Each workgroup now processes: +Each workgroup now processes: - 128 rows of Q - All of K and V (still materializes `128×4096` attention matrix) - Produces 128 rows of output -**Parallelism**: 16 × 32 = 512 independent workgroups +16 × 32 = 512 independent workgroups --- @@ -265,12 +257,10 @@ Linalg operations are converted to vector operations for SIMD execution. ], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind -} %attention_weights_vec, %v_vec, %zero_output_init : +} %attention_weights_vec, %v_vec, %zero_output_init : vector<128x4096xf16>, vector<4096x64xf16> into vector<128x64xf16> ``` -**Result**: All operations now use vector types, ready for SIMD hardware. - --- ## Stage 5: Bufferization @@ -278,7 +268,7 @@ Linalg operations are converted to vector operations for SIMD execution. Tensors are converted to memrefs (in-place memory buffers): ```mlir -func.func @payload(%arg_output: memref<2x8x4096x64xf16>, +func.func @payload(%arg_output: memref<2x8x4096x64xf16>, %arg_q: memref<2x8x4096x64xf16>, %arg_k: memref<2x8x4096x64xf16>, %arg_v: memref<2x8x4096x64xf16>) { @@ -298,7 +288,7 @@ func.func @payload(%arg_output: memref<2x8x4096x64xf16>, scf.forall (%batch_head_idx, %tile_idx) in (16, 32) { %row_offset = affine.apply affine_map<(d0) -> (d0 * 128)>(%tile_idx) - + // Direct reads from memrefs %k_vec = vector.transfer_read %k_3d[%batch_head_idx, %c0, %c0], %poison : memref<16x4096x64xf16>, vector<4096x64xf16> @@ -316,13 +306,14 @@ func.func @payload(%arg_output: memref<2x8x4096x64xf16>, } ``` -**Key Change**: No more tensor abstractions; direct memory operations. - --- ## Stage 6: Inner Tiling for Fused Attention (Online Softmax) -**This is the critical optimization** - implements "online" softmax to avoid materializing the full attention matrix. +This is the **Flash Attention** optimization. + +1. Implements "online" softmax to avoid materializing the full attention matrix. +2. Tile the K and V loads into `16 x d_head` and interleave with DPAS for lower register preassure. ### The Online Softmax Algorithm @@ -347,7 +338,7 @@ Instead of computing the full `128×4096` attention matrix, we process K/V in ch %k_chunk_0 = vector.transfer_read %k_4d[%batch_idx, %head_idx, %kv_chunk_idx, %c0], %poison : memref<2x8x4096x64xf16>, vector<16x64xf16> %k_chunk_0_t = vector.transpose %k_chunk_0, [1, 0] : vector<16x64xf16> to vector<64x16xf16> - %qk_chunk_0 = vector.contract { ... } %q_vec, %k_chunk_0_t, %zero_init : + %qk_chunk_0 = vector.contract { ... } %q_vec, %k_chunk_0_t, %zero_init : vector<128x64xf16>, vector<64x16xf16> -> vector<128x16xf16> // Chunk 1: columns [16:32] of K @@ -376,7 +367,7 @@ Instead of computing the full `128×4096` attention matrix, we process K/V in ch %max_01 = arith.maximumf %qk_chunk_0, %qk_chunk_1 : vector<128x16xf16> %max_012 = arith.maximumf %max_01, %qk_chunk_2 : vector<128x16xf16> %max_0123 = arith.maximumf %max_012, %qk_chunk_3 : vector<128x16xf16> - %max_chunk_per_row = vector.multi_reduction , %max_0123, %neg_inf_init [1] : + %max_chunk_per_row = vector.multi_reduction , %max_0123, %neg_inf_init [1] : vector<128x16xf16> -> vector<128xf16> // Scale and update running maximum @@ -419,7 +410,7 @@ Instead of computing the full `128×4096` attention matrix, we process K/V in ch %sum_01 = arith.addf %exp_chunk_0, %exp_chunk_1 : vector<128x16xf16> %sum_012 = arith.addf %sum_01, %exp_chunk_2 : vector<128x16xf16> %sum_0123 = arith.addf %sum_012, %exp_chunk_3 : vector<128x16xf16> - %l_chunk = vector.multi_reduction , %sum_0123, %zero_init [1] : + %l_chunk = vector.multi_reduction , %sum_0123, %zero_init [1] : vector<128x16xf16> -> vector<128xf16> // Correction factor for previous chunks: exp(m_old - m_new) @@ -444,22 +435,22 @@ Instead of computing the full `128×4096` attention matrix, we process K/V in ch : memref<2x8x4096x64xf16>, vector<16x64xf16> // Accumulate: O += exp_chunk_0 @ V[0:16, :] - %O_partial_0 = vector.contract { ... } %exp_chunk_0, %v_chunk_0, %O_old_corrected : + %O_partial_0 = vector.contract { ... } %exp_chunk_0, %v_chunk_0, %O_old_corrected : vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> // Accumulate: O += exp_chunk_1 @ V[16:32, :] %v_chunk_1 = vector.transfer_read %v_4d[%batch_idx, %head_idx, %k_offset_16, %c0], %poison - %O_partial_1 = vector.contract { ... } %exp_chunk_1, %v_chunk_1, %O_partial_0 : + %O_partial_1 = vector.contract { ... } %exp_chunk_1, %v_chunk_1, %O_partial_0 : vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> // Accumulate: O += exp_chunk_2 @ V[32:48, :] %v_chunk_2 = vector.transfer_read %v_4d[%batch_idx, %head_idx, %k_offset_32, %c0], %poison - %O_partial_2 = vector.contract { ... } %exp_chunk_2, %v_chunk_2, %O_partial_1 : + %O_partial_2 = vector.contract { ... } %exp_chunk_2, %v_chunk_2, %O_partial_1 : vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> // Accumulate: O += exp_chunk_3 @ V[48:64, :] %v_chunk_3 = vector.transfer_read %v_4d[%batch_idx, %head_idx, %k_offset_48, %c0], %poison - %O_new = vector.contract { ... } %exp_chunk_3, %v_chunk_3, %O_partial_2 : + %O_new = vector.contract { ... } %exp_chunk_3, %v_chunk_3, %O_partial_2 : vector<128x16xf16>, vector<16x64xf16> -> vector<128x64xf16> scf.yield %m_new, %l_new, %O_new : vector<128xf16>, vector<128xf16>, vector<128x64xf16> @@ -482,30 +473,19 @@ Instead of computing the full `128×4096` attention matrix, we process K/V in ch %output_normalized = arith.divf %O_accumulated, %l_final_broadcast : vector<128x64xf16> // Write result back to output buffer -vector.transfer_write %output_normalized, %output_4d[%batch_idx, %head_idx, %row_offset, %c0] +vector.transfer_write %output_normalized, %output_4d[%batch_idx, %head_idx, %row_offset, %c0] {in_bounds = [true, true]} : vector<128x64xf16>, memref<2x8x4096x64xf16> ``` -### Memory Savings - -**Before**: `128 × 4096 × 2 bytes = 1 MB` per workgroup - -**After**: -- `128 × 64 × 2 bytes = 16 KB` for partial QK^T (8 chunks of 128×16) -- `128 × 64 × 2 bytes = 16 KB` for partial output -- **Total: 32 KB** per workgroup - -**Reduction: 96.875%** — this enables processing much longer sequences! - --- ## Stage 7: GPU Outlining -The computation is extracted into a separate GPU kernel module: +`scf.forall` loop is distirbuted to workgroups. ```mlir module attributes {gpu.container_module} { - func.func @payload(%arg_output: memref<2x8x4096x64xf16>, + func.func @payload(%arg_output: memref<2x8x4096x64xf16>, %arg_q: memref<2x8x4096x64xf16>, %arg_k: memref<2x8x4096x64xf16>, %arg_v: memref<2x8x4096x64xf16>) { @@ -518,7 +498,7 @@ module attributes {gpu.container_module} { gpu.launch_func @payload_kernel::@payload_kernel blocks in (%c16, %c32, %c1) // Grid: 16 × 32 × 1 (batch×head, seq_tiles, 1) threads in (%c128, %c1, %c1) // Block: 128 × 1 × 1 - args(%arg_q : memref<2x8x4096x64xf16>, + args(%arg_q : memref<2x8x4096x64xf16>, %arg_k : memref<2x8x4096x64xf16>, %arg_v : memref<2x8x4096x64xf16>, %arg_output : memref<2x8x4096x64xf16>) @@ -526,7 +506,7 @@ module attributes {gpu.container_module} { } gpu.module @payload_kernel { - gpu.func @payload_kernel(%q: memref<2x8x4096x64xf16>, + gpu.func @payload_kernel(%q: memref<2x8x4096x64xf16>, %k: memref<2x8x4096x64xf16>, %v: memref<2x8x4096x64xf16>, %output: memref<2x8x4096x64xf16>) kernel @@ -554,7 +534,7 @@ module attributes {gpu.container_module} { **Grid Configuration**: - 16 blocks in X (2 batches × 8 heads) - 32 blocks in Y (4096 / 128 = 32 tiles) -- Each block has 128 threads +- Each block has 128 threads (8 subgroups) --- @@ -612,11 +592,6 @@ The key hardware instruction for matrix multiplication: -> vector<128x16xf16> ``` -**DPAS Operation**: -- Performs `C = A @ B + C` in hardware -- Optimized for f16 on Intel GPUs -- Systolic array execution - ### XeGPU Store Operations ```mlir @@ -635,14 +610,7 @@ xegpu.store_nd %output_normalized, %output_tdesc[%row_offset, 0] ## Stage 9: Setting XeGPU Layouts -The final stage adds hardware-specific layout information for optimal data distribution across subgroups. - -### Layout Specification - -Layouts describe how data is distributed across: -- **Subgroups (SG)**: Groups of threads that execute together -- **SG Data**: Data assigned to each subgroup -- **Inst Data**: Data processed by a single instruction +Assign layouts for distributing to subgroups. ### Load Layout for Q (128×64) @@ -736,55 +704,6 @@ gpu.module @payload_kernel [#xevm.target] { --- -## Summary of Optimizations - -### Memory Efficiency -1. **Dimension collapse**: `2×8×4096×64` → `16×4096×64` (eliminate indexing overhead) -2. **Online softmax**: Avoid materializing `128×4096` attention matrix (96.9% memory reduction) -3. **Tiled computation**: Process K/V in chunks of 64 - -### Parallelism -1. **Batch/head parallelism**: 16 independent blocks (2 batches × 8 heads) -2. **Sequence parallelism**: 32 blocks for 4096 / 128 -3. **Thread parallelism**: 128 threads per block -4. **Total**: 16 × 32 × 128 = 65,536 concurrent threads - -### Hardware Utilization -1. **Vectorization**: SIMD operations on vectors -2. **XeGPU DPAS**: Hardware-accelerated matrix multiply -3. **Tensor descriptors**: Efficient 2D block loads/stores -4. **Layout optimization**: Data distribution matches hardware subgroup structure - -### Numerical Stability -1. **Max subtraction**: Prevent overflow in exp() -2. **Online updates**: Incrementally correct max/sum as new data arrives -3. **Final normalization**: Divide by accumulated sum - ---- - -## Performance Considerations - -### What Fits in Register File -Per workgroup (128 threads, 8 subgroups): -- Q tile: `128×64×2B = 16 KB` ✓ -- Partial attention scores: `128×16×2B = 4 KB` (4 chunks) ✓ -- Partial output: `128×64×2B = 16 KB` ✓ -- Statistics: `128×2×2B = 512 B` (max + sum) ✓ -- **Total: ~36 KB** — fits in GPU register file - -### Memory Access Pattern -1. **Q**: Load once, reuse for all 64 chunks of K -2. **K**: Stream through in chunks of 64×64 -3. **V**: Stream through in chunks of 64×64 -4. **Output**: Accumulate in registers, write once - -### Compute Intensity -- **FLOPs**: ~2 × 128 × 64 × 4096 × 2 (two matmuls) ≈ 134 MFLOP per workgroup -- **Memory**: ~128 × 64 × 2 (Q) + 4096 × 64 × 2 × 2 (K+V) ≈ 1 MB -- **Arithmetic intensity**: 134 MFLOP / 1 MB ≈ **134 FLOP/byte** — excellent! - ---- - ## Conclusion This lowering flow demonstrates a sophisticated compilation strategy: From bddad4006582f8e524c0016d11339deea39b45f1 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 20:49:13 +0000 Subject: [PATCH 06/20] add diagram --- docs/stage6.md | 271 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 271 insertions(+) create mode 100644 docs/stage6.md diff --git a/docs/stage6.md b/docs/stage6.md new file mode 100644 index 00000000..fe02acf0 --- /dev/null +++ b/docs/stage6.md @@ -0,0 +1,271 @@ +# Stage 6: Online Softmax (Fused Attention) + +Process K/V in chunks to avoid materializing the full attention matrix. + +## Matrix Layout Overview + +```mermaid +graph TB + subgraph Input["Input Matrices"] + Q["Q Tile
128 × 64
(solid)"] + K["K Matrix (full context)
4096 × 64
(load in chunks)"] + V["V Matrix (full context)
4096 × 64
(load in chunks)"] + end + + subgraph Process["Chunked Processing"] + K0["K Chunk 0
64 × 64"] + K1["K Chunk 1
64 × 64"] + K2["K Chunk 2
64 × 64"] + Kdots["..."] + K63["K Chunk 63
64 × 64"] + end + + subgraph Output["Output & State"] + O["Output O
128 × 64
(accumulated)"] + State["Running State
m_i (max)
l_i (sum)"] + end + + Q -->|Q @ K^T| K0 + K -->|Split into
64 chunks| K0 + K --> K1 + K --> K2 + K --> Kdots + K --> K63 + + K0 -->|scores| V + V -->|V Chunk 0| O + K0 -.->|update| State + State -.->|correction| O + + style Q fill:#bbdefb,stroke:#1976d2,stroke-width:3px + style K fill:#e0e0e0,stroke:#757575,stroke-width:2px,stroke-dasharray: 5 5 + style V fill:#e0e0e0,stroke:#757575,stroke-width:2px,stroke-dasharray: 5 5 + style O fill:#c8e6c9,stroke:#388e3c,stroke-width:3px + style State fill:#fff9c4,stroke:#f57f17,stroke-width:2px + style K0 fill:#ffcdd2,stroke:#c62828,stroke-width:2px +``` + +## Online Softmax Algorithm Flow + +```mermaid +flowchart TD + Start([Start: Loop over chunks i=0..63]) --> Init[Initialize:
m_0 = -∞
l_0 = 0
O_0 = 0] + + Init --> LoadK[Load K chunk i
64 × 64] + LoadK --> QK[Compute Q @ K^T
128 × 64 @ 64 × 64
= 128 × 64 scores] + + QK --> Scale[Scale scores
scores * 1/√d_k] + + Scale --> MaxChunk[Find max in chunk
m_chunk = max(scores)] + MaxChunk --> UpdateMax[Update global max
m_new = max(m_old, m_chunk)] + + UpdateMax --> ExpScores[Compute exponentials
exp(scores - m_new)] + ExpScores --> SumChunk[Sum exponentials
l_chunk = Σ exp(scores)] + + SumChunk --> Correction[Compute correction
α = exp(m_old - m_new)] + Correction --> UpdateSum[Update sum
l_new = l_old * α + l_chunk] + + UpdateSum --> LoadV[Load V chunk i
64 × 64] + LoadV --> RescaleO[Rescale old output
O_old * α] + RescaleO --> AccumO[Accumulate output
O_new = O_old * α + exp(scores) @ V] + + AccumO --> Check{More
chunks?} + Check -->|Yes| LoadK + Check -->|No| Normalize[Final normalization
O_final = O_accumulated / l_final] + + Normalize --> End([End: Output ready]) + + style Start fill:#e1bee7,stroke:#7b1fa2,stroke-width:2px + style End fill:#e1bee7,stroke:#7b1fa2,stroke-width:2px + style UpdateMax fill:#ffcdd2,stroke:#c62828,stroke-width:2px + style UpdateSum fill:#ffcdd2,stroke:#c62828,stroke-width:2px + style AccumO fill:#c8e6c9,stroke:#388e3c,stroke-width:2px + style Normalize fill:#c8e6c9,stroke:#388e3c,stroke-width:2px +``` + +## Chunk Processing Detail + +```mermaid +sequenceDiagram + participant Q as Q Tile
(128×64) + participant K as K Chunks
(64×64 each) + participant V as V Chunks
(64×64 each) + participant State as Running State
(m_i, l_i) + participant O as Output O
(128×64) + + Note over Q,O: Initialize: m_0=-∞, l_0=0, O_0=0 + + loop For each chunk i = 0..63 + K->>Q: Load K chunk i + Q->>Q: Compute Q @ K^T → scores (128×64) + Q->>Q: Scale scores + Q->>State: Find max → m_chunk + State->>State: m_new = max(m_old, m_chunk) + Q->>Q: exp(scores - m_new) + Q->>State: Sum exp → l_chunk + State->>State: α = exp(m_old - m_new) + State->>State: l_new = l_old * α + l_chunk + V->>O: Load V chunk i + O->>O: O_old * α (rescale) + O->>O: O += exp(scores) @ V + State->>State: Update m_old, l_old + end + + Note over O: Final: O_final = O / l_final +``` + +## Visual Representation of Sliding Window + +``` +Iteration 0: +┌─────────┐ ┌────┬────┬────┬─────────────────────────┐ +│ │ │ 0 │ 1 │ 2 │ ... (64 chunks total) │ +│ Q │ @ ├────┴────┴────┴─────────────────────────┤ +│ (128×64)│ │ K Matrix (4096 × 64) │ +│ │ │ [Chunk 0 highlighted] │ +└─────────┘ └──────────────────────────────────────────┘ + ↓ + ┌─────────────────────┐ + │ Partial scores │ + │ Update m_0, l_0 │ + │ Accumulate to O_0 │ + └─────────────────────┘ + +Iteration 1: +┌─────────┐ ┌────┬────┬────┬─────────────────────────┐ +│ │ │ │ 1 │ 2 │ ... │ +│ Q │ @ ├────┴────┴────┴─────────────────────────┤ +│ (128×64)│ │ K Matrix (4096 × 64) │ +│ │ │ [Chunk 1 highlighted] │ +└─────────┘ └──────────────────────────────────────────┘ + ↓ + ┌─────────────────────┐ + │ Partial scores │ + │ Update m_1, l_1 │ + │ Accumulate to O_1 │ + └─────────────────────┘ + +... continues for 64 iterations ... +``` + +## Key Benefits + +```mermaid +mindmap + root((Online
Softmax)) + Memory Efficiency + No 128×4096 matrix + Only 128×64 scores + ~32x reduction + Numerical Stability + Running max update + Prevents overflow + Standard softmax trick + Parallelism + Each workgroup independent + 512 workgroups total + 16 batch×head × 32 tiles + Hardware Optimization + Tiles fit in registers + DPAS instruction reuse + Reduced memory bandwidth +``` + +## Implementation Details + +### Sub-chunking K and V + +Each 64-column chunk is further divided into 4 sub-chunks of 16 columns: + +```mermaid +graph LR + subgraph "K/V Chunk (64 cols)" + K0[Sub 0
16 cols] + K1[Sub 1
16 cols] + K2[Sub 2
16 cols] + K3[Sub 3
16 cols] + end + + K0 --> DPAS[DPAS Operations
128×64 @ 64×16] + K1 --> DPAS + K2 --> DPAS + K3 --> DPAS + + DPAS --> Scores[4 partial scores
each 128×16] + Scores --> Max[Find max across all 4] + Max --> Update[Update state & output] + + style DPAS fill:#b3e5fc,stroke:#0277bd,stroke-width:2px + style Update fill:#c8e6c9,stroke:#388e3c,stroke-width:2px +``` + +## State Variables + +| Variable | Shape | Purpose | +|----------|-------|---------| +| `m_i` | `128×1` | Running maximum value per row | +| `l_i` | `128×1` | Running sum of exponentials per row | +| `O_i` | `128×64` | Running output accumulator | +| `Q` | `128×64` | Query tile (constant per workgroup) | +| `K_chunk` | `64×64` | Current K chunk being processed | +| `V_chunk` | `64×64` | Current V chunk being processed | + +## Comparison: Standard vs Online Softmax + +```mermaid +graph TB + subgraph Standard["Standard Attention (Stage 4-5)"] + SQ[Q: 128×64] + SK[K: 4096×64] + SQK["Q@K^T
128×4096
⚠️ Full matrix"] + SSoft[Softmax
128×4096] + SV[V: 4096×64] + SO[O: 128×64] + + SQ --> SQK + SK --> SQK + SQK --> SSoft + SSoft --> SO + SV --> SO + end + + subgraph Online["Online Softmax (Stage 6)"] + OQ[Q: 128×64] + OK[K chunks:
64×64] + OQK["Q@K^T
128×64
✓ Chunk only"] + OState[State:
m_i, l_i] + OV[V chunks:
64×64] + OO[O: 128×64
accumulated] + + OQ --> OQK + OK --> OQK + OQK --> OState + OState -.correction.-> OO + OV --> OO + OQK -.exp.-> OO + end + + style SQK fill:#ffcdd2,stroke:#c62828,stroke-width:3px + style OQK fill:#c8e6c9,stroke:#388e3c,stroke-width:3px +``` + +## Mathematical Formulation + +For each chunk $i$: + +```math +\begin{align} +\text{scores}_i &= Q \cdot K_i^T \cdot \frac{1}{\sqrt{d_k}} \\ +m_i &= \max(m_{i-1}, \max(\text{scores}_i)) \\ +\alpha_i &= \exp(m_{i-1} - m_i) \\ +l_i &= l_{i-1} \cdot \alpha_i + \sum \exp(\text{scores}_i - m_i) \\ +O_i &= O_{i-1} \cdot \alpha_i + \exp(\text{scores}_i - m_i) \cdot V_i \\ +O_{\text{final}} &= \frac{O_n}{l_n} +\end{align} +``` + +Where: +- $m_i$ is the running maximum +- $l_i$ is the running sum of exponentials +- $\alpha_i$ is the correction factor for previous chunks +- $O_i$ is the accumulated output From 971ace6f75df5dfee6462722771f652acd2c644d Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 20:51:16 +0000 Subject: [PATCH 07/20] cleanup --- docs/fused_attention_lowering.md | 2 - docs/stage6.md | 271 ------------------------------- 2 files changed, 273 deletions(-) delete mode 100644 docs/stage6.md diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md index 3e8b8700..97d37cfe 100644 --- a/docs/fused_attention_lowering.md +++ b/docs/fused_attention_lowering.md @@ -310,8 +310,6 @@ func.func @payload(%arg_output: memref<2x8x4096x64xf16>, ## Stage 6: Inner Tiling for Fused Attention (Online Softmax) -This is the **Flash Attention** optimization. - 1. Implements "online" softmax to avoid materializing the full attention matrix. 2. Tile the K and V loads into `16 x d_head` and interleave with DPAS for lower register preassure. diff --git a/docs/stage6.md b/docs/stage6.md deleted file mode 100644 index fe02acf0..00000000 --- a/docs/stage6.md +++ /dev/null @@ -1,271 +0,0 @@ -# Stage 6: Online Softmax (Fused Attention) - -Process K/V in chunks to avoid materializing the full attention matrix. - -## Matrix Layout Overview - -```mermaid -graph TB - subgraph Input["Input Matrices"] - Q["Q Tile
128 × 64
(solid)"] - K["K Matrix (full context)
4096 × 64
(load in chunks)"] - V["V Matrix (full context)
4096 × 64
(load in chunks)"] - end - - subgraph Process["Chunked Processing"] - K0["K Chunk 0
64 × 64"] - K1["K Chunk 1
64 × 64"] - K2["K Chunk 2
64 × 64"] - Kdots["..."] - K63["K Chunk 63
64 × 64"] - end - - subgraph Output["Output & State"] - O["Output O
128 × 64
(accumulated)"] - State["Running State
m_i (max)
l_i (sum)"] - end - - Q -->|Q @ K^T| K0 - K -->|Split into
64 chunks| K0 - K --> K1 - K --> K2 - K --> Kdots - K --> K63 - - K0 -->|scores| V - V -->|V Chunk 0| O - K0 -.->|update| State - State -.->|correction| O - - style Q fill:#bbdefb,stroke:#1976d2,stroke-width:3px - style K fill:#e0e0e0,stroke:#757575,stroke-width:2px,stroke-dasharray: 5 5 - style V fill:#e0e0e0,stroke:#757575,stroke-width:2px,stroke-dasharray: 5 5 - style O fill:#c8e6c9,stroke:#388e3c,stroke-width:3px - style State fill:#fff9c4,stroke:#f57f17,stroke-width:2px - style K0 fill:#ffcdd2,stroke:#c62828,stroke-width:2px -``` - -## Online Softmax Algorithm Flow - -```mermaid -flowchart TD - Start([Start: Loop over chunks i=0..63]) --> Init[Initialize:
m_0 = -∞
l_0 = 0
O_0 = 0] - - Init --> LoadK[Load K chunk i
64 × 64] - LoadK --> QK[Compute Q @ K^T
128 × 64 @ 64 × 64
= 128 × 64 scores] - - QK --> Scale[Scale scores
scores * 1/√d_k] - - Scale --> MaxChunk[Find max in chunk
m_chunk = max(scores)] - MaxChunk --> UpdateMax[Update global max
m_new = max(m_old, m_chunk)] - - UpdateMax --> ExpScores[Compute exponentials
exp(scores - m_new)] - ExpScores --> SumChunk[Sum exponentials
l_chunk = Σ exp(scores)] - - SumChunk --> Correction[Compute correction
α = exp(m_old - m_new)] - Correction --> UpdateSum[Update sum
l_new = l_old * α + l_chunk] - - UpdateSum --> LoadV[Load V chunk i
64 × 64] - LoadV --> RescaleO[Rescale old output
O_old * α] - RescaleO --> AccumO[Accumulate output
O_new = O_old * α + exp(scores) @ V] - - AccumO --> Check{More
chunks?} - Check -->|Yes| LoadK - Check -->|No| Normalize[Final normalization
O_final = O_accumulated / l_final] - - Normalize --> End([End: Output ready]) - - style Start fill:#e1bee7,stroke:#7b1fa2,stroke-width:2px - style End fill:#e1bee7,stroke:#7b1fa2,stroke-width:2px - style UpdateMax fill:#ffcdd2,stroke:#c62828,stroke-width:2px - style UpdateSum fill:#ffcdd2,stroke:#c62828,stroke-width:2px - style AccumO fill:#c8e6c9,stroke:#388e3c,stroke-width:2px - style Normalize fill:#c8e6c9,stroke:#388e3c,stroke-width:2px -``` - -## Chunk Processing Detail - -```mermaid -sequenceDiagram - participant Q as Q Tile
(128×64) - participant K as K Chunks
(64×64 each) - participant V as V Chunks
(64×64 each) - participant State as Running State
(m_i, l_i) - participant O as Output O
(128×64) - - Note over Q,O: Initialize: m_0=-∞, l_0=0, O_0=0 - - loop For each chunk i = 0..63 - K->>Q: Load K chunk i - Q->>Q: Compute Q @ K^T → scores (128×64) - Q->>Q: Scale scores - Q->>State: Find max → m_chunk - State->>State: m_new = max(m_old, m_chunk) - Q->>Q: exp(scores - m_new) - Q->>State: Sum exp → l_chunk - State->>State: α = exp(m_old - m_new) - State->>State: l_new = l_old * α + l_chunk - V->>O: Load V chunk i - O->>O: O_old * α (rescale) - O->>O: O += exp(scores) @ V - State->>State: Update m_old, l_old - end - - Note over O: Final: O_final = O / l_final -``` - -## Visual Representation of Sliding Window - -``` -Iteration 0: -┌─────────┐ ┌────┬────┬────┬─────────────────────────┐ -│ │ │ 0 │ 1 │ 2 │ ... (64 chunks total) │ -│ Q │ @ ├────┴────┴────┴─────────────────────────┤ -│ (128×64)│ │ K Matrix (4096 × 64) │ -│ │ │ [Chunk 0 highlighted] │ -└─────────┘ └──────────────────────────────────────────┘ - ↓ - ┌─────────────────────┐ - │ Partial scores │ - │ Update m_0, l_0 │ - │ Accumulate to O_0 │ - └─────────────────────┘ - -Iteration 1: -┌─────────┐ ┌────┬────┬────┬─────────────────────────┐ -│ │ │ │ 1 │ 2 │ ... │ -│ Q │ @ ├────┴────┴────┴─────────────────────────┤ -│ (128×64)│ │ K Matrix (4096 × 64) │ -│ │ │ [Chunk 1 highlighted] │ -└─────────┘ └──────────────────────────────────────────┘ - ↓ - ┌─────────────────────┐ - │ Partial scores │ - │ Update m_1, l_1 │ - │ Accumulate to O_1 │ - └─────────────────────┘ - -... continues for 64 iterations ... -``` - -## Key Benefits - -```mermaid -mindmap - root((Online
Softmax)) - Memory Efficiency - No 128×4096 matrix - Only 128×64 scores - ~32x reduction - Numerical Stability - Running max update - Prevents overflow - Standard softmax trick - Parallelism - Each workgroup independent - 512 workgroups total - 16 batch×head × 32 tiles - Hardware Optimization - Tiles fit in registers - DPAS instruction reuse - Reduced memory bandwidth -``` - -## Implementation Details - -### Sub-chunking K and V - -Each 64-column chunk is further divided into 4 sub-chunks of 16 columns: - -```mermaid -graph LR - subgraph "K/V Chunk (64 cols)" - K0[Sub 0
16 cols] - K1[Sub 1
16 cols] - K2[Sub 2
16 cols] - K3[Sub 3
16 cols] - end - - K0 --> DPAS[DPAS Operations
128×64 @ 64×16] - K1 --> DPAS - K2 --> DPAS - K3 --> DPAS - - DPAS --> Scores[4 partial scores
each 128×16] - Scores --> Max[Find max across all 4] - Max --> Update[Update state & output] - - style DPAS fill:#b3e5fc,stroke:#0277bd,stroke-width:2px - style Update fill:#c8e6c9,stroke:#388e3c,stroke-width:2px -``` - -## State Variables - -| Variable | Shape | Purpose | -|----------|-------|---------| -| `m_i` | `128×1` | Running maximum value per row | -| `l_i` | `128×1` | Running sum of exponentials per row | -| `O_i` | `128×64` | Running output accumulator | -| `Q` | `128×64` | Query tile (constant per workgroup) | -| `K_chunk` | `64×64` | Current K chunk being processed | -| `V_chunk` | `64×64` | Current V chunk being processed | - -## Comparison: Standard vs Online Softmax - -```mermaid -graph TB - subgraph Standard["Standard Attention (Stage 4-5)"] - SQ[Q: 128×64] - SK[K: 4096×64] - SQK["Q@K^T
128×4096
⚠️ Full matrix"] - SSoft[Softmax
128×4096] - SV[V: 4096×64] - SO[O: 128×64] - - SQ --> SQK - SK --> SQK - SQK --> SSoft - SSoft --> SO - SV --> SO - end - - subgraph Online["Online Softmax (Stage 6)"] - OQ[Q: 128×64] - OK[K chunks:
64×64] - OQK["Q@K^T
128×64
✓ Chunk only"] - OState[State:
m_i, l_i] - OV[V chunks:
64×64] - OO[O: 128×64
accumulated] - - OQ --> OQK - OK --> OQK - OQK --> OState - OState -.correction.-> OO - OV --> OO - OQK -.exp.-> OO - end - - style SQK fill:#ffcdd2,stroke:#c62828,stroke-width:3px - style OQK fill:#c8e6c9,stroke:#388e3c,stroke-width:3px -``` - -## Mathematical Formulation - -For each chunk $i$: - -```math -\begin{align} -\text{scores}_i &= Q \cdot K_i^T \cdot \frac{1}{\sqrt{d_k}} \\ -m_i &= \max(m_{i-1}, \max(\text{scores}_i)) \\ -\alpha_i &= \exp(m_{i-1} - m_i) \\ -l_i &= l_{i-1} \cdot \alpha_i + \sum \exp(\text{scores}_i - m_i) \\ -O_i &= O_{i-1} \cdot \alpha_i + \exp(\text{scores}_i - m_i) \cdot V_i \\ -O_{\text{final}} &= \frac{O_n}{l_n} -\end{align} -``` - -Where: -- $m_i$ is the running maximum -- $l_i$ is the running sum of exponentials -- $\alpha_i$ is the correction factor for previous chunks -- $O_i$ is the accumulated output From 35b1245d8d8af42edfba0f89eaa978d51e437f19 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 20:56:27 +0000 Subject: [PATCH 08/20] add note about softmax decompose --- docs/fused_attention_lowering.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md index 97d37cfe..61304b45 100644 --- a/docs/fused_attention_lowering.md +++ b/docs/fused_attention_lowering.md @@ -48,6 +48,11 @@ This document describes the multi-stage lowering process for standard attention ### Softmax Decomposition The atomic `linalg.softmax` is decomposed into explicit operations: +**Why decompose so early? :** Softmax operation did not work well tile and fuse. +If last matmul is tiled and then if we try to fuse softmax into the `scf.forall` +It does not **bubble-up** the `tensor.extract_slice`. instead entire softmax is +done first (including parallel dims) and the extracted from the softmax result. + ```mlir // 1. Find max value per row (for numerical stability) %max_per_row = linalg.generic { From cc0400be0ede1e800ca35f638acb38e088864fa7 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 20:58:44 +0000 Subject: [PATCH 09/20] add note about softmax decompose --- docs/fused_attention_lowering.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md index 61304b45..0f7b51de 100644 --- a/docs/fused_attention_lowering.md +++ b/docs/fused_attention_lowering.md @@ -48,10 +48,12 @@ This document describes the multi-stage lowering process for standard attention ### Softmax Decomposition The atomic `linalg.softmax` is decomposed into explicit operations: -**Why decompose so early? :** Softmax operation did not work well tile and fuse. +**Why decompose so early? :** +Softmax operation did not work well with tile and fuse. If last matmul is tiled and then if we try to fuse softmax into the `scf.forall` It does not **bubble-up** the `tensor.extract_slice`. instead entire softmax is -done first (including parallel dims) and the extracted from the softmax result. +done first (including parallel dims) and then extractions are done from the result to +feed into the last matmul. ```mlir // 1. Find max value per row (for numerical stability) From b81b69c41e88b4976d8cc9c7b595c1b791c3d6ed Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 20:59:22 +0000 Subject: [PATCH 10/20] add note about softmax decompose --- docs/fused_attention_lowering.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md index 0f7b51de..4600da94 100644 --- a/docs/fused_attention_lowering.md +++ b/docs/fused_attention_lowering.md @@ -48,7 +48,7 @@ This document describes the multi-stage lowering process for standard attention ### Softmax Decomposition The atomic `linalg.softmax` is decomposed into explicit operations: -**Why decompose so early? :** +#### **Why decompose so early? :** Softmax operation did not work well with tile and fuse. If last matmul is tiled and then if we try to fuse softmax into the `scf.forall` It does not **bubble-up** the `tensor.extract_slice`. instead entire softmax is From 1cd66f80dc811960afad85e498bb8efa888c772c Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 21:04:25 +0000 Subject: [PATCH 11/20] add note about fused loop generation stage --- docs/fused_attention_lowering.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md index 4600da94..4969c521 100644 --- a/docs/fused_attention_lowering.md +++ b/docs/fused_attention_lowering.md @@ -183,6 +183,12 @@ Each workgroup now processes: 16 × 32 = 512 independent workgroups +#### Why not introduce the fused attention loop at this point (at linalg level)? + +I tried to do this. This materializes the partial max and sum vectors in the fused +attention algorithm as SLM memory buffers. There is no way to promote these to +registers. At this point, vector to xegpu does not support 1D SLM access also. +Because of these resons delayed the introduction of fused loop until after bufferization. --- ## Stage 4: Vectorization From 8bd44fe97d2db8033814d860d334b16c3057412e Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 21:05:27 +0000 Subject: [PATCH 12/20] add note about fused loop generation stage --- docs/fused_attention_lowering.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md index 4969c521..54850d25 100644 --- a/docs/fused_attention_lowering.md +++ b/docs/fused_attention_lowering.md @@ -189,6 +189,7 @@ I tried to do this. This materializes the partial max and sum vectors in the fus attention algorithm as SLM memory buffers. There is no way to promote these to registers. At this point, vector to xegpu does not support 1D SLM access also. Because of these resons delayed the introduction of fused loop until after bufferization. + --- ## Stage 4: Vectorization From 7e0175f04c9e425ab0874ebf638a65e6b030531c Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 21:07:15 +0000 Subject: [PATCH 13/20] add note about fused loop generation stage --- docs/fused_attention_lowering.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md index 54850d25..93244407 100644 --- a/docs/fused_attention_lowering.md +++ b/docs/fused_attention_lowering.md @@ -185,9 +185,9 @@ Each workgroup now processes: #### Why not introduce the fused attention loop at this point (at linalg level)? -I tried to do this. This materializes the partial max and sum vectors in the fused +Tried this approach. This materializes the partial max and sum vectors in the fused attention algorithm as SLM memory buffers. There is no way to promote these to -registers. At this point, vector to xegpu does not support 1D SLM access also. +registers (*using only upstream magic?*). At this point, vector to xegpu does not support 1D SLM access also. Because of these resons delayed the introduction of fused loop until after bufferization. --- From f6b20705cdfda701fad7385b44637851d17c33ab Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 21:08:03 +0000 Subject: [PATCH 14/20] add note about fused loop generation stage --- docs/fused_attention_lowering.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md index 93244407..44d4bbc1 100644 --- a/docs/fused_attention_lowering.md +++ b/docs/fused_attention_lowering.md @@ -187,7 +187,7 @@ Each workgroup now processes: Tried this approach. This materializes the partial max and sum vectors in the fused attention algorithm as SLM memory buffers. There is no way to promote these to -registers (*using only upstream magic?*). At this point, vector to xegpu does not support 1D SLM access also. +registers (*using only upstream magic?*). At this point, vector to xegpu does not support 1D SLM access also (**WIP**). Because of these resons delayed the introduction of fused loop until after bufferization. --- From 5500db62cb40a86167182429eebd4eb3d5805daa Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 21:16:19 +0000 Subject: [PATCH 15/20] remove gpu target attribute --- docs/fused_attention_lowering.md | 9 --------- 1 file changed, 9 deletions(-) diff --git a/docs/fused_attention_lowering.md b/docs/fused_attention_lowering.md index 44d4bbc1..d9d4f17f 100644 --- a/docs/fused_attention_lowering.md +++ b/docs/fused_attention_lowering.md @@ -705,15 +705,6 @@ xegpu.store_nd %output_normalized, %output_tdesc[%row_offset, 0] <{ }> : vector<128x64xf16>, !xegpu.tensor_desc<128x64xf16, ...> ``` -### GPU Target Attribute - -```mlir -gpu.module @payload_kernel [#xevm.target] { - // O = 3 specifies optimization level 3 - gpu.func @payload_kernel(...) kernel { ... } -} -``` - --- ## Conclusion From 3a10839620fe6934a5ad0fb106ead1ac8e18a690 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 22:59:04 +0000 Subject: [PATCH 16/20] softmax tiling example --- docs/softmax_tiling.md | 241 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 docs/softmax_tiling.md diff --git a/docs/softmax_tiling.md b/docs/softmax_tiling.md new file mode 100644 index 00000000..ab8e48af --- /dev/null +++ b/docs/softmax_tiling.md @@ -0,0 +1,241 @@ +# Softmax Tiling and Fusion + +--- + +## Stage 0: Initial Softmax + +**What we have:** A single high-level `linalg.softmax` operation that computes softmax across dimension 1. + +**Softmax formula:** For each row $i$: + +$$\text{softmax}(x)_{i,j} = \frac{\exp(x_{i,j} - \max_k x_{i,k})}{\sum_k \exp(x_{i,k} - \max_k x_{i,k})}$$ + +```mlir +func.func @softmax_v0(%arg0: memref<64x1024xf32>, %arg1: memref<64x1024xf32>) { + %0 = bufferization.to_tensor %arg0 restrict writable : memref<64x1024xf32> to tensor<64x1024xf32> + %1 = tensor.empty() : tensor<64x1024xf32> + %2 = linalg.softmax dimension(1) ins(%0 : tensor<64x1024xf32>) outs(%1 : tensor<64x1024xf32>) -> tensor<64x1024xf32> + bufferization.materialize_in_destination %2 in restrict writable %arg1 : (tensor<64x1024xf32>, memref<64x1024xf32>) -> () + return +} +``` + +--- + +## Stage 1: Softmax Decomposition + +**Transform applied:** `transform.structured.decompose_interface` + +**What happens:** The high-level softmax operation is decomposed into 5 primitive `linalg.generic` operations: + +1. **Max reduction** (`%4`): Find the maximum value along each row +2. **Subtract and exp** (`%5`): Compute `exp(x - max)` elementwise +3. **Fill for sum** (`%6`): Initialize accumulator for sum reduction +4. **Sum reduction** (`%7`): Sum the exp values along each row +5. **Final division** (`%8`): Divide each exp value by the sum + +```mlir +func.func @softmax_v0(%arg0: memref<64x1024xf32>, %arg1: memref<64x1024xf32>) { + %0 = bufferization.to_tensor %arg0 restrict writable : memref<64x1024xf32> to tensor<64x1024xf32> + %1 = tensor.empty() : tensor<64x1024xf32> + %2 = tensor.empty() : tensor<64xf32> + %cst = arith.constant 0xFFC00000 : f32 + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + + // 1. Max reduction: compute max along each row + %4 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%0 : tensor<64x1024xf32>) outs(%3 : tensor<64xf32>) { + ^bb0(%in: f32, %out: f32): + %9 = arith.maxnumf %in, %out : f32 + linalg.yield %9 : f32 + } -> tensor<64xf32> + + // 2. Subtract max and apply exp: exp(x - max) + %5 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel"] + } ins(%0, %4 : tensor<64x1024xf32>, tensor<64xf32>) outs(%1 : tensor<64x1024xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %9 = arith.subf %in, %in_1 : f32 + %10 = math.exp %9 : f32 + linalg.yield %10 : f32 + } -> tensor<64x1024xf32> + + %cst_0 = arith.constant 0.000000e+00 : f32 + %6 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + + // 3. Sum reduction: sum of exp values + %7 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%5 : tensor<64x1024xf32>) outs(%6 : tensor<64xf32>) { + ^bb0(%in: f32, %out: f32): + %9 = arith.addf %in, %out : f32 + linalg.yield %9 : f32 + } -> tensor<64xf32> + + // 4. Final division: exp(x - max) / sum + %8 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel"] + } ins(%5, %7 : tensor<64x1024xf32>, tensor<64xf32>) outs(%1 : tensor<64x1024xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %9 = arith.divf %in, %in_1 : f32 + linalg.yield %9 : f32 + } -> tensor<64x1024xf32> + + bufferization.materialize_in_destination %8 in restrict writable %arg1 : (tensor<64x1024xf32>, memref<64x1024xf32>) -> () + return +} +``` + +--- + +## Stage 2: Elementwise Operation Fusion + +**Transform applied:** `transform.apply_registered_pass "linalg-fuse-elementwise-ops"` with some modifications. + +**What happens:** The compiler fuses operations that can be computed together to reduce memory traffic: + +1. The elementwise `exp(x - max)` operation (`%5`) is fused into the sum reduction (`%6`) +2. The elementwise `exp(x - max)` and division are fused into the final output (`%7`) + +```mlir +func.func @softmax_v0(%arg0: memref<64x1024xf32>, %arg1: memref<64x1024xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 0xFFC00000 : f32 + %0 = bufferization.to_tensor %arg0 restrict writable : memref<64x1024xf32> to tensor<64x1024xf32> + %1 = tensor.empty() : tensor<64x1024xf32> + %2 = tensor.empty() : tensor<64xf32> + %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + + // 1. Max reduction (unchanged) + %4 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%0 : tensor<64x1024xf32>) outs(%3 : tensor<64xf32>) { + ^bb0(%in: f32, %out: f32): + %8 = arith.maxnumf %in, %out : f32 + linalg.yield %8 : f32 + } -> tensor<64xf32> + + %5 = linalg.fill ins(%cst : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + + // 2. Fused: exp(x - max) and sum reduction + // This computes exp on-the-fly during the reduction + %6 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0)> + ], + iterator_types = ["parallel", "reduction"] + } ins(%0, %4 : tensor<64x1024xf32>, tensor<64xf32>) outs(%5 : tensor<64xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %8 = arith.subf %in, %in_1 : f32 + %9 = math.exp %8 : f32 + %10 = arith.addf %9, %out : f32 + linalg.yield %10 : f32 + } -> tensor<64xf32> + + // 3. Fused: exp(x - max) and division + // This computes exp again and immediately divides by sum + %7 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel"] + } ins(%0, %4, %6 : tensor<64x1024xf32>, tensor<64xf32>, tensor<64xf32>) outs(%1 : tensor<64x1024xf32>) { + ^bb0(%in: f32, %in_1: f32, %in_2: f32, %out: f32): + %8 = arith.subf %in, %in_1 : f32 + %9 = math.exp %8 : f32 + %10 = arith.divf %9, %in_2 : f32 + linalg.yield %10 : f32 + } -> tensor<64x1024xf32> + + bufferization.materialize_in_destination %7 in restrict writable %arg1 : (tensor<64x1024xf32>, memref<64x1024xf32>) -> () + return +} +``` + +--- + +## Stage 3: Max-Sum Reduction Fusion (Online Softmax) + + +**What happens:** The max reduction and sum reduction are fused into a single pass over the data using the "online softmax" algorithm. This algorithm maintains running max and running sum values, applying a correction factor when the max changes. + +**Algorithm:** +``` +For each element x in the row: + new_max = max(current_max, x) + correction = exp(old_max - new_max) + corrected_sum = current_sum * correction + new_sum = corrected_sum + exp(x - new_max) +``` + +```mlir +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @softmax_v2(%arg0: memref<64x1024xf32>, %arg1: memref<64x1024xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 0xFFC00000 : f32 + %0 = bufferization.to_tensor %arg0 restrict writable : memref<64x1024xf32> to tensor<64x1024xf32> + %1 = tensor.empty() : tensor<64x1024xf32> + %2 = tensor.empty() : tensor<64xf32> + %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + %4 = linalg.fill ins(%cst : f32) outs(%2 : tensor<64xf32>) -> tensor<64xf32> + + // Fused max and sum reduction with online correction + %5:2 = linalg.generic { + indexing_maps = [#map, #map1, #map1], + iterator_types = ["parallel", "reduction"] + } ins(%0 : tensor<64x1024xf32>) outs(%3, %4 : tensor<64xf32>, tensor<64xf32>) { + ^bb0(%in: f32, %out_max: f32, %out_sum: f32): + // Update max + %new_max = arith.maxnumf %in, %out_max : f32 + + // Compute correction factor for sum when max changes + // correction = exp(old_max - new_max) + %max_diff = arith.subf %out_max, %new_max : f32 + %correction = math.exp %max_diff : f32 + %corrected_sum = arith.mulf %out_sum, %correction : f32 + + // Add new contribution: exp(value - new_max) + %val_diff = arith.subf %in, %new_max : f32 + %exp_val = math.exp %val_diff : f32 + %new_sum = arith.addf %corrected_sum, %exp_val : f32 + + linalg.yield %new_max, %new_sum : f32, f32 + } -> (tensor<64xf32>, tensor<64xf32>) + + // Final division to compute softmax + %6 = linalg.generic { + indexing_maps = [#map, #map1, #map1, #map], + iterator_types = ["parallel", "parallel"] + } ins(%0, %5#0, %5#1 : tensor<64x1024xf32>, tensor<64xf32>, tensor<64xf32>) outs(%1 : tensor<64x1024xf32>) { + ^bb0(%in: f32, %in_max: f32, %in_sum: f32, %out: f32): + %7 = arith.subf %in, %in_max : f32 + %8 = math.exp %7 : f32 + %9 = arith.divf %8, %in_sum : f32 + linalg.yield %9 : f32 + } -> tensor<64x1024xf32> + + bufferization.materialize_in_destination %6 in restrict writable %arg1 : (tensor<64x1024xf32>, memref<64x1024xf32>) -> () + return +} +``` From 119c4fe515a37325fa456a04f1be1ad50cf4842d Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 23:07:55 +0000 Subject: [PATCH 17/20] softmax tiling example --- docs/softmax_tiling.md | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/docs/softmax_tiling.md b/docs/softmax_tiling.md index ab8e48af..661f69c1 100644 --- a/docs/softmax_tiling.md +++ b/docs/softmax_tiling.md @@ -178,14 +178,33 @@ func.func @softmax_v0(%arg0: memref<64x1024xf32>, %arg1: memref<64x1024xf32>) { **What happens:** The max reduction and sum reduction are fused into a single pass over the data using the "online softmax" algorithm. This algorithm maintains running max and running sum values, applying a correction factor when the max changes. -**Algorithm:** -``` -For each element x in the row: - new_max = max(current_max, x) - correction = exp(old_max - new_max) - corrected_sum = current_sum * correction - new_sum = corrected_sum + exp(x - new_max) -``` +**Why this works?** + +Consider the reduction loop at an arbitrary stage $i$. We have access to: + +$$m_{i-1} = \max(x_0, x_1, \ldots, x_{i-1})$$ + +$$s_{i-1} = \sum_{j=0}^{i-1} \exp(x_j - m_{i-1})$$ + +When we process element $x_i$, we first compute the updated max: + +$$m_i = \max(m_{i-1}, x_i)$$ + +Now we need to compute the updated sum. The challenge is that $s_{i-1}$ was computed with respect to $m_{i-1}$, but we need the sum with respect to the new max $m_i$. + +Using the property $\exp(a - b) = \frac{\exp(a)}{\exp(b)} = \exp(a) \cdot \exp(-b)$, we can rewrite each term: + +$$\exp(x_j - m_i) = \exp(x_j - m_{i-1} - (m_i - m_{i-1})) = \exp(x_j - m_{i-1}) \cdot \exp(m_{i-1} - m_i)$$ + +Since multiplication distributes over addition, we can factor out the correction term from the entire sum: + +$$\sum_{j=0}^{i-1} \exp(x_j - m_i) = \exp(m_{i-1} - m_i) \cdot \sum_{j=0}^{i-1} \exp(x_j - m_{i-1}) = \exp(m_{i-1} - m_i) \cdot s_{i-1}$$ + +Therefore, the corrected sum at stage $i$ is: + +$$s_i = \exp(m_{i-1} - m_i) \cdot s_{i-1} + \exp(x_i - m_i)$$ + +This shows that we can maintain running max and sum values in a single pass, applying a multiplicative correction factor $\exp(m_{i-1} - m_i)$ whenever the max changes. ```mlir #map = affine_map<(d0, d1) -> (d0, d1)> From 413569026939ef7c25f0708c7a52c6293734d2ed Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 May 2026 23:25:31 +0000 Subject: [PATCH 18/20] add general comments --- docs/softmax_tiling.md | 59 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/docs/softmax_tiling.md b/docs/softmax_tiling.md index 661f69c1..fb20db54 100644 --- a/docs/softmax_tiling.md +++ b/docs/softmax_tiling.md @@ -258,3 +258,62 @@ func.func @softmax_v2(%arg0: memref<64x1024xf32>, %arg1: memref<64x1024xf32>) { return } ``` + +--- + +## Does This Always Work? + +**Short answer:** No. The online correction technique only works when the correction can be applied distributively over the aggregation operation. + +### Why Softmax Works + +The key to the online softmax algorithm is that **multiplication distributes over addition**: + +$$c \cdot (a + b) = c \cdot a + c \cdot b$$ + +Combined with the exponential property $\exp(a - b) = \exp(a) \cdot \exp(-b)$, this allows us to apply a multiplicative correction factor to the entire accumulated sum. + +### A Counter-Example: What If There Was No Exp? + +Consider a hypothetical variant where we want to compute: + +$$\text{result}_j = \frac{x_j - m}{s}$$ + +where $m = \max_k x_k$ and $s = \sum_k (x_k - m)$. + +Can we fuse the max and sum reductions using online correction? Let's try: + +At stage $i$, we have: +- $m_{i-1} = \max(x_0, x_1, \ldots, x_{i-1})$ +- $s_{i-1} = \sum_{j=0}^{i-1} (x_j - m_{i-1})$ + +When processing $x_i$: +- $m_i = \max(m_{i-1}, x_i)$ + +To correct $s_{i-1}$ for the new max $m_i$, we need: + +$$s_i = \sum_{j=0}^{i-1} (x_j - m_i) + (x_i - m_i)$$ + +Expanding the first term: + +$$\sum_{j=0}^{i-1} (x_j - m_i) = \sum_{j=0}^{i-1} (x_j - m_{i-1} - (m_i - m_{i-1}))$$ + +$$= \sum_{j=0}^{i-1} (x_j - m_{i-1}) - \sum_{j=0}^{i-1} (m_i - m_{i-1})$$ + +$$= s_{i-1} - i \cdot (m_i - m_{i-1})$$ + +**The problem:** To correct $s_{i-1}$, we need to subtract $(m_i - m_{i-1})$ from each of the $i$ terms we've seen so far. This requires knowing the count $i$ and performing $i$ subtractions worth of correction. **This is not possibel at linalg level** (maybe its possible when we materialize the loops?) + +**Why subtraction doesn't work:** Subtraction is **not distributive** over addition in the way we need: + +$(a + b) - c \neq (a - c) + (b - c)$ when we want a single correction + +We would need to track both the sum and the count of elements, and the correction becomes additive rather than multiplicative: $s_i = s_{i-1} - i \cdot \Delta m$, where the correction depends on how many elements we've processed. + +### General Principle + +The online correction technique works when: + +1. **The operation allows factoring out corrections**: The transformation applied to each element (like $\exp(x - m)$) can be rewritten to separate the correction term +2. **The correction is distributive over the reduction**: The correction can be applied to the accumulated result as a whole, not element-by-element +3. **The correction is independent of cardinality**: We don't need to know how many elements contributed to the partial result From dabeba4190f5338414d6f6e2f5cece1ed98ce149 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 20 May 2026 05:27:14 +0000 Subject: [PATCH 19/20] add attention tiling analysis --- docs/attention_tiling_analysis.md | 456 ++++++++++++++++++++++++++++++ 1 file changed, 456 insertions(+) create mode 100644 docs/attention_tiling_analysis.md diff --git a/docs/attention_tiling_analysis.md b/docs/attention_tiling_analysis.md new file mode 100644 index 00000000..c1d46a95 --- /dev/null +++ b/docs/attention_tiling_analysis.md @@ -0,0 +1,456 @@ +# Attention Kernel Transformation Analysis + +This document describes the step-by-step transformation of an attention kernel implementation through MLIR compiler passes. The attention mechanism computes: `attention(Q, K, V) = softmax(Q × K^T) × V` + +--- + +## Stage 1: Initial Attention Computation + +```mlir +func.func @attention(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + + // Allocate output tensors + %3 = tensor.empty() : tensor<64x4096xf16> + %4 = tensor.empty() : tensor<128x4096xf16> + %5 = tensor.empty() : tensor<128x4096xf16> + %6 = tensor.empty() : tensor<128x64xf16> + + // Transpose K: 4096x64 -> 64x4096 + %transposed = linalg.transpose ins(%1 : tensor<4096x64xf16>) + outs(%3 : tensor<64x4096xf16>) + permutation = [1, 0] + + // First matmul: Q × K^T + %cst = arith.constant 0.000000e+00 : f16 + %7 = linalg.fill ins(%cst : f16) outs(%4 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + %8 = linalg.matmul ins(%0, %transposed : tensor<128x64xf16>, tensor<64x4096xf16>) + outs(%7 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + + // Softmax operation (high-level) + %9 = linalg.softmax dimension(1) ins(%8 : tensor<128x4096xf16>) + outs(%5 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + + // Second matmul: softmax_out × V + %10 = linalg.fill ins(%cst : f16) outs(%6 : tensor<128x64xf16>) -> tensor<128x64xf16> + %11 = linalg.matmul ins(%9, %2 : tensor<128x4096xf16>, tensor<4096x64xf16>) + outs(%10 : tensor<128x64xf16>) -> tensor<128x64xf16> + + bufferization.materialize_in_destination %11 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` + +--- + +## Stage 2: After Softmax Decomposition + +Softmax decomposition into constituent operations + +This is the numerically stable softmax: `softmax(x) = exp(x - max(x)) / sum(exp(x - max(x)))` + +```mlir +func.func @attention(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %3 = tensor.empty() : tensor<64x4096xf16> + %4 = tensor.empty() : tensor<128x4096xf16> + %5 = tensor.empty() : tensor<128x4096xf16> + %6 = tensor.empty() : tensor<128x64xf16> + + %transposed = linalg.transpose ins(%1 : tensor<4096x64xf16>) + outs(%3 : tensor<64x4096xf16>) + permutation = [1, 0] + + // First matmul (unchanged) + %cst = arith.constant 0.000000e+00 : f16 + %7 = linalg.fill ins(%cst : f16) outs(%4 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + %8 = linalg.matmul ins(%0, %transposed : tensor<128x64xf16>, tensor<64x4096xf16>) + outs(%7 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + + // === SOFTMAX DECOMPOSITION BEGINS === + + // Step 1: Find max along dimension 1 (reduction over 4096 elements) + %9 = tensor.empty() : tensor<128xf16> + %cst_0 = arith.constant 0xFE00 : f16 // -inf in f16 + %10 = linalg.fill ins(%cst_0 : f16) outs(%9 : tensor<128xf16>) -> tensor<128xf16> + %11 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%8 : tensor<128x4096xf16>) outs(%10 : tensor<128xf16>) { + ^bb0(%in: f16, %out: f16): + %18 = arith.maxnumf %in, %out : f16 + linalg.yield %18 : f16 + } -> tensor<128xf16> + + // Step 2: Subtract max and compute exp + %12 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%8, %11 : tensor<128x4096xf16>, tensor<128xf16>) + outs(%5 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_2: f16, %out: f16): + %18 = arith.subf %in, %in_2 : f16 + %19 = math.exp %18 : f16 + linalg.yield %19 : f16 + } -> tensor<128x4096xf16> + + // Step 3: Sum exponentials (reduction) + %cst_1 = arith.constant 0.000000e+00 : f16 + %13 = linalg.fill ins(%cst_1 : f16) outs(%9 : tensor<128xf16>) -> tensor<128xf16> + %14 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%12 : tensor<128x4096xf16>) outs(%13 : tensor<128xf16>) { + ^bb0(%in: f16, %out: f16): + %18 = arith.addf %in, %out : f16 + linalg.yield %18 : f16 + } -> tensor<128xf16> + + // Step 4: Divide by sum for normalization + %15 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%12, %14 : tensor<128x4096xf16>, tensor<128xf16>) + outs(%5 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_2: f16, %out: f16): + %18 = arith.divf %in, %in_2 : f16 + linalg.yield %18 : f16 + } -> tensor<128x4096xf16> + + // === SOFTMAX DECOMPOSITION ENDS === + + // Second matmul (unchanged) + %16 = linalg.fill ins(%cst : f16) outs(%6 : tensor<128x64xf16>) -> tensor<128x64xf16> + %17 = linalg.matmul ins(%15, %2 : tensor<128x4096xf16>, tensor<4096x64xf16>) + outs(%16 : tensor<128x64xf16>) -> tensor<128x64xf16> + + bufferization.materialize_in_destination %17 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` + +--- + +## Stage 3: After Matmul and Transpose Generalization + +Convert named operations (`linalg.matmul`, `linalg.transpose`) to `linalg.generic` + +```mlir +func.func @attention(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %3 = tensor.empty() : tensor<64x4096xf16> + %4 = tensor.empty() : tensor<128x4096xf16> + %5 = tensor.empty() : tensor<128x4096xf16> + %6 = tensor.empty() : tensor<128x64xf16> + + // Transpose converted to generic form + %7 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, // Input: swapped (4096×64 accessed as d1,d0) + affine_map<(d0, d1) -> (d0, d1)>], // Output: 64×4096 (normal d0,d1) + iterator_types = ["parallel", "parallel"] + } ins(%1 : tensor<4096x64xf16>) outs(%3 : tensor<64x4096xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<64x4096xf16> + + // First matmul converted to generic form + %cst = arith.constant 0.000000e+00 : f16 + %8 = linalg.fill ins(%cst : f16) outs(%4 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + %9 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, // Q: 128×64 + affine_map<(d0, d1, d2) -> (d2, d1)>, // K^T: 64×4096 + affine_map<(d0, d1, d2) -> (d0, d1)>], // Out: 128×4096 + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%0, %7 : tensor<128x64xf16>, tensor<64x4096xf16>) + outs(%8 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_2: f16, %out: f16): + %19 = arith.mulf %in, %in_2 : f16 + %20 = arith.addf %out, %19 : f16 + linalg.yield %20 : f16 + } -> tensor<128x4096xf16> + + // Softmax decomposition (unchanged) + // .... + + // Second matmul converted to generic form + %17 = linalg.fill ins(%cst : f16) outs(%6 : tensor<128x64xf16>) -> tensor<128x64xf16> + %18 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, // Softmax: 128×4096 + affine_map<(d0, d1, d2) -> (d2, d1)>, // V: 4096×64 + affine_map<(d0, d1, d2) -> (d0, d1)>], // Out: 128×64 + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%16, %2 : tensor<128x4096xf16>, tensor<4096x64xf16>) + outs(%17 : tensor<128x64xf16>) { + ^bb0(%in: f16, %in_2: f16, %out: f16): + %19 = arith.mulf %in, %in_2 : f16 + %20 = arith.addf %out, %19 : f16 + linalg.yield %20 : f16 + } -> tensor<128x64xf16> + + bufferization.materialize_in_destination %18 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` + +--- + +## Stage 4: After Elementwise Fusion (Final) + +Fusion of generic ops. + +### Notes +This is the most optimized form. Multiple transformations occur: +1. **Transpose + Matmul fusion**: The transpose operation is fused into the first matmul, directly reading K with transposed indexing +2. **Softmax reduction fusion**: The exp computation and sum reduction are fused into a single operation +3. **Final matmul fusion**: The normalization division is fused with the final matmul, computing `(exp(x - max) / sum) × V` in one pass + +```mlir +func.func @attention(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %cst = arith.constant 0xFE00 : f16 + %cst_0 = arith.constant 0.000000e+00 : f16 + + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %3 = tensor.empty() : tensor<128x4096xf16> + %4 = tensor.empty() : tensor<128x64xf16> + + // === FUSED: Transpose + First Matmul === + // Previously: separate transpose (4096×64 → 64×4096) + matmul + // Now: matmul directly reads K with transposed indexing + %5 = linalg.fill ins(%cst_0 : f16) outs(%3 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + %6 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, // Q: 128×64 + affine_map<(d0, d1, d2) -> (d1, d2)>, // K: 4096×64 (transpose on-the-fly!) + affine_map<(d0, d1, d2) -> (d0, d1)>], // Out: 128×4096 + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%0, %1 : tensor<128x64xf16>, tensor<4096x64xf16>) + outs(%5 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_1: f16, %out: f16): + %14 = arith.mulf %in, %in_1 : f16 + %15 = arith.addf %out, %14 : f16 + linalg.yield %15 : f16 + } -> tensor<128x4096xf16> + + // Max reduction (unchanged) + %7 = tensor.empty() : tensor<128xf16> + %8 = linalg.fill ins(%cst : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %9 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%6 : tensor<128x4096xf16>) outs(%8 : tensor<128xf16>) { + ^bb0(%in: f16, %out: f16): + %14 = arith.maxnumf %in, %out : f16 + linalg.yield %14 : f16 + } -> tensor<128xf16> + + // === FUSED: exp and sum reduction === + // Previously: separate exp operation + sum reduction + // Now: compute exp(x - max) and accumulate sum in one pass + %10 = linalg.fill ins(%cst_0 : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %11 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, // Input: QK^T + affine_map<(d0, d1) -> (d0)>, // Max values + affine_map<(d0, d1) -> (d0)>], // Sum accumulator (output) + iterator_types = ["parallel", "reduction"] + } ins(%6, %9 : tensor<128x4096xf16>, tensor<128xf16>) + outs(%10 : tensor<128xf16>) { + ^bb0(%in: f16, %in_1: f16, %out: f16): + %14 = arith.subf %in, %in_1 : f16 // x - max + %15 = math.exp %14 : f16 // exp(x - max) + %16 = arith.addf %15, %out : f16 // accumulate sum + linalg.yield %16 : f16 + } -> tensor<128xf16> + + // === FUSED: softmax normalization + second matmul === + // Previously: separate divide operation + matmul + // Now: compute (exp(x - max) / sum) * V in one fused kernel + %12 = linalg.fill ins(%cst_0 : f16) outs(%4 : tensor<128x64xf16>) -> tensor<128x64xf16> + %13 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, // QK^T scores + affine_map<(d0, d1, d2) -> (d0)>, // Max values + affine_map<(d0, d1, d2) -> (d0)>, // Sum values + affine_map<(d0, d1, d2) -> (d2, d1)>, // V matrix + affine_map<(d0, d1, d2) -> (d0, d1)>], // Output + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%6, %9, %11, %2 : tensor<128x4096xf16>, tensor<128xf16>, + tensor<128xf16>, tensor<4096x64xf16>) + outs(%12 : tensor<128x64xf16>) { + ^bb0(%in: f16, %in_1: f16, %in_2: f16, %in_3: f16, %out: f16): + %14 = arith.subf %in, %in_1 : f16 // x - max + %15 = math.exp %14 : f16 // exp(x - max) + %16 = arith.divf %15, %in_2 : f16 // exp(x - max) / sum [softmax] + %17 = arith.mulf %16, %in_3 : f16 // softmax * V + %18 = arith.addf %out, %17 : f16 // accumulate result + linalg.yield %18 : f16 + } -> tensor<128x64xf16> + + bufferization.materialize_in_destination %13 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` + +--- + +## Stage 5: Online Softmax Optimization (Max-Sum Fusion) + +The previous stage still computes max and sum in two separate passes over the attention scores. We can fuse these into a single pass using the **online softmax algorithm**. + +```mlir +func.func @attention_max_sum_fused(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %cst = arith.constant 0xFE00 : f16 + %cst_0 = arith.constant 0.000000e+00 : f16 + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %3 = tensor.empty() : tensor<128x4096xf16> + %4 = tensor.empty() : tensor<128x64xf16> + %5 = linalg.fill ins(%cst_0 : f16) outs(%3 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + + // First generic: Q @ K^T + %6 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} + ins(%0, %1 : tensor<128x64xf16>, tensor<4096x64xf16>) outs(%5 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_1: f16, %out: f16): + %14 = arith.mulf %in, %in_1 : f16 + %15 = arith.addf %out, %14 : f16 + linalg.yield %15 : f16 + } -> tensor<128x4096xf16> + + // === FUSED: Online Softmax (max + sum in single pass) === + // Takes: QK^T scores + // Returns: (max, sum) computed simultaneously + %7 = tensor.empty() : tensor<128xf16> + %8 = linalg.fill ins(%cst : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %9 = linalg.fill ins(%cst_0 : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %10:2 = linalg.generic {indexing_maps = [#map3, #map4, #map4], iterator_types = ["parallel", "reduction"]} + ins(%6 : tensor<128x4096xf16>) outs(%8, %9 : tensor<128xf16>, tensor<128xf16>) { + ^bb0(%in: f16, %out_max: f16, %out_sum: f16): + %14 = arith.maxnumf %in, %out_max : f16 // new_max = max(in, old_max) + %15 = arith.subf %out_max, %14 : f16 // old_max - new_max + %16 = math.exp %15 : f16 // correction_factor = exp(old_max - new_max) + %17 = arith.mulf %out_sum, %16 : f16 // rescaled_sum = old_sum * correction_factor + %18 = arith.subf %in, %14 : f16 // in - new_max + %19 = math.exp %18 : f16 // exp(in - new_max) + %20 = arith.addf %17, %19 : f16 // new_sum = rescaled_sum + exp(in - new_max) + linalg.yield %14, %20 : f16, f16 + } -> (tensor<128xf16>, tensor<128xf16>) + + // Final generic: compute attention output using max and sum + %11 = linalg.fill ins(%cst_0 : f16) outs(%4 : tensor<128x64xf16>) -> tensor<128x64xf16> + %12 = linalg.generic {indexing_maps = [#map, #map5, #map5, #map6, #map2], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%6, %10#0, %10#1, %2 : tensor<128x4096xf16>, tensor<128xf16>, tensor<128xf16>, tensor<4096x64xf16>) + outs(%11 : tensor<128x64xf16>) { + ^bb0(%in: f16, %in_1: f16, %in_2: f16, %in_3: f16, %out: f16): + %14 = arith.subf %in, %in_1 : f16 + %15 = math.exp %14 : f16 + %16 = arith.divf %15, %in_2 : f16 + %17 = arith.mulf %16, %in_3 : f16 + %18 = arith.addf %out, %17 : f16 + linalg.yield %18 : f16 + } -> tensor<128x64xf16> + + bufferization.materialize_in_destination %12 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` + +--- + +## Stage 6: Fully Fused Online Attention (Max-Sum-Matmul Fusion) + +The ultimate optimization fuses **all three operations** (max, sum, and final matmul) into a single generic operation. This computes the attention output in one pass over the sequence dimension. + +### Key Insight +Not only does the sum need rescaling when max changes, but the **accumulated output** must also be rescaled by the same correction factor to maintain correctness. + +```mlir +func.func @attention_final(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, + %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) { + %cst = arith.constant 0xFE00 : f16 + %cst_0 = arith.constant 0.000000e+00 : f16 + %0 = bufferization.to_tensor %arg0 restrict writable : memref<128x64xf16> to tensor<128x64xf16> + %1 = bufferization.to_tensor %arg1 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %2 = bufferization.to_tensor %arg2 restrict writable : memref<4096x64xf16> to tensor<4096x64xf16> + %3 = tensor.empty() : tensor<128x4096xf16> + %4 = tensor.empty() : tensor<128x64xf16> + %5 = linalg.fill ins(%cst_0 : f16) outs(%3 : tensor<128x4096xf16>) -> tensor<128x4096xf16> + + // First generic: Q @ K^T + %6 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} + ins(%0, %1 : tensor<128x64xf16>, tensor<4096x64xf16>) outs(%5 : tensor<128x4096xf16>) { + ^bb0(%in: f16, %in_1: f16, %out: f16): + %14 = arith.mulf %in, %in_1 : f16 + %15 = arith.addf %out, %14 : f16 + linalg.yield %15 : f16 + } -> tensor<128x4096xf16> + + // === FULLY FUSED: Online attention (max + sum + matmul in single pass) === + // Takes: QK^T scores and V matrix + // Returns: (max, sum, attention_output) computed simultaneously + %7 = tensor.empty() : tensor<128xf16> + %8 = linalg.fill ins(%cst : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %9 = linalg.fill ins(%cst_0 : f16) outs(%7 : tensor<128xf16>) -> tensor<128xf16> + %10 = linalg.fill ins(%cst_0 : f16) outs(%4 : tensor<128x64xf16>) -> tensor<128x64xf16> + %11:3 = linalg.generic {indexing_maps = [#map, #map6, #map5, #map5, #map2], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%6, %2 : tensor<128x4096xf16>, tensor<4096x64xf16>) + outs(%8, %9, %10 : tensor<128xf16>, tensor<128xf16>, tensor<128x64xf16>) { + ^bb0(%in_qk: f16, %in_v: f16, %out_max: f16, %out_sum: f16, %out_acc: f16): + // Compute new max + %14 = arith.maxnumf %in_qk, %out_max : f16 + // Compute correction factor: exp(old_max - new_max) + %15 = arith.subf %out_max, %14 : f16 + %16 = math.exp %15 : f16 + // Rescale old sum with correction factor + %17 = arith.mulf %out_sum, %16 : f16 + // Rescale old acc with correction factor (CRITICAL!) + %18 = arith.mulf %out_acc, %16 : f16 + // Compute exp(in_qk - new_max) + %19 = arith.subf %in_qk, %14 : f16 + %20 = math.exp %19 : f16 + // Update sum + %21 = arith.addf %17, %20 : f16 + // Compute weighted value + %22 = arith.mulf %20, %in_v : f16 + // Update acc + %23 = arith.addf %18, %22 : f16 + linalg.yield %14, %21, %23 : f16, f16, f16 + } -> (tensor<128xf16>, tensor<128xf16>, tensor<128x64xf16>) + + // Final normalization: divide acc by sum + %12 = linalg.generic {indexing_maps = [#map3, #map4, #map3], iterator_types = ["parallel", "parallel"]} + ins(%11#2, %11#1 : tensor<128x64xf16>, tensor<128xf16>) outs(%4 : tensor<128x64xf16>) { + ^bb0(%in_acc: f16, %in_sum: f16, %out: f16): + %14 = arith.divf %in_acc, %in_sum : f16 + linalg.yield %14 : f16 + } -> tensor<128x64xf16> + + bufferization.materialize_in_destination %12 in restrict writable %arg3 : + (tensor<128x64xf16>, memref<128x64xf16>) -> () + return +} +``` From ad8f6bf99b150ea00d1782bc82ba774aafa2fa46 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 20 May 2026 05:33:46 +0000 Subject: [PATCH 20/20] add attention tiling analysis --- docs/attention_tiling_analysis.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/attention_tiling_analysis.md b/docs/attention_tiling_analysis.md index c1d46a95..e9172a7a 100644 --- a/docs/attention_tiling_analysis.md +++ b/docs/attention_tiling_analysis.md @@ -317,6 +317,11 @@ func.func @attention(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, The previous stage still computes max and sum in two separate passes over the attention scores. We can fuse these into a single pass using the **online softmax algorithm**. +### Chanages we need in upstream? + +This require improve the **GenericOp Fusion** to support complex reductions that are **Fusible** (reduction error can be corrected i.e. first reduction can be factored out from each term in the second reduction and corrected.) + + ```mlir func.func @attention_max_sum_fused(%arg0: memref<128x64xf16>, %arg1: memref<4096x64xf16>, %arg2: memref<4096x64xf16>, %arg3: memref<128x64xf16>) {