From 08cdc2f0dd79a653ac72f95286a899ce9bf7ab20 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Wed, 8 Apr 2026 20:35:40 +0800 Subject: [PATCH 1/3] fix: canonicalize singleton partition-view axes for GT lowering --- lib/PTO/Transforms/PTOToEmitC.cpp | 60 +++++++++++++++++++ ..._partition_view_singleton_axis_reorder.pto | 45 ++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 test/basic/issue453_partition_view_singleton_axis_reorder.pto diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index c70bed094..57e030f78 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -2864,9 +2864,11 @@ struct SubviewToEmitCPattern : public OpConversionPattern { // 2. 生成 Shape 模板参数,之后会右对齐有效维度并补齐到 5 维(高维填 1) SmallVector shapeParamsVec; SmallVector sizeValues; // 每个维度对应的运行时 size(统一为 unsigned) + SmallVector staticSingletonDims; // 可静态证明为 1 的维度 auto resShape = resTy.getShape(); auto mixedSizes = op.getMixedSizes(); sizeValues.reserve(rank); + staticSingletonDims.reserve(rank); for (int i = 0; i < resTy.getRank(); ++i) { if (resShape[i] == ShapedType::kDynamic) { shapeParamsVec.push_back("-1"); @@ -2879,6 +2881,13 @@ struct SubviewToEmitCPattern : public OpConversionPattern { else sizeValues.push_back( mkU32(resShape[i] == ShapedType::kDynamic ? 1 : resShape[i])); + + std::optional staticDim; + if (i < (int)mixedSizes.size()) + staticDim = extractStaticInt(mixedSizes[i]); + if (!staticDim && resShape[i] != ShapedType::kDynamic) + staticDim = resShape[i]; + staticSingletonDims.push_back(staticDim && *staticDim == 1); } // 3. 生成 Stride 模板参数 + 运行时 stride 值(考虑 subview step) @@ -2917,6 +2926,57 @@ struct SubviewToEmitCPattern : public OpConversionPattern { rewriter.create(loc, u32Ty, srcV, stepV)); } + // 3.0 对可证明安全的分区视图做 singleton 轴前移: + // 仅允许 size==1 的轴跨越其它轴,保证地址集合不变;动态轴视为非 singleton,不重排。 + if (rank > 2) { + int nonSingletonCount = 0; + for (bool isSingleton : staticSingletonDims) { + if (!isSingleton) + ++nonSingletonCount; + } + if (nonSingletonCount <= 2) { + SmallVector permutation; + permutation.reserve(rank); + for (int i = 0; i < rank; ++i) { + if (staticSingletonDims[i]) + permutation.push_back(i); + } + for (int i = 0; i < rank; ++i) { + if (!staticSingletonDims[i]) + permutation.push_back(i); + } + + bool changed = false; + for (int i = 0; i < rank; ++i) { + if (permutation[i] != static_cast(i)) { + changed = true; + break; + } + } + + if (changed) { + SmallVector reorderedShape; + SmallVector reorderedSizes; + SmallVector reorderedStrides; + SmallVector reorderedStrideValues; + reorderedShape.reserve(rank); + reorderedSizes.reserve(rank); + reorderedStrides.reserve(rank); + reorderedStrideValues.reserve(rank); + for (unsigned idx : permutation) { + reorderedShape.push_back(shapeParamsVec[idx]); + reorderedSizes.push_back(sizeValues[idx]); + reorderedStrides.push_back(dummyStrideVec[idx]); + reorderedStrideValues.push_back(strideValues[idx]); + } + shapeParamsVec = std::move(reorderedShape); + sizeValues = std::move(reorderedSizes); + dummyStrideVec = std::move(reorderedStrides); + strideValues = std::move(reorderedStrideValues); + } + } + } + // 3.1 右对齐到 5 维:shape 补 1;已有维度继承原 stride; // 被补出来的高维按“紧密升维”规则连续推导:stride[i] = shape[i+1] * stride[i+1] SmallVector finalShape(5, "1"); diff --git a/test/basic/issue453_partition_view_singleton_axis_reorder.pto b/test/basic/issue453_partition_view_singleton_axis_reorder.pto new file mode 100644 index 000000000..0e8f35e81 --- /dev/null +++ b/test/basic/issue453_partition_view_singleton_axis_reorder.pto @@ -0,0 +1,45 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + // Safe case: only singleton axes are moved. This should canonicalize + // [1, 16, 1, 16] -> [1, 1, 16, 16] before 5D right-align, so emitted + // GlobalTensor shape becomes <1, 1, 1, 16, 16>. + func.func @issue453_singleton_axis_reorder_positive(%src: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + + %tv = pto.make_tensor_view %src, shape = [%c1, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout} : !pto.tensor_view + %sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c1, %c16, %c1, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<1x16x1x16xf16> + %tile = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%sv : !pto.partition_tensor_view<1x16x1x16xf16>) + outs(%tile : !pto.tile_buf) + return + } + + // Unsafe case: three non-singleton axes (2,16,16). Reordering is not legal + // and must not happen. + func.func @issue453_singleton_axis_reorder_negative(%src: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + + %tv = pto.make_tensor_view %src, shape = [%c2, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout} : !pto.tensor_view + %sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c2, %c16, %c1, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<2x16x1x16xf16> + %tile = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%sv : !pto.partition_tensor_view<2x16x1x16xf16>) + outs(%tile : !pto.tile_buf) + return + } +} + +// A5-LABEL: AICORE void issue453_singleton_axis_reorder_positive( +// A5: GlobalTensor +// A5-NOT: GlobalTensor + +// A5-LABEL: AICORE void issue453_singleton_axis_reorder_negative( +// A5: GlobalTensor +// A5-NOT: GlobalTensor From 8d0d3ea2c4412954a011b68699f72f57336061c8 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 9 Apr 2026 16:20:09 +0800 Subject: [PATCH 2/3] Restrict singleton-axis canonicalization to ND/DN->NZ tload --- lib/PTO/Transforms/PTOToEmitC.cpp | 137 +++++++++++++++++- ..._partition_view_singleton_axis_reorder.pto | 40 +++++ 2 files changed, 176 insertions(+), 1 deletion(-) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 57e030f78..2e624384a 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -142,6 +142,140 @@ resolveLayoutForGlobalTensor(Operation *anchor, Value basePtr) { return resolveLayoutFromValueChain(basePtr); } +static std::optional +resolveTileConfigFromValueChain(Value v) { + v = peelUnrealized(v); + while (Operation *def = v.getDefiningOp()) { + if (auto bind = dyn_cast(def)) + return bind.getConfigAttr(); + if (auto cast = dyn_cast(def)) { + if (auto cfg = cast.getConfig()) + return *cfg; + return std::nullopt; + } + if (auto mcast = dyn_cast(def)) { + v = peelUnrealized(mcast.getSource()); + continue; + } + if (auto sv = dyn_cast(def)) { + v = peelUnrealized(sv.getSource()); + continue; + } + if (auto unrealized = dyn_cast(def)) { + if (unrealized->getNumOperands() == 0) + break; + v = peelUnrealized(unrealized.getOperand(0)); + continue; + } + break; + } + return std::nullopt; +} + +static bool isNZLikeTileConfig(pto::TileBufConfigAttr configAttr) { + int32_t blVal = 0; + if (auto bl = dyn_cast(configAttr.getBLayout())) + blVal = static_cast(bl.getValue()); + + int32_t slVal = 0; + if (auto sl = dyn_cast(configAttr.getSLayout())) + slVal = static_cast(sl.getValue()); + + int32_t fractal = 0; + if (auto fr = dyn_cast(configAttr.getSFractalSize())) + fractal = fr.getInt(); + + // pto-isa NZ-like tile predicate: + // !isRowMajor && SFractal == RowMajor && SFractalSize == 512 + return blVal == static_cast(BLayout::ColMajor) && + slVal == static_cast(SLayout::RowMajor) && fractal == 512; +} + +static bool tracesBackThroughViewCasts(Value v, Value target) { + Value cur = peelUnrealized(v); + for (int guard = 0; guard < 32; ++guard) { + if (cur == target) + return true; + Operation *def = cur.getDefiningOp(); + if (!def) + return false; + if (auto mcast = dyn_cast(def)) { + cur = peelUnrealized(mcast.getSource()); + continue; + } + if (auto unrealized = dyn_cast(def)) { + if (unrealized->getNumOperands() == 0) + return false; + cur = peelUnrealized(unrealized.getOperand(0)); + continue; + } + return false; + } + return false; +} + +static void collectUsersThroughViewCasts(Value v, + SmallVectorImpl &out) { + auto walk = [&](auto &&self, Value cur) -> void { + for (Operation *u : cur.getUsers()) { + if (auto unrealized = dyn_cast(u)) { + for (Value r : unrealized.getResults()) + self(self, r); + continue; + } + if (auto mcast = dyn_cast(u)) { + self(self, mcast.getResult()); + continue; + } + bool seen = false; + for (Operation *old : out) { + if (old == u) { + seen = true; + break; + } + } + if (!seen) + out.push_back(u); + } + }; + walk(walk, v); +} + +static bool isNdDnToNzLikeTLoad(pto::TLoadOp load) { + if (!load.getDst()) + return false; + + auto gtLayout = + resolveLayoutForGlobalTensor(load.getOperation(), load.getSrc()); + if (!gtLayout || + (*gtLayout != mlir::pto::Layout::ND && + *gtLayout != mlir::pto::Layout::DN)) + return false; + + auto tileCfg = resolveTileConfigFromValueChain(load.getDst()); + if (!tileCfg) + return false; + return isNZLikeTileConfig(*tileCfg); +} + +static bool shouldCanonicalizeSubviewForNdDnToNz(memref::SubViewOp sv) { + SmallVector users; + collectUsersThroughViewCasts(sv.getResult(), users); + bool sawTarget = false; + + for (Operation *user : users) { + auto load = dyn_cast(user); + if (!load) + return false; + if (!tracesBackThroughViewCasts(load.getSrc(), sv.getResult())) + continue; + if (!isNdDnToNzLikeTLoad(load)) + return false; + sawTarget = true; + } + return sawTarget; +} + static std::string layoutToEmitCString(mlir::pto::Layout layout) { switch (layout) { case mlir::pto::Layout::ND: @@ -2928,7 +3062,8 @@ struct SubviewToEmitCPattern : public OpConversionPattern { // 3.0 对可证明安全的分区视图做 singleton 轴前移: // 仅允许 size==1 的轴跨越其它轴,保证地址集合不变;动态轴视为非 singleton,不重排。 - if (rank > 2) { + const bool needCanonicalize = shouldCanonicalizeSubviewForNdDnToNz(op); + if (needCanonicalize && rank > 2) { int nonSingletonCount = 0; for (bool isSingleton : staticSingletonDims) { if (!isSingleton) diff --git a/test/basic/issue453_partition_view_singleton_axis_reorder.pto b/test/basic/issue453_partition_view_singleton_axis_reorder.pto index 0e8f35e81..c3ff31645 100644 --- a/test/basic/issue453_partition_view_singleton_axis_reorder.pto +++ b/test/basic/issue453_partition_view_singleton_axis_reorder.pto @@ -34,6 +34,38 @@ module { outs(%tile : !pto.tile_buf) return } + + // Guard case #1: ND->ND tload (non NZ-like tile config). Reordering should + // stay disabled even when singleton-axis movement would be legal. + func.func @issue453_singleton_axis_no_reorder_nd_to_nd_tload(%src: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + + %tv = pto.make_tensor_view %src, shape = [%c1, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout} : !pto.tensor_view + %sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c1, %c16, %c1, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<1x16x1x16xf16> + %tile = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%sv : !pto.partition_tensor_view<1x16x1x16xf16>) + outs(%tile : !pto.tile_buf) + return + } + + // Guard case #2: TStore-only user. Reordering is a tload ND/DN->NZ fix, so + // it must not run for pure store paths. + func.func @issue453_singleton_axis_no_reorder_tstore_only(%dst: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + + %tv = pto.make_tensor_view %dst, shape = [%c1, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout} : !pto.tensor_view + %sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c1, %c16, %c1, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<1x16x1x16xf16> + %tile = pto.alloc_tile : !pto.tile_buf + pto.tstore ins(%tile : !pto.tile_buf) + outs(%sv : !pto.partition_tensor_view<1x16x1x16xf16>) + return + } } // A5-LABEL: AICORE void issue453_singleton_axis_reorder_positive( @@ -43,3 +75,11 @@ module { // A5-LABEL: AICORE void issue453_singleton_axis_reorder_negative( // A5: GlobalTensor // A5-NOT: GlobalTensor + +// A5-LABEL: AICORE void issue453_singleton_axis_no_reorder_nd_to_nd_tload( +// A5: GlobalTensor +// A5-NOT: GlobalTensor + +// A5-LABEL: AICORE void issue453_singleton_axis_no_reorder_tstore_only( +// A5: GlobalTensor +// A5-NOT: GlobalTensor From 350e04846e24ed649894262ae6dba6223b2448d7 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 9 Apr 2026 16:30:54 +0800 Subject: [PATCH 3/3] Extract subview singleton canonicalization into dedicated pass --- include/PTO/Transforms/Passes.h | 1 + include/PTO/Transforms/Passes.td | 17 + lib/PTO/Transforms/CMakeLists.txt | 1 + .../PTOCanonicalizeSubviewForTLoad.cpp | 326 ++++++++++++++++++ lib/PTO/Transforms/PTOToEmitC.cpp | 238 +++---------- ..._partition_view_singleton_axis_reorder.pto | 21 ++ tools/ptoas/ptoas.cpp | 2 + 7 files changed, 418 insertions(+), 188 deletions(-) create mode 100644 lib/PTO/Transforms/PTOCanonicalizeSubviewForTLoad.cpp diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index cafdb784c..f43577bf8 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -61,6 +61,7 @@ createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption = {}); std::unique_ptr createPTORemoveRedundantBarrierPass(); std::unique_ptr createPTOViewToMemrefPass(); std::unique_ptr createInferPTOLayoutPass(); +std::unique_ptr createPTOCanonicalizeSubviewForTLoadPass(); // Declare register function void registerPTOPasses(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 37979bf21..b56b5290f 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -57,6 +57,23 @@ def InferPTOLayout : Pass<"pto-infer-layout", "func::FuncOp"> { let dependentDialects = ["pto::PTODialect", "arith::ArithDialect"]; } +def PTOCanonicalizeSubviewForTLoad + : Pass<"pto-canonicalize-subview-for-tload", "func::FuncOp"> { + let summary = "Canonicalize singleton-axis subviews for ND/DN->NZ-like tload"; + let description = [{ + Marks memref.subview operations with a safe singleton-axis permutation when: + - all effective consumers are pto.tload + - source GlobalTensor layout is ND or DN + - destination tile config is NZ-like + + The permutation only moves statically-singleton axes before non-singleton + axes and only when the subview has at most two non-singleton dimensions. + EmitC lowering consumes this marker to avoid ND2NZ static-shape assertion. + }]; + let constructor = "mlir::pto::createPTOCanonicalizeSubviewForTLoadPass()"; + let dependentDialects = ["pto::PTODialect", "memref::MemRefDialect"]; +} + def InferPTOMemScope : Pass<"pto-infer-mem-scope"> { let summary = "Infer memory scope for PTO Ops"; diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index b82d227fe..90b0af943 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(PTOTransforms InsertSync/SyncCodegen.cpp LoweringSyncToPipe.cpp PTOVerifyTFreePass.cpp + PTOCanonicalizeSubviewForTLoad.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/PTO diff --git a/lib/PTO/Transforms/PTOCanonicalizeSubviewForTLoad.cpp b/lib/PTO/Transforms/PTOCanonicalizeSubviewForTLoad.cpp new file mode 100644 index 000000000..36f94a80f --- /dev/null +++ b/lib/PTO/Transforms/PTOCanonicalizeSubviewForTLoad.cpp @@ -0,0 +1,326 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOCANONICALIZESUBVIEWFORTLOAD +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static constexpr llvm::StringLiteral kLayoutAttrName = "layout"; +static constexpr llvm::StringLiteral kSingletonAxisPermutationAttrName = + "pto.singleton_axis_permutation"; + +static Value peelUnrealized(Value v) { + if (auto castOp = v.getDefiningOp()) + return castOp.getOperand(0); + return v; +} + +static std::optional extractStaticInt(OpFoldResult ofr) { + if (auto attr = ofr.dyn_cast()) { + if (auto ia = dyn_cast(attr)) + return ia.getInt(); + return std::nullopt; + } + Value v = ofr.get(); + if (auto cIdx = v.getDefiningOp()) + return cIdx.value(); + if (auto cInt = v.getDefiningOp()) + return cInt.value(); + if (auto c = v.getDefiningOp()) { + if (auto ia = dyn_cast(c.getValue())) + return ia.getInt(); + } + return std::nullopt; +} + +static std::optional getLayoutAttrFromOp(Operation *op) { + if (!op) + return std::nullopt; + if (auto attr = op->getAttrOfType(kLayoutAttrName)) + return attr.getLayout(); + return std::nullopt; +} + +static std::optional resolveLayoutFromValueChain(Value v) { + v = peelUnrealized(v); + while (Operation *def = v.getDefiningOp()) { + if (auto layout = getLayoutAttrFromOp(def)) + return layout; + if (auto subview = dyn_cast(def)) { + v = peelUnrealized(subview.getSource()); + continue; + } + if (auto reinterpret = dyn_cast(def)) { + v = peelUnrealized(reinterpret.getSource()); + continue; + } + if (auto cast = dyn_cast(def)) { + v = peelUnrealized(cast.getSource()); + continue; + } + if (auto unrealized = dyn_cast(def)) { + if (unrealized->getNumOperands() == 0) + break; + v = peelUnrealized(unrealized.getOperand(0)); + continue; + } + break; + } + return std::nullopt; +} + +static std::optional +resolveLayoutForGlobalTensor(Operation *anchor, Value basePtr) { + if (auto layout = getLayoutAttrFromOp(anchor)) + return layout; + return resolveLayoutFromValueChain(basePtr); +} + +static std::optional +resolveTileConfigFromValueChain(Value v) { + v = peelUnrealized(v); + while (Operation *def = v.getDefiningOp()) { + if (auto bind = dyn_cast(def)) + return bind.getConfigAttr(); + if (auto cast = dyn_cast(def)) { + if (auto cfg = cast.getConfig()) + return *cfg; + return std::nullopt; + } + if (auto mcast = dyn_cast(def)) { + v = peelUnrealized(mcast.getSource()); + continue; + } + if (auto rc = dyn_cast(def)) { + v = peelUnrealized(rc.getSource()); + continue; + } + if (auto sv = dyn_cast(def)) { + v = peelUnrealized(sv.getSource()); + continue; + } + if (auto unrealized = dyn_cast(def)) { + if (unrealized->getNumOperands() == 0) + break; + v = peelUnrealized(unrealized.getOperand(0)); + continue; + } + break; + } + return std::nullopt; +} + +static bool isNZLikeTileConfig(pto::TileBufConfigAttr configAttr) { + int32_t blVal = 0; + if (auto bl = dyn_cast(configAttr.getBLayout())) + blVal = static_cast(bl.getValue()); + + int32_t slVal = 0; + if (auto sl = dyn_cast(configAttr.getSLayout())) + slVal = static_cast(sl.getValue()); + + int32_t fractal = 0; + if (auto fr = dyn_cast(configAttr.getSFractalSize())) + fractal = fr.getInt(); + + return blVal == static_cast(BLayout::ColMajor) && + slVal == static_cast(SLayout::RowMajor) && fractal == 512; +} + +static bool tracesBackThroughViewCasts(Value v, Value target) { + Value cur = peelUnrealized(v); + for (int guard = 0; guard < 64; ++guard) { + if (cur == target) + return true; + Operation *def = cur.getDefiningOp(); + if (!def) + return false; + if (auto mcast = dyn_cast(def)) { + cur = peelUnrealized(mcast.getSource()); + continue; + } + if (auto rc = dyn_cast(def)) { + cur = peelUnrealized(rc.getSource()); + continue; + } + if (auto unrealized = dyn_cast(def)) { + if (unrealized->getNumOperands() == 0) + return false; + cur = peelUnrealized(unrealized.getOperand(0)); + continue; + } + return false; + } + return false; +} + +static void collectUsersThroughViewCasts(Value v, + SmallVectorImpl &out) { + SmallVector worklist; + llvm::DenseSet visitedValues; + llvm::DenseSet visitedUsers; + worklist.push_back(v); + + while (!worklist.empty()) { + Value cur = worklist.pop_back_val(); + if (!visitedValues.insert(cur).second) + continue; + for (Operation *u : cur.getUsers()) { + if (auto unrealized = dyn_cast(u)) { + for (Value r : unrealized->getResults()) + worklist.push_back(r); + continue; + } + if (auto mcast = dyn_cast(u)) { + worklist.push_back(mcast.getResult()); + continue; + } + if (auto rc = dyn_cast(u)) { + worklist.push_back(rc.getResult()); + continue; + } + if (visitedUsers.insert(u).second) + out.push_back(u); + } + } +} + +static bool isNdDnToNzLikeTLoad(pto::TLoadOp load) { + if (!load.getDst()) + return false; + + auto gtLayout = + resolveLayoutForGlobalTensor(load.getOperation(), load.getSrc()); + if (!gtLayout || + (*gtLayout != mlir::pto::Layout::ND && + *gtLayout != mlir::pto::Layout::DN)) + return false; + + auto tileCfg = resolveTileConfigFromValueChain(load.getDst()); + if (!tileCfg) + return false; + return isNZLikeTileConfig(*tileCfg); +} + +static bool shouldCanonicalizeSubviewForNdDnToNz(memref::SubViewOp sv) { + SmallVector users; + collectUsersThroughViewCasts(sv.getResult(), users); + bool sawTarget = false; + + for (Operation *user : users) { + auto load = dyn_cast(user); + if (!load) + return false; + if (!tracesBackThroughViewCasts(load.getSrc(), sv.getResult())) + continue; + if (!isNdDnToNzLikeTLoad(load)) + return false; + sawTarget = true; + } + return sawTarget; +} + +static std::optional> +computeSingletonFirstPermutation(memref::SubViewOp sv) { + auto resTy = dyn_cast(sv.getResult().getType()); + if (!resTy) + return std::nullopt; + + const int rank = resTy.getRank(); + if (rank <= 2) + return std::nullopt; + + auto mixedSizes = sv.getMixedSizes(); + auto resShape = resTy.getShape(); + SmallVector staticSingletonDims; + staticSingletonDims.reserve(rank); + + int nonSingletonCount = 0; + for (int i = 0; i < rank; ++i) { + std::optional staticDim; + if (i < static_cast(mixedSizes.size())) + staticDim = extractStaticInt(mixedSizes[i]); + if (!staticDim && resShape[i] != ShapedType::kDynamic) + staticDim = resShape[i]; + + bool isSingleton = staticDim && *staticDim == 1; + staticSingletonDims.push_back(isSingleton); + if (!isSingleton) + ++nonSingletonCount; + } + + if (nonSingletonCount > 2) + return std::nullopt; + + SmallVector permutation; + permutation.reserve(rank); + for (int i = 0; i < rank; ++i) { + if (staticSingletonDims[i]) + permutation.push_back(i); + } + for (int i = 0; i < rank; ++i) { + if (!staticSingletonDims[i]) + permutation.push_back(i); + } + + bool changed = false; + for (int i = 0; i < rank; ++i) { + if (permutation[i] != i) { + changed = true; + break; + } + } + if (!changed) + return std::nullopt; + + return permutation; +} + +struct PTOCanonicalizeSubviewForTLoadPass + : public mlir::pto::impl::PTOCanonicalizeSubviewForTLoadBase< + PTOCanonicalizeSubviewForTLoadPass> { + void runOnOperation() override { + func::FuncOp func = getOperation(); + MLIRContext *ctx = &getContext(); + + func.walk([&](memref::SubViewOp sv) { + sv->removeAttr(kSingletonAxisPermutationAttrName); + + auto perm = computeSingletonFirstPermutation(sv); + if (!perm) + return; + if (!shouldCanonicalizeSubviewForNdDnToNz(sv)) + return; + + sv->setAttr(kSingletonAxisPermutationAttrName, + DenseI64ArrayAttr::get(ctx, *perm)); + }); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTOCanonicalizeSubviewForTLoadPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 2e624384a..1644108cb 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -142,140 +142,6 @@ resolveLayoutForGlobalTensor(Operation *anchor, Value basePtr) { return resolveLayoutFromValueChain(basePtr); } -static std::optional -resolveTileConfigFromValueChain(Value v) { - v = peelUnrealized(v); - while (Operation *def = v.getDefiningOp()) { - if (auto bind = dyn_cast(def)) - return bind.getConfigAttr(); - if (auto cast = dyn_cast(def)) { - if (auto cfg = cast.getConfig()) - return *cfg; - return std::nullopt; - } - if (auto mcast = dyn_cast(def)) { - v = peelUnrealized(mcast.getSource()); - continue; - } - if (auto sv = dyn_cast(def)) { - v = peelUnrealized(sv.getSource()); - continue; - } - if (auto unrealized = dyn_cast(def)) { - if (unrealized->getNumOperands() == 0) - break; - v = peelUnrealized(unrealized.getOperand(0)); - continue; - } - break; - } - return std::nullopt; -} - -static bool isNZLikeTileConfig(pto::TileBufConfigAttr configAttr) { - int32_t blVal = 0; - if (auto bl = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(bl.getValue()); - - int32_t slVal = 0; - if (auto sl = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(sl.getValue()); - - int32_t fractal = 0; - if (auto fr = dyn_cast(configAttr.getSFractalSize())) - fractal = fr.getInt(); - - // pto-isa NZ-like tile predicate: - // !isRowMajor && SFractal == RowMajor && SFractalSize == 512 - return blVal == static_cast(BLayout::ColMajor) && - slVal == static_cast(SLayout::RowMajor) && fractal == 512; -} - -static bool tracesBackThroughViewCasts(Value v, Value target) { - Value cur = peelUnrealized(v); - for (int guard = 0; guard < 32; ++guard) { - if (cur == target) - return true; - Operation *def = cur.getDefiningOp(); - if (!def) - return false; - if (auto mcast = dyn_cast(def)) { - cur = peelUnrealized(mcast.getSource()); - continue; - } - if (auto unrealized = dyn_cast(def)) { - if (unrealized->getNumOperands() == 0) - return false; - cur = peelUnrealized(unrealized.getOperand(0)); - continue; - } - return false; - } - return false; -} - -static void collectUsersThroughViewCasts(Value v, - SmallVectorImpl &out) { - auto walk = [&](auto &&self, Value cur) -> void { - for (Operation *u : cur.getUsers()) { - if (auto unrealized = dyn_cast(u)) { - for (Value r : unrealized.getResults()) - self(self, r); - continue; - } - if (auto mcast = dyn_cast(u)) { - self(self, mcast.getResult()); - continue; - } - bool seen = false; - for (Operation *old : out) { - if (old == u) { - seen = true; - break; - } - } - if (!seen) - out.push_back(u); - } - }; - walk(walk, v); -} - -static bool isNdDnToNzLikeTLoad(pto::TLoadOp load) { - if (!load.getDst()) - return false; - - auto gtLayout = - resolveLayoutForGlobalTensor(load.getOperation(), load.getSrc()); - if (!gtLayout || - (*gtLayout != mlir::pto::Layout::ND && - *gtLayout != mlir::pto::Layout::DN)) - return false; - - auto tileCfg = resolveTileConfigFromValueChain(load.getDst()); - if (!tileCfg) - return false; - return isNZLikeTileConfig(*tileCfg); -} - -static bool shouldCanonicalizeSubviewForNdDnToNz(memref::SubViewOp sv) { - SmallVector users; - collectUsersThroughViewCasts(sv.getResult(), users); - bool sawTarget = false; - - for (Operation *user : users) { - auto load = dyn_cast(user); - if (!load) - return false; - if (!tracesBackThroughViewCasts(load.getSrc(), sv.getResult())) - continue; - if (!isNdDnToNzLikeTLoad(load)) - return false; - sawTarget = true; - } - return sawTarget; -} - static std::string layoutToEmitCString(mlir::pto::Layout layout) { switch (layout) { case mlir::pto::Layout::ND: @@ -2998,11 +2864,9 @@ struct SubviewToEmitCPattern : public OpConversionPattern { // 2. 生成 Shape 模板参数,之后会右对齐有效维度并补齐到 5 维(高维填 1) SmallVector shapeParamsVec; SmallVector sizeValues; // 每个维度对应的运行时 size(统一为 unsigned) - SmallVector staticSingletonDims; // 可静态证明为 1 的维度 auto resShape = resTy.getShape(); auto mixedSizes = op.getMixedSizes(); sizeValues.reserve(rank); - staticSingletonDims.reserve(rank); for (int i = 0; i < resTy.getRank(); ++i) { if (resShape[i] == ShapedType::kDynamic) { shapeParamsVec.push_back("-1"); @@ -3015,13 +2879,6 @@ struct SubviewToEmitCPattern : public OpConversionPattern { else sizeValues.push_back( mkU32(resShape[i] == ShapedType::kDynamic ? 1 : resShape[i])); - - std::optional staticDim; - if (i < (int)mixedSizes.size()) - staticDim = extractStaticInt(mixedSizes[i]); - if (!staticDim && resShape[i] != ShapedType::kDynamic) - staticDim = resShape[i]; - staticSingletonDims.push_back(staticDim && *staticDim == 1); } // 3. 生成 Stride 模板参数 + 运行时 stride 值(考虑 subview step) @@ -3060,54 +2917,59 @@ struct SubviewToEmitCPattern : public OpConversionPattern { rewriter.create(loc, u32Ty, srcV, stepV)); } - // 3.0 对可证明安全的分区视图做 singleton 轴前移: - // 仅允许 size==1 的轴跨越其它轴,保证地址集合不变;动态轴视为非 singleton,不重排。 - const bool needCanonicalize = shouldCanonicalizeSubviewForNdDnToNz(op); - if (needCanonicalize && rank > 2) { - int nonSingletonCount = 0; - for (bool isSingleton : staticSingletonDims) { - if (!isSingleton) - ++nonSingletonCount; - } - if (nonSingletonCount <= 2) { - SmallVector permutation; - permutation.reserve(rank); - for (int i = 0; i < rank; ++i) { - if (staticSingletonDims[i]) - permutation.push_back(i); - } - for (int i = 0; i < rank; ++i) { - if (!staticSingletonDims[i]) - permutation.push_back(i); - } - - bool changed = false; - for (int i = 0; i < rank; ++i) { - if (permutation[i] != static_cast(i)) { - changed = true; - break; + // 3.0 应用前置 canonicalize pass 预计算的 singleton 轴重排。 + if (rank > 2) { + if (auto permAttr = op->getAttrOfType( + "pto.singleton_axis_permutation")) { + if (static_cast(permAttr.size()) == rank) { + SmallVector permutation; + permutation.reserve(rank); + SmallVector seen(rank, 0); + bool validPermutation = true; + for (int64_t idx64 : permAttr.asArrayRef()) { + if (idx64 < 0 || idx64 >= rank) { + validPermutation = false; + break; + } + unsigned idx = static_cast(idx64); + if (seen[idx]) { + validPermutation = false; + break; + } + seen[idx] = 1; + permutation.push_back(idx); } - } - if (changed) { - SmallVector reorderedShape; - SmallVector reorderedSizes; - SmallVector reorderedStrides; - SmallVector reorderedStrideValues; - reorderedShape.reserve(rank); - reorderedSizes.reserve(rank); - reorderedStrides.reserve(rank); - reorderedStrideValues.reserve(rank); - for (unsigned idx : permutation) { - reorderedShape.push_back(shapeParamsVec[idx]); - reorderedSizes.push_back(sizeValues[idx]); - reorderedStrides.push_back(dummyStrideVec[idx]); - reorderedStrideValues.push_back(strideValues[idx]); + if (validPermutation) { + bool changed = false; + for (int i = 0; i < rank; ++i) { + if (permutation[i] != static_cast(i)) { + changed = true; + break; + } + } + + if (changed) { + SmallVector reorderedShape; + SmallVector reorderedSizes; + SmallVector reorderedStrides; + SmallVector reorderedStrideValues; + reorderedShape.reserve(rank); + reorderedSizes.reserve(rank); + reorderedStrides.reserve(rank); + reorderedStrideValues.reserve(rank); + for (unsigned idx : permutation) { + reorderedShape.push_back(shapeParamsVec[idx]); + reorderedSizes.push_back(sizeValues[idx]); + reorderedStrides.push_back(dummyStrideVec[idx]); + reorderedStrideValues.push_back(strideValues[idx]); + } + shapeParamsVec = std::move(reorderedShape); + sizeValues = std::move(reorderedSizes); + dummyStrideVec = std::move(reorderedStrides); + strideValues = std::move(reorderedStrideValues); + } } - shapeParamsVec = std::move(reorderedShape); - sizeValues = std::move(reorderedSizes); - dummyStrideVec = std::move(reorderedStrides); - strideValues = std::move(reorderedStrideValues); } } } diff --git a/test/basic/issue453_partition_view_singleton_axis_reorder.pto b/test/basic/issue453_partition_view_singleton_axis_reorder.pto index c3ff31645..8e7260878 100644 --- a/test/basic/issue453_partition_view_singleton_axis_reorder.pto +++ b/test/basic/issue453_partition_view_singleton_axis_reorder.pto @@ -35,6 +35,22 @@ module { return } + // Safe case (DN): same singleton-axis movement is legal for DN layout and + // should still canonicalize before ND/DN->NZ-like tload lowering. + func.func @issue453_singleton_axis_reorder_positive_dn(%src: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + + %tv = pto.make_tensor_view %src, shape = [%c1, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout} : !pto.tensor_view + %sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c1, %c16, %c1, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<1x16x1x16xf16> + %tile = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%sv : !pto.partition_tensor_view<1x16x1x16xf16>) + outs(%tile : !pto.tile_buf) + return + } + // Guard case #1: ND->ND tload (non NZ-like tile config). Reordering should // stay disabled even when singleton-axis movement would be legal. func.func @issue453_singleton_axis_no_reorder_nd_to_nd_tload(%src: !pto.ptr) { @@ -76,6 +92,11 @@ module { // A5: GlobalTensor // A5-NOT: GlobalTensor +// A5-LABEL: AICORE void issue453_singleton_axis_reorder_positive_dn( +// A5: GlobalTensor +// A5: pto::Layout::DN +// A5-NOT: GlobalTensor + // A5-LABEL: AICORE void issue453_singleton_axis_no_reorder_nd_to_nd_tload( // A5: GlobalTensor // A5-NOT: GlobalTensor diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index e0c49c4cd..816758f88 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1141,6 +1141,8 @@ int main(int argc, char **argv) { if (!disableInferLayout) pm.addNestedPass(pto::createInferPTOLayoutPass()); pm.addPass(pto::createPTOViewToMemrefPass()); + pm.addNestedPass( + pto::createPTOCanonicalizeSubviewForTLoadPass()); //pm.addPass(createInferPTOMemScopePass()); if (effectiveLevel != PTOBuildLevel::Level3) {