diff --git a/mlir/include/mlir/Dialect/LLVMIR/GENXOps.td b/mlir/include/mlir/Dialect/LLVMIR/GENXOps.td index bcdc3beeff2a9..ff9480cc2e0b9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/GENXOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/GENXOps.td @@ -445,7 +445,7 @@ def GENX_Matrix2DBlockLoadOp : GENX_Op<"matrix.2Dblockload">, The 'genx.matrix.2Dblockload' operation loads a submatrix from an array in memory. $ptr - the base address of the memory array $base_width, $base_height, $base_pitch - the shape of the memory array - $x, $y, $tile_width, $tile_height - the starting offets and shape of the submatrix to load + $x, $y, $tile_width, $tile_height - the starting offsets and shape of the submatrix to load $elem_size_in_bits - 32 for f32, bf32; 16 for f16, int16, bf16; 8 for int8, int4, int2 and etc $v_blocks - number of blocks to load $transpose - transpose the submatrix in vector register (useful for 32 bit element types) @@ -492,7 +492,7 @@ def GENX_Matrix2DBlockStoreOp : GENX_Op<"matrix.2Dblockstore">, The 'genx.matrix.2Dblockstore' operation stores to a submatrix from an array in memory. $ptr - the base address of the memory array $base_width, $base_height, $base_pitch - the shape of the memory array - $x, $y, $tile_width, $tile_height - the starting offets and shape of the submatrix to load + $x, $y, $tile_width, $tile_height - the starting offsets and shape of the submatrix to load $elem_size_in_bits - 32 for f32, bf32; 16 for f16, int16, bf16; 8 for int8, int4, int2 and etc $v_blocks - number of blocks to store $transpose - transpose the submatrix in vector register (useful for 32 bit element types) diff --git a/mlir/lib/Dialect/LLVMIR/IR/GENXOps.cpp b/mlir/lib/Dialect/LLVMIR/IR/GENXOps.cpp index f9bfc5b47bbd3..48f36252e0159 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/GENXOps.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/GENXOps.cpp @@ -162,6 +162,30 @@ LogicalResult GENX::Matrix2DBlockLoadOp::verify() { return this->emitOpError( "4th operand (base pitch) should be >= 2nd operand (base width)"); + Type InputElemType = getPtr().getType().getElementType(); + switch (getElemSizeInBits()) { + case 32: + if (!InputElemType.isF32()) // || !InputElemType.isBFloat32Type(); + return this->emitOpError( + "element of size 32 should be of type bf32 or f32"); + break; + + case 16: + if (!InputElemType.isF16() || !InputElemType.isBF16()) + return this->emitOpError( + "element of size 16 should be of type bf16 or f16"); + break; + + case 8: + if (!InputElemType.isa()) // ||!InputElemType.isIntegerTy(8); + return this->emitOpError( + "element of size 8 should be of type int8 or uint8"); + break; + + default: + return this->emitOpError( + "element size should be 8, 16 or 32 bits"); + } return success(); } @@ -185,6 +209,30 @@ LogicalResult GENX::Matrix2DBlockStoreOp::verify() { return this->emitOpError( "4th operand (base pitch) should be >= 2nd operand (base width)"); + Type InputElemType = getPtr().getType().getElementType();//cast; + switch (getElemSizeInBits()) { + case 32: + if (!InputElemType.isF32()) // || !InputElemType.isBFloat32Type(); + return this->emitOpError( + "element of size 32 should be of type bf32 or f32"); + break; + + case 16: + if (!InputElemType.isF16() || !InputElemType.isBF16()) + return this->emitOpError( + "element of size 16 should be of type bf16 or f16"); + break; + + case 8: + if (!InputElemType.isa()) // ||!InputElemType.isIntegerTy(8); + return this->emitOpError( + "element of size 8 should be of type int8 or uint8"); + break; + + default: + return this->emitOpError( + "element size should be 8, 16 or 32 bits"); + } return success(); } diff --git a/mlir/test/Dialect/LLVMIR/genx-invalid.mlir b/mlir/test/Dialect/LLVMIR/genx-invalid.mlir index a708730244e22..43fff7a5f2c5c 100644 --- a/mlir/test/Dialect/LLVMIR/genx-invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/genx-invalid.mlir @@ -106,6 +106,38 @@ func.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %x : i3 // ----- +func.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'genx.matrix.2Dblockload' element of size 32 should be of type bf32 or f32}} + %0 = genx.matrix.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi8> + llvm.return +} + +// ----- + +func.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'genx.matrix.2Dblockload' element of size 16 should be of type bf16 or f16}} + %0 = genx.matrix.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi32> + llvm.return +} + +// ----- + +func.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'genx.matrix.2Dblockload' element of size 16 should be of type bf16 or f16}} + %0 = genx.matrix.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi8> + llvm.return +} + +// ----- + +func.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'genx.matrix.2Dblockload' element of size 8 should be of type int8 or uint8}} + %0 = genx.matrix.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi32> + llvm.return +} + +// ----- + func.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<4xi32>) { // expected-error @+1 {{'genx.matrix.2Dblockstore' op expecting 'elem_size_in_bits' to be 8, 16, or 32}} genx.matrix.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=64:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xi32>) @@ -132,6 +164,38 @@ func.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %x : i // ----- +func.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<16xi8>) { +// expected-error @+1 {{'genx.matrix.2Dblockstore' element of size 32 should be of type bf32 or f32}} + genx.matrix.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xi8>) + llvm.return +} + +// ----- + +func.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<4xi32>) { +// expected-error @+1 {{'genx.matrix.2Dblockstore' element of size 16 should be of type bf16 or f16}} + genx.matrix.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xi32>) + llvm.return +} + +// ----- + +func.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<16xi8>) { +// expected-error @+1 {{'genx.matrix.2Dblockstore' element of size 16 should be of type bf16 or f16}} + genx.matrix.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xi8>) + llvm.return +} + +// ----- + +func.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<4xi32>) { +// expected-error @+1 {{'genx.matrix.2Dblockstore' element of size 8 should be of type int8 or uint8}} + genx.matrix.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8:i32, tile_width=4:i32, tile_height=1:i32, v_blocks=1:i32, transpose=false, vnni_transform=false} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xi32>) + llvm.return +} + +// ----- + func.func @joint_matrix_load(%ptr : !llvm.ptr, %stride : index) { // expected-error @+1 {{'genx.matrix.load' op scope attribute must have value 'Subgroup'}} %0 = genx.matrix.load %ptr, %stride {memory_access = #genx.memory_access} : (!llvm.ptr, index) -> !genx.jointmatrix<8x16xi32, RowMajor>