diff --git a/lib/include/rocRoller/AssemblyKernel.hpp b/lib/include/rocRoller/AssemblyKernel.hpp index 6d68bde1..8521fb3c 100644 --- a/lib/include/rocRoller/AssemblyKernel.hpp +++ b/lib/include/rocRoller/AssemblyKernel.hpp @@ -154,6 +154,7 @@ namespace rocRoller * @param args Vector of CommandArgument pointers that should be added as arguments. */ void addCommandArguments(std::vector args); + void addNewCommandArguments(std::vector args); Expression::ExpressionPtr addCommandArgument(CommandArgumentPtr arg); std::string args_string(); diff --git a/lib/include/rocRoller/AssemblyKernel_impl.hpp b/lib/include/rocRoller/AssemblyKernel_impl.hpp index 428944fc..c46a6742 100644 --- a/lib/include/rocRoller/AssemblyKernel_impl.hpp +++ b/lib/include/rocRoller/AssemblyKernel_impl.hpp @@ -228,6 +228,15 @@ namespace rocRoller } } + inline void AssemblyKernel::addNewCommandArguments(std::vector args) + { + for(auto arg : args) + { + if(m_argumentNames.find(arg->name()) == m_argumentNames.end()) + addCommandArgument(arg); + } + } + inline Expression::ExpressionPtr AssemblyKernel::addCommandArgument(CommandArgumentPtr arg) { return addArgument({arg->name(), diff --git a/lib/include/rocRoller/CommandSolution.hpp b/lib/include/rocRoller/CommandSolution.hpp index 8b269023..b2cdb86e 100644 --- a/lib/include/rocRoller/CommandSolution.hpp +++ b/lib/include/rocRoller/CommandSolution.hpp @@ -195,6 +195,17 @@ namespace rocRoller */ void generateKernel(); + /** + * @brief Generates the kernel Graph by graph lowering and + * doesn't do code-generation. + */ + void generateKernelGraphOnlyAfterTransforms(); + + /** + * @brief Lower command arguments to kernel arguments. + */ + void lowerToKernelArguments(); + /** * @brief Assembles a generated kernel. Does not try to load * it. diff --git a/lib/source/CommandSolution.cpp b/lib/source/CommandSolution.cpp index 0e00ccdf..9cdeb7f9 100644 --- a/lib/source/CommandSolution.cpp +++ b/lib/source/CommandSolution.cpp @@ -129,6 +129,8 @@ namespace rocRoller rv.reserve(m_context->kernel()->argumentSize(), argStructs.size()); + Log::debug("== getKernelArguments =="); + for(auto& arg : argStructs) { auto value = Expression::evaluate(arg.expression, args); @@ -145,6 +147,8 @@ namespace rocRoller arg.name)); } + Log::debug(" arg.name {} value {}", arg.name, toString(value)); + rv.append(arg.name, value); } @@ -161,11 +165,21 @@ namespace rocRoller auto const& workitems = m_context->kernel()->workitemCount(); if(workitems[0]) + { rv.workitemCount[0] = getUnsignedInt(evaluate(workitems[0], args)); + Log::debug("== getKernelInvocation =="); + Log::debug(" workitemCount[0] {}", rv.workitemCount[0]); + } if(workitems[1]) + { rv.workitemCount[1] = getUnsignedInt(evaluate(workitems[1], args)); + Log::debug(" workitemCount[1] {}", rv.workitemCount[1]); + } if(workitems[2]) + { rv.workitemCount[2] = getUnsignedInt(evaluate(workitems[2], args)); + Log::debug(" workitemCount[2] {}", rv.workitemCount[2]); + } auto const& sharedMem = m_context->kernel()->dynamicSharedMemBytes(); if(sharedMem) @@ -211,6 +225,15 @@ namespace rocRoller co_yield Instruction::Comment(m_command->argInfo()); } + void CommandKernel::lowerToKernelArguments() + { + for(auto arg : m_command->getArguments()) + { + Log::debug("command argument: {}, {}", arg->toString(), toString(arg->expression())); + } + m_context->kernel()->addNewCommandArguments(m_command->getArguments()); + } + void CommandKernel::generateKernelGraph(std::string name) { TIMER(t, "CommandKernel::generateKernelGraph"); @@ -396,6 +419,23 @@ namespace rocRoller generateKernelSource(); } else + { + Log::debug("generateKernel() is doing nothing"); + // Probably from a unit test. The context should contain + // scheduled instructions already. + } + } + + void CommandKernel::generateKernelGraphOnlyAfterTransforms() + { + TIMER(t, "CommandKernel::generateKernelGraphOnlyAfterTransforms()"); + + if(m_command) + { + // Only lower the KernelGraph and don't generate codes. + generateKernelGraph(m_name); + } + else { // Probably from a unit test. The context should contain // scheduled instructions already. diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1d0bda1e..327339b9 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -177,6 +177,7 @@ set( catch/CustomAssertions.cpp catch/TestKernels.cpp + catch/AddressCalculationTest.cpp catch/AnnotateInstructionsTest.cpp catch/BinaryExpressionTest.cpp catch/BranchGeneratorTest.cpp diff --git a/test/catch/AddressCalculationTest.cpp b/test/catch/AddressCalculationTest.cpp new file mode 100644 index 00000000..c5e677f7 --- /dev/null +++ b/test/catch/AddressCalculationTest.cpp @@ -0,0 +1,946 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "CustomMatchers.hpp" +#include "CustomSections.hpp" +#include "TestContext.hpp" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace rocRoller; + +namespace AddressCalculationTest +{ + namespace KernelG = rocRoller::KernelGraph; + namespace ControlG = rocRoller::KernelGraph::ControlGraph; + namespace CoordG = rocRoller::KernelGraph::CoordinateGraph; + using KernelGraphType = typename rocRoller::KernelGraph::KernelGraph; + + class AddressTrace + { + public: + AddressTrace(KernelGraphType const& graph, ContextPtr ctx) + : m_kGraph(graph) + , m_context(ctx){}; + std::vector traceComputeIndexWithBuffer(); + + private: + KernelGraphType m_kGraph; + ContextPtr m_context; + }; + + std::vector AddressTrace::traceComputeIndexWithBuffer() + { + auto isComputeIndex = [&](int tag) { + return isOperation(this->m_kGraph.control.getElement(tag)); + }; + + std::vector rv; + auto root = m_kGraph.control.roots().only(); + int count = 0; + + // Note that identity_transduer is intentionally being used here in place of FastArithmetic. + // It is alright to use FastArithmetic here but + // fastArithmetic is being used eventually when the expressions are generated. + auto identity_transducer = [&](auto expr) { return expr; }; + auto coords = CoordG::Transformer(&(m_kGraph.coordinates), identity_transducer); + coords.fillExecutionCoordinates(m_context); + + for(auto ciTag : filter(isComputeIndex, m_kGraph.control.depthFirstVisit(root.value()))) + { + + auto maybeCi = m_kGraph.control.get(ciTag); + AssertFatal(maybeCi.has_value()); + auto ci = maybeCi.value(); + + auto buffer + = m_kGraph.mapper.get(ciTag, + KernelG::Connections::ComputeIndex{ + KernelG::Connections::ComputeIndexArgument::BUFFER}); + if(buffer == -1) + continue; + + auto base = m_kGraph.mapper.get(ciTag, + KernelG::Connections::ComputeIndex{ + KernelG::Connections::ComputeIndexArgument::BASE}); + + // Currently, only base < 0 case is being covered. + if(base >= 0) + continue; + + { + // Debugging log + if(m_context->kernel()) + Log::debug("kernel is non-null \n"); + + auto const& kernelWorkgroupIndices = m_context->kernel()->workgroupIndex(); + Log::debug("Size of kernelWorkGroupIndices: {}", kernelWorkgroupIndices.size()); + } + + auto offset + = m_kGraph.mapper.get(ciTag, + KernelG::Connections::ComputeIndex{ + KernelG::Connections::ComputeIndexArgument::OFFSET}); + auto stride + = m_kGraph.mapper.get(ciTag, + KernelG::Connections::ComputeIndex{ + KernelG::Connections::ComputeIndexArgument::STRIDE}); + auto target + = m_kGraph.mapper.get(ciTag, + KernelG::Connections::ComputeIndex{ + KernelG::Connections::ComputeIndexArgument::TARGET}); + auto increment + = m_kGraph.mapper.get(ciTag, + KernelG::Connections::ComputeIndex{ + KernelG::Connections::ComputeIndexArgument::INCREMENT}); + + auto fullStop = [&](int tag) { return tag == increment; }; + auto direction = ci.forward ? Graph::Direction::Upstream : Graph::Direction::Downstream; + auto [required, path] = findRequiredCoordinates(target, direction, fullStop, m_kGraph); + + for(auto tag : required) + if((tag != increment) && (!coords.hasCoordinate(tag))) + coords.setCoordinate(tag, Expression::literal(0u)); + + // Set the increment coordinate to zero if it doesn't + // already have a value + bool initializeIncrement = !coords.hasPath({target}, ci.forward); + if(initializeIncrement) + { + coords.setCoordinate(increment, Expression::literal(0u)); + } + + // Compute an offset address if we don't have an + // associated base address to inherit from + { + // base < 0 by the time control reached here. + auto indexExpr + = ci.forward ? coords.forward({target})[0] : coords.reverse({target})[0]; + + rv.push_back(indexExpr); + + // Rests are for logging for debugging. + Log::debug("ci.forward for tag {} dir {}, stride {}, buffer {}", + ciTag, + ci.forward, + stride, + buffer); + Log::debug("IndexExpr base < 0 for tag {} in the original graph {}", + ciTag, + toString(indexExpr)); + + // Buffer Descriptor's base + auto user = m_kGraph.coordinates.get(target); + if(user) + { + Log::debug("User for tag {} {}, offset {}", + ciTag, + ShowValue(target), + (user->offset ? "yes" : "no")); + Log::debug("argument name {}, real name {}", + user->argumentName, + m_context->kernel()->findArgument(user->argumentName).name); + + // 1. user->argumentName + user->offset if any --> set to base of BufferDesc + // (Expression) + // 2. user->size --> set to size field (32-bit) of BuferDesc + } + } + // book-keeping for debugging purpose + count++; + } + + Log::debug("Count of computeIndex investigated: {}", count); + + return rv; + } + + struct AddressCalculationTest + { + /** + * gemmGraph should be a graph initialized by prob + */ + AddressCalculationTest(rocRoller::ContextPtr context, + GEMMProblem const& prob, + rocRollerTest::Graphs::GEMM gemmGraph) + : m_context(context) + , m_problem(prob) + , m_gemmGraph(gemmGraph) + { + } + + bool check_uint32_overflow(uint a, uint b) + { + uint64_t prod = static_cast(a) * b; + return prod > static_cast(std::numeric_limits::max()); + } + + std::pair, uint> + getWorkItemCount(CommandParametersPtr params, GEMMProblem const& problem) + { + int M = problem.m; + int N = problem.n; + int K = problem.k; + + AssertFatal(M > 0 && N > 0 && K > 0); + + auto workGroupSizes = params->getManualWorkgroupSize(); + + AssertFatal(workGroupSizes.has_value()); + auto workgroupSizeX = workGroupSizes.value()[0]; + auto workgroupSizeY = workGroupSizes.value()[1]; + + // compute NumWorkGroups + uint numWorkgroupX; + uint numWorkgroupY; + + if(problem.loopOverTiles > 0) + { + // multiple output macro tiles per workgroup + numWorkgroupX = M * N / problem.macM / problem.macN / 2; + numWorkgroupY = 1; + } + else if(problem.streamK) + { + numWorkgroupX = problem.numWGs; + numWorkgroupY = 1; + } + else + { + // one output macro tile per workgroup + numWorkgroupX = M / problem.macM; + numWorkgroupY = N / problem.macN; + } + + AssertFatal(!check_uint32_overflow(numWorkgroupX, workgroupSizeX)); + AssertFatal(!check_uint32_overflow(numWorkgroupY, workgroupSizeY)); + + auto NX_literal = numWorkgroupX * workgroupSizeX; + auto NY_literal = numWorkgroupY * workgroupSizeY; + + auto NX = std::make_shared(NX_literal); + auto NY = std::make_shared(NY_literal); + auto NZ = std::make_shared(1u); + + auto totalWorkitemCounts = NX_literal * NY_literal; + { + Log::debug("Calculated workitemcount[0]: {}", toString(NX)); + Log::debug("Calculated workitemcount[1]: {}", toString(NY)); + Log::debug("Calculated workitemcount[2]: {}", toString(NZ)); + Log::debug("totalWorkitemCounts: {}", totalWorkitemCounts); + } + + return {{NX, NY, NZ}, totalWorkitemCounts}; + } + + static Expression::ExpressionPtr + get64BitVectorOffset(ContextPtr context, + std::array const& workitemCount, + std::array const& workgroupSize) + { + std::array thread_index; + for(int i = 0; i < 3; i++) + thread_index[i] = std::make_shared( + context->kernel()->workitemIndex()[i]); + + std::array workgroup_index; + for(int i = 0; i < 3; i++) + workgroup_index[i] = std::make_shared( + context->kernel()->workgroupIndex()[i]); + + std::array workgroup_size; + for(int i = 0; i < 3; i++) + workgroup_size[i] = std::make_shared( + Register::Value::Literal(workgroupSize[i])); + + auto idx_x = thread_index[0] + workgroup_index[0] * workgroup_size[0]; + auto idx_y = thread_index[1] + workgroup_index[1] * workgroup_size[1]; + + auto compare_res_pointer = idx_x + idx_y * workitemCount[0]; + auto elementSize = std::make_shared( + Register::Value::Literal(sizeof(uint64_t))); + compare_res_pointer = compare_res_pointer * elementSize; + + Log::debug("Offset in kb: {}", toString(compare_res_pointer)); + + return compare_res_pointer; + } + + void setTensorArguments(CommandArguments& commandArgs, + GEMMProblem const& problem, + rocRollerTest::Graphs::GEMM const& gemm, + CommandKernelPtr const& commandKernelPtr) const + { + // calling setArgument - needed + TensorDescriptor descA( + gemm.mTa, {size_t(problem.m), size_t(problem.k)}, problem.transA); + TensorDescriptor descB( + gemm.mTb, {size_t(problem.k), size_t(problem.n)}, problem.transB); + TensorDescriptor descC(gemm.mTd, {size_t(problem.m), size_t(problem.n)}, "N"); + TensorDescriptor descD(gemm.mTd, {size_t(problem.m), size_t(problem.n)}, "N"); + + // Note that actually large matrix hipMalloc is not needed. + // But at the same time, larger malloc may incurr larger base address, which can lead to overflow. + // + // TODO: How to get "float" from DataType of gemm.tA ? DataType::Float --> "float" + // (An opposite way is TypeInfo::Var.dataType). + // Might be alright to use float everytime. In the end, we never uses the allocated values. + // + auto deviceA = make_shared_device(); + auto deviceB = make_shared_device(); + auto deviceC = make_shared_device(); + auto deviceD = make_shared_device(); + + // Note that gemm built the CommandGraph, and store OperationTags. We are reusing + // that original Command, but getting away building KernelGraph and its lowering. + setCommandTensorArg(commandArgs, gemm.mTagTensorA, descA, deviceA.get()); + setCommandTensorArg(commandArgs, gemm.mTagTensorB, descB, deviceB.get()); + setCommandTensorArg(commandArgs, gemm.mTagTensorC, descC, deviceC.get()); + setCommandTensorArg(commandArgs, gemm.mTagTensorD, descD, deviceD.get()); + + commandArgs.setArgument(gemm.mTagScalarAlpha, ArgumentType::Value, problem.alpha); + commandArgs.setArgument(gemm.mTagScalarBeta, ArgumentType::Value, problem.beta); + // seed doesn't seem to be relevant currently. + //if(seed.has_value()) + // commandArgs.setArgument(gemm.m_tagScalarSeed, ArgumentType::Value, seed.value()); + + // Create scratch space + if(problem.streamK) + { + commandArgs.setArgument(gemm.mTagNumWGs, ArgumentType::Value, problem.numWGs); + } + + // ?? Is it correct to use original commandKernel or addrTestCommandKernel + auto scratchSpaceRequired + = commandKernelPtr->scratchSpaceRequired(commandArgs.runtimeArguments()); + auto deviceScratch = make_shared_device(scratchSpaceRequired, 0); + commandArgs.setArgument(gemm.mTagScratch, ArgumentType::Value, deviceScratch.get()); + } + + void generateOrigKernelByProlog() + { + m_commandKernel = std::make_shared(m_gemmGraph.getCommand(), ""); + m_commandKernel->setContext(m_context); + m_commandKernel->setCommandParameters(m_gemmGraph.getCommandParameters()); + + Log::debug("lazyAddArguments: {}", m_context->kernelOptions().lazyAddArguments); + + // Add extra kargs for testing + m_rvTag = m_commandKernel->getCommand()->allocateTag(); + m_commandKernel->getCommand()->allocateArgument( + VariableType(DataType::UInt64, PointerType::PointerGlobal), + m_rvTag, + ArgumentType::Value, + DataDirection::WriteOnly, + "rv_ptr"); + + m_rvTag2 = m_commandKernel->getCommand()->allocateTag(); + m_commandKernel->getCommand()->allocateArgument( + VariableType(DataType::UInt64, PointerType::PointerGlobal), + m_rvTag2, + ArgumentType::Value, + DataDirection::WriteOnly, + "rv_ptr2"); + + m_commandKernel->generateKernelGraphOnlyAfterTransforms(); + + auto k = m_context->kernel(); + m_context->schedule(k->preamble()); + { + { // Original Command Arguments + auto commandArguments = m_commandKernel->getCommand()->getArguments(); + for(auto arg : commandArguments) + Log::debug("Original Arg: {}", arg->toString()); + } + + { // workitemcount + for(auto wit : k->workitemCount()) + Log::debug("Original Workitem expr: {}", toString(wit)); + } + } + + m_commandKernel->lowerToKernelArguments(); + auto commandParameters = m_commandKernel->getCommandParameters(); + auto workgroupSize = commandParameters->getManualWorkgroupSize(); + CHECK(workgroupSize.has_value()); + + { // workitemcount + for(auto wit : k->workitemCount()) + Log::debug("== Original Workitem expr: {}", toString(wit)); + } + + m_context->schedule(k->prolog()); + } + + void launchKernelAndCopyBackToHost() + { + // Setting workitem counts before launching + // Notice that addrTestCommandKernel's generateKernel() won't be called. + auto launch = std::make_shared(); + + // See the comments down there. Computation of workitemCount is redundant. + // "totalWorkitemCounts" is for debugging or logging. + std::tie(m_workitemCount, m_totalWorkitemCount) + = getWorkItemCount(m_commandKernel->getCommandParameters(), m_problem); + + // Notice that without following line of "setManualWorkitemCount". + // launch->setManualWorkitemCount(workitemCount); + // - workitemcounts were computed correctly + // within the workitemcount_kernelbody's execution. That, I think, is because + // The original graph already had workitemcount expressions as a function of + // input tensor sizes, one of the commandArguments. (e.g. Tensor_0_size_0_8) + // Thus, only the command arguments are set, workitemcounts can be computed. + // In ohter words, this manual computation of workitemcounts and setting by + // setManualWorkitemCount is not needed. Still, "getWorkItemCount" is kept + // for debugging purposes. In order to allocate device and host memory for storing + // results of kernel's execution, we still need a concrete number of allocation sizes. + + m_commandKernel->setLaunchParameters(launch); + + CommandArguments commandArgs = m_commandKernel->getCommand()->createArguments(); + setTensorArguments(commandArgs, m_problem, m_gemmGraph, m_commandKernel); + + auto outputSize = m_totalWorkitemCount; + auto rvPointer = make_shared_device(outputSize, 0); + commandArgs.setArgument(m_rvTag, ArgumentType::Value, rvPointer.get()); + + auto rvPointer2 = make_shared_device(outputSize, 0); + commandArgs.setArgument(m_rvTag2, ArgumentType::Value, rvPointer2.get()); + + m_commandKernel->launchKernel(commandArgs.runtimeArguments()); + + m_hostBuffer.resize(outputSize, 10); + CHECK_THAT(hipMemcpy(m_hostBuffer.data(), + rvPointer.get(), + sizeof(uint64_t) * outputSize, + hipMemcpyDefault), + HasHipSuccess(0)); + + m_hostBuffer2.resize(outputSize, 20); + CHECK_THAT(hipMemcpy(m_hostBuffer2.data(), + rvPointer2.get(), + sizeof(uint64_t) * outputSize, + hipMemcpyDefault), + HasHipSuccess(0)); + + // check hostBuffer's value + Log::debug("outputSize: {}", outputSize); + } + + void test_implicit_workitemcount() + { + generateOrigKernelByProlog(); + + auto k = m_context->kernel(); + + // [context, workitemcount_X, workitemcount_Y, workgroupSize] are used. + // Just print out passed expressions to device memory. + // If passed expressions are workitemCounts, the expectation is that + // the generated expressions' values are the same with the host-side computed values. + auto kb = [&]() -> Generator { + // store base addrs + Register::ValuePtr s_ptr; + co_yield m_context->argLoader()->getValue("rv_ptr", s_ptr); + Register::ValuePtr s_ptr2; + co_yield m_context->argLoader()->getValue("rv_ptr2", s_ptr2); + auto compare_res_pointer = get64BitVectorOffset( + m_context, + k->workitemCount(), + (m_commandKernel->getCommandParameters()->getManualWorkgroupSize()).value()); + Log::debug("Offset in kb: {}", toString(compare_res_pointer)); + + Register::ValuePtr v_offset_1 = nullptr; + co_yield Expression::generate( + v_offset_1, compare_res_pointer + s_ptr->expression(), m_context); + + Register::ValuePtr v_offset_2 = nullptr; + co_yield Expression::generate( + v_offset_2, compare_res_pointer + s_ptr2->expression(), m_context); + + // Compute the value 1 + Register::ValuePtr s_value_1 = nullptr; + co_yield Expression::generate(s_value_1, k->workitemCount()[0], m_context); + auto v_value_11 = Register::Value::Placeholder( + m_context, Register::Type::Vector, DataType::Int64, 1); + co_yield m_context->copier()->copy(v_value_11, s_value_1, "copy to v1"); + + // Compute the value 2 + Register::ValuePtr s_value_2 = nullptr; + co_yield Expression::generate(s_value_2, k->workitemCount()[1], m_context); + auto v_value_22 = Register::Value::Placeholder( + m_context, Register::Type::Vector, DataType::Int64, 1); + co_yield m_context->copier()->copy(v_value_22, s_value_2, "copy to v2"); + + co_yield m_context->mem()->storeGlobal(v_offset_1, v_value_11, 0, 8); + co_yield m_context->mem()->storeGlobal(v_offset_2, v_value_22, 0, 8); + }; + + m_context->schedule(kb()); + + m_context->schedule(k->postamble()); + m_context->schedule(k->amdgpu_metadata()); + + launchKernelAndCopyBackToHost(); + + // Remove ":" and subsequent parts to extract only leading literals. + auto const& host_x_string = toString(m_workitemCount[0]); + size_t del1 = host_x_string.find_first_of(":"); + auto const& host_x = host_x_string.substr(0, del1); + + auto const& host_y_string = toString(m_workitemCount[1]); + del1 = host_y_string.find_first_of(":"); + auto const& host_y = host_y_string.substr(0, del1); + + for(int i = 0, size = m_hostBuffer.size(); i < m_totalWorkitemCount; i++) + { + // workitemCount.x in device and host. + CAPTURE(toString(m_hostBuffer[i]), host_x, i); + CHECK(toString(m_hostBuffer[i]) == host_x); + + // workitemCount.y in device and host. + CAPTURE(toString(m_hostBuffer2[i]), host_y, i); + CHECK(toString(m_hostBuffer2[i]) == host_y); + } + } + + void test_wg_thr_indices() + { + generateOrigKernelByProlog(); + + auto k = m_context->kernel(); + + // This is for printing out workgroup index and thread index + // [context, workitemCount, workgroupSize] are used. + auto kb = [&]() -> Generator { + // store base addr + Register::ValuePtr s_ptr; + co_yield m_context->argLoader()->getValue("rv_ptr", s_ptr); + + Register::ValuePtr s_ptr2; + co_yield m_context->argLoader()->getValue("rv_ptr2", s_ptr2); + + auto compare_res_pointer = get64BitVectorOffset( + m_context, + k->workitemCount(), + (m_commandKernel->getCommandParameters()->getManualWorkgroupSize()).value()); + Log::debug("Offset in kb: {}", toString(compare_res_pointer)); + + Register::ValuePtr v_offset_1 = nullptr; + co_yield Expression::generate( + v_offset_1, compare_res_pointer + s_ptr->expression(), m_context); + + Register::ValuePtr v_offset_2 = nullptr; + co_yield Expression::generate( + v_offset_2, compare_res_pointer + s_ptr2->expression(), m_context); + + // workgroupIndex x + auto v_wg_x = Register::Value::Placeholder( + m_context, Register::Type::Vector, DataType::UInt32, 1); + co_yield m_context->copier()->copy( + v_wg_x, k->workgroupIndex()[0], "copy wgi.x to v"); + co_yield m_context->mem()->storeGlobal(v_offset_1, v_wg_x, 0, 4); + + co_yield m_context->mem()->storeGlobal(v_offset_2, (k->workitemIndex())[0], 0, 4); + }; + + m_context->schedule(kb()); + + m_context->schedule(k->postamble()); + m_context->schedule(k->amdgpu_metadata()); + + launchKernelAndCopyBackToHost(); + + for(int i = 0; i < m_totalWorkitemCount; i++) + Log::debug("wgx {} thx {}", m_hostBuffer[i], m_hostBuffer2[i]); + + // Remove ":" and subsequent parts to extract only leading literals. + auto const& host_x_string = toString(m_workitemCount[0]); + size_t del1 = host_x_string.find_first_of(":"); + auto const& host_x = host_x_string.substr(0, del1); + auto workitemcount_x = std::stoi(host_x); + + auto const& host_y_string = toString(m_workitemCount[1]); + del1 = host_y_string.find_first_of(":"); + auto const& host_y = host_y_string.substr(0, del1); + auto workitemcount_y = std::stoi(host_y); + + Log::debug("workitemcount_x {} workitemcount_y {}", workitemcount_x, workitemcount_y); + + auto workgroupsize_x + = ((m_commandKernel->getCommandParameters()->getManualWorkgroupSize()).value())[0]; + for(int rows = 0, numRows = workitemcount_y; rows < numRows; rows++) + { + auto blockIdx_y = rows % workgroupsize_x; + auto base = rows * workitemcount_x; + for(int cols = 0, numCols = workitemcount_x; cols < numCols; cols++) + { + auto blockIdx_x = cols / workgroupsize_x; + auto threadIdx_x = cols % workgroupsize_x; + + auto linearIdx = cols + base; + + CHECK(m_hostBuffer[linearIdx] == blockIdx_x); + CHECK(m_hostBuffer2[linearIdx] == threadIdx_x); + } + } + } + + void test_equal_one_pair() + { + generateOrigKernelByProlog(); + + // Get the expression to compute + std::vector indexExprPtrs; + indexExprPtrs = AddressTrace(m_commandKernel->getKernelGraph(), m_context) + .traceComputeIndexWithBuffer(); + std::vector widenedExprPtrs; + + // Test only one pair of expressions (the first pair.) + auto eptr = indexExprPtrs[0]; + Log::debug("== Expr : {} ", toString(eptr)); + + widenedExprPtrs.push_back(rocRollerTest::widenAddrExprTo64bit(eptr)); + Log::debug("++ Widen : {} ", toString(widenedExprPtrs.back())); + + auto fast = Expression::FastArithmetic(m_context); + Log::debug("** fast : {} ", toString(fast(widenedExprPtrs.back()))); + + auto k = m_context->kernel(); + + // context, input, widenedInput, workitemCount, workgroupSize, are used + // inside the kb. + auto kb = [&]() -> Generator { + // store base addr + Register::ValuePtr s_ptr; + co_yield m_context->argLoader()->getValue("rv_ptr", s_ptr); + + Register::ValuePtr s_ptr2; + co_yield m_context->argLoader()->getValue("rv_ptr2", s_ptr2); + + auto compare_res_pointer = get64BitVectorOffset( + m_context, + k->workitemCount(), + (m_commandKernel->getCommandParameters()->getManualWorkgroupSize()).value()); + Log::debug("Offset in kb: {}", toString(compare_res_pointer)); + + Register::ValuePtr v_offset_1 = nullptr; + co_yield Expression::generate( + v_offset_1, compare_res_pointer + s_ptr->expression(), m_context); + + Register::ValuePtr v_offset_2 = nullptr; + co_yield Expression::generate( + v_offset_2, compare_res_pointer + s_ptr2->expression(), m_context); + + // Compute the value 1 + Register::ValuePtr v_value_1 = nullptr; + co_yield Expression::generate(v_value_1, indexExprPtrs[0], m_context); + + // Compute the value 2 + Register::ValuePtr v_value_2 = nullptr; + + co_yield Expression::generate(v_value_2, widenedExprPtrs[0], m_context); + + co_yield m_context->mem()->storeGlobal(v_offset_1, v_value_1, 0, 8); + co_yield m_context->mem()->storeGlobal(v_offset_2, v_value_2, 0, 8); + }; + + m_context->schedule(kb()); + + m_context->schedule(k->postamble()); + m_context->schedule(k->amdgpu_metadata()); + + launchKernelAndCopyBackToHost(); + + // for test_kernelbody + for(int i = 0, size = m_hostBuffer.size(); i < size; i++) + { + if(m_hostBuffer[i] != m_hostBuffer2[i]) + { + Log::debug("diff at {}: {} {}", i, m_hostBuffer[i], m_hostBuffer2[i]); + } + CHECK(m_hostBuffer[i] == m_hostBuffer2[i]); + } + } + + void test_equal_all_pairs() + { + generateOrigKernelByProlog(); + + // Get the expression to compute + std::vector indexExprPtrs; + indexExprPtrs = AddressTrace(m_commandKernel->getKernelGraph(), m_context) + .traceComputeIndexWithBuffer(); + std::vector widenedExprPtrs; + for(int i = 0, size = indexExprPtrs.size(); i < size; i++) + { + auto eptr = indexExprPtrs[i]; + Log::debug("== Expr : {} ", toString(eptr)); + + widenedExprPtrs.push_back(rocRollerTest::widenAddrExprTo64bit(eptr)); + Log::debug("++ Widen : {} ", toString(widenedExprPtrs.back())); + } + + auto allone_uint64 + = std::make_shared(static_cast(0xFFFFFFFF)); + + auto k = m_context->kernel(); + + // context, input, widenedInput, workitemCount, workgroupSize, allone_uint64 are used + // inside the kb. + auto kb = [&]() -> Generator { + // store base addr + Register::ValuePtr s_ptr; + + co_yield m_context->argLoader()->getValue("rv_ptr", s_ptr); + + Register::ValuePtr s_ptr2; + + co_yield m_context->argLoader()->getValue("rv_ptr2", s_ptr2); + + // 2-D + auto compare_res_pointer = get64BitVectorOffset( + m_context, + k->workitemCount(), + (m_commandKernel->getCommandParameters()->getManualWorkgroupSize()).value()); + Log::debug("Offset in kb: {}", toString(compare_res_pointer)); + + Register::ValuePtr v_offset = nullptr; + co_yield Expression::generate( + v_offset, compare_res_pointer + s_ptr->expression(), m_context); + + // boolean_diff was allocated to s[0:1] + // diff should be computed per lane, + // but is_zero_diff is actually one-bit + auto boolean_true = Register::Value::WavefrontPlaceholder(m_context); + Register::ValuePtr v_allone; + co_yield Expression::generate(v_allone, allone_uint64, m_context); + + co_yield m_context->copier()->copy( + boolean_true, v_allone, "set to true for all lanes"); + + Register::ValuePtr temp_res; + for(int i = 0, size = indexExprPtrs.size(); i < size; i++) + { + // Compute the value 1 + Register::ValuePtr v_value_1 = nullptr; + co_yield Expression::generate(v_value_1, indexExprPtrs[i], m_context); + + // Compute the value 2 + Register::ValuePtr v_value_2 = nullptr; + co_yield Expression::generate(v_value_2, widenedExprPtrs[i], m_context); + + // Compute diff (value_1 == value_2) + auto is_zero_diff = v_value_1->expression() == v_value_2->expression(); + { + auto boolType = resultVariableType(is_zero_diff).dataType; + AssertFatal( + boolType == DataType::Bool64, "is_zero type: ", toString(boolType)); + } + + // boolean_true = boolean_true & is_zero_diff + auto accumRes = std::make_shared( + Expression::BitwiseAnd{boolean_true->expression(), is_zero_diff, "accum"}); + { + auto accumType = resultVariableType(accumRes).dataType; + AssertFatal( + accumType == DataType::Bool64, "accum type {}", toString(accumType)); + } + co_yield Expression::generate(boolean_true, accumRes, m_context); + } + + auto v_value = Register::Value::Placeholder( + m_context, Register::Type::Vector, DataType::UInt64, 1); + co_yield m_context->copier()->copy(v_value, boolean_true, "Move value"); + co_yield m_context->mem()->storeGlobal(v_offset, v_value, 0, 8); + }; + + m_context->schedule(kb()); + + m_context->schedule(k->postamble()); + m_context->schedule(k->amdgpu_metadata()); + + launchKernelAndCopyBackToHost(); + + for(int i = 0, size = m_hostBuffer.size(); i < size; i++) + { + // The two addresses are not same at i. + CAPTURE(i, m_hostBuffer[i]); + CHECK(m_hostBuffer[i] == 0xFFFFFFFF); + } + + // hipFree will be taken care of by make_shared_device + } + + private: + ContextPtr m_context; + CommandKernelPtr m_commandKernel; + GEMMProblem const& m_problem; + rocRollerTest::Graphs::GEMM m_gemmGraph; + + rocRoller::Operations::OperationTag m_rvTag; + rocRoller::Operations::OperationTag m_rvTag2; + + // For checking results on host-side + std::array m_workitemCount; + uint m_totalWorkitemCount; + std::vector m_hostBuffer; + std::vector m_hostBuffer2; + }; + + TEST_CASE("address calculation test generate and run", "[expression][gpu]") + { + // Noticed that for "float" type, all different combinations of + // problem sizes (m, n, k) by macro tile sizes (macM, macN) + // will generate the same kernel instructions. + // This is because address calculation expressions use the same arithmetic + // over the same command arguments and workgroup indices, workitem indices. + // Only the vales of the command arguments and workgroup, workitem indices change. + + // Also noticed that current expressions obtained from generation of computeIndex's + // VGPR base part of buffer_ instructions uses only following command arguments. + // Tensor_15_stride_1_14 (D's m or n) + // Tensor_4_stride_1_13 (C's m or n) + // Tensor_0_stride_1_10 (A's m ) + // Tensor_2_stride_0_11 (B's n) + // No problem size "k" is involved in the expression. + // This can be because current expressions are only base VGPR address of buffer_ + // instructions. The "k" part might be applied as increment. + // Or, it could be simply a bug. + + // Called single as the one data type is applied to all A, B, C and D matrices. + + // To cut down execution time further, consider running only large matrices. + auto singleDataType = GENERATE(DataType::Float, DataType::Double); + + std::cout << "singleType: " << singleDataType << "\n"; + DYNAMIC_SECTION(singleDataType) + { + auto [m, n, k] + = GENERATE(values(TestValues::gemmProblemSizes)); + + std::cout << "problemSize m: " << m << "\n"; + std::cout << "problemSize n: " << n << "\n"; + std::cout << "problemSize k: " << k << "\n"; + DYNAMIC_SECTION("ps_" << m << "x" << n << "x" << k) + { + auto [macM, macN] = GENERATE(values(TestValues::macroTileSizes)); + DYNAMIC_SECTION("mc_" << macM << "x" << macN) + { + std::cout << "macM: " << macM << "\n"; + std::cout << "macN: " << macN << "\n"; + + // Come up with a string from problem_size and data type, to be given to ForTestDevice(); + auto probSizeString + = std::to_string(m) + "x" + std::to_string(n) + "x" + std::to_string(k); + auto macroTileString = std::to_string(macM) + "x" + std::to_string(macN); + auto suffixForKernelName + = toString(singleDataType) + "_" + probSizeString + "_" + macroTileString; + + auto [transA, transB] = GENERATE(values>( + {{"N", "T"}, {"N", "N"}, {"T", "N"}, {"T", "T"}})); + DYNAMIC_SECTION(transA << transB) + { + std::cout << "transA: " << transA << "\n"; + std::cout << "transB: " << transB << "\n"; + + suffixForKernelName += "_" + transA + transB; + auto context = TestContext::ForTestDevice({}, suffixForKernelName); + + GEMMProblem problem{.m = m, + .n = n, + .k = k, + .macM = macM, + .macN = macN, + .transA = transA, + .transB = transB}; + rocRollerTest::Graphs::GEMM gemm(singleDataType); + gemm.setProblem(problem); + + AddressCalculationTest kernel(context.get(), problem, gemm); + + // Generate a kernel for testing address calculation and run. + // Verification of the result is done. + kernel.test_equal_all_pairs(); + } + } + } + } + } + + // Following test checks only one-pair, simpler version of above. + TEST_CASE("address calculation test generate and run one pair", "[expression][gpu]") + { + auto context = TestContext::ForTestDevice({}, "128x128_one_pair"); + + GEMMProblem problem{.m = 128, .n = 128, .macM = 64, .macN = 64, .macK = 64}; + rocRollerTest::Graphs::GEMM gemm(DataType::Float); + gemm.setProblem(problem); + + AddressCalculationTest kernel(context.get(), problem, gemm); + kernel.test_equal_one_pair(); + } + + // Sanity check 1 + TEST_CASE("address calculation test implicit workitemcount", "[expression][gpu]") + { + auto context = TestContext::ForTestDevice({}, "impl_workitemcnt"); + + GEMMProblem problem{.m = 128, .n = 128, .macM = 64, .macN = 64, .macK = 64}; + rocRollerTest::Graphs::GEMM gemm(DataType::Float); + gemm.setProblem(problem); + + AddressCalculationTest kernel(context.get(), problem, gemm); + kernel.test_implicit_workitemcount(); + } + + // Sanity check 2 + TEST_CASE("address calculation test workgroup thread index", "[expression][gpu]") + { + auto context = TestContext::ForTestDevice({}, "128x128_sanity_indices"); + + GEMMProblem problem{.m = 128, .n = 128, .macM = 64, .macN = 64, .macK = 64}; + rocRollerTest::Graphs::GEMM gemm(DataType::Float); + gemm.setProblem(problem); + + AddressCalculationTest kernel(context.get(), problem, gemm); + kernel.test_wg_thr_indices(); + } +} \ No newline at end of file diff --git a/test/catch/CommandTest.cpp b/test/catch/CommandTest.cpp index 85bce25a..2bfd8918 100644 --- a/test/catch/CommandTest.cpp +++ b/test/catch/CommandTest.cpp @@ -38,9 +38,9 @@ namespace CommandTest { SECTION("GEMM/TileAdd") { - auto example1 = *rocRollerTest::Graphs::GEMM().getCommand(); - auto example2 = *rocRollerTest::Graphs::GEMM().getCommand(); - auto example3 = *rocRollerTest::Graphs::GEMM().getCommand(); + auto example1 = *rocRollerTest::Graphs::GEMM(DataType::Float).getCommand(); + auto example2 = *rocRollerTest::Graphs::GEMM(DataType::Float).getCommand(); + auto example3 = *rocRollerTest::Graphs::GEMM(DataType::Half).getCommand(); auto example4 = *rocRollerTest::Graphs::TileDoubleAdd().getCommand(); CHECK(example1 == example2); @@ -58,7 +58,7 @@ namespace CommandTest { SECTION("GEMM") { - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); auto command0 = example.getCommand(); auto yaml = Command::toYAML(*command0); diff --git a/test/catch/IdentifyParallelDimensionsTest.cpp b/test/catch/IdentifyParallelDimensionsTest.cpp index b20ef993..8dc35f4c 100644 --- a/test/catch/IdentifyParallelDimensionsTest.cpp +++ b/test/catch/IdentifyParallelDimensionsTest.cpp @@ -74,7 +74,7 @@ TEST_CASE("identifyParallelDimensionSets works for GEMM", "[kernel-graph]") using namespace rocRoller; auto ctx = TestContext::ForDefaultTarget(); - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); auto kgraph = KernelGraph::translate(example.getCommand()); @@ -175,7 +175,7 @@ SCENARIO("IdentifyParallelDimensions transformation works for GEMM", "[kernel-gr using namespace rocRoller; auto ctx = TestContext::ForDefaultTarget(); - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); GIVEN("The initial kernel graph for a GEMM") { diff --git a/test/catch/KernelGraphRemoveDuplicatesTest.cpp b/test/catch/KernelGraphRemoveDuplicatesTest.cpp index 9bb470c5..595793f0 100644 --- a/test/catch/KernelGraphRemoveDuplicatesTest.cpp +++ b/test/catch/KernelGraphRemoveDuplicatesTest.cpp @@ -42,9 +42,9 @@ TEST_CASE("Remove duplicates", "[kernel-graph]") using namespace rocRoller::KernelGraph::ControlGraph; auto ctx = TestContext::ForDefaultTarget().get(); - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); - example.setTileSize(128, 128, 32); + example.setTileSize(128, 64, 32); example.setMFMA(32, 32, 16, 1); example.setUseLDS(true, true, false); diff --git a/test/common/CommonGraphs.cpp b/test/common/CommonGraphs.cpp index c0271b56..644b62dd 100644 --- a/test/common/CommonGraphs.cpp +++ b/test/common/CommonGraphs.cpp @@ -193,4 +193,260 @@ namespace rocRollerTest::Graphs return params; } + GEMM::GEMM(DataType ta) + : GEMM(ta, ta) + { + } + GEMM::GEMM(DataType ta, DataType tb) + : GEMM(ta, tb, tb) + { + } + GEMM::GEMM(DataType ta, DataType tb, DataType tc) + : GEMM(ta, tb, tc, tc) + { + } + GEMM::GEMM(DataType ta, DataType tb, DataType tc, DataType td) + : mTa(ta) + , mTb(tb) + , mTc(tc) + , mTd(td) + { + } + + void GEMM::createCommand() + { + m_command = std::make_shared(); + + std::vector oneStridesN + = m_problem.literalStrides ? std::vector({(size_t)1}) : std::vector({}); + + std::vector oneStridesT = m_problem.literalStrides + ? std::vector({(size_t)0, (size_t)1}) + : std::vector({}); + + mTagTensorA = m_command->addOperation(rocRoller::Operations::Tensor( + 2, mTa, m_problem.transA == "N" ? oneStridesN : oneStridesT)); // A + + m_tagA = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorA)); + + mTagTensorB = m_command->addOperation(rocRoller::Operations::Tensor( + 2, mTb, m_problem.transB == "N" ? oneStridesN : oneStridesT)); // B + m_tagB = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorB)); + + mTagTensorC + = m_command->addOperation(rocRoller::Operations::Tensor(2, mTc, oneStridesN)); // C + m_tagC = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(mTagTensorC)); + + mTagScalarAlpha + = m_command->addOperation(rocRoller::Operations::Scalar(DataType::Float)); // alpha + auto tagLoadAlpha + = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(mTagScalarAlpha)); + + mTagScalarBeta = m_command->addOperation(rocRoller::Operations::Scalar(mTc)); // beta + auto tagLoadBeta + = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(mTagScalarBeta)); // beta + + auto tagAB = m_command->addOperation(rocRoller::Operations::T_Mul(m_tagA, m_tagB)); // A * B + + rocRoller::Operations::T_Execute execute(m_command->getNextTag()); + + auto tagBetaC + = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadBeta, m_tagC)); // beta * C + + auto tagAlphaAB + = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadAlpha, tagAB)); // alpha * (A * B) + + if(m_problem.betaInFma) + { + m_tagD = execute.addXOp(rocRoller::Operations::E_Add(tagBetaC, tagAlphaAB)); + // alpha * (A * B) + beta * C + } + else + { + m_tagD = execute.addXOp(rocRoller::Operations::E_Add(tagAlphaAB, tagBetaC)); + // alpha * (A * B) + beta * C + } + m_command->addOperation(std::move(execute)); + + mTagTensorD + = m_command->addOperation(rocRoller::Operations::Tensor(2, mTd, oneStridesN)); // D + m_command->addOperation(rocRoller::Operations::T_Store_Tiled(m_tagD, mTagTensorD)); // D + + if(m_problem.streamK) + { + mTagNumWGs = m_command->allocateTag(); + auto numWGsArg = m_command->allocateArgument(DataType::UInt32, + mTagNumWGs, + ArgumentType::Value, + DataDirection::ReadOnly, + rocRoller::NUMWGS); + } + + mTagScratch = m_command->allocateTag(); + m_command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), + mTagScratch, + ArgumentType::Value, + DataDirection::ReadWrite, + rocRoller::SCRATCH); + } + + CommandPtr GEMM::getCommand() + { + if(!m_command) + createCommand(); + + return m_command; + } + + KernelGraph GEMM::getKernelGraph() + { + return rocRoller::KernelGraph::translate(getCommand()); + } + + void GEMM::setTileSize(int m, int n, int k) + { + m_problem.macM = m; + m_problem.macN = n; + m_problem.macK = k; + } + + void GEMM::setMFMA(int m, int n, int k, int b) + { + m_problem.waveM = m; + m_problem.waveN = n; + m_problem.waveK = k; + m_problem.waveB = b; + } + + void GEMM::setUseLDS(bool a, bool b, bool d) + { + m_problem.loadLDSA = a; + m_problem.loadLDSB = b; + m_problem.storeLDSD = d; + } + + void GEMM::setPrefetch(bool prefetch, + int prefetchInFlight, + int prefetchLDSFactor, + bool prefetchMixMemOps) + { + m_problem.prefetch = prefetch; + m_problem.prefetchInFlight = prefetchInFlight; + m_problem.prefetchLDSFactor = prefetchLDSFactor; + m_problem.prefetchMixMemOps = prefetchMixMemOps; + + m_problem.unrollK = prefetchInFlight; + } + + void GEMM::setProblem(GEMMProblem const& problem) + { + m_problem = problem; + } + + GEMMProblem const& GEMM::getProblem() const + { + return m_problem; + } + + CommandParametersPtr GEMM::getCommandParameters() const + { + using namespace rocRoller::KernelGraph::CoordinateGraph; + + auto params = std::make_shared(); + + params->setManualKernelDimension(2); + + AssertFatal(m_problem.m % m_problem.macM == 0, + "MacroTile size mismatch (M)", + ShowValue(m_problem.m), + ShowValue(m_problem.macM)); + AssertFatal(m_problem.n % m_problem.macN == 0, + "MacroTile size mismatch (N)", + ShowValue(m_problem.n), + ShowValue(m_problem.macN)); + + AssertFatal(m_problem.workgroupSizeX % m_problem.wavefrontSize == 0, + "Workgroup Size X must be multiply of wave front size"); + + AssertFatal(m_problem.macM % m_problem.waveM == 0, + "Macrotile size must be a multiple of wavetile size"); + AssertFatal(m_problem.macN % m_problem.waveN == 0, + "Macrotile size must be a multiple of wavetile size"); + + // i.e. jammedM + uint wavetilePerWavefrontM + = m_problem.wavefrontSize * m_problem.macM / m_problem.waveM / m_problem.workgroupSizeX; + AssertFatal(wavetilePerWavefrontM > 0, + "Wavetiles per wavefront in M should be positive integer"); + + // i.e. jammedN + uint wavetilePerWavefrontN = m_problem.macN / m_problem.waveN / m_problem.workgroupSizeY; + AssertFatal(wavetilePerWavefrontN > 0, + "Wavetiles per wavefront in N should be positive integer"); + + AssertFatal(m_problem.macM % (m_problem.waveM * wavetilePerWavefrontM) == 0, + "WaveTile size mismatch (M)"); + AssertFatal(m_problem.macN % (m_problem.waveN * wavetilePerWavefrontN) == 0, + "WaveTile size mismatch (N)"); + + uint workgroupSizeX = m_problem.workgroupSizeX * m_problem.workgroupSizeY; + uint workgroupSizeY = 1; + params->setManualWorkgroupSize({workgroupSizeX, workgroupSizeY, 1}); + + auto macTileA + = MacroTile({m_problem.macM, m_problem.macK}, + LayoutType::MATRIX_A, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}, + m_problem.loadLDSA ? MemoryType::LDS : MemoryType::WAVE); + auto macTileB + = MacroTile({m_problem.macK, m_problem.macN}, + LayoutType::MATRIX_B, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}, + m_problem.loadLDSB ? MemoryType::LDS : MemoryType::WAVE); + auto macTileC + = MacroTile({m_problem.macM, m_problem.macN}, + LayoutType::MATRIX_ACCUMULATOR, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}); + auto macTileD + = MacroTile({m_problem.macM, m_problem.macN}, + LayoutType::MATRIX_ACCUMULATOR, + {m_problem.waveM, m_problem.waveN, m_problem.waveK, m_problem.waveB}, + m_problem.storeLDSD ? MemoryType::LDS : MemoryType::WAVE); + + params->setDimensionInfo(m_tagA, macTileA); + params->setDimensionInfo(m_tagB, macTileB); + params->setDimensionInfo(m_tagC, macTileC); + params->setDimensionInfo(m_tagD, macTileD); + + Log::debug("GEMM workgroup sizes {} {} {}", workgroupSizeX, workgroupSizeY, 1); + Log::debug("GEMM jamming {} {}", wavetilePerWavefrontM, wavetilePerWavefrontN); + + params->setWaveTilesPerWavefront(wavetilePerWavefrontM, wavetilePerWavefrontN); + params->setManualWavefrontCount( + {static_cast(m_problem.macM / m_problem.waveM / wavetilePerWavefrontM), + static_cast(m_problem.macN / m_problem.waveN / wavetilePerWavefrontN)}); + + params->fuseLoops = m_problem.fuseLoops; + params->tailLoops = m_problem.tailLoops; + params->allowAmbiguousMemoryNodes = m_problem.allowAmbiguousMemoryNodes; + params->unrollK = m_problem.unrollK; + params->packMultipleElementsInto1VGPR = m_problem.packMultipleElementsInto1VGPR; + params->prefetch = m_problem.prefetch; + params->prefetchInFlight = m_problem.prefetchInFlight; + params->prefetchLDSFactor = m_problem.prefetchLDSFactor; + params->prefetchMixMemOps = m_problem.prefetchMixMemOps; + params->transposeMemoryAccess[LayoutType::MATRIX_A] = m_problem.transA == "T"; + params->transposeMemoryAccess[LayoutType::MATRIX_B] = m_problem.transB == "T"; + params->transposeMemoryAccess[LayoutType::None] = true; + + if(m_problem.streamK) + { + params->loopOverOutputTilesDimensions = {0, 1}; + params->streamK = true; + params->streamKTwoTile = m_problem.streamKTwoTile; + } + + return params; + } + } diff --git a/test/common/common/CommonGraphs.hpp b/test/common/common/CommonGraphs.hpp index 4ee7a3a4..ddb4b816 100644 --- a/test/common/common/CommonGraphs.hpp +++ b/test/common/common/CommonGraphs.hpp @@ -40,6 +40,8 @@ #include #include +#include + namespace rocRollerTest { namespace Graphs @@ -52,6 +54,7 @@ namespace rocRollerTest using ContextPtr = rocRoller::ContextPtr; using KernelArguments = rocRoller::KernelArguments; using KernelGraph = rocRoller::KernelGraph::KernelGraph; + using DataType = rocRoller::DataType; /** * @brief Graph for linear: alpha x + beta y. @@ -178,11 +181,13 @@ namespace rocRollerTest * - Assign(D = alpha * AB + beta * C) * - StoreTiled(D) */ - template class GEMM { public: - GEMM(); + GEMM(DataType ta); + GEMM(DataType ta, DataType tb); + GEMM(DataType ta, DataType tb, DataType tc); + GEMM(DataType ta, DataType tb, DataType tc, DataType td); CommandPtr getCommand(); KernelGraph getKernelGraph(); @@ -190,19 +195,28 @@ namespace rocRollerTest void setTileSize(int m, int n, int k); void setMFMA(int m, int n, int k, int b); void setUseLDS(bool a, bool b, bool d); + void setPrefetch(bool prefetch, + int prefetchInFlight, + int prefetchLDSFactor, + bool prefetchMixMemOps); + void setProblem(GEMMProblem const& problem); + GEMMProblem const& getProblem() const; CommandParametersPtr getCommandParameters() const; + rocRoller::Operations::OperationTag mTagTensorA, mTagTensorB, mTagTensorC, mTagTensorD, + mTagScalarAlpha, mTagScalarBeta, mTagScalarSeed, mTagScratch; + rocRoller::Operations::OperationTag mTagNumWGs; + + DataType mTa, mTb, mTc, mTd; + private: void createCommand(); - int m_macM, m_macN, m_macK; - int m_waveM, m_waveN, m_waveK, m_waveB; - bool m_useLDSA = false, m_useLDSB = false, m_useLDSD = false; - rocRoller::Operations::OperationTag m_tagA, m_tagB, m_tagC, m_tagD; - CommandPtr m_command; + CommandPtr m_command; + GEMMProblem m_problem; }; /** diff --git a/test/common/common/CommonGraphs_impl.hpp b/test/common/common/CommonGraphs_impl.hpp index ec097436..9a8d6201 100644 --- a/test/common/common/CommonGraphs_impl.hpp +++ b/test/common/common/CommonGraphs_impl.hpp @@ -256,132 +256,6 @@ namespace rocRollerTest::Graphs return rocRoller::KernelGraph::translate(m_command); } - /* - * GEMM - */ - - template - GEMM::GEMM() - { - createCommand(); - } - - template - void GEMM::createCommand() - { - m_command = std::make_shared(); - - auto dataType = TypeInfo::Var.dataType; - - auto tagTensorA = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // A - m_tagA = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(tagTensorA)); - - auto tagTensorB = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // B - m_tagB = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(tagTensorB)); - - auto tagTensorC = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // C - m_tagC = m_command->addOperation(rocRoller::Operations::T_Load_Tiled(tagTensorC)); - - auto tagScalarAlpha - = m_command->addOperation(rocRoller::Operations::Scalar(dataType)); // alpha - auto tagLoadAlpha - = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(tagScalarAlpha)); - - auto tagScalarBeta - = m_command->addOperation(rocRoller::Operations::Scalar(dataType)); // beta - auto tagLoadBeta - = m_command->addOperation(rocRoller::Operations::T_Load_Scalar(tagScalarBeta)); // beta - - auto tagAB = m_command->addOperation(rocRoller::Operations::T_Mul(m_tagA, m_tagB)); // A * B - - rocRoller::Operations::T_Execute execute(m_command->getNextTag()); - auto tagAlphaAB - = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadAlpha, tagAB)); // alpha * (A * B) - auto tagBetaC - = execute.addXOp(rocRoller::Operations::E_Mul(tagLoadBeta, m_tagC)); // beta * C - m_tagD = execute.addXOp(rocRoller::Operations::E_Add(tagAlphaAB, tagBetaC)); - // alpha * (A * B) + beta * C - m_command->addOperation(std::move(execute)); - - auto tagTensorD = m_command->addOperation(rocRoller::Operations::Tensor(2, dataType)); // D - m_command->addOperation(rocRoller::Operations::T_Store_Tiled(m_tagD, tagTensorD)); // D - } - - template - CommandPtr GEMM::getCommand() - { - return m_command; - } - - template - KernelGraph GEMM::getKernelGraph() - { - return rocRoller::KernelGraph::translate(m_command); - } - - template - void GEMM::setTileSize(int m, int n, int k) - { - m_macM = m; - m_macN = n; - m_macK = k; - } - - template - void GEMM::setMFMA(int m, int n, int k, int b) - { - m_waveM = m; - m_waveN = n; - m_waveK = k; - m_waveB = b; - } - - template - void GEMM::setUseLDS(bool a, bool b, bool d) - { - m_useLDSA = a; - m_useLDSB = b; - m_useLDSD = d; - } - - template - CommandParametersPtr GEMM::getCommandParameters() const - { - using namespace rocRoller::KernelGraph::CoordinateGraph; - - auto params = std::make_shared(); - - auto macTileA = MacroTile({m_macM, m_macK}, - LayoutType::MATRIX_A, - {m_waveM, m_waveN, m_waveK, m_waveB}, - m_useLDSA ? MemoryType::LDS : MemoryType::WAVE); - auto macTileB = MacroTile({m_macK, m_macN}, - LayoutType::MATRIX_B, - {m_waveM, m_waveN, m_waveK, m_waveB}, - m_useLDSB ? MemoryType::LDS : MemoryType::WAVE); - auto macTileC = MacroTile( - {m_macM, m_macN}, LayoutType::MATRIX_ACCUMULATOR, {m_waveM, m_waveN, m_waveK, m_waveB}); - - params->setDimensionInfo(m_tagA, macTileA); - params->setDimensionInfo(m_tagB, macTileB); - params->setDimensionInfo(m_tagC, macTileC); - - // Workgroup size - uint wavefrontSize = 64; - uint workgroupSizeX = 2 * wavefrontSize; - uint workgroupSizeY = 4; - - uint jammedM = wavefrontSize * m_macM / m_waveM / workgroupSizeX; - uint jammedN = m_macN / m_waveN / workgroupSizeY; - - Log::debug("GEMM workgroup sizes {} {} {}", workgroupSizeX, workgroupSizeY, 1); - Log::debug("GEMM jamming {} {}", jammedM, jammedN); - - params->setWaveTilesPerWavefront(jammedM, jammedN); - - return params; - } - /* * TileDoubleAdd */ diff --git a/test/common/common/GEMMProblem.hpp b/test/common/common/GEMMProblem.hpp index bf6e8f50..f7fe1371 100644 --- a/test/common/common/GEMMProblem.hpp +++ b/test/common/common/GEMMProblem.hpp @@ -94,4 +94,6 @@ struct GEMMProblem rocRoller::Operations::ScaleMode scaleAMode = rocRoller::Operations::ScaleMode::None; rocRoller::Operations::ScaleMode scaleBMode = rocRoller::Operations::ScaleMode::None; + + auto operator<=>(GEMMProblem const& rhs) const = default; }; diff --git a/test/common/common/TestValues.hpp b/test/common/common/TestValues.hpp index e2fa08ab..9fc4116d 100644 --- a/test/common/common/TestValues.hpp +++ b/test/common/common/TestValues.hpp @@ -27,10 +27,12 @@ #pragma once // -// Value "suites" for arithemtic and expression tests +// Value "suites" for arithmetic and expression tests // #include +#include +#include #include #include @@ -138,6 +140,21 @@ namespace TestValues 12981.0, 42e5}; + // Portions of GEMMProblem + struct GemmProblemSize + { + int m; + int n; + int k; + }; + + inline std::initializer_list gemmProblemSizes + = {{128, 128, 128}, {512, 512, 128}, {1024, 1024, 256}}; + + // Notice that {256, 256} was intentionally avoided due to extremely prolonged time of + // code generation. + inline std::initializer_list> macroTileSizes = {{64, 64}, {128, 128}}; + template struct ByType { diff --git a/test/unit/KernelGraphTest.cpp b/test/unit/KernelGraphTest.cpp index 72f602aa..1b084440 100644 --- a/test/unit/KernelGraphTest.cpp +++ b/test/unit/KernelGraphTest.cpp @@ -1250,12 +1250,12 @@ namespace KernelGraphTest TEST_F(KernelGraphTest, LowerTensor) { - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); int macK = 16; int waveK = 8; - example.setTileSize(128, 256, macK); + example.setTileSize(128, 128, macK); example.setMFMA(32, 32, waveK, 1); example.setUseLDS(true, false, false); @@ -1342,7 +1342,7 @@ namespace KernelGraphTest TEST_F(KernelGraphTest, InlineIncrement) { - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); example.setTileSize(128, 256, 8); example.setMFMA(32, 32, 2, 1); @@ -2917,7 +2917,7 @@ namespace KernelGraphTest { using GD = Graph::Direction; - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); example.setTileSize(128, 256, 8); example.setMFMA(32, 32, 2, 1); diff --git a/test/unit/KernelGraphTest/UpdateParameters.cpp b/test/unit/KernelGraphTest/UpdateParameters.cpp index dd101c5c..e72f9960 100644 --- a/test/unit/KernelGraphTest/UpdateParameters.cpp +++ b/test/unit/KernelGraphTest/UpdateParameters.cpp @@ -88,7 +88,7 @@ namespace KernelGraphTest { using namespace rocRoller::KernelGraph; - auto example = rocRollerTest::Graphs::GEMM(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); int macK = 16; int waveK = 8; @@ -118,16 +118,16 @@ namespace KernelGraphTest // Now apply SetWorkitemCount and try again kgraph = kgraph.transform(std::make_shared(m_context)); - CommandArgumentPtr tensorDsizeX; + CommandArgumentPtr tensorAsizeX; { auto arguments = command->getArguments(); for(auto argument : arguments) { - if(argument->name() == "Tensor_4_size_0") - tensorDsizeX = argument; + if(argument->name() == "Tensor_0_size_0") + tensorAsizeX = argument; } } - ASSERT_NE(tensorDsizeX, nullptr) << "D size not found"; + ASSERT_NE(tensorAsizeX, nullptr) << "A size not found"; workitemCount = m_context->kernel()->workitemCount(); @@ -135,7 +135,7 @@ namespace KernelGraphTest auto workgroupSizeX = Expression::literal(128u); auto expected - = (((tensorDsizeX->expression() + workgroupSizeX) - one) / workgroupSizeX) * one; + = (((tensorAsizeX->expression() + workgroupSizeX) - one) / workgroupSizeX) * one; EXPECT_TRUE(Expression::identical(expected, workitemCount[0])); }