Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/catch/CommandTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ namespace CommandTest
{
SECTION("GEMM/TileAdd")
{
auto example1 = *rocRollerTest::Graphs::GEMM<float>().getCommand();
auto example2 = *rocRollerTest::Graphs::GEMM<float>().getCommand();
auto example3 = *rocRollerTest::Graphs::GEMM<Half>().getCommand();
auto example1 = *rocRollerTest::Graphs::GEMM(DataType::Float).getCommand();
auto example2 = *rocRollerTest::Graphs::GEMM(DataType::Float).getCommand();
auto example3 = *rocRollerTest::Graphs::GEMM(DataType::Half).getCommand();
auto example4 = *rocRollerTest::Graphs::TileDoubleAdd<Half>().getCommand();

CHECK(example1 == example2);
Expand All @@ -58,7 +58,7 @@ namespace CommandTest
{
SECTION("GEMM")
{
auto example = rocRollerTest::Graphs::GEMM<float>();
auto example = rocRollerTest::Graphs::GEMM(DataType::Float);

auto command0 = example.getCommand();
auto yaml = Command::toYAML(*command0);
Expand Down
4 changes: 2 additions & 2 deletions test/catch/IdentifyParallelDimensionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ TEST_CASE("identifyParallelDimensionSets works for GEMM", "[kernel-graph]")
using namespace rocRoller;
auto ctx = TestContext::ForDefaultTarget();

auto example = rocRollerTest::Graphs::GEMM<float>();
auto example = rocRollerTest::Graphs::GEMM(DataType::Float);

auto kgraph = KernelGraph::translate(example.getCommand());

Expand Down Expand Up @@ -175,7 +175,7 @@ SCENARIO("IdentifyParallelDimensions transformation works for GEMM", "[kernel-gr
using namespace rocRoller;
auto ctx = TestContext::ForDefaultTarget();

auto example = rocRollerTest::Graphs::GEMM<float>();
auto example = rocRollerTest::Graphs::GEMM(DataType::Float);

GIVEN("The initial kernel graph for a GEMM")
{
Expand Down
4 changes: 2 additions & 2 deletions test/catch/KernelGraphRemoveDuplicatesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ TEST_CASE("Remove duplicates", "[kernel-graph]")
using namespace rocRoller::KernelGraph::ControlGraph;

auto ctx = TestContext::ForDefaultTarget().get();
auto example = rocRollerTest::Graphs::GEMM<float>();
auto example = rocRollerTest::Graphs::GEMM(DataType::Float);

example.setTileSize(128, 128, 32);
example.setTileSize(128, 64, 32);
example.setMFMA(32, 32, 16, 1);
example.setUseLDS(true, true, false);

Expand Down
256 changes: 256 additions & 0 deletions test/common/CommonGraphs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,260 @@ namespace rocRollerTest::Graphs
return params;
}

GEMM::GEMM(DataType ta)
: GEMM(ta, ta)
{
}
GEMM::GEMM(DataType ta, DataType tb)
: GEMM(ta, tb, tb)
{
}
GEMM::GEMM(DataType ta, DataType tb, DataType tc)
: GEMM(ta, tb, tc, tc)
{
}
GEMM::GEMM(DataType ta, DataType tb, DataType tc, DataType td)
: mTa(ta)
, mTb(tb)
, mTc(tc)
, mTd(td)
{
}

void GEMM::createCommand()
{
m_command = std::make_shared<rocRoller::Command>();

std::vector<size_t> oneStridesN
= m_problem.literalStrides ? std::vector<size_t>({(size_t)1}) : std::vector<size_t>({});

std::vector<size_t> oneStridesT = m_problem.literalStrides
? std::vector<size_t>({(size_t)0, (size_t)1})
: std::vector<size_t>({});

mTagTensorA = m_command->addOperation(rocRoller::Operations::Tensor(
2, mTa, m_problem.transA == "N" ? oneStridesN : oneStridesT)); // A

m_tagA = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorA));

mTagTensorB = m_command->addOperation(rocRoller::Operations::Tensor(
2, mTb, m_problem.transB == "N" ? oneStridesN : oneStridesT)); // B
m_tagB = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorB));

mTagTensorC
= m_command->addOperation(rocRoller::Operations::Tensor(2, mTc, oneStridesN)); // C
m_tagC = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorC));

mTagScalarAlpha
= m_command->addOperation(rocRoller::Operations::Scalar(DataType::Float)); // alpha
auto tagLoadAlpha
= m_command->addOperation(rocRoller::Operations::T_Load_Scalar(mTagScalarAlpha));

mTagScalarBeta = m_command->addOperation(rocRoller::Operations::Scalar(mTc)); // beta
auto tagLoadBeta
= m_command->addOperation(rocRoller::Operations::T_Load_Scalar(mTagScalarBeta)); // beta

auto tagAB = m_command->addOperation(rocRoller::Operations::T_Mul(m_tagA, m_tagB)); // A * B

rocRoller::Operations::T_Execute execute(m_command->getNextTag());

