diff --git a/client/include/DataParallelGEMMSolution.hpp b/client/include/DataParallelGEMMSolution.hpp index 182295c9..9e9a0b7a 100644 --- a/client/include/DataParallelGEMMSolution.hpp +++ b/client/include/DataParallelGEMMSolution.hpp @@ -549,7 +549,7 @@ namespace rocRoller // predicates // unrollK size match predicates - if(params->unrollX <= 1 && params->unrollY <= 1 && !params->streamK) + if(params->tailLoops && !params->streamK) { auto unrollKPredicate = (aSizeExps[1] % macKExp == zero); setComment(unrollKPredicate, "K must be a multiple of macK."); diff --git a/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp b/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp index aebc18ad..ce375f8e 100644 --- a/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp +++ b/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp @@ -70,6 +70,11 @@ namespace rocRoller::KernelGraph return a.id < b.id; } + bool inline operator==(TypeAndSubDimension const& a, TypeAndSubDimension const& b) + { + return a.id == b.id && a.subdimension == b.subdimension; + } + struct TypeAndNaryArgument { std::string id; @@ -89,6 +94,11 @@ namespace rocRoller::KernelGraph return a.id < b.id; } + bool inline operator==(TypeAndNaryArgument const& a, TypeAndNaryArgument const& b) + { + return a.id == b.id && a.argument == b.argument; + } + enum class ComputeIndexArgument : int { TARGET = 0, @@ -117,6 +127,11 @@ namespace rocRoller::KernelGraph return a.argument < b.argument; } + bool inline operator==(ComputeIndex const& a, ComputeIndex const& b) + { + return a.argument == b.argument && a.index == b.index; + } + using ConnectionSpec = std::variant postConstraints() const override; + std::string name() const override { return "FuseLoops"; } diff --git a/lib/source/KernelGraph/Transformations/FuseLoops.cpp b/lib/source/KernelGraph/Transformations/FuseLoops.cpp index af11039a..737f92e3 100644 --- a/lib/source/KernelGraph/Transformations/FuseLoops.cpp +++ b/lib/source/KernelGraph/Transformations/FuseLoops.cpp @@ -56,49 +56,145 @@ namespace rocRoller using GD = rocRoller::Graph::Direction; /** - * @brief Find a path from a node to a ForLoopOp using only Sequence edges - * - * Returns an empty vector if no path is found. + * @brief Gather all for loops to fuse below the starting node * * @param graph * @param start - * @return std::vector + * @return std::vector> */ - std::vector pathToForLoop(KernelGraph& graph, int start) + std::vector> gatherForLoops(KernelGraph& graph, int start) { - // Find the first ForLoop under the node - auto allForLoops - = graph.control - .findNodes( - start, - [&](int tag) -> bool { - return isOperation(graph.control.getElement(tag)); - }, - GD::Downstream) - .to(); - - if(allForLoops.empty()) - return {}; - - auto firstForLoop = allForLoops[0]; - - // Find all of the nodes in between the node and the first for loop - auto pathToLoopWithEdges = graph.control - .path(std::vector{start}, - std::vector{firstForLoop}) - .to(); - - // Filter out only the nodes - std::vector pathToLoop; - std::copy_if(pathToLoopWithEdges.begin(), - pathToLoopWithEdges.end(), - std::back_inserter(pathToLoop), - [&](int tag) -> bool { - return graph.control.getElementType(tag) - == Graph::ElementType::Node; - }); - - return pathToLoop; + auto bodies = graph.control.getOutputNodeIndices(start).to(); + auto isForLoopOp = graph.control.isElemType(); + auto isSetCoordinate = graph.control.isElemType(); + auto isSequence = graph.control.isElemType(); + auto isBody = graph.control.isElemType(); + + // Get a set of all ForLoops and SetCoordinates which contain (tail) ForLoops + auto maybeForLoops + = graph.control.depthFirstVisit(bodies, isSequence, GD::Downstream) + .filter([&](int tag) -> bool { + if(isForLoopOp(tag)) + { + return true; + } + if(isSetCoordinate(tag)) + { + return graph.mapper.get(tag) > 0; + } + return false; + }) + .to(); + + // Filter the previous set of nodes to only the ForLoops under consideration + std::vector forLoops; + for(auto const& maybeForLoop : maybeForLoops) + { + if(isForLoopOp(maybeForLoop)) + { + forLoops.insert(forLoops.begin(), maybeForLoop); + } + else + { + // The filter means that this is now a SetCoordinate connected to an Unroll dimension + // Tail loops are always downstream of these + std::optional tag = maybeForLoop; + while(tag && isSetCoordinate(*tag)) + { + tag = only(graph.control.getOutputNodeIndices(*tag)); + } + if(tag && isForLoopOp(*tag)) + { + forLoops.push_back(*tag); + } + } + } + + // Determine sets of loops to be fused together + // Currently loops should only be fused if they're the "same" loop (from unrolling) + std::vector> loopGroupsToFuse; + while(!forLoops.empty()) + { + std::unordered_set loopGroup; + Expression::ExpressionPtr loopIncrement; + Expression::ExpressionPtr loopLength; + std::map baseLoopContents; + for(auto const& loop : forLoops) + { + if(loopGroup.count(loop) != 0) + continue; + + auto loopDim = getSize(std::get(graph.coordinates.getElement( + graph.mapper.get(loop, NaryArgument::DEST)))); + + // Loops to be fused must have the same length + if(loopLength) + { + if(!identical(loopDim, loopLength)) + continue; + } + else + { + loopLength = loopDim; + } + + // Loops to be fused must have the same increment value + auto [dataTag, increment] = getForLoopIncrement(graph, loop); + if(loopIncrement) + { + if(!identical(loopIncrement, increment)) + continue; + } + else + { + loopIncrement = increment; + } + + // Loop similarity heuristic + // We only want to fuse loops which are "the same" or similar. + // Currently this is just counting the number of each type of node in the body. + // We may want to replace this with a better heuristic in the future. + auto loopContentsGenerator = graph.control.nodesInBody(loop); + std::map loopContents; + std::for_each(loopContentsGenerator.begin(), + loopContentsGenerator.end(), + [&](int tag) -> void { + auto type = graph.control.get(tag)->index(); + if(loopContents.contains(type)) + { + loopContents[type]++; + } + else + { + loopContents[type] = 1; + } + }); + + if(!baseLoopContents.empty()) + { + if(baseLoopContents != loopContents) + continue; + } + else + { + baseLoopContents = loopContents; + } + + loopGroup.insert(loop); + } + + for(auto loop : loopGroup) + { + std::erase(forLoops, loop); + } + + if(loopGroup.size() > 1) + { + loopGroupsToFuse.push_back(loopGroup); + } + } + + return loopGroupsToFuse; } /** @@ -192,207 +288,270 @@ namespace rocRoller } } - void fuseLoops(KernelGraph& graph, int tag) + /** + * @brief General routine to fuse one node into another + * + * @param graph + * @param fusedNodeTag + * @param nodeTag + */ + void fuseNode(KernelGraph& graph, int fusedNodeTag, int nodeTag) { - rocRoller::Log::getLogger()->debug("KernelGraph::fuseLoops({})", tag); - - auto dontWalkPastForLoop = [&](int tag) -> bool { - for(auto neighbour : graph.control.getNeighbours(tag, GD::Downstream)) + // Move operations that come after `nodeTag` to follow `fusedNodeTag`. + for(auto const& child : + graph.control.getOutputNodeIndices(nodeTag).to()) + { + if(fusedNodeTag != child) { - if(graph.control.get(neighbour)) + graph.control.addElement(Sequence(), {fusedNodeTag}, {child}); + } + graph.control.deleteElement(std::vector{nodeTag}, + std::vector{child}); + // Make sure we don't introduce a cycle + std::unordered_set toDelete; + for(auto descSeqOfChild : + filter(graph.control.isElemType(), + graph.control.depthFirstVisit(child, GD::Downstream))) + { + if(graph.control.getNeighbours(descSeqOfChild) + .to() + .contains(fusedNodeTag)) { - return false; + toDelete.insert(descSeqOfChild); } } - return true; - }; - - // Find all of the paths from the top of one of a body to a - // ForLoopOp. - auto bodies = graph.control.getOutputNodeIndices(tag).to(); - std::vector> paths; - for(auto const& body : bodies) + for(auto edge : toDelete) + { + graph.control.deleteElement(edge); + } + } + + // Fuse the bodies + for(auto const& child : + graph.control.getOutputNodeIndices(nodeTag).to()) { - auto path = pathToForLoop(graph, body); - if(!path.empty()) - paths.push_back(path); + graph.control.addElement(Body(), {fusedNodeTag}, {child}); + graph.control.deleteElement(std::vector{nodeTag}, + std::vector{child}); } - // See if any of the ForLoopOps that were found in paths - // should be fused together. - std::unordered_set forLoopsToFuse; - Expression::ExpressionPtr loopIncrement; - Expression::ExpressionPtr loopLength; - for(auto const& path : paths) + // Make sure dependencies are satisfied + for(auto const& parent : + graph.control.getInputNodeIndices(nodeTag).to()) { - auto forLoop = path.back(); - if(forLoopsToFuse.count(forLoop) != 0) - return; + auto descOfFusedLoop + = graph.control + .depthFirstVisit(fusedNodeTag, + graph.control.isElemType(), + GD::Downstream) + .to(); + + if(!descOfFusedLoop.contains(parent)) + { + graph.control.addElement(Sequence(), {parent}, {fusedNodeTag}); + } + graph.control.deleteElement(std::vector{parent}, + std::vector{nodeTag}); + } + + for(auto const& parent : + graph.control.getInputNodeIndices(nodeTag).to()) + { + graph.control.addElement(Body(), {parent}, {fusedNodeTag}); + graph.control.deleteElement(std::vector{parent}, + std::vector{nodeTag}); + } + } + + /** + * @brief Visitor to determine if two nodes are the "same" operation of the purposes of fusion + * + */ + struct IsSameOperationVisitor + { + template + bool operator()(int, OpA const&, int, OpB const&) + { + return false; + } - // Check to see if loops are all the same length - auto forLoopDim = getSize(std::get(graph.coordinates.getElement( - graph.mapper.get(forLoop, NaryArgument::DEST)))); - if(loopLength) + bool operator()(int tagA, SetCoordinate const& A, int tagB, SetCoordinate const& B) + { + auto connA = graph.mapper.getConnections(tagA); + auto connB = graph.mapper.getConnections(tagB); + + if(connA.size() != connB.size()) { - if(!identical(forLoopDim, loopLength)) - return; + return false; } - else + for(auto iterA = connA.begin(), iterB = connB.begin(); iterA != connA.end(); + iterA++, iterB++) { - loopLength = forLoopDim; + if(iterA->coordinate != iterB->coordinate) + return false; + if(iterA->connection != iterB->connection) + return false; } + return identical(A.value, B.value); + } + + bool call(int tagA, int tagB) + { + return std::visit(*this, + std::variant(tagA), + graph.control.getNode(tagA), + std::variant(tagB), + graph.control.getNode(tagB)); + } - // Check to see if loops are incremented by the same value - auto [dataTag, increment] = getForLoopIncrement(graph, forLoop); - if(loopIncrement) + KernelGraph const& graph; + }; + + /** + * @brief Walks up the tree, fusing nodes which contain the start in their body and are the same operation + * + * @param graph + * @param tag + */ + void fuseScopes(KernelGraph& graph, int tag) + { + auto parentsWithEdges + = graph.control.getInputNodeIndices(tag).template to(); + std::set> nodeSetsToMerge; + IsSameOperationVisitor visitor{graph}; + + for(auto const& A : parentsWithEdges) + { + std::set sameAsThis; + for(auto const& B : parentsWithEdges) { - if(!identical(loopIncrement, increment)) - return; + if(A == B) + continue; + if(visitor.call(A, B)) + { + sameAsThis.insert(B); + } } - else + if(!sameAsThis.empty()) { - loopIncrement = increment; + sameAsThis.insert(A); + nodeSetsToMerge.insert(sameAsThis); } + } - forLoopsToFuse.insert(forLoop); + for(auto mergeSet : nodeSetsToMerge) + { + auto mergedNodeTag = *mergeSet.begin(); + for(auto const& nodeTag : mergeSet) + { + if(nodeTag == mergedNodeTag) + continue; + fuseNode(graph, mergedNodeTag, nodeTag); + + graph.control.deleteElement(nodeTag); + graph.mapper.purge(nodeTag); + } + fuseScopes(graph, mergedNodeTag); } + } - if(forLoopsToFuse.size() <= 1) - return; + void fuseLoops(KernelGraph& graph, int tag) + { + rocRoller::Log::getLogger()->debug("KernelGraph::fuseLoops({})", tag); - auto fusedLoopTag = *forLoopsToFuse.begin(); - - auto fusedLoopBodyChildren - = graph.control.getOutputNodeIndices(fusedLoopTag).to(); - - auto initializeGroups = [&]() { - std::set> groups; - auto nodes - = filter(graph.control.isElemType(), - graph.control.depthFirstVisit( - fusedLoopBodyChildren, dontWalkPastForLoop, GD::Downstream)) - .template to(); - if(not nodes.empty()) - groups.emplace(getFirstAndLastNodes(graph, nodes)); - return groups; - }; - - auto groups_loads = initializeGroups.template operator()(); - auto groups_ldsLoads = initializeGroups.template operator()(); - auto groups_stores = initializeGroups.template operator()(); - auto groups_ldsStores = initializeGroups.template operator()(); - - for(auto const& forLoopTag : forLoopsToFuse) + auto loopGroupsToFuse = gatherForLoops(graph, tag); + for(auto forLoopsToFuse : loopGroupsToFuse) { - if(forLoopTag == fusedLoopTag) - continue; + if(forLoopsToFuse.size() <= 1) + return; - for(auto const& child : - graph.control.getOutputNodeIndices(forLoopTag).to()) - { - if(fusedLoopTag != child) + auto dontWalkPastForLoop = [&](int tag) -> bool { + for(auto neighbour : graph.control.getNeighbours(tag, GD::Downstream)) { - graph.control.addElement(Sequence(), {fusedLoopTag}, {child}); - } - graph.control.deleteElement(std::vector{forLoopTag}, - std::vector{child}); - std::unordered_set toDelete; - for(auto descSeqOfChild : - filter(graph.control.isElemType(), - graph.control.depthFirstVisit(child, GD::Downstream))) - { - if(graph.control.getNeighbours(descSeqOfChild) - .to() - .contains(fusedLoopTag)) + if(graph.control.get(neighbour)) { - toDelete.insert(descSeqOfChild); + return false; } } - for(auto edge : toDelete) - { - graph.control.deleteElement(edge); - } - } - - // - // Extract the memory nodes in forLoopTag, which will be used - // at the end to order with memory nodes in fusedLoopTag. - // - std::vector loads; - std::vector ldsLoads; - std::vector stores; - std::vector ldsStores; + return true; + }; + + auto fusedLoopTag = *forLoopsToFuse.begin(); + + auto fusedLoopBodyChildren + = graph.control.getOutputNodeIndices(fusedLoopTag).to(); + + auto initializeGroups = [&]() { + std::set> groups; + auto nodes + = filter(graph.control.isElemType(), + graph.control.depthFirstVisit( + fusedLoopBodyChildren, dontWalkPastForLoop, GD::Downstream)) + .template to(); + if(not nodes.empty()) + groups.emplace(getFirstAndLastNodes(graph, nodes)); + return groups; + }; + + auto groups_loads = initializeGroups.template operator()(); + auto groups_ldsLoads = initializeGroups.template operator()(); + auto groups_stores = initializeGroups.template operator()(); + auto groups_ldsStores = initializeGroups.template operator()(); + + for(auto const& forLoopTag : forLoopsToFuse) { - auto children = graph.control.getOutputNodeIndices(forLoopTag) - .to(); + if(forLoopTag == fusedLoopTag) + continue; + + // + // Extract the memory nodes in forLoopTag, which will be used + // at the end to order with memory nodes in fusedLoopTag. + // + std::vector loads; + std::vector ldsLoads; + std::vector stores; + std::vector ldsStores; + { + auto children = graph.control.getOutputNodeIndices(forLoopTag) + .to(); - loads = filter(graph.control.isElemType(), - graph.control.depthFirstVisit( - children, dontWalkPastForLoop, GD::Downstream)) - .to(); - ldsLoads = filter(graph.control.isElemType(), - graph.control.depthFirstVisit( - children, dontWalkPastForLoop, GD::Downstream)) - .to(); - stores = filter(graph.control.isElemType(), + loads = filter(graph.control.isElemType(), graph.control.depthFirstVisit( children, dontWalkPastForLoop, GD::Downstream)) - .to(); - ldsStores = filter(graph.control.isElemType(), - graph.control.depthFirstVisit( - children, dontWalkPastForLoop, GD::Downstream)) .to(); - } - - for(auto const& child : - graph.control.getOutputNodeIndices(forLoopTag).to()) - { - graph.control.addElement(Body(), {fusedLoopTag}, {child}); - graph.control.deleteElement(std::vector{forLoopTag}, - std::vector{child}); - } - - for(auto const& parent : - graph.control.getInputNodeIndices(forLoopTag).to()) - { - auto descOfFusedLoop - = graph.control - .depthFirstVisit(fusedLoopTag, - graph.control.isElemType(), - GD::Downstream) - .to(); - - if(!descOfFusedLoop.contains(parent)) - { - graph.control.addElement(Sequence(), {parent}, {fusedLoopTag}); + ldsLoads = filter(graph.control.isElemType(), + graph.control.depthFirstVisit( + children, dontWalkPastForLoop, GD::Downstream)) + .to(); + stores = filter(graph.control.isElemType(), + graph.control.depthFirstVisit( + children, dontWalkPastForLoop, GD::Downstream)) + .to(); + ldsStores = filter(graph.control.isElemType(), + graph.control.depthFirstVisit( + children, dontWalkPastForLoop, GD::Downstream)) + .to(); } - graph.control.deleteElement(std::vector{parent}, - std::vector{forLoopTag}); - } - for(auto const& parent : - graph.control.getInputNodeIndices(forLoopTag).to()) - { - graph.control.addElement(Body(), {parent}, {fusedLoopTag}); - graph.control.deleteElement(std::vector{parent}, - std::vector{forLoopTag}); + fuseNode(graph, fusedLoopTag, forLoopTag); + purgeFor(graph, forLoopTag); + + // + // Order the memory nodes in forLoopTag with memory + // nodes in fusedLoopTag. + // + // An important assumption here is the memory nodes of + // forLoopTag should be ordered totally already, and + // orderGroups leverages this fact to connect the first and + // last nodes with other groups to achieve total ordering. + // + orderGroups(graph, groups_loads, loads); + orderGroups(graph, groups_ldsLoads, ldsLoads); + orderGroups(graph, groups_stores, stores); + orderGroups(graph, groups_ldsStores, ldsStores); } - purgeFor(graph, forLoopTag); - - // - // Order the memory nodes in forLoopTag with memory - // nodes in fusedLoopTag. - // - // An important assumption here is the memory nodes of - // forLoopTag should be ordered totally already, and - // orderGroups leverages this fact to connect the first and - // last nodes with other groups to achieve total ordering. - // - orderGroups(graph, groups_loads, loads); - orderGroups(graph, groups_ldsLoads, ldsLoads); - orderGroups(graph, groups_stores, stores); - orderGroups(graph, groups_ldsStores, ldsStores); + fuseScopes(graph, fusedLoopTag); } } } @@ -414,5 +573,45 @@ namespace rocRoller return newGraph; } + + /** + * @brief Ensure no node is a child of two nodes which are not contained within each other + * + * @param graph + * @return ConstraintStatus + */ + ConstraintStatus BodyOfOnlyOneNode(const KernelGraph& graph) + { + ConstraintStatus retval; + for(auto tag : graph.control.leaves().to()) + { + auto containing = graph.control.nodesContaining(tag).to(); + for(auto contA : containing) + { + for(auto contB : containing) + { + if(contA == contB) + { + continue; + } + auto order = graph.control.compareNodes(UseCacheIfAvailable, contA, contB); + if(!(order == NodeOrdering::LeftInBodyOfRight + || order == NodeOrdering::RightInBodyOfLeft)) + { + retval.combine( + false, + concatenate( + "Nodes ", contA, " and ", contB, "have intersecting bodies.")); + } + } + } + } + return retval; + } + + std::vector FuseLoops::postConstraints() const + { + return {BodyOfOnlyOneNode}; + } } } diff --git a/lib/source/KernelGraph/Transformations/RemoveDuplicates.cpp b/lib/source/KernelGraph/Transformations/RemoveDuplicates.cpp index c4bbd1b9..235101e8 100644 --- a/lib/source/KernelGraph/Transformations/RemoveDuplicates.cpp +++ b/lib/source/KernelGraph/Transformations/RemoveDuplicates.cpp @@ -358,6 +358,27 @@ namespace rocRoller } } + // If there exists a reindexing "chain" A -> B -> C, then processing + // B -> C before A -> B causes the old references to A to be left + // dangling. To avoid this, we collapse the chain so A -> C and B -> C, + // which matches the behaviour we would see if A -> B was processed first. + for(auto [oldTag, newTag] : expressionReindexer.coordinates) + { + while(expressionReindexer.coordinates.count(newTag) > 0 && newTag != oldTag) + { + auto nextTag = expressionReindexer.coordinates.at(newTag); + if(nextTag != newTag) + { + newTag = expressionReindexer.coordinates.at(newTag); + } + else + { + break; + } + } + expressionReindexer.coordinates[oldTag] = newTag; + } + // Update tile references auto kernel = *graph.control.roots().begin(); reindexExpressions(graph, kernel, expressionReindexer); diff --git a/lib/source/KernelGraph/Transformations/UnrollLoops.cpp b/lib/source/KernelGraph/Transformations/UnrollLoops.cpp index c95c942a..73a6673d 100644 --- a/lib/source/KernelGraph/Transformations/UnrollLoops.cpp +++ b/lib/source/KernelGraph/Transformations/UnrollLoops.cpp @@ -622,22 +622,6 @@ namespace rocRoller } } - { - auto containingForLoops = graph.control.nodesContaining(loop).filter( - graph.control.isElemType()); - for(auto containingLoop : containingForLoops) - { - if(getUnrollAmount(graph, containingLoop, m_params) > 1) - { - Log::debug("Not adding tail loop for {} because it is contained by {} " - "which is also unrolled.", - loop, - containingLoop); - return std::nullopt; - } - } - } - auto loopSizeType = resultVariableType(loopSize); auto amount = Expression::literal(unrollAmount, loopSizeType); auto loopSizeRoundedDown = (loopSize / amount) * amount; diff --git a/scripts/lib/rrperf/rrsuites.py b/scripts/lib/rrperf/rrsuites.py index 1d8233fe..b3e3aed1 100644 --- a/scripts/lib/rrperf/rrsuites.py +++ b/scripts/lib/rrperf/rrsuites.py @@ -430,6 +430,26 @@ def tail_loop_reproducer(): ) +def jammed_tail_loop(): + yield mkGEMM( + M=256, + N=256, + K=24, + trans_A="T", + trans_B="N", + wave_m=32, + wave_n=32, + wave_k=2, + wave_b=1, + mac_m=64, + mac_n=64, + mac_k=8, + unroll_x=2, + unroll_y=2, + # unroll_k=2 implicit due to prefetchInFlight=2 + ) + + def guidepost_1(): yield mkGEMM( HGEMM_7680x8448x8448, diff --git a/test/unit/GEMMFusion.cpp b/test/unit/GEMMFusion.cpp index 62fe3301..626b9352 100644 --- a/test/unit/GEMMFusion.cpp +++ b/test/unit/GEMMFusion.cpp @@ -77,7 +77,7 @@ namespace GEMMDriverTest AssertFatal(M % gemm.macM == 0, "MacroTile size mismatch (M)"); AssertFatal(N % gemm.macN == 0, "MacroTile size mismatch (N)"); - if(gemm.unrollK > 0) + if(gemm.unrollK > 0 && !gemm.tailLoops) { AssertFatal(K % (gemm.macK * gemm.unrollK) == 0, "MacroTile size mismatch (K unroll)"); diff --git a/test/unit/GEMMTest.cpp b/test/unit/GEMMTest.cpp index 4049b2ad..9137a979 100644 --- a/test/unit/GEMMTest.cpp +++ b/test/unit/GEMMTest.cpp @@ -1280,6 +1280,29 @@ namespace GEMMDriverTest } } + TEST_P(GEMMTestGPU, GPU_BasicGEMMJammedUnrollKTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + gemm.m = 64; + gemm.n = 128; + gemm.transA = "T"; + gemm.transB = "N"; + gemm.loadLDSA = false; + gemm.loadLDSB = false; + gemm.storeLDSD = false; + gemm.fuseLoops = true; + gemm.tailLoops = true; + gemm.unrollY = 2; + gemm.unrollK = 4; + gemm.macK = 8; + for(auto k : {8, 16, 24, 32, 40, 48, 56, 64}) + { + gemm.k = k; + basicGEMM(gemm); + } + } + TEST_P(GEMMTestGPU, GPU_BasicGEMMUnrollKLDS) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); @@ -3077,7 +3100,7 @@ namespace GEMMDriverTest EXPECT_EQ(countSubstring(generatedCode, "ds_write"), 0); } - TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2X2) + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x2) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); GEMMProblem gemm; @@ -3102,7 +3125,61 @@ namespace GEMMDriverTest basicGEMM(gemm); } - TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2X1) + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x2UnrollK) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 64; + + gemm.macM = 128; + gemm.macN = 256; + gemm.macK = 16; + + gemm.unrollK = 2; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 2 * gemm.wavefrontSize; + gemm.workgroupSizeY = 4; + + gemm.loadLDSA = false; + gemm.storeLDSD = false; + gemm.fuseLoops = false; + + basicGEMM(gemm); + } + + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x2UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 80; + + gemm.macM = 128; + gemm.macN = 256; + gemm.macK = 16; + + gemm.unrollK = 2; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 2 * gemm.wavefrontSize; + gemm.workgroupSizeY = 4; + + gemm.loadLDSA = false; + gemm.storeLDSD = false; + gemm.fuseLoops = false; + + basicGEMM(gemm); + } + + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x1) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); GEMMProblem gemm; @@ -3134,7 +3211,7 @@ namespace GEMMDriverTest EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); } - TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2X1UnrollK) + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x1UnrollK) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); GEMMProblem gemm; @@ -3161,12 +3238,44 @@ namespace GEMMDriverTest std::string generatedCode = m_context->instructions()->toString(); - EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 20); + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 22); EXPECT_EQ(countSubstring(generatedCode, "ds_read_b128"), 8); EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); } - TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1X2) + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x1UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 80; + + gemm.macM = 128; + gemm.macN = 128; + gemm.macK = 16; + + gemm.unrollK = 2; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 2 * gemm.wavefrontSize; + gemm.workgroupSizeY = 4; + + gemm.transA = "T"; + gemm.transB = "N"; + + basicGEMM(gemm); + + std::string generatedCode = m_context->instructions()->toString(); + + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 22); + EXPECT_EQ(countSubstring(generatedCode, "ds_read_b128"), 8); + EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); + } + + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1x2) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); GEMMProblem gemm; @@ -3195,7 +3304,7 @@ namespace GEMMDriverTest EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); } - TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1X2UnrollK) + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1x2UnrollK) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); GEMMProblem gemm; @@ -3221,7 +3330,38 @@ namespace GEMMDriverTest std::string generatedCode = m_context->instructions()->toString(); - EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 24); + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 26); + EXPECT_EQ(countSubstring(generatedCode, "ds_read_b128"), 8); + EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); + } + + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1x2UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 80; + + gemm.macM = 128; + gemm.macN = 128; + gemm.macK = 16; + + gemm.unrollK = 4; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 4 * gemm.wavefrontSize; + gemm.workgroupSizeY = 2; + + gemm.transA = "T"; + + basicGEMM(gemm); + + std::string generatedCode = m_context->instructions()->toString(); + + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 26); EXPECT_EQ(countSubstring(generatedCode, "ds_read_b128"), 8); EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); } @@ -3273,6 +3413,32 @@ namespace GEMMDriverTest basicGEMM(gemm); } + + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1x8UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 80; + + gemm.macM = 128; + gemm.macN = 256; + gemm.macK = 16; + + gemm.unrollK = 2; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 4 * gemm.wavefrontSize; + gemm.workgroupSizeY = 1; + + gemm.storeLDSD = false; + + basicGEMM(gemm); + } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x4) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); @@ -3331,7 +3497,40 @@ namespace GEMMDriverTest std::string generatedCode = m_context->instructions()->toString(); - EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 6); + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 9); + } + + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x4UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 80; + + gemm.macM = 128; + gemm.macN = 256; + gemm.macK = 16; + + gemm.unrollK = 2; + + gemm.prefetchInFlight = 2; + gemm.prefetchLDSFactor = 2; + gemm.prefetchMixMemOps = true; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 2 * gemm.wavefrontSize; + gemm.workgroupSizeY = 2; + + gemm.storeLDSD = false; + + basicGEMM(gemm); + + std::string generatedCode = m_context->instructions()->toString(); + + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 9); } TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed4x2) @@ -3391,7 +3590,38 @@ namespace GEMMDriverTest std::string generatedCode = m_context->instructions()->toString(); - EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 12); + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 15); + } + + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed4x2UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 64; + + gemm.macM = 128; + gemm.macN = 256; + gemm.macK = 16; + + gemm.unrollK = 4; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 1 * gemm.wavefrontSize; + gemm.workgroupSizeY = 4; + + gemm.storeLDSD = false; + + gemm.transB = "N"; + + basicGEMM(gemm); + + std::string generatedCode = m_context->instructions()->toString(); + + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 15); } TEST_P(GEMMTestGPU, GPU_BasicGEMMFP16AllLDS)