diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index a419713ce..6ecf14f66 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -100,9 +100,9 @@ static SmallVector getShapeVec(Type ty); static SmallVector getValidShapeVec(Type ty); static SmallVector getValidShapeVec(Value value); static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name); -static LogicalResult verifyTileBufSameShapeAndElem(Operation *op, Type lhs, Type rhs, - StringRef lhsName, - StringRef rhsName); +static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, + StringRef lhsName, + StringRef rhsName); static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, StringRef lhsName, StringRef rhsName); static LogicalResult verifyVecTileCommon(Operation *op, Type ty, StringRef name); @@ -1885,7 +1885,7 @@ LogicalResult pto::TAbsOp::verify() { if (failed(verifyVecTileCommon(*this, srcTy, "src")) || failed(verifyVecTileCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) return failure(); @@ -2139,18 +2139,15 @@ static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name) return success(); } -static LogicalResult verifyTileBufSameShapeAndElem(Operation *op, Type lhs, Type rhs, - StringRef lhsName, - StringRef rhsName) { +static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, + StringRef lhsName, + StringRef rhsName) { if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) return op->emitOpError() << "expects " << lhsName << " and " << rhsName << " to be !pto.tile_buf or memref"; if (getElemTy(lhs) != getElemTy(rhs)) return op->emitOpError() << "expects " << lhsName << " and " << rhsName << " to have the same element type"; - if (getShapeVec(lhs) != getShapeVec(rhs)) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have the same shape"; return success(); } @@ -2264,7 +2261,7 @@ static LogicalResult verifyScalarTileOp(Operation *op, Type srcTy, Type dstTy, if (!dstSpace || *dstSpace != pto::AddressSpace::VEC) return op->emitOpError() << "expects " << dstName << " to be in the vec address space"; - if (failed(verifyTileBufSameShapeAndElem(op, srcTy, dstTy, srcName, dstName))) + if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) return failure(); auto srcValid = getValidShapeVec(srcTy); @@ -2340,7 +2337,7 @@ static LogicalResult verifyVecTileUnaryOp(Operation *op, Type srcTy, Type dstTy, if (failed(verifyVecTileCommon(op, srcTy, srcName)) || failed(verifyVecTileCommon(op, dstTy, dstName))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(op, srcTy, dstTy, srcName, dstName))) + if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) return failure(); if (!isSupportedVecElemType(getElemTy(srcTy), allowBf16, allowInt8)) return op->emitOpError() << "expects vec tile element types to be supported"; @@ -2355,8 +2352,8 @@ static LogicalResult verifyVecTileBinaryOp(Operation *op, Type src0Ty, Type src1 failed(verifyVecTileCommon(op, src1Ty, "src1")) || failed(verifyVecTileCommon(op, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(op, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(op, src0Ty, dstTy, "src0", "dst"))) + if (failed(verifyTileBufSameElemType(op, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(op, src0Ty, dstTy, "src0", "dst"))) return failure(); if (!isSupportedVecElemType(getElemTy(src0Ty), allowBf16, allowInt8)) return op->emitOpError() << "expects vec tile element types to be supported"; @@ -2590,8 +2587,8 @@ LogicalResult pto::TAddOp::verify() { failed(verifyTileBufCommon(*this, t1, "src1")) || failed(verifyTileBufCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, t0, t1, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, t0, td, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, t0, t1, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, t0, td, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, t0, t1, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst"))) return failure(); @@ -2610,8 +2607,8 @@ LogicalResult pto::TAddOp::verify() { failed(verifyTileBufCommon(*this, t1, "src1")) || failed(verifyTileBufCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, t0, t1, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, t0, td, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, t0, t1, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, t0, td, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, t0, t1, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst"))) return failure(); @@ -3480,8 +3477,8 @@ LogicalResult mlir::pto::TDivOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -3501,8 +3498,8 @@ LogicalResult mlir::pto::TDivOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -4585,7 +4582,7 @@ mlir::LogicalResult mlir::pto::TLReluOp::verify() { if (failed(verifyVecTileStorage(*this, srcTy, "src")) || failed(verifyVecTileStorage(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) return failure(); auto valid = getValidShapeVec(srcTy); @@ -4604,7 +4601,7 @@ mlir::LogicalResult mlir::pto::TLReluOp::verify() { if (failed(verifyVecTileStorage(*this, srcTy, "src")) || failed(verifyVecTileStorage(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) return failure(); Type elemTy = getElemTy(srcTy); @@ -4626,8 +4623,8 @@ mlir::LogicalResult mlir::pto::TMaxOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -4647,8 +4644,8 @@ mlir::LogicalResult mlir::pto::TMaxOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -4708,8 +4705,8 @@ mlir::LogicalResult mlir::pto::TMinOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -4729,8 +4726,8 @@ mlir::LogicalResult mlir::pto::TMinOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -5129,7 +5126,7 @@ LogicalResult TGemvMxAccOp::verify() { if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), getDst().getType(), "a", "b", "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, getCIn().getType(), + if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), getDst().getType(), "c_in", "dst")) || failed(verifyTileBufSameValidShape(*this, getCIn().getType(), getDst().getType(), "c_in", "dst"))) @@ -5221,7 +5218,7 @@ LogicalResult TMatmulMxAccOp::verify() { if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), getDst().getType(), "a", "b", "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, getCIn().getType(), + if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), getDst().getType(), "c_in", "dst")) || failed(verifyTileBufSameValidShape(*this, getCIn().getType(), getDst().getType(), "c_in", "dst"))) @@ -5499,8 +5496,8 @@ mlir::LogicalResult mlir::pto::TMulOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -5520,8 +5517,8 @@ mlir::LogicalResult mlir::pto::TMulOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -5649,7 +5646,7 @@ mlir::LogicalResult mlir::pto::TNegOp::verify() { if (failed(verifyVecTileStorage(*this, srcTy, "src")) || failed(verifyVecTileStorage(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) return failure(); @@ -5667,7 +5664,7 @@ mlir::LogicalResult mlir::pto::TNegOp::verify() { if (failed(verifyVecTileStorage(*this, srcTy, "src")) || failed(verifyVecTileStorage(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst"))) + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst"))) return failure(); auto srcValid = getValidShapeVec(srcTy); @@ -6344,7 +6341,7 @@ mlir::LogicalResult mlir::pto::TReluOp::verify() { if (failed(verifyVecTileCommon(*this, ts, "src")) || failed(verifyVecTileCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, ts, td, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, ts, td, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) return failure(); Type elemTy = getElemTy(ts); @@ -6358,7 +6355,7 @@ mlir::LogicalResult mlir::pto::TReluOp::verify() { if (failed(verifyVecTileCommon(*this, ts, "src")) || failed(verifyVecTileCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, ts, td, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, ts, td, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) return failure(); Type elemTy = getElemTy(ts); @@ -6383,8 +6380,8 @@ mlir::LogicalResult mlir::pto::TRemOp::verify() { failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -6426,8 +6423,8 @@ mlir::LogicalResult mlir::pto::TFModOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -6447,8 +6444,8 @@ mlir::LogicalResult mlir::pto::TFModOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -6474,7 +6471,7 @@ mlir::LogicalResult mlir::pto::TRemSOp::verify() { failed(verifyTileBufCommon(*this, tt, "tmp")) || failed(verifyTileBufCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, ts, td, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, ts, td, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) return failure(); if (getElemTy(tt) != getElemTy(td)) @@ -6516,7 +6513,7 @@ mlir::LogicalResult mlir::pto::TFModSOp::verify() { if (failed(verifyTileBufCommon(*this, srcTy, "src")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) return failure(); if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) @@ -7045,7 +7042,7 @@ mlir::LogicalResult mlir::pto::TRowExpandDivOp::verify() { if (getTmp() && failed(verifyTileBufCommon(*this, getTmp().getType(), "tmp"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst"))) + if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); if (getElemTy(src0Ty) != getElemTy(src1Ty)) return emitOpError("expects src0 and src1 to have the same element type"); @@ -7076,7 +7073,7 @@ mlir::LogicalResult mlir::pto::TRowExpandMulOp::verify() { if (getTmp() && failed(verifyTileBufCommon(*this, getTmp().getType(), "tmp"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst"))) + if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); if (getElemTy(src0Ty) != getElemTy(src1Ty)) return emitOpError("expects src0 and src1 to have the same element type"); @@ -7107,7 +7104,7 @@ mlir::LogicalResult mlir::pto::TRowExpandSubOp::verify() { if (getTmp() && failed(verifyTileBufCommon(*this, getTmp().getType(), "tmp"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst"))) + if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); if (getElemTy(src0Ty) != getElemTy(src1Ty)) return emitOpError("expects src0 and src1 to have the same element type"); @@ -7134,7 +7131,7 @@ mlir::LogicalResult mlir::pto::TRowExpandAddOp::verify() { failed(verifyTileBufCommon(*this, src1Ty, "src1")) || failed(verifyTileBufCommon(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst"))) + if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); if (failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) return failure(); @@ -7389,7 +7386,7 @@ mlir::LogicalResult mlir::pto::TRowMinOp::verify() { failed(verifyVecTileCommon(*this, tt, "tmp")) || failed(verifyRowReductionDstLayout(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, ts, tt, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, ts, tt, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, ts, tt, "src", "tmp"))) return failure(); if (getElemTy(ts) != getElemTy(td)) @@ -7409,7 +7406,7 @@ mlir::LogicalResult mlir::pto::TRowMinOp::verify() { failed(verifyVecTileCommon(*this, tt, "tmp")) || failed(verifyRowReductionDstLayout(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, ts, tt, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, ts, tt, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, ts, tt, "src", "tmp"))) return failure(); if (getElemTy(ts) != getElemTy(td)) @@ -7468,7 +7465,7 @@ mlir::LogicalResult mlir::pto::TRowProdOp::verify() { failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, tmpTy, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) return failure(); if (getElemTy(srcTy) != getElemTy(dstTy)) @@ -7488,7 +7485,7 @@ mlir::LogicalResult mlir::pto::TRowProdOp::verify() { failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, tmpTy, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) return failure(); if (getElemTy(srcTy) != getElemTy(dstTy)) @@ -7955,8 +7952,8 @@ mlir::LogicalResult mlir::pto::TSubOp::verify() { failed(verifyTileBufCommon(*this, t1, "src1")) || failed(verifyTileBufCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, t0, t1, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, t0, td, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, t0, t1, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, t0, td, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, t0, t1, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst"))) return failure(); @@ -7975,8 +7972,8 @@ mlir::LogicalResult mlir::pto::TSubOp::verify() { failed(verifyTileBufCommon(*this, t1, "src1")) || failed(verifyTileBufCommon(*this, td, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, t0, t1, "src0", "src1")) || - failed(verifyTileBufSameShapeAndElem(*this, t0, td, "src0", "dst")) || + if (failed(verifyTileBufSameElemType(*this, t0, t1, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, t0, td, "src0", "dst")) || failed(verifyTileBufSameValidShape(*this, t0, t1, "src0", "src1")) || failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst"))) return failure();