diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 6d47ec32db41..99b39d6a05d4 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -278,6 +278,27 @@ getSharedEncIfAllUsersAreDotEnc(Value loadedValue) { ctaLayout, opIdx, srcTy.getShape(), order, vecSize, bitWidth, /*needTrans=*/false); } + } else { + auto loadEncoding = dyn_cast( + dyn_cast(loadedValue.getType()).getEncoding()); + auto sizePerThread = loadEncoding.getSizePerThread(); + auto threadsShape = loadEncoding.getThreadsPerWarp(); + auto order = loadEncoding.getOrder(); + + int vecSize = sizePerThread[order[0]]; + int numBanks = 64; + int perPhase = + (32 * numBanks) / (srcTy.getShape()[order[0]] * bitWidth); + if (perPhase == 0) + perPhase = 1; + int maxPhase = threadsShape[order[1]] / perPhase; + vecSize = sizePerThread[order[0]]; + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + loadedValue.getContext(), vecSize, perPhase, maxPhase, order, + loadEncoding.getCTALayout()); + LDBG("Deduced shared encoding candidate from blocked layout: " + << tempAttr); + sharedEncs.push_back(tempAttr); } } // Check that the shared encodings needed by the users are compatible.