diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index e19bb6229..a4a79c7c0 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -6086,7 +6086,7 @@ pto.tscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>) ##### `pto.mgather` - Gather-Load from Global Memory -**Summary:** Loads elements from global memory into a tile using per-element indices. +**Summary:** Loads elements from a global table into a VEC tile using per-element indices. Supports an optional A5-only out-of-bounds mode that lowers to the corresponding `MGATHER<...>` template overload. **Semantics:** @@ -6096,18 +6096,34 @@ dst[i, j] = mem[idx[i, j]] **Arguments:** -| Name | Type | Description | -|------|------|-------------| -| `mem` | `AnyMemRef/pto.tile_buf` | Source memory | -| `idx` | `pto.tile_buf` | Index tile | -| `dst` | `pto.tile_buf` | Destination tile | +| Name | Type | Default | Description | +|------|------|---------|-------------| +| `mem` | `!pto.partition_tensor_view<...>` / GM memref | `NA` | Global source table | +| `idx` | `pto.tile_buf` | `NA` | Index tile | +| `dst` | `pto.tile_buf` | `NA` | Destination VEC tile | +| `gatherOob` | `#pto` | `undefined` | A5-only out-of-bounds mode (`undefined/clamp/wrap/zero`) | **Results:** None. Writes into `dst` via DPS pattern. **Constraints & Verification:** -- Index interpretation is target-defined. The CPU simulator treats indices as linear element indices into `src.data()`. -- No bounds checks are enforced on `indexes` by the CPU simulator. +- **Types (data and indices)** + - `mem` and `dst` must have the **same element type**. Supported element types: `i8`/`i16`/`i32`/`f16`/`bf16`/`f32`. On **A5** targets, `float8_e4m3` / `float8_e5m2` family element types are also supported. + - `idx` element type must be signless `i32`. + +- **Tile / memory roles** + - `dst` and `idx` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`. + - `mem` must denote a GlobalTensor in GM memory. + - `mem` must use `ND` layout when layout can be inferred. + +- **Shape** + - `dst row == idx row`. + - `idx column == 1` or `idx column == dst column`. + - If `mem` is a rank-5 static GM memref, it must satisfy `<1, 1, 1, Rows, RowWidth>`. + +- **Out-of-bounds mode** + - Default `gatherOob = undefined` lowers to the default `MGATHER(dst, mem, idx)` overload. + - Non-default `gatherOob` values are only supported on **A5** and lower to `MGATHER(dst, mem, idx)`. **Hardware Mapping:** @@ -6118,13 +6134,17 @@ dst[i, j] = mem[idx[i, j]] ```mlir pto.mgather ins(%mem, %idx : memref<...>, !pto.tile_buf<...>) outs(%dst : !pto.tile_buf<...>) + +pto.mgather ins(%mem, %idx : memref<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) + {gatherOob = #pto} ``` --- ##### `pto.mscatter` - Scatter-Store to Global Memory -**Summary:** Stores elements from a tile into global memory using per-element indices. +**Summary:** Stores elements from a VEC tile into a global table using per-element indices. Supports optional A5-only atomic and out-of-bounds modes that lower to the corresponding `MSCATTER<...>` template overload family. **Semantics:** @@ -6134,18 +6154,41 @@ mem[idx[i, j]] = src[i, j] **Arguments:** -| Name | Type | Description | -|------|------|-------------| -| `src` | `pto.tile_buf` | Source tile | -| `idx` | `pto.tile_buf` | Index tile | -| `mem` | `AnyMemRef/pto.tile_buf` | Destination memory | +| Name | Type | Default | Description | +|------|------|---------|-------------| +| `src` | `pto.tile_buf` | `NA` | Source VEC tile | +| `idx` | `pto.tile_buf` | `NA` | Index tile | +| `mem` | `!pto.partition_tensor_view<...>` / GM memref | `NA` | Global destination table | +| `scatterAtomicOp` | `#pto` | `none` | A5-only atomic mode (`none/add/max/min`) | +| `scatterOob` | `#pto` | `undefined` | A5-only out-of-bounds mode (`undefined/skip/clamp/wrap`) | **Results:** None. Writes into `mem` via DPS pattern. **Constraints & Verification:** -- Index interpretation is target-defined. The CPU simulator treats indices as linear element indices into `dst.data()`. -- No bounds checks are enforced on `indexes` by the CPU simulator. +- **Types (data and indices)** + - `src` and `mem` must have the **same element type**. Supported element types: `i8`/`i16`/`i32`/`f16`/`bf16`/`f32`. On **A5** targets, `float8_e4m3` / `float8_e5m2` family element types are also supported. + - `idx` element type must be signless `i32`. + +- **Tile / memory roles** + - `src` and `idx` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`. + - `mem` must denote a GlobalTensor in GM memory. + - `mem` must use `ND` layout when layout can be inferred. + +- **Shape** + - `src row == idx row`. + - `idx column == 1` or `idx column == src column`. + - If `mem` is a rank-5 static GM memref, it must satisfy `<1, 1, 1, Rows, RowWidth>`. + +- **Atomic modes** + - Default `scatterAtomicOp = none` lowers to the default `MSCATTER(mem, src, idx)` overload. + - Non-default `scatterAtomicOp` values are only supported on **A5**. + - `add` requires `i32`/`f16`/`f32`. + - `max`/`min` require signless `i32` or `f32`. + +- **Out-of-bounds modes** + - Default `scatterOob = undefined` lowers to the 1-template-parameter `MSCATTER(mem, src, idx)` form when only atomic is specified, or to the default overload when both attrs are default. + - Non-default `scatterOob` values are only supported on **A5** and lower to `MSCATTER(mem, src, idx)`. **Hardware Mapping:** @@ -6156,6 +6199,15 @@ mem[idx[i, j]] = src[i, j] ```mlir pto.mscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>) outs(%mem : memref<...>) + +pto.mscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%mem : memref<...>) + {scatterAtomicOp = #pto} + +pto.mscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%mem : memref<...>) + {scatterAtomicOp = #pto, + scatterOob = #pto} ``` --- diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index 1a975fea1..688a68fc6 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -404,6 +404,42 @@ def PTO_ReluPreModeAttr : EnumAttr, + I32EnumAttrCase<"Clamp", 1, "clamp">, + I32EnumAttrCase<"Wrap", 2, "wrap">, + I32EnumAttrCase<"Zero", 3, "zero"> + ]>; + +def PTO_GatherOOBAttr : EnumAttr { + let summary = "MGATHER out-of-bounds handling mode"; +} + +def PTO_ScatterAtomicOpEnum : PTO_I32Enum< + "ScatterAtomicOp", "PTO MSCATTER atomic mode", [ + I32EnumAttrCase<"None", 0, "none">, + I32EnumAttrCase<"Add", 1, "add">, + I32EnumAttrCase<"Max", 2, "max">, + I32EnumAttrCase<"Min", 3, "min"> + ]>; + +def PTO_ScatterAtomicOpAttr : EnumAttr { + let summary = "MSCATTER atomic mode"; +} + +def PTO_ScatterOOBEnum : PTO_I32Enum< + "ScatterOOB", "PTO MSCATTER out-of-bounds mode", [ + I32EnumAttrCase<"Undefined", 0, "undefined">, + I32EnumAttrCase<"Skip", 1, "skip">, + I32EnumAttrCase<"Clamp", 2, "clamp">, + I32EnumAttrCase<"Wrap", 3, "wrap"> + ]>; + +def PTO_ScatterOOBAttr : EnumAttr { + let summary = "MSCATTER out-of-bounds handling mode"; +} + def PTO_AccToVecMode_Enum : PTO_I32Enum<"AccToVecMode", "TMOV acc-to-vec mode", [ I32EnumAttrCase<"SingleModeVec0", 0, "single_mode_vec0">, I32EnumAttrCase<"SingleModeVec1", 1, "single_mode_vec1">, diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 193fe44c5..1250a44cf 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -2166,7 +2166,8 @@ def MGatherOp : PTO_TOp<"mgather", [ let arguments = (ins PTODpsType:$mem, PTODpsType:$idx, - PTODpsType:$dst); + PTODpsType:$dst, + DefaultValuedAttr:$gatherOob); let results = (outs); @@ -2261,7 +2262,9 @@ def MScatterOp : PTO_TOp<"mscatter", [ let arguments = (ins PTODpsType:$src, PTODpsType:$idx, - PTODpsType:$mem // outs target + PTODpsType:$mem, // outs target + DefaultValuedAttr:$scatterAtomicOp, + DefaultValuedAttr:$scatterOob ); let results = (outs); diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index a419713ce..e18c32e93 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -2114,6 +2114,107 @@ static bool isSupportedVecElemType(Type ty, bool allowBf16, return false; } +static bool isSupportedMGatherMScatterIndexElemType(Type ty) { + auto it = dyn_cast(ty); + if (!it || it.getWidth() != 32) + return false; + return it.isSignless(); +} + +static bool isSupportedMGatherMScatterPayloadElemType(Operation *op, Type ty) { + if (isSupportedVecElemType(ty, /*allowBf16=*/true, /*allowInt8=*/true)) + return true; + if (!isTargetArchA5(op)) + return false; + return ty.isFloat8E4M3() || ty.isFloat8E4M3FN() || ty.isFloat8E4M3FNUZ() || + ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E5M2() || ty.isFloat8E5M2FNUZ(); +} + +static bool isSupportedMScatterAtomicPayloadElemType(Type ty, + pto::ScatterAtomicOp atomic) { + auto intTy = dyn_cast(ty); + switch (atomic) { + case pto::ScatterAtomicOp::None: + return true; + case pto::ScatterAtomicOp::Add: + return ty.isF16() || ty.isF32() || + (intTy && intTy.getWidth() == 32 && intTy.isSignless()); + case pto::ScatterAtomicOp::Max: + case pto::ScatterAtomicOp::Min: + return ty.isF32() || + (intTy && intTy.getWidth() == 32 && intTy.isSignless()); + } + llvm_unreachable("unknown ScatterAtomicOp"); +} + +static LogicalResult verifyMGatherMScatterMemOperand(Operation *op, + Value memValue, + Type dataElemTy, + StringRef dataOperandLabel) { + Type memTy = memValue.getType(); + Type memElem = getElemTy(memTy); + if (!memElem || memElem != dataElemTy) + return op->emitOpError() << "expects mem element type to match " + << dataOperandLabel << " element type"; + + if (isa(memTy)) { + if (auto layout = getLogicalViewLayout(memValue)) { + if (*layout != pto::Layout::ND) + return op->emitOpError( + "expects mem partition view to use ND logical layout when layout " + "can be inferred"); + } + return success(); + } + + if (auto mr = dyn_cast(memTy)) { + auto as = getPTOMemorySpaceEnum(mr); + if (!as || (*as != pto::AddressSpace::GM && + *as != pto::AddressSpace::Zero)) + return op->emitOpError( + "expects mem memref to use GM or zero address space"); + if (mr.getRank() == 5) { + auto shape = mr.getShape(); + bool allStatic = true; + for (int64_t d : shape) + if (d == ShapedType::kDynamic) + allStatic = false; + if (allStatic && (shape[0] != 1 || shape[1] != 1 || shape[2] != 1)) + return op->emitOpError( + "expects rank-5 GM memref leading dimensions to be [1,1,1,...] " + "(GlobalTensor table shape)"); + } + return success(); + } + + return op->emitOpError( + "expects mem to be !pto.partition_tensor_view or a GM/ZERO memref"); +} + +static LogicalResult verifyMGatherMScatterTileShape(Operation *op, Type dataTy, + Type idxTy, + StringRef dataName) { + auto dataShape = getShapeVec(dataTy); + auto idxShape = getShapeVec(idxTy); + if (dataShape.size() != 2 || idxShape.size() != 2) + return op->emitOpError() << "expects " << dataName + << " and idx to be rank-2"; + + if (dataShape[0] != ShapedType::kDynamic && + idxShape[0] != ShapedType::kDynamic && dataShape[0] != idxShape[0]) + return op->emitOpError() << "expects " << dataName + << " and idx static row dimensions to match"; + + int64_t dataCols = dataShape[1]; + int64_t idxCols = idxShape[1]; + if (idxCols != ShapedType::kDynamic && dataCols != ShapedType::kDynamic && + idxCols != 1 && idxCols != dataCols) + return op->emitOpError() << "expects idx cols to be 1 or equal to " + << dataName << " cols"; + + return success(); +} + static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name) { auto tb = dyn_cast(ty); if (tb) { @@ -5293,13 +5394,54 @@ LogicalResult TGetValOp::verify() { LogicalResult MScatterOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) return success(); - int64_t srcrank = getPTOTypeRank(getSrc().getType()); - int64_t memrank = getPTOTypeRank(getMem().getType()); - int64_t idxrank = getPTOTypeRank(getIdx().getType()); - - if (memrank == -1 || idxrank == -1 || srcrank == -1) { - return emitOpError("src, idx, mem does not support PTO type"); + + if (!isTargetArchA5(getOperation())) + return emitOpError("pto.mscatter is only supported on A5 targets"); + + Type srcTy = getSrc().getType(); + Type idxTy = getIdx().getType(); + Type memTy = getMem().getType(); + + if (getPTOTypeRank(srcTy) == -1 || getPTOTypeRank(idxTy) == -1 || + getPTOTypeRank(memTy) == -1) + return emitOpError("expects src, idx, and mem to use supported PTO shapes"); + + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, idxTy, "idx"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type idxElem = getElemTy(idxTy); + if (!srcElem || !idxElem) + return emitOpError("failed to resolve element types for src or idx"); + + if (!isSupportedMGatherMScatterPayloadElemType(getOperation(), srcElem)) + return emitOpError( + "expects src element type to be i8/ui8/i16/ui16/i32/ui32/f16/bf16/f32 " + "(and on A5 targets also float8_e4m3/float8_e5m2 family types)"); + + if (!isSupportedMGatherMScatterIndexElemType(idxElem)) + return emitOpError("expects idx element type to be signless i32"); + + if (failed(verifyMGatherMScatterMemOperand(getOperation(), getMem(), srcElem, + "src"))) + return failure(); + + if (getScatterAtomicOp() != pto::ScatterAtomicOp::None || + getScatterOob() != pto::ScatterOOB::Undefined) { + if (!isTargetArchA5(getOperation())) + return emitOpError( + "expects non-default scatterAtomicOp/scatterOob only on A5 targets"); } + + if (!isSupportedMScatterAtomicPayloadElemType(srcElem, getScatterAtomicOp())) + return emitOpError( + "expects scatterAtomicOp-compatible src element type: add supports " + "i32/ui32/f16/f32, max/min support signless i32/f32"); + + if (failed(verifyMGatherMScatterTileShape(getOperation(), srcTy, idxTy, "src"))) + return failure(); + return success(); } @@ -5307,13 +5449,46 @@ LogicalResult MScatterOp::verify() { LogicalResult MGatherOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) return success(); - int64_t memrank = getPTOTypeRank(getMem().getType()); - int64_t idxrank = getPTOTypeRank(getIdx().getType()); - int64_t dstrank = getPTOTypeRank(getDst().getType()); - if (memrank == -1 || idxrank == -1 || memrank == -1) { - return emitOpError("mem, idx and dst does not support PTO type"); - } + if (!isTargetArchA5(getOperation())) + return emitOpError("pto.mgather is only supported on A5 targets"); + + Type memTy = getMem().getType(); + Type idxTy = getIdx().getType(); + Type dstTy = getDst().getType(); + + if (getPTOTypeRank(memTy) == -1 || getPTOTypeRank(idxTy) == -1 || + getPTOTypeRank(dstTy) == -1) + return emitOpError("expects mem, idx, and dst to use supported PTO shapes"); + + if (failed(verifyNDStyleVecTile(*this, dstTy, "dst")) || + failed(verifyNDStyleVecTile(*this, idxTy, "idx"))) + return failure(); + + Type dstElem = getElemTy(dstTy); + Type idxElem = getElemTy(idxTy); + if (!dstElem || !idxElem) + return emitOpError("failed to resolve element types for dst or idx"); + + if (!isSupportedMGatherMScatterPayloadElemType(getOperation(), dstElem)) + return emitOpError( + "expects dst element type to be i8/ui8/i16/ui16/i32/ui32/f16/bf16/f32 " + "(and on A5 targets also float8_e4m3/float8_e5m2 family types)"); + + if (!isSupportedMGatherMScatterIndexElemType(idxElem)) + return emitOpError("expects idx element type to be signless i32"); + + if (failed(verifyMGatherMScatterMemOperand(getOperation(), getMem(), dstElem, + "dst"))) + return failure(); + + if (getGatherOob() != pto::GatherOOB::Undefined && + !isTargetArchA5(getOperation())) + return emitOpError( + "expects non-default gatherOob only on A5 targets"); + + if (failed(verifyMGatherMScatterTileShape(getOperation(), dstTy, idxTy, "dst"))) + return failure(); return success(); } diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index f091e3a9b..6cf321842 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -99,6 +99,14 @@ static Value peelUnrealized(Value v) { return v; } +static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + MemRefType mrTy, Operation *anchor); + +static Value maybeWrapGlobalMemrefAsGlobalTensor( + ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, + Type originalType, Operation *anchor); + static std::optional getLayoutAttrFromOp(Operation *op) { if (!op) return std::nullopt; @@ -2344,8 +2352,7 @@ struct ArithTruncIToEmitC : public OpConversionPattern { } }; //===----------------------------------------------------------------------===// -// pto.mgather lowering -> MGATHER(dst, mem, idx) -// %dst = pto.mgather %mem, %idx : memref<...>, memref<...> -> memref<...> +// pto.mgather lowering -> MGATHER(dst, src, indexes) (pto-isa) //===----------------------------------------------------------------------===// struct PTOMGatherToMGATHER : public OpConversionPattern { @@ -2353,17 +2360,39 @@ struct PTOMGatherToMGATHER : public OpConversionPattern { LogicalResult matchAndRewrite(pto::MGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); Value mem = peelUnrealized(adaptor.getMem()); + Value idx = peelUnrealized(adaptor.getIdx()); Value dst = peelUnrealized(adaptor.getDst()); - // pto-isa currently has no NPU implementation for MGATHER/MSCATTER. - // Fallback to a smoke-friendly lowering to keep compile/run coverage. + Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( + rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); + + ArrayAttr templateArgs; + if (op.getGatherOob() != pto::GatherOOB::Undefined) { + auto gatherOobTok = [&](pto::GatherOOB mode) -> StringRef { + switch (mode) { + case pto::GatherOOB::Undefined: + return "pto::GatherOOB::Undefined"; + case pto::GatherOOB::Clamp: + return "pto::GatherOOB::Clamp"; + case pto::GatherOOB::Wrap: + return "pto::GatherOOB::Wrap"; + case pto::GatherOOB::Zero: + return "pto::GatherOOB::Zero"; + } + llvm_unreachable("unknown GatherOOB"); + }; + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, gatherOobTok(op.getGatherOob()))}); + } + rewriter.create( - op.getLoc(), TypeRange{}, "TLOAD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, mem}); + op.getLoc(), TypeRange{}, "MGATHER", + ArrayAttr{}, templateArgs, + ValueRange{dst, memArg, idx}); - if (op->getNumResults() == 0) { + if (op->getNumResults() == 0) { rewriter.eraseOp(op); } else { rewriter.replaceOp(op, dst); @@ -3259,6 +3288,28 @@ static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, return gtInst.getResult(0); } +static Value maybeWrapGlobalMemrefAsGlobalTensor( + ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, + Type originalType, Operation *anchor) { + auto mrTy = dyn_cast(originalType); + if (!mrTy) + return loweredValue; + + bool isGlobal = true; + if (auto asAttr = + dyn_cast_or_null(mrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (!isGlobal) + return loweredValue; + + if (Value gt = + buildGlobalTensorFromMemref(rewriter, loc, loweredValue, mrTy, anchor)) + return gt; + return loweredValue; +} + static Value castToGMBytePointer(ConversionPatternRewriter &rewriter, Location loc, Value value) { auto *ctx = rewriter.getContext(); @@ -4774,15 +4825,57 @@ struct PTOMScatterToMSCATTER : public OpConversionPattern { LogicalResult matchAndRewrite(pto::MScatterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); Value src = peelUnrealized(adaptor.getSrc()); + Value idx = peelUnrealized(adaptor.getIdx()); Value mem = peelUnrealized(adaptor.getMem()); - // pto-isa currently has no NPU implementation for MGATHER/MSCATTER. - // Fallback to a smoke-friendly lowering to keep compile/run coverage. + Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( + rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); + + auto scatterAtomicTok = [&](pto::ScatterAtomicOp atomic) -> StringRef { + switch (atomic) { + case pto::ScatterAtomicOp::None: + return "pto::ScatterAtomicOp::None"; + case pto::ScatterAtomicOp::Add: + return "pto::ScatterAtomicOp::Add"; + case pto::ScatterAtomicOp::Max: + return "pto::ScatterAtomicOp::Max"; + case pto::ScatterAtomicOp::Min: + return "pto::ScatterAtomicOp::Min"; + } + llvm_unreachable("unknown ScatterAtomicOp"); + }; + auto scatterOobTok = [&](pto::ScatterOOB mode) -> StringRef { + switch (mode) { + case pto::ScatterOOB::Undefined: + return "pto::ScatterOOB::Undefined"; + case pto::ScatterOOB::Skip: + return "pto::ScatterOOB::Skip"; + case pto::ScatterOOB::Clamp: + return "pto::ScatterOOB::Clamp"; + case pto::ScatterOOB::Wrap: + return "pto::ScatterOOB::Wrap"; + } + llvm_unreachable("unknown ScatterOOB"); + }; + + SmallVector templateArgVec; + if (op.getScatterAtomicOp() != pto::ScatterAtomicOp::None || + op.getScatterOob() != pto::ScatterOOB::Undefined) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, scatterAtomicTok(op.getScatterAtomicOp()))); + if (op.getScatterOob() != pto::ScatterOOB::Undefined) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, scatterOobTok(op.getScatterOob()))); + } + ArrayAttr templateArgs = + templateArgVec.empty() ? ArrayAttr{} : rewriter.getArrayAttr(templateArgVec); + rewriter.create( - op.getLoc(), TypeRange{}, "TSTORE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{mem, src}); + op.getLoc(), TypeRange{}, "MSCATTER", + ArrayAttr{}, templateArgs, + ValueRange{memArg, src, idx}); rewriter.eraseOp(op); return success(); diff --git a/test/samples/runop.sh b/test/samples/runop.sh index e0af12d8e..0ba87ebcf 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -298,6 +298,10 @@ process_one_dir() { if [[ "$base" == "test_intercore_sync_a3_missing_setffts" && "$(printf '%s' "$target_arch" | tr '[:upper:]' '[:lower:]')" == "a3" ]]; then expect_fail=1 fi + if [[ ("$base" == "mgather" || "$base" == "mscatter") && \ + "$(printf '%s' "$target_arch" | tr '[:upper:]' '[:lower:]')" == "a3" ]]; then + expect_fail=1 + fi mlir="${out_subdir}/${base}-pto-ir.pto" cpp="${out_subdir}/${base}-pto.cpp" @@ -349,6 +353,14 @@ process_one_dir() { continue fi fi + if [[ ("$base" == "mgather" || "$base" == "mscatter") && \ + "$(printf '%s' "$target_arch" | tr '[:upper:]' '[:lower:]')" == "a3" ]]; then + if ! grep -Eq "pto\\.m(gather|scatter) is only supported on A5 targets" "${ptoas_log}"; then + echo -e "${A}(${base}.py)\tFAIL\texpected A5-only diagnostic not found" + overall=1 + continue + fi + fi echo -e "${A}(${base}.py)\tXFAIL\tptoas failed as expected" continue fi