From 5fe7fd2a12f202125bdc6f34b2e4f40790172b24 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Fri, 14 Feb 2025 16:52:02 +0000 Subject: [PATCH 01/17] Initial work on FuseLoops --- lib/include/rocRoller/Expression.hpp | 2 + lib/include/rocRoller/Graph/Hypergraph.hpp | 11 +- .../rocRoller/Graph/Hypergraph_impl.hpp | 10 + .../KernelGraph/ControlGraph/Operation.hpp | 2 + .../KernelGraph/ControlToCoordinateMapper.hpp | 15 + .../KernelGraph/Transforms/FuseLoops.hpp | 5 +- lib/source/Expression_toString.cpp | 162 +++++++ .../KernelGraph/ControlGraph/Operation.cpp | 9 + lib/source/KernelGraph/KernelGraph.cpp | 22 +- .../KernelGraph/Transformations/FuseLoops.cpp | 455 +++++++++++------- .../Transformations/UnrollLoops.cpp | 16 - test/common/common/GEMMProblem.hpp | 2 + test/unit/GEMMTest.cpp | 58 ++- 13 files changed, 559 insertions(+), 210 deletions(-) diff --git a/lib/include/rocRoller/Expression.hpp b/lib/include/rocRoller/Expression.hpp index dff6a728..b77e8459 100644 --- a/lib/include/rocRoller/Expression.hpp +++ b/lib/include/rocRoller/Expression.hpp @@ -738,6 +738,8 @@ namespace rocRoller std::string toString(ExpressionPtr const& expr); std::string toString(Expression const& expr); + std::string toShortString(ExpressionPtr const& expr); + std::string toShortString(Expression const& expr); std::ostream& operator<<(std::ostream&, ExpressionPtr const&); std::ostream& operator<<(std::ostream&, Expression const&); std::ostream& operator<<(std::ostream&, std::vector const&); diff --git a/lib/include/rocRoller/Graph/Hypergraph.hpp b/lib/include/rocRoller/Graph/Hypergraph.hpp index 04c9648e..f71d1bff 100644 --- a/lib/include/rocRoller/Graph/Hypergraph.hpp +++ b/lib/include/rocRoller/Graph/Hypergraph.hpp @@ -301,13 +301,20 @@ namespace rocRoller std::string toDOT(std::string const& prefix = "", bool standalone = true) const; - static bool identity(Edge const&) + static bool edgeIdentity(Edge const&) { return true; } + static inline void nodeIdentity(Node&, int) {} + + template Transform> + std::string toDOT(std::string const& prefix = "", + bool standalone = true, + Transform nodeTransform = nodeIdentity) const; + template Predicate> - std::string toDOT(Predicate edgePredicate = identity) const; + std::string toDOT(Predicate edgePredicate = edgeIdentity) const; template requires(std::constructible_from || std::constructible_from) diff --git a/lib/include/rocRoller/Graph/Hypergraph_impl.hpp b/lib/include/rocRoller/Graph/Hypergraph_impl.hpp index 139306ba..4269a96f 100644 --- a/lib/include/rocRoller/Graph/Hypergraph_impl.hpp +++ b/lib/include/rocRoller/Graph/Hypergraph_impl.hpp @@ -884,6 +884,15 @@ namespace rocRoller template std::string Hypergraph::toDOT(std::string const& prefix, bool standalone) const + { + return toDOT(prefix, standalone, nodeIdentity); + } + + template + template Transform> + std::string Hypergraph::toDOT(std::string const& prefix, + bool standalone, + Transform nodeTransform) const { std::ostringstream msg; @@ -896,6 +905,7 @@ namespace rocRoller if(getElementType(pair.second) == ElementType::Node) { auto x = std::get(pair.second); + nodeTransform(x, pair.first); msg << toString(x) << "(" << pair.first << ")\""; } else diff --git a/lib/include/rocRoller/KernelGraph/ControlGraph/Operation.hpp b/lib/include/rocRoller/KernelGraph/ControlGraph/Operation.hpp index f1f8715b..5453db43 100644 --- a/lib/include/rocRoller/KernelGraph/ControlGraph/Operation.hpp +++ b/lib/include/rocRoller/KernelGraph/ControlGraph/Operation.hpp @@ -42,8 +42,10 @@ namespace rocRoller SetCoordinate(Expression::ExpressionPtr value); Expression::ExpressionPtr value; + std::string coordName; std::string name() const; + std::string toString() const; }; /** diff --git a/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp b/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp index dc7b66a8..241d5c08 100644 --- a/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp +++ b/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp @@ -44,6 +44,11 @@ namespace rocRoller::KernelGraph return a.id < b.id; } + bool inline operator==(TypeAndSubDimension const& a, TypeAndSubDimension const& b) + { + return a.id == b.id && a.subdimension == b.subdimension; + } + struct TypeAndNaryArgument { std::string id; @@ -63,6 +68,11 @@ namespace rocRoller::KernelGraph return a.id < b.id; } + bool inline operator==(TypeAndNaryArgument const& a, TypeAndNaryArgument const& b) + { + return a.id == b.id && a.argument == b.argument; + } + enum class ComputeIndexArgument : int { TARGET = 0, @@ -91,6 +101,11 @@ namespace rocRoller::KernelGraph return a.argument < b.argument; } + bool inline operator==(ComputeIndex const& a, ComputeIndex const& b) + { + return a.argument == b.argument && a.index == b.index; + } + using ConnectionSpec = std::variant postConstraints() const override; + std::string name() const override { return "FuseLoops"; } diff --git a/lib/source/Expression_toString.cpp b/lib/source/Expression_toString.cpp index 619861bf..57807af3 100644 --- a/lib/source/Expression_toString.cpp +++ b/lib/source/Expression_toString.cpp @@ -156,5 +156,167 @@ namespace rocRoller auto visitor = ExpressionToStringVisitor(); return visitor.call(expr); } + + struct ExpressionToShortStringVisitor + { + + template + std::string operator()(Expr const& expr) const + { + return concatenate(ExpressionInfo::name(), + "(", + call(expr.lhs), + ", ", + call(expr.r1hs), + ", ", + call(expr.r2hs), + ")"); + } + + template + std::string operator()(Expr const& expr) const + { + return concatenate( + ExpressionInfo::name(), "(", call(expr.lhs), ", ", call(expr.rhs), ")"); + } + template + std::string operator()(Expr const& expr) const + { + return concatenate(ExpressionInfo::name(), "(", call(expr.arg), ")"); + } + std::string operator()(Register::ValuePtr const& expr) const + { + // This allows an unallocated register value to be rendered into a string which + // improves debugging by allowing the string representation of that expression + // to be put into the source file as a comment. + // Trying to generate the code for the expression will throw an exception. + + std::string tostr = "UNALLOCATED"; + if(expr->canUseAsOperand()) + tostr = expr->toString(); + + // The call() function appends the result type, so add ":" to separate the + // value from the type. + return tostr; + } + std::string operator()(CommandArgumentPtr const& expr) const + { + if(expr) + return concatenate("CommandArgument(", expr->name(), ")"); + else + return "CommandArgument(nullptr)"; + } + + std::string operator()(CommandArgumentValue const& expr) const + { + return std::visit([](auto const& val) { return concatenate(val); }, expr); + } + + std::string operator()(AssemblyKernelArgumentPtr const& expr) const + { + // The call() function appends the result type, so add ":" to separate the + // value from the type. + return expr->name; + } + + std::string operator()(WaveTilePtr const& expr) const + { + return "WaveTile"; + } + + std::string operator()(DataFlowTag const& expr) const + { + return concatenate("DataFlowTag(", expr.tag, ")"); + } + +#define HANDLE_INFIX_OP(TYPE, INFIX) \ + std::string operator()(TYPE const& expr) const \ + { \ + return concatenate("(", call(expr.lhs), INFIX, call(expr.rhs), ")"); \ + } + HANDLE_INFIX_OP(Add, "+"); + HANDLE_INFIX_OP(Subtract, "-"); + HANDLE_INFIX_OP(Multiply, "*"); + HANDLE_INFIX_OP(Divide, "/"); + HANDLE_INFIX_OP(Modulo, "%"); + HANDLE_INFIX_OP(ShiftL, "<<"); + HANDLE_INFIX_OP(ArithmeticShiftR, ">>"); + HANDLE_INFIX_OP(BitwiseAnd, "&"); + HANDLE_INFIX_OP(BitwiseOr, "|"); + HANDLE_INFIX_OP(BitwiseXor, "^"); + HANDLE_INFIX_OP(GreaterThan, ">"); + HANDLE_INFIX_OP(GreaterThanEqual, ">="); + HANDLE_INFIX_OP(LessThan, "<"); + HANDLE_INFIX_OP(LessThanEqual, "<="); + HANDLE_INFIX_OP(Equal, "=="); + HANDLE_INFIX_OP(NotEqual, "!="); + HANDLE_INFIX_OP(LogicalAnd, "&&"); + HANDLE_INFIX_OP(LogicalOr, "||"); + + std::string operator()(AddShiftL const& expr) const + { + return concatenate( + "((", call(expr.lhs), "+", call(expr.r1hs), ")<<", call(expr.r2hs), ")"); + } + + std::string operator()(ShiftLAdd const& expr) const + { + return concatenate( + "((", call(expr.lhs), "<<", call(expr.r1hs), ")+", call(expr.r2hs), ")"); + } + + std::string operator()(MultiplyAdd const& expr) const + { + return concatenate( + "((", call(expr.lhs), "*", call(expr.r1hs), ")+", call(expr.r2hs), ")"); + } + + std::string operator()(Conditional const& expr) const + { + return concatenate( + "(", call(expr.lhs), "?", call(expr.r1hs), ":", call(expr.r2hs), ")"); + } + + std::string operator()(Negate const& expr) const + { + return concatenate("(-", call(expr.arg), ")"); + } + + std::string operator()(BitwiseNegate const& expr) const + { + return concatenate("(~", call(expr.arg), ")"); + } + + std::string operator()(Convert const& expr) const + { + return concatenate( + "((", TypeAbbrev(resultVariableType(expr)), ")", call(expr.arg), ")"); + } + + std::string call(Expression const& expr) const + { + return std::visit(*this, expr); + } + + std::string call(ExpressionPtr expr) const + { + if(!expr) + return "nullptr"; + + return call(*expr); + } + }; + + std::string toShortString(Expression const& expr) + { + auto visitor = ExpressionToShortStringVisitor(); + return visitor.call(expr); + } + + std::string toShortString(ExpressionPtr const& expr) + { + auto visitor = ExpressionToShortStringVisitor(); + return visitor.call(expr); + } } } diff --git a/lib/source/KernelGraph/ControlGraph/Operation.cpp b/lib/source/KernelGraph/ControlGraph/Operation.cpp index b58aae32..d3d23afa 100644 --- a/lib/source/KernelGraph/ControlGraph/Operation.cpp +++ b/lib/source/KernelGraph/ControlGraph/Operation.cpp @@ -9,6 +9,15 @@ namespace rocRoller::KernelGraph::ControlGraph { } + std::string SetCoordinate::toString() const + { + if(coordName != "") + { + return concatenate(name(), " ", coordName, ": ", toShortString(value)); + } + return concatenate(name(), ": ", value); + } + std::string ForLoopOp::toString() const { return concatenate(name(), " ", loopName, ": ", condition); diff --git a/lib/source/KernelGraph/KernelGraph.cpp b/lib/source/KernelGraph/KernelGraph.cpp index 56333b77..c6f69243 100644 --- a/lib/source/KernelGraph/KernelGraph.cpp +++ b/lib/source/KernelGraph/KernelGraph.cpp @@ -7,6 +7,7 @@ namespace rocRoller { namespace KernelGraph { + std::string KernelGraph::toDOT(bool drawMappings, std::string title) const { std::stringstream ss; @@ -19,7 +20,26 @@ namespace rocRoller ss << coordinates.toDOT("coord", false); ss << "subgraph clusterCF {"; ss << "label = \"Control Graph\";" << std::endl; - ss << control.toDOT("cntrl", false); + + auto setCoordRename = [&](ControlGraph::Operation& op, int index) { + std::visit( + rocRoller::overloaded{[&](ControlGraph::SetCoordinate& setCoord) { + auto connected = only(mapper.getConnections(index)); + if(!connected) + return; + auto connectedIndex = (*connected).coordinate; + if(connectedIndex < 0) + return; + std::stringstream coordss; + coordss << name(coordinates.getNode(connectedIndex)) + << "(" << connectedIndex << ")"; + setCoord.coordName = coordss.str(); + }, + [](auto&) {}}, + op); + }; + + ss << control.toDOT("cntrl", false, setCoordRename); ss << "}" << std::endl; if(drawMappings) { diff --git a/lib/source/KernelGraph/Transformations/FuseLoops.cpp b/lib/source/KernelGraph/Transformations/FuseLoops.cpp index 9e1cf347..5b839cde 100644 --- a/lib/source/KernelGraph/Transformations/FuseLoops.cpp +++ b/lib/source/KernelGraph/Transformations/FuseLoops.cpp @@ -27,218 +27,312 @@ namespace rocRoller { using GD = rocRoller::Graph::Direction; - /** - * @brief Find a path from a node to a ForLoopOp using only Sequence edges - * - * Returns an empty vector if no path is found. - * - * @param graph - * @param start - * @return std::vector - */ - std::vector pathToForLoop(KernelGraph& graph, int start) + std::vector> gatherForLoops(KernelGraph& graph, int start) { - // Find the first ForLoop under the node - auto allForLoops - = graph.control - .findNodes( - start, - [&](int tag) -> bool { - return isOperation(graph.control.getElement(tag)); - }, - GD::Downstream) - .to(); - - if(allForLoops.empty()) - return {}; - - auto firstForLoop = allForLoops[0]; - - // Find all of the nodes in between the node and the first for loop - auto pathToLoopWithEdges = graph.control - .path(std::vector{start}, - std::vector{firstForLoop}) - .to(); - - // Filter out only the nodes - std::vector pathToLoop; - std::copy_if(pathToLoopWithEdges.begin(), - pathToLoopWithEdges.end(), - std::back_inserter(pathToLoop), - [&](int tag) -> bool { - return graph.control.getElementType(tag) - == Graph::ElementType::Node; - }); - - return pathToLoop; - } - - void fuseLoops(KernelGraph& graph, int tag) - { - rocRoller::Log::getLogger()->debug("KernelGraph::fuseLoops({})", tag); - - auto dontWalkPastForLoop = [&](int tag) -> bool { - for(auto neighbour : graph.control.getNeighbours(tag, GD::Downstream)) + auto bodies = graph.control.getOutputNodeIndices(start).to(); + auto isForLoopOp = graph.control.isElemType(); + auto isSetCoordinate = graph.control.isElemType(); + auto isSequence = graph.control.isElemType(); + auto isBody = graph.control.isElemType(); + + auto maybeForLoops + = graph.control.depthFirstVisit(bodies, isSequence, GD::Downstream) + .filter([&](int tag) -> bool { + if(isForLoopOp(tag)) + { + return true; + } + if(isSetCoordinate(tag)) + { + return graph.mapper.get(tag) > 0; + } + return false; + }) + .to(); + + std::vector forLoops; + for(auto const& maybeForLoop : maybeForLoops) + { + if(isForLoopOp(maybeForLoop)) { - if(graph.control.get(neighbour)) + forLoops.insert(forLoops.begin(), maybeForLoop); + } + else + { + // The filter means that this is now a SetCoordinate connected to an Unroll dimension + // Tail loops are always downstream of these + std::optional tag = maybeForLoop; + while(tag && isSetCoordinate(*tag)) { - return false; + tag = only(graph.control.getOutputNodeIndices(*tag)); + } + if(tag && isForLoopOp(*tag)) + { + forLoops.push_back(*tag); } } - return true; - }; - - // Find all of the paths from the top of one of a body to a - // ForLoopOp. - auto bodies = graph.control.getOutputNodeIndices(tag).to(); - std::vector> paths; - for(auto const& body : bodies) - { - auto path = pathToForLoop(graph, body); - if(!path.empty()) - paths.push_back(path); } - // See if any of the ForLoopOps that were found in paths - // should be fused together. - std::unordered_set forLoopsToFuse; - Expression::ExpressionPtr loopIncrement; - Expression::ExpressionPtr loopLength; - for(auto const& path : paths) + std::vector> loopGroupsToFuse; + while(!forLoops.empty()) { - auto forLoop = path.back(); - if(forLoopsToFuse.count(forLoop) != 0) - return; + std::unordered_set loopGroup; + Expression::ExpressionPtr loopIncrement; + Expression::ExpressionPtr loopLength; + for(auto const& loop : forLoops) + { + if(loopGroup.count(loop) != 0) + continue; + + auto loopDim = getSize(std::get(graph.coordinates.getElement( + graph.mapper.get(loop, NaryArgument::DEST)))); + + // Loops to be fused must have the same length + if(loopLength) + { + if(!identical(loopDim, loopLength)) + continue; + } + else + { + loopLength = loopDim; + } + + // Loops to be fused must have the same increment value + auto [dataTag, increment] = getForLoopIncrement(graph, loop); + if(loopIncrement) + { + if(!identical(loopIncrement, increment)) + continue; + } + else + { + loopIncrement = increment; + } + loopGroup.insert(loop); + } - // Check to see if loops are all the same length - auto dimTag = graph.mapper.get(forLoop, NaryArgument::DEST); - auto forLoopDim = getSize(graph.coordinates.getNode(dimTag)); - if(loopLength) + for(auto loop : loopGroup) { - if(!identical(forLoopDim, loopLength)) - return; + std::erase(forLoops, loop); } - else + + if(loopGroup.size() > 1) { - loopLength = forLoopDim; + loopGroupsToFuse.push_back(loopGroup); } + } - // Check to see if loops are incremented by the same value - auto [dataTag, increment] = getForLoopIncrement(graph, forLoop); - if(loopIncrement) + return loopGroupsToFuse; + } + + void fuseNode(KernelGraph& graph, int fusedNodeTag, int nodeTag) + { + for(auto const& child : + graph.control.getOutputNodeIndices(nodeTag).to()) + { + if(fusedNodeTag != child) { - if(!identical(loopIncrement, increment)) - return; + graph.control.addElement(Sequence(), {fusedNodeTag}, {child}); } - else + graph.control.deleteElement(std::vector{nodeTag}, + std::vector{child}); + std::unordered_set toDelete; + for(auto descSeqOfChild : + filter(graph.control.isElemType(), + graph.control.depthFirstVisit(child, GD::Downstream))) { - loopIncrement = increment; + if(graph.control.getNeighbours(descSeqOfChild) + .to() + .contains(fusedNodeTag)) + { + toDelete.insert(descSeqOfChild); + } } + for(auto edge : toDelete) + { + graph.control.deleteElement(edge); + } + } - forLoopsToFuse.insert(forLoop); + for(auto const& child : + graph.control.getOutputNodeIndices(nodeTag).to()) + { + graph.control.addElement(Body(), {fusedNodeTag}, {child}); + graph.control.deleteElement(std::vector{nodeTag}, + std::vector{child}); } - if(forLoopsToFuse.size() <= 1) - return; + for(auto const& parent : + graph.control.getInputNodeIndices(nodeTag).to()) + { + auto descOfFusedLoop + = graph.control + .depthFirstVisit(fusedNodeTag, + graph.control.isElemType(), + GD::Downstream) + .to(); + + if(!descOfFusedLoop.contains(parent)) + { + graph.control.addElement(Sequence(), {parent}, {fusedNodeTag}); + } + graph.control.deleteElement(std::vector{parent}, + std::vector{nodeTag}); + } - auto fusedLoopTag = *forLoopsToFuse.begin(); + for(auto const& parent : + graph.control.getInputNodeIndices(nodeTag).to()) + { + graph.control.addElement(Body(), {parent}, {fusedNodeTag}); + graph.control.deleteElement(std::vector{parent}, + std::vector{nodeTag}); + } + } + + struct IsSameOperationVisitor + { + template + bool operator()(int, OpA const&, int, OpB const&) + { + return false; + } - for(auto const& forLoopTag : forLoopsToFuse) + bool operator()(int tagA, SetCoordinate const& A, int tagB, SetCoordinate const& B) { - if(forLoopTag == fusedLoopTag) - continue; + auto connA = graph.mapper.getConnections(tagA); + auto connB = graph.mapper.getConnections(tagB); - for(auto const& child : - graph.control.getOutputNodeIndices(forLoopTag).to()) + if(connA.size() != connB.size()) { - if(fusedLoopTag != child) - { - graph.control.addElement(Sequence(), {fusedLoopTag}, {child}); - } - graph.control.deleteElement(std::vector{forLoopTag}, - std::vector{child}); - std::unordered_set toDelete; - for(auto descSeqOfChild : - filter(graph.control.isElemType(), - graph.control.depthFirstVisit(child, GD::Downstream))) - { - if(graph.control.getNeighbours(descSeqOfChild) - .to() - .contains(fusedLoopTag)) - { - toDelete.insert(descSeqOfChild); - } - } - for(auto edge : toDelete) - { - graph.control.deleteElement(edge); - } + return false; } - - for(auto const& child : - graph.control.getOutputNodeIndices(forLoopTag).to()) + for(auto iterA = connA.begin(), iterB = connB.begin(); iterA != connA.end(); + iterA++, iterB++) { - graph.control.addElement(Body(), {fusedLoopTag}, {child}); - graph.control.deleteElement(std::vector{forLoopTag}, - std::vector{child}); + if(iterA->coordinate != iterB->coordinate) + return false; + if(iterA->connection != iterB->connection) + return false; } + return identical(A.value, B.value); + } + + bool call(int tagA, int tagB) + { + return std::visit(*this, + std::variant(tagA), + graph.control.getNode(tagA), + std::variant(tagB), + graph.control.getNode(tagB)); + } + + KernelGraph const& graph; + }; + + void fuseScopes(KernelGraph& graph, int tag) + { + auto parentsWithEdges + = graph.control.getInputNodeIndices(tag).template to(); + std::set> nodeSetsToMerge; + IsSameOperationVisitor visitor{graph}; - for(auto const& parent : - graph.control.getInputNodeIndices(forLoopTag).to()) + for(auto const& A : parentsWithEdges) + { + std::set sameAsThis; + for(auto const& B : parentsWithEdges) { - auto descOfFusedLoop - = graph.control - .depthFirstVisit(fusedLoopTag, - graph.control.isElemType(), - GD::Downstream) - .to(); - - if(!descOfFusedLoop.contains(parent)) + if(A == B) + continue; + if(visitor.call(A, B)) { - graph.control.addElement(Sequence(), {parent}, {fusedLoopTag}); + sameAsThis.insert(B); } - graph.control.deleteElement(std::vector{parent}, - std::vector{forLoopTag}); } - - for(auto const& parent : - graph.control.getInputNodeIndices(forLoopTag).to()) + if(!sameAsThis.empty()) { - graph.control.addElement(Body(), {parent}, {fusedLoopTag}); - graph.control.deleteElement(std::vector{parent}, - std::vector{forLoopTag}); + sameAsThis.insert(A); + nodeSetsToMerge.insert(sameAsThis); } + } - purgeFor(graph, forLoopTag); + for(auto mergeSet : nodeSetsToMerge) + { + auto mergedNodeTag = *mergeSet.begin(); + for(auto const& nodeTag : mergeSet) + { + if(nodeTag == mergedNodeTag) + continue; + fuseNode(graph, mergedNodeTag, nodeTag); + + graph.control.deleteElement(nodeTag); + graph.mapper.purge(nodeTag); + } + fuseScopes(graph, mergedNodeTag); } + } - auto children - = graph.control.getOutputNodeIndices(fusedLoopTag).to(); + void fuseLoops(KernelGraph& graph, int tag) + { + rocRoller::Log::getLogger()->debug("KernelGraph::fuseLoops({})", tag); + + auto loopGroupsToFuse = gatherForLoops(graph, tag); + for(auto forLoopsToFuse : loopGroupsToFuse) + { + if(forLoopsToFuse.size() <= 1) + return; + + auto dontWalkPastForLoop = [&](int tag) -> bool { + for(auto neighbour : graph.control.getNeighbours(tag, GD::Downstream)) + { + if(graph.control.get(neighbour)) + { + return false; + } + } + return true; + }; + auto fusedLoopTag = *forLoopsToFuse.begin(); - auto loads = filter(graph.control.isElemType(), - graph.control.depthFirstVisit( - children, dontWalkPastForLoop, GD::Downstream)) - .to(); + for(auto const& forLoopTag : forLoopsToFuse) + { + if(forLoopTag == fusedLoopTag) + continue; - auto ldsLoads = filter(graph.control.isElemType(), - graph.control.depthFirstVisit( - children, dontWalkPastForLoop, GD::Downstream)) - .to(); + fuseNode(graph, fusedLoopTag, forLoopTag); + purgeFor(graph, forLoopTag); + } - auto stores = filter(graph.control.isElemType(), - graph.control.depthFirstVisit( - children, dontWalkPastForLoop, GD::Downstream)) - .to(); + fuseScopes(graph, fusedLoopTag); - auto ldsStores = filter(graph.control.isElemType(), + auto children + = graph.control.getOutputNodeIndices(fusedLoopTag).to(); + + auto loads = filter(graph.control.isElemType(), graph.control.depthFirstVisit( children, dontWalkPastForLoop, GD::Downstream)) .to(); - orderMemoryNodes(graph, loads, true); - orderMemoryNodes(graph, ldsLoads, true); - orderMemoryNodes(graph, stores, true); - orderMemoryNodes(graph, ldsStores, true); + auto ldsLoads = filter(graph.control.isElemType(), + graph.control.depthFirstVisit( + children, dontWalkPastForLoop, GD::Downstream)) + .to(); + + auto stores = filter(graph.control.isElemType(), + graph.control.depthFirstVisit( + children, dontWalkPastForLoop, GD::Downstream)) + .to(); + + auto ldsStores = filter(graph.control.isElemType(), + graph.control.depthFirstVisit( + children, dontWalkPastForLoop, GD::Downstream)) + .to(); + + orderMemoryNodes(graph, loads, true); + orderMemoryNodes(graph, ldsLoads, true); + orderMemoryNodes(graph, stores, true); + orderMemoryNodes(graph, ldsStores, true); + } } } @@ -259,5 +353,36 @@ namespace rocRoller return newGraph; } + + ConstraintStatus BodyOfOnlyOneNode(const KernelGraph& graph) + { + ConstraintStatus retval; + for(auto tag : graph.control.leaves().to()) + { + auto containing = graph.control.nodesContaining(tag).to(); + for(auto contA : containing) + { + for(auto contB : containing) + { + if(contA == contB) + { + continue; + } + auto order = graph.control.compareNodes(contA, contB); + retval.combine( + order == NodeOrdering::LeftInBodyOfRight + || order == NodeOrdering::RightInBodyOfLeft, + concatenate( + "Nodes ", contA, " and ", contB, "have intersecting bodies.")); + } + } + } + return retval; + } + + std::vector FuseLoops::postConstraints() const + { + return {BodyOfOnlyOneNode}; + } } } diff --git a/lib/source/KernelGraph/Transformations/UnrollLoops.cpp b/lib/source/KernelGraph/Transformations/UnrollLoops.cpp index ea843be2..ea3eb825 100644 --- a/lib/source/KernelGraph/Transformations/UnrollLoops.cpp +++ b/lib/source/KernelGraph/Transformations/UnrollLoops.cpp @@ -579,22 +579,6 @@ namespace rocRoller } } - { - auto containingForLoops = graph.control.nodesContaining(loop).filter( - graph.control.isElemType()); - for(auto containingLoop : containingForLoops) - { - if(getUnrollAmount(graph, containingLoop, m_params) > 1) - { - Log::debug("Not adding tail loop for {} because it is contained by {} " - "which is also unrolled.", - loop, - containingLoop); - return std::nullopt; - } - } - } - auto loopSizeType = resultVariableType(loopSize); auto amount = Expression::literal(unrollAmount, loopSizeType); auto loopSizeRoundedDown = (loopSize / amount) * amount; diff --git a/test/common/common/GEMMProblem.hpp b/test/common/common/GEMMProblem.hpp index 5d6fe552..f6289ccc 100644 --- a/test/common/common/GEMMProblem.hpp +++ b/test/common/common/GEMMProblem.hpp @@ -34,6 +34,8 @@ struct GEMMProblem std::string transB = "T"; // Unroll Sizes + unsigned int unrollM = 0; + unsigned int unrollN = 0; unsigned int unrollK = 0; bool loadLDSA = true; diff --git a/test/unit/GEMMTest.cpp b/test/unit/GEMMTest.cpp index 94d7ab92..0a53304a 100644 --- a/test/unit/GEMMTest.cpp +++ b/test/unit/GEMMTest.cpp @@ -95,20 +95,26 @@ namespace GEMMDriverTest float alpha = gemm.alpha; float beta = gemm.beta; - AssertFatal(M % gemm.macM == 0, - "MacroTile size mismatch (M)", - ShowValue(M), - ShowValue(gemm.macM)); - AssertFatal(N % gemm.macN == 0, - "MacroTile size mismatch (N)", - ShowValue(N), - ShowValue(gemm.macN)); - - if(gemm.unrollK > 0 && !gemm.tailLoops) - { - AssertFatal(K % (gemm.macK * gemm.unrollK) == 0, - "MacroTile size mismatch (K unroll)"); - } +#define AssertMacroTileMatch(D) \ + if(gemm.unroll##D > 0 && !gemm.tailLoops) \ + { \ + AssertFatal(D % (gemm.mac##D * gemm.unroll##D) == 0, \ + "MacroTile size mismatch (" #D " unroll)", \ + ShowValue(D), \ + ShowValue(gemm.mac##D), \ + ShowValue(gemm.unroll##D)); \ + } \ + else \ + { \ + AssertFatal(D % gemm.mac##D == 0, \ + "MacroTile size mismatch (" #D ")", \ + ShowValue(D), \ + ShowValue(gemm.mac##D)); \ + } + + AssertMacroTileMatch(M); + AssertMacroTileMatch(N); + AssertMacroTileMatch(K); auto bpeA = DataTypeInfo::Get(dataTypeA).elementBytes; auto bpeB = DataTypeInfo::Get(dataTypeB).elementBytes; @@ -352,6 +358,8 @@ namespace GEMMDriverTest params->fuseLoops = gemm.fuseLoops; params->tailLoops = gemm.tailLoops; params->allowAmbiguousMemoryNodes = gemm.allowAmbiguousMemoryNodes; + params->unrollX = gemm.unrollM; + params->unrollY = gemm.unrollN; params->unrollK = gemm.unrollK; params->packMultipleElementsInto1VGPR = gemm.packMultipleElementsInto1VGPR; params->prefetch = gemm.prefetch; @@ -2272,7 +2280,7 @@ namespace GEMMDriverTest basicGEMM(gemm); } - TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2X2) + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x2) { GEMMProblem gemm; @@ -2296,7 +2304,7 @@ namespace GEMMDriverTest basicGEMM(gemm); } - TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2X1) + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x1) { GEMMProblem gemm; @@ -2327,7 +2335,7 @@ namespace GEMMDriverTest EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); } - TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2X1UnrollK) + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x1UnrollK) { GEMMProblem gemm; @@ -2339,7 +2347,9 @@ namespace GEMMDriverTest gemm.macN = 128; gemm.macK = 16; - gemm.unrollK = 2; + gemm.unrollK = 2; + gemm.tailLoops = true; + gemm.allowAmbiguousMemoryNodes = true; gemm.waveK = 8; @@ -2353,12 +2363,12 @@ namespace GEMMDriverTest std::string generatedCode = m_context->instructions()->toString(); - EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 20); + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 22); EXPECT_EQ(countSubstring(generatedCode, "ds_read_b128"), 8); EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); } - TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1X2) + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1x2) { GEMMProblem gemm; @@ -2386,7 +2396,7 @@ namespace GEMMDriverTest EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); } - TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1X2UnrollK) + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1x2UnrollK) { GEMMProblem gemm; @@ -2411,7 +2421,7 @@ namespace GEMMDriverTest std::string generatedCode = m_context->instructions()->toString(); - EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 24); + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 26); EXPECT_EQ(countSubstring(generatedCode, "ds_read_b128"), 8); EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); } @@ -2517,7 +2527,7 @@ namespace GEMMDriverTest std::string generatedCode = m_context->instructions()->toString(); - EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 6); + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 9); } TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed4x2) @@ -2575,7 +2585,7 @@ namespace GEMMDriverTest std::string generatedCode = m_context->instructions()->toString(); - EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 12); + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 15); } TEST_P(GEMMTestGPU, GPU_BasicGEMMLiteralStrides) From b940af0ada8cfe946099d6a1f2587e730b473138 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Thu, 27 Feb 2025 17:15:37 +0000 Subject: [PATCH 02/17] Fix graphconstraint output --- .../KernelGraph/Transformations/FuseLoops.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/source/KernelGraph/Transformations/FuseLoops.cpp b/lib/source/KernelGraph/Transformations/FuseLoops.cpp index 5b839cde..511b54f4 100644 --- a/lib/source/KernelGraph/Transformations/FuseLoops.cpp +++ b/lib/source/KernelGraph/Transformations/FuseLoops.cpp @@ -369,11 +369,14 @@ namespace rocRoller continue; } auto order = graph.control.compareNodes(contA, contB); - retval.combine( - order == NodeOrdering::LeftInBodyOfRight - || order == NodeOrdering::RightInBodyOfLeft, - concatenate( - "Nodes ", contA, " and ", contB, "have intersecting bodies.")); + if(!(order == NodeOrdering::LeftInBodyOfRight + || order == NodeOrdering::RightInBodyOfLeft)) + { + retval.combine( + false, + concatenate( + "Nodes ", contA, " and ", contB, "have intersecting bodies.")); + } } } } From da7529918867371be24a491d4b708685b29af9e3 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Tue, 8 Apr 2025 21:11:46 +0000 Subject: [PATCH 03/17] fix compilation issue --- lib/source/Expression_toString.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/lib/source/Expression_toString.cpp b/lib/source/Expression_toString.cpp index 1c4ee124..0d2460ca 100644 --- a/lib/source/Expression_toString.cpp +++ b/lib/source/Expression_toString.cpp @@ -224,6 +224,22 @@ namespace rocRoller // value from the type. return tostr; } + + std::string operator()(ScaledMatrixMultiply const& expr) const + { + return concatenate("ScaledMatrixMultiply(", + call(expr.matA), + ", ", + call(expr.matB), + ", ", + call(expr.matC), + ", ", + call(expr.scaleA), + ", ", + call(expr.scaleB), + ")"); + } + std::string operator()(CommandArgumentPtr const& expr) const { if(expr) From 9c60015baaf5fb14cfb63fcb78dd94547b56a0cb Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Wed, 9 Apr 2025 20:23:47 +0000 Subject: [PATCH 04/17] Add heuristic to restrict loops that can be merged --- .../KernelGraph/Transformations/FuseLoops.cpp | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/lib/source/KernelGraph/Transformations/FuseLoops.cpp b/lib/source/KernelGraph/Transformations/FuseLoops.cpp index 15f0d94d..30fa7329 100644 --- a/lib/source/KernelGraph/Transformations/FuseLoops.cpp +++ b/lib/source/KernelGraph/Transformations/FuseLoops.cpp @@ -105,6 +105,7 @@ namespace rocRoller std::unordered_set loopGroup; Expression::ExpressionPtr loopIncrement; Expression::ExpressionPtr loopLength; + std::map baseLoopContents; for(auto const& loop : forLoops) { if(loopGroup.count(loop) != 0) @@ -135,6 +136,37 @@ namespace rocRoller { loopIncrement = increment; } + + // Loop similarity heuristic + // We only want to fuse loops which are "the same" or similar. + // Currently this is just counting the number of each type of node in the body. + // We may want to replace this with a better heuristic in the future. + auto loopContentsGenerator = graph.control.nodesInBody(loop); + std::map loopContents; + std::for_each(loopContentsGenerator.begin(), + loopContentsGenerator.end(), + [&](int tag) -> void { + auto type = graph.control.get(tag)->index(); + if(loopContents.contains(type)) + { + loopContents[type]++; + } + else + { + loopContents[type] = 1; + } + }); + + if(!baseLoopContents.empty()) + { + if(baseLoopContents != loopContents) + continue; + } + else + { + baseLoopContents = loopContents; + } + loopGroup.insert(loop); } From 1183ce3530dc1821df71fb7f3afd6b394430d1da Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Wed, 9 Apr 2025 20:24:12 +0000 Subject: [PATCH 05/17] Add UnrollK version of jammed 2x2 test --- test/unit/GEMMTest.cpp | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/unit/GEMMTest.cpp b/test/unit/GEMMTest.cpp index 07a919ac..e456757f 100644 --- a/test/unit/GEMMTest.cpp +++ b/test/unit/GEMMTest.cpp @@ -2730,6 +2730,34 @@ namespace GEMMDriverTest basicGEMM(gemm); } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x2UnrollK) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 64; + + gemm.macM = 128; + gemm.macN = 256; + gemm.macK = 16; + + gemm.unrollK = 2; + gemm.tailLoops = true; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 2 * gemm.wavefrontSize; + gemm.workgroupSizeY = 4; + + gemm.loadLDSA = false; + gemm.storeLDSD = false; + gemm.fuseLoops = false; + + basicGEMM(gemm); + } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x1) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); From ce3f3ed8212c1782d34bca5c96bb8c5c47a643cb Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Wed, 9 Apr 2025 20:24:50 +0000 Subject: [PATCH 06/17] Ensure no reindexing chains in RemoveDuplicates transformation --- .../Transformations/RemoveDuplicates.cpp | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/lib/source/KernelGraph/Transformations/RemoveDuplicates.cpp b/lib/source/KernelGraph/Transformations/RemoveDuplicates.cpp index f0507f4e..47e43b52 100644 --- a/lib/source/KernelGraph/Transformations/RemoveDuplicates.cpp +++ b/lib/source/KernelGraph/Transformations/RemoveDuplicates.cpp @@ -354,6 +354,27 @@ namespace rocRoller } } + // If there exists a reindexing "chain" A -> B -> C, then processing + // B -> C before A -> B causes the old references to A to be left + // dangling. To avoid this, we collapse the chain so A -> C and B -> C, + // which matches the behaviour we would see if A -> B was processed first. + for(auto [oldTag, newTag] : expressionReindexer.coordinates) + { + while(expressionReindexer.coordinates.count(newTag) > 0 && newTag != oldTag) + { + auto nextTag = expressionReindexer.coordinates.at(newTag); + if(nextTag != newTag) + { + newTag = expressionReindexer.coordinates.at(newTag); + } + else + { + break; + } + } + expressionReindexer.coordinates[oldTag] = newTag; + } + // Update tile references auto kernel = *graph.control.roots().begin(); reindexExpressions(graph, kernel, expressionReindexer); From c593b46b71924c382682e780afc440b698c7438e Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Wed, 9 Apr 2025 21:22:41 +0000 Subject: [PATCH 07/17] Remove redundant gemm arguments --- test/unit/GEMMTest.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/unit/GEMMTest.cpp b/test/unit/GEMMTest.cpp index 23497c7d..d6f3f3fc 100644 --- a/test/unit/GEMMTest.cpp +++ b/test/unit/GEMMTest.cpp @@ -2737,8 +2737,7 @@ namespace GEMMDriverTest gemm.macN = 256; gemm.macK = 16; - gemm.unrollK = 2; - gemm.tailLoops = true; + gemm.unrollK = 2; gemm.waveK = 8; @@ -2797,9 +2796,7 @@ namespace GEMMDriverTest gemm.macN = 128; gemm.macK = 16; - gemm.unrollK = 2; - gemm.tailLoops = true; - gemm.allowAmbiguousMemoryNodes = true; + gemm.unrollK = 2; gemm.waveK = 8; From cf06225503e433c6d82f7b3fc0f20fe68b7c78b5 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Wed, 9 Apr 2025 21:34:03 +0000 Subject: [PATCH 08/17] Revert unintended merge changes --- test/catch/RegisterTagManagerTest.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/catch/RegisterTagManagerTest.cpp b/test/catch/RegisterTagManagerTest.cpp index 1bd33479..e2b35688 100644 --- a/test/catch/RegisterTagManagerTest.cpp +++ b/test/catch/RegisterTagManagerTest.cpp @@ -38,11 +38,11 @@ TEST_CASE("RegisterTagManager RegisterExpressionAttributes toString works", "[codegen][kernel-graph]") { auto expected = R"({ - t.dataType = Int32 - t.unitStride = 0 - t.elementBlockSize = 0 - t.elementBlockStride = nullptr - t.trLoadPairStride = nullptr + t.dataType = Int32 + t.unitStride = 0 + t.elementBlockSize = 0 + t.elementBlockStride = nullptr + t.trLoadPairStride = nullptr })"; CHECK(toString(RegisterExpressionAttributes{DataType::Int32, false}) == expected); From 0c19abc4689b29c1cc7da0f9941128573d7e5a36 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Wed, 9 Apr 2025 21:36:29 +0000 Subject: [PATCH 09/17] Formatting fixes --- test/catch/RegisterTagManagerTest.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/catch/RegisterTagManagerTest.cpp b/test/catch/RegisterTagManagerTest.cpp index e2b35688..1bd33479 100644 --- a/test/catch/RegisterTagManagerTest.cpp +++ b/test/catch/RegisterTagManagerTest.cpp @@ -38,11 +38,11 @@ TEST_CASE("RegisterTagManager RegisterExpressionAttributes toString works", "[codegen][kernel-graph]") { auto expected = R"({ - t.dataType = Int32 - t.unitStride = 0 - t.elementBlockSize = 0 - t.elementBlockStride = nullptr - t.trLoadPairStride = nullptr + t.dataType = Int32 + t.unitStride = 0 + t.elementBlockSize = 0 + t.elementBlockStride = nullptr + t.trLoadPairStride = nullptr })"; CHECK(toString(RegisterExpressionAttributes{DataType::Int32, false}) == expected); From 41945660ffe3a1a8bdadfc61966ad09461bcdab0 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Thu, 10 Apr 2025 17:04:57 +0000 Subject: [PATCH 10/17] Update compareNodes call --- lib/source/KernelGraph/Transformations/FuseLoops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/source/KernelGraph/Transformations/FuseLoops.cpp b/lib/source/KernelGraph/Transformations/FuseLoops.cpp index 6c7f0fb6..1ca1234e 100644 --- a/lib/source/KernelGraph/Transformations/FuseLoops.cpp +++ b/lib/source/KernelGraph/Transformations/FuseLoops.cpp @@ -428,7 +428,7 @@ namespace rocRoller { continue; } - auto order = graph.control.compareNodes(contA, contB); + auto order = graph.control.compareNodes(UseCacheIfAvailable, contA, contB); if(!(order == NodeOrdering::LeftInBodyOfRight || order == NodeOrdering::RightInBodyOfLeft)) { From 09dacb3e0206a9aea3ccd0b44392ebfdc53dfc37 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Tue, 15 Apr 2025 15:44:28 +0000 Subject: [PATCH 11/17] Adjust checks on K size to accommodate correct unrolling --- client/include/DataParallelGEMMSolution.hpp | 2 +- test/unit/GEMMFusion.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client/include/DataParallelGEMMSolution.hpp b/client/include/DataParallelGEMMSolution.hpp index 2b5147d3..1cf5fef4 100644 --- a/client/include/DataParallelGEMMSolution.hpp +++ b/client/include/DataParallelGEMMSolution.hpp @@ -475,7 +475,7 @@ namespace rocRoller // predicates // unrollK size match predicates - if(params->unrollX <= 1 && params->unrollY <= 1 && !params->streamK) + if(params->tailLoops && !params->streamK) { auto unrollKPredicate = (aSizeExps[1] % macKExp == zero); setComment(unrollKPredicate, "K must be a multiple of macK."); diff --git a/test/unit/GEMMFusion.cpp b/test/unit/GEMMFusion.cpp index 62fe3301..626b9352 100644 --- a/test/unit/GEMMFusion.cpp +++ b/test/unit/GEMMFusion.cpp @@ -77,7 +77,7 @@ namespace GEMMDriverTest AssertFatal(M % gemm.macM == 0, "MacroTile size mismatch (M)"); AssertFatal(N % gemm.macN == 0, "MacroTile size mismatch (N)"); - if(gemm.unrollK > 0) + if(gemm.unrollK > 0 && !gemm.tailLoops) { AssertFatal(K % (gemm.macK * gemm.unrollK) == 0, "MacroTile size mismatch (K unroll)"); From 8455ee9b1b5c11f53758dad4fdab79a97155b751 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Tue, 15 Apr 2025 15:50:29 +0000 Subject: [PATCH 12/17] Add additional tests --- scripts/lib/rrperf/rrsuites.py | 18 +++ test/unit/GEMMTest.cpp | 202 +++++++++++++++++++++++++++++++++ 2 files changed, 220 insertions(+) diff --git a/scripts/lib/rrperf/rrsuites.py b/scripts/lib/rrperf/rrsuites.py index 2177794f..2ece6977 100644 --- a/scripts/lib/rrperf/rrsuites.py +++ b/scripts/lib/rrperf/rrsuites.py @@ -363,6 +363,24 @@ def tail_loop_reproducer(): mac_k=8, ) +def jammed_tail_loop(): + yield mkGEMM( + M=256, + N=256, + K=24, + trans_A="T", + trans_B="N", + wave_m=32, + wave_n=32, + wave_k=2, + wave_b=1, + mac_m=64, + mac_n=64, + mac_k=8, + unroll_x=2, + unroll_y=2, + # unroll_k=2 implicit due to prefetchInFlight=2 + ) def guidepost_1(): yield mkGEMM( diff --git a/test/unit/GEMMTest.cpp b/test/unit/GEMMTest.cpp index a692612f..25beee81 100644 --- a/test/unit/GEMMTest.cpp +++ b/test/unit/GEMMTest.cpp @@ -1186,6 +1186,29 @@ namespace GEMMDriverTest } } + TEST_P(GEMMTestGPU, GPU_BasicGEMMJammedUnrollKTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + gemm.m = 64; + gemm.n = 128; + gemm.transA = "T"; + gemm.transB = "N"; + gemm.loadLDSA = false; + gemm.loadLDSB = false; + gemm.storeLDSD = false; + gemm.fuseLoops = true; + gemm.tailLoops = true; + gemm.unrollY = 2; + gemm.unrollK = 4; + gemm.macK = 8; + for(auto k : {8, 16, 24, 32, 40, 48, 56, 64}) + { + gemm.k = k; + basicGEMM(gemm); + } + } + TEST_P(GEMMTestGPU, GPU_BasicGEMMUnrollKLDS) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); @@ -2756,6 +2779,33 @@ namespace GEMMDriverTest basicGEMM(gemm); } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x2UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 80; + + gemm.macM = 128; + gemm.macN = 256; + gemm.macK = 16; + + gemm.unrollK = 2; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 2 * gemm.wavefrontSize; + gemm.workgroupSizeY = 4; + + gemm.loadLDSA = false; + gemm.storeLDSD = false; + gemm.fuseLoops = false; + + basicGEMM(gemm); + } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x1) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); @@ -2820,6 +2870,38 @@ namespace GEMMDriverTest EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x1UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 80; + + gemm.macM = 128; + gemm.macN = 128; + gemm.macK = 16; + + gemm.unrollK = 2; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 2 * gemm.wavefrontSize; + gemm.workgroupSizeY = 4; + + gemm.transA = "T"; + gemm.transB = "N"; + + basicGEMM(gemm); + + std::string generatedCode = m_context->instructions()->toString(); + + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 22); + EXPECT_EQ(countSubstring(generatedCode, "ds_read_b128"), 8); + EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); + } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1x2) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); @@ -2880,6 +2962,37 @@ namespace GEMMDriverTest EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1x2UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 80; + + gemm.macM = 128; + gemm.macN = 128; + gemm.macK = 16; + + gemm.unrollK = 4; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 4 * gemm.wavefrontSize; + gemm.workgroupSizeY = 2; + + gemm.transA = "T"; + + basicGEMM(gemm); + + std::string generatedCode = m_context->instructions()->toString(); + + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b64"), 26); + EXPECT_EQ(countSubstring(generatedCode, "ds_read_b128"), 8); + EXPECT_EQ(countSubstring(generatedCode, "buffer_store_dwordx4"), 8); + } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1x8) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); @@ -2927,6 +3040,31 @@ namespace GEMMDriverTest basicGEMM(gemm); } + + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed1x8UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 80; + gemm.macM = 128; + gemm.macN = 256; + gemm.macK = 16; + + gemm.unrollK = 2; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 4 * gemm.wavefrontSize; + gemm.workgroupSizeY = 1; + + gemm.storeLDSD = false; + + basicGEMM(gemm); + } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x4) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); @@ -2988,6 +3126,39 @@ namespace GEMMDriverTest EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 9); } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed2x4UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 80; + + gemm.macM = 128; + gemm.macN = 256; + gemm.macK = 16; + + gemm.unrollK = 2; + + gemm.prefetchInFlight = 2; + gemm.prefetchLDSFactor = 2; + gemm.prefetchMixMemOps = true; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 2 * gemm.wavefrontSize; + gemm.workgroupSizeY = 2; + + gemm.storeLDSD = false; + + basicGEMM(gemm); + + std::string generatedCode = m_context->instructions()->toString(); + + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 9); + } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed4x2) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); @@ -3048,6 +3219,37 @@ namespace GEMMDriverTest EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 15); } + TEST_P(GEMMJammedTestGPU, GPU_BasicGEMMFP16Jammed4x2UnrollKUseTailLoop) + { + REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); + GEMMProblem gemm; + + gemm.m = 256; + gemm.n = 512; + gemm.k = 64; + + gemm.macM = 128; + gemm.macN = 256; + gemm.macK = 16; + + gemm.unrollK = 4; + + gemm.waveK = 8; + + gemm.workgroupSizeX = 1 * gemm.wavefrontSize; + gemm.workgroupSizeY = 4; + + gemm.storeLDSD = false; + + gemm.transB = "N"; + + basicGEMM(gemm); + + std::string generatedCode = m_context->instructions()->toString(); + + EXPECT_EQ(countSubstring(generatedCode, "ds_write_b128"), 15); + } + TEST_P(GEMMTestGPU, GPU_BasicGEMMLiteralStrides) { REQUIRE_ARCH_CAP(GPUCapability::HasMFMA); From 5c1be1fa71ca2a68358d4dc617df21460d2b5406 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Tue, 15 Apr 2025 16:13:39 +0000 Subject: [PATCH 13/17] Remove debug changes to graph printing I think these changes would actually be useful, but they're not appropriate to include in this PR so I'll make a new PR for those in the near future. --- lib/include/rocRoller/Expression.hpp | 2 - lib/include/rocRoller/Graph/Hypergraph.hpp | 11 +- .../rocRoller/Graph/Hypergraph_impl.hpp | 10 - .../KernelGraph/ControlGraph/Operation.hpp | 2 - .../KernelGraph/ControlToCoordinateMapper.hpp | 15 -- lib/source/Expression_toString.cpp | 178 ------------------ .../KernelGraph/ControlGraph/Operation.cpp | 9 - lib/source/KernelGraph/KernelGraph.cpp | 22 +-- 8 files changed, 3 insertions(+), 246 deletions(-) diff --git a/lib/include/rocRoller/Expression.hpp b/lib/include/rocRoller/Expression.hpp index 1276e0f7..8e8fa36e 100644 --- a/lib/include/rocRoller/Expression.hpp +++ b/lib/include/rocRoller/Expression.hpp @@ -731,8 +731,6 @@ namespace rocRoller std::string toString(ExpressionPtr const& expr); std::string toString(Expression const& expr); - std::string toShortString(ExpressionPtr const& expr); - std::string toShortString(Expression const& expr); std::ostream& operator<<(std::ostream&, ExpressionPtr const&); std::ostream& operator<<(std::ostream&, Expression const&); std::ostream& operator<<(std::ostream&, std::vector const&); diff --git a/lib/include/rocRoller/Graph/Hypergraph.hpp b/lib/include/rocRoller/Graph/Hypergraph.hpp index 9d6195be..18fdc0fb 100644 --- a/lib/include/rocRoller/Graph/Hypergraph.hpp +++ b/lib/include/rocRoller/Graph/Hypergraph.hpp @@ -325,20 +325,13 @@ namespace rocRoller std::string toDOT(std::string const& prefix = "", bool standalone = true) const; - static bool edgeIdentity(Edge const&) + static bool identity(Edge const&) { return true; } - static inline void nodeIdentity(Node&, int) {} - - template Transform> - std::string toDOT(std::string const& prefix = "", - bool standalone = true, - Transform nodeTransform = nodeIdentity) const; - template Predicate> - std::string toDOT(Predicate edgePredicate = edgeIdentity) const; + std::string toDOT(Predicate edgePredicate = identity) const; template requires(std::constructible_from || std::constructible_from) diff --git a/lib/include/rocRoller/Graph/Hypergraph_impl.hpp b/lib/include/rocRoller/Graph/Hypergraph_impl.hpp index 4b6d9cd0..807c6143 100644 --- a/lib/include/rocRoller/Graph/Hypergraph_impl.hpp +++ b/lib/include/rocRoller/Graph/Hypergraph_impl.hpp @@ -899,15 +899,6 @@ namespace rocRoller template std::string Hypergraph::toDOT(std::string const& prefix, bool standalone) const - { - return toDOT(prefix, standalone, nodeIdentity); - } - - template - template Transform> - std::string Hypergraph::toDOT(std::string const& prefix, - bool standalone, - Transform nodeTransform) const { std::ostringstream msg; @@ -920,7 +911,6 @@ namespace rocRoller if(getElementType(pair.second) == ElementType::Node) { auto x = std::get(pair.second); - nodeTransform(x, pair.first); msg << toString(x) << "(" << pair.first << ")\""; } else diff --git a/lib/include/rocRoller/KernelGraph/ControlGraph/Operation.hpp b/lib/include/rocRoller/KernelGraph/ControlGraph/Operation.hpp index 0fadbf44..721570e3 100644 --- a/lib/include/rocRoller/KernelGraph/ControlGraph/Operation.hpp +++ b/lib/include/rocRoller/KernelGraph/ControlGraph/Operation.hpp @@ -68,10 +68,8 @@ namespace rocRoller explicit SetCoordinate(Expression::ExpressionPtr value); Expression::ExpressionPtr value; - std::string coordName; std::string name() const; - std::string toString() const; }; /** diff --git a/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp b/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp index ce375f8e..aebc18ad 100644 --- a/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp +++ b/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp @@ -70,11 +70,6 @@ namespace rocRoller::KernelGraph return a.id < b.id; } - bool inline operator==(TypeAndSubDimension const& a, TypeAndSubDimension const& b) - { - return a.id == b.id && a.subdimension == b.subdimension; - } - struct TypeAndNaryArgument { std::string id; @@ -94,11 +89,6 @@ namespace rocRoller::KernelGraph return a.id < b.id; } - bool inline operator==(TypeAndNaryArgument const& a, TypeAndNaryArgument const& b) - { - return a.id == b.id && a.argument == b.argument; - } - enum class ComputeIndexArgument : int { TARGET = 0, @@ -127,11 +117,6 @@ namespace rocRoller::KernelGraph return a.argument < b.argument; } - bool inline operator==(ComputeIndex const& a, ComputeIndex const& b) - { - return a.argument == b.argument && a.index == b.index; - } - using ConnectionSpec = std::variant - std::string operator()(Expr const& expr) const - { - return concatenate(ExpressionInfo::name(), - "(", - call(expr.lhs), - ", ", - call(expr.r1hs), - ", ", - call(expr.r2hs), - ")"); - } - - template - std::string operator()(Expr const& expr) const - { - return concatenate( - ExpressionInfo::name(), "(", call(expr.lhs), ", ", call(expr.rhs), ")"); - } - template - std::string operator()(Expr const& expr) const - { - return concatenate(ExpressionInfo::name(), "(", call(expr.arg), ")"); - } - std::string operator()(Register::ValuePtr const& expr) const - { - // This allows an unallocated register value to be rendered into a string which - // improves debugging by allowing the string representation of that expression - // to be put into the source file as a comment. - // Trying to generate the code for the expression will throw an exception. - - std::string tostr = "UNALLOCATED"; - if(expr->canUseAsOperand()) - tostr = expr->toString(); - - // The call() function appends the result type, so add ":" to separate the - // value from the type. - return tostr; - } - - std::string operator()(ScaledMatrixMultiply const& expr) const - { - return concatenate("ScaledMatrixMultiply(", - call(expr.matA), - ", ", - call(expr.matB), - ", ", - call(expr.matC), - ", ", - call(expr.scaleA), - ", ", - call(expr.scaleB), - ")"); - } - - std::string operator()(CommandArgumentPtr const& expr) const - { - if(expr) - return concatenate("CommandArgument(", expr->name(), ")"); - else - return "CommandArgument(nullptr)"; - } - - std::string operator()(CommandArgumentValue const& expr) const - { - return std::visit([](auto const& val) { return concatenate(val); }, expr); - } - - std::string operator()(AssemblyKernelArgumentPtr const& expr) const - { - // The call() function appends the result type, so add ":" to separate the - // value from the type. - return expr->name; - } - - std::string operator()(WaveTilePtr const& expr) const - { - return "WaveTile"; - } - - std::string operator()(DataFlowTag const& expr) const - { - return concatenate("DataFlowTag(", expr.tag, ")"); - } - -#define HANDLE_INFIX_OP(TYPE, INFIX) \ - std::string operator()(TYPE const& expr) const \ - { \ - return concatenate("(", call(expr.lhs), INFIX, call(expr.rhs), ")"); \ - } - HANDLE_INFIX_OP(Add, "+"); - HANDLE_INFIX_OP(Subtract, "-"); - HANDLE_INFIX_OP(Multiply, "*"); - HANDLE_INFIX_OP(Divide, "/"); - HANDLE_INFIX_OP(Modulo, "%"); - HANDLE_INFIX_OP(ShiftL, "<<"); - HANDLE_INFIX_OP(ArithmeticShiftR, ">>"); - HANDLE_INFIX_OP(BitwiseAnd, "&"); - HANDLE_INFIX_OP(BitwiseOr, "|"); - HANDLE_INFIX_OP(BitwiseXor, "^"); - HANDLE_INFIX_OP(GreaterThan, ">"); - HANDLE_INFIX_OP(GreaterThanEqual, ">="); - HANDLE_INFIX_OP(LessThan, "<"); - HANDLE_INFIX_OP(LessThanEqual, "<="); - HANDLE_INFIX_OP(Equal, "=="); - HANDLE_INFIX_OP(NotEqual, "!="); - HANDLE_INFIX_OP(LogicalAnd, "&&"); - HANDLE_INFIX_OP(LogicalOr, "||"); - - std::string operator()(AddShiftL const& expr) const - { - return concatenate( - "((", call(expr.lhs), "+", call(expr.r1hs), ")<<", call(expr.r2hs), ")"); - } - - std::string operator()(ShiftLAdd const& expr) const - { - return concatenate( - "((", call(expr.lhs), "<<", call(expr.r1hs), ")+", call(expr.r2hs), ")"); - } - - std::string operator()(MultiplyAdd const& expr) const - { - return concatenate( - "((", call(expr.lhs), "*", call(expr.r1hs), ")+", call(expr.r2hs), ")"); - } - - std::string operator()(Conditional const& expr) const - { - return concatenate( - "(", call(expr.lhs), "?", call(expr.r1hs), ":", call(expr.r2hs), ")"); - } - - std::string operator()(Negate const& expr) const - { - return concatenate("(-", call(expr.arg), ")"); - } - - std::string operator()(BitwiseNegate const& expr) const - { - return concatenate("(~", call(expr.arg), ")"); - } - - std::string operator()(Convert const& expr) const - { - return concatenate( - "((", TypeAbbrev(resultVariableType(expr)), ")", call(expr.arg), ")"); - } - - std::string call(Expression const& expr) const - { - return std::visit(*this, expr); - } - - std::string call(ExpressionPtr expr) const - { - if(!expr) - return "nullptr"; - - return call(*expr); - } - }; - - std::string toShortString(Expression const& expr) - { - auto visitor = ExpressionToShortStringVisitor(); - return visitor.call(expr); - } - - std::string toShortString(ExpressionPtr const& expr) - { - auto visitor = ExpressionToShortStringVisitor(); - return visitor.call(expr); - } } } diff --git a/lib/source/KernelGraph/ControlGraph/Operation.cpp b/lib/source/KernelGraph/ControlGraph/Operation.cpp index c347364d..b55018c6 100644 --- a/lib/source/KernelGraph/ControlGraph/Operation.cpp +++ b/lib/source/KernelGraph/ControlGraph/Operation.cpp @@ -69,15 +69,6 @@ namespace rocRoller::KernelGraph::ControlGraph { } - std::string SetCoordinate::toString() const - { - if(coordName != "") - { - return concatenate(name(), " ", coordName, ": ", toShortString(value)); - } - return concatenate(name(), ": ", value); - } - std::string ForLoopOp::toString() const { return concatenate(name(), " ", loopName, ": ", condition); diff --git a/lib/source/KernelGraph/KernelGraph.cpp b/lib/source/KernelGraph/KernelGraph.cpp index c263c5fe..831aa5c9 100644 --- a/lib/source/KernelGraph/KernelGraph.cpp +++ b/lib/source/KernelGraph/KernelGraph.cpp @@ -33,7 +33,6 @@ namespace rocRoller { namespace KernelGraph { - std::string KernelGraph::toDOT(bool drawMappings, std::string title) const { std::stringstream ss; @@ -46,26 +45,7 @@ namespace rocRoller ss << coordinates.toDOT("coord", false); ss << "subgraph clusterCF {"; ss << "label = \"Control Graph\";" << std::endl; - - auto setCoordRename = [&](ControlGraph::Operation& op, int index) { - std::visit( - rocRoller::overloaded{[&](ControlGraph::SetCoordinate& setCoord) { - auto connected = only(mapper.getConnections(index)); - if(!connected) - return; - auto connectedIndex = (*connected).coordinate; - if(connectedIndex < 0) - return; - std::stringstream coordss; - coordss << name(coordinates.getNode(connectedIndex)) - << "(" << connectedIndex << ")"; - setCoord.coordName = coordss.str(); - }, - [](auto&) {}}, - op); - }; - - ss << control.toDOT("cntrl", false, setCoordRename); + ss << control.toDOT("cntrl", false); ss << "}" << std::endl; if(drawMappings) { From b8865833936bd2b9796292a6574d3b7b66dec111 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Tue, 15 Apr 2025 16:31:49 +0000 Subject: [PATCH 14/17] Add explanatory comments --- .../KernelGraph/Transformations/FuseLoops.cpp | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/lib/source/KernelGraph/Transformations/FuseLoops.cpp b/lib/source/KernelGraph/Transformations/FuseLoops.cpp index 1ca1234e..10b49208 100644 --- a/lib/source/KernelGraph/Transformations/FuseLoops.cpp +++ b/lib/source/KernelGraph/Transformations/FuseLoops.cpp @@ -55,6 +55,13 @@ namespace rocRoller { using GD = rocRoller::Graph::Direction; + /** + * @brief Gather all for loops to fuse below the starting node + * + * @param graph + * @param start + * @return std::vector> + */ std::vector> gatherForLoops(KernelGraph& graph, int start) { auto bodies = graph.control.getOutputNodeIndices(start).to(); @@ -63,6 +70,7 @@ namespace rocRoller auto isSequence = graph.control.isElemType(); auto isBody = graph.control.isElemType(); + // Get a set of all ForLoops and SetCoordinates which contain (tail) ForLoops auto maybeForLoops = graph.control.depthFirstVisit(bodies, isSequence, GD::Downstream) .filter([&](int tag) -> bool { @@ -78,6 +86,7 @@ namespace rocRoller }) .to(); + // Filter the previous set of nodes to only the ForLoops under consideration std::vector forLoops; for(auto const& maybeForLoop : maybeForLoops) { @@ -101,6 +110,8 @@ namespace rocRoller } } + // Determine sets of loops to be fused together + // Currently loops should only be fused if they're the "same" loop (from unrolling) std::vector> loopGroupsToFuse; while(!forLoops.empty()) { @@ -186,6 +197,13 @@ namespace rocRoller return loopGroupsToFuse; } + /** + * @brief General routine to fuse one node into another + * + * @param graph + * @param fusedNodeTag + * @param nodeTag + */ void fuseNode(KernelGraph& graph, int fusedNodeTag, int nodeTag) { for(auto const& child : @@ -250,6 +268,10 @@ namespace rocRoller } } + /** + * @brief Visitor to determine if two nodes are the "same" operation of the purposes of fusion + * + */ struct IsSameOperationVisitor { template @@ -290,6 +312,12 @@ namespace rocRoller KernelGraph const& graph; }; + /** + * @brief Walks up the tree, fusing nodes which contain the start in their body and are the same operation + * + * @param graph + * @param tag + */ void fuseScopes(KernelGraph& graph, int tag) { auto parentsWithEdges @@ -414,6 +442,12 @@ namespace rocRoller return newGraph; } + /** + * @brief Ensure no node is a child of two nodes which are not contained within each other + * + * @param graph + * @return ConstraintStatus + */ ConstraintStatus BodyOfOnlyOneNode(const KernelGraph& graph) { ConstraintStatus retval; From 53c7633d417f7125d94b63e18e45b1befc686a15 Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Wed, 16 Apr 2025 16:24:31 +0000 Subject: [PATCH 15/17] Add comments suggested by @maemmett --- lib/source/KernelGraph/Transformations/FuseLoops.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/source/KernelGraph/Transformations/FuseLoops.cpp b/lib/source/KernelGraph/Transformations/FuseLoops.cpp index 10b49208..d07887f1 100644 --- a/lib/source/KernelGraph/Transformations/FuseLoops.cpp +++ b/lib/source/KernelGraph/Transformations/FuseLoops.cpp @@ -206,6 +206,7 @@ namespace rocRoller */ void fuseNode(KernelGraph& graph, int fusedNodeTag, int nodeTag) { + // Move operations that come after `nodeTag` to follow `fusedNodeTag`. for(auto const& child : graph.control.getOutputNodeIndices(nodeTag).to()) { @@ -215,6 +216,7 @@ namespace rocRoller } graph.control.deleteElement(std::vector{nodeTag}, std::vector{child}); + // Make sure we don't introduce a cycle std::unordered_set toDelete; for(auto descSeqOfChild : filter(graph.control.isElemType(), @@ -233,6 +235,7 @@ namespace rocRoller } } + // Fuse the bodies for(auto const& child : graph.control.getOutputNodeIndices(nodeTag).to()) { @@ -241,6 +244,7 @@ namespace rocRoller std::vector{child}); } + // Make sure dependencies are satisfied for(auto const& parent : graph.control.getInputNodeIndices(nodeTag).to()) { From f3d09dde553d80dbc3eb7f76ef2a6540339686ee Mon Sep 17 00:00:00 2001 From: Lauren Wrubleski Date: Wed, 16 Apr 2025 17:02:33 +0000 Subject: [PATCH 16/17] Re-add change unintentionally removed in cleanup Implement == on TypeAndSubDimension --- .../KernelGraph/ControlToCoordinateMapper.hpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp b/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp index aebc18ad..ce375f8e 100644 --- a/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp +++ b/lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp @@ -70,6 +70,11 @@ namespace rocRoller::KernelGraph return a.id < b.id; } + bool inline operator==(TypeAndSubDimension const& a, TypeAndSubDimension const& b) + { + return a.id == b.id && a.subdimension == b.subdimension; + } + struct TypeAndNaryArgument { std::string id; @@ -89,6 +94,11 @@ namespace rocRoller::KernelGraph return a.id < b.id; } + bool inline operator==(TypeAndNaryArgument const& a, TypeAndNaryArgument const& b) + { + return a.id == b.id && a.argument == b.argument; + } + enum class ComputeIndexArgument : int { TARGET = 0, @@ -117,6 +127,11 @@ namespace rocRoller::KernelGraph return a.argument < b.argument; } + bool inline operator==(ComputeIndex const& a, ComputeIndex const& b) + { + return a.argument == b.argument && a.index == b.index; + } + using ConnectionSpec = std::variant Date: Tue, 22 Apr 2025 18:44:38 +0000 Subject: [PATCH 17/17] Formatting --- scripts/lib/rrperf/rrsuites.py | 2 ++ test/unit/GEMMTest.cpp | 1 + 2 files changed, 3 insertions(+) diff --git a/scripts/lib/rrperf/rrsuites.py b/scripts/lib/rrperf/rrsuites.py index 2ece6977..b6794bfe 100644 --- a/scripts/lib/rrperf/rrsuites.py +++ b/scripts/lib/rrperf/rrsuites.py @@ -363,6 +363,7 @@ def tail_loop_reproducer(): mac_k=8, ) + def jammed_tail_loop(): yield mkGEMM( M=256, @@ -382,6 +383,7 @@ def jammed_tail_loop(): # unroll_k=2 implicit due to prefetchInFlight=2 ) + def guidepost_1(): yield mkGEMM( HGEMM_7680x8448x8448, diff --git a/test/unit/GEMMTest.cpp b/test/unit/GEMMTest.cpp index 11fc3518..b36b5fa2 100644 --- a/test/unit/GEMMTest.cpp +++ b/test/unit/GEMMTest.cpp @@ -3046,6 +3046,7 @@ namespace GEMMDriverTest gemm.m = 256; gemm.n = 512; gemm.k = 80; + gemm.macM = 128; gemm.macN = 256; gemm.macK = 16;