auto tagBetaC
= execute.addXOp(rocRoller::Operations::E_Mul(tagLoadBeta, m_tagC)); // beta * C

auto tagAlphaAB
= execute.addXOp(rocRoller::Operations::E_Mul(tagLoadAlpha, tagAB)); // alpha * (A * B)

if(m_problem.betaInFma)
{
m_tagD = execute.addXOp(rocRoller::Operations::E_Add(tagBetaC, tagAlphaAB));
// alpha * (A * B) + beta * C
}
else
{
m_tagD = execute.addXOp(rocRoller::Operations::E_Add(tagAlphaAB, tagBetaC));
// alpha * (A * B) + beta * C
}
m_command->addOperation(std::move(execute));

mTagTensorD
= m_command->addOperation(rocRoller::Operations::Tensor(2, mTd, oneStridesN)); // D
m_command->addOperation(rocRoller::Operations::T_Store_Tiled(m_tagD, mTagTensorD)); // D

if(m_problem.streamK)
{
mTagNumWGs = m_command->allocateTag();
auto numWGsArg = m_command->allocateArgument(DataType::UInt32,
mTagNumWGs,
ArgumentType::Value,
DataDirection::ReadOnly,
rocRoller::NUMWGS);
}

mTagScratch = m_command->allocateTag();
m_command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal),
mTagScratch,
ArgumentType::Value,
DataDirection::ReadWrite,
rocRoller::SCRATCH);
}

CommandPtr GEMM::getCommand()
{
if(!m_command)
createCommand();

return m_command;
}

KernelGraph GEMM::getKernelGraph()
{
return rocRoller::KernelGraph::translate(getCommand());
}

void GEMM::setTileSize(int m, int n, int k)
{
m_problem.macM = m;
m_problem.macN = n;
m_problem.macK = k;
}

void GEMM::setMFMA(int m, int n, int k, int b)
{
m_problem.waveM = m;
m_problem.waveN = n;
m_problem.waveK = k;
m_problem.waveB = b;
}

void GEMM::setUseLDS(bool a, bool b, bool d)
{
m_problem.loadLDSA = a;
m_problem.loadLDSB = b;
m_problem.storeLDSD = d;
}

void GEMM::setPrefetch(bool prefetch,
int prefetchInFlight,
int prefetchLDSFactor,
bool prefetchMixMemOps)
{
m_problem.prefetch = prefetch;
m_problem.prefetchInFlight = prefetchInFlight;
m_problem.prefetchLDSFactor = prefetchLDSFactor;
m_problem.prefetchMixMemOps = prefetchMixMemOps;

m_problem.unrollK = prefetchInFlight;
}

void GEMM::setProblem(GEMMProblem const& problem)
{
m_problem = problem;
}

GEMMProblem const& GEMM::getProblem() const
{
return m_problem;
}

