diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 7f63a107a..3fe29465b 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -1384,6 +1384,7 @@ def InitializeL2G2LPipeOp : PTO_Op<"initialize_l2g2l_pipe", [ I32Attr:$slot_num, OptionalAttr:$local_slot_num, OptionalAttr:$flag_base, + OptionalAttr:$nosplit, AnyType:$gm_addr, AnyType:$local_addr, Optional:$peer_local_addr @@ -1398,6 +1399,7 @@ def InitializeL2G2LPipeOp : PTO_Op<"initialize_l2g2l_pipe", [ `slot_num` `=` $slot_num (`,` `local_slot_num` `=` $local_slot_num^)? (`,` `flag_base` `=` $flag_base^)? + (`,` `nosplit` `=` $nosplit^)? `}` `(` $gm_addr `:` type($gm_addr) `,` $local_addr `:` type($local_addr) (`,` $peer_local_addr^ `:` type($peer_local_addr))? `)` @@ -1415,6 +1417,7 @@ def InitializeL2LPipeOp : PTO_Op<"initialize_l2l_pipe", [ I32Attr:$slot_size, I32Attr:$slot_num, OptionalAttr:$flag_base, + OptionalAttr:$nosplit, AnyType:$local_addr, Optional:$peer_local_addr ); @@ -1427,6 +1430,7 @@ def InitializeL2LPipeOp : PTO_Op<"initialize_l2l_pipe", [ `slot_size` `=` $slot_size `,` `slot_num` `=` $slot_num (`,` `flag_base` `=` $flag_base^)? + (`,` `nosplit` `=` $nosplit^)? `}` `(` $local_addr `:` type($local_addr) (`,` $peer_local_addr^ `:` type($peer_local_addr))? `)` diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 37979bf21..6b2fd6e56 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -145,6 +145,9 @@ def PTOResolveReservedBuffers : Pass<"pto-resolve-reserved-buffers", "ModuleOp"> Runs after `pto-plan-memory`. Assumes `pto.reserve_buffer` base addresses have already been planned, then: - aligns missing `flag_base` attrs for peer internal pipe init ops + - infers implicit `nosplit = true` for internal pipe init ops when any + downstream `pto.tpush` / `pto.tpop` / `pto.tfree` user on the same + logical pipe has `split = 0` - rejects internal pipe init ops without explicit `flag_base` when their `local_addr` cannot be traced back to `pto.reserve_buffer` / `pto.import_reserved_buffer` diff --git a/lib/PTO/Transforms/PTOLowerFrontendPipeOpsPass.cpp b/lib/PTO/Transforms/PTOLowerFrontendPipeOpsPass.cpp index d0ab9f34c..529c42918 100644 --- a/lib/PTO/Transforms/PTOLowerFrontendPipeOpsPass.cpp +++ b/lib/PTO/Transforms/PTOLowerFrontendPipeOpsPass.cpp @@ -55,6 +55,7 @@ static FailureOr lowerFrontendInitOp(InitOpT initOp, if (arch == PTOArch::A5) { auto pipe = rewriter.create( loc, pipeTy, dirAttr, slotSizeAttr, slotNumAttr, IntegerAttr{}, + BoolAttr{}, localAddr, /*peer_local_addr=*/Value{}); return pipe.getPipe(); } @@ -67,7 +68,7 @@ static FailureOr lowerFrontendInitOp(InitOpT initOp, auto localSlotNumAttr = rewriter.getI32IntegerAttr(slotNum); auto pipe = rewriter.create( loc, pipeTy, dirAttr, slotSizeAttr, slotNumAttr, localSlotNumAttr, - IntegerAttr{}, initOp.getGmSlotBuffer(), localAddr, + IntegerAttr{}, BoolAttr{}, initOp.getGmSlotBuffer(), localAddr, /*peer_local_addr=*/Value{}); return pipe.getPipe(); }; @@ -101,6 +102,7 @@ static FailureOr lowerFrontendInitOp(InitOpT initOp, if (arch == PTOArch::A5) { auto pipe = rewriter.create( loc, pipeTy, dirAttr, slotSizeAttr, slotNumAttr, IntegerAttr{}, + BoolAttr{}, c2vAddr, v2cAddr); handles.c2vPipe = pipe.getPipe(); handles.v2cPipe = pipe.getPipe(); @@ -113,7 +115,8 @@ static FailureOr lowerFrontendInitOp(InitOpT initOp, auto localSlotNumAttr = rewriter.getI32IntegerAttr(4); auto pipe = rewriter.create( loc, pipeTy, dirAttr, slotSizeAttr, slotNumAttr, localSlotNumAttr, - IntegerAttr{}, initOp.getGmSlotBuffer(), c2vAddr, v2cAddr); + IntegerAttr{}, BoolAttr{}, initOp.getGmSlotBuffer(), c2vAddr, + v2cAddr); handles.c2vPipe = pipe.getPipe(); handles.v2cPipe = pipe.getPipe(); handles.anchorOp = pipe.getOperation(); diff --git a/lib/PTO/Transforms/PTOResolveReservedBuffersPass.cpp b/lib/PTO/Transforms/PTOResolveReservedBuffersPass.cpp index 95fc17936..ec91b897a 100644 --- a/lib/PTO/Transforms/PTOResolveReservedBuffersPass.cpp +++ b/lib/PTO/Transforms/PTOResolveReservedBuffersPass.cpp @@ -55,6 +55,7 @@ struct PipeInitInfo { Operation *op = nullptr; func::FuncOp funcOp; int8_t dirMask = 0; + bool inferredNoSplit = false; }; template static Value getLocalAddrOperand(InitOpT op) { @@ -72,6 +73,35 @@ static void setFlagBaseAttr(InitOpT op, IntegerAttr attr) { op->setAttr("flag_base", attr); } +template +static void setNoSplitAttr(InitOpT op, BoolAttr attr) { + op->setAttr("nosplit", attr); +} + +template static Value getPipeResult(InitOpT op) { + return op.getPipe(); +} + +static bool inferNoSplitFromPipeUsers(Value pipe) { + for (Operation *user : pipe.getUsers()) { + if (auto pushOp = dyn_cast(user)) { + if (pushOp.getSplit() == 0) + return true; + continue; + } + if (auto popOp = dyn_cast(user)) { + if (popOp.getSplit() == 0) + return true; + continue; + } + if (auto freeOp = dyn_cast(user)) { + if (freeOp.getSplit() == 0) + return true; + } + } + return false; +} + static ReserveBufferOp findReserveBufferByName(func::FuncOp funcOp, StringRef name) { // Reserve-buffer lookup is name-based because import_reserved_buffer only @@ -140,6 +170,7 @@ struct PTOResolveReservedBuffersPass info.op = initOp.getOperation(); info.funcOp = initOp->template getParentOfType(); info.dirMask = initOp.getDirMask(); + info.inferredNoSplit = inferNoSplitFromPipeUsers(getPipeResult(initOp)); // Record one address into the keyed maps. Returns true when the // address comes from reserve_buffer / import_reserved_buffer. @@ -187,6 +218,7 @@ struct PTOResolveReservedBuffersPass } OpBuilder builder(moduleOp.getContext()); + std::set groupedNoSplitResolved; for (const auto &it : keyedInits) { const auto &inits = it.second; // flag_base is always 0: single-direction pipes use flag pair 0/1; @@ -221,18 +253,44 @@ struct PTOResolveReservedBuffersPass chosenBase = desiredBase; auto flagBaseAttr = builder.getI32IntegerAttr(*chosenBase); + bool groupNoSplit = false; + for (const PipeInitInfo &info : inits) { + if (info.inferredNoSplit) { + groupNoSplit = true; + break; + } + } for (const PipeInitInfo &info : inits) { if (auto initOp = dyn_cast(info.op)) { if (!getFlagBaseAttr(initOp)) setFlagBaseAttr(initOp, flagBaseAttr); + if (groupNoSplit) + setNoSplitAttr(initOp, builder.getBoolAttr(true)); + groupedNoSplitResolved.insert(info.op); continue; } auto initOp = cast(info.op); if (!getFlagBaseAttr(initOp)) setFlagBaseAttr(initOp, flagBaseAttr); + if (groupNoSplit) + setNoSplitAttr(initOp, builder.getBoolAttr(true)); + groupedNoSplitResolved.insert(info.op); } } + moduleOp.walk([&](InitializeL2LPipeOp initOp) { + if (groupedNoSplitResolved.count(initOp.getOperation())) + return; + if (inferNoSplitFromPipeUsers(initOp.getPipe())) + setNoSplitAttr(initOp, builder.getBoolAttr(true)); + }); + moduleOp.walk([&](InitializeL2G2LPipeOp initOp) { + if (groupedNoSplitResolved.count(initOp.getOperation())) + return; + if (inferNoSplitFromPipeUsers(initOp.getPipe())) + setNoSplitAttr(initOp, builder.getBoolAttr(true)); + }); + return success(); } diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 9a1bb48be..586ff549e 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -372,12 +372,14 @@ getTPipeDirectionToken(bool isL2G2L, int8_t dirMask, PTOArch targetArch) { static std::string buildTPipeToken(int32_t flagBase, llvm::StringRef dirTok, int32_t slotSize, int32_t slotNum, - std::optional localSlotNum) { + std::optional localSlotNum, + bool nosplit) { std::string token = "TPipe<" + std::to_string(flagBase) + ", " + dirTok.str() + ", " + std::to_string(slotSize) + ", " + std::to_string(slotNum); if (localSlotNum) token += ", " + std::to_string(*localSlotNum); + token += nosplit ? ", true" : ", false"; token += ">"; return token; } @@ -395,8 +397,9 @@ static FailureOr buildTPipeTokenFromInitOp(Operation *op, ? initOp.getLocalSlotNumAttr().getInt() : initOp.getSlotNum(); return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), - initOp.getSlotNum(), localSlotNum); + initOp.getSlotSize(), initOp.getSlotNum(), + localSlotNum, initOp.getNosplitAttr() && + initOp.getNosplitAttr().getValue()); } if (auto initOp = dyn_cast(op)) { @@ -407,8 +410,9 @@ static FailureOr buildTPipeTokenFromInitOp(Operation *op, if (failed(dirTok)) return failure(); return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), - initOp.getSlotNum(), std::nullopt); + initOp.getSlotSize(), initOp.getSlotNum(), + std::nullopt, initOp.getNosplitAttr() && + initOp.getNosplitAttr().getValue()); } return failure(); diff --git a/test/basic/tpush_tpop_emitc.pto b/test/basic/tpush_tpop_emitc.pto index 8a6841d70..2074bc90d 100644 --- a/test/basic/tpush_tpop_emitc.pto +++ b/test/basic/tpush_tpop_emitc.pto @@ -36,8 +36,8 @@ module { // A3: const int32_t {{v[0-9]+}} = 16; // A3: const int64_t {{v[0-9]+}} = 0; // A3: #if defined(__DAV_CUBE__) -// A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_C2V, 1024, 8, 8>( -// A3: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>( +// A3: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( // A3: #endif // __DAV_CUBE__ // A3-LABEL: AICORE void vector_pop_gm( @@ -45,8 +45,8 @@ module { // A3: #if defined(__DAV_VEC__) // A3: set_mask_norm(); // A3: set_vector_mask(-1, -1); -// A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_C2V, 1024, 8, 8>( +// A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_C2V, 1024, 8, 8, false>( // A3: Tile {{v[0-9]+}}; -// A3: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>( -// A3: TFREE, TileSplitAxis::TILE_LEFT_RIGHT>( +// A3: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>( +// A3: TFREE, TileSplitAxis::TILE_LEFT_RIGHT>( // A3: #endif // __DAV_VEC__ diff --git a/test/basic/tpush_tpop_frontend_lowering_a3.pto b/test/basic/tpush_tpop_frontend_lowering_a3.pto index 4cc086057..188b8b03e 100644 --- a/test/basic/tpush_tpop_frontend_lowering_a3.pto +++ b/test/basic/tpush_tpop_frontend_lowering_a3.pto @@ -61,32 +61,32 @@ module { } // A3-LABEL: AICORE void cube_kernel(__gm__ float* -// A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4, 4>( -// A3: TPUSH +// A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4, 4, true>( +// A3: TPUSH // A3: Tile {{v[0-9]+}}; -// A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( // A3: Tile {{v[0-9]+}}; // A3: TMOV( -// A3: TFREE, TileSplitAxis::TILE_NO_SPLIT>( +// A3: TFREE, TileSplitAxis::TILE_NO_SPLIT>( // A3-LABEL: AICORE void vector_kernel(__gm__ float* -// A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4, 4>( +// A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4, 4, true>( // A3: Tile {{v[0-9]+}}; -// A3: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( -// A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A3: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( // A3: Tile {{v[0-9]+}}; // A3: TNEG( -// A3: TFREE, TileSplitAxis::TILE_NO_SPLIT>( +// A3: TFREE, TileSplitAxis::TILE_NO_SPLIT>( // SYNC-A3-LABEL: AICORE void cube_kernel(__gm__ float* -// SYNC-A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// SYNC-A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( // SYNC-A3: set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // SYNC-A3: Tile // SYNC-A3: wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // SYNC-A3: TMOV( // SYNC-A3-LABEL: AICORE void vector_kernel(__gm__ float* -// SYNC-A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// SYNC-A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( // SYNC-A3: set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); // SYNC-A3: Tile // SYNC-A3: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); diff --git a/test/basic/tpush_tpop_frontend_lowering_a5.pto b/test/basic/tpush_tpop_frontend_lowering_a5.pto index 429345221..a98ec8957 100644 --- a/test/basic/tpush_tpop_frontend_lowering_a5.pto +++ b/test/basic/tpush_tpop_frontend_lowering_a5.pto @@ -57,21 +57,21 @@ module { } // A5-LABEL: AICORE void cube_kernel( -// A5: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4>( -// A5: TPUSH +// A5: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4, true>( +// A5: TPUSH // A5: Tile {{v[0-9]+}}; -// A5: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A5: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( // A5: Tile {{v[0-9]+}}; // A5: TMOV( -// A5: TFREE, TileSplitAxis::TILE_NO_SPLIT>( +// A5: TFREE, TileSplitAxis::TILE_NO_SPLIT>( // A5-LABEL: AICORE void vector_kernel( -// A5: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4>( +// A5: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4, true>( // A5: Tile {{v[0-9]+}}; // A5: Tile {{v[0-9]+}}; // A5: TMOV( -// A5: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( -// A5: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A5: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A5: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( // A5: Tile {{v[0-9]+}}; // A5: TNEG( -// A5: TFREE, TileSplitAxis::TILE_NO_SPLIT>( +// A5: TFREE, TileSplitAxis::TILE_NO_SPLIT>( diff --git a/test/basic/tpush_tpop_frontend_nosplit_a5.pto b/test/basic/tpush_tpop_frontend_nosplit_a5.pto new file mode 100644 index 000000000..55eb28e74 --- /dev/null +++ b/test/basic/tpush_tpop_frontend_nosplit_a5.pto @@ -0,0 +1,52 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @cube_kernel() attributes {pto.kernel_kind = #pto.kernel_kind} { + %v2c_local = pto.reserve_buffer { + name = "v2c_fifo", + size = 4096, + location = #pto.address_space, + auto = true + } -> i32 + %c2v_import = pto.import_reserved_buffer { + name = "c2v_fifo", + peer_func = @vector_kernel + } -> i32 + pto.aic_initialize_pipe {dir_mask = 3, slot_size = 1024} + (c2v_consumer_buf = %c2v_import : i32, + v2c_consumer_buf = %v2c_local : i32) + + %acc_tile = pto.alloc_tile : !pto.tile_buf + pto.tpush_to_aiv(%acc_tile : !pto.tile_buf) {split = 1} + return + } + + func.func @vector_kernel() attributes {pto.kernel_kind = #pto.kernel_kind} { + %c2v_local = pto.reserve_buffer { + name = "c2v_fifo", + size = 4096, + location = #pto.address_space, + auto = true + } -> i32 + %v2c_import = pto.import_reserved_buffer { + name = "v2c_fifo", + peer_func = @cube_kernel + } -> i32 + pto.aiv_initialize_pipe {dir_mask = 3, slot_size = 1024} + (c2v_consumer_buf = %c2v_local : i32, + v2c_consumer_buf = %v2c_import : i32) + + %recv_tile = pto.tpop_from_aic {split = 0} + -> !pto.tile_buf + pto.tfree_from_aic {split = 1} + return + } +} + +// A5-LABEL: AICORE void cube_kernel( +// A5: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4, true>( +// A5: TPUSH + +// A5-LABEL: AICORE void vector_kernel( +// A5: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4, true>( +// A5: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>(