From a8ea682ffeab86edace1a121f8d90ea3a42dc40e Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Thu, 18 Sep 2025 15:22:18 +0000 Subject: [PATCH] Move quantized dot operands in LDS in pipeliner This change stores quantized operands in LDS instead of storing them in registers between loop iterations. --- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 6d47ec32db41..99b39d6a05d4 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -278,6 +278,27 @@ getSharedEncIfAllUsersAreDotEnc(Value loadedValue) { ctaLayout, opIdx, srcTy.getShape(), order, vecSize, bitWidth, /*needTrans=*/false); } + } else { + auto loadEncoding = dyn_cast( + dyn_cast(loadedValue.getType()).getEncoding()); + auto sizePerThread = loadEncoding.getSizePerThread(); + auto threadsShape = loadEncoding.getThreadsPerWarp(); + auto order = loadEncoding.getOrder(); + + int vecSize = sizePerThread[order[0]]; + int numBanks = 64; + int perPhase = + (32 * numBanks) / (srcTy.getShape()[order[0]] * bitWidth); + if (perPhase == 0) + perPhase = 1; + int maxPhase = threadsShape[order[1]] / perPhase; + vecSize = sizePerThread[order[0]]; + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + loadedValue.getContext(), vecSize, perPhase, maxPhase, order, + loadEncoding.getCTALayout()); + LDBG("Deduced shared encoding candidate from blocked layout: " + << tempAttr); + sharedEncs.push_back(tempAttr); } } // Check that the shared encodings needed by the users are compatible.