Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5fe7fd2
Initial work on FuseLoops
lawruble13 Feb 14, 2025
b940af0
Fix graphconstraint output
lawruble13 Feb 27, 2025
0f7f3d3
Merge branch 'main' into lwrubles/fusetailloops
lawruble13 Apr 8, 2025
da75299
fix compilation issue
lawruble13 Apr 8, 2025
9c60015
Add heuristic to restrict loops that can be merged
lawruble13 Apr 9, 2025
1183ce3
Add UnrollK version of jammed 2x2 test
lawruble13 Apr 9, 2025
ce3f3ed
Ensure no reindexing chains in RemoveDuplicates transformation
lawruble13 Apr 9, 2025
348eeb7
Merge remote-tracking branch 'origin/main' into lwrubles/fusetailloops
lawruble13 Apr 9, 2025
c593b46
Remove redundant gemm arguments
lawruble13 Apr 9, 2025
cf06225
Revert unintended merge changes
lawruble13 Apr 9, 2025
0c19abc
Formatting fixes
lawruble13 Apr 9, 2025
ac7cd66
Merge branch 'main' into lwrubles/fusetailloops
lawruble13 Apr 10, 2025
4194566
Update compareNodes call
lawruble13 Apr 10, 2025
bed7f2a
Merge branch 'main' into lwrubles/fusetailloops
lawruble13 Apr 15, 2025
09dacb3
Adjust checks on K size to accommodate correct unrolling
lawruble13 Apr 15, 2025
8455ee9
Add additional tests
lawruble13 Apr 15, 2025
5c1be1f
Remove debug changes to graph printing
lawruble13 Apr 15, 2025
b886583
Add explanatory comments
lawruble13 Apr 15, 2025
d104cd8
Merge branch 'main' into lwrubles/fusetailloops
lawruble13 Apr 16, 2025
53c7633
Add comments suggested by @maemmett
lawruble13 Apr 16, 2025
f3d09dd
Re-add change unintentionally removed in cleanup
lawruble13 Apr 16, 2025
809c0f2
Merge branch 'main' into lwrubles/fusetailloops
lawruble13 Apr 22, 2025
74718cf
Merge branch 'main' into lwrubles/fusetailloops
lawruble13 Apr 24, 2025
5987093
Formatting
lawruble13 Apr 22, 2025
ea1003c
Merge branch 'main' into lwrubles/fusetailloops
lawruble13 May 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/include/DataParallelGEMMSolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ namespace rocRoller
// predicates
// unrollK size match predicates

if(params->unrollX <= 1 && params->unrollY <= 1 && !params->streamK)
if(params->tailLoops && !params->streamK)
{
auto unrollKPredicate = (aSizeExps[1] % macKExp == zero);
setComment(unrollKPredicate, "K must be a multiple of macK.");
Expand Down
15 changes: 15 additions & 0 deletions lib/include/rocRoller/KernelGraph/ControlToCoordinateMapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ namespace rocRoller::KernelGraph
return a.id < b.id;
}

bool inline operator==(TypeAndSubDimension const& a, TypeAndSubDimension const& b)
{
return a.id == b.id && a.subdimension == b.subdimension;
}

struct TypeAndNaryArgument
{
std::string id;
Expand All @@ -89,6 +94,11 @@ namespace rocRoller::KernelGraph
return a.id < b.id;
}

bool inline operator==(TypeAndNaryArgument const& a, TypeAndNaryArgument const& b)
{
return a.id == b.id && a.argument == b.argument;
}

enum class ComputeIndexArgument : int
{
TARGET = 0,
Expand Down Expand Up @@ -117,6 +127,11 @@ namespace rocRoller::KernelGraph
return a.argument < b.argument;
}

bool inline operator==(ComputeIndex const& a, ComputeIndex const& b)
{
return a.argument == b.argument && a.index == b.index;
}

using ConnectionSpec = std::variant<std::monostate,
JustNaryArgument,
ComputeIndex,
Expand Down
5 changes: 3 additions & 2 deletions lib/include/rocRoller/KernelGraph/Transforms/FuseLoops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ namespace rocRoller
class FuseLoops : public GraphTransform
{
public:
KernelGraph apply(KernelGraph const& original) override;
std::string name() const override
KernelGraph apply(KernelGraph const& original) override;
virtual std::vector<GraphConstraint> postConstraints() const override;
std::string name() const override
{
return "FuseLoops";
}
Expand Down
601 changes: 400 additions & 201 deletions lib/source/KernelGraph/Transformations/FuseLoops.cpp

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions lib/source/KernelGraph/Transformations/RemoveDuplicates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,27 @@ namespace rocRoller
}
}

// If there exists a reindexing "chain" A -> B -> C, then processing
// B -> C before A -> B causes the old references to A to be left
// dangling. To avoid this, we collapse the chain so A -> C and B -> C,
// which matches the behaviour we would see if A -> B was processed first.
for(auto [oldTag, newTag] : expressionReindexer.coordinates)
{
while(expressionReindexer.coordinates.count(newTag) > 0 && newTag != oldTag)
{
auto nextTag = expressionReindexer.coordinates.at(newTag);
if(nextTag != newTag)
{
newTag = expressionReindexer.coordinates.at(newTag);
}
else
{
break;
}
}
expressionReindexer.coordinates[oldTag] = newTag;
}

// Update tile references
auto kernel = *graph.control.roots().begin();
reindexExpressions(graph, kernel, expressionReindexer);
Expand Down
16 changes: 0 additions & 16 deletions lib/source/KernelGraph/Transformations/UnrollLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,22 +622,6 @@ namespace rocRoller
}
}

{
auto containingForLoops = graph.control.nodesContaining(loop).filter(
graph.control.isElemType<ForLoopOp>());
for(auto containingLoop : containingForLoops)
{
if(getUnrollAmount(graph, containingLoop, m_params) > 1)
{
Log::debug("Not adding tail loop for {} because it is contained by {} "
"which is also unrolled.",
loop,
containingLoop);
return std::nullopt;
}
}
}

auto loopSizeType = resultVariableType(loopSize);
auto amount = Expression::literal(unrollAmount, loopSizeType);
auto loopSizeRoundedDown = (loopSize / amount) * amount;
Expand Down
20 changes: 20 additions & 0 deletions scripts/lib/rrperf/rrsuites.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,26 @@ def tail_loop_reproducer():
)


def jammed_tail_loop():
yield mkGEMM(
M=256,
N=256,
K=24,
trans_A="T",
trans_B="N",
wave_m=32,
wave_n=32,
wave_k=2,
wave_b=1,
mac_m=64,
mac_n=64,
mac_k=8,
unroll_x=2,
unroll_y=2,
# unroll_k=2 implicit due to prefetchInFlight=2
)


def guidepost_1():
yield mkGEMM(
HGEMM_7680x8448x8448,
Expand Down
2 changes: 1 addition & 1 deletion test/unit/GEMMFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace GEMMDriverTest
AssertFatal(M % gemm.macM == 0, "MacroTile size mismatch (M)");
AssertFatal(N % gemm.macN == 0, "MacroTile size mismatch (N)");

if(gemm.unrollK > 0)
if(gemm.unrollK > 0 && !gemm.tailLoops)
{
AssertFatal(K % (gemm.macK * gemm.unrollK) == 0,
"MacroTile size mismatch (K unroll)");
Expand Down
Loading