From efa60dbceb89bce17a741624035c00ca0d8697a2 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Mon, 10 Mar 2025 20:43:49 +0000 Subject: [PATCH 1/4] Allow configuration of mxDataGenerator source --- CMakeLists.txt | 3 +++ cmake/Dependencies.cmake | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c539139..066cff06 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -121,6 +121,9 @@ option(BUILD_TESTING "Build rocRoller test clients" ON) option(BUILD_DOCS "Build rocRoller documentation" ON) option(BUILD_VERBOSE "Output additional build information" OFF) +option(MXDATAGENERATOR_SSH "Fetch mxDataGenerator via SSH" OFF) +set(MXDATAGENERATOR_GIT_URL "github.com" CACHE STRING "Base Git URL to fetch mxDataGenerator from.") + set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/bin) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index c47ee114..66d1c8dc 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -441,6 +441,11 @@ endfunction() function(_fetch_mxDataGenerator VERSION HASH) _determine_git_tag(v main) + if(MXDATAGENERATOR_SSH) + set(mxDataGenerator_url "git@${MXDATAGENERATOR_GIT_URL}:ROCm/mxDataGenerator.git") + else() + set(mxDataGenerator_url "https://${MXDATAGENERATOR_GIT_URL}/ROCm/mxDataGenerator.git") + endif() FetchContent_Declare( mxDataGenerator GIT_REPOSITORY git@github.com:ROCm/mxDataGenerator.git From 729a26622c610730e867d0a87893e0f2475855d2 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Tue, 11 Mar 2025 19:15:08 +0000 Subject: [PATCH 2/4] actually use computed mxDataGenerator URL --- cmake/Dependencies.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 66d1c8dc..1b161f86 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -448,7 +448,7 @@ function(_fetch_mxDataGenerator VERSION HASH) endif() FetchContent_Declare( mxDataGenerator - GIT_REPOSITORY git@github.com:ROCm/mxDataGenerator.git + GIT_REPOSITORY ${mxDataGenerator_url} GIT_TAG ${GIT_TAG} ) FetchContent_MakeAvailable(mxDataGenerator) From efba1dcd0a33555959f9ffabfa970d5667a893da Mon Sep 17 00:00:00 2001 From: Yoonseo Choi Date: Tue, 11 Mar 2025 21:45:39 -0400 Subject: [PATCH 3/4] LWP100-1465 Common GEMM graph: replacing templatized dataType with member variable --- test/catch/CommandTest.cpp | 8 +- test/catch/IdentifyParallelDimensionsTest.cpp | 4 +- .../catch/KernelGraphRemoveDuplicatesTest.cpp | 12 +- test/common/CommonGraphs.cpp | 239 ++++++++++++++++++ test/common/common/CommonGraphs.hpp | 24 +- test/common/common/CommonGraphs_impl.hpp | 126 --------- test/unit/KernelGraphTest.cpp | 16 +- .../unit/KernelGraphTest/UpdateParameters.cpp | 12 +- 8 files changed, 286 insertions(+), 155 deletions(-) diff --git a/test/catch/CommandTest.cpp b/test/catch/CommandTest.cpp index 85bce25a..2bfd8918 100644 --- a/test/catch/CommandTest.cpp +++ b/test/catch/CommandTest.cpp @@ -38,9 +38,9 @@ namespace CommandTest { SECTION("GEMM/TileAdd") { - auto example1 = *rocRollerTest::Graphs::GEMM().getCommand(); - auto example2 = *rocRollerTest::Graphs::GEMM().getCommand(); - auto example3 = *rocRollerTest::Graphs::GEMM().getCommand(); + auto example1 = *rocRollerTest::Graphs::GEMM(DataType::Float).getCommand(); + auto example2 = *rocRollerTest::Graphs::GEMM(DataType::Float).getCommand(); + auto example3 = *rocRollerTest::Graphs::GEMM(DataType::Half).getCommand(); auto example4 = *rocRollerTest::Graphs::TileDoubleAdd().getCommand(); CHECK(example1 == example2); @@ -58,7 +58,7 @@ namespace CommandTest { SECTION("GEMM") { - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); auto command0 = example.getCommand(); auto yaml = Command::toYAML(*command0); diff --git a/test/catch/IdentifyParallelDimensionsTest.cpp b/test/catch/IdentifyParallelDimensionsTest.cpp index 7f46a0f1..62973404 100644 --- a/test/catch/IdentifyParallelDimensionsTest.cpp +++ b/test/catch/IdentifyParallelDimensionsTest.cpp @@ -74,7 +74,7 @@ TEST_CASE("identifyParallelDimensionSets works for GEMM", "[kernel-graph]") using namespace rocRoller; auto ctx = TestContext::ForDefaultTarget(); - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); auto kgraph = KernelGraph::translate(example.getCommand()); @@ -175,7 +175,7 @@ SCENARIO("IdentifyParallelDimensions transformation works for GEMM", "[kernel-gr using namespace rocRoller; auto ctx = TestContext::ForDefaultTarget(); - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); GIVEN("The initial kernel graph for a GEMM") { diff --git a/test/catch/KernelGraphRemoveDuplicatesTest.cpp b/test/catch/KernelGraphRemoveDuplicatesTest.cpp index 9bb470c5..01812d2a 100644 --- a/test/catch/KernelGraphRemoveDuplicatesTest.cpp +++ b/test/catch/KernelGraphRemoveDuplicatesTest.cpp @@ -42,7 +42,7 @@ TEST_CASE("Remove duplicates", "[kernel-graph]") using namespace rocRoller::KernelGraph::ControlGraph; auto ctx = TestContext::ForDefaultTarget().get(); - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); example.setTileSize(128, 128, 32); example.setMFMA(32, 32, 16, 1); @@ -72,16 +72,16 @@ TEST_CASE("Remove duplicates", "[kernel-graph]") // LoadTiled: A A, B B, C C // After removing 2x1 jamming: A, B, C C - CHECK(graph0.control.getElements().to().size() == 6); - CHECK(graph1.control.getElements().to().size() == 4); + CHECK(graph0.control.getElements().to().size() == 3); + CHECK(graph1.control.getElements().to().size() == 3); // StoreLDSTile: A A, B B // After removing 2x1 jamming: A, B - CHECK(graph0.control.getElements().to().size() == 4); + CHECK(graph0.control.getElements().to().size() == 2); CHECK(graph1.control.getElements().to().size() == 2); // LoadLDSTile: A A A A, B B B B // After removing 2x1 jamming: A A A A, B B - CHECK(graph0.control.getElements().to().size() == 8); - CHECK(graph1.control.getElements().to().size() == 6); + CHECK(graph0.control.getElements().to().size() == 4); + CHECK(graph1.control.getElements().to().size() == 4); } diff --git a/test/common/CommonGraphs.cpp b/test/common/CommonGraphs.cpp index c0271b56..5e1b4adf 100644 --- a/test/common/CommonGraphs.cpp +++ b/test/common/CommonGraphs.cpp @@ -193,4 +193,243 @@ namespace rocRollerTest::Graphs return params; } + GEMM::GEMM(DataType ta) + : GEMM(ta, ta) + { + } + GEMM::GEMM(DataType ta, DataType tb) + : GEMM(ta, tb, tb) + { + } + GEMM::GEMM(DataType ta, DataType tb, DataType tc) + : GEMM(ta, tb, tc, tc) + { + } + GEMM::GEMM(DataType ta, DataType tb, DataType tc, DataType td) + : mTa(ta) + , mTb(tb) + , mTc(tc) + , mTd(td) + { + } + + void GEMM::createCommand() + { + m_command = std::make_shared(); + + std::vector oneStridesN + = m_problem.literalStrides ? std::vector({(size_t)1}) : std::vector({}); + + std::vector oneStridesT = m_problem.literalStrides + ? std::vector({(size_t)0, (size_t)1}) + : std::vector({}); + + mTagTensorA = m_command->addOperation(rocRoller::Operations::Tensor( + 2, mTa, m_problem.transA == "N" ? oneStridesN : oneStridesT)); // A + + m_tagA = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorA)); + + mTagTensorB = m_command->addOperation(rocRoller::Operations::Tensor( + 2, mTb, m_problem.transB == "N" ? oneStridesN : oneStridesT)); // B + m_tagB = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorB)); + + mTagTensorC + = m_command->addOperation(rocRoller::Operations::Tensor(2, mTc, oneStridesN)); // C + m_tagC = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorC)); + + mTagScalarAlpha + = m_command->addOperation(rocRoller::Operations::Scalar(DataType::Float)); // alpha + auto tagLoadAlpha + = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(mTagScalarAlpha)); + + mTagScalarBeta = m_command->addOperation(rocRoller::Operations::Scalar(mTc)); // beta + auto tagLoadBeta + = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(mTagScalarBeta)); // beta + + auto tagAB = m_command->addOperation(rocRoller::Operations::T_Mul(m_tagA, m_tagB)); // A * B + + rocRoller::Operations::T_Execute execute(m_command->getNextTag()); + + auto tagBetaC + = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadBeta, m_tagC)); // beta * C + + auto tagAlphaAB + = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadAlpha, tagAB)); // alpha * (A * B) + + if(m_problem.betaInFma) + { + m_tagD = execute.addXOp(rocRoller::Operations::E_Add(tagBetaC, tagAlphaAB)); + // alpha * (A * B) + beta * C + } + else + { + m_tagD = execute.addXOp(rocRoller::Operations::E_Add(tagAlphaAB, tagBetaC)); + // alpha * (A * B) + beta * C + } + m_command->addOperation(std::move(execute)); + + mTagTensorD + = m_command->addOperation(rocRoller::Operations::Tensor(2, mTd, oneStridesN)); // D + m_command->addOperation(rocRoller::Operations::T_Store_Tiled(m_tagD, mTagTensorD)); // D + + if(m_problem.streamK) + { + mTagNumWGs = m_command->allocateTag(); + auto numWGsArg = m_command->allocateArgument(DataType::UInt32, + mTagNumWGs, + ArgumentType::Value, + DataDirection::ReadOnly, + rocRoller::NUMWGS); + } + + mTagScratch = m_command->allocateTag(); + m_command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), + mTagScratch, + ArgumentType::Value, + DataDirection::ReadWrite, + rocRoller::SCRATCH); + } + + CommandPtr GEMM::getCommand() + { + if(!m_command) + createCommand(); + + return m_command; + } + + KernelGraph GEMM::getKernelGraph() + { + return rocRoller::KernelGraph::translate(getCommand()); + } + + void GEMM::setTileSize(int m, int n, int k) + { + m_problem.macM = m; + m_problem.macN = n; + m_problem.macK = k; + } + + void GEMM::setMFMA(int m, int n, int k, int b) + { + m_problem.waveM = m; + m_problem.waveN = n; + m_problem.waveK = k; + m_problem.waveB = b; + } + + void GEMM::setUseLDS(bool a, bool b, bool d) + { + m_problem.loadLDSA = a; + m_problem.loadLDSB = b; + m_problem.storeLDSD = d; + } + + void GEMM::setPrefetch(bool prefetch, + int prefetchInFlight, + int prefetchLDSFactor, + bool prefetchMixMemOps) + { + m_problem.prefetch = prefetch; + m_problem.prefetchInFlight = prefetchInFlight; + m_problem.prefetchLDSFactor = prefetchLDSFactor; + m_problem.prefetchMixMemOps = prefetchMixMemOps; + + m_problem.unrollK = prefetchInFlight; + } + + void GEMM::setProblem(GEMMProblem const& problem) + { + m_problem = problem; + } + + GEMMProblem const& GEMM::getProblem() const + { + return m_problem; + } + + CommandParametersPtr GEMM::getCommandParameters() const + { + using namespace rocRoller::KernelGraph::CoordinateGraph; + + auto params = std::make_shared(); + + params->setManualKernelDimension(2); + + AssertFatal(m_problem.workgroupSizeX % m_problem.wavefrontSize == 0, + "Workgroup Size X must be multiply of wave front size"); + + uint wavetilePerWavefrontM + = m_problem.wavefrontSize * m_problem.macM / m_problem.waveM / m_problem.workgroupSizeX; + uint wavetilePerWavefrontN = m_problem.macN / m_problem.waveN / m_problem.workgroupSizeY; + + AssertFatal(m_problem.macM % (m_problem.waveM * wavetilePerWavefrontM) == 0, + "WaveTile size mismatch (M)"); + AssertFatal(m_problem.macN % (m_problem.waveN * wavetilePerWavefrontN) == 0, + "WaveTile size mismatch (N)"); + + uint workgroupSizeX = m_problem.workgroupSizeX * m_problem.workgroupSizeY; + uint workgroupSizeY = 1; + params->setManualWorkgroupSize({workgroupSizeX, workgroupSizeY, 1}); + + auto macTileA + = MacroTile({m_problem.macM, m_problem.macK}, + LayoutType::MATRIX_A, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}, + m_problem.loadLDSA ? MemoryType::LDS : MemoryType::WAVE); + auto macTileB + = MacroTile({m_problem.macK, m_problem.macN}, + LayoutType::MATRIX_B, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}, + m_problem.loadLDSB ? MemoryType::LDS : MemoryType::WAVE); + auto macTileC + = MacroTile({m_problem.macM, m_problem.macN}, + LayoutType::MATRIX_ACCUMULATOR, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}); + auto macTileD + = MacroTile({m_problem.macM, m_problem.macN}, + LayoutType::MATRIX_ACCUMULATOR, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}, + m_problem.storeLDSD ? MemoryType::LDS : MemoryType::WAVE); + + params->setDimensionInfo(m_tagA, macTileA); + params->setDimensionInfo(m_tagB, macTileB); + params->setDimensionInfo(m_tagC, macTileC); + params->setDimensionInfo(m_tagD, macTileD); + + // uint jammedM + // = m_problem.wavefrontSize * m_problem.macM / m_problem.waveM / workgroupSizeX; + // uint jammedN = m_problem.macN / m_problem.waveN / workgroupSizeY; + + Log::debug("GEMM workgroup sizes {} {} {}", workgroupSizeX, workgroupSizeY, 1); + // Log::debug("GEMM jamming {} {}", jammedM, jammedN); + // params->setWaveTilesPerWavefront(jammedM, jammedN); + + params->setManualWavefrontCount( + {static_cast(m_problem.macM / m_problem.waveM / wavetilePerWavefrontM), + static_cast(m_problem.macN / m_problem.waveN / wavetilePerWavefrontN)}); + + params->fuseLoops = m_problem.fuseLoops; + params->tailLoops = m_problem.tailLoops; + params->allowAmbiguousMemoryNodes = m_problem.allowAmbiguousMemoryNodes; + params->unrollK = m_problem.unrollK; + params->packMultipleElementsInto1VGPR = m_problem.packMultipleElementsInto1VGPR; + params->prefetch = m_problem.prefetch; + params->prefetchInFlight = m_problem.prefetchInFlight; + params->prefetchLDSFactor = m_problem.prefetchLDSFactor; + params->prefetchMixMemOps = m_problem.prefetchMixMemOps; + params->transposeMemoryAccess[LayoutType::MATRIX_A] = m_problem.transA == "T"; + params->transposeMemoryAccess[LayoutType::MATRIX_B] = m_problem.transB == "T"; + params->transposeMemoryAccess[LayoutType::None] = true; + + if(m_problem.streamK) + { + params->loopOverOutputTilesDimensions = {0, 1}; + params->streamK = true; + params->streamKTwoTile = m_problem.streamKTwoTile; + } + + return params; + } + } diff --git a/test/common/common/CommonGraphs.hpp b/test/common/common/CommonGraphs.hpp index 4ee7a3a4..03e902c8 100644 --- a/test/common/common/CommonGraphs.hpp +++ b/test/common/common/CommonGraphs.hpp @@ -40,6 +40,8 @@ #include #include +#include + namespace rocRollerTest { namespace Graphs @@ -52,6 +54,7 @@ namespace rocRollerTest using ContextPtr = rocRoller::ContextPtr; using KernelArguments = rocRoller::KernelArguments; using KernelGraph = rocRoller::KernelGraph::KernelGraph; + using DataType = rocRoller::DataType; /** * @brief Graph for linear: alpha x + beta y. @@ -178,11 +181,13 @@ namespace rocRollerTest * - Assign(D = alpha * AB + beta * C) * - StoreTiled(D) */ - template class GEMM { public: - GEMM(); + GEMM(DataType ta); + GEMM(DataType ta, DataType tb); + GEMM(DataType ta, DataType tb, DataType tc); + GEMM(DataType ta, DataType tb, DataType tc, DataType td); CommandPtr getCommand(); KernelGraph getKernelGraph(); @@ -190,9 +195,21 @@ namespace rocRollerTest void setTileSize(int m, int n, int k); void setMFMA(int m, int n, int k, int b); void setUseLDS(bool a, bool b, bool d); + void setPrefetch(bool prefetch, + int prefetchInFlight, + int prefetchLDSFactor, + bool prefetchMixMemOps); + void setProblem(GEMMProblem const& problem); + GEMMProblem const& getProblem() const; CommandParametersPtr getCommandParameters() const; + rocRoller::Operations::OperationTag mTagTensorA, mTagTensorB, mTagTensorC, mTagTensorD, + mTagScalarAlpha, mTagScalarBeta, mTagScalarSeed, mTagScratch; + rocRoller::Operations::OperationTag mTagNumWGs; + + DataType mTa, mTb, mTc, mTd; + private: void createCommand(); @@ -202,7 +219,8 @@ namespace rocRollerTest rocRoller::Operations::OperationTag m_tagA, m_tagB, m_tagC, m_tagD; - CommandPtr m_command; + CommandPtr m_command; + GEMMProblem m_problem; }; /** diff --git a/test/common/common/CommonGraphs_impl.hpp b/test/common/common/CommonGraphs_impl.hpp index ec097436..9a8d6201 100644 --- a/test/common/common/CommonGraphs_impl.hpp +++ b/test/common/common/CommonGraphs_impl.hpp @@ -256,132 +256,6 @@ namespace rocRollerTest::Graphs return rocRoller::KernelGraph::translate(m_command); } - /* - * GEMM - */ - - template - GEMM::GEMM() - { - createCommand(); - } - - template - void GEMM::createCommand() - { - m_command = std::make_shared(); - - auto dataType = TypeInfo::Var.dataType; - - auto tagTensorA = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // A - m_tagA = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(tagTensorA)); - - auto tagTensorB = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // B - m_tagB = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(tagTensorB)); - - auto tagTensorC = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // C - m_tagC = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(tagTensorC)); - - auto tagScalarAlpha - = m_command->addOperation(rocRoller::Operations::Scalar(dataType)); // alpha - auto tagLoadAlpha - = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(tagScalarAlpha)); - - auto tagScalarBeta - = m_command->addOperation(rocRoller::Operations::Scalar(dataType)); // beta - auto tagLoadBeta - = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(tagScalarBeta)); // beta - - auto tagAB = m_command->addOperation(rocRoller::Operations::T_Mul(m_tagA, m_tagB)); // A * B - - rocRoller::Operations::T_Execute execute(m_command->getNextTag()); - auto tagAlphaAB - = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadAlpha, tagAB)); // alpha * (A * B) - auto tagBetaC - = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadBeta, m_tagC)); // beta * C - m_tagD = execute.addXOp(rocRoller::Operations::E_Add(tagAlphaAB, tagBetaC)); - // alpha * (A * B) + beta * C - m_command->addOperation(std::move(execute)); - - auto tagTensorD = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // D - m_command->addOperation(rocRoller::Operations::T_Store_Tiled(m_tagD, tagTensorD)); // D - } - - template - CommandPtr GEMM::getCommand() - { - return m_command; - } - - template - KernelGraph GEMM::getKernelGraph() - { - return rocRoller::KernelGraph::translate(m_command); - } - - template - void GEMM::setTileSize(int m, int n, int k) - { - m_macM = m; - m_macN = n; - m_macK = k; - } - - template - void GEMM::setMFMA(int m, int n, int k, int b) - { - m_waveM = m; - m_waveN = n; - m_waveK = k; - m_waveB = b; - } - - template - void GEMM::setUseLDS(bool a, bool b, bool d) - { - m_useLDSA = a; - m_useLDSB = b; - m_useLDSD = d; - } - - template - CommandParametersPtr GEMM::getCommandParameters() const - { - using namespace rocRoller::KernelGraph::CoordinateGraph; - - auto params = std::make_shared(); - - auto macTileA = MacroTile({m_macM, m_macK}, - LayoutType::MATRIX_A, - {m_waveM, m_waveN, m_waveK, m_waveB}, - m_useLDSA ? MemoryType::LDS : MemoryType::WAVE); - auto macTileB = MacroTile({m_macK, m_macN}, - LayoutType::MATRIX_B, - {m_waveM, m_waveN, m_waveK, m_waveB}, - m_useLDSB ? MemoryType::LDS : MemoryType::WAVE); - auto macTileC = MacroTile( - {m_macM, m_macN}, LayoutType::MATRIX_ACCUMULATOR, {m_waveM, m_waveN, m_waveK, m_waveB}); - - params->setDimensionInfo(m_tagA, macTileA); - params->setDimensionInfo(m_tagB, macTileB); - params->setDimensionInfo(m_tagC, macTileC); - - // Workgroup size - uint wavefrontSize = 64; - uint workgroupSizeX = 2 * wavefrontSize; - uint workgroupSizeY = 4; - - uint jammedM = wavefrontSize * m_macM / m_waveM / workgroupSizeX; - uint jammedN = m_macN / m_waveN / workgroupSizeY; - - Log::debug("GEMM workgroup sizes {} {} {}", workgroupSizeX, workgroupSizeY, 1); - Log::debug("GEMM jamming {} {}", jammedM, jammedN); - - params->setWaveTilesPerWavefront(jammedM, jammedN); - - return params; - } - /* * TileDoubleAdd */ diff --git a/test/unit/KernelGraphTest.cpp b/test/unit/KernelGraphTest.cpp index 8533b907..edfd3e01 100644 --- a/test/unit/KernelGraphTest.cpp +++ b/test/unit/KernelGraphTest.cpp @@ -1085,7 +1085,7 @@ namespace KernelGraphTest TEST_F(KernelGraphTest, LowerTensor) { - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); int macK = 16; int waveK = 8; @@ -1135,7 +1135,7 @@ namespace KernelGraphTest // Verify that loops have been unrolled auto unrolledForLoops = kgraphUnrolled.control.getNodes().to(); - EXPECT_EQ(unrolledForLoops.size(), 10); // main: X (Y (K K)) (Y (K K)); epilogue: X (Y Y) + EXPECT_EQ(unrolledForLoops.size(), 5); // main: X (Y (K K)) (Y (K K)); epilogue: X (Y Y) auto kgraphFused = kgraphUnrolled.transform(fuseLoopsTransform); kgraphFused = kgraphFused.transform(removeDuplicatesTransform); @@ -1145,7 +1145,7 @@ namespace KernelGraphTest EXPECT_EQ(fusedForLoops.size(), 5); auto fusedLoads = kgraphFused.control.getNodes().to(); - EXPECT_EQ(fusedLoads.size(), 9); // 1 for A, 4 for B, 4 for C + EXPECT_EQ(fusedLoads.size(), 4); // 1 for A, 4 for B, 4 for C // Verify that single iteration loops have been removed. auto kgraphClean = kgraphFused.transform(cleanLoopsTransform); @@ -1154,13 +1154,13 @@ namespace KernelGraphTest // Verify that there is only a single StoreLDSTile node per K loop auto unrolledStoreLDS = kgraphUnrolled.control.getNodes().to(); - EXPECT_EQ(unrolledStoreLDS.size(), 4); + EXPECT_EQ(unrolledStoreLDS.size(), 1); // Verify number of ComputeIndexes: A loads; A LDS loads; B loads; C load; D // store: 3 + (2+2) + 3 + 3 + 3 = 12 kgraph1 = kgraph1.transform(addComputeIndexTransform); auto computeIndexes = kgraph1.control.getNodes().to(); - EXPECT_EQ(computeIndexes.size(), 16); + EXPECT_EQ(computeIndexes.size(), 15); // Verify number of Deallocates auto addDeallocate = std::make_shared(); @@ -1169,7 +1169,7 @@ namespace KernelGraphTest EXPECT_EQ(addDeallocates.size(), 16); auto storeLDS = kgraphUnrolled.control.getNodes().to(); - EXPECT_EQ(storeLDS.size(), 4); + EXPECT_EQ(storeLDS.size(), 1); auto fusedStoreLDS = kgraphFused.control.getNodes().to(); EXPECT_EQ(fusedStoreLDS.size(), 1); @@ -1177,7 +1177,7 @@ namespace KernelGraphTest TEST_F(KernelGraphTest, InlineIncrement) { - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); example.setTileSize(128, 256, 8); example.setMFMA(32, 32, 2, 1); @@ -2717,7 +2717,7 @@ namespace KernelGraphTest { using GD = Graph::Direction; - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); example.setTileSize(128, 256, 8); example.setMFMA(32, 32, 2, 1); diff --git a/test/unit/KernelGraphTest/UpdateParameters.cpp b/test/unit/KernelGraphTest/UpdateParameters.cpp index dd101c5c..e72f9960 100644 --- a/test/unit/KernelGraphTest/UpdateParameters.cpp +++ b/test/unit/KernelGraphTest/UpdateParameters.cpp @@ -88,7 +88,7 @@ namespace KernelGraphTest { using namespace rocRoller::KernelGraph; - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); int macK = 16; int waveK = 8; @@ -118,16 +118,16 @@ namespace KernelGraphTest // Now apply SetWorkitemCount and try again kgraph = kgraph.transform(std::make_shared(m_context)); - CommandArgumentPtr tensorDsizeX; + CommandArgumentPtr tensorAsizeX; { auto arguments = command->getArguments(); for(auto argument : arguments) { - if(argument->name() == "Tensor_4_size_0") - tensorDsizeX = argument; + if(argument->name() == "Tensor_0_size_0") + tensorAsizeX = argument; } } - ASSERT_NE(tensorDsizeX, nullptr) << "D size not found"; + ASSERT_NE(tensorAsizeX, nullptr) << "A size not found"; workitemCount = m_context->kernel()->workitemCount(); @@ -135,7 +135,7 @@ namespace KernelGraphTest auto workgroupSizeX = Expression::literal(128u); auto expected - = (((tensorDsizeX->expression() + workgroupSizeX) - one) / workgroupSizeX) * one; + = (((tensorAsizeX->expression() + workgroupSizeX) - one) / workgroupSizeX) * one; EXPECT_TRUE(Expression::identical(expected, workitemCount[0])); } From a6c9f49de30e548cb1d88623af500e7050220591 Mon Sep 17 00:00:00 2001 From: Yoonseo Choi Date: Wed, 12 Mar 2025 23:27:03 +0000 Subject: [PATCH 4/4] Restore jamming factors in common GEMM graph and restore original tests --- .../catch/KernelGraphRemoveDuplicatesTest.cpp | 12 ++++---- test/common/CommonGraphs.cpp | 29 +++++++++++++++---- test/common/common/CommonGraphs.hpp | 4 --- test/unit/KernelGraphTest.cpp | 12 ++++---- 4 files changed, 35 insertions(+), 22 deletions(-) diff --git a/test/catch/KernelGraphRemoveDuplicatesTest.cpp b/test/catch/KernelGraphRemoveDuplicatesTest.cpp index 01812d2a..595793f0 100644 --- a/test/catch/KernelGraphRemoveDuplicatesTest.cpp +++ b/test/catch/KernelGraphRemoveDuplicatesTest.cpp @@ -44,7 +44,7 @@ TEST_CASE("Remove duplicates", "[kernel-graph]") auto ctx = TestContext::ForDefaultTarget().get(); auto example = rocRollerTest::Graphs::GEMM(DataType::Float); - example.setTileSize(128, 128, 32); + example.setTileSize(128, 64, 32); example.setMFMA(32, 32, 16, 1); example.setUseLDS(true, true, false); @@ -72,16 +72,16 @@ TEST_CASE("Remove duplicates", "[kernel-graph]") // LoadTiled: A A, B B, C C // After removing 2x1 jamming: A, B, C C - CHECK(graph0.control.getElements().to().size() == 3); - CHECK(graph1.control.getElements().to().size() == 3); + CHECK(graph0.control.getElements().to().size() == 6); + CHECK(graph1.control.getElements().to().size() == 4); // StoreLDSTile: A A, B B // After removing 2x1 jamming: A, B - CHECK(graph0.control.getElements().to().size() == 2); + CHECK(graph0.control.getElements().to().size() == 4); CHECK(graph1.control.getElements().to().size() == 2); // LoadLDSTile: A A A A, B B B B // After removing 2x1 jamming: A A A A, B B - CHECK(graph0.control.getElements().to().size() == 4); - CHECK(graph1.control.getElements().to().size() == 4); + CHECK(graph0.control.getElements().to().size() == 8); + CHECK(graph1.control.getElements().to().size() == 6); } diff --git a/test/common/CommonGraphs.cpp b/test/common/CommonGraphs.cpp index 5e1b4adf..644b62dd 100644 --- a/test/common/CommonGraphs.cpp +++ b/test/common/CommonGraphs.cpp @@ -356,12 +356,33 @@ namespace rocRollerTest::Graphs params->setManualKernelDimension(2); + AssertFatal(m_problem.m % m_problem.macM == 0, + "MacroTile size mismatch (M)", + ShowValue(m_problem.m), + ShowValue(m_problem.macM)); + AssertFatal(m_problem.n % m_problem.macN == 0, + "MacroTile size mismatch (N)", + ShowValue(m_problem.n), + ShowValue(m_problem.macN)); + AssertFatal(m_problem.workgroupSizeX % m_problem.wavefrontSize == 0, "Workgroup Size X must be multiply of wave front size"); + AssertFatal(m_problem.macM % m_problem.waveM == 0, + "Macrotile size must be a multiple of wavetile size"); + AssertFatal(m_problem.macN % m_problem.waveN == 0, + "Macrotile size must be a multiple of wavetile size"); + + // i.e. jammedM uint wavetilePerWavefrontM = m_problem.wavefrontSize * m_problem.macM / m_problem.waveM / m_problem.workgroupSizeX; + AssertFatal(wavetilePerWavefrontM > 0, + "Wavetiles per wavefront in M should be positive integer"); + + // i.e. jammedN uint wavetilePerWavefrontN = m_problem.macN / m_problem.waveN / m_problem.workgroupSizeY; + AssertFatal(wavetilePerWavefrontN > 0, + "Wavetiles per wavefront in N should be positive integer"); AssertFatal(m_problem.macM % (m_problem.waveM * wavetilePerWavefrontM) == 0, "WaveTile size mismatch (M)"); @@ -397,14 +418,10 @@ namespace rocRollerTest::Graphs params->setDimensionInfo(m_tagC, macTileC); params->setDimensionInfo(m_tagD, macTileD); - // uint jammedM - // = m_problem.wavefrontSize * m_problem.macM / m_problem.waveM / workgroupSizeX; - // uint jammedN = m_problem.macN / m_problem.waveN / workgroupSizeY; - Log::debug("GEMM workgroup sizes {} {} {}", workgroupSizeX, workgroupSizeY, 1); - // Log::debug("GEMM jamming {} {}", jammedM, jammedN); - // params->setWaveTilesPerWavefront(jammedM, jammedN); + Log::debug("GEMM jamming {} {}", wavetilePerWavefrontM, wavetilePerWavefrontN); + params->setWaveTilesPerWavefront(wavetilePerWavefrontM, wavetilePerWavefrontN); params->setManualWavefrontCount( {static_cast(m_problem.macM / m_problem.waveM / wavetilePerWavefrontM), static_cast(m_problem.macN / m_problem.waveN / wavetilePerWavefrontN)}); diff --git a/test/common/common/CommonGraphs.hpp b/test/common/common/CommonGraphs.hpp index 03e902c8..ddb4b816 100644 --- a/test/common/common/CommonGraphs.hpp +++ b/test/common/common/CommonGraphs.hpp @@ -213,10 +213,6 @@ namespace rocRollerTest private: void createCommand(); - int m_macM, m_macN, m_macK; - int m_waveM, m_waveN, m_waveK, m_waveB; - bool m_useLDSA = false, m_useLDSB = false, m_useLDSD = false; - rocRoller::Operations::OperationTag m_tagA, m_tagB, m_tagC, m_tagD; CommandPtr m_command; diff --git a/test/unit/KernelGraphTest.cpp b/test/unit/KernelGraphTest.cpp index edfd3e01..aa31ef94 100644 --- a/test/unit/KernelGraphTest.cpp +++ b/test/unit/KernelGraphTest.cpp @@ -1090,7 +1090,7 @@ namespace KernelGraphTest int macK = 16; int waveK = 8; - example.setTileSize(128, 256, macK); + example.setTileSize(128, 128, macK); example.setMFMA(32, 32, waveK, 1); example.setUseLDS(true, false, false); @@ -1135,7 +1135,7 @@ namespace KernelGraphTest // Verify that loops have been unrolled auto unrolledForLoops = kgraphUnrolled.control.getNodes().to(); - EXPECT_EQ(unrolledForLoops.size(), 5); // main: X (Y (K K)) (Y (K K)); epilogue: X (Y Y) + EXPECT_EQ(unrolledForLoops.size(), 10); // main: X (Y (K K)) (Y (K K)); epilogue: X (Y Y) auto kgraphFused = kgraphUnrolled.transform(fuseLoopsTransform); kgraphFused = kgraphFused.transform(removeDuplicatesTransform); @@ -1145,7 +1145,7 @@ namespace KernelGraphTest EXPECT_EQ(fusedForLoops.size(), 5); auto fusedLoads = kgraphFused.control.getNodes().to(); - EXPECT_EQ(fusedLoads.size(), 4); // 1 for A, 4 for B, 4 for C + EXPECT_EQ(fusedLoads.size(), 9); // 1 for A, 4 for B, 4 for C // Verify that single iteration loops have been removed. auto kgraphClean = kgraphFused.transform(cleanLoopsTransform); @@ -1154,13 +1154,13 @@ namespace KernelGraphTest // Verify that there is only a single StoreLDSTile node per K loop auto unrolledStoreLDS = kgraphUnrolled.control.getNodes().to(); - EXPECT_EQ(unrolledStoreLDS.size(), 1); + EXPECT_EQ(unrolledStoreLDS.size(), 4); // Verify number of ComputeIndexes: A loads; A LDS loads; B loads; C load; D // store: 3 + (2+2) + 3 + 3 + 3 = 12 kgraph1 = kgraph1.transform(addComputeIndexTransform); auto computeIndexes = kgraph1.control.getNodes().to(); - EXPECT_EQ(computeIndexes.size(), 15); + EXPECT_EQ(computeIndexes.size(), 16); // Verify number of Deallocates auto addDeallocate = std::make_shared(); @@ -1169,7 +1169,7 @@ namespace KernelGraphTest EXPECT_EQ(addDeallocates.size(), 16); auto storeLDS = kgraphUnrolled.control.getNodes().to(); - EXPECT_EQ(storeLDS.size(), 1); + EXPECT_EQ(storeLDS.size(), 4); auto fusedStoreLDS = kgraphFused.control.getNodes().to(); EXPECT_EQ(fusedStoreLDS.size(), 1);