diff --git a/cpu/include/Dialect/TritonCPU/IR/TritonCPUOps.td b/cpu/include/Dialect/TritonCPU/IR/TritonCPUOps.td index 4b2d9265..65daf960 100644 --- a/cpu/include/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/cpu/include/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -213,7 +213,7 @@ def TTC_GenericOp : TTC_Op<"generic", [AttrSizedOperandSegments, RecursiveMemory Example — elementwise chain (no reductions): ``` - ttc.generic (%src_ptrs, %dst_ptrs) blocks [%c128 : i32] + ttc.generic ins (%src_ptrs, %dst_ptrs) blocks [%c128 : i32] {tileShape = array, reductionDims = array} body { ^bb0(%iv: i32, %src: tensor<16x!tt.ptr>, %dst: tensor<16x!tt.ptr>): @@ -226,7 +226,7 @@ def TTC_GenericOp : TTC_Op<"generic", [AttrSizedOperandSegments, RecursiveMemory Example — scalar reduction (softmax max pass): ``` - %max = ttc.generic (%ptrs, %n_cols) init (%neg_inf : f32) + %max = ttc.generic ins (%ptrs, %n_cols) init (%neg_inf : f32) blocks [%c128 : i32] {tileShape = array, reductionDims = array} body { @@ -240,7 +240,7 @@ def TTC_GenericOp : TTC_Op<"generic", [AttrSizedOperandSegments, RecursiveMemory Example — mixed scalar + tensor result (softmax sum + exp): ``` - %sum, %exp_out = ttc.generic (%ptrs, %n_cols) init (%zero : f32) + %sum, %exp_out = ttc.generic ins (%ptrs, %n_cols) init (%zero : f32) blocks [%c128 : i32] {tileShape = array, reductionDims = array} body { @@ -276,7 +276,7 @@ def TTC_GenericOp : TTC_Op<"generic", [AttrSizedOperandSegments, RecursiveMemory ]; let assemblyFormat = [{ - `(` $ins `)` (`init` `(` $init_vals^ `:` type($init_vals) `)`)? `blocks` `[` $blockShape `:` type($blockShape) `]` attr-dict-with-keyword `body` $body `:` functional-type($ins, $results) + (`init` `(` $init_vals^ `:` type($init_vals) `)`)? (`ins` `(` $ins^ `:` type($ins) `)`)? `blocks` `[` $blockShape `:` type($blockShape) `]` attr-dict-with-keyword `body` $body (`->` type($results)^)? }]; let extraClassDeclaration = [{ @@ -293,6 +293,8 @@ def TTC_GenericOp : TTC_Op<"generic", [AttrSizedOperandSegments, RecursiveMemory unsigned getInsArgOffset() { return getNumInductionVars() + getNumIterArgs(); } + + std::string getHeader(); }]; } diff --git a/cpu/lib/Dialect/TritonCPU/IR/Ops.cpp b/cpu/lib/Dialect/TritonCPU/IR/Ops.cpp index fb4abd4d..6f7f8f47 100644 --- a/cpu/lib/Dialect/TritonCPU/IR/Ops.cpp +++ b/cpu/lib/Dialect/TritonCPU/IR/Ops.cpp @@ -135,6 +135,13 @@ LogicalResult GenericOp::verify() { return success(); } +std::string GenericOp::getHeader() { + std::string s; + llvm::raw_string_ostream os(s); + print(os, OpPrintingFlags().skipRegions()); + return s; +} + LogicalResult MakeDynamicRangeOp::verify() { auto resultTensorTy = cast(getResult().getType()); if (resultTensorTy.getShape().size() != 1) diff --git a/cpu/lib/TritonCPUToLLVM/GenericOpToLLVM.cpp b/cpu/lib/TritonCPUToLLVM/GenericOpToLLVM.cpp index 37042b5f..79f9f868 100644 --- a/cpu/lib/TritonCPUToLLVM/GenericOpToLLVM.cpp +++ b/cpu/lib/TritonCPUToLLVM/GenericOpToLLVM.cpp @@ -776,11 +776,7 @@ struct GenericOpConversion : public ConvertOpToLLVMPattern { vectorSize *= tileShape[d]; } - // TODO: put this into extraClassDefinitions? - std::string s; - llvm::raw_string_ostream os(s); - op->print(os, OpPrintingFlags().skipRegions()); - LDBG("Lowering generic op: " << s + LDBG("Lowering generic op: " << op.getHeader() << "\n with vectorSize = " << vectorSize); SmallVector argInfos = diff --git a/test/Analysis/axis-info-generic.mlir b/test/Analysis/axis-info-generic.mlir index 7e772543..af37ed4b 100644 --- a/test/Analysis/axis-info-generic.mlir +++ b/test/Analysis/axis-info-generic.mlir @@ -17,11 +17,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar %x_3 = tt.addptr %x, %offsets_1 : tensor<512x!tt.ptr, #blocked>, tensor<512xi32, #blocked> // CHECK-COUNT-128: ttc.masked_load {{.*}} -> vector<4xf32> - ttc.generic (%x_3) blocks [%c512_i32 : i32] attributes {tileShape = array, reductionDims = array} body { + ttc.generic ins (%x_3 : tensor<512x!tt.ptr, #blocked>) blocks [%c512_i32 : i32] attributes {tileShape = array, reductionDims = array} body { ^bb0(%offset:i32, %arg0: tensor<4x!tt.ptr, #blocked>): %x_10 = tt.load %arg0 : tensor<4x!tt.ptr, #blocked> ttc.yield - }: (tensor<512x!tt.ptr, #blocked>) -> () + } tt.return } } @@ -32,7 +32,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32, ttg.target = "cpu", "ttg.threads-per-warp" = 1 : i32} { tt.func public @load_scalar(%x_ptr: !tt.ptr {tt.divisibility = 16 : i32}) { %c1024_i32 = arith.constant 1024 : i32 - ttc.generic (%x_ptr) blocks [%c1024_i32 : i32] attributes {tileShape = array, reductionDims = array} body { + ttc.generic ins (%x_ptr : !tt.ptr) blocks [%c1024_i32 : i32] attributes {tileShape = array, reductionDims = array} body { ^bb0(%tileOffset: i32, %ptr: !tt.ptr): %offsets = ttc.make_dynamic_range %tileOffset : tensor<4xi32, #blocked> %ptrs = tt.splat %ptr : !tt.ptr -> tensor<4x!tt.ptr, #blocked> @@ -46,7 +46,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: ttc.masked_load {{.*}} -> vector<4xf32> %ret = tt.load %offset_ptrs : tensor<4x!tt.ptr, #blocked> ttc.yield - }: (!tt.ptr) -> () + } tt.return } @@ -69,7 +69,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK-NOT: ttc.masked_store {{.*}} : (!llvm.ptr<1>, vector<1xf16> %c4_i32 = arith.constant 4 : i32 %c8_i32 = arith.constant 8 : i32 - ttc.generic (%data, %c_ptr, %stride_cm, %M, %N, %offs_am, %offs_bn) blocks + ttc.generic ins (%data, %c_ptr, %stride_cm, %M, %N, %offs_am, %offs_bn : + tensor<4x8xf16, #blocked>, !tt.ptr, i32, i32, i32, i32, i32) blocks [%c4_i32, %c8_i32 : i32, i32] attributes {tileShape = array, reductionDims = array} body { ^bb0(%tile_m: i32, %tile_n: i32, %tile_data: tensor<1x8xf16, #blocked>, @@ -116,7 +117,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.store %ptrs, %tile_data, %mask : tensor<1x8x!tt.ptr, #blocked> ttc.yield - } : (tensor<4x8xf16, #blocked>, !tt.ptr, i32, i32, i32, i32, i32) -> () + } tt.return } diff --git a/test/Conversion/generic-op-loops.mlir b/test/Conversion/generic-op-loops.mlir index 2989c94d..b18be3ac 100644 --- a/test/Conversion/generic-op-loops.mlir +++ b/test/Conversion/generic-op-loops.mlir @@ -17,13 +17,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @scale_1d(%scale: f32) { %c16_i32 = arith.constant 16 : i32 - %0 = ttc.generic (%scale) blocks [%c16_i32 : i32] attributes {tileShape = array, reductionDims = array} body { + %0 = ttc.generic ins (%scale : f32) blocks [%c16_i32 : i32] attributes {tileShape = array, reductionDims = array} body { ^bb0(%offset: i32, %s: f32): %cst = arith.constant dense<1.0> : tensor<4xf32, #blocked> %splat = tt.splat %s : f32 -> tensor<4xf32, #blocked> %result = arith.mulf %cst, %splat : tensor<4xf32, #blocked> ttc.yield %result : tensor<4xf32, #blocked> - } : (f32) -> tensor<16xf32, #blocked> + } -> tensor<16xf32, #blocked> tt.return } } @@ -40,13 +40,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @scale_2d(%scale: f32) { %c4_i32 = arith.constant 4 : i32 %c8_i32 = arith.constant 8 : i32 - %0 = ttc.generic (%scale) blocks [%c4_i32, %c8_i32 : i32, i32] attributes {tileShape = array, reductionDims = array} body { + %0 = ttc.generic ins (%scale : f32) blocks [%c4_i32, %c8_i32 : i32, i32] attributes {tileShape = array, reductionDims = array} body { ^bb0(%dim0: i32, %dim1: i32, %s: f32): %cst = arith.constant dense<1.0> : tensor<1x4xf32, #blocked> %splat = tt.splat %s : f32 -> tensor<1x4xf32, #blocked> %result = arith.mulf %cst, %splat : tensor<1x4xf32, #blocked> ttc.yield %result : tensor<1x4xf32, #blocked> - } : (f32) -> tensor<4x8xf32, #blocked> + } -> tensor<4x8xf32, #blocked> tt.return } } @@ -63,13 +63,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @scale_3d(%scale: f32) { %c4_i32 = arith.constant 4 : i32 %c8_i32 = arith.constant 8 : i32 - %0 = ttc.generic (%scale) blocks [%c4_i32, %c4_i32, %c8_i32 : i32, i32, i32] attributes {tileShape = array, reductionDims = array} body { + %0 = ttc.generic ins (%scale : f32) blocks [%c4_i32, %c4_i32, %c8_i32 : i32, i32, i32] attributes {tileShape = array, reductionDims = array} body { ^bb0(%dim0: i32, %dim1: i32, %dim2: i32, %s: f32): %cst = arith.constant dense<1.0> : tensor<1x1x4xf32, #blocked> %splat = tt.splat %s : f32 -> tensor<1x1x4xf32, #blocked> %result = arith.mulf %cst, %splat : tensor<1x1x4xf32, #blocked> ttc.yield %result : tensor<1x1x4xf32, #blocked> - } : (f32) -> tensor<4x4x8xf32, #blocked> + } -> tensor<4x4x8xf32, #blocked> tt.return } } @@ -87,13 +87,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar %c2_i32 = arith.constant 2 : i32 %c4_i32 = arith.constant 4 : i32 %c8_i32 = arith.constant 8 : i32 - %0 = ttc.generic (%scale) blocks [%c2_i32, %c2_i32, %c4_i32, %c8_i32 : i32, i32, i32, i32] attributes {tileShape = array, reductionDims = array} body { + %0 = ttc.generic ins (%scale : f32) blocks [%c2_i32, %c2_i32, %c4_i32, %c8_i32 : i32, i32, i32, i32] attributes {tileShape = array, reductionDims = array} body { ^bb0(%dim0: i32, %dim1: i32, %dim2: i32, %dim3: i32, %s: f32): %cst = arith.constant dense<1.0> : tensor<1x1x1x4xf32, #blocked> %splat = tt.splat %s : f32 -> tensor<1x1x1x4xf32, #blocked> %result = arith.mulf %cst, %splat : tensor<1x1x1x4xf32, #blocked> ttc.yield %result : tensor<1x1x1x4xf32, #blocked> - } : (f32) -> tensor<2x2x4x8xf32, #blocked> + } -> tensor<2x2x4x8xf32, #blocked> tt.return } }