diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index cafdb784c..f43577bf8 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -61,6 +61,7 @@ createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption = {}); std::unique_ptr createPTORemoveRedundantBarrierPass(); std::unique_ptr createPTOViewToMemrefPass(); std::unique_ptr createInferPTOLayoutPass(); +std::unique_ptr createPTOCanonicalizeSubviewForTLoadPass(); // Declare register function void registerPTOPasses(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 37979bf21..b56b5290f 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -57,6 +57,23 @@ def InferPTOLayout : Pass<"pto-infer-layout", "func::FuncOp"> { let dependentDialects = ["pto::PTODialect", "arith::ArithDialect"]; } +def PTOCanonicalizeSubviewForTLoad + : Pass<"pto-canonicalize-subview-for-tload", "func::FuncOp"> { + let summary = "Canonicalize singleton-axis subviews for ND/DN->NZ-like tload"; + let description = [{ + Marks memref.subview operations with a safe singleton-axis permutation when: + - all effective consumers are pto.tload + - source GlobalTensor layout is ND or DN + - destination tile config is NZ-like + + The permutation only moves statically-singleton axes before non-singleton + axes and only when the subview has at most two non-singleton dimensions. + EmitC lowering consumes this marker to avoid ND2NZ static-shape assertion. + }]; + let constructor = "mlir::pto::createPTOCanonicalizeSubviewForTLoadPass()"; + let dependentDialects = ["pto::PTODialect", "memref::MemRefDialect"]; +} + def InferPTOMemScope : Pass<"pto-infer-mem-scope"> { let summary = "Infer memory scope for PTO Ops"; diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index b82d227fe..90b0af943 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(PTOTransforms InsertSync/SyncCodegen.cpp LoweringSyncToPipe.cpp PTOVerifyTFreePass.cpp + PTOCanonicalizeSubviewForTLoad.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/PTO diff --git a/lib/PTO/Transforms/PTOCanonicalizeSubviewForTLoad.cpp b/lib/PTO/Transforms/PTOCanonicalizeSubviewForTLoad.cpp new file mode 100644 index 000000000..36f94a80f --- /dev/null +++ b/lib/PTO/Transforms/PTOCanonicalizeSubviewForTLoad.cpp @@ -0,0 +1,326 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOCANONICALIZESUBVIEWFORTLOAD +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static constexpr llvm::StringLiteral kLayoutAttrName = "layout"; +static constexpr llvm::StringLiteral kSingletonAxisPermutationAttrName = + "pto.singleton_axis_permutation"; + +static Value peelUnrealized(Value v) { + if (auto castOp = v.getDefiningOp()) + return castOp.getOperand(0); + return v; +} + +static std::optional extractStaticInt(OpFoldResult ofr) { + if (auto attr = ofr.dyn_cast()) { + if (auto ia = dyn_cast(attr)) + return ia.getInt(); + return std::nullopt; + } + Value v = ofr.get(); + if (auto cIdx = v.getDefiningOp()) + return cIdx.value(); + if (auto cInt = v.getDefiningOp()) + return cInt.value(); + if (auto c = v.getDefiningOp()) { + if (auto ia = dyn_cast(c.getValue())) + return ia.getInt(); + } + return std::nullopt; +} + +static std::optional getLayoutAttrFromOp(Operation *op) { + if (!op) + return std::nullopt; + if (auto attr = op->getAttrOfType(kLayoutAttrName)) + return attr.getLayout(); + return std::nullopt; +} + +static std::optional resolveLayoutFromValueChain(Value v) { + v = peelUnrealized(v); + while (Operation *def = v.getDefiningOp()) { + if (auto layout = getLayoutAttrFromOp(def)) + return layout; + if (auto subview = dyn_cast(def)) { + v = peelUnrealized(subview.getSource()); + continue; + } + if (auto reinterpret = dyn_cast(def)) { + v = peelUnrealized(reinterpret.getSource()); + continue; + } + if (auto cast = dyn_cast(def)) { + v = peelUnrealized(cast.getSource()); + continue; + } + if (auto unrealized = dyn_cast(def)) { + if (unrealized->getNumOperands() == 0) + break; + v = peelUnrealized(unrealized.getOperand(0)); + continue; + } + break; + } + return std::nullopt; +} + +static std::optional +resolveLayoutForGlobalTensor(Operation *anchor, Value basePtr) { + if (auto layout = getLayoutAttrFromOp(anchor)) + return layout; + return resolveLayoutFromValueChain(basePtr); +} + +static std::optional +resolveTileConfigFromValueChain(Value v) { + v = peelUnrealized(v); + while (Operation *def = v.getDefiningOp()) { + if (auto bind = dyn_cast(def)) + return bind.getConfigAttr(); + if (auto cast = dyn_cast(def)) { + if (auto cfg = cast.getConfig()) + return *cfg; + return std::nullopt; + } + if (auto mcast = dyn_cast(def)) { + v = peelUnrealized(mcast.getSource()); + continue; + } + if (auto rc = dyn_cast(def)) { + v = peelUnrealized(rc.getSource()); + continue; + } + if (auto sv = dyn_cast(def)) { + v = peelUnrealized(sv.getSource()); + continue; + } + if (auto unrealized = dyn_cast(def)) { + if (unrealized->getNumOperands() == 0) + break; + v = peelUnrealized(unrealized.getOperand(0)); + continue; + } + break; + } + return std::nullopt; +} + +static bool isNZLikeTileConfig(pto::TileBufConfigAttr configAttr) { + int32_t blVal = 0; + if (auto bl = dyn_cast(configAttr.getBLayout())) + blVal = static_cast(bl.getValue()); + + int32_t slVal = 0; + if (auto sl = dyn_cast(configAttr.getSLayout())) + slVal = static_cast(sl.getValue()); + + int32_t fractal = 0; + if (auto fr = dyn_cast(configAttr.getSFractalSize())) + fractal = fr.getInt(); + + return blVal == static_cast(BLayout::ColMajor) && + slVal == static_cast(SLayout::RowMajor) && fractal == 512; +} + +static bool tracesBackThroughViewCasts(Value v, Value target) { + Value cur = peelUnrealized(v); + for (int guard = 0; guard < 64; ++guard) { + if (cur == target) + return true; + Operation *def = cur.getDefiningOp(); + if (!def) + return false; + if (auto mcast = dyn_cast(def)) { + cur = peelUnrealized(mcast.getSource()); + continue; + } + if (auto rc = dyn_cast(def)) { + cur = peelUnrealized(rc.getSource()); + continue; + } + if (auto unrealized = dyn_cast(def)) { + if (unrealized->getNumOperands() == 0) + return false; + cur = peelUnrealized(unrealized.getOperand(0)); + continue; + } + return false; + } + return false; +} + +static void collectUsersThroughViewCasts(Value v, + SmallVectorImpl &out) { + SmallVector worklist; + llvm::DenseSet visitedValues; + llvm::DenseSet visitedUsers; + worklist.push_back(v); + + while (!worklist.empty()) { + Value cur = worklist.pop_back_val(); + if (!visitedValues.insert(cur).second) + continue; + for (Operation *u : cur.getUsers()) { + if (auto unrealized = dyn_cast(u)) { + for (Value r : unrealized->getResults()) + worklist.push_back(r); + continue; + } + if (auto mcast = dyn_cast(u)) { + worklist.push_back(mcast.getResult()); + continue; + } + if (auto rc = dyn_cast(u)) { + worklist.push_back(rc.getResult()); + continue; + } + if (visitedUsers.insert(u).second) + out.push_back(u); + } + } +} + +static bool isNdDnToNzLikeTLoad(pto::TLoadOp load) { + if (!load.getDst()) + return false; + + auto gtLayout = + resolveLayoutForGlobalTensor(load.getOperation(), load.getSrc()); + if (!gtLayout || + (*gtLayout != mlir::pto::Layout::ND && + *gtLayout != mlir::pto::Layout::DN)) + return false; + + auto tileCfg = resolveTileConfigFromValueChain(load.getDst()); + if (!tileCfg) + return false; + return isNZLikeTileConfig(*tileCfg); +} + +static bool shouldCanonicalizeSubviewForNdDnToNz(memref::SubViewOp sv) { + SmallVector users; + collectUsersThroughViewCasts(sv.getResult(), users); + bool sawTarget = false; + + for (Operation *user : users) { + auto load = dyn_cast(user); + if (!load) + return false; + if (!tracesBackThroughViewCasts(load.getSrc(), sv.getResult())) + continue; + if (!isNdDnToNzLikeTLoad(load)) + return false; + sawTarget = true; + } + return sawTarget; +} + +static std::optional> +computeSingletonFirstPermutation(memref::SubViewOp sv) { + auto resTy = dyn_cast(sv.getResult().getType()); + if (!resTy) + return std::nullopt; + + const int rank = resTy.getRank(); + if (rank <= 2) + return std::nullopt; + + auto mixedSizes = sv.getMixedSizes(); + auto resShape = resTy.getShape(); + SmallVector staticSingletonDims; + staticSingletonDims.reserve(rank); + + int nonSingletonCount = 0; + for (int i = 0; i < rank; ++i) { + std::optional staticDim; + if (i < static_cast(mixedSizes.size())) + staticDim = extractStaticInt(mixedSizes[i]); + if (!staticDim && resShape[i] != ShapedType::kDynamic) + staticDim = resShape[i]; + + bool isSingleton = staticDim && *staticDim == 1; + staticSingletonDims.push_back(isSingleton); + if (!isSingleton) + ++nonSingletonCount; + } + + if (nonSingletonCount > 2) + return std::nullopt; + + SmallVector permutation; + permutation.reserve(rank); + for (int i = 0; i < rank; ++i) { + if (staticSingletonDims[i]) + permutation.push_back(i); + } + for (int i = 0; i < rank; ++i) { + if (!staticSingletonDims[i]) + permutation.push_back(i); + } + + bool changed = false; + for (int i = 0; i < rank; ++i) { + if (permutation[i] != i) { + changed = true; + break; + } + } + if (!changed) + return std::nullopt; + + return permutation; +} + +struct PTOCanonicalizeSubviewForTLoadPass + : public mlir::pto::impl::PTOCanonicalizeSubviewForTLoadBase< + PTOCanonicalizeSubviewForTLoadPass> { + void runOnOperation() override { + func::FuncOp func = getOperation(); + MLIRContext *ctx = &getContext(); + + func.walk([&](memref::SubViewOp sv) { + sv->removeAttr(kSingletonAxisPermutationAttrName); + + auto perm = computeSingletonFirstPermutation(sv); + if (!perm) + return; + if (!shouldCanonicalizeSubviewForNdDnToNz(sv)) + return; + + sv->setAttr(kSingletonAxisPermutationAttrName, + DenseI64ArrayAttr::get(ctx, *perm)); + }); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTOCanonicalizeSubviewForTLoadPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index c70bed094..1644108cb 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -2917,6 +2917,63 @@ struct SubviewToEmitCPattern : public OpConversionPattern { rewriter.create(loc, u32Ty, srcV, stepV)); } + // 3.0 应用前置 canonicalize pass 预计算的 singleton 轴重排。 + if (rank > 2) { + if (auto permAttr = op->getAttrOfType( + "pto.singleton_axis_permutation")) { + if (static_cast(permAttr.size()) == rank) { + SmallVector permutation; + permutation.reserve(rank); + SmallVector seen(rank, 0); + bool validPermutation = true; + for (int64_t idx64 : permAttr.asArrayRef()) { + if (idx64 < 0 || idx64 >= rank) { + validPermutation = false; + break; + } + unsigned idx = static_cast(idx64); + if (seen[idx]) { + validPermutation = false; + break; + } + seen[idx] = 1; + permutation.push_back(idx); + } + + if (validPermutation) { + bool changed = false; + for (int i = 0; i < rank; ++i) { + if (permutation[i] != static_cast(i)) { + changed = true; + break; + } + } + + if (changed) { + SmallVector reorderedShape; + SmallVector reorderedSizes; + SmallVector reorderedStrides; + SmallVector reorderedStrideValues; + reorderedShape.reserve(rank); + reorderedSizes.reserve(rank); + reorderedStrides.reserve(rank); + reorderedStrideValues.reserve(rank); + for (unsigned idx : permutation) { + reorderedShape.push_back(shapeParamsVec[idx]); + reorderedSizes.push_back(sizeValues[idx]); + reorderedStrides.push_back(dummyStrideVec[idx]); + reorderedStrideValues.push_back(strideValues[idx]); + } + shapeParamsVec = std::move(reorderedShape); + sizeValues = std::move(reorderedSizes); + dummyStrideVec = std::move(reorderedStrides); + strideValues = std::move(reorderedStrideValues); + } + } + } + } + } + // 3.1 右对齐到 5 维:shape 补 1;已有维度继承原 stride; // 被补出来的高维按“紧密升维”规则连续推导:stride[i] = shape[i+1] * stride[i+1] SmallVector finalShape(5, "1"); diff --git a/test/basic/issue453_partition_view_singleton_axis_reorder.pto b/test/basic/issue453_partition_view_singleton_axis_reorder.pto new file mode 100644 index 000000000..8e7260878 --- /dev/null +++ b/test/basic/issue453_partition_view_singleton_axis_reorder.pto @@ -0,0 +1,106 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + // Safe case: only singleton axes are moved. This should canonicalize + // [1, 16, 1, 16] -> [1, 1, 16, 16] before 5D right-align, so emitted + // GlobalTensor shape becomes <1, 1, 1, 16, 16>. + func.func @issue453_singleton_axis_reorder_positive(%src: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + + %tv = pto.make_tensor_view %src, shape = [%c1, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout} : !pto.tensor_view + %sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c1, %c16, %c1, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<1x16x1x16xf16> + %tile = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%sv : !pto.partition_tensor_view<1x16x1x16xf16>) + outs(%tile : !pto.tile_buf) + return + } + + // Unsafe case: three non-singleton axes (2,16,16). Reordering is not legal + // and must not happen. + func.func @issue453_singleton_axis_reorder_negative(%src: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + + %tv = pto.make_tensor_view %src, shape = [%c2, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout} : !pto.tensor_view + %sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c2, %c16, %c1, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<2x16x1x16xf16> + %tile = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%sv : !pto.partition_tensor_view<2x16x1x16xf16>) + outs(%tile : !pto.tile_buf) + return + } + + // Safe case (DN): same singleton-axis movement is legal for DN layout and + // should still canonicalize before ND/DN->NZ-like tload lowering. + func.func @issue453_singleton_axis_reorder_positive_dn(%src: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + + %tv = pto.make_tensor_view %src, shape = [%c1, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout} : !pto.tensor_view + %sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c1, %c16, %c1, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<1x16x1x16xf16> + %tile = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%sv : !pto.partition_tensor_view<1x16x1x16xf16>) + outs(%tile : !pto.tile_buf) + return + } + + // Guard case #1: ND->ND tload (non NZ-like tile config). Reordering should + // stay disabled even when singleton-axis movement would be legal. + func.func @issue453_singleton_axis_no_reorder_nd_to_nd_tload(%src: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + + %tv = pto.make_tensor_view %src, shape = [%c1, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout} : !pto.tensor_view + %sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c1, %c16, %c1, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<1x16x1x16xf16> + %tile = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%sv : !pto.partition_tensor_view<1x16x1x16xf16>) + outs(%tile : !pto.tile_buf) + return + } + + // Guard case #2: TStore-only user. Reordering is a tload ND/DN->NZ fix, so + // it must not run for pure store paths. + func.func @issue453_singleton_axis_no_reorder_tstore_only(%dst: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + + %tv = pto.make_tensor_view %dst, shape = [%c1, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout} : !pto.tensor_view + %sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c1, %c16, %c1, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<1x16x1x16xf16> + %tile = pto.alloc_tile : !pto.tile_buf + pto.tstore ins(%tile : !pto.tile_buf) + outs(%sv : !pto.partition_tensor_view<1x16x1x16xf16>) + return + } +} + +// A5-LABEL: AICORE void issue453_singleton_axis_reorder_positive( +// A5: GlobalTensor +// A5-NOT: GlobalTensor + +// A5-LABEL: AICORE void issue453_singleton_axis_reorder_negative( +// A5: GlobalTensor +// A5-NOT: GlobalTensor + +// A5-LABEL: AICORE void issue453_singleton_axis_reorder_positive_dn( +// A5: GlobalTensor +// A5: pto::Layout::DN +// A5-NOT: GlobalTensor + +// A5-LABEL: AICORE void issue453_singleton_axis_no_reorder_nd_to_nd_tload( +// A5: GlobalTensor +// A5-NOT: GlobalTensor + +// A5-LABEL: AICORE void issue453_singleton_axis_no_reorder_tstore_only( +// A5: GlobalTensor +// A5-NOT: GlobalTensor diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index e0c49c4cd..816758f88 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1141,6 +1141,8 @@ int main(int argc, char **argv) { if (!disableInferLayout) pm.addNestedPass(pto::createInferPTOLayoutPass()); pm.addPass(pto::createPTOViewToMemrefPass()); + pm.addNestedPass( + pto::createPTOCanonicalizeSubviewForTLoadPass()); //pm.addPass(createInferPTOMemScopePass()); if (effectiveLevel != PTOBuildLevel::Level3) {