Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/GENXOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def GENX_Matrix2DBlockLoadOp : GENX_Op<"matrix.2Dblockload">,
The 'genx.matrix.2Dblockload' operation loads a submatrix from an array in memory.
$ptr - the base address of the memory array
$base_width, $base_height, $base_pitch - the shape of the memory array
$x, $y, $tile_width, $tile_height - the starting offets and shape of the submatrix to load
$x, $y, $tile_width, $tile_height - the starting offsets and shape of the submatrix to load
$elem_size_in_bits - 32 for f32, bf32; 16 for f16, int16, bf16; 8 for int8, int4, int2 and etc
$v_blocks - number of blocks to load
$transpose - transpose the submatrix in vector register (useful for 32 bit element types)
Expand Down Expand Up @@ -492,7 +492,7 @@ def GENX_Matrix2DBlockStoreOp : GENX_Op<"matrix.2Dblockstore">,
The 'genx.matrix.2Dblockstore' operation stores to a submatrix from an array in memory.
$ptr - the base address of the memory array
$base_width, $base_height, $base_pitch - the shape of the memory array
$x, $y, $tile_width, $tile_height - the starting offets and shape of the submatrix to load
$x, $y, $tile_width, $tile_height - the starting offsets and shape of the submatrix to load
$elem_size_in_bits - 32 for f32, bf32; 16 for f16, int16, bf16; 8 for int8, int4, int2 and etc
$v_blocks - number of blocks to store
$transpose - transpose the submatrix in vector register (useful for 32 bit element types)
Expand Down
48 changes: 48 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/GENXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,30 @@ LogicalResult GENX::Matrix2DBlockLoadOp::verify() {
return this->emitOpError(
"4th operand (base pitch) should be >= 2nd operand (base width)");

Type InputElemType = getPtr().getType().getElementType();
switch (getElemSizeInBits()) {
case 32:
if (!InputElemType.isF32()) // || !InputElemType.isBFloat32Type();
return this->emitOpError(
"element of size 32 should be of type bf32 or f32");
break;

case 16:
if (!InputElemType.isF16() || !InputElemType.isBF16())
return this->emitOpError(
"element of size 16 should be of type bf16 or f16");
break;

case 8:
if (!InputElemType.isa<IntegerType>()) // ||!InputElemType.isIntegerTy(8);
return this->emitOpError(
"element of size 8 should be of type int8 or uint8");
break;

default:
return this->emitOpError(
"element size should be 8, 16 or 32 bits");
}
return success();
}

Expand All @@ -185,6 +209,30 @@ LogicalResult GENX::Matrix2DBlockStoreOp::verify() {
return this->emitOpError(
"4th operand (base pitch) should be >= 2nd operand (base width)");

Type InputElemType = getPtr().getType().getElementType();//cast<GENX::MatrixElemType>;
switch (getElemSizeInBits()) {
case 32:
if (!InputElemType.isF32()) // || !InputElemType.isBFloat32Type();
return this->emitOpError(
"element of size 32 should be of type bf32 or f32");
break;

case 16:
if (!InputElemType.isF16() || !InputElemType.isBF16())
return this->emitOpError(
"element of size 16 should be of type bf16 or f16");
break;

case 8:
if (!InputElemType.isa<IntegerType>()) // ||!InputElemType.isIntegerTy(8);
return this->emitOpError(
"element of size 8 should be of type int8 or uint8");
break;

default:
return this->emitOpError(
"element size should be 8, 16 or 32 bits");
}
return success();
}

Expand Down
64 changes: 64 additions & 0 deletions mlir/test/Dialect/LLVMIR/genx-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,38 @@ func.func @matrix_2Dblockload(%ptr : !llvm.ptr<i32>, %base_height : i32, %x : i3

// -----

func.func @matrix_2Dblockload(%ptr : !llvm.ptr<i8>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
// expected-error @+1 {{'genx.matrix.2Dblockload' element of size 32 should be of type bf32 or f32}}
%0 = genx.matrix.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr<i8>, i32, i32, i32, i32, i32) -> vector<4xi8>
llvm.return
}

// -----

func.func @matrix_2Dblockload(%ptr : !llvm.ptr<i32>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
// expected-error @+1 {{'genx.matrix.2Dblockload' element of size 16 should be of type bf16 or f16}}
%0 = genx.matrix.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr<i32>, i32, i32, i32, i32, i32) -> vector<4xi32>
llvm.return
}

// -----

func.func @matrix_2Dblockload(%ptr : !llvm.ptr<i8>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
// expected-error @+1 {{'genx.matrix.2Dblockload' element of size 16 should be of type bf16 or f16}}
%0 = genx.matrix.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr<i8>, i32, i32, i32, i32, i32) -> vector<4xi8>
llvm.return
}

// -----

func.func @matrix_2Dblockload(%ptr : !llvm.ptr<i32>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
// expected-error @+1 {{'genx.matrix.2Dblockload' element of size 8 should be of type int8 or uint8}}
%0 = genx.matrix.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr<i32>, i32, i32, i32, i32, i32) -> vector<4xi32>
llvm.return
}

// -----

func.func @matrix_2Dblockstore(%ptr : !llvm.ptr<i32>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<4xi32>) {
// expected-error @+1 {{'genx.matrix.2Dblockstore' op expecting 'elem_size_in_bits' to be 8, 16, or 32}}
genx.matrix.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=64:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr<i32>, i32, i32, i32, i32, i32, vector<4xi32>)
Expand All @@ -132,6 +164,38 @@ func.func @matrix_2Dblockstore(%ptr : !llvm.ptr<i32>, %base_height : i32, %x : i

// -----

func.func @matrix_2Dblockstore(%ptr : !llvm.ptr<i8>, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<16xi8>) {
// expected-error @+1 {{'genx.matrix.2Dblockstore' element of size 32 should be of type bf32 or f32}}
genx.matrix.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr<i8>, i32, i32, i32, i32, i32, vector<4xi8>)
llvm.return
}

// -----

func.func @matrix_2Dblockstore(%ptr : !llvm.ptr<i32>, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<4xi32>) {
// expected-error @+1 {{'genx.matrix.2Dblockstore' element of size 16 should be of type bf16 or f16}}
genx.matrix.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr<i32>, i32, i32, i32, i32, i32, vector<4xi32>)
llvm.return
}

// -----

func.func @matrix_2Dblockstore(%ptr : !llvm.ptr<i8>, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<16xi8>) {
// expected-error @+1 {{'genx.matrix.2Dblockstore' element of size 16 should be of type bf16 or f16}}
genx.matrix.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr<i8>, i32, i32, i32, i32, i32, vector<4xi8>)
llvm.return
}

// -----

func.func @matrix_2Dblockstore(%ptr : !llvm.ptr<i32>, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<4xi32>) {
// expected-error @+1 {{'genx.matrix.2Dblockstore' element of size 8 should be of type int8 or uint8}}
genx.matrix.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr<i32>, i32, i32, i32, i32, i32, vector<4xi32>)
llvm.return
}

// -----

func.func @joint_matrix_load(%ptr : !llvm.ptr<i32>, %stride : index) {
// expected-error @+1 {{'genx.matrix.load' op scope attribute must have value 'Subgroup'}}
%0 = genx.matrix.load <Workgroup> <RowMajor> %ptr, %stride {memory_access = #genx.memory_access<Volatile>} : (!llvm.ptr<i32>, index) -> !genx.jointmatrix<8x16xi32, RowMajor>
Expand Down