Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions lib/PTO/Transforms/PTOToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc,
static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter,
Location loc, Value v,
unsigned bitWidth);
static bool needsA5NoSplitVectorGuard(Operation *op);

static FailureOr<std::string> getTileSplitToken(int64_t split) {
switch (split) {
Expand Down Expand Up @@ -2517,6 +2518,9 @@ struct FuncToEmitC : public OpConversionPattern<func::FuncOp> {
emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"AICORE"}));
}

std::optional<StringRef> kernelKindMacro = getKernelKindMacro(op);
bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation());

// Inline the original body, then convert region/block argument types to
// match the converted signature (also covers CFG blocks introduced by
// pre-lowering, e.g. scf.while -> cf.br/cf.cond_br).
Expand All @@ -2531,8 +2535,6 @@ struct FuncToEmitC : public OpConversionPattern<func::FuncOp> {
*getTypeConverter(), &entryConv)))
return failure();

std::optional<StringRef> kernelKindMacro = getKernelKindMacro(op);

// Preserve the existing function prologue shape. `kernel_kind` functions are
// emitted with the same macro guard/reset sequence that used to come from
// early pto.section wrapping, but only after SCF pre-lowering has finished.
Expand All @@ -2547,13 +2549,18 @@ struct FuncToEmitC : public OpConversionPattern<func::FuncOp> {
rewriter.create<emitc::VerbatimOp>(op.getLoc(), "set_mask_norm();");
rewriter.create<emitc::VerbatimOp>(op.getLoc(),
"set_vector_mask(-1, -1);");
if (needsNoSplitGuard)
rewriter.create<emitc::VerbatimOp>(
op.getLoc(), "if (get_subblockid() == 0) {");
}
}
}

if (kernelKindMacro) {
Block &lastBlock = emitcFunc.getBody().back();
rewriter.setInsertionPoint(lastBlock.getTerminator());
if (*kernelKindMacro == "__DAV_VEC__" && needsNoSplitGuard)
rewriter.create<emitc::VerbatimOp>(op.getLoc(), "}");
std::string endMacro = "#endif // " + kernelKindMacro->str() + "\n";
rewriter.create<emitc::VerbatimOp>(op.getLoc(), endMacro);
}
Expand Down Expand Up @@ -8956,6 +8963,68 @@ class ArithCmpIToEmitC : public OpConversionPattern<arith::CmpIOp> {
//===----------------------------------------------------------------------===//
// Section Op Lowering
//===----------------------------------------------------------------------===//
static bool isA5NoSplitPipeOp(Operation *op) {
if (auto tpush = dyn_cast<pto::TPushOp>(op))
return tpush.getSplit() == 0;
if (auto tpop = dyn_cast<pto::TPopOp>(op))
return tpop.getSplit() == 0;
if (auto tfree = dyn_cast<pto::TFreeOp>(op))
return tfree.getSplit() == 0;
if (auto tpush = dyn_cast<pto::TPushToAivOp>(op))
return tpush.getSplit() == 0;
if (auto tpush = dyn_cast<pto::TPushToAicOp>(op))
return tpush.getSplit() == 0;
if (auto tpop = dyn_cast<pto::TPopFromAicOp>(op))
return tpop.getSplit() == 0;
if (auto tpop = dyn_cast<pto::TPopFromAivOp>(op))
return tpop.getSplit() == 0;
if (auto tfree = dyn_cast<pto::TFreeFromAicOp>(op))
return tfree.getSplit() == 0;
if (auto tfree = dyn_cast<pto::TFreeFromAivOp>(op))
return tfree.getSplit() == 0;
return false;
}

static bool hasExplicitSubblockControl(Operation *op) {
bool hasControl = false;
op->walk([&](Operation *nested) {
if (isa<pto::GetSubBlockIdxOp, pto::GetSubBlockNumOp>(nested)) {
hasControl = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
return hasControl;
}

static bool needsA5NoSplitVectorGuard(Operation *op) {
auto arch = getTargetArch(op);
if (arch != PTOArch::A5)
return false;
bool isVectorScope = isa<pto::SectionVectorOp>(op);
if (auto func = dyn_cast<func::FuncOp>(op)) {
if (auto kernelKindAttr =
func->getAttrOfType<FunctionKernelKindAttr>(
FunctionKernelKindAttr::name)) {
isVectorScope =
kernelKindAttr.getKernelKind() == FunctionKernelKind::Vector;
}
}
if (!isVectorScope)
return false;
if (hasExplicitSubblockControl(op))
return false;

bool hasNoSplitPipe = false;
op->walk([&](Operation *nested) {
if (!isA5NoSplitPipeOp(nested))
return WalkResult::advance();
hasNoSplitPipe = true;
return WalkResult::interrupt();
});
return hasNoSplitPipe;
}

template <typename SectionOpTy>
struct SectionToEmitC : public OpConversionPattern<SectionOpTy> {
using OpConversionPattern<SectionOpTy>::OpConversionPattern;
Expand All @@ -8972,6 +9041,7 @@ struct SectionToEmitC : public OpConversionPattern<SectionOpTy> {
matchAndRewrite(SectionOpTy op, typename SectionOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation());

std::string startMacro = "\n#if defined(" + getMacroName() + ")";
rewriter.create<emitc::VerbatimOp>(loc, startMacro);
Expand All @@ -8984,11 +9054,19 @@ struct SectionToEmitC : public OpConversionPattern<SectionOpTy> {
rewriter.create<emitc::VerbatimOp>(loc, "set_vector_mask(-1, -1);");
}

if (needsNoSplitGuard) {
rewriter.create<emitc::VerbatimOp>(
loc, "if (get_subblockid() == 0) {");
}

Block &innerBlock = op.getBody().front();
if (!innerBlock.empty()) {
rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{});
}

if (needsNoSplitGuard)
rewriter.create<emitc::VerbatimOp>(loc, "}");

std::string endMacro = "#endif // " + getMacroName() + "\n";
rewriter.create<emitc::VerbatimOp>(loc, endMacro);

Expand Down
2 changes: 2 additions & 0 deletions test/basic/tpush_tpop_frontend_lowering_a5.pto
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ module {
// A5: TFREE<TPipe<0, Direction::DIR_BOTH, 1024, 4>, TileSplitAxis::TILE_NO_SPLIT>(

// A5-LABEL: AICORE void vector_kernel(
// A5: if (get_subblockid() == 0) {
// A5: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4>(
// A5: Tile<TileType::Vec, float, 16, 16, BLayout::RowMajor, 16, 16, SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null> {{v[0-9]+}};
// A5: Tile<TileType::Vec, float, 16, 16, BLayout::ColMajor, 16, 16, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> {{v[0-9]+}};
Expand All @@ -75,3 +76,4 @@ module {
// A5: Tile<TileType::Vec, float, 16, 16, BLayout::RowMajor, 16, 16, SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null> {{v[0-9]+}};
// A5: TNEG(
// A5: TFREE<TPipe<0, Direction::DIR_BOTH, 1024, 4>, TileSplitAxis::TILE_NO_SPLIT>(
// A5: }
Loading
Loading