From e328e91174745c29177639510edb141178cff68c Mon Sep 17 00:00:00 2001 From: FangRui Date: Thu, 9 Apr 2026 17:10:12 +0800 Subject: [PATCH] fix: delete tile shape verifier helper, inplaced by validshape check --- lib/PTO/IR/PTO.cpp | 137 +++++++++++++++++++-------------------------- 1 file changed, 57 insertions(+), 80 deletions(-) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index a419713ce..e9dc87882 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -100,9 +100,9 @@ static SmallVector getShapeVec(Type ty); static SmallVector getValidShapeVec(Type ty); static SmallVector getValidShapeVec(Value value); static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name); -static LogicalResult verifyTileBufSameShapeAndElem(Operation *op, Type lhs, Type rhs, - StringRef lhsName, - StringRef rhsName); +static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, + StringRef lhsName, + StringRef rhsName); static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, StringRef lhsName, StringRef rhsName); static LogicalResult verifyVecTileCommon(Operation *op, Type ty, StringRef name); @@ -115,10 +115,6 @@ static LogicalResult verifyVecTileUnaryOp(Operation *op, Type srcTy, Type dstTy, StringRef dstName = "dst", bool allowBf16 = true, bool allowInt8 = true); -static LogicalResult verifyVecTileBinaryOp(Operation *op, Type src0Ty, Type src1Ty, - Type dstTy, - bool allowBf16 = true, - bool allowInt8 = true); static LogicalResult verifyAccTileCommon(Operation *op, Type ty, StringRef name); static LogicalResult verifyAccTileCommonA2A3(Operation *op, Type ty, StringRef name); @@ -1885,7 +1881,7 @@ LogicalResult pto::TAbsOp::verify() { if (failed(verifyVecTileCommon(*this, srcTy, "src")) || failed(verifyVecTileCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) return failure(); @@ -2139,18 +2135,15 @@ static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name) return success(); } -static LogicalResult verifyTileBufSameShapeAndElem(Operation *op, Type lhs, Type rhs, - StringRef lhsName, - StringRef rhsName) { +static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, + StringRef lhsName, + StringRef rhsName) { if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) return op->emitOpError() << "expects " << lhsName << " and " << rhsName << " to be !pto.tile_buf or memref"; if (getElemTy(lhs) != getElemTy(rhs)) return op->emitOpError() << "expects " << lhsName << " and " << rhsName << " to have the same element type"; - if (getShapeVec(lhs) != getShapeVec(rhs)) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have the same shape"; return success(); } @@ -2264,7 +2257,7 @@ static LogicalResult verifyScalarTileOp(Operation *op, Type srcTy, Type dstTy, if (!dstSpace || *dstSpace != pto::AddressSpace::VEC) return op->emitOpError() << "expects " << dstName << " to be in the vec address space"; - if (failed(verifyTileBufSameShapeAndElem(op, srcTy, dstTy, srcName, dstName))) + if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) return failure(); auto srcValid = getValidShapeVec(srcTy); @@ -2340,29 +2333,13 @@ static LogicalResult verifyVecTileUnaryOp(Operation *op, Type srcTy, Type dstTy, if (failed(verifyVecTileCommon(op, srcTy, srcName)) || failed(verifyVecTileCommon(op, dstTy, dstName))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(op, srcTy, dstTy, srcName, dstName))) + if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) return failure(); if (!isSupportedVecElemType(getElemTy(srcTy), allowBf16, allowInt8)) return op->emitOpError() << "expects vec tile element types to be supported"; return success(); } -static LogicalResult verifyVecTileBinaryOp(Operation *op, Type src0Ty, Type src1Ty, - Type dstTy, - bool allowBf16, - bool allowInt8) { - if (failed(verifyVecTileCommon(op, src0Ty, "src0")) || - failed(verifyVecTileCommon(op, src1Ty, "src1")) || - failed(verifyVecTileCommon(op, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameShapeAndElem(op, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(op, src0Ty, dstTy, "src0", "dst"))) - return failure(); - if (!isSupportedVecElemType(getElemTy(src0Ty), allowBf16, allowInt8)) - return op->emitOpError() << "expects vec tile element types to be supported"; - return success(); -} - static LogicalResult verifyAccTileCommonA2A3(Operation *op, Type ty, StringRef name) { if (failed(verifyTileBufCommon(op, ty, name))) @@ -2590,8 +2567,8 @@ LogicalResult pto::TAddOp::verify() { failed(verifyTileBufCommon(*this, t1, "src1")) || failed(verifyTileBufCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, t0, t1, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, t0, td, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, t0, t1, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, t0, td, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, t0, t1, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst"))) return failure(); @@ -2610,8 +2587,8 @@ LogicalResult pto::TAddOp::verify() { failed(verifyTileBufCommon(*this, t1, "src1")) || failed(verifyTileBufCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, t0, t1, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, t0, td, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, t0, t1, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, t0, td, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, t0, t1, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst"))) return failure(); @@ -3480,8 +3457,8 @@ LogicalResult mlir::pto::TDivOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -3501,8 +3478,8 @@ LogicalResult mlir::pto::TDivOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -4585,7 +4562,7 @@ mlir::LogicalResult mlir::pto::TLReluOp::verify() { if (failed(verifyVecTileStorage(*this, srcTy, "src")) || failed(verifyVecTileStorage(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) return failure(); auto valid = getValidShapeVec(srcTy); @@ -4604,7 +4581,7 @@ mlir::LogicalResult mlir::pto::TLReluOp::verify() { if (failed(verifyVecTileStorage(*this, srcTy, "src")) || failed(verifyVecTileStorage(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) return failure(); Type elemTy = getElemTy(srcTy); @@ -4626,8 +4603,8 @@ mlir::LogicalResult mlir::pto::TMaxOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -4647,8 +4624,8 @@ mlir::LogicalResult mlir::pto::TMaxOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -4708,8 +4685,8 @@ mlir::LogicalResult mlir::pto::TMinOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -4729,8 +4706,8 @@ mlir::LogicalResult mlir::pto::TMinOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -5129,7 +5106,7 @@ LogicalResult TGemvMxAccOp::verify() { if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), getDst().getType(), "a", "b", "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, getCIn().getType(), + if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), getDst().getType(), "c_in", "dst")) || failed(verifyTileBufSameValidShape(*this, getCIn().getType(), getDst().getType(), "c_in", "dst"))) @@ -5221,7 +5198,7 @@ LogicalResult TMatmulMxAccOp::verify() { if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), getDst().getType(), "a", "b", "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, getCIn().getType(), + if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), getDst().getType(), "c_in", "dst")) || failed(verifyTileBufSameValidShape(*this, getCIn().getType(), getDst().getType(), "c_in", "dst"))) @@ -5499,8 +5476,8 @@ mlir::LogicalResult mlir::pto::TMulOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -5520,8 +5497,8 @@ mlir::LogicalResult mlir::pto::TMulOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -5649,7 +5626,7 @@ mlir::LogicalResult mlir::pto::TNegOp::verify() { if (failed(verifyVecTileStorage(*this, srcTy, "src")) || failed(verifyVecTileStorage(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) return failure(); @@ -5667,7 +5644,7 @@ mlir::LogicalResult mlir::pto::TNegOp::verify() { if (failed(verifyVecTileStorage(*this, srcTy, "src")) || failed(verifyVecTileStorage(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst"))) + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst"))) return failure(); auto srcValid = getValidShapeVec(srcTy); @@ -6344,7 +6321,7 @@ mlir::LogicalResult mlir::pto::TReluOp::verify() { if (failed(verifyVecTileCommon(*this, ts, "src")) || failed(verifyVecTileCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, ts, td, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, ts, td, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) return failure(); Type elemTy = getElemTy(ts); @@ -6358,7 +6335,7 @@ mlir::LogicalResult mlir::pto::TReluOp::verify() { if (failed(verifyVecTileCommon(*this, ts, "src")) || failed(verifyVecTileCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, ts, td, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, ts, td, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) return failure(); Type elemTy = getElemTy(ts); @@ -6383,8 +6360,8 @@ mlir::LogicalResult mlir::pto::TRemOp::verify() { failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -6426,8 +6403,8 @@ mlir::LogicalResult mlir::pto::TFModOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -6447,8 +6424,8 @@ mlir::LogicalResult mlir::pto::TFModOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -6474,7 +6451,7 @@ mlir::LogicalResult mlir::pto::TRemSOp::verify() { failed(verifyTileBufCommon(*this, tt, "tmp")) || failed(verifyTileBufCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, ts, td, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, ts, td, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) return failure(); if (getElemTy(tt) != getElemTy(td)) @@ -6516,7 +6493,7 @@ mlir::LogicalResult mlir::pto::TFModSOp::verify() { if (failed(verifyTileBufCommon(*this, srcTy, "src")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) return failure(); if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) @@ -7045,7 +7022,7 @@ mlir::LogicalResult mlir::pto::TRowExpandDivOp::verify() { if (getTmp() && failed(verifyTileBufCommon(*this, getTmp().getType(), "tmp"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst"))) + if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); if (getElemTy(src0Ty) != getElemTy(src1Ty)) return emitOpError("expects src0 and src1 to have the same element type"); @@ -7076,7 +7053,7 @@ mlir::LogicalResult mlir::pto::TRowExpandMulOp::verify() { if (getTmp() && failed(verifyTileBufCommon(*this, getTmp().getType(), "tmp"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst"))) + if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); if (getElemTy(src0Ty) != getElemTy(src1Ty)) return emitOpError("expects src0 and src1 to have the same element type"); @@ -7107,7 +7084,7 @@ mlir::LogicalResult mlir::pto::TRowExpandSubOp::verify() { if (getTmp() && failed(verifyTileBufCommon(*this, getTmp().getType(), "tmp"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst"))) + if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); if (getElemTy(src0Ty) != getElemTy(src1Ty)) return emitOpError("expects src0 and src1 to have the same element type"); @@ -7134,7 +7111,7 @@ mlir::LogicalResult mlir::pto::TRowExpandAddOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst"))) + if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); if (failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -7389,7 +7366,7 @@ mlir::LogicalResult mlir::pto::TRowMinOp::verify() { failed(verifyVecTileCommon(*this, tt, "tmp")) || failed(verifyRowReductionDstLayout(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, ts, tt, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, ts, tt, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, ts, tt, "src", "tmp"))) return failure(); if (getElemTy(ts) != getElemTy(td)) @@ -7409,7 +7386,7 @@ mlir::LogicalResult mlir::pto::TRowMinOp::verify() { failed(verifyVecTileCommon(*this, tt, "tmp")) || failed(verifyRowReductionDstLayout(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, ts, tt, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, ts, tt, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, ts, tt, "src", "tmp"))) return failure(); if (getElemTy(ts) != getElemTy(td)) @@ -7468,7 +7445,7 @@ mlir::LogicalResult mlir::pto::TRowProdOp::verify() { failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, tmpTy, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) return failure(); if (getElemTy(srcTy) != getElemTy(dstTy)) @@ -7488,7 +7465,7 @@ mlir::LogicalResult mlir::pto::TRowProdOp::verify() { failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, tmpTy, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) return failure(); if (getElemTy(srcTy) != getElemTy(dstTy)) @@ -7955,8 +7932,8 @@ mlir::LogicalResult mlir::pto::TSubOp::verify() { failed(verifyTileBufCommon(*this, t1, "src1")) || failed(verifyTileBufCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, t0, t1, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, t0, td, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, t0, t1, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, t0, td, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, t0, t1, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst"))) return failure(); @@ -7975,8 +7952,8 @@ mlir::LogicalResult mlir::pto::TSubOp::verify() { failed(verifyTileBufCommon(*this, t1, "src1")) || failed(verifyTileBufCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, t0, t1, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, t0, td, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, t0, t1, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, t0, td, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, t0, t1, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst"))) return failure();