diff --git a/test/catch/CommandTest.cpp b/test/catch/CommandTest.cpp index 85bce25a..2bfd8918 100644 --- a/test/catch/CommandTest.cpp +++ b/test/catch/CommandTest.cpp @@ -38,9 +38,9 @@ namespace CommandTest { SECTION("GEMM/TileAdd") { - auto example1 = *rocRollerTest::Graphs::GEMM().getCommand(); - auto example2 = *rocRollerTest::Graphs::GEMM().getCommand(); - auto example3 = *rocRollerTest::Graphs::GEMM().getCommand(); + auto example1 = *rocRollerTest::Graphs::GEMM(DataType::Float).getCommand(); + auto example2 = *rocRollerTest::Graphs::GEMM(DataType::Float).getCommand(); + auto example3 = *rocRollerTest::Graphs::GEMM(DataType::Half).getCommand(); auto example4 = *rocRollerTest::Graphs::TileDoubleAdd().getCommand(); CHECK(example1 == example2); @@ -58,7 +58,7 @@ namespace CommandTest { SECTION("GEMM") { - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); auto command0 = example.getCommand(); auto yaml = Command::toYAML(*command0); diff --git a/test/catch/IdentifyParallelDimensionsTest.cpp b/test/catch/IdentifyParallelDimensionsTest.cpp index b20ef993..8dc35f4c 100644 --- a/test/catch/IdentifyParallelDimensionsTest.cpp +++ b/test/catch/IdentifyParallelDimensionsTest.cpp @@ -74,7 +74,7 @@ TEST_CASE("identifyParallelDimensionSets works for GEMM", "[kernel-graph]") using namespace rocRoller; auto ctx = TestContext::ForDefaultTarget(); - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); auto kgraph = KernelGraph::translate(example.getCommand()); @@ -175,7 +175,7 @@ SCENARIO("IdentifyParallelDimensions transformation works for GEMM", "[kernel-gr using namespace rocRoller; auto ctx = TestContext::ForDefaultTarget(); - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); GIVEN("The initial kernel graph for a GEMM") { diff --git a/test/catch/KernelGraphRemoveDuplicatesTest.cpp b/test/catch/KernelGraphRemoveDuplicatesTest.cpp index 9bb470c5..595793f0 100644 --- a/test/catch/KernelGraphRemoveDuplicatesTest.cpp +++ b/test/catch/KernelGraphRemoveDuplicatesTest.cpp @@ -42,9 +42,9 @@ TEST_CASE("Remove duplicates", "[kernel-graph]") using namespace rocRoller::KernelGraph::ControlGraph; auto ctx = TestContext::ForDefaultTarget().get(); - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); - example.setTileSize(128, 128, 32); + example.setTileSize(128, 64, 32); example.setMFMA(32, 32, 16, 1); example.setUseLDS(true, true, false); diff --git a/test/common/CommonGraphs.cpp b/test/common/CommonGraphs.cpp index c0271b56..644b62dd 100644 --- a/test/common/CommonGraphs.cpp +++ b/test/common/CommonGraphs.cpp @@ -193,4 +193,260 @@ namespace rocRollerTest::Graphs return params; } + GEMM::GEMM(DataType ta) + : GEMM(ta, ta) + { + } + GEMM::GEMM(DataType ta, DataType tb) + : GEMM(ta, tb, tb) + { + } + GEMM::GEMM(DataType ta, DataType tb, DataType tc) + : GEMM(ta, tb, tc, tc) + { + } + GEMM::GEMM(DataType ta, DataType tb, DataType tc, DataType td) + : mTa(ta) + , mTb(tb) + , mTc(tc) + , mTd(td) + { + } + + void GEMM::createCommand() + { + m_command = std::make_shared(); + + std::vector oneStridesN + = m_problem.literalStrides ? std::vector({(size_t)1}) : std::vector({}); + + std::vector oneStridesT = m_problem.literalStrides + ? std::vector({(size_t)0, (size_t)1}) + : std::vector({}); + + mTagTensorA = m_command->addOperation(rocRoller::Operations::Tensor( + 2, mTa, m_problem.transA == "N" ? oneStridesN : oneStridesT)); // A + + m_tagA = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorA)); + + mTagTensorB = m_command->addOperation(rocRoller::Operations::Tensor( + 2, mTb, m_problem.transB == "N" ? oneStridesN : oneStridesT)); // B + m_tagB = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorB)); + + mTagTensorC + = m_command->addOperation(rocRoller::Operations::Tensor(2, mTc, oneStridesN)); // C + m_tagC = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorC)); + + mTagScalarAlpha + = m_command->addOperation(rocRoller::Operations::Scalar(DataType::Float)); // alpha + auto tagLoadAlpha + = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(mTagScalarAlpha)); + + mTagScalarBeta = m_command->addOperation(rocRoller::Operations::Scalar(mTc)); // beta + auto tagLoadBeta + = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(mTagScalarBeta)); // beta + + auto tagAB = m_command->addOperation(rocRoller::Operations::T_Mul(m_tagA, m_tagB)); // A * B + + rocRoller::Operations::T_Execute execute(m_command->getNextTag()); + + auto tagBetaC + = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadBeta, m_tagC)); // beta * C + + auto tagAlphaAB + = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadAlpha, tagAB)); // alpha * (A * B) + + if(m_problem.betaInFma) + { + m_tagD = execute.addXOp(rocRoller::Operations::E_Add(tagBetaC, tagAlphaAB)); + // alpha * (A * B) + beta * C + } + else + { + m_tagD = execute.addXOp(rocRoller::Operations::E_Add(tagAlphaAB, tagBetaC)); + // alpha * (A * B) + beta * C + } + m_command->addOperation(std::move(execute)); + + mTagTensorD + = m_command->addOperation(rocRoller::Operations::Tensor(2, mTd, oneStridesN)); // D + m_command->addOperation(rocRoller::Operations::T_Store_Tiled(m_tagD, mTagTensorD)); // D + + if(m_problem.streamK) + { + mTagNumWGs = m_command->allocateTag(); + auto numWGsArg = m_command->allocateArgument(DataType::UInt32, + mTagNumWGs, + ArgumentType::Value, + DataDirection::ReadOnly, + rocRoller::NUMWGS); + } + + mTagScratch = m_command->allocateTag(); + m_command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), + mTagScratch, + ArgumentType::Value, + DataDirection::ReadWrite, + rocRoller::SCRATCH); + } + + CommandPtr GEMM::getCommand() + { + if(!m_command) + createCommand(); + + return m_command; + } + + KernelGraph GEMM::getKernelGraph() + { + return rocRoller::KernelGraph::translate(getCommand()); + } + + void GEMM::setTileSize(int m, int n, int k) + { + m_problem.macM = m; + m_problem.macN = n; + m_problem.macK = k; + } + + void GEMM::setMFMA(int m, int n, int k, int b) + { + m_problem.waveM = m; + m_problem.waveN = n; + m_problem.waveK = k; + m_problem.waveB = b; + } + + void GEMM::setUseLDS(bool a, bool b, bool d) + { + m_problem.loadLDSA = a; + m_problem.loadLDSB = b; + m_problem.storeLDSD = d; + } + + void GEMM::setPrefetch(bool prefetch, + int prefetchInFlight, + int prefetchLDSFactor, + bool prefetchMixMemOps) + { + m_problem.prefetch = prefetch; + m_problem.prefetchInFlight = prefetchInFlight; + m_problem.prefetchLDSFactor = prefetchLDSFactor; + m_problem.prefetchMixMemOps = prefetchMixMemOps; + + m_problem.unrollK = prefetchInFlight; + } + + void GEMM::setProblem(GEMMProblem const& problem) + { + m_problem = problem; + } + + GEMMProblem const& GEMM::getProblem() const + { + return m_problem; + } + + CommandParametersPtr GEMM::getCommandParameters() const + { + using namespace rocRoller::KernelGraph::CoordinateGraph; + + auto params = std::make_shared(); + + params->setManualKernelDimension(2); + + AssertFatal(m_problem.m % m_problem.macM == 0, + "MacroTile size mismatch (M)", + ShowValue(m_problem.m), + ShowValue(m_problem.macM)); + AssertFatal(m_problem.n % m_problem.macN == 0, + "MacroTile size mismatch (N)", + ShowValue(m_problem.n), + ShowValue(m_problem.macN)); + + AssertFatal(m_problem.workgroupSizeX % m_problem.wavefrontSize == 0, + "Workgroup Size X must be multiply of wave front size"); + + AssertFatal(m_problem.macM % m_problem.waveM == 0, + "Macrotile size must be a multiple of wavetile size"); + AssertFatal(m_problem.macN % m_problem.waveN == 0, + "Macrotile size must be a multiple of wavetile size"); + + // i.e. jammedM + uint wavetilePerWavefrontM + = m_problem.wavefrontSize * m_problem.macM / m_problem.waveM / m_problem.workgroupSizeX; + AssertFatal(wavetilePerWavefrontM > 0, + "Wavetiles per wavefront in M should be positive integer"); + + // i.e. jammedN + uint wavetilePerWavefrontN = m_problem.macN / m_problem.waveN / m_problem.workgroupSizeY; + AssertFatal(wavetilePerWavefrontN > 0, + "Wavetiles per wavefront in N should be positive integer"); + + AssertFatal(m_problem.macM % (m_problem.waveM * wavetilePerWavefrontM) == 0, + "WaveTile size mismatch (M)"); + AssertFatal(m_problem.macN % (m_problem.waveN * wavetilePerWavefrontN) == 0, + "WaveTile size mismatch (N)"); + + uint workgroupSizeX = m_problem.workgroupSizeX * m_problem.workgroupSizeY; + uint workgroupSizeY = 1; + params->setManualWorkgroupSize({workgroupSizeX, workgroupSizeY, 1}); + + auto macTileA + = MacroTile({m_problem.macM, m_problem.macK}, + LayoutType::MATRIX_A, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}, + m_problem.loadLDSA ? MemoryType::LDS : MemoryType::WAVE); + auto macTileB + = MacroTile({m_problem.macK, m_problem.macN}, + LayoutType::MATRIX_B, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}, + m_problem.loadLDSB ? MemoryType::LDS : MemoryType::WAVE); + auto macTileC + = MacroTile({m_problem.macM, m_problem.macN}, + LayoutType::MATRIX_ACCUMULATOR, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}); + auto macTileD + = MacroTile({m_problem.macM, m_problem.macN}, + LayoutType::MATRIX_ACCUMULATOR, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}, + m_problem.storeLDSD ? MemoryType::LDS : MemoryType::WAVE); + + params->setDimensionInfo(m_tagA, macTileA); + params->setDimensionInfo(m_tagB, macTileB); + params->setDimensionInfo(m_tagC, macTileC); + params->setDimensionInfo(m_tagD, macTileD); + + Log::debug("GEMM workgroup sizes {} {} {}", workgroupSizeX, workgroupSizeY, 1); + Log::debug("GEMM jamming {} {}", wavetilePerWavefrontM, wavetilePerWavefrontN); + + params->setWaveTilesPerWavefront(wavetilePerWavefrontM, wavetilePerWavefrontN); + params->setManualWavefrontCount( + {static_cast(m_problem.macM / m_problem.waveM / wavetilePerWavefrontM), + static_cast(m_problem.macN / m_problem.waveN / wavetilePerWavefrontN)}); + + params->fuseLoops = m_problem.fuseLoops; + params->tailLoops = m_problem.tailLoops; + params->allowAmbiguousMemoryNodes = m_problem.allowAmbiguousMemoryNodes; + params->unrollK = m_problem.unrollK; + params->packMultipleElementsInto1VGPR = m_problem.packMultipleElementsInto1VGPR; + params->prefetch = m_problem.prefetch; + params->prefetchInFlight = m_problem.prefetchInFlight; + params->prefetchLDSFactor = m_problem.prefetchLDSFactor; + params->prefetchMixMemOps = m_problem.prefetchMixMemOps; + params->transposeMemoryAccess[LayoutType::MATRIX_A] = m_problem.transA == "T"; + params->transposeMemoryAccess[LayoutType::MATRIX_B] = m_problem.transB == "T"; + params->transposeMemoryAccess[LayoutType::None] = true; + + if(m_problem.streamK) + { + params->loopOverOutputTilesDimensions = {0, 1}; + params->streamK = true; + params->streamKTwoTile = m_problem.streamKTwoTile; + } + + return params; + } + } diff --git a/test/common/common/CommonGraphs.hpp b/test/common/common/CommonGraphs.hpp index 4ee7a3a4..ddb4b816 100644 --- a/test/common/common/CommonGraphs.hpp +++ b/test/common/common/CommonGraphs.hpp @@ -40,6 +40,8 @@ #include #include +#include + namespace rocRollerTest { namespace Graphs @@ -52,6 +54,7 @@ namespace rocRollerTest using ContextPtr = rocRoller::ContextPtr; using KernelArguments = rocRoller::KernelArguments; using KernelGraph = rocRoller::KernelGraph::KernelGraph; + using DataType = rocRoller::DataType; /** * @brief Graph for linear: alpha x + beta y. @@ -178,11 +181,13 @@ namespace rocRollerTest * - Assign(D = alpha * AB + beta * C) * - StoreTiled(D) */ - template class GEMM { public: - GEMM(); + GEMM(DataType ta); + GEMM(DataType ta, DataType tb); + GEMM(DataType ta, DataType tb, DataType tc); + GEMM(DataType ta, DataType tb, DataType tc, DataType td); CommandPtr getCommand(); KernelGraph getKernelGraph(); @@ -190,19 +195,28 @@ namespace rocRollerTest void setTileSize(int m, int n, int k); void setMFMA(int m, int n, int k, int b); void setUseLDS(bool a, bool b, bool d); + void setPrefetch(bool prefetch, + int prefetchInFlight, + int prefetchLDSFactor, + bool prefetchMixMemOps); + void setProblem(GEMMProblem const& problem); + GEMMProblem const& getProblem() const; CommandParametersPtr getCommandParameters() const; + rocRoller::Operations::OperationTag mTagTensorA, mTagTensorB, mTagTensorC, mTagTensorD, + mTagScalarAlpha, mTagScalarBeta, mTagScalarSeed, mTagScratch; + rocRoller::Operations::OperationTag mTagNumWGs; + + DataType mTa, mTb, mTc, mTd; + private: void createCommand(); - int m_macM, m_macN, m_macK; - int m_waveM, m_waveN, m_waveK, m_waveB; - bool m_useLDSA = false, m_useLDSB = false, m_useLDSD = false; - rocRoller::Operations::OperationTag m_tagA, m_tagB, m_tagC, m_tagD; - CommandPtr m_command; + CommandPtr m_command; + GEMMProblem m_problem; }; /** diff --git a/test/common/common/CommonGraphs_impl.hpp b/test/common/common/CommonGraphs_impl.hpp index ec097436..9a8d6201 100644 --- a/test/common/common/CommonGraphs_impl.hpp +++ b/test/common/common/CommonGraphs_impl.hpp @@ -256,132 +256,6 @@ namespace rocRollerTest::Graphs return rocRoller::KernelGraph::translate(m_command); } - /* - * GEMM - */ - - template - GEMM::GEMM() - { - createCommand(); - } - - template - void GEMM::createCommand() - { - m_command = std::make_shared(); - - auto dataType = TypeInfo::Var.dataType; - - auto tagTensorA = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // A - m_tagA = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(tagTensorA)); - - auto tagTensorB = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // B - m_tagB = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(tagTensorB)); - - auto tagTensorC = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // C - m_tagC = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(tagTensorC)); - - auto tagScalarAlpha - = m_command->addOperation(rocRoller::Operations::Scalar(dataType)); // alpha - auto tagLoadAlpha - = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(tagScalarAlpha)); - - auto tagScalarBeta - = m_command->addOperation(rocRoller::Operations::Scalar(dataType)); // beta - auto tagLoadBeta - = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(tagScalarBeta)); // beta - - auto tagAB = m_command->addOperation(rocRoller::Operations::T_Mul(m_tagA, m_tagB)); // A * B - - rocRoller::Operations::T_Execute execute(m_command->getNextTag()); - auto tagAlphaAB - = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadAlpha, tagAB)); // alpha * (A * B) - auto tagBetaC - = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadBeta, m_tagC)); // beta * C - m_tagD = execute.addXOp(rocRoller::Operations::E_Add(tagAlphaAB, tagBetaC)); - // alpha * (A * B) + beta * C - m_command->addOperation(std::move(execute)); - - auto tagTensorD = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // D - m_command->addOperation(rocRoller::Operations::T_Store_Tiled(m_tagD, tagTensorD)); // D - } - - template - CommandPtr GEMM::getCommand() - { - return m_command; - } - - template - KernelGraph GEMM::getKernelGraph() - { - return rocRoller::KernelGraph::translate(m_command); - } - - template - void GEMM::setTileSize(int m, int n, int k) - { - m_macM = m; - m_macN = n; - m_macK = k; - } - - template - void GEMM::setMFMA(int m, int n, int k, int b) - { - m_waveM = m; - m_waveN = n; - m_waveK = k; - m_waveB = b; - } - - template - void GEMM::setUseLDS(bool a, bool b, bool d) - { - m_useLDSA = a; - m_useLDSB = b; - m_useLDSD = d; - } - - template - CommandParametersPtr GEMM::getCommandParameters() const - { - using namespace rocRoller::KernelGraph::CoordinateGraph; - - auto params = std::make_shared(); - - auto macTileA = MacroTile({m_macM, m_macK}, - LayoutType::MATRIX_A, - {m_waveM, m_waveN, m_waveK, m_waveB}, - m_useLDSA ? MemoryType::LDS : MemoryType::WAVE); - auto macTileB = MacroTile({m_macK, m_macN}, - LayoutType::MATRIX_B, - {m_waveM, m_waveN, m_waveK, m_waveB}, - m_useLDSB ? MemoryType::LDS : MemoryType::WAVE); - auto macTileC = MacroTile( - {m_macM, m_macN}, LayoutType::MATRIX_ACCUMULATOR, {m_waveM, m_waveN, m_waveK, m_waveB}); - - params->setDimensionInfo(m_tagA, macTileA); - params->setDimensionInfo(m_tagB, macTileB); - params->setDimensionInfo(m_tagC, macTileC); - - // Workgroup size - uint wavefrontSize = 64; - uint workgroupSizeX = 2 * wavefrontSize; - uint workgroupSizeY = 4; - - uint jammedM = wavefrontSize * m_macM / m_waveM / workgroupSizeX; - uint jammedN = m_macN / m_waveN / workgroupSizeY; - - Log::debug("GEMM workgroup sizes {} {} {}", workgroupSizeX, workgroupSizeY, 1); - Log::debug("GEMM jamming {} {}", jammedM, jammedN); - - params->setWaveTilesPerWavefront(jammedM, jammedN); - - return params; - } - /* * TileDoubleAdd */ diff --git a/test/unit/KernelGraphTest.cpp b/test/unit/KernelGraphTest.cpp index 72f602aa..1b084440 100644 --- a/test/unit/KernelGraphTest.cpp +++ b/test/unit/KernelGraphTest.cpp @@ -1250,12 +1250,12 @@ namespace KernelGraphTest TEST_F(KernelGraphTest, LowerTensor) { - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); int macK = 16; int waveK = 8; - example.setTileSize(128, 256, macK); + example.setTileSize(128, 128, macK); example.setMFMA(32, 32, waveK, 1); example.setUseLDS(true, false, false); @@ -1342,7 +1342,7 @@ namespace KernelGraphTest TEST_F(KernelGraphTest, InlineIncrement) { - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); example.setTileSize(128, 256, 8); example.setMFMA(32, 32, 2, 1); @@ -2917,7 +2917,7 @@ namespace KernelGraphTest { using GD = Graph::Direction; - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); example.setTileSize(128, 256, 8); example.setMFMA(32, 32, 2, 1); diff --git a/test/unit/KernelGraphTest/UpdateParameters.cpp b/test/unit/KernelGraphTest/UpdateParameters.cpp index dd101c5c..e72f9960 100644 --- a/test/unit/KernelGraphTest/UpdateParameters.cpp +++ b/test/unit/KernelGraphTest/UpdateParameters.cpp @@ -88,7 +88,7 @@ namespace KernelGraphTest { using namespace rocRoller::KernelGraph; - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); int macK = 16; int waveK = 8; @@ -118,16 +118,16 @@ namespace KernelGraphTest // Now apply SetWorkitemCount and try again kgraph = kgraph.transform(std::make_shared(m_context)); - CommandArgumentPtr tensorDsizeX; + CommandArgumentPtr tensorAsizeX; { auto arguments = command->getArguments(); for(auto argument : arguments) { - if(argument->name() == "Tensor_4_size_0") - tensorDsizeX = argument; + if(argument->name() == "Tensor_0_size_0") + tensorAsizeX = argument; } } - ASSERT_NE(tensorDsizeX, nullptr) << "D size not found"; + ASSERT_NE(tensorAsizeX, nullptr) << "A size not found"; workitemCount = m_context->kernel()->workitemCount(); @@ -135,7 +135,7 @@ namespace KernelGraphTest auto workgroupSizeX = Expression::literal(128u); auto expected - = (((tensorDsizeX->expression() + workgroupSizeX) - one) / workgroupSizeX) * one; + = (((tensorAsizeX->expression() + workgroupSizeX) - one) / workgroupSizeX) * one; EXPECT_TRUE(Expression::identical(expected, workitemCount[0])); }