Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions cpu/include/Dialect/TritonCPU/IR/TritonCPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def TTC_GenericOp : TTC_Op<"generic", [AttrSizedOperandSegments, RecursiveMemory

Example — elementwise chain (no reductions):
```
ttc.generic (%src_ptrs, %dst_ptrs) blocks [%c128 : i32]
ttc.generic ins (%src_ptrs, %dst_ptrs) blocks [%c128 : i32]
{tileShape = array<i32: 16>, reductionDims = array<i32>}
body {
^bb0(%iv: i32, %src: tensor<16x!tt.ptr<f32>>, %dst: tensor<16x!tt.ptr<f32>>):
Expand All @@ -226,7 +226,7 @@ def TTC_GenericOp : TTC_Op<"generic", [AttrSizedOperandSegments, RecursiveMemory

Example — scalar reduction (softmax max pass):
```
%max = ttc.generic (%ptrs, %n_cols) init (%neg_inf : f32)
%max = ttc.generic ins (%ptrs, %n_cols) init (%neg_inf : f32)
blocks [%c128 : i32]
{tileShape = array<i32: 16>, reductionDims = array<i32: 0>}
body {
Expand All @@ -240,7 +240,7 @@ def TTC_GenericOp : TTC_Op<"generic", [AttrSizedOperandSegments, RecursiveMemory

Example — mixed scalar + tensor result (softmax sum + exp):
```
%sum, %exp_out = ttc.generic (%ptrs, %n_cols) init (%zero : f32)
%sum, %exp_out = ttc.generic ins (%ptrs, %n_cols) init (%zero : f32)
blocks [%c128 : i32]
{tileShape = array<i32: 16>, reductionDims = array<i32: 0>}
body {
Expand Down Expand Up @@ -276,7 +276,7 @@ def TTC_GenericOp : TTC_Op<"generic", [AttrSizedOperandSegments, RecursiveMemory
];

let assemblyFormat = [{
`(` $ins `)` (`init` `(` $init_vals^ `:` type($init_vals) `)`)? `blocks` `[` $blockShape `:` type($blockShape) `]` attr-dict-with-keyword `body` $body `:` functional-type($ins, $results)
(`init` `(` $init_vals^ `:` type($init_vals) `)`)? (`ins` `(` $ins^ `:` type($ins) `)`)? `blocks` `[` $blockShape `:` type($blockShape) `]` attr-dict-with-keyword `body` $body (`->` type($results)^)?
}];

let extraClassDeclaration = [{
Expand All @@ -293,6 +293,8 @@ def TTC_GenericOp : TTC_Op<"generic", [AttrSizedOperandSegments, RecursiveMemory
unsigned getInsArgOffset() {
return getNumInductionVars() + getNumIterArgs();
}

std::string getHeader();
}];
}

Expand Down
7 changes: 7 additions & 0 deletions cpu/lib/Dialect/TritonCPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ LogicalResult GenericOp::verify() {
return success();
}

std::string GenericOp::getHeader() {
std::string s;
llvm::raw_string_ostream os(s);
print(os, OpPrintingFlags().skipRegions());
return s;
}

LogicalResult MakeDynamicRangeOp::verify() {
auto resultTensorTy = cast<RankedTensorType>(getResult().getType());
if (resultTensorTy.getShape().size() != 1)
Expand Down
6 changes: 1 addition & 5 deletions cpu/lib/TritonCPUToLLVM/GenericOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,11 +776,7 @@ struct GenericOpConversion : public ConvertOpToLLVMPattern<cpu::GenericOp> {
vectorSize *= tileShape[d];
}

// TODO: put this into extraClassDefinitions?
std::string s;
llvm::raw_string_ostream os(s);
op->print(os, OpPrintingFlags().skipRegions());
LDBG("Lowering generic op: " << s
LDBG("Lowering generic op: " << op.getHeader()
<< "\n with vectorSize = " << vectorSize);

SmallVector<ArgInfo> argInfos =
Expand Down
13 changes: 7 additions & 6 deletions test/Analysis/axis-info-generic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
%x_3 = tt.addptr %x, %offsets_1 : tensor<512x!tt.ptr<f32>, #blocked>, tensor<512xi32, #blocked>

// CHECK-COUNT-128: ttc.masked_load {{.*}} -> vector<4xf32>
ttc.generic (%x_3) blocks [%c512_i32 : i32] attributes {tileShape = array<i32: 4>, reductionDims = array<i32>} body {
ttc.generic ins (%x_3 : tensor<512x!tt.ptr<f32>, #blocked>) blocks [%c512_i32 : i32] attributes {tileShape = array<i32: 4>, reductionDims = array<i32>} body {
^bb0(%offset:i32, %arg0: tensor<4x!tt.ptr<f32>, #blocked>):
%x_10 = tt.load %arg0 : tensor<4x!tt.ptr<f32>, #blocked>
ttc.yield
}: (tensor<512x!tt.ptr<f32>, #blocked>) -> ()
}
tt.return
}
}
Expand All @@ -32,7 +32,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32, ttg.target = "cpu", "ttg.threads-per-warp" = 1 : i32} {
tt.func public @load_scalar(%x_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
%c1024_i32 = arith.constant 1024 : i32
ttc.generic (%x_ptr) blocks [%c1024_i32 : i32] attributes {tileShape = array<i32: 4>, reductionDims = array<i32>} body {
ttc.generic ins (%x_ptr : !tt.ptr<f32>) blocks [%c1024_i32 : i32] attributes {tileShape = array<i32: 4>, reductionDims = array<i32>} body {
^bb0(%tileOffset: i32, %ptr: !tt.ptr<f32>):
%offsets = ttc.make_dynamic_range %tileOffset : tensor<4xi32, #blocked>
%ptrs = tt.splat %ptr : !tt.ptr<f32> -> tensor<4x!tt.ptr<f32>, #blocked>
Expand All @@ -46,7 +46,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
// CHECK: ttc.masked_load {{.*}} -> vector<4xf32>
%ret = tt.load %offset_ptrs : tensor<4x!tt.ptr<f32>, #blocked>
ttc.yield
}: (!tt.ptr<f32>) -> ()
}

tt.return
}
Expand All @@ -69,7 +69,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
// CHECK-NOT: ttc.masked_store {{.*}} : (!llvm.ptr<1>, vector<1xf16>
%c4_i32 = arith.constant 4 : i32
%c8_i32 = arith.constant 8 : i32
ttc.generic (%data, %c_ptr, %stride_cm, %M, %N, %offs_am, %offs_bn) blocks
ttc.generic ins (%data, %c_ptr, %stride_cm, %M, %N, %offs_am, %offs_bn :
tensor<4x8xf16, #blocked>, !tt.ptr<f16>, i32, i32, i32, i32, i32) blocks
[%c4_i32, %c8_i32 : i32, i32] attributes {tileShape = array<i32: 1, 8>, reductionDims = array<i32>} body {
^bb0(%tile_m: i32, %tile_n: i32,
%tile_data: tensor<1x8xf16, #blocked>,
Expand Down Expand Up @@ -116,7 +117,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar

tt.store %ptrs, %tile_data, %mask : tensor<1x8x!tt.ptr<f16>, #blocked>
ttc.yield
} : (tensor<4x8xf16, #blocked>, !tt.ptr<f16>, i32, i32, i32, i32, i32) -> ()
}

tt.return
}
Expand Down
16 changes: 8 additions & 8 deletions test/Conversion/generic-op-loops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
tt.func public @scale_1d(%scale: f32) {
%c16_i32 = arith.constant 16 : i32

%0 = ttc.generic (%scale) blocks [%c16_i32 : i32] attributes {tileShape = array<i32: 4>, reductionDims = array<i32>} body {
%0 = ttc.generic ins (%scale : f32) blocks [%c16_i32 : i32] attributes {tileShape = array<i32: 4>, reductionDims = array<i32>} body {
^bb0(%offset: i32, %s: f32):
%cst = arith.constant dense<1.0> : tensor<4xf32, #blocked>
%splat = tt.splat %s : f32 -> tensor<4xf32, #blocked>
%result = arith.mulf %cst, %splat : tensor<4xf32, #blocked>
ttc.yield %result : tensor<4xf32, #blocked>
} : (f32) -> tensor<16xf32, #blocked>
} -> tensor<16xf32, #blocked>
tt.return
}
}
Expand All @@ -40,13 +40,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
tt.func public @scale_2d(%scale: f32) {
%c4_i32 = arith.constant 4 : i32
%c8_i32 = arith.constant 8 : i32
%0 = ttc.generic (%scale) blocks [%c4_i32, %c8_i32 : i32, i32] attributes {tileShape = array<i32: 1, 4>, reductionDims = array<i32>} body {
%0 = ttc.generic ins (%scale : f32) blocks [%c4_i32, %c8_i32 : i32, i32] attributes {tileShape = array<i32: 1, 4>, reductionDims = array<i32>} body {
^bb0(%dim0: i32, %dim1: i32, %s: f32):
%cst = arith.constant dense<1.0> : tensor<1x4xf32, #blocked>
%splat = tt.splat %s : f32 -> tensor<1x4xf32, #blocked>
%result = arith.mulf %cst, %splat : tensor<1x4xf32, #blocked>
ttc.yield %result : tensor<1x4xf32, #blocked>
} : (f32) -> tensor<4x8xf32, #blocked>
} -> tensor<4x8xf32, #blocked>
tt.return
}
}
Expand All @@ -63,13 +63,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
tt.func public @scale_3d(%scale: f32) {
%c4_i32 = arith.constant 4 : i32
%c8_i32 = arith.constant 8 : i32
%0 = ttc.generic (%scale) blocks [%c4_i32, %c4_i32, %c8_i32 : i32, i32, i32] attributes {tileShape = array<i32: 1, 1, 4>, reductionDims = array<i32>} body {
%0 = ttc.generic ins (%scale : f32) blocks [%c4_i32, %c4_i32, %c8_i32 : i32, i32, i32] attributes {tileShape = array<i32: 1, 1, 4>, reductionDims = array<i32>} body {
^bb0(%dim0: i32, %dim1: i32, %dim2: i32, %s: f32):
%cst = arith.constant dense<1.0> : tensor<1x1x4xf32, #blocked>
%splat = tt.splat %s : f32 -> tensor<1x1x4xf32, #blocked>
%result = arith.mulf %cst, %splat : tensor<1x1x4xf32, #blocked>
ttc.yield %result : tensor<1x1x4xf32, #blocked>
} : (f32) -> tensor<4x4x8xf32, #blocked>
} -> tensor<4x4x8xf32, #blocked>
tt.return
}
}
Expand All @@ -87,13 +87,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
%c2_i32 = arith.constant 2 : i32
%c4_i32 = arith.constant 4 : i32
%c8_i32 = arith.constant 8 : i32
%0 = ttc.generic (%scale) blocks [%c2_i32, %c2_i32, %c4_i32, %c8_i32 : i32, i32, i32, i32] attributes {tileShape = array<i32: 1, 1, 1, 4>, reductionDims = array<i32>} body {
%0 = ttc.generic ins (%scale : f32) blocks [%c2_i32, %c2_i32, %c4_i32, %c8_i32 : i32, i32, i32, i32] attributes {tileShape = array<i32: 1, 1, 1, 4>, reductionDims = array<i32>} body {
^bb0(%dim0: i32, %dim1: i32, %dim2: i32, %dim3: i32, %s: f32):
%cst = arith.constant dense<1.0> : tensor<1x1x1x4xf32, #blocked>
%splat = tt.splat %s : f32 -> tensor<1x1x1x4xf32, #blocked>
%result = arith.mulf %cst, %splat : tensor<1x1x1x4xf32, #blocked>
ttc.yield %result : tensor<1x1x1x4xf32, #blocked>
} : (f32) -> tensor<2x2x4x8xf32, #blocked>
} -> tensor<2x2x4x8xf32, #blocked>
tt.return
}
}
Loading