Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,26 @@ createTargetMachine(llvm::Module *module, std::string proc,
return machine;
}

std::string translateLLVMIRToASM(llvm::Module &module,
const std::string &triple,
const std::string &proc,
const std::string &features,
const std::vector<std::string> &flags,
bool enable_fp_fusion, bool isObject) {
std::string translateLLVMIRToASM(
llvm::Module &module, const std::string &triple, const std::string &proc,
const std::string &features,
const std::vector<std::pair<std::string, std::variant<std::string, bool>>>
&flags,
bool enable_fp_fusion, bool isObject) {
using namespace mlir;
// options
auto options = llvm::cl::getRegisteredOptions();
for (std::string flag : flags) {
auto *shortPtr = static_cast<llvm::cl::opt<bool> *>(options[flag]);
assert(shortPtr);
shortPtr->setValue(true);
for (auto flag : flags) {
if (std::string *strVal = std::get_if<std::string>(&flag.second)) {
auto *shortPtr =
static_cast<llvm::cl::opt<std::string> *>(options[flag.first]);
assert(shortPtr);
shortPtr->setValue(*strVal);
} else if (bool *boolVal = std::get_if<bool>(&flag.second)) {
auto *shortPtr = static_cast<llvm::cl::opt<bool> *>(options[flag.first]);
assert(shortPtr);
shortPtr->setValue(*boolVal);
}
}
if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) {
auto optIt = options.find("print-after-all");
Expand Down Expand Up @@ -423,7 +430,9 @@ void init_triton_llvm(py::module &&m) {
m.def(
"translate_to_asm",
[](std::string llvmIR, std::string triple, std::string proc,
std::string features, std::vector<std::string> flags,
std::string features,
std::vector<std::pair<std::string, std::variant<std::string, bool>>>
flags,
bool enable_fp_fusion, bool isObject) -> py::object {
std::string obj;
{
Expand Down
37 changes: 37 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8118,3 +8118,40 @@ def kernel():
tl.device_assert(tl.sum(x) == x.sum())

kernel[(1, )]()


def test_schedule_hint(device):
if not is_hip():
pytest.skip("schedule_hint option is defined only for HIP")

@triton.jit
def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K)
Xs = X + off_m[:, None] * BLOCK_K + off_k[None, :] * 1
Ys = Y + off_k[:, None] * 1 + off_n[None, :] * BLOCK_K
Zs = Z + off_m[:, None] * BLOCK_N + off_n[None, :] * 1
x = tl.load(Xs)
y = tl.load(Ys)
z = tl.dot(x, y)
tl.store(Zs, z)

# input
rs = RandomState(17)
M = 128
N = 128
K = 128

pgm_default = kernel.warmup(torch.float32, torch.float32, torch.float32, M, N, K, grid=(1, ))
pgm_ilp = kernel.warmup(torch.float32, torch.float32, torch.float32, M, N, K,
schedule_hint="iterative-ilp-scheduler", grid=(1, ))

def get_num_vgprs(text):
for line in text.split("\n"):
if ".vgpr_count" in line:
return int(line.split(" ")[-1])

default_vgprs = get_num_vgprs(pgm_default.asm['amdgcn'])
ilp_vgprs = get_num_vgprs(pgm_ilp.asm['amdgcn'])
assert ilp_vgprs != default_vgprs
15 changes: 12 additions & 3 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class HIPOptions:
# attention: enables a bunch of optimizations for attention kernels, including:
# - iglp 2 and sched.barrier around it
# - sink-insts-to-avoid-spills flag to avoid register spills
# iterative-ilp-scheduler: enables custom instruction scheduler in backend
#
# Option allows to set multiple variants divided by commas:
# schedule_hint="attention,iterative-ilp-scheduler"
schedule_hint: str = 'none'

def __post_init__(self):
Expand Down Expand Up @@ -231,7 +235,8 @@ def make_ttgir(mod, metadata, options):
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
passes.common.add_canonicalizer(pm)
if options.schedule_hint.lower() != "none":
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.schedule_hint)
for hint in options.schedule_hint.split(","):
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, hint)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
if is_in_thread_transpose_enabled(options.arch):
Expand Down Expand Up @@ -417,8 +422,12 @@ def make_amdgcn(src, metadata, options):
# into loops to avoid register spills in the MachineSinking pass, while it
# can also lead to regression in some cases. But from current observation,
# the regression is not significant. It would be better to have some heuristics.
if options.schedule_hint == 'attention':
flags.append('sink-insts-to-avoid-spills')

for hint in options.schedule_hint.split(","):
if hint == 'attention':
flags.append(('sink-insts-to-avoid-spills', True))
if hint == 'iterative-ilp-scheduler':
flags.append(('amdgpu-sched-strategy', 'iterative-ilp'))
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', flags, options.enable_fp_fusion, False)
if knobs.amd.dump_amdgcn:
print("// -----// AMDGCN Dump //----- //")
Expand Down