CommandParametersPtr GEMM::getCommandParameters() const
{
using namespace rocRoller::KernelGraph::CoordinateGraph;

auto params = std::make_shared<CommandParameters>();

params->setManualKernelDimension(2);

AssertFatal(m_problem.m % m_problem.macM == 0,
"MacroTile size mismatch (M)",
ShowValue(m_problem.m),
ShowValue(m_problem.macM));
AssertFatal(m_problem.n % m_problem.macN == 0,
"MacroTile size mismatch (N)",
ShowValue(m_problem.n),
ShowValue(m_problem.macN));

AssertFatal(m_problem.workgroupSizeX % m_problem.wavefrontSize == 0,
"Workgroup Size X must be multiply of wave front size");

AssertFatal(m_problem.macM % m_problem.waveM == 0,
"Macrotile size must be a multiple of wavetile size");
AssertFatal(m_problem.macN % m_problem.waveN == 0,
"Macrotile size must be a multiple of wavetile size");

// i.e. jammedM
uint wavetilePerWavefrontM
= m_problem.wavefrontSize * m_problem.macM / m_problem.waveM / m_problem.workgroupSizeX;
AssertFatal(wavetilePerWavefrontM > 0,
"Wavetiles per wavefront in M should be positive integer");

// i.e. jammedN
uint wavetilePerWavefrontN = m_problem.macN / m_problem.waveN / m_problem.workgroupSizeY;
AssertFatal(wavetilePerWavefrontN > 0,
"Wavetiles per wavefront in N should be positive integer");

AssertFatal(m_problem.macM % (m_problem.waveM * wavetilePerWavefrontM) == 0,
"WaveTile size mismatch (M)");
AssertFatal(m_problem.macN % (m_problem.waveN * wavetilePerWavefrontN) == 0,
"WaveTile size mismatch (N)");

uint workgroupSizeX = m_problem.workgroupSizeX * m_problem.workgroupSizeY;
uint workgroupSizeY = 1;
params->setManualWorkgroupSize({workgroupSizeX, workgroupSizeY, 1});

auto macTileA
= MacroTile({m_problem.macM, m_problem.macK},
LayoutType::MATRIX_A,
{m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB},
m_problem.loadLDSA ? MemoryType::LDS : MemoryType::WAVE);
auto macTileB
= MacroTile({m_problem.macK, m_problem.macN},
LayoutType::MATRIX_B,
{m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB},
m_problem.loadLDSB ? MemoryType::LDS : MemoryType::WAVE);
auto macTileC
= MacroTile({m_problem.macM, m_problem.macN},
LayoutType::MATRIX_ACCUMULATOR,
{m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB});
auto macTileD
= MacroTile({m_problem.macM, m_problem.macN},
LayoutType::MATRIX_ACCUMULATOR,
{m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB},
m_problem.storeLDSD ? MemoryType::LDS : MemoryType::WAVE);

params->setDimensionInfo(m_tagA, macTileA);
params->setDimensionInfo(m_tagB, macTileB);
params->setDimensionInfo(m_tagC, macTileC);
params->setDimensionInfo(m_tagD, macTileD);

Log::debug("GEMM workgroup sizes {} {} {}", workgroupSizeX, workgroupSizeY, 1);
Log::debug("GEMM jamming {} {}", wavetilePerWavefrontM, wavetilePerWavefrontN);

params->setWaveTilesPerWavefront(wavetilePerWavefrontM, wavetilePerWavefrontN);
params->setManualWavefrontCount(
{static_cast<uint>(m_problem.macM / m_problem.waveM / wavetilePerWavefrontM),
static_cast<uint>(m_problem.macN / m_problem.waveN / wavetilePerWavefrontN)});

params->fuseLoops = m_problem.fuseLoops;
params->tailLoops = m_problem.tailLoops;
params->allowAmbiguousMemoryNodes = m_problem.allowAmbiguousMemoryNodes;
params->unrollK = m_problem.unrollK;
params->packMultipleElementsInto1VGPR = m_problem.packMultipleElementsInto1VGPR;
params->prefetch = m_problem.prefetch;
params->prefetchInFlight = m_problem.prefetchInFlight;
params->prefetchLDSFactor = m_problem.prefetchLDSFactor;
params->prefetchMixMemOps = m_problem.prefetchMixMemOps;
params->transposeMemoryAccess[LayoutType::MATRIX_A] = m_problem.transA == "T";
params->transposeMemoryAccess[LayoutType::MATRIX_B] = m_problem.transB == "T";
params->transposeMemoryAccess[LayoutType::None] = true;

if(m_problem.streamK)
{
params->loopOverOutputTilesDimensions = {0, 1};
params->streamK = true;
params->streamKTwoTile = m_problem.streamKTwoTile;
}

return params;
}

}
28 changes: 21 additions & 7 deletions test/common/common/CommonGraphs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
#include <rocRoller/Operations/Command_fwd.hpp>
#include <rocRoller/Operations/OperationTag.hpp>

#include <common/GEMMProblem.hpp>

namespace rocRollerTest
{
namespace Graphs
Expand All @@ -52,6 +54,7 @@ namespace rocRollerTest
using ContextPtr = rocRoller::ContextPtr;
using KernelArguments = rocRoller::KernelArguments;
using KernelGraph = rocRoller::KernelGraph::KernelGraph;
using DataType = rocRoller::DataType;

/**
* @brief Graph for linear: alpha x + beta y.
Expand Down Expand Up @@ -178,31 +181,42 @@ namespace rocRollerTest
* - Assign(D = alpha * AB + beta * C)
* - StoreTiled(D)
*/
template <typename T>
class GEMM
{
public:
GEMM();
GEMM(DataType ta);
GEMM(DataType ta, DataType tb);
GEMM(DataType ta, DataType tb, DataType tc);
GEMM(DataType ta, DataType tb, DataType tc, DataType td);

CommandPtr getCommand();
KernelGraph getKernelGraph();

void setTileSize(int m, int n, int k);
void setMFMA(int m, int n, int k, int b);
void setUseLDS(bool a, bool b, bool d);
void setPrefetch(bool prefetch,
int prefetchInFlight,
int prefetchLDSFactor,
bool prefetchMixMemOps);
void setProblem(GEMMProblem const& problem);

GEMMProblem const& getProblem() const;
CommandParametersPtr getCommandParameters() const;

rocRoller::Operations::OperationTag mTagTensorA, mTagTensorB, mTagTensorC, mTagTensorD,
mTagScalarAlpha, mTagScalarBeta, mTagScalarSeed, mTagScratch;
rocRoller::Operations::OperationTag mTagNumWGs;

DataType mTa, mTb, mTc, mTd;

private:
void createCommand();

int m_macM, m_macN, m_macK;
int m_waveM, m_waveN, m_waveK, m_waveB;
bool m_useLDSA = false, m_useLDSB = false, m_useLDSD = false;

rocRoller::Operations::OperationTag m_tagA, m_tagB, m_tagC, m_tagD;

CommandPtr m_command;
CommandPtr m_command;
GEMMProblem m_problem;
};

/**
Expand Down
Loading