diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index ff1f96c1d1..6be9c978c6 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -1,12 +1,11 @@ #pragma once - -#ifdef __AMD__ -#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" -#include "amd/include/TritonAMDGPUTransforms/Passes.h" -#endif -#ifdef __NVIDIA__ -#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" -#endif +#include "AutoBlockify/Passes.h" +#include "TritonToHFusion/Passes.h" +#include "TritonToHIVM/Passes.h" +#include "TritonToLLVM/Passes.h" +// #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +// #include "amd/include/TritonAMDGPUTransforms/Passes.h" +// #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #ifdef __NVIDIA__ @@ -66,11 +65,13 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { #endif mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::registerAllocateSharedMemoryPass(); -#ifdef __NVIDIA__ - mlir::triton::registerConvertTritonGPUToLLVMPass(); - mlir::triton::registerConvertNVGPUToLLVMPass(); - mlir::triton::registerDecomposeUnsupportedNVIDIAConversions(); -#endif + // mlir::triton::registerConvertTritonGPUToLLVMPass(); + // mlir::triton::registerConvertNVGPUToLLVMPass(); + // mlir::triton::registerDecomposeUnsupportedNVIDIAConversions(); + mlir::triton::registerTritonToHIVMPasses(); + mlir::triton::registerTritonToHFusionPasses(); + mlir::triton::registerTritonToLLVMPasses(); + mlir::triton::registerAutoBlockifyPasses(); mlir::registerLLVMDIScope(); #ifdef __AMD__ diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index e5132b6d36..e13de65d94 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -20,6 +20,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "DISABLE_PTXAS_OPT", "LLVM_IR_ENABLE_DUMP", "LLVM_ENABLE_TIMING", + "MLIR_DISABLE_MULTITHREADING", "LLVM_PASS_PLUGIN_PATH", "MLIR_ENABLE_DIAGNOSTICS", "MLIR_ENABLE_DUMP", diff --git a/python/src/llvm.cc b/python/src/llvm.cc index f9b98a2540..4aa24986c7 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -25,6 +25,7 @@ #include "llvm/Transforms/InstCombine/InstCombine.h" #include #include +#include #include #include #include @@ -172,7 +173,7 @@ void init_triton_llvm(py::module &&m) { [](llvm::Module::FunctionListType &s) { return py::make_iterator(s.begin(), s.end()); }, - py::keep_alive<0, 1>()); + py::keep_alive<0, 1>(), py::call_guard()); // Module Flag behavior. See // https://llvm.org/doxygen/classllvm_1_1Module.html#a0a5c55e12c97b80021330fe82b642293 @@ -388,7 +389,8 @@ void init_triton_llvm(py::module &&m) { // (optional) parameters py::arg("arch") = "", py::arg("features") = "", py::arg("flags") = std::vector{}, - py::arg("enable_fp_fusion") = false); + py::arg("enable_fp_fusion") = false, + py::call_guard()); m.def( "translate_to_asm", diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 6b6228f0b4..83d599e0cd 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -23,6 +23,7 @@ MockTensor, ) from .runtime.jit import jit +from .runtime._async_compile import AsyncCompileMode from .compiler import compile, CompilationError from .errors import TritonError @@ -31,6 +32,7 @@ from . import tools __all__ = [ + "AsyncCompileMode", "autotune", "cdiv", "CompilationError", diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index 92ba144ba9..9dd06e623f 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -35,7 +35,12 @@ class Backend: def _discover_backends(): backends = dict() root = os.path.dirname(__file__) + # The package does not ship the files required to load the + # upstream nvidia and amd backends, so skip discovering them here. + ignored_dirs = {"nvidia", "amd"} for name in os.listdir(root): + if name in ignored_dirs: + continue if not os.path.isdir(os.path.join(root, name)): continue if name.startswith('__'): diff --git a/python/triton/extension/__init__.py b/python/triton/extension/__init__.py new file mode 100644 index 0000000000..6cbe0ecf88 --- /dev/null +++ b/python/triton/extension/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. diff --git a/python/triton/extension/buffer/__init__.py b/python/triton/extension/buffer/__init__.py new file mode 100644 index 0000000000..6cbe0ecf88 --- /dev/null +++ b/python/triton/extension/buffer/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. diff --git a/python/triton/extension/buffer/language/__init__.py b/python/triton/extension/buffer/language/__init__.py new file mode 100644 index 0000000000..9d02b82c93 --- /dev/null +++ b/python/triton/extension/buffer/language/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = [ + # core + "builtin", + "is_builtin", + + # buffer + "buffer", + + # base address space + "address_space", + + # alloc + "alloc", + + # to_buffer + "to_buffer", + + # to_tensor + "to_tensor", + "subview", +] + +from .core import builtin, is_builtin, address_space, buffer, alloc, to_buffer, to_tensor, subview diff --git a/python/triton/extension/buffer/language/builder.py b/python/triton/extension/buffer/language/builder.py new file mode 100644 index 0000000000..bb519df070 --- /dev/null +++ b/python/triton/extension/buffer/language/builder.py @@ -0,0 +1,73 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +buffer-specific builder utilities for code generation. +""" + +__all__ = [ + "create_builder_method_wrapper_with_buffer_builder", + "attach_builder_methods_with_buffer_builder", + "setup_unified_builder_with_buffer_builder", +] + + +def create_builder_method_wrapper_with_buffer_builder(main_builder, delegate_builder, method_name): + """ + Create a wrapper that delegates a method call to another builder while + synchronizing insertion points and locations. + """ + delegate_method = getattr(delegate_builder, method_name) + + def wrapper(*args, **kwargs): + saved_ip = main_builder.get_insertion_point() + saved_loc = main_builder.get_loc() + delegate_builder.restore_insertion_point(saved_ip) + if saved_loc: + delegate_builder.set_loc(saved_loc) + result = delegate_method(*args, **kwargs) + main_builder.restore_insertion_point(saved_ip) + if saved_loc: + main_builder.set_loc(saved_loc) + return result + + wrapper.__name__ = method_name + wrapper.__doc__ = getattr(delegate_method, '__doc__', None) + return wrapper + + +def attach_builder_methods_with_buffer_builder(main_builder, delegate_builder, method_names): + """Attach multiple methods from a delegate builder to the main builder.""" + for method_name in method_names: + wrapper = create_builder_method_wrapper_with_buffer_builder(main_builder, delegate_builder, method_name) + setattr(main_builder, method_name, wrapper) + + +def setup_unified_builder_with_buffer_builder(main_builder, buffer_builder): + """Set up a unified builder interface by attaching methods from specialized builders.""" + main_builder._buffer_builder = buffer_builder + buffer_methods = [ + 'get_null_attr', + 'get_str_array_attr', + 'alloc', + 'to_buffer', + 'to_tensor', + 'subview', + ] + attach_builder_methods_with_buffer_builder(main_builder, buffer_builder, buffer_methods) diff --git a/python/triton/extension/buffer/language/core.py b/python/triton/extension/buffer/language/core.py new file mode 100644 index 0000000000..d32a787d04 --- /dev/null +++ b/python/triton/extension/buffer/language/core.py @@ -0,0 +1,330 @@ +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = [ + "address_space", + "buffer_type", + "subview", + "alloc", + "buffer", + "to_buffer", + "to_tensor", +] + +import importlib +from typing import TypeVar, List +from functools import wraps + +from triton._C.libtriton import ir +import triton.language.core as tl +from triton.language import semantic as real_semantic + +T = TypeVar("T") + +TRITON_BUILTIN = "__triton_builtin__" +BUFFER_BUILTIN = "__buffer_builtin__" + + +def builtin(fn: T) -> T: + """Mark a function as a buffer language builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + # also set triton_builtin to true so that CodeGenerator will recognize this function + setattr(wrapper, TRITON_BUILTIN, True) + setattr(wrapper, BUFFER_BUILTIN, True) + + return wrapper + + +def is_builtin(fn) -> bool: + """Is this a registered buffer language builtin function?""" + return getattr(fn, BUFFER_BUILTIN, False) + + +class address_space: + """Represents a buffer's address space. + + The :code:`address_space` of a buffer is a target-specific attribute. + """ + + def to_ir(self, builder: ir.builder) -> ir.type: + raise NotImplementedError("Abstract address_space cannot be converted to ir") + + +class buffer_type(tl.dtype): + + def __init__(self, element_ty: tl.dtype, shape: List, space: address_space = None, strides: List = None): + self.element_ty = element_ty + self.shape = shape if isinstance(shape, list) else list(shape) + self.space = space + self.strides = strides if strides is not None else [] + self.name = self._make_name() + + def _make_name(self): + res = '' + + def to_ir(self, builder: ir.builder) -> ir.type: + element_ty_ir = self.element_ty.to_ir(builder) + addr_space_attr = self.space.to_ir(builder) if self.space else builder.get_null_attr() + + # use the method with strides if strides is not empty + if self.strides: + return builder.get_buffer_ty_with_strides(self.shape, element_ty_ir, self.strides, addr_space_attr) + else: + return builder.get_buffer_ty(self.shape, element_ty_ir, addr_space_attr) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def __eq__(self, other) -> bool: + if not isinstance(other, buffer_type): + return False + return (self.element_ty == other.element_ty and self.shape == other.shape and self.space == other.space + and self.strides == other.strides) + + def __ne__(self, other) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +# ----------------------- +# buffer +# ----------------------- + + +class buffer(tl._value): + """Represents a region of memory. + + :code:`buffer` is the fundamental data structure for Triton programs using + the buffer language extension. Most functions in + :py:mod:`triton.extension.buffer.language` operate on and return buffers. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, buffer_ty: buffer_type): + """Not called by user code.""" + super().__init__(handle) + self.type = buffer_ty + self.dtype = buffer_ty.element_ty.scalar + self.shape = buffer_ty.shape + self.space = buffer_ty.space + self.strides = buffer_ty.strides + + def __str__(self) -> str: + # ex. "<16x32xfloat32, address_space>" + res = '<' + 'x'.join(str(s) for s in self.shape) + 'x' + str(self.dtype) + if self.space: + res += ', ' + str(self.space) + return res + '>' + + @builtin + def subview(self, offsets: List[tl.constexpr], sizes: List[tl.constexpr], strides: List[tl.constexpr], + _builder=None) -> 'buffer': + return subview(self, offsets, sizes, strides, _builder=_builder) + + @builtin + def to_tensor(self, writable=True, target_shape=None, _builder=None): + """Convert this buffer to a tl.tensor""" + return to_tensor(self, writable=writable, target_shape=target_shape, _builder=_builder) + + +semantic = importlib.import_module(".semantic", package=__package__) + + +@builtin +def alloc(etype: tl.dtype, shape: List[tl.constexpr], _address_space: address_space = None, is_mem_unique: bool = False, + _builder=None) -> buffer: + """ + Allocates a region of local memory with the specified shape and type. + + :param etype: the element type of the buffer. + :type etype: tl.dtype + :param shape: A list of non-negative integers representing the shape of the buffer. + :type shape: List[tl.constexpr] + :param _address_space: (Optional) backend-specific local memory address space + :type _address_space: bl.address_space + """ + return semantic.alloc(etype, shape, _address_space, is_mem_unique, _builder) + + +@builtin +def to_buffer(tensor: tl.tensor, space: address_space = None, bind_buffer: buffer = None, _builder=None) -> buffer: + """ + Convert a tensor to a buffer. + + :param tensor: the tensor to convert. + :type tensor: tl.tensor + :param space: the address space for the buffer (optional). + :type space: address_space + """ + return semantic.to_buffer(tensor, space, bind_buffer, _builder) + + +@builtin +def to_tensor(memref: buffer, writable: bool = True, target_shape=None, _builder=None) -> tl.tensor: + """ + Create a tl.tensor from a bl.buffer. + + :param memref: the input bl.buffer object. + :memref type: bl.buffer + :param writable: If set true, the resultant tensor is considered "writable" during bufferization. + :type writable: bool + """ + return semantic.to_tensor(memref, writable, _builder, target_shape=target_shape) + + +def check_subview(src, offsets, sizes, strides): + """ + Check data of subview methods which the data length and the offset value must be 32-byte aligned. + + The conditions for checking data are as follows: + 1. offset value must be 32-bytes aligned. + 2. all strides must be 1. + 3. the first point's offset in the second row of the last dimension must be 32-bytes aligned. + + For instance, the following example fails to satisfy the specified criteria. + %subview = memref.subview %arg0[1, 1][4, 4][2, 2] + : memref<8x8xf32, strided<[8, 1], offset: 0>> to + memref<4x4xf32, strided<[16, 2], offset: 9>> + offsets = [8, 8] | sizes = [4, 4] | strides = [2, 2] + result_offset = 9 + second_row_start_offset = 25 + The scene will be go wrong because the follow conditions are not meet. + 1) result_offset is not 32-bytes aligned. + 2) strides = [2, 2], not all strides are equal to 1. + 3) second_row_start_offset are not 32-bytes aligned. + """ + bytes_per_block = 32 + bits_per_byte = 8 + base_byte = bytes_per_block // (src.dtype.primitive_bitwidth // bits_per_byte) + result_strides = [] + result_offset = 0 + second_row_start_offset = 0 + length = len(strides) + src_strides = [1] * length + if length == 1: + if offset[0] % base_byte != 0: + raise TypeError("all strides should be 1 and the offset value should be 32-bytes aligned.") + return + for i in range(length - 2, -1, -1): + src_strides[i] = src_strides[i + 1] * src.shape[i + 1] + for i in range(0, length): + if isinstance(offsets[i], tl.tensor): + return + result_strides.append(src_strides[i] * strides[i]) + result_offset = result_offset + offsets[i] * src_strides[i] + second_row_start_offset = result_offset + src_strides[-2] * strides[-2] + is_unaligned = False + if sizes[1] > 1: + is_unaligned = second_row_start_offset % base_byte != 0 + stride_1 = all(s == 1 for s in strides) + is_unaligned = result_offset % base_byte != 0 or is_unaligned or not stride_1 + if is_unaligned: + raise TypeError("all strides should be 1 and the offset value should be 32-bytes aligned.") + + +@builtin +def subview(src: buffer, offsets: List[tl.constexpr], sizes: List[tl.constexpr], strides: List[tl.constexpr], + _builder=None) -> buffer: + ''' + Creates a subview of the source buffer with the specified offsets, sizes, and strides. + + :param src: The source buffer to create a subview from. + :type src: buffer + :param offsets: A list of non-negative integers representing the offsets in each dimension. + :type offsets: List[tl.constexpr] + :param sizes: A list of non-negative integers representing the sizes in each dimension. + :type sizes: List[tl.constexpr] + :param strides: A list of non-negative integers representing the strides in each dimension. + :type strides: List[tl.constexpr] + :return: A new buffer representing the subview of the source buffer. + :rtype: buffer + ''' + # Validate that sizes and strides contain only constexpr values + new_sizes = [] + for i, size in enumerate(sizes): + if isinstance(size, int): + # Convert regular integers to constexpr + new_sizes.append(tl.constexpr(size)) + elif isinstance(size, tl.constexpr): + new_sizes.append(size) + else: + raise TypeError(f"sizes[{i}] must be constexpr, got {type(size).__name__}") + + new_strides = [] + for i, stride in enumerate(strides): + if isinstance(stride, int): + # Convert regular integers to constexpr + new_strides.append(tl.constexpr(stride)) + elif isinstance(stride, tl.constexpr): + new_strides.append(stride) + else: + raise TypeError(f"strides[{i}] must be constexpr, got {type(stride).__name__}") + + check_offsets = [] + new_offsets = [] + for offset in offsets: + if isinstance(offset, tl.constexpr): + # Check that constexpr offset values cannot be negative + if offset < 0: + raise ValueError(f"Offset value must be non-negative, got {offset}") + new_offsets.append(real_semantic.to_tensor(offset, _builder)) + check_offsets.append(offset) + elif isinstance(offset, int): + # Convert regular integers to constexpr and then to tensor + if offset < 0: + raise ValueError(f"Offset value must be non-negative, got {offset}") + new_offsets.append(real_semantic.to_tensor(tl.constexpr(offset), _builder)) + check_offsets.append(tl.constexpr(offset)) + else: + # Assume it's already a tensor + new_offsets.append(offset) + check_offsets.append(offset) + + check_subview(src, check_offsets, new_sizes, new_strides) + return semantic.subview(src, new_offsets, new_sizes, new_strides, _builder) diff --git a/python/triton/extension/buffer/language/semantic.py b/python/triton/extension/buffer/language/semantic.py new file mode 100644 index 0000000000..b694e3ab29 --- /dev/null +++ b/python/triton/extension/buffer/language/semantic.py @@ -0,0 +1,128 @@ +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from typing import (TypeVar, List) + +from triton._C.libtriton import ir +import triton.language.core as tl + +from . import core as bl + +T = TypeVar('T') + + +def alloc(etype: tl.dtype, shape: List[tl.constexpr], address_space: bl.address_space, is_mem_unique, + builder: ir.builder) -> bl.buffer: + shape = tl._unwrap_shape(shape) + if etype == tl.int1: + raise TypeError("Unsupported alloc int1 type") + if not isinstance(shape, (tuple, list)): + raise TypeError("shape must be list/tuple") + etype = tl._constexpr_to_value(etype) + address_space = tl._constexpr_to_value(address_space) + element_ty_ir = etype.to_ir(builder) + addr_space_attr = (address_space.to_ir(builder) if address_space else builder.get_null_attr()) + memref_ty = builder.get_buffer_ty(shape, element_ty_ir, addr_space_attr) + handle = builder.alloc(memref_ty) + if is_mem_unique: + builder.create_annotation_mark(handle, "mem_unique", builder.get_unit_attr()) + builder.create_annotation_mark(handle, "effects", builder.get_str_array_attr(["write", "read"])) + + buffer_ty = bl.buffer_type(element_ty=etype, shape=shape, space=address_space) + return bl.buffer(handle, buffer_ty) + + +def to_buffer( + tensor: tl.tensor, + address_space: bl.address_space, + bind_buffer: bl.buffer, + builder: ir.builder, +) -> bl.buffer: + if not isinstance(tensor.shape, (tuple, list)) or not tensor.shape: + raise TypeError("scalar type cannot be converted to buffer") + if isinstance(bind_buffer, bl.buffer): + builder.create_bind_buffer(tensor.handle, bind_buffer.handle) + return bind_buffer + if bind_buffer is not None: + raise ValueError("bind_buffer must be a buffer or None") + address_space = tl._constexpr_to_value(address_space) + addr_space_attr = (address_space.to_ir(builder) if address_space else builder.get_null_attr()) + handle = builder.to_buffer(tensor.handle, addr_space_attr) + buffer_ty = bl.buffer_type(element_ty=tensor.dtype, shape=tensor.shape, space=address_space) + return bl.buffer(handle, buffer_ty) + + +def to_tensor(memref: bl.buffer, writable: bool, builder: ir.builder, target_shape=None) -> tl.tensor: + if not isinstance(memref, bl.buffer): + raise TypeError("memref must be bl.buffer") + + need_convert_layout = False + shape = memref.shape + if target_shape: + need_convert_layout = True + shape = tl._unwrap_shape(target_shape) + assert shape != memref.shape, "target shape is the same as source shape" + if not isinstance(shape, (tuple, list)): + raise TypeError("shape must be list/tuple") + tensor_type = tl.block_type(memref.dtype, shape) + + memref_value = memref.handle + if need_convert_layout: + buffer_ty = bl.buffer_type( + element_ty=memref.dtype, + shape=shape, + space=memref.space, + ) + memref_value = builder.create_convert_layout(memref_value, buffer_ty.to_ir(builder)) + + return tl.tensor(builder.to_tensor(memref_value, writable), tensor_type) + + +def subview(src: bl.buffer, offsets: List[tl.tensor], sizes: List[tl.constexpr], strides: List[tl.constexpr], + builder: ir.builder) -> bl.buffer: + + new_offsets = [offset.handle for offset in offsets] + sizes_int = tl._unwrap_shape(sizes) + strides_int = tl._unwrap_shape(strides) + + result_handle = builder.subview(src.handle, new_offsets, sizes_int, strides_int) + + # calculate the memory layout strides of the source buffer + if src.strides: + # use the strides of the source buffer + src_memory_strides = src.strides + else: + # calculate the default row-major strides + src_memory_strides = [] + stride = 1 + for dim_size in reversed(src.shape): + if dim_size < 0: + raise ValueError("Cannot compute strides for buffer with dynamic dimensions") + src_memory_strides.insert(0, stride) + stride *= dim_size + + result_memory_strides = [] + for src_stride, subview_stride in zip(src_memory_strides, strides_int): + result_memory_strides.append(src_stride * subview_stride) + + # create buffer_type with strides + buffer_ty = bl.buffer_type(element_ty=src.dtype, shape=sizes_int, space=src.space, strides=result_memory_strides) + return bl.buffer(result_handle, buffer_ty) diff --git a/python/triton/extension/buffer/src/buffer_ir.cc b/python/triton/extension/buffer/src/buffer_ir.cc new file mode 100644 index 0000000000..bd1bc917b6 --- /dev/null +++ b/python/triton/extension/buffer/src/buffer_ir.cc @@ -0,0 +1,177 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * Copyright 2018-2020 Philippe Tillet + * Copyright 2020-2022 OpenAI + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +#include "ir.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +using namespace mlir; +namespace py = pybind11; + +constexpr unsigned kIntegerAttrBitWidth = 64; + +struct BufferOpBuilder : public TritonOpBuilder {}; + +void init_buffer_ir(py::module &&m) { + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_( + m, "buffer_builder", py::module_local(), py::dynamic_attr()) + .def(py::init()) + .def("get_null_attr", [](BufferOpBuilder &self) { return Attribute(); }) + .def("get_str_array_attr", + [](BufferOpBuilder &self, + const std::vector &array) -> ArrayAttr { + auto strRefVec = to_vector(llvm::map_range( + array, [](const auto &s) { return llvm::StringRef(s); })); + return self.getBuilder().getStrArrayAttr( + llvm::ArrayRef{strRefVec}); + }) + .def("alloc", + [](BufferOpBuilder &self, Type memrefType) -> Value { + return self.create( + mlir::cast(memrefType)); + }) + .def("to_buffer", + [](BufferOpBuilder &self, Value &src, + const Attribute &addressSpace) -> Value { + auto tensorType = dyn_cast(src.getType()); + if (!tensorType) { + llvm::report_fatal_error("to_buffer: src must be tensor type"); + } + auto memrefType = MemRefType::get(tensorType.getShape(), + tensorType.getElementType(), + MemRefLayoutAttrInterface{}); + // TODO: We need to add a pass before OneShotBufferize to generate + // MemorySpaceCastOp + Operation *memref = + self.create(memrefType, src); + if (addressSpace) { + memref = self.create( + MemRefType::get(memrefType.getShape(), + memrefType.getElementType(), + memrefType.getLayout(), addressSpace), + memref->getResult(0)); + } + return memref->getResult(0); + }) + .def("to_tensor", + [](BufferOpBuilder &self, Value &src, bool writable) -> Value { + const auto &memrefType = mlir::cast(src.getType()); + auto hasAddressSpace = memrefType.getMemorySpace(); + if (hasAddressSpace) { + return self.create( + self.create( + MemRefType::get(memrefType.getShape(), + memrefType.getElementType(), + memrefType.getLayout()), + src), + true, writable); + } + return self.create(src, true, writable); + }) + .def("subview", + [](BufferOpBuilder &self, Value source, std::vector &offsets, + const std::vector &sizes, + const std::vector &strides) -> Value { + SmallVector mixedOffsets; + auto *context = self.getBuilder().getContext(); + auto &builder = self.getBuilder(); + + // Get source memref type for validation + auto sourceType = mlir::cast(source.getType()); + int64_t rank = sourceType.getRank(); + // Verify the number of parameters + if (offsets.size() != rank || sizes.size() != rank || + strides.size() != rank) { + throw std::runtime_error("Number of offsets, sizes, and strides " + "must match memref rank"); + } + + for (const auto &offset : offsets) { + auto indexType = builder.getIndexType(); + if (offset.getType() != indexType) { + Value offset_val = + self.create(indexType, offset); + mixedOffsets.push_back(offset_val); + } else { + mixedOffsets.push_back(offset); + } + } + + SmallVector mixedSizes; + SmallVector mixedStrides; + for (int64_t i = 0; i < rank; ++i) { + int64_t size = sizes[i]; + int64_t stride = strides[i]; + int64_t srcDim = sourceType.getDimSize(i); + + // verify sizes cannot be negative or zero + if (size <= 0) { + throw std::runtime_error("Expected sizes to be positive"); + } + + // verify strides cannot be negative or zero + if (stride <= 0) { + throw std::runtime_error("Expected strides to be positive"); + } + + // getDimSize() returns -1 (ShapedType::kDynamic) for dynamic + // dimensions + if (!ShapedType::isDynamic(srcDim)) { + // verify the subview size does not exceed the source dimension + if (size > srcDim) { + throw std::runtime_error( + "Subview size cannot exceed source dimension size"); + } + + // verify strides cannot exceed the source dimension size + if (stride > srcDim) { + throw std::runtime_error( + "Stride cannot exceed source dimension size"); + } + } + + mixedSizes.push_back(IntegerAttr::get( + IntegerType::get(context, kIntegerAttrBitWidth), size)); + mixedStrides.push_back(IntegerAttr::get( + IntegerType::get(context, kIntegerAttrBitWidth), stride)); + } + + return self.create(source, mixedOffsets, + mixedSizes, mixedStrides); + }); +} diff --git a/python/triton/tools/get_ascend_devices.py b/python/triton/tools/get_ascend_devices.py new file mode 100644 index 0000000000..f6ba1c9c01 --- /dev/null +++ b/python/triton/tools/get_ascend_devices.py @@ -0,0 +1,49 @@ +import os +import glob +import logging +import subprocess + +logger = logging.getLogger(__name__) + + +def get_ascend_devices(): + devices = [] + pci_path = '/sys/bus/pci/devices/*' + + for dev in glob.glob(pci_path): + try: + vendor_path = os.path.join(dev, 'vendor') + device_path = os.path.join(dev, 'device') + + if os.path.exists(vendor_path): + with open(vendor_path, 'r') as f: + vendor = f.read().strip() + + if vendor == "0x19e5" and os.path.exists(device_path): + with open(device_path, 'r') as f: + device = f.read().strip() + devices.append(device) + except (IOError, OSError) as e: + logger.warning(f"can not fetch device {dev}: {e}") + continue + + return devices + + +def check_npu_smi_device(): + try: + result = subprocess.run(["npu-smi", "info"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, + shell=False, timeout=100) + if result.returncode == 0: + output = result.stdout.lower() + return "ascend910_95" in output or "ascend950" in output or "910_958b" in output + return False + except Exception: + logger.warning("can not use command: npu-smi info") + return False + + +ascend_devices = get_ascend_devices() +pci_condition = any("0xd806" in dev for dev in ascend_devices) +npu_smi_condition = check_npu_smi_device() +is_compile_on_910_95 = pci_condition or npu_smi_condition diff --git a/third_party/ascend/CMakeLists.txt b/third_party/ascend/CMakeLists.txt index 1bf6cb570a..742063d2ce 100644 --- a/third_party/ascend/CMakeLists.txt +++ b/third_party/ascend/CMakeLists.txt @@ -12,9 +12,9 @@ include_directories(${CMAKE_BINARY_DIR}/third_party/flir/include) # set(BISHENGIR_ENABLE_A5_UNPUBLISHED_FEATURES ON) # set(BISHENGIR_BUILD_STANDALONE_IR_ONLY ON) -# add_subdirectory(${ASCENDNPU_IR_SRC_DIR} ${ASCENDNPU_IR_BINARY_DIR}) -# include_directories(${ASCENDNPU_IR_SRC_DIR}/bishengir/include) -# include_directories(${ASCENDNPU_IR_BINARY_DIR}/bishengir/include) # Tablegen'd files +# AscendNPU-IR is already added from the top-level CMakeLists when +# FLAGTREE_BACKEND=ascend. Do not add it again here, otherwise CMake will +# fail with "binary directory is already used to build a source directory". add_subdirectory(backend/spec/lib) @@ -36,7 +36,6 @@ endif() add_triton_plugin(TritonAscend ${CMAKE_CURRENT_SOURCE_DIR}/triton_ascend.cc ${CMAKE_CURRENT_SOURCE_DIR}/ascend_ir.cc - LINK_LIBS TritonToLinalgIncubated BiShengIRScopeDialect @@ -44,7 +43,34 @@ add_triton_plugin(TritonAscend ${_MLIRMeshDialect_LIB} ) -# target_link_libraries(TritonAscend PRIVATE Python3::Module pybind11::headers) +option(TRITON_ENABLE_COVERAGE_LLVM_COV "Enable code llvm-cov coverage tool for Ascend plugin " OFF) +if(TRITON_ENABLE_COVERAGE_LLVM_COV) + message(STATUS "Enabling llvm-cov coverage tool flags for TritonAscend") + target_compile_options(TritonAscend PRIVATE + -fprofile-arcs + -ftest-coverage + -O0 + -fprofile-update=atomic + --coverage + ) + target_link_options(TritonAscend PRIVATE + --coverage + -lgcov + ) + # branch coverage + target_compile_definitions(TritonAscend PRIVATE + COVERAGE_ENABLED=1 + ) +endif() + + +# To enable hitest coverage tool +if(TRITON_ENABLE_COVERAGE_HITEST) + set_target_properties(TritonAscend PROPERTIES + RULE_LAUNCH_COMPILE "hitestwrapper" + RULE_LAUNCH_LINK "hitestwrapper" + ) +endif() if(TRITON_BUILD_UT) add_subdirectory(unittest) diff --git a/third_party/ascend/ascend_ir.cc b/third_party/ascend/ascend_ir.cc index 7cab69c452..97ee3f0c2b 100644 --- a/third_party/ascend/ascend_ir.cc +++ b/third_party/ascend/ascend_ir.cc @@ -24,6 +24,7 @@ #include "ir.h" #include "pybind11/pybind11.h" +#include #include #include "bishengir/Dialect/Annotation/IR/Annotation.h" @@ -33,6 +34,8 @@ #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" @@ -44,20 +47,41 @@ namespace py = pybind11; struct AscendNPUIROpBuilder : public TritonOpBuilder { std::string target; static constexpr char kTarget910_95[] = "Ascend910_95"; + static constexpr char kTarget950[] = "Ascend950"; explicit AscendNPUIROpBuilder(MLIRContext *context, std::string target = "") : TritonOpBuilder(context), target(target) {} - bool is_910_95() { + bool is_910_95() const { // TODO: Use enum instead of strings after enabling HACC in satandalone // build - constexpr size_t kTargetLen = sizeof(kTarget910_95) - 1; - return target.size() >= kTargetLen && - target.compare(0, kTargetLen, kTarget910_95) == 0; + constexpr size_t kLen910 = sizeof(kTarget910_95) - 1; + bool match_910 = target.size() >= kLen910 && + target.compare(0, kLen910, kTarget910_95) == 0; + + constexpr size_t kLen950 = sizeof(kTarget950) - 1; + bool match_950 = + target.size() >= kLen950 && target.compare(0, kLen950, kTarget950) == 0; + + return match_910 || match_950; } }; namespace { +MLIRContext *gDefaultAscendContext = nullptr; + +MLIRContext *resolveContext(const py::object &contextObj) { + if (!contextObj.is_none()) { + return &py::cast(contextObj); + } + if (gDefaultAscendContext) { + return gDefaultAscendContext; + } + throw std::invalid_argument( + "No default MLIR context. Pass context explicitly or call " + "ascend_ir.load_dialects(context) first."); +} + struct ModeAndPipes { hivm::SyncBlockModeAttr modeAttr = {}; hivm::PipeAttr cubePipe = {}; @@ -143,6 +167,264 @@ ModeAndPipes GetSyncBlockModeAndPipes(MLIRContext *ctx, } // namespace void init_ascend_ir(py::module &&m) { + auto affineExprClass = + py::class_(m, "affine_expr", py::module_local()); + affineExprClass + .def("__str__", + [](AffineExpr self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }) + .def("__repr__", + [](AffineExpr self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return ""; + }) + .def("is_symbolic_or_constant", &AffineExpr::isSymbolicOrConstant) + .def("is_pure_affine", &AffineExpr::isPureAffine) + .def("is_function_of_dim", &AffineExpr::isFunctionOfDim) + .def("compose", + [](AffineExpr self, AffineMap map) { return self.compose(map); }) + .def("get_largest_known_divisor", &AffineExpr::getLargestKnownDivisor) + .def("floordiv", [](AffineExpr self, + AffineExpr other) { return self.floorDiv(other); }) + .def("ceildiv", [](AffineExpr self, + AffineExpr other) { return self.ceilDiv(other); }) + .def("mod", + [](AffineExpr self, AffineExpr other) { return self % other; }) + .def("__hash__", + [](AffineExpr self) { + return py::int_(static_cast(mlir::hash_value(self))); + }) + .def("__eq__", [](AffineExpr lhs, AffineExpr rhs) { return lhs == rhs; }) + .def(py::self + py::self) + .def(py::self - py::self) + .def(py::self * py::self) + .def(py::self % py::self); + affineExprClass + .def_static( + "get_constant", + [](int64_t val, py::object contextObj) { + auto *context = resolveContext(contextObj); + return getAffineConstantExpr(val, context); + }, + py::arg("value"), py::arg("context") = py::none()) + .def_static( + "get_dim", + [](uint32_t pos, py::object contextObj) { + auto *context = resolveContext(contextObj); + return getAffineDimExpr(pos, context); + }, + py::arg("pos"), py::arg("context") = py::none()) + .def_static( + "get_symbol", + [](uint32_t pos, py::object contextObj) { + auto *context = resolveContext(contextObj); + return getAffineSymbolExpr(pos, context); + }, + py::arg("pos"), py::arg("context") = py::none()); + + py::class_(m, "affine_constant_expr", + py::module_local()) + .def("get_value", &AffineConstantExpr::getValue); + py::class_(m, "affine_dim_expr", + py::module_local()) + .def("get_position", &AffineDimExpr::getPosition); + py::class_(m, "affine_symbol_expr", + py::module_local()) + .def("get_position", &AffineSymbolExpr::getPosition); + py::class_(m, "affine_binary_op_expr", + py::module_local()) + .def("get_lhs", &AffineBinaryOpExpr::getLHS) + .def("get_rhs", &AffineBinaryOpExpr::getRHS); + + auto affineMapClass = + py::class_(m, "affine_map", py::module_local()); + affineMapClass + .def("__str__", + [](AffineMap &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }) + .def("__repr__", + [](AffineMap &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return ""; + }) + .def("is_identity", &AffineMap::isIdentity) + .def("is_permutation", &AffineMap::isPermutation) + .def("get_num_dims", &AffineMap::getNumDims) + .def("get_num_symbols", &AffineMap::getNumSymbols) + .def("get_num_results", &AffineMap::getNumResults) + .def("is_empty", &AffineMap::isEmpty) + .def("is_single_constant", &AffineMap::isSingleConstant) + .def("is_constant", &AffineMap::isConstant) + .def("get_constant_result", + [](AffineMap &self) -> int64_t { + if (!self.isSingleConstant()) { + throw std::runtime_error( + "affine map is not a single constant map"); + } + return self.getSingleConstantResult(); + }) + .def("get_result", + [](AffineMap &self, uint32_t pos) { + if (pos >= self.getNumResults()) { + throw py::index_error("result index out of range"); + } + return self.getResult(pos); + }) + .def("get_sub_map", + [](AffineMap &self, const std::vector &resultPos) { + return self.getSubMap(resultPos); + }) + .def("replace", + [](AffineMap &self, AffineExpr expr, AffineExpr replacement, + uint32_t numResultDims, uint32_t numResultSymbols) { + return self.replace(expr, replacement, numResultDims, + numResultSymbols); + }) + .def("compose", + [](AffineMap &self, AffineMap map) { return self.compose(map); }) + .def("get_results", + [](AffineMap &self) -> std::vector { + auto results = self.getResults(); + return std::vector(results.begin(), results.end()); + }) + .def("__hash__", + [](AffineMap &self) { + return py::int_(static_cast(mlir::hash_value(self))); + }) + .def("__eq__", [](AffineMap &lhs, AffineMap &rhs) { return lhs == rhs; }) + .def("inverse_permutation", + [](AffineMap &self) -> py::object { + // Validate it's a permutation first + if (!self.isPermutation()) { + throw py::value_error( + "AffineMap must be a valid permutation to compute inverse"); + } + + // Returns AffineMap directly, not a pointer + AffineMap inverse = mlir::inversePermutation(self); + + // Check if result is valid (null AffineMap) + if (!inverse) { + throw py::value_error("Failed to compute inverse permutation"); + } + + return py::cast(inverse); + }) + .def("to_dict", [](AffineMap &self) -> py::dict { + py::list results; + for (AffineExpr result : self.getResults()) { + if (auto dimExpr = dyn_cast(result)) { + results.append(dimExpr.getPosition()); + } else { + std::string exprStr; + llvm::raw_string_ostream os(exprStr); + result.print(os); + results.append(py::str(exprStr)); + } + } + + py::dict ret; + ret["num_dims"] = self.getNumDims(); + ret["num_symbols"] = self.getNumSymbols(); + ret["results"] = std::move(results); + return ret; + }); + affineMapClass + .def_static( + "get", + [](int64_t numDims, int64_t numSymbols, const py::iterable &resultsIn, + py::object contextObj) -> AffineMap { + MLIRContext *context = nullptr; + if (numDims < 0 || numSymbols < 0) { + throw std::invalid_argument( + "num_dims and num_symbols must be non-negative"); + } + llvm::SmallVector results; + for (const auto &item : resultsIn) { + if (py::isinstance(item)) { + auto expr = py::cast(item); + if (!context) { + context = expr.getContext(); + } + results.push_back(expr); + continue; + } + if (py::isinstance(item)) { + if (!context) { + context = resolveContext(contextObj); + } + int64_t pos = py::cast(item); + if (pos < 0 || pos >= numDims) { + throw std::invalid_argument( + "result dim index is out of range for num_dims"); + } + results.push_back(getAffineDimExpr(pos, context)); + continue; + } + throw std::invalid_argument( + "results must contain affine_expr or int dim indices"); + } + if (!context) { + context = resolveContext(contextObj); + } + return AffineMap::get(numDims, numSymbols, results, context); + }, + py::arg("num_dims"), py::arg("num_symbols"), py::arg("result_dims"), + py::arg("context") = py::none()) + .def_static( + "get_identity", + [](int64_t numDims, py::object contextObj) -> AffineMap { + auto *context = resolveContext(contextObj); + if (numDims < 0) { + throw std::invalid_argument("num_dims must be non-negative"); + } + return AffineMap::getMultiDimIdentityMap(numDims, context); + }, + py::arg("num_dims"), py::arg("context") = py::none()) + .def_static( + "get_minor_identity", + [](int64_t dims, int64_t results, py::object contextObj) { + auto *context = resolveContext(contextObj); + if (dims < 0 || results < 0) { + throw std::invalid_argument("dims/results must be non-negative"); + } + return AffineMap::getMinorIdentityMap(dims, results, context); + }, + py::arg("dims"), py::arg("results"), py::arg("context") = py::none()) + .def_static( + "get_empty", + [](py::object contextObj) { + auto *context = resolveContext(contextObj); + return AffineMap::get(0, 0, {}, context); + }, + py::arg("context") = py::none()) + .def_static( + "get_permutation", + [](const std::vector &permutation, py::object contextObj) { + auto *context = resolveContext(contextObj); + return AffineMap::getPermutationMap(permutation, context); + }, + py::arg("permutation"), py::arg("context") = py::none()) + .def_static( + "get_constant", + [](int64_t value, py::object contextObj) { + auto *context = resolveContext(contextObj); + return AffineMap::getConstantMap(value, context); + }, + py::arg("value"), py::arg("context") = py::none()); + py::enum_(m, "AddressSpace", py::module_local()) .value("L1", hivm::AddressSpace::L1) .value("UB", hivm::AddressSpace::UB) @@ -175,6 +457,21 @@ void init_ascend_ir(py::module &&m) { .value("MIX", hivm::VFMode::MIX) .export_values(); + py::enum_(m, "IteratorType", py::module_local()) + .value("Parallel", hivm::IteratorType::kParallel) + .value("Broadcast", hivm::IteratorType::kBroadcast) + .value("Transpose", hivm::IteratorType::kTranspose) + .value("Reduction", hivm::IteratorType::kReduction) + .value("Interleave", hivm::IteratorType::kInterleave) + .value("Deinterleave", hivm::IteratorType::kDeinterleave) + .value("Inverse", hivm::IteratorType::kInverse) + .value("Pad", hivm::IteratorType::kPad) + .value("Concat", hivm::IteratorType::kConcat) + .value("Gather", hivm::IteratorType::kGather) + .value("Cumulative", hivm::IteratorType::kCumulative) + .value("Opaque", hivm::IteratorType::kOpaque) + .export_values(); + py::enum_(m, "FixpipeDMAMode", py::module_local()) .value("NZ2DN", hivm::FixpipeDMAMode::NZ2DN) .value("NZ2ND", hivm::FixpipeDMAMode::NZ2ND) @@ -209,9 +506,7 @@ void init_ascend_ir(py::module &&m) { .export_values(); m.def("load_dialects", [](MLIRContext &context) { - // Allow unregistered dialects so we can parse HACC attributes without - // registering the dialect - context.allowUnregisteredDialects(); + gDefaultAscendContext = &context; DialectRegistry registry; registry.insert(); @@ -223,6 +518,10 @@ void init_ascend_ir(py::module &&m) { m, "ascendnpu_ir_builder", py::module_local(), py::dynamic_attr()) .def(py::init(), py::arg("context"), py::arg("target") = "") + .def("get_int_attr", + [](AscendNPUIROpBuilder &self, int64_t value) -> Attribute { + return IntegerAttr::get(self.getBuilder().getI64Type(), value); + }) .def("get_core_type_attr", [](AscendNPUIROpBuilder &self, hivm::TCoreType core_type) -> Attribute { @@ -236,6 +535,17 @@ void init_ascend_ir(py::module &&m) { [](AscendNPUIROpBuilder &self, hivm::VFMode mode) -> Attribute { return self.getBuilder().getAttr(mode); }) + .def("get_iterator_types_attr", + [](AscendNPUIROpBuilder &self, + const std::vector &array) { + llvm::SmallVector attrs; + attrs.reserve(array.size()); + for (auto type : array) { + attrs.push_back(self.getBuilder().getI32IntegerAttr( + static_cast(type))); + } + return self.getBuilder().getArrayAttr(attrs); + }) .def("get_t_core_type_attr_name", [](AscendNPUIROpBuilder &self) -> std::string { return hivm::TCoreTypeAttr::name.str(); @@ -253,8 +563,33 @@ void init_ascend_ir(py::module &&m) { .def("parse_attr", [](TritonOpBuilder &self, std::string value) -> Attribute { auto *ctx = self.getBuilder().getContext(); + // Enable parsing of HACC attributes by allowing unregistered + // dialects. + ctx->allowUnregisteredDialects(); return mlir::parseAttribute(value, ctx); }) + .def("get_affine_map_attr", + [](AscendNPUIROpBuilder &self, AffineMap affineMap) -> Attribute { + return AffineMapAttr::get(affineMap); + }) + .def("get_affine_map_array_attr", + [](AscendNPUIROpBuilder &self, + const std::vector &affineMaps) -> Attribute { + auto *ctx = self.getBuilder().getContext(); + llvm::SmallVector attrs; + attrs.reserve(affineMaps.size()); + for (const auto &map : affineMaps) { + attrs.push_back(AffineMapAttr::get(map)); + } + return ArrayAttr::get(ctx, attrs); + }) + .def("get_buffer_ty_with_affine_map", + [](AscendNPUIROpBuilder &self, std::vector &shape, + Type &elementType, AffineMap affineMap, + const Attribute &memorySpace) -> Type { + auto layout = AffineMapAttr::get(affineMap); + return MemRefType::get(shape, elementType, layout, memorySpace); + }) .def("create_fixpipe", [](AscendNPUIROpBuilder &self, Value src, Value dst, hivm::FixpipeDMAMode dma_mode, @@ -281,6 +616,13 @@ void init_ascend_ir(py::module &&m) { mlir::TypeRange{}, src, dst, dma_mode_attr, dual_dst_mode_attr, pre_quant_mode_attr, pre_relu_mode_attr, channel_split); }) + .def("create_annotation_mark", + [](TritonOpBuilder &self, Value &ptr, const std::string &attrKey, + Attribute &attrVal) { + auto annotationOp = self.create(ptr); + annotationOp->setAttr(self.getBuilder().getStringAttr(attrKey), + attrVal); + }) .def("create_bind_buffer", [](TritonOpBuilder &self, Value &src, Value &alloc) -> void { auto ctx = self.getBuilder().getContext(); @@ -298,17 +640,46 @@ void init_ascend_ir(py::module &&m) { .def("create_custom_op", [](AscendNPUIROpBuilder &self, const std::string &name, const py::dict &attrs, const std::vector &ins, - const std::vector &outs) -> std::vector { + const std::vector &outs, + const std::vector &arg_attrs) -> std::vector { ValueRange inputs{ins}; ValueRange outputs{outs}; + ValueRange temp_buffers{}; TypeRange res_types{outputs}; - auto op = - self.create(res_types, name, inputs, outputs); + auto op = self.create(res_types, name, inputs, + outputs, temp_buffers); for (auto &attr : attrs) { std::string attr_name = py::cast(attr.first); Attribute attr_value = py::cast(attr.second); op->setAttr(attr_name, attr_value); } + + SmallVector dictAttrs(arg_attrs.size()); + Attribute emptyDict = self.getBuilder().getDictionaryAttr({}); + for (const auto &[idx, attrs] : llvm::enumerate(arg_attrs)) { + if (idx >= op.getNumOperands()) + continue; + + if (attrs.is_none()) { + dictAttrs[idx] = emptyDict; + continue; + } + + llvm::SmallVector namedAttrs; + for (const auto &attr : attrs) { + std::string attr_name = py::cast(attr.first); + Attribute attr_value = py::cast(attr.second); + namedAttrs.push_back(NamedAttribute( + self.getBuilder().getStringAttr(attr_name), attr_value)); + } + + dictAttrs[idx] = self.getBuilder().getDictionaryAttr(namedAttrs); + } + + ArrayAttr arg_attrs_array = + self.getBuilder().getArrayAttr(dictAttrs); + op->setAttr("arg_attrs", arg_attrs_array); + auto results = op->getResults(); return std::vector(results.begin(), results.end()); }) diff --git a/third_party/ascend/backend/backend_register.py b/third_party/ascend/backend/backend_register.py index 480f2a7fed..ca2d39c401 100644 --- a/third_party/ascend/backend/backend_register.py +++ b/third_party/ascend/backend/backend_register.py @@ -33,7 +33,7 @@ def decorator(func: Callable): if category not in self.strategies: self.strategies[category] = {} if method in self.strategies[category]: - raise ValueError(f"Strategy {name} already registered") + raise ValueError(f"Strategy {method} already registered") self.strategies[category][method] = func return func @@ -164,7 +164,7 @@ def get_empty_tensor(size): @backend_strategy_registry.register("mindspore", "get_tensor_params_shape") -def get_tensor_params_shape(args): +def get_tensor_params_shape(*args): import mindspore tensor_params = [arg for arg in args if isinstance(arg, mindspore.Tensor)] tensor_params_shape = [] @@ -174,7 +174,7 @@ def get_tensor_params_shape(args): @backend_strategy_registry.register("torch_npu", "get_tensor_params_shape") -def get_tensor_params_shape(args): +def get_tensor_params_shape(*args): import torch tensor_params = [arg for arg in args if isinstance(arg, torch.Tensor)] tensor_params_shape = [] @@ -188,10 +188,13 @@ def get_cc_cmd(build_pch): import mindspore mindspore_path = os.path.dirname(os.path.realpath(mindspore.__file__)) cc_cmd = [ + f"-I{mindspore_path}", + f"-I{os.path.join(mindspore_path, 'include/')}", f"-I{os.path.join(mindspore_path, 'include/third_party')}", f"-I{os.path.join(mindspore_path, 'include/third_party/robin_hood_hashing')}", f"-I{os.path.join(mindspore_path, 'include/mindspore/core')}", f"-I{os.path.join(mindspore_path, 'include/mindspore/core/include')}", + f"-I{os.path.join(mindspore_path, 'include/mindspore/core/mindrt/include')}", f"-I{os.path.join(mindspore_path, 'include/mindspore/ccsrc')}", f"-I{os.path.join(mindspore_path, 'include/mindspore/ccsrc/include')}", f"-I{os.path.join(mindspore_path, 'include/mindspore/ops')}", @@ -255,23 +258,31 @@ def set_current_device(device_id): @backend_strategy_registry.register("mindspore", "get_current_stream") def get_current_stream(device): import mindspore - return mindspore.current_stream().id + try: + return mindspore.current_stream().stream_ptr() + except Exception: + return mindspore.current_stream().id @backend_strategy_registry.register("torch_npu", "get_current_stream") def get_current_stream(device): import torch import torch_npu - from torch_npu._C import _npu_getCurrentRawStream if device is None: device = torch.npu.current_device() - return _npu_getCurrentRawStream(device) + if hasattr(torch_npu._C, "_npu_getCurrentRawStreamNoWait"): + from torch_npu._C import _npu_getCurrentRawStreamNoWait + return _npu_getCurrentRawStreamNoWait(device) + else: + from torch_npu._C import _npu_getCurrentRawStream + return _npu_getCurrentRawStream(device) @backend_strategy_registry.register("mindspore", "header_file") def header_file(enable_taskqueue): return f'''#include "include/utils/device_manager_conf.h" #include "include/runtime/hardware_abstract/device_context/device_context_manager.h" +#include "include/mindspore/ops/kernel/ascend/aclnn/pyboost_impl/aclnn_utils.h" {'#include "include/pynative/utils/runtime/op_executor.h"' if {enable_taskqueue} else ''} {'#include "include/runtime/pipeline/pipeline.h"' if {enable_taskqueue} else ''}''' @@ -285,34 +296,43 @@ def header_file(enable_taskqueue): @backend_strategy_registry.register("mindspore", "allocate_memory") def allocate_memory(size, stream): - return f"device_context->device_res_manager_->AllocateMemory({size}, reinterpret_cast({stream}));" + return f'''auto work_ptr = std::make_shared(device_context, {size}, reinterpret_cast({stream})); + workspace_addr_ptr = work_ptr->ptr_;''' @backend_strategy_registry.register("torch_npu", "allocate_memory") -def allocate_memory(size, option): - return f"const_cast(at::empty({size}, {option}).storage().data());" +def allocate_memory(size, stream): + return f"workspace_addr_ptr = const_cast(at::empty({size}, at::TensorOptions().device(at::kPrivateUse1).dtype(at::kByte)).storage().data());" + + +@backend_strategy_registry.register("mindspore", "allocate_sync_block_lock") +def allocate_sync_block_lock(size, stream): + return f'''auto sync_ptr = std::make_shared(device_context, {size}, reinterpret_cast({stream})); + syncBlockLock_ptr = work_ptr->ptr_;''' @backend_strategy_registry.register("torch_npu", "allocate_sync_block_lock") def allocate_sync_block_lock(size, stream): - return f"const_cast(at_npu::native::allocate_workspace({size}, {stream}).storage().data());" + return f"syncBlockLock_ptr = const_cast(at_npu::native::allocate_workspace({size}, {stream}).storage().data());" @backend_strategy_registry.register("mindspore", "pre_launch") -def pre_launch(): - return '''static auto device_context = mindspore::device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({mindspore::device::DeviceType::kAscend, mindspore::DeviceManagerConf::GetInstance()->device_id()}); - device_context->device_res_manager_->BindDeviceToCurrentThread(false);''' +def pre_launch(first_call): + if first_call: + return '''static auto device_context = mindspore::device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({mindspore::device::DeviceType::kAscend, mindspore::DeviceManagerConf::GetInstance()->device_id()}); + device_context->device_res_manager_->BindDeviceToCurrentThread(false);''' + else: + return '''device_context->device_res_manager_->BindDeviceToCurrentThread(false);''' @backend_strategy_registry.register("torch_npu", "pre_launch") -def pre_launch(): +def pre_launch(first_call): return "" @backend_strategy_registry.register("mindspore", "async_launch") def async_launch(func): - return f'''mindspore::runtime::OpExecutor::DispatchLaunchTask({func}); - mindspore::runtime::Pipeline::Get().launch_stage()->Wait();''' + return f'''mindspore::runtime::OpExecutor::DispatchLaunchTask({func});''' @backend_strategy_registry.register("torch_npu", "async_launch") diff --git a/third_party/ascend/backend/compiler.py b/third_party/ascend/backend/compiler.py index 2bb1c550af..dc7a069737 100644 --- a/third_party/ascend/backend/compiler.py +++ b/third_party/ascend/backend/compiler.py @@ -21,6 +21,7 @@ import ctypes import functools import hashlib +import glob import os import re import subprocess @@ -37,6 +38,7 @@ _check_bishengir_is_regbased, _enable_unpublished_feature, _enable_print_ub_bits, + _enable_dump_memory_info, _get_kernel_target, _get_llvm_path, _get_mlir_path, @@ -47,6 +49,7 @@ _is_auto_map_parallel_blocks_enabled, downgrade_llir, force_disable_ffts, + triton_enable_libdevice_simt, ) from triton.backends.ascend.driver import (NPUUtils) from triton.backends.compiler import ( @@ -57,12 +60,7 @@ ) from triton.runtime import driver from triton.runtime.cache import get_dump_manager - -try: - import acl - is_compile_on_910_95 = acl.get_soc_name().startswith("Ascend910_95") -except Exception as e: - is_compile_on_910_95 = False +from triton.tools.get_ascend_devices import is_compile_on_910_95 # TODO: materialize the concrete min shape @@ -96,6 +94,48 @@ def make_ttir(mod, metadata, opt): def ttir_to_linalg(mod, metadata, opt, *, named_ops=False): # use triton_adapter to lower Triton-MLIR to linalg # Get Triton-MLIR as string + ttir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "kernel.ttir.mlir") + dst_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") + Path(src_path).write_text(ttir_code) + triton_adapter_opt_path = _get_triton_adapter_opt_path() + + enable_nd2nz_on_vector = metadata["enable_nd2nz_on_vector"] + enable_select_analysis = metadata["enable_select_analysis"] + compile_on_910_95 = metadata["compile_on_910_95"] + force_simt_template = metadata["force_simt_template"] + enable_sync_block_lock = metadata["enable_sync_block_lock"] + enable_mask_fallback_conversion = metadata["enable_mask_fallback_conversion"] + optimize_dynamic_offset = metadata["optimize_dynamic_offset"] + auto_blockify_size = metadata["auto_blockify_size"] + if not _is_auto_map_parallel_blocks_enabled(): + auto_blockify_size = 1 + pm = ir.pass_manager(mod.context) + pm.enable_debug() + ascend.passes.ttir.add_auto_blockify(pm, auto_blockify_size) + if (metadata["add_auto_scheduling"]): + ascend.passes.ttir.add_dag_sync(pm) + ascend.passes.ttir.add_dag_scope(pm) + passes.common.add_cse(pm) + passes.common.add_canonicalizer(pm) + ascend.passes.ttir.add_dag_ssbuffer(pm) + passes.common.add_cse(pm) + passes.common.add_canonicalizer(pm) + + ascend.passes.ttir.add_triton_to_structure(pm, enable_mask_fallback_conversion, optimize_dynamic_offset) + ascend.passes.ttir.add_discrete_mask_access_conversion(pm, compile_on_910_95, force_simt_template, + enable_sync_block_lock) + ascend.passes.ttir.add_triton_to_annotation(pm) + ascend.passes.ttir.add_triton_to_unstructure(pm, compile_on_910_95, force_simt_template) + ascend.passes.ttir.add_triton_to_hivm(pm) + ascend.passes.ttir.add_triton_to_hfusion(pm) + ascend.passes.ttir.add_triton_to_llvm(pm) + ascend.passes.ttir.add_bubble_up_operation(pm) + ascend.passes.ttir.add_triton_to_structure(pm, enable_mask_fallback_conversion, optimize_dynamic_offset) + ascend.passes.ttir.add_triton_to_linalg(pm, False, named_ops, enable_nd2nz_on_vector, enable_select_analysis, + compile_on_910_95) + pm.run(mod) enable_nd2nz_on_vector = metadata["enable_nd2nz_on_vector"] enable_select_analysis = metadata["enable_select_analysis"] @@ -129,104 +169,6 @@ def ttir_to_linalg(mod, metadata, opt, *, named_ops=False): return str(mod) -def linalg_to_llir(linalg: str, metadata, opt): - with tempfile.TemporaryDirectory() as tmpdir: - ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") - llmlir_path = os.path.join(tmpdir, "kernel.llir.mlir") - llir_path = os.path.join(tmpdir, "kernel.ll") - Path(ttadapter_path).write_text(linalg) - mlir_opt_path = _get_mlir_path("bin", "mlir-opt") - # TritonAdapter-MLIR to LLVM-MLIR - subprocess.check_call([ - mlir_opt_path, - ttadapter_path, - "--convert-linalg-to-affine-loops", - "--eliminate-empty-tensors", - "--empty-tensor-to-alloc-tensor", - "--one-shot-bufferize=allow-return-allocs-from-loops=true", - "--lower-affine", - "--convert-linalg-to-loops", - "--convert-scf-to-cf", - "--convert-cf-to-llvm", - "--convert-arith-to-llvm", - "--convert-math-to-llvm", - "--convert-complex-to-llvm", - "--convert-vector-to-llvm", - "--convert-index-to-llvm", - "--memref-expand", - "--expand-strided-metadata", - "--finalize-memref-to-llvm", - "--convert-func-to-llvm", - # Lowering memrefs creates more affine.apply ops. - # Lowering these affine ops again creates further arith ops, - # so we have to run these two passes again here. - "--lower-affine", - "--convert-arith-to-llvm", - # Remove all unrealized casts created - "--reconcile-unrealized-casts", - "-o", - llmlir_path, - ]) - if opt.debug: - dump_manager = get_dump_manager(metadata["hash"]) - dump_manager.put(Path(llmlir_path).read_text(), "kernel.llir.mlir", binary=False) - - # LLVM-MLIR to LLVM-IR - mlir_translate_path = _get_mlir_path("bin", "mlir-translate") - subprocess.check_call([mlir_translate_path, llmlir_path, "--mlir-to-llvmir", "-o", llir_path]) - if opt.debug: - dump_manager = get_dump_manager(metadata["hash"]) - dump_manager.put(Path(llir_path).read_text(), "kernel.ll", binary=False) - - return Path(llir_path).read_text() - - -def llir_to_cpuasm(llir: str, metadata, opt): - # add metadata at final stage - # Note: Compiled Kernel requires to estimate size of shared memory to occupy - # Currently, CPU backend requires no limit on shared memory size - metadata["shared"] = 1 - # We can get a function name (C naming) from - # LLVM-IR by getting the first "define void @". - fn_name = llir.split("define void @")[1].split("(")[0].strip() - metadata["name"] = fn_name + " cpu" - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "kernel.ll") - linked_path = os.path.join(tmpdir, "kernel_linked.ll") - dst_path = os.path.join(tmpdir, "kernel.s") - - llir = downgrade_llir(llir) - if opt.debug: - dump_manager = get_dump_manager(metadata["hash"]) - dump_manager.put(llir, "kernel_downgrade.ll", binary=False) - - Path(src_path).write_text(llir) - - linker_path = _get_llvm_path("bin", "llvm-link") - libclc_path = _get_llvm_path("lib", "clc", "libspirv-aarch64--.bc") - subprocess.check_call([ - linker_path, - src_path, - libclc_path, - "--only-needed", - "-S", - "-o", - linked_path, - ]) - if opt.debug: - dump_manager = get_dump_manager(metadata["hash"]) - dump_manager.put(Path(linked_path).read_text(), "kernel_linked.ll", binary=False) - - llc_path = _get_llvm_path("bin", "llc") - subprocess.check_call([llc_path, linked_path, "-o", dst_path]) - if opt.debug: - dump_manager = get_dump_manager(metadata["hash"]) - dump_manager.put(Path(dst_path).read_text(), "kernel.s", binary=False) - - # Actually it's text-format assembly. Use read_text(). - return Path(dst_path).read_text() - - def __get_metadata_attr_by_callback(lib, postfix: str, metadata, meta_key: str): func_symbol = metadata["kernel_name"] + postfix if hasattr(lib, func_symbol): @@ -264,8 +206,8 @@ def _parse_linalg_metadata(linalg: str, metadata: dict): # Example: %arg1: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32} -> ('1', '0') TENSOR_KIND_REGEX = r'%arg(\d+):[^,)]*?\{[^}]*?tt\.tensor_kind\s*=\s*([^:\s}]+)\s*:[^}]*?\}' - # Example removal: ', mix_mode = "aiv"' → '' - REMOVE_MIX_MODE_REGEX = r', mix_mode\s*=\s*"[^"]*"' + # Example: bitcode = "a.bc" + BITCODES_REGEX = r'bitcode\s*=\s*(?:"([^"]+)"|\'([^\']+)\'|(\w+))' # Note: Compiled Kernel requires to estimate size of shared memory to occupy # Currently, NPU backend does not limit on shared memory @@ -276,15 +218,17 @@ def _parse_linalg_metadata(linalg: str, metadata: dict): metadata["mix_mode"] = re.search(MIX_MODE_REGEX, linalg).group(1) metadata["parallel_mode"] = re.search(PARALLEL_MODE_REGEX, linalg).group(1) metadata["kernel_name"] = re.search(KERNEL_NAME_REGEX, linalg).group(1) - # Use while space to split kernel_name and mix_mode. + # Use while "_" to split kernel_name and mix_mode. # Check the function load_binary in npu_driver.py. - metadata["name"] = metadata["kernel_name"] + " " + metadata["mix_mode"] + metadata["name"] = metadata["kernel_name"] + "_" + metadata["mix_mode"] # Parse all tensor kinds from arguments metadata["tensor_kinds"] = [int(kind) for _, kind in re.findall(TENSOR_KIND_REGEX, linalg)] # init the ub bits of triton kernel for inductor autotune using metadata["required_ub_bits"] = 0 - # remove the mix_mode attribute - linalg = re.sub(REMOVE_MIX_MODE_REGEX, "", linalg) + + # Parse all bitcode paths + bitcodes = re.findall(BITCODES_REGEX, linalg) + metadata["bitcodes"] = [val for group in bitcodes for val in group if val] return linalg, metadata @@ -308,7 +252,7 @@ def _parse_ttir_metadata(ttir: str, metadata: dict): # Note: Currently, for TTIR inputs, we only support vector kernels. metadata["mix_mode"] = "aiv" metadata["kernel_name"] = re.search(KERNEL_NAME_REGEX, ttir).group(1) - metadata["name"] = metadata["kernel_name"] + " " + metadata["mix_mode"] + metadata["name"] = metadata["kernel_name"] + "_" + metadata["mix_mode"] # Parse all tensor kinds from arguments metadata["tensor_kinds"] = [int(kind) for _, kind in re.findall(TENSOR_KIND_REGEX, ttir)] return metadata @@ -320,6 +264,28 @@ def get_common_bishengir_compile_options(metadata): return [bishengir_target_opt] +def get_auto_bind_sub_block_option(metadata): + # auto_tile_and_bind_subblock is read from the module. + # enable_auto_bind_sub_block is set by the user and has a higher priority. + enable_auto_bind_sub_block = metadata["enable_auto_bind_sub_block"] + return (metadata["auto_tile_and_bind_subblock"] + if enable_auto_bind_sub_block is None else enable_auto_bind_sub_block) + + +def _save_npuir_debug_output(stdout_bytes: bytes, stderr_bytes: bytes, tmpdir: str, metadata_hash: str): + stdout = stdout_bytes.decode('utf-8') if stdout_bytes else '' + stderr = stderr_bytes.decode('utf-8') if stderr_bytes else '' + combined = stdout + stderr + if not combined.strip(): + combined = "No output captured." + output_path = os.path.join(tmpdir, "kernel.npuir.mlir") + with open(output_path, 'w', encoding='utf-8') as f: + f.write(combined) + + dump_manager = get_dump_manager(metadata_hash) + dump_manager.put(Path(output_path).read_text(encoding='utf-8'), "kernel.npuir.mlir", binary=False) + + def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): linalg, metadata = _parse_linalg_metadata(linalg, metadata) with tempfile.TemporaryDirectory() as tmpdir: @@ -340,17 +306,14 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): f"--enable-auto-multi-buffer={multibuffer}", ] - enable_ubuf_saving = metadata["enable_ubuf_saving"] - if enable_ubuf_saving is not None: - _compile_option_list += [ - f"--enable-ubuf-saving={enable_ubuf_saving}", - ] + disable_tightly_coupled_buffer_reuse = metadata["disable_tightly_coupled_buffer_reuse"] + if disable_tightly_coupled_buffer_reuse: + _compile_option_list += ["--disable-tightly-coupled-buffer-reuse"] + + _compile_option_list += [ + f"--enable-auto-bind-sub-block={get_auto_bind_sub_block_option(metadata)}", + ] - enable_auto_bind_sub_block = metadata["enable_auto_bind_sub_block"] - if enable_auto_bind_sub_block is not None: - _compile_option_list += [ - f"--enable-auto-bind-sub-block={enable_auto_bind_sub_block}", - ] if force_disable_ffts(): _compile_option_list += ["--disable-ffts"] if _is_ascend_sanitizer_enabled(): @@ -366,6 +329,11 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): _compile_option_list += \ [f"--enable-hivm-auto-cv-balance={enable_hivm_auto_cv_balance}"] + sync_solver = metadata["sync_solver"] + if sync_solver is not None: + _compile_option_list += \ + [f"--enable-hivm-graph-sync-solver={sync_solver}"] + unit_flag = metadata["unit_flag"] if unit_flag is not None: _compile_option_list += \ @@ -376,6 +344,11 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): _compile_option_list += \ [f"--enable-hivm-inject-barrier-all-sync={inject_barrier_all}"] + inject_block_all = metadata["inject_block_all"] + if inject_block_all is not None: + _compile_option_list += \ + [f"--enable-hivm-inject-block-all-sync={inject_block_all}"] + limit_auto_multi_buffer_only_for_local_buffer = metadata["limit_auto_multi_buffer_only_for_local_buffer"] if limit_auto_multi_buffer_only_for_local_buffer is not None: _compile_option_list += \ @@ -386,16 +359,6 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): _compile_option_list += \ [f"--set-workspace-multibuffer={set_workspace_multibuffer}"] - tile_mix_vector_loop = metadata["tile_mix_vector_loop"] - if tile_mix_vector_loop is not None: - _compile_option_list += \ - [f"--tile-mix-vector-loop={tile_mix_vector_loop}"] - - tile_mix_cube_loop = metadata["tile_mix_cube_loop"] - if tile_mix_cube_loop is not None: - _compile_option_list += \ - [f"--tile-mix-cube-loop={tile_mix_cube_loop}"] - auto_multi_buffer = metadata["limit_auto_multi_buffer_of_local_buffer"] if auto_multi_buffer is not None: _compile_option_list += \ @@ -404,33 +367,55 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): enable_mixed_cv = metadata["enable_mixed_cv"] if enable_mixed_cv is not None: _compile_option_list += \ - [f"--enable_mixed_cv={enable_mixed_cv}"] + [f"--enable-mixed-cv={enable_mixed_cv}"] enable_cce_vf_auto_sync = metadata["enable_cce_vf_auto_sync"] if enable_cce_vf_auto_sync is not None: _compile_option_list += \ - [f"--apend-bisheng-options=-mllvm --cce-vf-auto-sync={enable_cce_vf_auto_sync}"] + [f"--append-bisheng-options=-mllvm --cce-vf-auto-sync={enable_cce_vf_auto_sync}"] enable_cce_vf_remove_membar = metadata["enable_cce_vf_remove_membar"] if enable_cce_vf_remove_membar is not None: _compile_option_list += \ - [f"--apend-bisheng-options=-mllvm --cce-vf-remove-membar={enable_cce_vf_remove_membar}"] + [f"--append-bisheng-options=-mllvm --cce-vf-remove-membar={enable_cce_vf_remove_membar}"] + + if metadata["enable_vf_fusion"]: + _compile_option_list += ["--enable-vf-fusion"] enable_drop_unit_dims = metadata["enable_drop_unit_dims"] if enable_drop_unit_dims is not None: _compile_option_list += \ [f"--enable-drop-unit-dims={enable_drop_unit_dims}"] + enable_flatten = metadata["enable_flatten"] + if enable_flatten is not None: + _compile_option_list += \ + [f"--enable-flatten={enable_flatten}"] + enable_auto_vectorize_v2 = metadata["enable_auto_vectorize_v2"] if enable_auto_vectorize_v2 is not None: _compile_option_list += \ [f"--enable-auto-vectorize-v2={enable_auto_vectorize_v2}"] + auto_vectorize_v2_max_fused_ops_num = metadata["auto_vectorize_v2_max_fused_ops_num"] + if auto_vectorize_v2_max_fused_ops_num is not None: + _compile_option_list += \ + [f"--hfusion-max-fused-ops-in-auto-vectorize-v2={auto_vectorize_v2_max_fused_ops_num}"] + prevec_max_fused_ops_num = metadata["prevec_max_fused_ops_num"] + if prevec_max_fused_ops_num is not None: + _compile_option_list += \ + [f"--hfusion-max-fused-elementwise-ops={prevec_max_fused_ops_num}"] disable_auto_inject_block_sync = metadata["disable_auto_inject_block_sync"] if disable_auto_inject_block_sync is not None: _compile_option_list += \ [f"--disable-auto-inject-block-sync={disable_auto_inject_block_sync}"] + bitcodes = metadata["bitcodes"] + if bitcodes is not None: + for bitcode in bitcodes: + _compile_option_list += \ + [f"--link-aicore-bitcode={bitcode}"] + if _is_auto_map_parallel_blocks_enabled(): _compile_option_list += ["--enable-auto-blockify-loop"] npu_compiler_path, env = _get_npucompiler_path() @@ -445,20 +430,44 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): mix_mode = opt.mix_mode if mix_mode in ["aic"]: _compile_option_list += ["--disable-hfusion-vectorize=true"] + + if opt.debug: + _compile_option_list += ["--bishengir-print-ir-after=hivm-graph-sync-solver"] + cmd_list = ([npu_compiler_path, ttadapter_path] + _compile_option_list + ["-o", bin_file]) - # TODO both bishengir-compile and triton-compile use passing attr by module - auto_tile_and_bind_subblock = metadata["auto_tile_and_bind_subblock"] - if auto_tile_and_bind_subblock is False: - cmd_list += ["--enable-auto-bind-sub-block=false"] vf_merge_level = metadata["vf_merge_level"] - if vf_merge_level: + if vf_merge_level is not None: cmd_list += [f"--enable-vf-merge-level={vf_merge_level}"] - ret = subprocess.run(cmd_list, env=env, capture_output=True, check=True) - match = re.search(r'UB\s+size\s*=\s*(\d+)\s*bits', ret.stdout.decode('utf-8')) + hfusion_enable_multiple_consumer_fusion = metadata["hfusion_enable_multiple_consumer_fusion"] + if hfusion_enable_multiple_consumer_fusion: + cmd_list += [f"--hfusion-enable-multiple-consumer-fusion={hfusion_enable_multiple_consumer_fusion}"] + + if opt.debug: + print(f"[DEBUG] cmd_list: {' '.join(cmd_list)}") + + try: + ret = subprocess.run(cmd_list, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + except subprocess.CalledProcessError as e: + if opt.debug: + _save_npuir_debug_output(e.stdout, e.stderr, tmpdir, metadata["hash"]) + raise + + if opt.debug: + _save_npuir_debug_output(ret.stdout, ret.stderr, tmpdir, metadata["hash"]) + + stdout_str = ret.stdout.decode('utf-8') if ret.stdout else '' + match = re.search(r'UB\s+size\s*=\s*(\d+)\s*bits', stdout_str) if match: # get the ub bits of triton kernel from bisheng for inductor autotune using metadata["required_ub_bits"] = int(match.group(1)) + + if not Path(bin_path).exists(): + error_msg = ret.stderr.decode('utf-8') if ret.stderr else '' + print(f"[DEBUG] {bin_path} is not found") + print(f"[DEBUG] Stderr:\n{error_msg}") + raise subprocess.CalledProcessError(ret.returncode, cmd_list, ret.stdout, ret.stderr) + if Path(callback_path).is_file(): lib = ctypes.CDLL(callback_path) __get_metadata_attr_by_callback(lib, "_infer_task_type_function", metadata, "bs_task_type") @@ -501,11 +510,16 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): f"--enable-ubuf-saving={enable_ubuf_saving}", ] - enable_auto_bind_sub_block = metadata["enable_auto_bind_sub_block"] - if enable_auto_bind_sub_block is not None: + enable_preload = metadata["enable_preload"] + if enable_preload is not None: _compile_option_list += [ - f"--enable-auto-bind-sub-block={enable_auto_bind_sub_block}", + f"--enable-preload={enable_preload}", ] + + _compile_option_list += [ + f"--enable-auto-bind-sub-block={get_auto_bind_sub_block_option(metadata)}", + ] + if _is_ascend_sanitizer_enabled(): _compile_option_list += ["--enable-sanitizer=true"] if not _is_debug_line_info_disabled(): @@ -514,6 +528,9 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): if _enable_print_ub_bits(): _compile_option_list += ["--enable-print-memory-allocated-size"] + if _enable_dump_memory_info(): + _compile_option_list += ["--enable-memory-display=true"] + enable_hivm_auto_cv_balance = metadata["enable_hivm_auto_cv_balance"] if enable_hivm_auto_cv_balance is not None: _compile_option_list += \ @@ -521,8 +538,10 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): sync_solver = metadata["sync_solver"] if sync_solver is not None: - _compile_option_list += \ - [f"--enable-hivm-graph-sync-solver={sync_solver}"] + _compile_option_list += [ + f"--enable-hivm-graph-sync-solver={sync_solver}", + f"--enable-hivm-cross-core-gss={sync_solver}", + ] unit_flag = metadata["unit_flag"] if unit_flag is not None: @@ -534,6 +553,11 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): _compile_option_list += \ [f"--enable-drop-unit-dims={enable_drop_unit_dims}"] + enable_flatten = metadata["enable_flatten"] + if enable_flatten is not None: + _compile_option_list += \ + [f"--enable-flatten={enable_flatten}"] + enable_auto_vectorize_v2 = metadata["enable_auto_vectorize_v2"] if enable_auto_vectorize_v2 is not None: _compile_option_list += \ @@ -579,6 +603,21 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): _compile_option_list += \ [f"--disable-auto-inject-block-sync={disable_auto_inject_block_sync}"] + bitcodes = metadata["bitcodes"] + if bitcodes is not None: + for bitcode in bitcodes: + _compile_option_list += \ + [f"--link-aicore-bitcode={bitcode}"] + + enable_libdevice = os.getenv("TRITON_ENABLE_LIBDEVICE", False) + if enable_libdevice: + _compile_option_list += [f"--link-aicore-bitcode={get_libdevice()}"] + + disable_size_align_for_cast = metadata["disable_size_align_for_cast"] + if disable_size_align_for_cast is not None: + _compile_option_list += \ + [f"--disable-size-align-for-cast={disable_size_align_for_cast}"] + if _is_auto_map_parallel_blocks_enabled(): _compile_option_list += ["--enable-auto-blockify-loop"] npu_compiler_path, env = _get_npucompiler_path() @@ -588,15 +627,34 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): bishengir_hivm_opt, "--enable-triton-kernel-compile=true", ] + + if opt.debug: + _compile_option_list += ["--bishengir-print-ir-after=hivm-graph-sync-solver"] cmd_list = ([npu_compiler_path, ttadapter_path] + _compile_option_list + ["-o", bin_file]) - auto_tile_and_bind_subblock = metadata["auto_tile_and_bind_subblock"] - if auto_tile_and_bind_subblock is False: - cmd_list += ["--enable-auto-bind-sub-block=false"] - ret = subprocess.run(cmd_list, env=env, capture_output=True, check=True) - match = re.search(r'UB\s+size\s*=\s*(\d+)\s*bits', ret.stdout.decode('utf-8')) + if opt.debug: + print(f"[DEBUG] cmd_list: {' '.join(cmd_list)}") + + try: + ret = subprocess.run(cmd_list, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + except subprocess.CalledProcessError as e: + if opt.debug: + _save_npuir_debug_output(e.stdout, e.stderr, tmpdir, metadata["hash"]) + raise + + if opt.debug: + _save_npuir_debug_output(ret.stdout, ret.stderr, tmpdir, metadata["hash"]) + + stdout_str = ret.stdout.decode('utf-8') if ret.stdout else '' + match = re.search(r'UB\s+size\s*=\s*(\d+)\s*bits', stdout_str) if match: - # get the ub bits of triton kernel from bisheng for inductor autotune using metadata["required_ub_bits"] = int(match.group(1)) + + if not Path(bin_path).exists(): + error_msg = ret.stderr.decode('utf-8') if ret.stderr else '' + print(f"[DEBUG] {bin_path} is not found") + print(f"[DEBUG] Stderr:\n{error_msg}") + raise subprocess.CalledProcessError(ret.returncode, cmd_list, ret.stdout, ret.stderr) + if Path(callback_path).is_file(): lib = ctypes.CDLL(callback_path) __get_metadata_attr_by_callback(lib, "_infer_task_type_function", metadata, "bs_task_type") @@ -607,6 +665,11 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): return Path(bin_path).read_bytes() +def get_libdevice(): + current = os.path.dirname(__file__) + return os.path.join(current, "lib/libdevice.10.bc") + + @dataclass(frozen=True) class NPUOptions: debug: bool = False @@ -616,7 +679,7 @@ class NPUOptions: arch: str = "" cluster_dims: tuple = (1, 1, 1) - num_warps: int = 4 + num_warps: int = 32 num_ctas: int = 1 num_stages: int = 1 warp_size: int = 32 @@ -625,6 +688,7 @@ class NPUOptions: reg_dec_producer: int = 0 reg_inc_consumer: int = 0 + auto_blockify_size: int = 1 compile_on_910_95: bool = is_compile_on_910_95 optimize_dynamic_offset: bool = False enable_mask_fallback_conversion: bool = False @@ -639,14 +703,17 @@ class NPUOptions: supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15", "fp8e4nv", "fp8e4b8", "fp8e5b16") deprecated_fp8_dtypes: Tuple[str] = () vf_merge_level: int = 1 + default_dot_input_precision: str = "ieee" allowed_dot_input_precisions: Tuple[str] = ("ieee", "hf32") max_num_imprecise_acc_default: int = 0 extern_libs: dict = None - bisheng_options: str = None + bisheng_options: str = "-cce-link-aicore-ll-module " + get_libdevice() multibuffer: bool = not is_compile_on_910_95 enable_ubuf_saving: bool = None - enable_auto_bind_sub_block: bool = not is_compile_on_910_95 + enable_preload: bool = None + enable_auto_bind_sub_block: bool = None + disable_tightly_coupled_buffer_reuse: bool = False enable_select_analysis: bool = True enable_hivm_auto_cv_balance: bool = None sync_solver: bool = None @@ -654,9 +721,13 @@ class NPUOptions: enable_cce_vf_auto_sync: bool = None enable_cce_vf_remove_membar: bool = None enable_drop_unit_dims: bool = None + enable_flatten: bool = None enable_auto_vectorize_v2: bool = None + auto_vectorize_v2_max_fused_ops_num: int = None + prevec_max_fused_ops_num: int = None inject_barrier_all: bool = None inject_block_all: bool = None + disable_size_align_for_cast: bool = None limit_auto_multi_buffer_only_for_local_buffer: bool = None limit_auto_multi_buffer_of_local_buffer: str = None set_workspace_multibuffer: int = None @@ -664,13 +735,17 @@ class NPUOptions: tile_mix_cube_loop: int = None disable_auto_inject_block_sync: bool = None enable_mixed_cv: bool = None + enable_vf_fusion: bool = False + add_auto_scheduling: bool = False + hfusion_enable_multiple_consumer_fusion: bool = False stream: int = None parallel_mode: str = "simd" force_simt_only: bool = False force_simt_template: bool = False + enable_sync_block_lock: bool = False # only take effect on the simt-only & simd-simt-mix scenarios - shared_mem_dynamic_size: int = 221184 + shared_mem_dynamic_size: int = None # enable_bishengir_simt_optimization is passed as # -enable-bishengir-simt-optimization flag to bishengir-compile. enable_bishengir_simt_optimization: int = 000 @@ -679,6 +754,10 @@ class NPUOptions: compile_mode: str = "simd" mix_mode: str = "" simt_stack_limit: int = None + # take effect on the reorder instruction pattern for SIMT. The pattern is disabled by default. + enable_simt_reorder_instruction: bool = False + # disable simt fma optimization to get high precision + disable_fma: bool = False def __post_init__(self): # Parse compile_mode and set related fields @@ -690,31 +769,12 @@ def __post_init__(self): elif self.compile_mode == "simt_only": object.__setattr__(self, "force_simt_only", True) object.__setattr__(self, "parallel_mode", "simt") - object.__setattr__(self, "shared_mem_dynamic_size", 122880) - - def hash(self): - key = "_".join([f"{name}-{val}" for name, val in self.__dict__.items()]) - return hashlib.sha256(key.encode("utf-8")).hexdigest() - -@dataclass(frozen=True) -class CPUOptions: - debug: bool = False - llvm_version: int = 15 - kernel_name: str = "triton_" - - cluster_dims: tuple = (1, 1, 1) - num_warps: int = -1 - num_ctas: int = -1 - num_stages: int = -1 - - enable_warp_specialization: bool = False - enable_persistent: bool = False - optimize_epilogue: bool = False - enable_fp_fusion: bool = True - allow_fp8e4nv: bool = False - max_num_imprecise_acc_default: int = 0 - extern_libs: dict = None + if self.force_simt_only: + if self.shared_mem_dynamic_size is None: + object.__setattr__(self, "shared_mem_dynamic_size", 122880) + else: + object.__setattr__(self, "shared_mem_dynamic_size", 221184) def hash(self): key = "_".join([f"{name}-{val}" for name, val in self.__dict__.items()]) @@ -744,6 +804,7 @@ def ttir_to_npubin(mod, metadata, opt): # build compile options _compile_option_list = get_common_bishengir_compile_options(metadata) if opt.force_simt_only: + _compile_option_list += ["--enable-hivm-compile=false"] _compile_option_list += ["--enable-triton-ir-compile"] _compile_option_list += ["--pure-simt"] _compile_option_list += [f"--num-warps={opt.num_warps}"] @@ -754,12 +815,27 @@ def ttir_to_npubin(mod, metadata, opt): ] if opt.simt_stack_limit: _compile_option_list += [f"--simt-stack-limit={opt.simt_stack_limit}"] - if opt.shared_mem_dynamic_size: + if opt.shared_mem_dynamic_size is not None: _compile_option_list += [f"--shared-mem-dynamic-size={opt.shared_mem_dynamic_size}"] + if opt.enable_simt_reorder_instruction: + _compile_option_list += ["--enable-simt-reorder-instruction=true"] + if opt.disable_fma: + _compile_option_list += [f"--disable-fma"] + + enable_libdevice_simt = triton_enable_libdevice_simt() + if (enable_libdevice_simt): + bisheng_options = metadata["bisheng_options"] + if bisheng_options is not None: + _compile_option_list += [f"--append-bisheng-options={bisheng_options}"] npu_compiler_path, env = _get_npucompiler_path() cmd_list = ([npu_compiler_path, src_path] + _compile_option_list + ["-o", bin_file]) ret = subprocess.run(cmd_list, env=env, capture_output=True, check=True) + if not Path(bin_path).exists(): + error_msg = ret.stderr.decode('utf-8') + print(f"[DEBUG] {bin_path} is not found") + print(f"[DEBUG] Stderr:\n{error_msg}") + raise subprocess.CalledProcessError(ret.returncode, cmd_list, ret.stdout, ret.stderr) return Path(bin_path).read_bytes() @@ -767,13 +843,11 @@ class AscendBackend(BaseBackend): @staticmethod def supports_target(target: GPUTarget): - return target.backend == "cpu" or target.backend == "npu" + return target.backend == "npu" def __init__(self, target: GPUTarget) -> None: super().__init__(target) - if target.backend == "cpu": - self.binary_ext = "cpuasm" - elif target.backend == "npu": + if target.backend == "npu": self.binary_ext = "npubin" def parse_options(self, opts) -> Any: @@ -783,8 +857,8 @@ def parse_options(self, opts) -> Any: args.setdefault("arch", self.target.arch) options = NPUOptions(**args) else: - args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} - options = CPUOptions(**args) + raise NotImplementedError(f"Backend '{self.target.backend}' is not supported. " + "Please ensure the target backend is set to 'npu'.") return options def pack_metadata(self, metadata): @@ -795,7 +869,7 @@ def pack_metadata(self, metadata): # CANN runtime limits the length of kernel name <= 50. # Considering '\n' is appended, thus the real kernel name <= 49. KERNEL_NAME_MAX_LEN = 49 - kernel_name_orig, mix_mode = metadata.name.split() + kernel_name_orig, _ = metadata.name.rsplit("_", 1) if len(kernel_name_orig) > KERNEL_NAME_MAX_LEN: kernel_name = kernel_name_orig[-KERNEL_NAME_MAX_LEN:] else: @@ -835,10 +909,8 @@ def add_stages(self, stages, options): stages["npubin"] = ( lambda src, metadata: linalg_to_bin_enable_npu_compile_A2_A3(src, metadata, options)) else: - stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options) - stages["ttadapter"] = lambda src, metadata: ttir_to_linalg(src, metadata, options) - stages["llir"] = lambda src, metadata: linalg_to_llir(src, metadata, options) - stages["cpuasm"] = lambda src, metadata: llir_to_cpuasm(src, metadata, options) + raise NotImplementedError(f"Backend '{self.target.backend}' is not supported. " + "Please ensure the target backend is set to 'npu'.") @functools.lru_cache() def hash(self): diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py index e1fe48bb41..b70608ec57 100644 --- a/third_party/ascend/backend/driver.py +++ b/third_party/ascend/backend/driver.py @@ -28,7 +28,7 @@ from typing import Optional import functools import hashlib -from triton.runtime.cache import get_cache_manager, get_dump_manager +from triton.runtime.cache import get_cache_manager, get_dump_manager, default_cache_dir from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget from triton.backends.ascend.utils import (_precompile_npu_hash, _precompile_npu_ext, _build_npu_ext, _check_cxx11_abi, @@ -69,7 +69,7 @@ def __init__(self): env_arch = get_ascend_arch_from_env() def load_binary(self, name, kernel, shared, device): - fnname, mix_mode = name.split() + fnname, mix_mode = name.rsplit("_", 1) return self.npu_utils_mod.load_kernel_binary(fnname, kernel, shared, device, mix_mode) @functools.lru_cache() @@ -94,19 +94,6 @@ def get_aicore_num(self): def get_aivector_core_num(self): return self.get_device_properties("npu")["num_vectorcore"] - @functools.lru_cache() - def set_device_limit(self, device, ty, val): - """ - Set npu device limit - - Args: - device: Device id - ty: The type of the limit, valid types include: - "LOW_POWER_TIMEOUT", "WARP_STACK_SIZE", "DVG_WARP_STACK_SIZE", "STACK_SIZE" - val: The specific meaning of the value depends on the type of limit. - """ - self.npu_utils_mod.set_device_limit(device, ty, val) - class NPULauncher(object): @@ -225,15 +212,38 @@ def get_empty_cache_for_benchmark(self): return get_backend_func("get_empty_tensor", cache_size // 4) -# fixed the issue of corrupted gch header files in multi-threaded scenarios. -def _precompile_npu_ext_with_lock(header_path): +def _precompile_npu_ext_with_lock(header_src, enable_precompile): import fcntl - src_path = os.path.dirname(header_path) - lock_path = os.path.join(src_path, "precompiled.lock") + precompile_hash = _precompile_npu_hash(header_src) + cache = get_cache_manager(precompile_hash) + gch_path = cache.get_file("precompiled.h.gch") + header_path = cache.get_file("precompiled.h") + if enable_precompile: + if header_path is not None and gch_path is not None: + return header_path + else: + if header_path is not None: + return header_path + cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() + lock_path = os.path.join(cache_dir, f"{precompile_hash}.lock") with open(lock_path, "a+") as f: try: fcntl.flock(f, fcntl.LOCK_EX) - _precompile_npu_ext(header_path) + header_path = cache.get_file("precompiled.h") + if enable_precompile: + gch_path = cache.get_file("precompiled.h.gch") + if header_path is not None and gch_path is not None: + return header_path + else: + if header_path is not None: + return header_path + header_path = cache.put(header_src, "precompiled.h", binary=False) + if not enable_precompile: + return header_path + src_dir = os.path.dirname(header_path) + gch_path = os.path.join(src_dir, "precompiled.h.gch") + _precompile_npu_ext(header_path, gch_path) + return header_path finally: fcntl.flock(f, fcntl.LOCK_UN) @@ -242,14 +252,10 @@ def make_npu_launcher_stub(header_src, wrapper_src, debug=False): """ Generate the launcher stub to launch the kernel """ - precompile_hash = _precompile_npu_hash(header_src) - cache = get_cache_manager(precompile_hash) - header_path = cache.get_file("precompiled.h") - gch_path = cache.get_file("precompiled.h.gch") + enable_precompile = not os.getenv("TRITON_DISABLE_PRECOMPILE", 'false').lower() in ('true', '1') # if precompile header file and its gch file not exist, do precompile - if header_path is None and gch_path is None: - header_path = cache.put(header_src, "precompiled.h", binary=False) - _precompile_npu_ext_with_lock(header_path) + header_path = _precompile_npu_ext_with_lock(header_src, enable_precompile) + assert header_path is not None, "the precompiled.h path is empty." # try to get cached file so_cache_key = hashlib.sha256(wrapper_src.encode("utf-8")).hexdigest() @@ -274,15 +280,13 @@ def make_npu_launcher_stub(header_src, wrapper_src, debug=False): return cache_path kernel_launcher_type = "torch" - enable_taskqueue = os.getenv("TRITON_ENABLE_TASKQUEUE", 'true').lower() in ('true', '1') - if not enable_taskqueue: - kernel_launcher_type = None with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, f"{name}.cxx") with open(src_path, "w") as f: f.write(wrapper_src) - so_path = _build_npu_ext(name, header_path, src_path, kernel_launcher=kernel_launcher_type, precompile=True) + so_path = _build_npu_ext(name, header_path, src_path, kernel_launcher=kernel_launcher_type, + precompile=enable_precompile) if debug: with open(so_path, "rb") as f: dump_manager.put(f.read(), so_name, binary=True) @@ -555,10 +559,32 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); if(!ptr_info.dev_ptr) return ptr_info; - Py_DECREF(ret); // Thanks ChatGPT! + aclrtPtrAttributes attributes; + aclError status = aclrtPointerGetAttributes(ptr_info.dev_ptr, &attributes); + + if (status == ACL_SUCCESS) { + if (attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE && attributes.location.type != 4) { + Py_DECREF(ret); + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + return ptr_info; + } + } else { + Py_DECREF(ret); + PyErr_Format(PyExc_RuntimeError, + "Failed to query pointer attributes at argument %d. " + "Error code: %d. This may indicate invalid memory address " + "or NPU device error.", + idx, status); + ptr_info.valid = false; + return ptr_info; + } + Py_DECREF(ret); return ptr_info; } PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; return ptr_info; } """ @@ -746,12 +772,13 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): name.append(kernelName); void *workspace_addr_ptr = NULL; uint32_t blockNum4Workspace = gridX * gridY * gridZ; + {get_backend_func("pre_launch", True)} {f''' uint64_t totalWorkSpaceSize = {workspace_size} * blockNum4Workspace; - auto optionsWorkspace = at::TensorOptions().device(at::kPrivateUse1).dtype(at::kByte); - workspace_addr_ptr = {get_backend_func("allocate_memory", "totalWorkSpaceSize", "optionsWorkspace")} + {get_backend_func("allocate_memory", "totalWorkSpaceSize", "stream")} ''' if workspace_size > 0 else ''} {'auto launch_call = [=]() -> rtError_t' if enable_taskqueue else ''} {{ + {get_backend_func("pre_launch", False)} uint32_t blockNum = gridX * gridY * gridZ; #ifdef ENABLE_GRID_WARN_PRINT @@ -761,14 +788,13 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): warned = true; }} #endif - {get_backend_func("pre_launch")} {'blockNum = std::min(blockNum, (uint32_t)' + str(num_physical_blocks) + ');' if enable_auto_map_parallel_blocks else ''} // set mixBlockNumRation for nodeBasicBlockDim for msprof report uint32_t mixBlockNumRation = {mix_block_dim_ratio}; uint32_t nodeBasicBlockDim = (mixBlockNumRation << 16) + blockNum; {'cce::internal::DebugTunnelData *DTData = cce::internal::DebugTunnel::Open(blockNum);' if enable_device_print else ''} - rtError_t ret; + rtError_t ret = RT_ERROR_NONE; {'void *ffts_addr = NULL; uint32_t ffts_len; ret = rtGetC2cCtrlAddr((uint64_t*)&ffts_addr, &ffts_len);' if target_support_ffts else ''} {'if (ret != RT_ERROR_NONE) return ret;' if (target_support_ffts and enable_taskqueue) else 'if (ret != RT_ERROR_NONE) return;' if (target_support_ffts and (not enable_taskqueue)) else ''} // stub argument for workspace @@ -776,7 +802,7 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): uint16_t ModuleId = 0; {f''' uint64_t syncBlockLockSize = {lock_num} * sizeof(int64_t); - syncBlockLock_ptr = {get_backend_func("allocate_sync_block_lock", "syncBlockLockSize", "stream")} + {get_backend_func("allocate_sync_block_lock", "syncBlockLockSize", "stream")} if (!syncBlockLock_ptr) {{ {alloc_success_code if enable_taskqueue else sync_lock_fail_code} }} @@ -880,8 +906,12 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): } }} - if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ - return NULL; + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; }} // get kernel_name @@ -904,8 +934,12 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): if (PyErr_Occurred()) {{ return NULL; }} - if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ - return NULL; + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; }} Py_RETURN_NONE; }} diff --git a/third_party/ascend/backend/lib/libdevice.10.bc b/third_party/ascend/backend/lib/libdevice.10.bc new file mode 100644 index 0000000000..76bcbc63a5 Binary files /dev/null and b/third_party/ascend/backend/lib/libdevice.10.bc differ diff --git a/third_party/ascend/backend/npu_utils.cpp b/third_party/ascend/backend/npu_utils.cpp index 28c4ec1354..1665ff1b4a 100644 --- a/third_party/ascend/backend/npu_utils.cpp +++ b/third_party/ascend/backend/npu_utils.cpp @@ -37,9 +37,10 @@ static std::unordered_map registered_names; static std::unordered_map> func_stubs; -static std::tuple -registerKernel(const char *name, const void *data, size_t data_size, int shared, - int device, const char *kernel_mode_str) { +static std::tuple registerKernel(const char *name, + const void *data, + size_t data_size, int device, + const char *kernel_mode_str) { rtError_t rtRet; rtDevBinary_t devbin; @@ -55,14 +56,14 @@ registerKernel(const char *name, const void *data, size_t data_size, int shared, rtRet = rtSetDevice(device); if (rtRet != RT_ERROR_NONE) { printf("rtSetDevice failed, 0x%x\n", rtRet); - return {NULL, NULL}; + return {nullptr, nullptr}; } - void *devbinHandle = NULL; + void *devbinHandle = nullptr; rtRet = rtDevBinaryRegister(&devbin, &devbinHandle); if (rtRet != RT_ERROR_NONE) { printf("rtDevBinaryRegister failed, 0x%x\n", rtRet); - return {NULL, NULL}; + return {nullptr, nullptr}; } std::string stubName = name; @@ -75,7 +76,7 @@ registerKernel(const char *name, const void *data, size_t data_size, int shared, if (rtRet != RT_ERROR_NONE) { printf("rtFunctionRegister failed(stubName = %s), 0x%x\n", stubName.c_str(), rtRet); - return {NULL, NULL}; + return {nullptr, nullptr}; } return std::make_tuple(devbinHandle, func_stub_handle); @@ -91,16 +92,16 @@ static PyObject *loadKernelBinary(PyObject *self, PyObject *args) { if (!PyArg_ParseTuple(args, "ss#iis", &name, &data, &data_size, &shared, &device, &kernel_mode)) { - return NULL; + return nullptr; } auto [module_handle, func_handle] = - registerKernel(name, data, data_size, shared, device, kernel_mode); + registerKernel(name, data, data_size, device, kernel_mode); uint64_t mod = reinterpret_cast(module_handle); uint64_t func = reinterpret_cast(func_handle); if (PyErr_Occurred()) { - return NULL; + return nullptr; } return Py_BuildValue("(KKii)", mod, func, 0, 0); @@ -113,10 +114,10 @@ static PyObject *getArch(PyObject *self, PyObject *args) { if (rtRet != RT_ERROR_NONE) { printf("rtGetSocVersion failed, 0x%x", rtRet); - return NULL; + return nullptr; } if (PyErr_Occurred()) { - return NULL; + return nullptr; } return Py_BuildValue("s", name); } @@ -128,10 +129,10 @@ static PyObject *getAiCoreNum(PyObject *self, PyObject *args) { if (rtRet != RT_ERROR_NONE) { printf("rtGetAiCoreCount failed, 0x%x", rtRet); - return NULL; + return nullptr; } if (PyErr_Occurred()) { - return NULL; + return nullptr; } return Py_BuildValue("I", aiCoreCnt); } @@ -143,15 +144,15 @@ static PyObject *createStream(PyObject *self, PyObject *args) { if (rtRet != RT_ERROR_NONE) { printf("rtStreamCreate failed, 0x%x", rtRet); - return NULL; + return nullptr; } if (PyErr_Occurred()) { - return NULL; + return nullptr; } uint64_t stream_uint64 = reinterpret_cast(stream); PyObject *result = Py_BuildValue("K", stream_uint64); - if (result == NULL) { + if (result == nullptr) { rtStreamDestroy(stream); } @@ -196,7 +197,7 @@ static PyObject *readDataFromBinaryFileWrapper(PyObject *self, PyObject *args) { const char *filename; uint64_t arr_ptr; if (!PyArg_ParseTuple(args, "sK", &filename, &arr_ptr)) { - return NULL; + return nullptr; } try { @@ -206,7 +207,7 @@ static PyObject *readDataFromBinaryFileWrapper(PyObject *self, PyObject *args) { return Py_None; } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); - return NULL; + return nullptr; } } @@ -230,7 +231,7 @@ static PyObject *writeDataToBinaryFileWrapper(PyObject *self, PyObject *args) { size_t num_bytes; if (!PyArg_ParseTuple(args, "sKn", &filename, &arr_ptr, &num_bytes)) { - return NULL; + return nullptr; } try { @@ -239,27 +240,27 @@ static PyObject *writeDataToBinaryFileWrapper(PyObject *self, PyObject *args) { return Py_None; } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); - return NULL; + return nullptr; } } static PyObject *allocateHostMemory(PyObject *self, PyObject *args) { uint64_t num_bytes; if (!PyArg_ParseTuple(args, "K", &num_bytes)) { - return NULL; + return nullptr; } - void *host_ptr = NULL; + void *host_ptr = nullptr; rtError_t error = rtMallocHost(&host_ptr, num_bytes, RT_MEMORY_HOST); if (error != RT_ERROR_NONE) { PyErr_Format(PyExc_RuntimeError, "rtMallocHost failed with error code: 0x%x", error); - return NULL; + return nullptr; } PyObject *result = Py_BuildValue("K", (uint64_t)host_ptr); - if (result == NULL) { + if (result == nullptr) { rtFreeHost(host_ptr); } @@ -269,20 +270,20 @@ static PyObject *allocateHostMemory(PyObject *self, PyObject *args) { static PyObject *allocateDeviceMemory(PyObject *self, PyObject *args) { uint64_t num_bytes; if (!PyArg_ParseTuple(args, "K", &num_bytes)) { - return NULL; + return nullptr; } - void *device_ptr = NULL; + void *device_ptr = nullptr; rtError_t error = rtMalloc(&device_ptr, num_bytes, RT_MEMORY_HBM, 0); if (error != RT_ERROR_NONE) { PyErr_Format(PyExc_RuntimeError, "rtMalloc failed with error code: 0x%x", error); - return NULL; + return nullptr; } PyObject *result = Py_BuildValue("K", (uint64_t)device_ptr); - if (result == NULL) { + if (result == nullptr) { rtFree(device_ptr); } @@ -298,7 +299,7 @@ static PyObject *copyMemory(PyObject *self, PyObject *args) { if (!PyArg_ParseTuple(args, "KKns", &dst_ptr, &src_ptr, &count, &direction_str)) { - return NULL; + return nullptr; } if (strcmp(direction_str, "H2D") == 0) { @@ -308,7 +309,7 @@ static PyObject *copyMemory(PyObject *self, PyObject *args) { } else { PyErr_SetString(PyExc_ValueError, "Invalid copy direction. Must be 'H2D' or 'D2H'."); - return NULL; + return nullptr; } void *dst = (void *)dst_ptr; @@ -318,45 +319,13 @@ static PyObject *copyMemory(PyObject *self, PyObject *args) { if (error != RT_ERROR_NONE) { PyErr_Format(PyExc_RuntimeError, "rtMemcpy failed with error code: 0x%x", error); - return NULL; + return nullptr; } Py_INCREF(Py_None); return Py_None; } -static const std::unordered_map LimitTypeMap = { - {"LOW_POWER_TIMEOUT", rtLimitType_t::RT_LIMIT_TYPE_LOW_POWER_TIMEOUT}, - {"WARP_STACK_SIZE", rtLimitType_t::RT_LIMIT_TYPE_SIMT_WARP_STACK_SIZE}, - {"DVG_WARP_STACK_SIZE", - rtLimitType_t::RT_LIMIT_TYPE_SIMT_DVG_WARP_STACK_SIZE}, - {"STACK_SIZE", rtLimitType_t::RT_LIMIT_TYPE_STACK_SIZE}}; - -static PyObject *setDeviceLimit(PyObject *self, PyObject *args) { - int device; // device ID - const char *type_str; - uint32_t val; - if (!PyArg_ParseTuple(args, "isI", &device, &type_str, &val)) { - return NULL; - } - - auto it = LimitTypeMap.find(type_str); - if (it == LimitTypeMap.end()) { - printf("Invalid limit type: %s.\n", type_str); - return NULL; - } - - rtError_t rtRet = rtDeviceSetLimit(device, it->second, val); - if (rtRet != RT_ERROR_NONE) { - printf("rtDeviceSetLimit failed, 0x%x\n", rtRet); - return NULL; - } - if (PyErr_Occurred()) { - return NULL; - } - return Py_None; -} - static PyMethodDef NpuUtilsMethods[] = { {"load_kernel_binary", loadKernelBinary, METH_VARARGS, "Load NPU kernel binary into NPU driver"}, @@ -374,8 +343,7 @@ static PyMethodDef NpuUtilsMethods[] = { "Allocate host memory"}, {"copy_memory", copyMemory, METH_VARARGS, "Copy data between host and device"}, - {"set_device_limit", setDeviceLimit, METH_VARARGS, "Set the limit of NPU"}, - {NULL, NULL, 0, NULL}}; + {nullptr, nullptr, 0, nullptr}}; static PyModuleDef ModuleDef = { PyModuleDef_HEAD_INIT, "npu_utils", @@ -384,8 +352,8 @@ static PyModuleDef ModuleDef = { PyMODINIT_FUNC PyInit_npu_utils(void) { PyObject *m = PyModule_Create(&ModuleDef); - if (m == NULL) { - return NULL; + if (m == nullptr) { + return nullptr; } PyModule_AddFunctions(m, NpuUtilsMethods); diff --git a/third_party/ascend/backend/runtime/autoparser.py b/third_party/ascend/backend/runtime/autoparser.py index 2176d71e59..4ff29ed9e4 100644 --- a/third_party/ascend/backend/runtime/autoparser.py +++ b/third_party/ascend/backend/runtime/autoparser.py @@ -97,6 +97,14 @@ def get_axis(self, var: str, node=None): axis = self.handle_lt_node(var, child_node) elif isinstance(child_node, ast.Assign): axis = self.handle_assign_node(var, child_node) + + elif isinstance(child_node, ast.BinOp) and \ + isinstance(child_node.op, ast.BitAnd): + + axis = self.handle_lt_node(var, child_node.left) + if axis is None: + axis = self.handle_lt_node(var, child_node.right) + if axis is not None: return axis self.checked_vars.append(var) @@ -178,6 +186,13 @@ def __init__(self, func_ast: ast.AST, keys: Dict[str, str], candidates_params: L super().__init__(func_ast, keys) self.split_axes = dict() self.program_id_vars = list() + self.program_id_var_dims = dict() + self.num_programs_var_dims = dict() + self.grid_stride_tiling_only = dict() + # axis_name -> program_id axis dim + self.split_axis_pid_dims = dict() + # axis_name -> program_id axis dim (includes axes inferred without split params) + self.axis_pid_dims = dict() self.candidates_params = candidates_params def parse(self) -> Dict[str, str]: @@ -185,41 +200,226 @@ def parse(self) -> Dict[str, str]: return self.split_axes def visit_Assign(self, node): - if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Attribute): - if isinstance(node.value.func.value, ast.Name): - if node.value.func.value.id == "tl" and node.value.func.attr == "program_id": - if isinstance(node.targets[0], ast.Name) and \ - node.targets[0].id not in self.program_id_vars: - self.program_id_vars.append(node.targets[0].id) + pid_dim = self._get_program_id_dim(node.value) + if pid_dim is not None: + if (len(node.targets) == 1 and isinstance(node.targets[0], ast.Name) + and node.targets[0].id not in self.program_id_vars): + self.program_id_vars.append(node.targets[0].id) + self.program_id_var_dims[node.targets[0].id] = pid_dim + num_programs_dim = self._get_num_programs_dim(node.value) + if num_programs_dim is not None: + if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + self.num_programs_var_dims[node.targets[0].id] = num_programs_dim self.generic_visit(node) def visit_BinOp(self, node): if isinstance(node.op, ast.Mult): split_axes_val = None + split_axis_pid_dim = None if isinstance(node.left, ast.Name) and node.left.id in self.program_id_vars: if isinstance(node.right, ast.Name): split_axes_val = node.right.id + split_axis_pid_dim = self.program_id_var_dims.get(node.left.id) elif isinstance(node.left, ast.Call) and isinstance(node.left.func, ast.Attribute): if node.left.func.value.id == "tl" and \ node.left.func.attr == "program_id": if isinstance(node.right, ast.Name): split_axes_val = node.right.id + split_axis_pid_dim = self._get_program_id_dim(node.left) if isinstance(node.right, ast.Name) and node.right.id in self.program_id_vars: if isinstance(node.left, ast.Name): split_axes_val = node.left.id + split_axis_pid_dim = self.program_id_var_dims.get(node.right.id) elif isinstance(node.right, ast.Call) and isinstance(node.right.func, ast.Attribute): if node.right.func.value.id == "tl" and node.right.func.attr == "program_id": if isinstance(node.left, ast.Name): split_axes_val = node.left.id + split_axis_pid_dim = self._get_program_id_dim(node.right) if split_axes_val in self.candidates_params and \ split_axes_val not in self.split_axes.values(): split_axes_key = self.get_axis(split_axes_val) - if split_axes_key: + if split_axes_key and not self._is_tiling_only_split(split_axes_key, split_axes_val): self.split_axes[split_axes_key] = split_axes_val + if split_axis_pid_dim is not None: + self._record_axis_pid_dim(split_axes_key, split_axis_pid_dim) + self.generic_visit(node) + + def visit_For(self, node): + if not isinstance(node.iter, ast.Call): + self.generic_visit(node) + return + + iter_fn = node.iter.func + is_range = isinstance(iter_fn, ast.Name) and iter_fn.id == "range" + is_tl_range = (isinstance(iter_fn, ast.Attribute) and isinstance(iter_fn.value, ast.Name) + and iter_fn.value.id == "tl" and iter_fn.attr == "range") + if not (is_range or is_tl_range): + self.generic_visit(node) + return + + if len(node.iter.args) == 0: + self.generic_visit(node) + return + + start = node.iter.args[0] if len(node.iter.args) >= 2 else None + stop = node.iter.args[1] if len(node.iter.args) >= 2 else node.iter.args[0] + pid_dim = self._extract_pid_dim_from_expr(start) + axis = self._axis_from_expr(stop) + if axis is not None and pid_dim is not None: + self._record_axis_pid_dim(axis, pid_dim) + if len(node.iter.args) >= 3: + step = node.iter.args[2] + loop_tiling_only_param = self._extract_grid_stride_split_param(start, step, pid_dim) + if loop_tiling_only_param is not None: + self._mark_tiling_only_param(axis, loop_tiling_only_param) + self.generic_visit(node) + def _get_program_id_dim(self, node): + if not (isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and isinstance( + node.func.value, ast.Name) and node.func.value.id == "tl" and node.func.attr == "program_id"): + return None + + axis_dim = 0 + if len(node.args) > 0: + if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, int): + axis_dim = node.args[0].value + else: + return None + + for kw in node.keywords: + if kw.arg == "axis": + if isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, int): + axis_dim = kw.value.value + else: + return None + break + return axis_dim + + def _get_num_programs_dim(self, node): + if not (isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and isinstance( + node.func.value, ast.Name) and node.func.value.id == "tl" and node.func.attr == "num_programs"): + return None + + axis_dim = 0 + if len(node.args) > 0: + if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, int): + axis_dim = node.args[0].value + else: + return None + + for kw in node.keywords: + if kw.arg == "axis": + if isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, int): + axis_dim = kw.value.value + else: + return None + break + return axis_dim + + def _extract_pid_dim_from_expr(self, node): + if node is None: + return None + for child in ast.walk(node): + if isinstance(child, ast.Name) and child.id in self.program_id_var_dims: + return self.program_id_var_dims[child.id] + pid_dim = self._get_program_id_dim(child) + if pid_dim is not None: + return pid_dim + return None + + def _contains_pid_dim(self, node, pid_dim): + if node is None: + return False + for child in ast.walk(node): + if isinstance(child, ast.Name): + if self.program_id_var_dims.get(child.id, None) == pid_dim: + return True + if self._get_program_id_dim(child) == pid_dim: + return True + return False + + def _contains_num_programs_dim(self, node, pid_dim): + if node is None: + return False + for child in ast.walk(node): + if isinstance(child, ast.Name): + if self.num_programs_var_dims.get(child.id, None) == pid_dim: + return True + if self._get_num_programs_dim(child) == pid_dim: + return True + return False + + def _is_candidate_name(self, node, candidate_name): + return (isinstance(node, ast.Name) and node.id == candidate_name and candidate_name in self.candidates_params) + + def _extract_pid_multiplied_candidate(self, node, pid_dim): + if node is None: + return None + candidates = set() + for child in ast.walk(node): + if not isinstance(child, ast.BinOp) or not isinstance(child.op, ast.Mult): + continue + left = child.left + right = child.right + if isinstance(left, ast.Name) and left.id in self.candidates_params and \ + self._contains_pid_dim(right, pid_dim): + candidates.add(left.id) + if isinstance(right, ast.Name) and right.id in self.candidates_params and \ + self._contains_pid_dim(left, pid_dim): + candidates.add(right.id) + if len(candidates) == 1: + return next(iter(candidates)) + return None + + def _contains_num_programs_multiplied_candidate(self, node, candidate_name, pid_dim): + if node is None: + return False + for child in ast.walk(node): + if not isinstance(child, ast.BinOp) or not isinstance(child.op, ast.Mult): + continue + if self._is_candidate_name(child.left, candidate_name): + if self._contains_num_programs_dim(child.right, pid_dim): + return True + if self._is_candidate_name(child.right, candidate_name): + if self._contains_num_programs_dim(child.left, pid_dim): + return True + return False + + def _extract_grid_stride_split_param(self, start, step, pid_dim): + if start is None or step is None: + return None + candidate_name = self._extract_pid_multiplied_candidate(start, pid_dim) + if candidate_name is None: + return None + if self._contains_num_programs_multiplied_candidate(step, candidate_name, pid_dim): + return candidate_name + return None + + def _mark_tiling_only_param(self, axis, candidate_name): + self.grid_stride_tiling_only.setdefault(axis, set()).add(candidate_name) + if self.split_axes.get(axis, None) == candidate_name: + del self.split_axes[axis] + self.split_axis_pid_dims.pop(axis, None) + + def _is_tiling_only_split(self, axis, candidate_name): + return candidate_name in self.grid_stride_tiling_only.get(axis, set()) + + def _axis_from_expr(self, node): + if node is None: + return None + for k, v in self.keys.items(): + if self.contains_target_var(node, v): + return k + return None + + def _record_axis_pid_dim(self, axis, pid_dim): + self.axis_pid_dims[axis] = pid_dim + if axis in self.split_axes: + self.split_axis_pid_dims[axis] = pid_dim + class TilingAxesParser(AxesKeyParser): """ @@ -262,12 +462,10 @@ def parse(self) -> Dict[str, str]: return self.tiling_axes def visit_For(self, node): - if isinstance(node.iter, ast.Call) and \ - len(node.iter.args) == 3 and \ - isinstance(node.iter.args[2], ast.Name): - for_loop_param = node.iter.args[2].id - if for_loop_param in self.candidates_params and \ - for_loop_param not in self.candidates_params_for_loop: + if isinstance(node.iter, ast.Call) and len(node.iter.args) == 3: + step_expr = node.iter.args[2] + for_loop_param = self._extract_unique_candidate(step_expr) + if (for_loop_param is not None and for_loop_param not in self.candidates_params_for_loop): self.candidates_params_for_loop.append(for_loop_param) self.generic_visit(node) @@ -276,10 +474,10 @@ def visit_Assign(self, node): # handle FloorDiv if isinstance(node.value, ast.BinOp) and isinstance(node.value.op, ast.FloorDiv): denominator = node.value.right - if isinstance(denominator, ast.Name) and \ - denominator.id in self.candidates_params and \ - denominator.id not in self.candidates_params_for_loop: - self.candidates_params_for_loop.append(denominator.id) + denominator_param = self._extract_unique_candidate(denominator) + if denominator_param is not None and \ + denominator_param not in self.candidates_params_for_loop: + self.candidates_params_for_loop.append(denominator_param) self.visit(self.func_ast) tiling_axes_val = self.get_tiling_axes_val(node.value) @@ -312,6 +510,18 @@ def get_tiling_axes_val(self, node): return val return None + def _extract_unique_candidate(self, expr): + """ + Extract a unique tiling candidate from an expression. + Return None when no candidate or ambiguous (more than one candidate) appears. + """ + if expr is None: + return None + candidates = [param for param in self.candidates_params if self.contains_target_var(expr, param)] + if len(candidates) == 1: + return candidates[0] + return None + class ReductionAxesParser(AxesKeyParser): """ @@ -344,11 +554,38 @@ def __init__(self, func_ast: ast.AST, keys: Dict[str, str]): super().__init__(func_ast, keys) self.reduction_axes = list() self.reduction_func = ('sum', 'xor_sum', 'max', 'min', 'argmax', 'argmin') # tl.xxx + self.ndim = 1 def parse(self) -> List[str]: super().parse() return self.reduction_axes + def visit_Assign(self, node): + self._scan_subscripts(node.value) + self.generic_visit(node) + + def _scan_subscripts(self, node): + if isinstance(node, ast.Subscript): + ndim = self._get_subscripts_ndim(node) + if ndim > self.ndim: + self.ndim = ndim + + for child in ast.iter_child_nodes(node): + self._scan_subscripts(child) + + def _get_subscripts_ndim(self, subscript_node): + slice_node = subscript_node.slice + + if isinstance(slice_node, ast.Tuple): + # e.g. [:, None] -> Tuple(elts=[Slice(), Constant(None)]) + return len(slice_node.elts) + elif isinstance(slice_node, (ast.Slice, ast.Constant, ast.Name, ast.UnaryOp, ast.BinOp)): + # e.g. [0], [:], [i], [-1], [i+1] + return 1 + else: + # Fallback: treat as 1D + return 1 + def visit_Call(self, node): if not isinstance(node.func, ast.Attribute): return @@ -359,22 +596,42 @@ def visit_Call(self, node): if func.attr not in self.reduction_func: return + axis_dim = None args = node.args if len(args) == 1: - keywords = node.keywords - for keyword in keywords: + # Axis passed as keyword argument + for keyword in node.keywords: if keyword.arg == 'axis': - if isinstance(keyword.value, ast.Constant): - axis_dim = keyword.value.value + axis_dim = self.get_axis_dim(keyword.value) + break + elif len(args) == 2: - if isinstance(args[1], ast.Constant): # check the second param - axis_dim = args[1].value + # Axis passed as positional argument. Check the second param + axis_dim = self.get_axis_dim(args[1]) + else: - return + raise ValueError("Reduction funtions args error") + + if axis_dim is not None: + reduction_axis = self.get_axis(axis_dim) + if reduction_axis and reduction_axis not in self.reduction_axes: + self.reduction_axes.append(reduction_axis) + + def get_axis_dim(self, node): + if isinstance(node, ast.Constant): + axis_dim = node.value + elif isinstance(node, ast.UnaryOp) and \ + isinstance(node.op, ast.USub): + operand = node.operand + if isinstance(operand, ast.Constant): + axis_dim = self.ndim - operand.value + else: + raise ValueError(f"Reduction function axis error, got: {ast.dump(node)}") - reduction_axis = self.get_axis(axis_dim) - if reduction_axis and reduction_axis not in self.reduction_axes: - self.reduction_axes.append(reduction_axis) + if not isinstance(axis_dim, int): + raise ValueError("Reduction function axis must be an integer, " + f"got {type(node.value).__name__}: {node.value}") + return axis_dim def get_axis(self, axis_dim: int): """ diff --git a/third_party/ascend/backend/runtime/autotuner.py b/third_party/ascend/backend/runtime/autotuner.py index 3ea0381054..0e72d51acb 100644 --- a/third_party/ascend/backend/runtime/autotuner.py +++ b/third_party/ascend/backend/runtime/autotuner.py @@ -23,16 +23,21 @@ from __future__ import annotations import builtins +import copy +import functools +import ast import os import time -import copy +from concurrent.futures import ThreadPoolExecutor from typing import Dict, List + from torch import Tensor +import triton from triton.runtime.autotuner import Autotuner, Config +from .autoparser import (LowDimsAxesParser, PtrNumsParser, ReductionAxesParser, SplitAxesParser, TilingAxesParser) from .utils import get_byte_per_numel, is_valid_axis_name, valid_axis_names -from .autoparser import SplitAxesParser, TilingAxesParser, ReductionAxesParser, LowDimsAxesParser, PtrNumsParser class AutoTilingTuner(Autotuner): @@ -88,7 +93,13 @@ def __init__( tiling_params = self.hints.get("tiling_params", None) low_dim_axes = self.hints.get("low_dim_axes", None) reduction_axes = self.hints.get("reduction_axes", None) - self._init_axis_params(key, split_params, tiling_params, low_dim_axes, reduction_axes) + self._init_axis_params( + key, + split_params, + tiling_params, + low_dim_axes, + reduction_axes, + ) self.auto_gen_config = not configs or self.hints.get("auto_gen_config", False) self.gen_configs = [] # generated configs from TileGenerator @@ -98,8 +109,13 @@ def __init__( else: self.user_configs = configs self.is_simt_mode = False + self.simt_stack_limit = 8192 self.user_specified_warps = None self.print_autotuning = os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" + # Compile kernels in parallel by default for triton.runtime.JITFunction, + # but not for others, e.g., LibEntry, since it's not compatible with AsyncCompileMode + self.compile_parallel = (isinstance(self.fn, triton.runtime.JITFunction) + and os.getenv("TRITON_AUTOTUNE_PARALLEL_COMPILE", "1") == "1") def _init_axis_params(self, key, split_params, tiling_params, low_dim_axes, reduction_axes): if isinstance(key, list): @@ -138,9 +154,15 @@ def _init_axis_params(self, key, split_params, tiling_params, low_dim_axes, redu set(self.keys.keys()))) self.split_params = split_params + self.all_split_params = {} + self.fixed_split_params = {} self.tiling_params = tiling_params self.low_dim_axes = low_dim_axes self.reduction_axes = reduction_axes + self.fixed_grid_dims = set() + self.fixed_grid_dim_values = {} + self.split_axis_pid_dims = {} + self.axis_pid_dims = {} self.dual_reduction = False self.persistent_reduction = False self.num_buffers = -1 @@ -167,7 +189,42 @@ def _autoparse_axis_params(self, all_args): self.persistent_reduction = True if not self.split_params: - self.split_params = self._autoparse_split_params(miss_params) + all_split_params = self._autoparse_split_params(self._get_constexpr_candidates()) + self.all_split_params = dict(all_split_params) + self.fixed_split_params = {} + self.fixed_grid_dim_values = self._get_fixed_grid_dim_values( + all_args.get("grid", None), + all_args, + ) + self.fixed_grid_dims = set(self.fixed_grid_dim_values.keys()) + + fixed_grid_axes = {axis for axis, pid_dim in self.axis_pid_dims.items() if pid_dim in self.fixed_grid_dims} + + # Only missing constexpr params are tunable, and fixed-grid axes + # should not be tuned on split. + self.split_params = { + axis: param + for axis, param in all_split_params.items() + if param in miss_params and axis not in fixed_grid_axes + } + + # Fixed split is inferred only from fixed grid dims. + for axis, pid_dim in self.axis_pid_dims.items(): + if pid_dim not in self.fixed_grid_dims: + continue + core_num = self.fixed_grid_dim_values.get(pid_dim, 0) + axis_len_name = self.keys.get(axis, None) + axis_len = all_args.get(axis_len_name, None) + if not isinstance(core_num, int) or core_num <= 0: + continue + if not isinstance(axis_len, int) or axis_len <= 0: + continue + + self.fixed_split_params[axis] = (axis_len + core_num - 1) // core_num + elif not self.axis_pid_dims: + # When split axes are provided by hints, parse axis->program_id mapping + # independently for fixed-grid semantics and diagnostics. + self._autoparse_axis_pid_dims() miss_params = [arg for arg in miss_params if arg not in self.split_params.values()] if not self.tiling_params: self.tiling_params = self._autoparse_tiling_params(miss_params) @@ -191,6 +248,7 @@ def _gen_tile_configs(self, kv_dict: Dict[str, int], dtype: torch.dtype) -> List kernel_meta = KernelMeta( axis_sizes, self.split_params, + self.fixed_split_params, self.tiling_params, self.low_dim_axes, dtype, @@ -272,6 +330,8 @@ def generate_key_and_configs(self, *args, **kwargs): def run(self, *args, **kwargs): key = self.generate_key_and_configs(*args, **kwargs) + if self.is_simt_mode and kwargs.get('simt_stack_limit', None) is None: + kwargs['simt_stack_limit'] = self.simt_stack_limit used_cached_result = True if key not in self.cache: # prune configs @@ -319,14 +379,40 @@ def _batch_bench(self, *args, configs, **kwargs): exc = None exc_stack = "" - for config, fn in kernels_call.items(): + if self.compile_parallel: + import psutil + + max_workers = min(psutil.cpu_count(logical=False) // 2, len(kernels_call)) + future_kernels = [] try: - fn() - run_fns[config] = fn - except (CompileTimeAssertionFailure, MLIRCompilationError, OutOfResources) as e: - import traceback - exc_stack = traceback.format_exc() - exc = e + with ( + ThreadPoolExecutor(max_workers=max_workers) as executor, + triton.AsyncCompileMode(executor), + ): + for config, fn in kernels_call.items(): + future_kernels.append((config, fn(warmup=True))) + + for config, fut in future_kernels: + try: + if hasattr(fut, "result"): + fut = fut.result() + run_fns[config] = functools.partial(kernels_call[config], warmup=False) + except (CompileTimeAssertionFailure, MLIRCompilationError) as e: + import traceback + exc_stack = traceback.format_exc() + exc = e + except Exception as e: + # ignore exception from __exit__() of AsyncCompileMode + triton.runtime._async_compile.active_mode.set(None) + else: + for config, fn in kernels_call.items(): + try: + fn(warmup=False) + run_fns[config] = functools.partial(fn, warmup=False) + except (CompileTimeAssertionFailure, MLIRCompilationError, OutOfResources) as e: + import traceback + exc_stack = traceback.format_exc() + exc = e if len(run_fns) == 0: raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc} \nStack trace: {exc_stack}") @@ -356,15 +442,18 @@ def _make_kernel_call(self, *args, config, **meta): current = dict(meta, **config.all_kwargs()) full_nargs = {**self.nargs, **current} - def kernel_call(): + def kernel_call(warmup): if config.pre_hook: config.pre_hook(full_nargs) self.pre_hook(full_nargs) try: - self.fn.run( + current.update({"warmup": warmup}) + res = self.fn.run( *args, **current, ) + if warmup: + return res except Exception as e: try: self.post_hook(full_nargs, exception=e) @@ -380,8 +469,19 @@ def warmup(self, *args, **kwargs): _ = self.generate_key_and_configs(*args, **kwargs) pruned_configs = self.prune_configs(kwargs) ret = [] - for config in pruned_configs: - ret.append(self.fn.warmup(*args, **kwargs, **config.all_kwargs())) + if self.compile_parallel: + import psutil + + max_workers = min(psutil.cpu_count(logical=False) // 2, len(pruned_configs)) + with ( + ThreadPoolExecutor(max_workers=max_workers) as executor, + triton.AsyncCompileMode(executor), + ): + for config in pruned_configs: + ret.append(self.fn.warmup(*args, **kwargs, **config.all_kwargs())) + else: + for config in pruned_configs: + ret.append(self.fn.warmup(*args, **kwargs, **config.all_kwargs())) self.nargs = None return ret @@ -389,7 +489,8 @@ def _profile(self, *args, config, **meta): from ..testing import do_bench_npu kernel_call = self._make_kernel_call(*args, config=config, **meta) - do_bench_npu(kernel_call, prof_dir=self.auto_profile_dir, keep_res=True) + fn = functools.partial(kernel_call, warmup=False) + do_bench_npu(fn, prof_dir=self.auto_profile_dir, keep_res=True) def _autoparse_split_params(self, candidates_params: List[str]) -> Dict[str, str]: """ @@ -398,10 +499,139 @@ def _autoparse_split_params(self, candidates_params: List[str]) -> Dict[str, str func_ast = self.fn.parse() parser = SplitAxesParser(func_ast, self.keys, candidates_params) split_axes = parser.parse() + self.split_axis_pid_dims = dict(getattr(parser, "split_axis_pid_dims", {})) + self.axis_pid_dims = dict(getattr(parser, "axis_pid_dims", {})) if self.print_autotuning: - print(f"Ascend autotuning parse split axes: {split_axes}") + print(f"Ascend autotuning parse split axes: {split_axes}, " + f"split axis pid dims: {self.split_axis_pid_dims}, " + f"axis pid dims: {self.axis_pid_dims}") return split_axes + def _autoparse_axis_pid_dims(self) -> Dict[str, int]: + """ + Extract axis -> program_id dim mapping without relying on split-parameter + classification, so fixed-grid semantics can always consume it. + """ + func_ast = self.fn.parse() + parser = SplitAxesParser( + func_ast, + self.keys, + self._get_constexpr_candidates(), + ) + _ = parser.parse() + self.axis_pid_dims = dict(getattr(parser, "axis_pid_dims", {})) + self.split_axis_pid_dims = dict(getattr(parser, "split_axis_pid_dims", {})) + if self.print_autotuning: + print("Ascend autotuning parse axis pid dims (independent): " + f"{self.axis_pid_dims}") + return self.axis_pid_dims + + def _get_constexpr_candidates(self) -> List[str]: + """ + Returns all constexpr parameter names from the kernel function definition. + """ + func_ast = self.fn.parse() + constexpr_names = [] + for node in ast.walk(func_ast): + if not isinstance(node, ast.FunctionDef): + continue + if not isinstance(node.args, ast.arguments): + continue + for arg in node.args.args: + if not isinstance(arg, ast.arg): + continue + ann = arg.annotation + if (isinstance(ann, ast.Attribute) and isinstance(ann.value, ast.Name) and ann.value.id == "tl" + and ann.attr == "constexpr"): + constexpr_names.append(arg.arg) + break + return constexpr_names + + def _get_fixed_grid_dim_values(self, grid, all_args: Dict[str, object] = None) -> Dict[int, int]: + """ + Returns fixed grid dim -> value. + - Static tuple/list grid: direct extraction + - Callable grid: infer fixed dims by perturbing missing constexpr params + """ + if grid is None: + return {} + if callable(grid): + return self._infer_fixed_dims_from_callable_grid(grid, all_args or {}) + return self._extract_fixed_grid_dims(grid) + + def _extract_fixed_grid_dims(self, grid) -> Dict[int, int]: + if isinstance(grid, int): + grid = (grid, ) + if not isinstance(grid, (tuple, list)): + return {} + fixed_dims = {} + for idx, dim in enumerate(grid): + if isinstance(dim, int) and dim > 0: + fixed_dims[idx] = dim + return fixed_dims + + def _normalize_grid_tuple(self, grid_out): + if isinstance(grid_out, int): + return (grid_out, ) + if isinstance(grid_out, (tuple, list)): + return tuple(grid_out) + return None + + def _infer_fixed_dims_from_callable_grid(self, grid_fn, all_args: Dict[str, object]) -> Dict[int, int]: + constexpr_candidates = self._get_constexpr_candidates() + base_meta = dict(all_args or {}) + + # Fill missing constexpr with stable probe defaults so grid(meta) can execute. + for name in constexpr_candidates: + if name not in base_meta: + base_meta[name] = 128 + + try: + base_grid_raw = grid_fn(dict(base_meta)) + except Exception: + return {} + + base_grid = self._normalize_grid_tuple(base_grid_raw) + if base_grid is None: + return {} + + dynamic_dims = set() + # Missing constexpr are tunable candidates. + tunable_probe_names = [name for name in constexpr_candidates if name not in (all_args or {})] + probe_values = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + + for name in tunable_probe_names: + baseline = base_meta.get(name, 128) + for probe in probe_values: + if probe == baseline: + continue + probe_meta = dict(base_meta) + probe_meta[name] = probe + try: + probe_grid_raw = grid_fn(probe_meta) + except Exception: + continue + probe_grid = self._normalize_grid_tuple(probe_grid_raw) + if probe_grid is None: + continue + if len(probe_grid) != len(base_grid): + dynamic_dims.update(range(min(len(probe_grid), len(base_grid)))) + continue + for idx, (base_dim, probe_dim) in enumerate(zip(base_grid, probe_grid)): + if not (isinstance(base_dim, int) and isinstance(probe_dim, int)): + dynamic_dims.add(idx) + continue + if base_dim != probe_dim: + dynamic_dims.add(idx) + + fixed_dims = {} + for idx, dim in enumerate(base_grid): + if idx in dynamic_dims: + continue + if isinstance(dim, int) and dim > 0: + fixed_dims[idx] = dim + return fixed_dims + def _autoparse_tiling_params(self, candidates_params: List[str]) -> Dict[str, str]: """ Extracts the tiling axis parameters from triton kernel code. diff --git a/third_party/ascend/backend/runtime/tile_generator.py b/third_party/ascend/backend/runtime/tile_generator.py index 9adf013953..a4e2a9a83c 100644 --- a/third_party/ascend/backend/runtime/tile_generator.py +++ b/third_party/ascend/backend/runtime/tile_generator.py @@ -52,7 +52,9 @@ class AxisInfo: split_name: str = "" tiling_name: str = "" is_split_axis: bool = False + is_tunable_split_axis: bool = False is_tiling_axis: bool = False + fixed_split_size: int = 0 @property def is_reduction(self): @@ -65,6 +67,7 @@ def __init__( self, axis_sizes: Dict[str, int], split_params: Dict[str, str], + fixed_split_params: Dict[str, int], tiling_params: Dict[str, str], low_dims: List[str], dtype: torch.dtype, @@ -90,7 +93,7 @@ def __init__( :param dual_reduction: performing reduction on more than one axis. :param persistent_reduction: there is no splitting in reduction axis. """ - self._validate_axis(axis_sizes, split_params, tiling_params, low_dims) + self._validate_axis(axis_sizes, split_params, fixed_split_params, tiling_params, low_dims) axis_dict = {} idx = 0 @@ -99,9 +102,11 @@ def __init__( if name.startswith("r"): prefix = "r" - is_split_axis = name in split_params + is_tunable_split_axis = name in split_params + fixed_split_size = fixed_split_params.get(name, 0) + is_split_axis = is_tunable_split_axis or fixed_split_size > 0 is_tiling_axis = name in tiling_params - split_name = "" if name not in split_params else split_params[name] + split_name = "" if not is_tunable_split_axis else split_params[name] tiling_name = "" if name not in tiling_params else tiling_params[name] axis_dict[name] = AxisInfo( @@ -112,12 +117,15 @@ def __init__( split_name=split_name, tiling_name=tiling_name, is_split_axis=is_split_axis, + is_tunable_split_axis=is_tunable_split_axis, is_tiling_axis=is_tiling_axis, + fixed_split_size=fixed_split_size, ) idx += 1 self.axis_info = list(axis_dict.values()) self.split_axis = [x for x in axis_dict.values() if x.is_split_axis] + self.tunable_split_axis = [x for x in axis_dict.values() if x.is_tunable_split_axis] self.tiling_axis = [x for x in axis_dict.values() if x.is_tiling_axis] self.low_dims_axis = [x for x in axis_dict.values() if x.name in low_dims] self.dtype = dtype @@ -131,6 +139,7 @@ def _validate_axis( cls, axis_sizes: Dict[str, int], split_params: Dict[str, str], + fixed_split_params: Dict[str, int], tiling_params: Dict[str, str], low_dims: List[str], ) -> None: @@ -144,6 +153,7 @@ def check_keys(params: List[str], context="parameter"): raise KeyError(f"{context} '{k}' not found in known axes: {axis_sizes.keys()}") check_keys(split_params.keys(), "split axis") + check_keys(fixed_split_params.keys(), "fixed split axis") check_keys(tiling_params.keys(), "tiling axis") check_keys(low_dims, "low dim axis") @@ -182,10 +192,16 @@ def __init__(self, kernel_meta: KernelMeta): self.max_numel_threshold = local_mem_size * 1024 // self.dtype_bytes // self.num_buffers self.max_total_numel = functools.reduce(lambda x, y: x * y, [x.block_size for x in self.blocks]) if self.blocks else 1 - self.tiny_kernel = self.max_total_numel < 128 * 1024 + self.small_kernel = self.max_total_numel < 128 * 1024 + self.tiny_kernel = self.max_total_numel <= 32 * 1024 self.stop_numel = min(1024 // self.dtype_bytes, self.max_total_numel // - (num_vector_core * 2)) if self.tiny_kernel else 1024 // self.dtype_bytes + (num_vector_core * 2)) if self.small_kernel else 1024 // self.dtype_bytes self.max_programs_num = 65535 + self.tiny_program_threshold = num_vector_core // 8 + self.tiny_per_program_cap = 1 + self.tiny_low_program_hist = {p: 0 for p in range(1, self.tiny_program_threshold + 1)} + self.tiny_low_program_active = False + self.tiny_low_program_tile_floor = 0 @classmethod def init_blocks_info(cls, kernel_meta: KernelMeta) -> List[BlockInfo]: @@ -193,7 +209,7 @@ def init_blocks_info(cls, kernel_meta: KernelMeta) -> List[BlockInfo]: for axis in kernel_meta.axis_info: block_name = axis.split_name sub_block_name = axis.tiling_name - block_size = axis.length + block_size = axis.fixed_split_size if axis.fixed_split_size > 0 else axis.length sub_block_size = block_size blocks.append(BlockInfo(block_name, sub_block_name, block_size, sub_block_size)) @@ -213,6 +229,7 @@ def calcu_last_split_blocks(self, axis_idx): break last_splits = num_vector_core // splits + last_splits = max(1, last_splits) last_blocks = (self.numels[axis_idx] + last_splits - 1) // last_splits return last_blocks @@ -244,10 +261,12 @@ def fill_config(self, cfg, candi_block): curr_numel = candi_block[axis.index] if not axis.is_tiling_axis: curr_numel = self.aligned_numel(curr_numel) - cfg[block_info.block_name] = curr_numel + if block_info.block_name: + cfg[block_info.block_name] = curr_numel if axis.is_tiling_axis: tiling_numel = self.aligned_numel(block_info.sub_block_size) - cfg[block_info.sub_block_name] = min(tiling_numel, candi_block[axis.index]) + cfg[block_info.sub_block_name] = (tiling_numel if self.is_simt_mode else min( + tiling_numel, candi_block[axis.index])) def find_config(self, cfg): for config_var in self.configs: @@ -255,11 +274,44 @@ def find_config(self, cfg): return True return False + def _try_add_tiny_low_program_config(self, total_programs): + if (not self.tiny_kernel or total_programs < 1 or total_programs > self.tiny_program_threshold): + return + + if self.tiny_low_program_hist.get(total_programs, 0) >= self.tiny_per_program_cap: + return + + candi_block = tuple([x.block_size for x in self.blocks]) + if self.add_to_configs(list(candi_block)): + if candi_block not in self.candidate_blocks: + self.candidate_blocks.append(candi_block) + if not self.tiny_low_program_active: + self.tiny_low_program_active = True + self.tiny_low_program_tile_floor = self.calculate_tile_numel() + self.tiny_low_program_hist[total_programs] = (self.tiny_low_program_hist.get(total_programs, 0) + 1) + + def _calc_total_programs(self, candi_block=None): + grids = [] + for axis in self.kernel_meta.split_axis: + numel = self.numels[axis.index] + block_size = (self.blocks[axis.index].block_size if candi_block is None else candi_block[axis.index]) + programs = (numel + block_size - 1) // block_size + grids.append(programs) + + total_programs = functools.reduce(lambda x, y: x * y, grids) if grids else 1 + return total_programs + def add_to_configs(self, candi_block): newcfg = {} self.fill_config(newcfg, candi_block) tile_numel = self.calculate_tile_numel() - stop_numel_threshold = 0 if len(self.configs) < 10 or self.tiny_kernel else self.stop_numel + 100 + stop_numel_threshold = 0 if len(self.configs) < 10 or self.small_kernel else self.stop_numel + 100 + if self.tiny_low_program_active and self.tiny_low_program_tile_floor > 0: + total_programs = self._calc_total_programs(candi_block) + program_threshold = self.tiny_program_threshold if self.small_kernel else num_vector_core // 2 + if total_programs <= program_threshold: + tiny_low_program_threshold = max(self.stop_numel, self.tiny_low_program_tile_floor // 2) + stop_numel_threshold = max(stop_numel_threshold, tiny_low_program_threshold) if (tile_numel <= self.max_numel_threshold and tile_numel >= stop_numel_threshold and not self.find_config(newcfg)): self.configs.append(Config(newcfg, num_warps=1, num_stages=1)) @@ -330,19 +382,21 @@ def calc_total_programs(): self.candidate_blocks.append(tuple([x.block_size for x in self.blocks])) break - program_threshold = num_vector_core // 8 if self.tiny_kernel else num_vector_core // 2 + program_threshold = self.tiny_program_threshold if self.small_kernel else num_vector_core // 2 + if self.tiny_kernel and total_programs <= program_threshold: + self._try_add_tiny_low_program_config(total_programs) if total_programs > program_threshold or self.dual_reduction: if len(self.candidate_blocks) > 2: self.candidate_blocks.pop(0) self.candidate_blocks.append(tuple([x.block_size for x in self.blocks])) - if self.tiny_kernel: + if self.small_kernel: self.add_to_configs(list(tuple([x.block_size for x in self.blocks]))) slow_decend_split = (total_programs > num_vector_core_tile // 2) if not slow_decend_split: - self.blocks[axis_idx].block_size = numel // 2 + self.blocks[axis_idx].block_size = (numel + 1) // 2 else: - step = numel // 4 if numel // 4 > 1 else 1 + step = (numel + 3) // 4 if (numel + 3) // 4 > 1 else 1 self.blocks[axis_idx].block_size = numel - step self.blocks[axis_idx].sub_block_size = self.blocks[axis_idx].block_size total_programs = calc_total_programs() @@ -404,7 +458,7 @@ def descend_split_tiling(self): tiling_not_low_dims = [x for x in self.kernel_meta.tiling_axis if x not in self.kernel_meta.low_dims_axis] def descend_split_axis(): - for axis in self.kernel_meta.split_axis: + for axis in self.kernel_meta.tunable_split_axis: if self.descend_one_axis(axis.index, is_split=True): return True diff --git a/third_party/ascend/backend/runtime/utils.py b/third_party/ascend/backend/runtime/utils.py index f470fc61da..b5e3eae719 100644 --- a/third_party/ascend/backend/runtime/utils.py +++ b/third_party/ascend/backend/runtime/utils.py @@ -41,11 +41,11 @@ def _init_npu_params(): ub_size_in_kbytes = 192 rf_size_in_kbytes = None - ASCEND_VARIANTS = ["Ascend910B", "Ascend910_93", "Ascend910_95"] + ASCEND_VARIANTS = ["Ascend910B", "Ascend910_93", "Ascend910_95", "Ascend950"] if any(variant in target.arch for variant in ASCEND_VARIANTS): num_vector_core = num_cube_core * 2 - if '910_95' in target.arch: + if target.arch.startswith("Ascend910_95") or target.arch.startswith("Ascend950"): ub_size_in_kbytes = 256 rf_size_in_kbytes = 128 diff --git a/third_party/ascend/backend/spec/include/runtime/libentry/libentry.h b/third_party/ascend/backend/spec/include/runtime/libentry/libentry.h index 2b7690ac01..730b3e6072 100644 --- a/third_party/ascend/backend/spec/include/runtime/libentry/libentry.h +++ b/third_party/ascend/backend/spec/include/runtime/libentry/libentry.h @@ -1,5 +1,4 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * Copyright 2018-2020 Philippe Tillet * Copyright 2020-2022 OpenAI * diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h index 3dc7a3b844..457a751fbd 100644 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h +++ b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h @@ -1,5 +1,4 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * Copyright 2018-2020 Philippe Tillet * Copyright 2020-2022 OpenAI * diff --git a/third_party/ascend/backend/spec/lib/runtime/libentry/libentry.cpp b/third_party/ascend/backend/spec/lib/runtime/libentry/libentry.cpp index ab9c81660e..1e2a9071e9 100644 --- a/third_party/ascend/backend/spec/lib/runtime/libentry/libentry.cpp +++ b/third_party/ascend/backend/spec/lib/runtime/libentry/libentry.cpp @@ -1,3 +1,26 @@ +/* + * Copyright 2018-2020 Philippe Tillet + * Copyright 2020-2022 OpenAI + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + #include "runtime/libentry/libentry.h" using namespace libentry; diff --git a/third_party/ascend/backend/spec/triton/compiler/compiler.py b/third_party/ascend/backend/spec/triton/compiler/compiler.py index cc7ba30e7c..0630699daa 100644 --- a/third_party/ascend/backend/spec/triton/compiler/compiler.py +++ b/third_party/ascend/backend/spec/triton/compiler/compiler.py @@ -169,6 +169,11 @@ def triton_key(): return f'{__version__}' + '-'.join(contents) +def get_cache_key(src, backend, backend_options, env_vars): + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}" + return key + + def parse(full_name, ext, context): if ext == "ttir" or ext == "ttgir": module = ir.parse_mlir_module(full_name, context) @@ -217,7 +222,7 @@ def filter_traceback(e: BaseException): e.__traceback__ = frames[0] -def compile(src, target=None, options=None): +def compile(src, target=None, options=None, _env_vars=None): if target is None: target = driver.active.get_current_target() assert isinstance(target, GPUTarget), "target must be of GPUTarget type" @@ -230,8 +235,8 @@ def compile(src, target=None, options=None): extra_options = src.parse_options() options = backend.parse_options(dict(options or dict(), **extra_options)) # create cache manager - env_vars = get_cache_invalidating_env_vars() - key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars + key = get_cache_key(src, backend, options, env_vars=env_vars) hash = hashlib.sha256(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) # For dumping/overriding only hash the source as we want it to be independent of triton @@ -291,7 +296,11 @@ def compile(src, target=None, options=None): else: stage_name = "MLIRCompile" error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e) - error_detail += f"\n\n[INFO]: The compiled kernel cache is in {fn_cache_manager.cache_dir}\n\n" + from ..runtime.cache import FileCacheManager + if isinstance(fn_cache_manager, FileCacheManager): + error_detail += f"\n\n[INFO]: The compiled kernel cache is in {fn_cache_manager.cache_dir}\n\n" + else: + error_detail += f"\n\n[INFO]: The compiled kernel cache is {file_name}.{ext}\n\n" raise MLIRCompilationError(stage_name, error_detail) from e ir_filename = f"{file_name}.{ext}" if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): diff --git a/third_party/ascend/backend/spec/triton/language/__init__.py b/third_party/ascend/backend/spec/triton/language/__init__.py index ea7f8e7d66..d89c541442 100644 --- a/third_party/ascend/backend/spec/triton/language/__init__.py +++ b/third_party/ascend/backend/spec/triton/language/__init__.py @@ -1,10 +1,130 @@ -def language_extend_globals(globals_dict): - try: - import acl - is_compile_on_910_95 = acl.get_soc_name().startswith("Ascend910_95") - except Exception as e: - is_compile_on_910_95 = False - globals_dict["is_compile_on_910_95"] = is_compile_on_910_95 +"""isort:skip_file""" +# Import order is significant here. +from triton.tools.get_ascend_devices import is_compile_on_910_95 +from . import math +from . import extra +from .standard import ( + argmax, + argmin, + # cdiv, + cumprod, + cumsum, + flip, + interleave, + max, + min, + ravel, + sigmoid, + softmax, + sort, + sum, + swizzle2d, + topk, + xor_sum, + zeros, + zeros_like, +) +from .core import ( + PropagateNan, + TRITON_MAX_TENSOR_NUMEL, + _experimental_descriptor_load, + _experimental_descriptor_store, + make_tensor_descriptor, + load_tensor_descriptor, + store_tensor_descriptor, + add, + advance, + arange, + associative_scan, + assume, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + broadcast, + broadcast_to, + cat, + cast, + clamp, + const, + constexpr, + debug_barrier, + device_assert, + device_print, + dot, + dot_scaled, + dtype, + expand_dims, + float16, + float32, + float64, + float8e4b15, + float8e4nv, + float8e4b8, + float8e5, + float8e5b16, + full, + function_type, + gather, + histogram, + inline_asm_elementwise, + int1, + int16, + int32, + int64, + int8, + join, + load, + make_block_ptr, + max_constancy, + max_contiguous, + maximum, + minimum, + multiple_of, + num_programs, + permute, + pi32_t, + pointer_type, + nv_tma_desc_type, + program_id, + range, + reduce, + reshape, + split, + static_assert, + static_print, + static_range, + store, + tensor, + trans, + uint16, + uint32, + uint64, + uint8, + view, + void, + where, +) +from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, + ceil, cdiv) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint_to_uniform_float, +) def language_extend_exports(globals_dict, all_list): diff --git a/third_party/ascend/backend/spec/triton/language/core.py b/third_party/ascend/backend/spec/triton/language/core.py index 9ed34a899d..8c6b1119bf 100644 --- a/third_party/ascend/backend/spec/triton/language/core.py +++ b/third_party/ascend/backend/spec/triton/language/core.py @@ -1542,15 +1542,18 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i specified (i.e. at least one must be :code:`None`). """ assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" - assert not allow_tf32, "allow_tf32 is deprecated, please use input_precision='hf32' on Ascend instead." + assert not allow_tf32, "allow_tf32 is not supported as 'True', please use fp32 on Ascend instead." if input_precision is None: supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" + # when setting allow_tf32, use input_precision='hf32' on Ascend instead. + if allow_tf32: + default_precision = "hf32" input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) else: - assert input_precision not in [ - "tf32", "tf32x3" - ], "input_precision == tf32 or tf32x3 is invalid, please use input_precision='hf32' on Ascend instead." + if input_precision == "tf32": + input_precision = "hf32" + input_precision = _constexpr_to_value(input_precision) out_dtype = _constexpr_to_value(out_dtype) max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) @@ -2091,9 +2094,9 @@ def expand_ndims(t, ndims): ret = semantic.reduction(input, axis, make_combine_region, _builder) if keep_dims: if axis is not None: - ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret) + ret = builtins.tuple(expand_dims(t, axis, _builder=_builder) for t in ret) else: - ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret) + ret = builtins.tuple(expand_ndims(t, len(input[0].shape)) for t in ret) return ret @@ -2523,7 +2526,7 @@ def kernel(A, B, C, D, BLOCK: tl.constexpr): if not has_multiple_outputs: return tensor(call.get_result(0), res_tys[0]) - return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) + return builtins.tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) # ----------------------- diff --git a/third_party/ascend/backend/spec/triton/language/semantic.py b/third_party/ascend/backend/spec/triton/language/semantic.py index 7be3b70db7..af88701b48 100644 --- a/third_party/ascend/backend/spec/triton/language/semantic.py +++ b/third_party/ascend/backend/spec/triton/language/semantic.py @@ -336,15 +336,16 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu # float % float if scalar_ty.is_floating(): # input - input.div(other, rounding_mode="floor") * other - floor = math.floor(fdiv(input, other, False, builder), _builder=builder) - ret = sub(input, mul(floor, other, True, builder), True, builder) - return ret + return tl.tensor(builder.create_frem(input.handle, other.handle), input.type) # % int elif scalar_ty.is_int(): if scalar_ty.int_signedness != other_scalar_ty.int_signedness: raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " "because they have different signedness;" "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if hasattr(input, 'was_bool_to_int8'): + false_val = builder.get_int1(False) + return tl.tensor(false_val, tl.int1) if scalar_ty.is_int_signed(): return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) else: @@ -1114,6 +1115,9 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti # Check `boundary_check` argument boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + if boundary_check and padding is None: + padding = ir.PADDING_OPTION.PAD_ZERO + # Build IR return tl.tensor( builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) @@ -1177,11 +1181,15 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ # Build IR if mask is None: - ret = tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + load_handle = builder.create_load(ptr.handle, cache, eviction, is_volatile) else: - ret = tl.tensor( - builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, - is_volatile), dst_ty) + load_handle = builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, + eviction, is_volatile) + + if is_bool: + load_handle.set_attr("was_bool_to_int8", builder.get_bool_attr(True)) + + ret = tl.tensor(load_handle, dst_ty) # Do not cast back to int1 when is_bool=true. We directly use the int8 tensor given by tl.load if is_bool: ret.was_bool_to_int8 = True @@ -1608,7 +1616,8 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona if (input_precision == getattr(ir.INPUT_PRECISION, "HF32")): if (not lhs.dtype.is_fp32() or not rhs.dtype.is_fp32() or not ret_scalar_ty.is_fp32()): - raise ValueError("input_precision = 'hf32' must be used with f32 * f32 = f32 on Ascend") + # when input and result is not fp32, ignore input_precision (default is ieee) + input_precision = _str_to_dot_input_precision(builder.options.default_dot_input_precision, builder) if max_num_imprecise_acc is not None: print("max_num_imprecise_acc in tl.dot is not supported on Ascend yet. Thus it is ignored.") @@ -1653,8 +1662,14 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te rhs_format: str, acc: Union[tl.tensor, None], out_dtype: tl.dtype, lhs_k_pack, rhs_k_pack, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() - assert lhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"lhs matrix dtype must be bf16 or fp16" - assert rhs.dtype == tl.bfloat16 or rhs.dtype == tl.float16, f"rhs matrix dtype must be bf16 or fp16" + if is_compile_on_910_95: + assert lhs.dtype in [tl.float16, tl.bfloat16, tl.uint8, tl.float8e5, + tl.float8e4nv], f"lhs matrix dtype must be in [bf16, fp16, uint8, e5m2, e4m3]" + assert rhs.dtype in [tl.float16, tl.bfloat16, tl.uint8, tl.float8e5, + tl.float8e4nv], f"rhs matrix dtype must be in [bf16, fp16, uint8, e5m2, e4m3]" + else: + assert lhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"lhs matrix dtype must be bf16 or fp16" + assert rhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"rhs matrix dtype must be bf16 or fp16" assert lhs.dtype == rhs.dtype, f"lhs rhs matrix must get same dtype" lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) @@ -1663,26 +1678,35 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te rhs_format: str = rhs_format.value lhs_format_enum = _str_to_fp_type(lhs_format) rhs_format_enum = _str_to_fp_type(rhs_format) - allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" + if is_compile_on_910_95: + allowed_formats = {"bf16", "fp16", "e4m3", "e5m2"} + else: + allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) - assert isinstance(lhs_scale, tl.tensor) and lhs_scale.dtype == tl.int8, f"lhs_scale must be int8 tensor" + assert isinstance(lhs_scale, tl.tensor) and (lhs_scale.dtype == tl.int8 or lhs_scale.dtype + == tl.uint8), f"lhs_scale must be int8 or uint8 tensor" if not rhs_scale_is_none: - assert isinstance(rhs_scale, tl.tensor) and rhs_scale.dtype == tl.int8, f"rhs_scale must be int8 tensor" + assert isinstance(rhs_scale, tl.tensor) and (rhs_scale.dtype == tl.int8 or rhs_scale.dtype + == tl.uint8), f"rhs_scale must be int8 or uint8 tensor" lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) - if lhs_k_pack == False: + assert lhs_k_pack or lhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K" + assert rhs_k_pack or rhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K" + + lhs_k_pack_v = lhs_k_pack.value if isinstance(lhs_k_pack, tl.constexpr) else lhs_k_pack + rhs_k_pack_v = rhs_k_pack.value if isinstance(rhs_k_pack, tl.constexpr) else rhs_k_pack + + if lhs_k_pack_v is False: dims = (1, 0) - dims = core._unwrap_iterable(dims) tmp_lhs = permute(lhs, dims, builder) lhs = reshape(tmp_lhs, (lhs.shape[0], lhs.shape[1]), True, builder) - if rhs_k_pack == False: + if rhs_k_pack_v is False: dims = (1, 0) - dims = core._unwrap_iterable(dims) tmp_rhs = permute(rhs, dims, builder) rhs = reshape(tmp_rhs, (rhs.shape[0], rhs.shape[1]), True, builder) diff --git a/third_party/ascend/backend/spec/triton/runtime/_async_compile.py b/third_party/ascend/backend/spec/triton/runtime/_async_compile.py new file mode 100644 index 0000000000..a6c773123e --- /dev/null +++ b/third_party/ascend/backend/spec/triton/runtime/_async_compile.py @@ -0,0 +1,55 @@ +from __future__ import annotations +from typing import Callable, Optional +from concurrent.futures import Executor, as_completed, Future +from contextvars import ContextVar + +active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None) + + +class FutureKernel: + + def __init__(self, finalize_compile: Callable, future: Future): + self.finalize_compile = finalize_compile + self.kernel = None + self.future = future + + def result(self): + if self.kernel is not None: + return self.kernel + + kernel = self.future.result() + self.finalize_compile(kernel) + self.kernel = kernel + return kernel + + +class AsyncCompileMode: + + def __init__(self, executor: Executor): + self.executor = executor + self.raw_futures = [] + self.future_kernels = {} + + def submit(self, key, compile_fn, finalize_fn): + future = self.future_kernels.get(key) + if future is not None: + return future + + future = self.executor.submit(compile_fn) + future._key = key + self.raw_futures.append(future) + future_kernel = FutureKernel(finalize_fn, future) + self.future_kernels[key] = future_kernel + return future_kernel + + def __enter__(self): + if active_mode.get() is not None: + raise RuntimeError("Another AsyncCompileMode is already active") + active_mode.set(self) + return self + + def __exit__(self, exc_type, exc_value, traceback): + # Finalize any outstanding compiles + for future in as_completed(self.raw_futures): + self.future_kernels[future._key].result() + active_mode.set(None) diff --git a/third_party/ascend/backend/spec/triton/runtime/ascend_interpreter.py b/third_party/ascend/backend/spec/triton/runtime/ascend_interpreter.py new file mode 100644 index 0000000000..9482e57ca7 --- /dev/null +++ b/third_party/ascend/backend/spec/triton/runtime/ascend_interpreter.py @@ -0,0 +1,736 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Ascend-specific interpreter builder extensions. + +This module extends the base InterpreterBuilder with Ascend-specific operations +(extension ops) without modifying the public base class. All Ascend-related +features are isolated here and can be extended independently. + +Author: Triton-Ascend Contributors +""" + +import warnings +import contextlib +import numpy as np +import triton.language as tl +from .interpreter import InterpreterBuilder, TensorHandle, ReduceOps, _get_np_dtype +from .._C.libtriton import interpreter as _interpreter + + +class AscendReduceOps(ReduceOps): + """ + Ascend reduce operations that override only the apply_impl logic. + All other methods (sum, min_max, generic_reduce, etc.) are inherited from ReduceOps. + """ + + def apply_impl(self, input_param): + if self.combine_fn == tl.standard._argmin_combine_tie_break_left: + return self.min_max(input_param[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) + elif self.combine_fn == tl.standard._argmax_combine_tie_break_left: + return self.min_max(input_param[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) + # Ta has modified the implemention of tl.max + elif self.combine_fn == tl.standard._elementwise_max_default: + return self.min_max(input_param[0], val_reduce_op=np.nanmax, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_max_propagate_nan: + return self.min_max(input_param[0], val_reduce_op=np.max, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_min: + return self.min_max(input_param[0], val_reduce_op=np.nanmin, idx_reduce_op=None) + elif self.combine_fn == tl.standard._sum_combine: + return self.sum(input_param[0]) + else: + # Fall back to the slow mode + return self.generic_reduce(input_param) + + +def _compute_strides(shape): + strides = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + strides[i] = strides[i + 1] * shape[i + 1] + return strides + + +class AscendInterpreterBuilder(InterpreterBuilder): + """ + Extended InterpreterBuilder with Ascend-specific extension operations. + + This class inherits from InterpreterBuilder and adds support for: + - get_element (extract_scalar): Extract scalar from tensor using indices + - insert_slice: Insert sub-tensor into full tensor + - extract_slice: Extract slice from tensor + - index_select_simd: SIMD gather operation + - get_sub_vec_id: Get vector core ID for 1:2 ratio emulation + - Synchronization operations: sync_block_set/wait/all + + All extension operations handle both TensorHandle and Python int types + for interpreter mode compatibility. + """ + + def __init__(self) -> None: + super().__init__() + # Sub-vector core ID for simulating 1:2 hardware ratio + self.sub_vec_id = 0 + # Flag to track if sub_vec_id simulation is needed + self._sub_vec_simulation_enabled = False + + def to_int_val(self, val): + """ + Convert a value (int or TensorHandle) to Python int. + + :param val: Value to convert (int, TensorHandle, or other) + :return: Python integer + """ + if isinstance(val, TensorHandle): + return int(val.data.item()) + return int(val) + + def _patch_lang_ascend(self, fn): + + def _new_range(arg1, arg2=None, step=None, **kwargs): + if step is None: + step = 1 + if arg2 is None: + start, end = 0, arg1 + else: + start, end = arg1, arg2 + return range(start, end, step) + + def _new_reduce(input_param, axis, combine_fn, keep_dims=False, **kwargs): + return AscendReduceOps(axis, combine_fn, keep_dims).apply(input_param) + + @contextlib.contextmanager + def _dummpy_scope(*args, **kwargs): + yield + + tl.extra.cann.extension.scope = _dummpy_scope + tl.extra.cann.extension.parallel = _new_range + tl.reduce = _new_reduce + tl.core.reduce = _new_reduce + + def get_additional_reserved_keywords(self): + """ + Return additional reserved keywords specific to Ascend backend. + + These keywords will be filtered out from kernel call arguments + and are not supported by the interpreter. + + :return: List of additional reserved keyword strings + """ + return [ + "multibuffer", # Ascend-specific memory buffering + "debug", + "optimize_dynamic_offset", + "enable_mixed_cv", + "enable_auto_bind_sub_block", + "sync_solver", + # Add more Ascend-specific keywords here as needed + # "ascend_option1", + # "ascend_option2", + ] + + def patch_extensions(self, fn): + """ + Patch Ascend extension modules for the given function. + + This method handles all Ascend-specific extension module patching, + including CANN extensions and any other extension modules found in + the function's global namespace. + + :param fn: The kernel function to patch extensions for + """ + # Import _patch_builtin from parent module + from .interpreter import _patch_builtin + self._patch_lang_ascend(fn) + + # Patch all modules in fn's globals that might be extension modules + for name, value in list(fn.__globals__.items()): + if value is None: + continue + try: + # Check if it looks like an extension module (has builtin functions) + if hasattr(value, '__name__') and 'extension' in str(value.__name__): + _patch_builtin(value, self) + # Also try patching any module-like object that might have builtin functions + elif hasattr(value, '__dict__') and not isinstance(value, type): + # Try to patch it and ignore if it fails + try: + _patch_builtin(value, self) + except Exception: + pass + except Exception: + pass + + # Also try importing extension directly as fallback + try: + import triton.language.extra.cann.extension as extension + _patch_builtin(extension, self) + except (ImportError, AttributeError): + # Extension module not available (e.g., non-Ascend backend) + pass + + def execute_with_sub_vec_simulation(self, fn, args, grid): + """ + Execute function with optional 1:2 sub-vector core simulation. + + Sub-vector simulation is only activated when create_get_sub_vec_id() is + actually called during execution. This avoids unnecessary double execution + for code that doesn't use sub_vec_id functionality. + + :param fn: The kernel function to execute + :param args: Function arguments + :param grid: Grid dimensions (nx, ny, nz) + """ + # Reset simulation flag at the beginning of each execution + self._sub_vec_simulation_enabled = False + self.sub_vec_id = 0 + + # First, try a single execution to see if sub_vec_id is used + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + self.set_grid_idx(x, y, z) + fn(**args) + + # If sub_vec_id was accessed during execution, run again with sub_vec_id=1 + if self._sub_vec_simulation_enabled: + self.sub_vec_id = 1 + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + self.set_grid_idx(x, y, z) + fn(**args) + + # ======================================================================== + # Extension ops for Ascend + # ======================================================================== + + def create_extract_scalar(self, tensor_handle, indices): + """ + Extract a scalar from a tensor using indices (equivalent to get_element). + + Handles mixed types: Python int (from loops) and TensorHandle (from other ops). + + :param tensor_handle: The tensor to extract from (TensorHandle) + :param indices: List of scalar indices (can be TensorHandle or Python int) + :return: Scalar value as TensorHandle + """ + # Convert indices from TensorHandle or Python int to integers + index_values = [] + for idx in indices: + if isinstance(idx, int): + # Python int passed directly (e.g., from loop counter) + index_values.append(idx) + elif isinstance(idx, TensorHandle): + # Interpreter TensorHandle + index_values.append(int(idx.data.item()) if hasattr(idx.data, 'item') else int(idx.data)) + else: + # Fallback: try to extract data + index_values.append( + int(idx.data.item()) if hasattr(idx, 'data') and hasattr(idx.data, 'item') else + int(idx.data) if hasattr(idx, 'data') else int(idx)) + + # Extract the scalar value + scalar_data = tensor_handle.data[tuple(index_values)] + return TensorHandle(np.array([scalar_data]), tensor_handle.dtype.scalar) + + def create_insert_slice(self, full_tensor, sub_tensor, offsets, sizes, strides): + """ + Insert a sub-tensor into a full tensor at specified offsets. + + Handles mixed types: Python int and TensorHandle for offsets. + + :param full_tensor: The full tensor (destination, TensorHandle) + :param sub_tensor: The sub-tensor to insert (TensorHandle) + :param offsets: List of offset TensorHandle objects or Python ints + :param sizes: List of size integers + :param strides: List of stride integers + :return: Modified tensor with sub_tensor inserted (TensorHandle) + """ + result = full_tensor.data.copy() + + # Convert offsets from TensorHandle or Python int to integers + offset_values = [] + for off in offsets: + if isinstance(off, int): + # Python int passed directly + offset_values.append(off) + elif isinstance(off, TensorHandle): + # Interpreter TensorHandle + offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) + else: + # Fallback + offset_values.append( + int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') else + int(off.data) if hasattr(off, 'data') else int(off)) + + # Build slices for insertion + slices = [] + for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): + end = offset + size * stride + if stride == 1: + slices.append(slice(offset, end)) + else: + slices.append(slice(offset, end, stride)) + + # Insert the sub-tensor + result[tuple(slices)] = sub_tensor.data + + return TensorHandle(result, full_tensor.dtype.scalar) + + def create_extract_slice(self, full_tensor, offsets, sizes, strides): + """ + Extract a slice from a full tensor. + + Handles mixed types: Python int and TensorHandle for offsets. + + :param full_tensor: The full tensor (TensorHandle) + :param offsets: List of offset TensorHandle objects or Python ints + :param sizes: List of size integers + :param strides: List of stride integers + :return: Extracted sub-tensor (TensorHandle) + """ + # Convert offsets from TensorHandle or Python int to integers + offset_values = [] + for off in offsets: + if isinstance(off, int): + # Python int passed directly + offset_values.append(off) + elif isinstance(off, TensorHandle): + # Interpreter TensorHandle + offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) + else: + # Fallback + offset_values.append( + int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') else + int(off.data) if hasattr(off, 'data') else int(off)) + + # Build slices for extraction + slices = [] + for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): + end = offset + size * stride + if stride == 1: + slices.append(slice(offset, end)) + else: + slices.append(slice(offset, end, stride)) + + # Extract the slice + extracted = full_tensor.data[tuple(slices)] + + return TensorHandle(extracted, full_tensor.dtype.scalar) + + def create_index_select_simd(self, src_ptr, index_tensor, dim, src_shape, src_offset, read_shape, result_shape): + """ + SIMD index_select operation (gather with indices along a dimension). + + This is a hardware-accelerated gather operation that selects elements + from a tensor using a set of indices along a specified dimension. + + :param src_ptr: Source tensor pointer (TensorHandle), just ptr address, not value + :param index_tensor: 1D tensor of indices (TensorHandle or array) + :param dim: Dimension to select from (int) + :param src_shape: List of source shape (int or TensorHandle) + :param src_offset: List of source offset (int or TensorHandle) + :param read_shape: List of read shape (int or TensorHandle) + :param result_shape: List of result shape (int or TensorHandle) + :return: Result tensor with selected indices (TensorHandle) + """ + # Convert src_shape, src_offset, read_shape to integers + src_shape_vals = [self.to_int_val(s) for s in src_shape] + src_offset_vals = [self.to_int_val(s) if s != -1 else -1 for s in src_offset] + read_shape_vals = [self.to_int_val(r) if r != -1 else -1 for r in read_shape] + result_shape_vals = [self.to_int_val(r) for r in result_shape] + + # Get index values - handle both array and TensorHandle + if isinstance(index_tensor, TensorHandle): + indices = index_tensor.data.flatten() + else: + indices = np.asarray(index_tensor).flatten() + + # Ensure indices are integers + if indices.dtype not in [np.int32, np.int64]: + indices = indices.astype(np.int32) + + # Element type + dtype_tt = src_ptr.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + src_strides = _compute_strides(src_shape_vals) + base_addr = int(src_ptr.data.item()) + + # Create result tensor + result = np.empty(result_shape_vals, dtype=dtype_np) + + # Perform index_select: for each index, read the specified data + for out_idx, in_idx in enumerate(indices): + in_idx = int(in_idx) + # Generate all coordinates in the tile + ranges = [] + for d in range(len(src_shape_vals)): + if d == dim: + ranges.append([in_idx]) + else: + offset = src_offset_vals[d] if src_offset_vals[d] != -1 else 0 + read_size = read_shape_vals[d] if read_shape_vals[d] != -1 else src_shape_vals[d] + # Clamp to valid range + offset = max(0, min(offset, src_shape_vals[d] - 1)) + read_size = min(read_size, src_shape_vals[d] - offset) + ranges.append(list(range(offset, offset + read_size))) + from itertools import product + coords = list(product(*ranges)) + + # Compute address for each element in the tile + addresses = [] + for coord in coords: + offset = sum(coord[i] * src_strides[i] for i in range(len(coord))) + addr = base_addr + offset * np.dtype(dtype_np).itemsize + addresses.append(addr) + # load data + addr_array = np.array(addresses, dtype=np.uint64) + mask_array = np.ones_like(addr_array, dtype=bool) + other_array = np.zeros_like(addr_array, dtype=dtype_np) + tile_data = _interpreter.load(addr_array, mask_array, other_array, dtype_np) + # Reshape tile_data to match read_shape with dim=1 at dim + tile_shape = [] + for d in range(len(src_shape_vals)): + if d == dim: + tile_shape.append(1) + else: + offset = src_offset_vals[d] + read_size = read_shape_vals[d] + offset = max(0, min(offset, src_shape_vals[d] - 1)) + read_size = min(read_size, src_shape_vals[d] - offset) + tile_shape.append(read_size) + tile_data = tile_data.reshape(tile_shape) + + # Build result slice + result_slices = [] + for d in range(len(result_shape_vals)): + if d == dim: + result_slices.append(slice(out_idx, out_idx + 1)) + else: + result_slices.append(slice(None)) + result[tuple(result_slices)] = tile_data + + return TensorHandle(result, dtype_tt) + + def create_get_sub_vec_id(self): + """ + Get the Vector Core index on the AI Core. + + In Interpreter mode, simulate multiple vector cores by maintaining + a sub_vec_id counter. This is used for 1:2 hardware ratio emulation + where different vector cores process different partitions of the data. + + The first call to this method enables sub_vec_simulation, causing + the kernel to be executed twice (once for each sub_vec_id value). + + :return: Vector Core ID as TensorHandle (int64, scalar) + """ + # Enable sub_vec_id simulation when this method is called + self._sub_vec_simulation_enabled = True + + # Return the current sub_vec_id + vec_id = np.int64(self.sub_vec_id) + return TensorHandle(np.array([vec_id], dtype=np.int64), tl.int64) + + def sync_block_set(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): + """ + Set synchronization event between compute and vector units. + + In Interpreter mode, this is a no-op since we execute single-threaded. + Synchronization is not needed in CPU emulation. + + :param sender: Source unit ("cube" or "vector") + :param receiver: Destination unit ("cube" or "vector") + :param event_id: Event ID (TensorHandle) + :param sender_pipe_value: Sender pipe value + :param receiver_pipe_value: Receiver pipe value + """ + # No-op in interpreter mode: single-threaded execution doesn't need sync + pass + + def sync_block_wait(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): + """ + Wait for synchronization event between compute and vector units. + + In Interpreter mode, this is a no-op since we execute single-threaded. + Synchronization is not needed in CPU emulation. + + :param sender: Source unit ("cube" or "vector") + :param receiver: Destination unit ("cube" or "vector") + :param event_id: Event ID (TensorHandle) + :param sender_pipe_value: Sender pipe value + :param receiver_pipe_value: Receiver pipe value + """ + # No-op in interpreter mode: single-threaded execution doesn't need sync + pass + + def sync_block_all(self, mode, event_id): + """ + Synchronize all compute or vector units globally. + + In Interpreter mode, this is a no-op since we execute single-threaded. + Synchronization is not needed in CPU emulation. + + :param mode: Sync mode ("all_cube", "all_vector", "all", "all_sub_vector") + :param event_id: Event ID (int, constexpr, or TensorHandle) + """ + # No-op in interpreter mode: single-threaded execution doesn't need sync + pass + + def get_int1_ty(self): + return tl.int1 + + def get_all_ones_value(self, tt_type): + np_type = _get_np_dtype(tt_type) + if "int" in np_type.name: + return TensorHandle(np.full(1, -1, dtype=np_type), tt_type.scalar) + elif np_type == np.bool_: + return TensorHandle(np.full(1, True, dtype=np_type), tt_type.scalar) + else: + raise TypeError(f"unsupported type {tt_type}") + + def is_simt_mode(self): + return False + + def create_sort(self, ptr_data, dim: int, descending: bool): + ndim = ptr_data.data.ndim + norm_dim = dim if dim >= 0 else dim + ndim + if not (0 <= norm_dim < ndim): + raise IndexError(f"Dimension out of range(expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})") + + if descending: + sorted_asc = np.sort(ptr_data.data, axis=norm_dim) + sorted_desc = np.flip(sorted_asc, axis=norm_dim) + return TensorHandle(sorted_desc, ptr_data.dtype.scalar) + else: + return TensorHandle(np.sort(ptr_data.data, axis=norm_dim), ptr_data.dtype.scalar) + + def create_flip(self, ptr_data, dim): + ndim = ptr_data.data.ndim + norm_dim = dim if dim >= 0 else dim + ndim + if not (0 <= norm_dim < ndim): + raise IndexError(f"Dimension out of range(expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})") + return TensorHandle(np.flip(ptr_data.data, axis=norm_dim), ptr_data.dtype.scalar) + + def create_gather_out_to_ub(self, src_ptr, index_tensor, index_boundary, dim, src_stride, end_offset, start_offset, + other=None): + # Convert src_stride, start_offset, end_offset to integers + src_stride_vals = [self.to_int_val(s) for s in src_stride] + start_offset_vals = [self.to_int_val(s) for s in start_offset] + end_offset_vals = [self.to_int_val(s) for s in end_offset] + + # Element type + dtype_tt = src_ptr.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + element_size = np.dtype(dtype_np).itemsize + base_addr = int(src_ptr.data.item()) + index_shape = index_tensor.data.shape + index_rank = len(index_shape) + total_elements = np.prod(index_shape) + + # Generate coordinates + all_coords = [] + for idx in range(total_elements): + coord = np.unravel_index(idx, index_shape) + all_coords.append(coord) + + # Compute the source tensor coordinates for each position in all_coords + src_coords = [] + for coord in all_coords: + src_coord = [] + for d in range(index_rank): + if d == dim: + index_value = index_tensor.data[coord] + if index_value >= index_boundary: + src_coord.append(-1) + else: + src_coord.append(start_offset_vals[d] + index_value) + else: + src_coord.append(start_offset_vals[d] + coord[d]) + src_coords.append(src_coord) + + # Compute address and mask + addresses = [] + valid_mask = [] + for _, src_coord in enumerate(src_coords): + if -1 in src_coord: + addresses.append(0) + valid_mask.append(False) + else: + offset = 0 + for d in range(index_rank): + offset += src_coord[d] * src_stride_vals[d] + address = base_addr + offset * element_size + addresses.append(address) + valid_mask.append(True) + + addr_array = np.array(addresses, dtype=np.uint64) + mask_array = np.array(valid_mask, dtype=bool) + + # Create other value array + if other is not None: + if isinstance(other, TensorHandle): + other_value = other.data.item() + else: + other_value = other + other_array = np.full(addr_array.shape, other_value, dtype=dtype_np) + else: + other_array = np.zeros(addr_array.shape, dtype=dtype_np) + + # Load data + flat_result = _interpreter.load(addr_array, mask_array, other_array, dtype_np) + result = flat_result.reshape(index_shape) + return TensorHandle(result, dtype_tt) + + def create_scatter_ub_to_out(self, dst_ptr, value_tensor, index_tensor, index_boundary, dim, dst_stride, end_offset, + start_offset): + # Convert dst_stride, start_offset, end_offset to integers + dst_stride_vals = [self.to_int_val(s) for s in dst_stride] + start_offset_vals = [self.to_int_val(s) for s in start_offset] + end_offset_vals = [self.to_int_val(s) for s in end_offset] + + # Element type + dtype_tt = dst_ptr.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + element_size = np.dtype(dtype_np).itemsize + base_addr = int(dst_ptr.data.item()) + + index_shape = index_tensor.data.shape + index_rank = len(index_shape) + total_elements = np.prod(index_shape) + flat_values = value_tensor.data.flatten() + flat_indices = index_tensor.data.flatten() + + # Generate coordinates + all_coords = [] + for idx in range(total_elements): + coord = np.unravel_index(idx, index_shape) + all_coords.append(coord) + + # Compute address and mask + addresses = [] + valid_mask = [] + for _, coord in enumerate(all_coords): + index_value = index_tensor.data[coord] + if index_value >= index_boundary: + addresses.append(0) + valid_mask.append(False) + else: + dst_coord = [] + for d in range(index_rank): + if d == dim: + dst_coord.append(start_offset_vals[d] + index_value) + else: + dst_coord.append(start_offset_vals[d] + coord[d]) + offset = 0 + for d in range(index_rank): + offset += dst_coord[d] * dst_stride_vals[d] + address = base_addr + offset * element_size + addresses.append(address) + valid_mask.append(True) + + addr_array = np.array(addresses, dtype=np.uint64) + mask_array = np.array(valid_mask, dtype=bool) + + _interpreter.store(addr_array, flat_values, mask_array) + + def create_index_put(self, dst_ptr, index_tensor, value_tensor, dim, index_boundary, end_offset, start_offset, + dst_stride): + # Convert dst_stride, start_offset, end_offset_ to integers + dst_stride_vals = [self.to_int_val(s) for s in dst_stride] + start_offset_vals = [self.to_int_val(s) for s in start_offset] + end_offset_vals = [self.to_int_val(s) for s in end_offset] + + # Element type + dtype_tt = dst_ptr.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + element_size = np.dtype(dtype_np).itemsize + base_addr = int(dst_ptr.data.item()) + + value_shape = value_tensor.data.shape + value_rank = len(value_shape) + + flat_values = value_tensor.data.flatten() + total_elements = flat_values.size + + # Generate coordinates + all_coords = [] + for idx in range(total_elements): + coord = np.unravel_index(idx, value_shape) + all_coords.append(coord) + + read_ranges = [] + for d in range(value_rank): + start = start_offset_vals[d] + end = end_offset_vals[d] + read_ranges.append((start, end)) + + #Compute address + addresses = [] + valid_mask = [] + values_to_store = [] + for i, coord in enumerate(all_coords): + index_pos = coord[dim] + index_value = index_tensor.data[index_pos] + if index_value >= index_boundary: + addresses.append(0) + valid_mask.append(False) + else: + dst_coord = [] + for d in range(value_rank): + if d == dim: + dst_coord.append(index_value) + else: + dst_coord.append(start_offset_vals[d] + coord[d]) + offset = 0 + for d in range(value_rank): + offset += dst_coord[d] * dst_stride_vals[d] + address = base_addr + offset * element_size + addresses.append(address) + values_to_store.append(flat_values[i]) + valid_mask.append(True) + + addr_array = np.array(addresses, dtype=np.uint64) + mask_array = np.array(valid_mask, dtype=bool) + values_array = np.array(values_to_store, dtype=dtype_np) + + _interpreter.store(addr_array, values_array, mask_array) + + def get_bool_attr(self, val): + return bool(val) + + def get_unit_attr(self): + return None # None valule in compile_hint return uint + + def get_int32_attr(self, val): + return int(val) + + def get_str_attr(self, val): + return str(val) + + def get_i64_array_attr(self, val): + return [int(x) for x in val] + + def create_annotation_mark(self, ptr_data, hint_name: str, hint_val): + if hint_name == "overflow_mode": + raise ValueError(f"overflow_mode is not supported in interpreter mode, may have accuracy issues") + else: + warnings.warn(f"compile_hint '{hint_name}' is not supported in interpreter mode, just pass it", UserWarning, + stacklevel=2) diff --git a/third_party/ascend/backend/spec/triton/runtime/autotuner.py b/third_party/ascend/backend/spec/triton/runtime/autotuner.py index 5afa5228d5..993aee56ad 100644 --- a/third_party/ascend/backend/spec/triton/runtime/autotuner.py +++ b/third_party/ascend/backend/spec/triton/runtime/autotuner.py @@ -4,7 +4,9 @@ import os import time import inspect -from typing import Dict +import itertools + +from typing import Any, Dict, List from .jit import KernelInterface from .errors import OutOfResources @@ -116,7 +118,6 @@ def _post_hook(kwargs, exception): quantiles=quantiles, ) return - import triton.testing self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( kernel_call, @@ -406,3 +407,216 @@ def decorator(fn): return Heuristics(fn, fn.arg_names, values) return decorator + + +_ALL_PARAMS = { + "BM_list", "BN_list", "multibuffer", "unit_flag", "limit_auto_multi_buffer_only_for_local_buffer", + "limit_auto_multi_buffer_of_local_buffer", "set_workspace_multibuffer", "enable_hivm_auto_cv_balance", + "tile_mix_vector_loop", "tile_mix_cube_loop" +} + +_DEFAULTS = { + "BM_list": [16, 32, 64, 128], "BN_list": [16, 32, 64, 128], "multibuffer": [False], "unit_flag": [False], + "limit_auto_multi_buffer_only_for_local_buffer": [True], "limit_auto_multi_buffer_of_local_buffer": ["no-l0c"], + "set_workspace_multibuffer": [2, 4], "enable_hivm_auto_cv_balance": [True], "tile_mix_vector_loop": [2, 4], + "tile_mix_cube_loop": [2, 4] +} + +_VALID_VALUES = { + "limit_auto_multi_buffer_of_local_buffer": ["no-limit", "no-l0c"], "set_workspace_multibuffer": [2, 4], + "tile_mix_vector_loop": [2, 4, 8], "tile_mix_cube_loop": [2, 4, 8] +} + +_CUBE_PARAMS = {"multibuffer", "unit_flag", "limit_auto_multi_buffer_of_local_buffer"} +_MIXCV_PARAMS = { + "multibuffer", "unit_flag", "limit_auto_multi_buffer_only_for_local_buffer", + "limit_auto_multi_buffer_of_local_buffer", "set_workspace_multibuffer", "enable_hivm_auto_cv_balance", + "tile_mix_vector_loop", "tile_mix_cube_loop" +} +_VECTOR_PARAMS = {"multibuffer"} + + +def _check_boolean_list(val: List[Any], param_name: str) -> bool: + return isinstance(val, (list, tuple)) and len(val) > 0 and all(isinstance(x, bool) for x in val) + + +def _check_string_in_set(val: List[Any], valid_set: set, param_name: str) -> bool: + return isinstance(val, (list, tuple)) and len(val) > 0 and all(v in valid_set for v in val) + + +def _check_int_in_set(val: List[Any], valid_set: set, param_name: str) -> bool: + return isinstance(val, (list, tuple)) and len(val) > 0 and all(isinstance(v, int) and v in valid_set for v in val) + + +_VALIDATION_RULES = { + "multibuffer": {"desc": "must be non-empty list/tuple of boolean values", "check": _check_boolean_list}, + "unit_flag": {"desc": "must be non-empty list/tuple of boolean values", "check": + _check_boolean_list}, "limit_auto_multi_buffer_only_for_local_buffer": + {"desc": "must be non-empty list/tuple of boolean values", "check": + _check_boolean_list}, "limit_auto_multi_buffer_of_local_buffer": { + "desc": + f"must be one or more of: {_VALID_VALUES['limit_auto_multi_buffer_of_local_buffer']}", "check": + lambda val, param_name: _check_string_in_set(val, _VALID_VALUES['limit_auto_multi_buffer_of_local_buffer'], + "limit_auto_multi_buffer_of_local_buffer") + }, "set_workspace_multibuffer": { + "desc": + f"must be one or more of: {_VALID_VALUES['set_workspace_multibuffer']}", "check": + lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['set_workspace_multibuffer'], + "set_workspace_multibuffer") + }, "enable_hivm_auto_cv_balance": + {"desc": "must be non-empty list/tuple of boolean values", "check": _check_boolean_list}, "tile_mix_vector_loop": { + "desc": + f"must be one or more of: {_VALID_VALUES['tile_mix_vector_loop']}", "check": + lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['tile_mix_vector_loop'], "tile_mix_vector_loop") + }, "tile_mix_cube_loop": { + "desc": f"must be one or more of: {_VALID_VALUES['tile_mix_cube_loop']}", "check": + lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['tile_mix_cube_loop'], "tile_mix_cube_loop") + } +} + + +class BaseAutotuner: + """ + Base Class: Used to generate auto-tuning configurations for Triton kernels. + Subclasses must define: + operator_name: The name of the operator. + supported_params: A set of supported parameter names. + default_params: A dictionary of default parameter values. + validation_rules: Validation rules for parameters (described in detail below). + """ + + def __init__(self, operator_name: str, supported_params: set, default_params: Dict[str, Any], + validation_rules: Dict[str, Dict[str, Any]]): + self.operator_name = operator_name + self.supported_params = supported_params + self.default_params = default_params + self.validation_rules = validation_rules + + SPECIAL_PARAMS_NO_WARNING = {"BM_list", "BN_list"} + + def validate_parameters(self, **kwargs: Any) -> bool: + invalid_params = [k for k in kwargs.keys() if k not in _ALL_PARAMS] + if invalid_params: + print(f"[ERROR] Invalid parameters for {self.operator_name}: {invalid_params}") + return False + + for param_name, rule in self.validation_rules.items(): + if param_name not in kwargs: + continue + + value = kwargs[param_name] + if not rule["check"](value, param_name): + print(f"[ERROR] Invalid value for '{param_name}' in {self.operator_name}: {value}") + print(f" Expected: {rule['desc']}") + return False + + return True + + def get_configs(self, **kwargs: Any) -> List[triton.Config]: + import triton + if not self.validate_parameters(**kwargs): + return [] + + params = self.default_params.copy() + bm_list = kwargs.get("BM_list") + bn_list = kwargs.get("BN_list") + + if bm_list is not None: + params["BM_list"] = bm_list + if bn_list is not None: + params["BN_list"] = bn_list + + for k, v in kwargs.items(): + if k in self.supported_params: + params[k] = v + + valid_kwargs = {k: v for k, v in kwargs.items() if k in self.supported_params} + + other_kwargs = { + k: v + for k, v in kwargs.items() + if k not in self.supported_params and k not in self.SPECIAL_PARAMS_NO_WARNING + } + if other_kwargs: + print( + f"[WARNING] Parameter(s) {list(other_kwargs.keys())} do not belong to {self.operator_name} and have been ignored." + ) + + configs = [] + + bm_list = params.get("BM_list", _DEFAULTS["BM_list"]) + bn_list = params.get("BN_list", _DEFAULTS["BN_list"]) + limit_flag = valid_kwargs.get("limit_auto_multi_buffer_only_for_local_buffer", [False])[0] + + dynamic_params = {} + + for param_name in sorted(self.supported_params): + if param_name == "limit_auto_multi_buffer_only_for_local_buffer": + dynamic_params[param_name] = valid_kwargs.get(param_name, _DEFAULTS[param_name]) + elif param_name in [ + "set_workspace_multibuffer", "enable_hivm_auto_cv_balance", "tile_mix_vector_loop", + "tile_mix_cube_loop" + ]: + if not limit_flag: + dynamic_params[param_name] = valid_kwargs.get(param_name, _DEFAULTS[param_name]) + else: + dynamic_params[param_name] = valid_kwargs.get(param_name, _DEFAULTS[param_name]) + + other_params = {} + for param_name in sorted(dynamic_params.keys()): + if param_name in valid_kwargs: + other_params[param_name] = valid_kwargs[param_name] + else: + other_params[param_name] = _DEFAULTS.get(param_name, [True]) + + bm_bn_combos = list(itertools.product(bm_list, bn_list)) + other_combos = list(itertools.product(*other_params.values())) + all_combos = list(itertools.product(bm_bn_combos, other_combos)) + for (bm, bn), other_values in all_combos: + config_kwargs = { + "BLOCK_M": bm, + "BLOCK_N": bn, + } + for i, param_name in enumerate(sorted(other_params.keys())): + config_kwargs[param_name] = other_values[i] + configs.append(triton.Config(config_kwargs)) + return configs + + +CubeAutotuner = BaseAutotuner(operator_name="cube", supported_params=_CUBE_PARAMS, default_params=_DEFAULTS, + validation_rules=_VALIDATION_RULES) + +MixcvAutotuner = BaseAutotuner(operator_name="mixcv", supported_params=_MIXCV_PARAMS, default_params=_DEFAULTS, + validation_rules=_VALIDATION_RULES) + +VectorAutotuner = BaseAutotuner(operator_name="vector", supported_params=_VECTOR_PARAMS, default_params=_DEFAULTS, + validation_rules=_VALIDATION_RULES) + + +def get_autotune_cube_config(**kwargs: Any) -> List[triton.Config]: + """ + Generate autotune configuration for the cube operator. + Supported parameters: multibuffer, unit_flag, limit_auto_multi_buffer_of_local_buffer. + """ + import triton + return CubeAutotuner.get_configs(**kwargs) + + +def get_autotune_cv_config(**kwargs: Any) -> List[triton.Config]: + """ + Generate autotune configuration for the mixcv operator. + Supported parameters: multibuffer, unit_flag, limit_auto_multi_buffer_only_for_local_buffer, + limit_auto_multi_buffer_of_local_buffer, set_workspace_multibuffer, + enable_hivm_auto_cv_balance, tile_mix_vector_loop, tile_mix_cube_loop + """ + import triton + return MixcvAutotuner.get_configs(**kwargs) + + +def get_autotune_vector_config(**kwargs: Any) -> List[triton.Config]: + """ + Generate autotune configuration for the vector operator. + Supported parameters: multibuffer + """ + import triton + return VectorAutotuner.get_configs(**kwargs) diff --git a/third_party/ascend/backend/spec/triton/runtime/code_cache.py b/third_party/ascend/backend/spec/triton/runtime/code_cache.py index 43d841cd3d..563d46c8af 100644 --- a/third_party/ascend/backend/spec/triton/runtime/code_cache.py +++ b/third_party/ascend/backend/spec/triton/runtime/code_cache.py @@ -1,4 +1,4 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright (c) FlagOpen contributors # Copyright 2018-2020 Philippe Tillet # Copyright 2020-2022 OpenAI # Copyright © 2024 BAAI. All rights reserved. diff --git a/third_party/ascend/backend/spec/triton/runtime/interpreter.py b/third_party/ascend/backend/spec/triton/runtime/interpreter.py index 7ad9b1b9f0..0f7be77b03 100644 --- a/third_party/ascend/backend/spec/triton/runtime/interpreter.py +++ b/third_party/ascend/backend/spec/triton/runtime/interpreter.py @@ -14,6 +14,25 @@ from .._C.libtriton import interpreter as _interpreter from .._C.libtriton import ir as _ir +# Import Ascend-specific interpreter builder (with deferred import to avoid circular dependency) +_has_ascend_support = False +AscendInterpreterBuilder = None + + +def _try_import_ascend(): + global _has_ascend_support, AscendInterpreterBuilder + try: + from . import ascend_interpreter + AscendInterpreterBuilder = ascend_interpreter.AscendInterpreterBuilder + _has_ascend_support = True + except ImportError as e: + _has_ascend_support = False + AscendInterpreterBuilder = None + except Exception as e: + # Catch other exceptions (like circular import) and log them + _has_ascend_support = False + AscendInterpreterBuilder = None + class TensorHandle: @@ -80,7 +99,7 @@ class InterpreterOptions: supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15") deprecated_fp8_dtypes: Tuple[str] = () default_dot_input_precision: str = "tf32" - allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", "hf32") max_num_imprecise_acc_default: int = 0 backend_name: str = "interpreter" @@ -140,6 +159,8 @@ def _convert_float(input, input_dtype, output_dtype, rounding_mode): bias_input = input_dtype.exponent_bias bias_output = output_dtype.exponent_bias exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + # mark NAN value + input_nan_index = (exponent == (1 << input_exponent_width) - 1) & (significand != 0) subnormal_index = exponent == 0 if np.any(subnormal_index): # Credit to Phil: phil@openai.com @@ -159,8 +180,13 @@ def _convert_float(input, input_dtype, output_dtype, rounding_mode): significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & ( (1 << input_dtype.fp_mantissa_width) - 1) # Prevent overflow and underflow - exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1)) + exponent_unclamped = exponent - bias_input + bias_output + output_max_exponent = (1 << output_exponent_width) - 1 + exponent_output = np.maximum(0, np.minimum(exponent_unclamped, output_max_exponent)) exponent_output = exponent_output.astype(output_unint_dtype) + # mark overflow index + overflow_index = exponent_unclamped > output_max_exponent - 1 + sign_output = sign.astype(output_unint_dtype) if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( @@ -188,6 +214,8 @@ def _convert_float(input, input_dtype, output_dtype, rounding_mode): shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input) significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) + # covert overflow value to inf + significand_output[overflow_index & ~input_nan_index] = 0 output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( exponent_output << output_dtype.fp_mantissa_width) | significand_output return output.reshape(input.shape) @@ -245,8 +273,6 @@ def __init__(self) -> None: # For interpreter mode, don't enforce GPU hardware shape constraints # NumPy matmul works with any size, including small matrices self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1) - # Sub-vector core ID for simulating 1:2 hardware ratio - self.sub_vec_id = 0 def set_grid_idx(self, x, y, z): if not x < self.grid_dim[0]: @@ -612,261 +638,6 @@ def create_splat(self, arg, shape): else: # scalar return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) - # Extension ops for Ascend - def create_extract_scalar(self, tensor_handle, indices): - """ - Extract a scalar from a tensor using indices (equivalent to get_element). - - :param tensor_handle: The tensor to extract from - :param indices: List of scalar indices (can be TensorHandle or Python int) - :return: Scalar value - """ - # Convert indices from TensorHandle or Python int to integers - index_values = [] - for idx in indices: - if isinstance(idx, int): - # Python int passed directly (e.g., from loop counter) - index_values.append(idx) - elif isinstance(idx, TensorHandle): - # Interpreter TensorHandle - index_values.append(int(idx.data.item()) if hasattr(idx.data, 'item') else int(idx.data)) - else: - # Fallback: try to extract data - index_values.append( - int(idx.data.item()) if hasattr(idx, 'data') and hasattr(idx.data, 'item') else - int(idx.data) if hasattr(idx, 'data') else int(idx)) - - # Extract the scalar value - scalar_data = tensor_handle.data[tuple(index_values)] - return TensorHandle(np.array([scalar_data]), tensor_handle.dtype.scalar) - - def create_insert_slice(self, full_tensor, sub_tensor, offsets, sizes, strides): - """ - Insert a sub-tensor into a full tensor at specified offsets. - - :param full_tensor: The full tensor (destination) - :param sub_tensor: The sub-tensor to insert - :param offsets: List of offset TensorHandle objects or Python ints - :param sizes: List of size integers - :param strides: List of stride integers - :return: Modified tensor with sub_tensor inserted - """ - result = full_tensor.data.copy() - - # Convert offsets from TensorHandle or Python int to integers - offset_values = [] - for off in offsets: - if isinstance(off, int): - # Python int passed directly - offset_values.append(off) - elif isinstance(off, TensorHandle): - # Interpreter TensorHandle - offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) - else: - # Fallback - offset_values.append( - int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') else - int(off.data) if hasattr(off, 'data') else int(off)) - - # Build slices for insertion - slices = [] - for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): - end = offset + size * stride - if stride == 1: - slices.append(slice(offset, end)) - else: - slices.append(slice(offset, end, stride)) - - # Insert the sub-tensor - result[tuple(slices)] = sub_tensor.data - - return TensorHandle(result, full_tensor.dtype.scalar) - - def create_extract_slice(self, full_tensor, offsets, sizes, strides): - """ - Extract a slice from a full tensor. - - :param full_tensor: The full tensor - :param offsets: List of offset TensorHandle objects or Python ints - :param sizes: List of size integers - :param strides: List of stride integers - :return: Extracted sub-tensor - """ - # Convert offsets from TensorHandle or Python int to integers - offset_values = [] - for off in offsets: - if isinstance(off, int): - # Python int passed directly - offset_values.append(off) - elif isinstance(off, TensorHandle): - # Interpreter TensorHandle - offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) - else: - # Fallback - offset_values.append( - int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') else - int(off.data) if hasattr(off, 'data') else int(off)) - - # Build slices for extraction - slices = [] - for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): - end = offset + size * stride - if stride == 1: - slices.append(slice(offset, end)) - else: - slices.append(slice(offset, end, stride)) - - # Extract the slice - extracted = full_tensor.data[tuple(slices)] - - return TensorHandle(extracted, full_tensor.dtype.scalar) - - def create_index_select_simd(self, src_ptr, index_tensor, dim, src_shape, src_offset, read_shape, result_shape): - """ - SIMD index_select operation (gather with indices along a dimension). - - :param src_ptr: Source tensor pointer - :param index_tensor: 1D tensor of indices - :param dim: Dimension to select from - :param src_shape: List of source shape (int or TensorHandle) - :param src_offset: List of source offset (int or TensorHandle) - :param read_shape: List of read shape (int or TensorHandle) - :param result_shape: List of result shape (int or TensorHandle) - :return: Result tensor with selected indices - """ - - # Convert src_shape, src_offset, read_shape to integers - def to_int(val): - if isinstance(val, TensorHandle): - return int(val.data.item()) - return int(val) - - src_shape_vals = [to_int(s) for s in src_shape] - src_offset_vals = [to_int(o) if o != -1 else -1 for o in src_offset] - read_shape_vals = [to_int(r) if r != -1 else -1 for r in read_shape] - result_shape_vals = [to_int(r) for r in result_shape] - - # Get index values - handle both array and TensorHandle - if isinstance(index_tensor, TensorHandle): - indices = index_tensor.data.flatten() - else: - indices = np.asarray(index_tensor).flatten() - - # Ensure indices are integers - if indices.dtype not in [np.int32, np.int64]: - indices = indices.astype(np.int32) - - # Create result tensor - result = np.empty(result_shape_vals, dtype=src_ptr.data.dtype) - - # Perform index_select: for each index, read the specified data - for out_idx, in_idx in enumerate(indices): - in_idx = int(in_idx) - - # Validate index bounds - if not (0 <= in_idx < src_shape_vals[dim]): - # Out of bounds - fill with zeros - result_slices = [slice(None)] * len(result_shape_vals) - result_slices[dim] = slice(out_idx, out_idx + 1) - result[tuple(result_slices)] = 0 - continue - - # Build source slice - src_slices = [] - for d in range(len(src_shape_vals)): - if d == dim: - src_slices.append(slice(in_idx, in_idx + 1)) - else: - offset = src_offset_vals[d] if src_offset_vals[d] != -1 else 0 - read_size = read_shape_vals[d] if read_shape_vals[d] != -1 else src_shape_vals[d] - # Clamp to valid range - offset = max(0, min(offset, src_shape_vals[d] - 1)) - read_size = min(read_size, src_shape_vals[d] - offset) - src_slices.append(slice(offset, offset + read_size)) - - # Build result slice - result_slices = [] - for d in range(len(result_shape_vals)): - if d == dim: - result_slices.append(slice(out_idx, out_idx + 1)) - else: - result_slices.append(slice(None)) - - # Copy data with proper shape handling - try: - src_data = src_ptr.data[tuple(src_slices)] - # Handle shape mismatch by resizing - target_shape = [result_shape_vals[d] if d != dim else 1 for d in range(len(result_shape_vals))] - if src_data.shape != tuple(target_shape): - # Pad or trim as needed - pad_width = [(0, target_shape[d] - src_data.shape[d]) for d in range(len(target_shape))] - src_data = np.pad(src_data, pad_width, mode='constant', constant_values=0) - result[tuple(result_slices)] = src_data - except Exception as e: - # On error, fill with zeros - result[tuple(result_slices)] = 0 - - return TensorHandle(result, src_ptr.dtype.scalar) - - def create_get_sub_vec_id(self): - """ - Get the Vector Core index on the AI Core. - - In Interpreter mode, simulate multiple vector cores by maintaining - a sub_vec_id counter. This is used for 1:2 hardware ratio emulation - where different vector cores process different partitions of the data. - - :return: Vector Core ID as TensorHandle (int64, scalar) - """ - # Return the current sub_vec_id (set by GridExecutor) - vec_id = np.int64(self.sub_vec_id) - return TensorHandle(np.array([vec_id], dtype=np.int64), tl.int64) - - def sync_block_set(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): - """ - Set synchronization event between compute and vector units. - - In Interpreter mode, this is a no-op since we execute single-threaded. - Synchronization is not needed in CPU emulation. - - :param sender: Source unit ("cube" or "vector") - :param receiver: Destination unit ("cube" or "vector") - :param event_id: Event ID (TensorHandle) - :param sender_pipe_value: Sender pipe value - :param receiver_pipe_value: Receiver pipe value - """ - # No-op in interpreter mode: single-threaded execution doesn't need sync - pass - - def sync_block_wait(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): - """ - Wait for synchronization event between compute and vector units. - - In Interpreter mode, this is a no-op since we execute single-threaded. - Synchronization is not needed in CPU emulation. - - :param sender: Source unit ("cube" or "vector") - :param receiver: Destination unit ("cube" or "vector") - :param event_id: Event ID (TensorHandle) - :param sender_pipe_value: Sender pipe value - :param receiver_pipe_value: Receiver pipe value - """ - # No-op in interpreter mode: single-threaded execution doesn't need sync - pass - - def sync_block_all(self, mode, event_id): - """ - Synchronize all compute or vector units globally. - - In Interpreter mode, this is a no-op since we execute single-threaded. - Synchronization is not needed in CPU emulation. - - :param mode: Sync mode ("all_cube", "all_vector", "all", "all_sub_vector") - :param event_id: Event ID (int, constexpr, or TensorHandle) - """ - # No-op in interpreter mode: single-threaded execution doesn't need sync - pass - def create_atomic_cas(self, ptr, cmp, val, sem, scope): if sem not in self.ir_sem_to_interpreter_sem: raise ValueError(f"unsupported semantic {sem}") @@ -1266,31 +1037,9 @@ def _patch_lang(fn): _patch_lang_tensor(lang.tensor) _patch_lang_core(lang) - # Patch all modules in fn's globals that might be extension modules - for name, value in list(fn.__globals__.items()): - if value is None: - continue - try: - # Check if it looks like an extension module (has builtin functions) - if hasattr(value, '__name__') and 'extension' in str(value.__name__): - _patch_builtin(value, interpreter_builder) - # Also try patching any module-like object that might have builtin functions - elif hasattr(value, '__dict__') and not isinstance(value, type): - # Try to patch it and ignore if it fails - try: - _patch_builtin(value, interpreter_builder) - except Exception: - pass - except Exception: - pass - - # Also try importing extension directly as fallback - try: - import triton.language.extra.cann.extension as extension - _patch_builtin(extension, interpreter_builder) - except (ImportError, AttributeError): - # Extension module not available (e.g., non-Ascend backend) - pass + # Patch Ascend extensions if using AscendInterpreterBuilder + if hasattr(interpreter_builder, 'patch_extensions'): + interpreter_builder.patch_extensions(fn) # TODO: wrap everything in triton tensors @@ -1317,10 +1066,19 @@ def _implicit_cvt(arg): return arg -interpreter_builder = InterpreterBuilder() +# Use AscendInterpreterBuilder if available, otherwise fall back to base InterpreterBuilder +_try_import_ascend() +if _has_ascend_support and AscendInterpreterBuilder is not None: + interpreter_builder = AscendInterpreterBuilder() +else: + interpreter_builder = InterpreterBuilder() # These keywords are not supported by the interpreter -RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg", "multibuffer"] +RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"] + +# Allow Ascend interpreter to extend reserved keywords +if hasattr(interpreter_builder, 'get_additional_reserved_keywords'): + RESERVED_KWS.extend(interpreter_builder.get_additional_reserved_keywords()) class GridExecutor: @@ -1380,25 +1138,13 @@ def __call__(self, *args_dev, **kwargs): grid = grid + (1, ) * (3 - len(grid)) interpreter_builder.set_grid_dim(*grid) - # Infer the number of sub-vector cores from kernel parameters - # Check for M and sub_M parameters (common pattern for 1:2 ratio) - num_sub_vec_ids = 1 - if 'M' in args and 'sub_M' in args: - M = args['M'] - sub_M = args['sub_M'] - # Extract scalar values if they're TensorHandle - if isinstance(M, TensorHandle): - M = int(M.data.item() if hasattr(M.data, 'item') else M.data) - if isinstance(sub_M, TensorHandle): - sub_M = int(sub_M.data.item() if hasattr(sub_M.data, 'item') else sub_M.data) - # Number of vector cores = M / sub_M - if isinstance(M, int) and isinstance(sub_M, int) and sub_M > 0: - num_sub_vec_ids = max(1, M // sub_M) - try: - # Loop over sub-vector IDs to simulate parallel vector core execution - for sub_vec_id in range(num_sub_vec_ids): - interpreter_builder.sub_vec_id = sub_vec_id + # Execute kernels - sub_vec_id simulation handled by AscendInterpreterBuilder + if hasattr(interpreter_builder, 'execute_with_sub_vec_simulation'): + # Ascend builder with sub-vector simulation + interpreter_builder.execute_with_sub_vec_simulation(self.fn, args, grid) + else: + # Standard execution for base interpreter for x in range(grid[0]): for y in range(grid[1]): for z in range(grid[2]): diff --git a/third_party/ascend/backend/spec/triton/runtime/jit.py b/third_party/ascend/backend/spec/triton/runtime/jit.py index 45178a40bb..17271aac93 100644 --- a/third_party/ascend/backend/spec/triton/runtime/jit.py +++ b/third_party/ascend/backend/spec/triton/runtime/jit.py @@ -9,9 +9,12 @@ from collections import defaultdict from functools import cached_property from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple -from ..runtime.driver import driver from types import ModuleType +from triton._C.libtriton import get_cache_invalidating_env_vars +from .driver import driver +from . import _async_compile + TRITON_MODULE = __name__[:-len(".runtime.jit")] T = TypeVar("T") @@ -616,17 +619,9 @@ def run(self, *args, grid, warmup, **kwargs): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") - if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): + kernel = self._do_compile(key, signature, device, backend, target, constants, options, configs[0], warmup) + if kernel is None: return None - # compile the kernel - src = self.ASTSource(self, signature, constants, configs[0]) - kernel = self.compile( - src, - target=target, - options=options.__dict__, - ) - self.cache[device][key] = kernel - self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) # Check that used global values have not changed. not_present = object() @@ -647,6 +642,8 @@ def run(self, *args, grid, warmup, **kwargs): grid_0 = grid[0] grid_1 = grid[1] if grid_size > 1 else 1 grid_2 = grid[2] if grid_size > 2 else 1 + if hasattr(kernel, "result"): + kernel = kernel.result() # launch kernel launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) @@ -728,7 +725,7 @@ def warmup(self, *args, grid, **kwargs): return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) def preload(self, specialization_data): - from ..compiler import compile, ASTSource + from ..compiler import make_backend from triton.backends.compiler import AttrsDescriptor import json import triton.language as tl @@ -742,14 +739,54 @@ def preload(self, specialization_data): for key, value in deserialized_obj['constants'].items() } signature = dict(deserialized_obj['signature'].items()) - src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) + options = { key: tuple(value) if isinstance(value, list) else value for key, value in deserialized_obj['options'].items() } key = deserialized_obj['key'] - kernel = compile(src, None, options) - self.cache[device][key] = kernel + target = driver.active.get_current_target() + backend = make_backend(target) + options = backend.parse_options(options) + attrs = AttrsDescriptor.from_dict(deserialized_obj['attrs']) + return self._do_compile( + key, + signature, + device, + backend, + target, + constants, + options, + attrs, + warmup=True, + ) + + def _do_compile(self, key, signature, device, backend, target, constants, options, attrs, warmup): + kernel_cache = self.cache[device] + + if self._call_hook(key, signature, device, constants, options, [attrs], warmup, before=True): + return None + src = self.ASTSource(self, signature, constants, attrs) + + async_mode = _async_compile.active_mode.get() + if async_mode is not None: + from triton.compiler.compiler import get_cache_key + + env_vars = get_cache_invalidating_env_vars() + cache_key = get_cache_key(src, backend, options, env_vars) + + def async_compile(): + return self.compile(src, target=target, options=options.__dict__, _env_vars=env_vars) + + def finalize_compile(kernel): + kernel_cache[key] = kernel + self._call_hook(key, signature, device, constants, options, [attrs], warmup, before=False) + + kernel = async_mode.submit(cache_key, async_compile, finalize_compile) + else: + kernel = self.compile(src, target=target, options=options.__dict__) + kernel_cache[key] = kernel + self._call_hook(key, signature, device, constants, options, [attrs], warmup, before=False) return kernel # we do not parse `src` in the constructor because diff --git a/third_party/ascend/backend/spec/triton/runtime/libentry.py b/third_party/ascend/backend/spec/triton/runtime/libentry.py index a358b9ae8c..3a4a0231e6 100644 --- a/third_party/ascend/backend/spec/triton/runtime/libentry.py +++ b/third_party/ascend/backend/spec/triton/runtime/libentry.py @@ -1,4 +1,3 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. # Copyright 2018-2020 Philippe Tillet # Copyright 2020-2022 OpenAI # Copyright © 2024 BAAI. All rights reserved. @@ -285,6 +284,9 @@ def libentry(): """ def decorator(fn): + from triton.runtime.interpreter import InterpretedFunction + if isinstance(fn, InterpretedFunction): + return fn return LibEntry(fn) return decorator diff --git a/third_party/ascend/backend/testing.py b/third_party/ascend/backend/testing.py index 65d5968dd7..5e588c5631 100644 --- a/third_party/ascend/backend/testing.py +++ b/third_party/ascend/backend/testing.py @@ -57,7 +57,8 @@ def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None if clear_l2_cache: buffer = runtime.driver.active.get_empty_cache_for_benchmark() - buffer.zero_() + buffer = buffer.float() # to avoid type cast + buffer.sum() torch.npu.synchronize() # shake out of any npu error total = warmup + active @@ -74,7 +75,8 @@ def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None for fn in funcs: for _ in builtins.range(total): if clear_l2_cache: - buffer.zero_() + buffer.sum() # use buffer read to clear l2 cache + torch.npu.synchronize() fn() torch.npu.synchronize() if clear_l2_cache: @@ -172,7 +174,7 @@ def _collect_prof_result(base_dir: str, funcs, num_warmup: int, num_active: int, df = pd.read_csv(kernel_details_file) # filter out l2 cache clearing operation - filter_cond = ~df["Type"].str.contains(r"^ZerosLike$", case=False, na=False) + filter_cond = ~df["Type"].str.contains(r"^ReduceSum$", case=False, na=False) filter_df = df[filter_cond] if key is not None: key_rows = filter_df[filter_df["Name"].str.contains(key, na=False)] diff --git a/third_party/ascend/backend/utils.py b/third_party/ascend/backend/utils.py index 8f90b5fa4e..cf9983b78f 100644 --- a/third_party/ascend/backend/utils.py +++ b/third_party/ascend/backend/utils.py @@ -28,6 +28,7 @@ from pathlib import Path import logging from platform import python_version +from triton.tools.get_ascend_devices import is_compile_on_910_95 from triton.backends.ascend.backend_register import backend_strategy_registry import pybind11 @@ -152,7 +153,10 @@ def _get_llvm_path(path: str, *paths) -> str: def _get_npucompiler_path() -> str: ascend_dir = os.path.dirname(os.path.abspath(__file__)) env = os.environ.copy() - npu_compiler_path = os.path.join(ascend_dir, "bishengir", "bin", "bishengir-compile") + if is_compile_on_910_95: + npu_compiler_path = os.path.join(ascend_dir, "bishengir-a5", "bin", "bishengir-compile") + else: + npu_compiler_path = os.path.join(ascend_dir, "bishengir", "bin", "bishengir-compile") if os.path.exists(npu_compiler_path): npuir_env_path = os.path.dirname(npu_compiler_path) env["PATH"] = npuir_env_path + ":" + env["PATH"] @@ -263,6 +267,10 @@ def _enable_print_ub_bits() -> bool: return os.getenv("ENABLE_PRINT_UB_BITS", "false").lower() in ("true", "1") +def _enable_dump_memory_info() -> bool: + return os.getenv("TRITON_MEMORY_DISPLAY", "false").lower() in ("true", "1") + + def _get_cxx(): cxx = os.environ.get("CC") if cxx is None: @@ -302,11 +310,8 @@ def _precompile_npu_hash(header_src): return hash_txt -def _precompile_npu_ext(header_path): - src_dir = os.path.dirname(header_path) - gch_path = os.path.join(src_dir, "precompiled.h.gch") +def _precompile_npu_ext(header_path, gch_path): cxx = _get_cxx() - cc_cmd = [cxx, "-x", "c++-header", header_path] # disable all warnings cc_cmd += [f"-w"] @@ -344,12 +349,12 @@ def _precompile_npu_ext(header_path): cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-o", gch_path] - ret = subprocess.check_call(cc_cmd) - - if ret != 0: - print(f"Unable to precompile header file, ret is: {ret}") + result = subprocess.run(cc_cmd, capture_output=True, text=True) - return header_path + if result.returncode == 0: + return header_path + else: + raise RuntimeError(f"Failed to compile {gch_path}, error: {result.stderr},cmd={cc_cmd}") def _build_npu_ext(obj_name: str, header_path, src_path, *, kernel_launcher="torch", precompile=False) -> str: @@ -399,8 +404,8 @@ def _build_npu_ext(obj_name: str, header_path, src_path, *, kernel_launcher="tor "-lascendcl", ] # FIXME: check why this condition works wrong in parall scene - # if kernel_launcher == "torch": - cc_cmd += get_backend_func("get_cc_cmd", build_pch=False) + if kernel_launcher == "torch": + cc_cmd += get_backend_func("get_cc_cmd", build_pch=False) cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-Winvalid-pch", "-o", so_path] @@ -413,7 +418,7 @@ def _build_npu_ext(obj_name: str, header_path, src_path, *, kernel_launcher="tor # only for clang++, when precompile invalid, fallback to normal compile return _build_npu_ext(obj_name, header_path, src_path, precompile=False) else: - raise RuntimeError(f"Failed to compile {src_path}, error: {result.stderr}") + raise RuntimeError(f"Failed to compile {src_path}, error: {result.stderr},cmd={cc_cmd}") def _get_kernel_target(metadata: dict): @@ -531,8 +536,11 @@ def is_ffts_supported(arch: str): Cases: - empty str: User does not specify arch, thus it runs on 910B/910D both of which support ffts. Return True. - Ascend310B4: 310B4 does not support ffts. Return False. + - Ascend910_95*: 910_95 does not support ffts. Return False. - Other arch: 910B/910D supports ffts. Return True. ''' + if is_compile_on_910_95: + return False if arch in ["Ascend910A", "Ascend310B4"]: return False return True @@ -541,5 +549,17 @@ def is_ffts_supported(arch: str): def force_disable_ffts(): ''' ''' + if is_compile_on_910_95: + return True disable_ffts = os.getenv("TRITON_DISABLE_FFTS", "false").lower() in ("true", "1") return disable_ffts + + +def triton_support_ffts(): + arch = get_ascend_arch_from_env() + return is_ffts_supported(arch) and (not force_disable_ffts()) + + +def triton_enable_libdevice_simt(): + enable_libdevice_simt = os.getenv("TRITON_ENABLE_LIBDEVICE_SIMT", False) + return enable_libdevice_simt diff --git a/third_party/ascend/include/AutoBlockify/AutoBlockify.h b/third_party/ascend/include/AutoBlockify/AutoBlockify.h new file mode 100644 index 0000000000..3ce8159ef9 --- /dev/null +++ b/third_party/ascend/include/AutoBlockify/AutoBlockify.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#pragma once + +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" + +#define GEN_PASS_DECL_AUTOBLOCKIFY +#include "ascend/include/AutoBlockify/Passes.h.inc" + +#define GEN_PASS_DEF_AUTOBLOCKIFY +#include "ascend/include/AutoBlockify/Passes.h.inc" + +namespace mlir { +namespace triton { + +std::unique_ptr> +createAutoBlockifyPass(const AutoBlockifyOptions &options = {}); + +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace triton; + +class PropagateUnrealizedCastDown + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit PropagateUnrealizedCastDown(MLIRContext *context, + Value logicalBlockId, + Value logicalBlockNum, + int autoBlockifySize); + + LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, + PatternRewriter &rewriter) const override; + +private: + void handleBlockifyLoop(scf::ForOp blockifyLoop, Operation *op, + PatternRewriter &rewriter) const; + void rewriteSplat(UnrealizedConversionCastOp op, triton::SplatOp splatOp, + PatternRewriter &rewriter) const; + void rewriteExpandDims(UnrealizedConversionCastOp op, + triton::ExpandDimsOp expandDimsOp, + PatternRewriter &rewriter) const; + void rewriteReduce(UnrealizedConversionCastOp op, triton::ReduceOp reduceOp, + PatternRewriter &rewriter) const; + void rewriteScan(UnrealizedConversionCastOp op, triton::ScanOp scanOp, + PatternRewriter &rewriter) const; + void rewriteLoad(UnrealizedConversionCastOp op, triton::LoadOp loadOp, + PatternRewriter &rewriter) const; + void rewriteStore(UnrealizedConversionCastOp op, triton::StoreOp storeOp, + PatternRewriter &rewriter) const; + void rewriteAtomicRMW(UnrealizedConversionCastOp op, + triton::AtomicRMWOp atomicRMWOp, + PatternRewriter &rewriter) const; + void rewriteAssert(UnrealizedConversionCastOp op, triton::AssertOp assertOp, + PatternRewriter &rewriter) const; + void rewriteExtractSlice(UnrealizedConversionCastOp op, + tensor::ExtractSliceOp extractSliceOp, + PatternRewriter &rewriter) const; + void rewriteInsertSlice(UnrealizedConversionCastOp op, + tensor::InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const; + void rewriteWhile(UnrealizedConversionCastOp op, scf::WhileOp whileOp, + PatternRewriter &rewriter) const; + void rewriteLoop(UnrealizedConversionCastOp op, LoopLikeOpInterface loopOp, + PatternRewriter &rewriter) const; + void rewriteIf(UnrealizedConversionCastOp &op, scf::IfOp ifOp, + ArrayRef indices, PatternRewriter &rewriter) const; + void rewriteYield(UnrealizedConversionCastOp &op, scf::YieldOp yieldOp, + PatternRewriter &rewriter) const; + void rewriteCondition(UnrealizedConversionCastOp op, + scf::ConditionOp conditionOp, + PatternRewriter &rewriter) const; + void rewriteGeneraleOp(UnrealizedConversionCastOp op, Operation *generalOp, + PatternRewriter &rewriter) const; + + Value logicalBlockId; + Value logicalBlockNum; + int autoBlockifySize; +}; + +class AutoBlockifyPass : public ::impl::AutoBlockifyBase { +public: + explicit AutoBlockifyPass(const AutoBlockifyOptions &options); + void runOnOperation() override; + +private: + bool checkBlockifiable(Value v); + void preProcess(triton::FuncOp func); + + DenseSet checkedValues; + Value logicalBlockId; + Value logicalBlockNum; +}; diff --git a/third_party/ascend/include/AutoBlockify/CMakeLists.txt b/third_party/ascend/include/AutoBlockify/CMakeLists.txt new file mode 100644 index 0000000000..ca4cf9f552 --- /dev/null +++ b/third_party/ascend/include/AutoBlockify/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name AutoBlockify) +add_public_tablegen_target(AutoBlockifyPassIncGen) diff --git a/third_party/ascend/include/AutoBlockify/Passes.h b/third_party/ascend/include/AutoBlockify/Passes.h new file mode 100644 index 0000000000..7d5147ef92 --- /dev/null +++ b/third_party/ascend/include/AutoBlockify/Passes.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef TRITON_ADAPTER_AUTO_BLOCKIFY_PASSES_H +#define TRITON_ADAPTER_AUTO_BLOCKIFY_PASSES_H + +#include "AutoBlockify.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "ascend/include/AutoBlockify/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_ADAPTER_AUTO_BLOCKIFY_PASSES_H diff --git a/third_party/ascend/include/AutoBlockify/Passes.td b/third_party/ascend/include/AutoBlockify/Passes.td new file mode 100644 index 0000000000..7d9f1a80a3 --- /dev/null +++ b/third_party/ascend/include/AutoBlockify/Passes.td @@ -0,0 +1,21 @@ +#ifndef AUTO_BLOCKIFY_PASSES +#define AUTO_BLOCKIFY_PASSES + +include "mlir/Pass/PassBase.td" + +def AutoBlockify : Pass<"auto-blockify", "mlir::ModuleOp"> { + let summary = "Apply auto blockify v2"; + let constructor = "triton::createAutoBlockifyPass()"; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::tensor::TensorDialect", + "mlir::triton::TritonDialect" + ]; + let options = [ + Option<"autoBlockifySize", "auto-blockify-size", "int", "1", + "Apply auto blockify v2 when TRITON_ALL_BLOCKS_PARALLEL is 1." + "Expand highest dimension with blockify size"> + ]; +} + +#endif // AUTO_BLOCKIFY_PASSES diff --git a/third_party/ascend/include/AutoBlockify/Utils.h b/third_party/ascend/include/AutoBlockify/Utils.h new file mode 100644 index 0000000000..385fa51a10 --- /dev/null +++ b/third_party/ascend/include/AutoBlockify/Utils.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#pragma once + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; +using namespace triton; + +constexpr llvm::StringLiteral autoBlockifySizeAttr = "auto_blockify_size"; +constexpr llvm::StringLiteral logicalBlockIdAttr = "logical_block_id"; +constexpr llvm::StringLiteral autoBlockifyLoopAttr = "auto_blockify_loop"; +constexpr llvm::StringLiteral autoBlockifyRegionOpAttr = + "auto_blockify_region_op"; + +RankedTensorType getExpandedType(Type type, UnrealizedConversionCastOp op); + +Value rewriteValue(Value value, UnrealizedConversionCastOp op, + OpBuilder &builder); + +void replaceValue(Operation *newOp, Operation *oldOp, Value newMask, + RewriterBase &rewriter, + ArrayRef replaceIndices = {}); + +Value createMask(Value mask, Value uccMask, ArrayRef targetShape, + RewriterBase &rewriter); + +void mapRegionIterArg(IRMapping &mapping, ValueRange oldArgs, + ValueRange newArgs, ArrayRef indices, Value mask, + OpBuilder &builder); + +void mapYieldedValue(IRMapping &mapping, scf::YieldOp yieldOp, + ArrayRef indices, UnrealizedConversionCastOp op, + OpBuilder &builder); + +Operation *createBlockifyLoop(Operation *targetOp, + UnrealizedConversionCastOp op, + Value logicalBlockId, Value logicalBlockNum, + int autoBlockifySize, RewriterBase &rewriter); + +std::optional getBlockifyLoop(Operation *op); diff --git a/third_party/ascend/include/CMakeLists.txt b/third_party/ascend/include/CMakeLists.txt index 9cd93fe4ae..e1eae53c09 100644 --- a/third_party/ascend/include/CMakeLists.txt +++ b/third_party/ascend/include/CMakeLists.txt @@ -1,3 +1,5 @@ -add_subdirectory(TritonToLLVM) -add_subdirectory(TritonToHIVM) -add_subdirectory(TritonToHFusion) +add_subdirectory(TritonToHFusion) +add_subdirectory(TritonToHIVM) +add_subdirectory(TritonToLLVM) +add_subdirectory(AutoBlockify) +add_subdirectory(TritonAffinityOpt) diff --git a/third_party/ascend/include/TritonAffinityOpt/CMakeLists.txt b/third_party/ascend/include/TritonAffinityOpt/CMakeLists.txt new file mode 100644 index 0000000000..c6193d1f5d --- /dev/null +++ b/third_party/ascend/include/TritonAffinityOpt/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonAffinityOpt) +add_public_tablegen_target(TritonAffinityOptConversionPassIncGen) diff --git a/third_party/ascend/include/TritonAffinityOpt/DAG.h b/third_party/ascend/include/TritonAffinityOpt/DAG.h new file mode 100644 index 0000000000..364c20258c --- /dev/null +++ b/third_party/ascend/include/TritonAffinityOpt/DAG.h @@ -0,0 +1,304 @@ +#ifndef AffinityDAGDEF +#define AffinityDAGDEF +#include "Utils.hpp" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TinyPtrVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include +#include +#include +#include +#include +#include + +namespace mlir { +namespace AffinityDAG { + +enum class OpAbility { + PREFER_VECTOR = 1 << 0, + CUBE_ONLY = 1 << 1, + CUBE_AND_VECTOR = PREFER_VECTOR | CUBE_ONLY + +}; + +enum CoreType { + UNDETERMINED = 0, + VECTOR_ONLY = 1 << 0, + CUBE_ONLY = 1 << 1, + CUBE_AND_VECTOR = VECTOR_ONLY | CUBE_ONLY +}; + +inline constexpr CoreType toCoreType(OpAbility ct) { + using U = std::underlying_type_t; + return static_cast(static_cast(ct)); +} + +constexpr inline CoreType operator|(CoreType lhs, CoreType rhs) { + return enumOp(std::bit_or<>(), lhs, rhs); +} + +inline CoreType operator&(CoreType lhs, CoreType rhs) { + return enumOp(std::bit_and<>(), lhs, rhs); +} + +inline bool intersects(CoreType lhs, CoreType rhs) { + return (lhs & rhs) != CoreType::UNDETERMINED; +} + +inline CoreType operator&(OpAbility lhs, CoreType rhs) { + return toCoreType(lhs) & rhs; +} + +inline CoreType operator!(CoreType ct) { + CoreType newCt = UNDETERMINED; + if ((ct & CoreType::CUBE_ONLY) == UNDETERMINED) { + newCt = newCt | CoreType::CUBE_ONLY; + } + + if ((ct & CoreType::VECTOR_ONLY) == UNDETERMINED) { + newCt = newCt | CoreType::VECTOR_ONLY; + } + + return newCt; +} + +inline hivm::TCoreType toHivm(CoreType ct) { + switch (ct) { + case UNDETERMINED: + return hivm::TCoreType::CUBE_OR_VECTOR; + case CUBE_ONLY: + return hivm::TCoreType::CUBE; + case VECTOR_ONLY: + return hivm::TCoreType::VECTOR; + case CUBE_AND_VECTOR: + return hivm::TCoreType::CUBE_AND_VECTOR; + default: + llvm_unreachable("Invalid CoreType that cannot convert to hivm"); + } +} + +inline bool intersects(OpAbility lhs, CoreType rhs) { + return (lhs & rhs) != CoreType::UNDETERMINED; +} + +inline bool exactlyOneType(CoreType ct) { + return (ct == CUBE_ONLY) || (ct == VECTOR_ONLY); +} + +const char *literalCoreType(CoreType ct); + +class MoveOnly { +protected: + MoveOnly() = default; + ~MoveOnly() = default; + + MoveOnly(const MoveOnly &) = delete; + MoveOnly &operator=(const MoveOnly &) = delete; + + MoveOnly(MoveOnly &&) = default; + MoveOnly &operator=(MoveOnly &&) = default; +}; + +class Node; +class OpNode; +class ValueNode; + +ValueNode *getDataSource(OpNode *op); + +class Graph : MoveOnly { +public: + using OpMapRaw = llvm::DenseMap>; + using ValueMapRaw = llvm::DenseMap>; + using OpMap = std::shared_ptr; + using ValueMap = std::shared_ptr; + + Graph(Block *block, Graph *parent = nullptr, OpMap opMap = nullptr, + ValueMap valueMap = nullptr, bool inheritParent = true); + + static std::unique_ptr fromMultiBlockFunc(triton::FuncOp funcOp); + + OpMapRaw &getOpMap() const { return *opMap; } + + ValueMapRaw &getValueMap() const { return *valueMap; } + + // [DEBUG] start + std::unique_ptr> legacyOpMap = nullptr; + std::unique_ptr> legacyValueTypes = nullptr; + + inline llvm::DenseMap &getOpMapLegacy() { + if (!legacyOpMap) { + legacyOpMap = + std::move(std::make_unique>()); + for (auto &[key, val] : *opMap) { + (*legacyOpMap)[key] = val.get(); + } + } + + return *legacyOpMap; + } + + llvm::DenseMap &getValueTypes(); + + // [DEBUG] end + +private: + friend class Node; + friend class OpNode; + OpMap opMap; + ValueMap valueMap; + Block *block; + Graph *parent; + OpNode *terminator = nullptr; + size_t opCount = 0; + llvm::SmallVector blockArgs; +}; + +class Node : MoveOnly { +protected: + friend class Graph; + friend class ValueNode; + bool isUpstreamOfCubeMem = false; + virtual CoreType absorbImpl() = 0; + llvm::SmallVector outputs; + +public: + CoreType isOnPrivate = UNDETERMINED; + + enum NodeKind { NK_Op, NK_Value }; + + inline CoreType isOn() const { return isOnPrivate; } + + bool absorb() { + auto newCoreType = absorbImpl(); + auto changed = newCoreType != isOnPrivate; + isOnPrivate = newCoreType; + + return changed; + }; + + virtual llvm::SmallVector getAffected() const = 0; + virtual OpNode *getSourceOpNode() = 0; + + ArrayRef getOutputs() const { return outputs; } + + CoreType absorbCommon(); + +private: + const NodeKind kind; + +public: + NodeKind getKind() const { return kind; } + +protected: + Node(NodeKind kind) : kind(kind) {} +}; + +class OpNode : public Node { + friend class Graph; + friend class ValueNode; + llvm::SmallVector inputs; + llvm::SmallVector subgraphs; + virtual CoreType absorbImpl() override; + +public: + Operation *op; + + OpNode(Operation *op, Graph *graph); + OpAbility canRunOn() const; + inline ArrayRef getInputs() const { return inputs; } + + static bool classof(const Node *node) { return node->getKind() == NK_Op; } + + virtual llvm::SmallVector getAffected() const override { + llvm::SmallVector result(inputs.begin(), inputs.end()); + result.append(outputs.begin(), outputs.end()); + + return result; + } + + virtual OpNode *getSourceOpNode() override { return this; } +}; + +class ValueNode : public Node { + friend class Graph; + friend class OpNode; + virtual CoreType absorbImpl() override; + +public: + Node *source = nullptr; + Value value; + // ValueNode(OpResult value); + // ValueNode(BlockArgument value); + + ValueNode(Value value) : Node(NK_Value), value(value){}; + virtual OpNode *getSourceOpNode() override { + if (!source) { + return nullptr; + } + + return source->getSourceOpNode(); + } + static bool classof(const Node *node) { return node->getKind() == NK_Value; } + + virtual llvm::SmallVector getAffected() const override { + llvm::SmallVector result(outputs.begin(), outputs.end()); + if (source) + result.push_back(source); + + return result; + } +}; + +class GraphManager { +private: + llvm::DenseMap> graphs; + +public: + static GraphManager &getInstance() { + static GraphManager instance; + return instance; + } + + void registerGraph(llvm::StringRef funcName, + std::shared_ptr graph) { + graphs[funcName] = graph; + } + + AffinityDAG::Graph *getGraph(llvm::StringRef funcName) { + auto it = graphs.find(funcName); + return it != graphs.end() ? it->second.get() : nullptr; + } + + void removeGraph(llvm::StringRef funcName) { graphs.erase(funcName); } +}; + +inline llvm::DenseMap &Graph::getValueTypes() { + static std::mutex mtx; + std::lock_guard lock(mtx); + if (!legacyValueTypes) { + legacyValueTypes = + std::move(std::make_unique>()); + for (auto &[key, val] : *valueMap) { + llvm::dbgs() << key << "\n"; + llvm::dbgs().flush(); + (*legacyValueTypes)[key] = val.get()->isOn(); + } + } + + return *legacyValueTypes; +} + +} // namespace AffinityDAG +} // namespace mlir +#endif diff --git a/third_party/ascend/include/TritonAffinityOpt/Passes.h b/third_party/ascend/include/TritonAffinityOpt/Passes.h new file mode 100644 index 0000000000..f58c7563bc --- /dev/null +++ b/third_party/ascend/include/TritonAffinityOpt/Passes.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef TRITON_ADAPTER_TRITON_AFFINITY_OPTIMIZATION_PASSES_H +#define TRITON_ADAPTER_TRITON_AFFINITY_OPTIMIZATION_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +// Forward declarations. +class ModuleOp; + +namespace triton { + +/// Creates a pass to convert Triton dialect to Annotation dialect. +std::unique_ptr> createDAGSSBufferPass(); + +std::unique_ptr> createDAGSyncPass(); + +std::unique_ptr> createDAGScopePass(); + +#define GEN_PASS_REGISTRATION +#include "ascend/include/TritonAffinityOpt/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_ADAPTER_TRITON_AFFINITY_OPTIMIZATION_PASSES_H diff --git a/third_party/ascend/include/TritonAffinityOpt/Passes.td b/third_party/ascend/include/TritonAffinityOpt/Passes.td new file mode 100644 index 0000000000..f12de8444e --- /dev/null +++ b/third_party/ascend/include/TritonAffinityOpt/Passes.td @@ -0,0 +1,29 @@ +/* + * Copyright (c) Huawei Technologies Co. + * Licensed under the MIT license. + */ + +#ifndef TRITON_AFFINITY_OPTIMIZATION_PASSES +#define TRITON_AFFINITY_OPTIMIZATION_PASSES + +include "mlir/Pass/PassBase.td" + +def DAGSSBuffer : Pass<"dag-ssbuf", "mlir::ModuleOp"> { + let summary = "Convert vector operations to shared storage buffer operations"; + let constructor = "triton::createDAGSSBufferPass()"; + let dependentDialects = ["hivm::HIVMDialect", "bufferization::BufferizationDialect", "scope::ScopeDialect", "annotation::AnnotationDialect"]; +} + +def DAGScope : Pass<"dag-scope", "mlir::ModuleOp"> { + let summary = "Convert native triton code to NPU-affine code"; + let constructor = "triton::createDAGScopePass()"; + let dependentDialects = ["hivm::HIVMDialect", "bufferization::BufferizationDialect", "scope::ScopeDialect", "annotation::AnnotationDialect"]; +} + +def DAGSync : Pass<"dag-sync", "mlir::ModuleOp"> { + let summary = "DAG sync"; + let constructor = "triton::createDAGSyncPass()"; + let dependentDialects = ["hivm::HIVMDialect", "bufferization::BufferizationDialect", "annotation::AnnotationDialect"]; +} + +#endif // TRITON_AFFINITY_OPTIMIZATION_PASSES diff --git a/third_party/ascend/include/TritonAffinityOpt/Utils.hpp b/third_party/ascend/include/TritonAffinityOpt/Utils.hpp new file mode 100644 index 0000000000..d3aa63be77 --- /dev/null +++ b/third_party/ascend/include/TritonAffinityOpt/Utils.hpp @@ -0,0 +1,20 @@ +#ifndef TRITON_AFFINITY_UTILS_HPP +#define TRITON_AFFINITY_UTILS_HPP + +#include + +namespace mlir::AffinityDAG { + +template +constexpr inline T enumOp(F &&func, T lhs, T rhs) { + static_assert(std::is_enum_v, "T must be an enum type"); + + using U = std::underlying_type_t; + + return static_cast(std::invoke(std::forward(func), static_cast(lhs), + static_cast(rhs))); +} + +} // namespace mlir::AffinityDAG + +#endif diff --git a/third_party/ascend/language/cann/__init__.py b/third_party/ascend/language/cann/__init__.py index d7feaad57a..d599b4569b 100644 --- a/third_party/ascend/language/cann/__init__.py +++ b/third_party/ascend/language/cann/__init__.py @@ -18,17 +18,19 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +from triton.language import math +from triton.backends.ascend.utils import triton_enable_libdevice_simt + from . import libdevice from . import extension extension.parallel = extension.aux_ops.parallel -libdevice.atan2 = extension.math_ops.atan2 +if not triton_enable_libdevice_simt(): + libdevice.atan2 = extension.math_ops.atan2 libdevice.isfinited = extension.math_ops.isfinited libdevice.finitef = extension.math_ops.finitef libdevice.flip = extension.flip -from triton.language import math - libdevice.umulhi = math.umulhi libdevice.exp = math.exp libdevice.exp2 = math.exp2 diff --git a/third_party/ascend/language/cann/extension/__init__.py b/third_party/ascend/language/cann/extension/__init__.py index 20c339bc8a..efa09c71fd 100644 --- a/third_party/ascend/language/cann/extension/__init__.py +++ b/third_party/ascend/language/cann/extension/__init__.py @@ -1,14 +1,27 @@ -try: - import acl - is_compile_on_910_95 = acl.get_soc_name().startswith("Ascend910_95") -except Exception as e: - is_compile_on_910_95 = False +from triton.tools.get_ascend_devices import is_compile_on_910_95 +from triton._C.libtriton.ascend import ir as _ascend_ir + +# MLIR affine bindings (same objects as triton._C.libtriton.ascend.ir). +affine_expr = _ascend_ir.affine_expr +affine_constant_expr = _ascend_ir.affine_constant_expr +affine_dim_expr = _ascend_ir.affine_dim_expr +affine_symbol_expr = _ascend_ir.affine_symbol_expr +affine_binary_op_expr = _ascend_ir.affine_binary_op_expr +affine_map = _ascend_ir.affine_map + +AffineExpr = affine_expr +AffineConstantExpr = affine_constant_expr +AffineDimExpr = affine_dim_expr +AffineSymbolExpr = affine_symbol_expr +AffineBinaryOpExpr = affine_binary_op_expr +AffineMap = affine_map from .core import ( ascend_address_space, builtin, CORE, copy_from_ub_to_l1, + copy, debug_barrier, fixpipe, FixpipeDMAMode, @@ -19,6 +32,7 @@ is_builtin, MODE, PIPE, + IteratorType, sub_vec_id, sub_vec_num, sync_block_all, @@ -55,7 +69,6 @@ ) from .mem_ops import ( - index_select, index_put, gather_out_to_ub, scatter_ub_to_out, @@ -66,6 +79,7 @@ # core "builtin", "copy_from_ub_to_l1", + "copy", "CORE", "debug_barrier", "fixpipe", @@ -77,6 +91,7 @@ "is_builtin", "MODE", "PIPE", + "IteratorType", "sub_vec_id", "sub_vec_num", "sync_block_all", @@ -85,6 +100,20 @@ # address space "ascend_address_space", + # ascend IR affine (MLIR) + "affine_expr", + "affine_constant_expr", + "affine_dim_expr", + "affine_symbol_expr", + "affine_binary_op_expr", + "affine_map", + "AffineExpr", + "AffineConstantExpr", + "AffineDimExpr", + "AffineSymbolExpr", + "AffineBinaryOpExpr", + "AffineMap", + # scope "scope", @@ -114,7 +143,6 @@ "cast", # mem ops - "index_select", "index_put", "gather_out_to_ub", "scatter_ub_to_out", diff --git a/third_party/ascend/language/cann/extension/aux_ops.py b/third_party/ascend/language/cann/extension/aux_ops.py index d872e44ac6..fd134a787b 100644 --- a/third_party/ascend/language/cann/extension/aux_ops.py +++ b/third_party/ascend/language/cann/extension/aux_ops.py @@ -112,10 +112,11 @@ def compile_hint_impl(ptr: tensor, hint_name: str, hint_val, builder: ir.builder # FIXME: is_simt_mode # if builder.is_simt_mode(): # return - if not hint_val: - hint_val = builder.get_unit_attr() - elif isinstance(hint_val, bool): + # Check isinstance(hint_val, bool) first to handle False explicitly + if isinstance(hint_val, bool): hint_val = builder.get_bool_attr(hint_val) + elif not hint_val: + hint_val = builder.get_unit_attr() elif isinstance(hint_val, int): hint_val = builder.get_int32_attr(hint_val) elif isinstance(hint_val, core.constexpr): @@ -125,7 +126,7 @@ def compile_hint_impl(ptr: tensor, hint_name: str, hint_val, builder: ir.builder hint_val = builder.get_i64_array_attr(hint_val) else: raise ValueError(f"Unsupported hint value type: {type(hint_val)}") - builder.create_annotation(ptr.handle, hint_name, hint_val) + builder.create_annotation_mark(ptr.handle, hint_name, hint_val) @builtin @@ -156,4 +157,4 @@ def multibuffer(src: tensor, size, _builder=None): """ buffer_size = _constexpr_to_value(size) assert isinstance(buffer_size, int) and buffer_size == 2, f"only support bufferize equals 2" - compile_hint_impl(src, "multi_buffer", buffer_size, _builder) + compile_hint_impl(src, "hivm.multi_buffer", buffer_size, _builder) diff --git a/third_party/ascend/language/cann/extension/builder.py b/third_party/ascend/language/cann/extension/builder.py index cfd4be3b0b..8cf699f63a 100644 --- a/third_party/ascend/language/cann/extension/builder.py +++ b/third_party/ascend/language/cann/extension/builder.py @@ -73,6 +73,7 @@ def setup_unified_builder(main_builder, ascend_builder): 'create_copy_buffer', 'create_copy_tensor', 'create_fixpipe', + 'create_annotation_mark', 'create_bind_buffer', 'create_debug_barrier', 'is_910_95', diff --git a/third_party/ascend/language/cann/extension/core.py b/third_party/ascend/language/cann/extension/core.py index 1710d7b960..e9520c679c 100644 --- a/third_party/ascend/language/cann/extension/core.py +++ b/third_party/ascend/language/cann/extension/core.py @@ -1,312 +1,343 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# Copyright 2018-2020 Philippe Tillet -# Copyright 2020-2022 OpenAI -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -__all__ = [ - "ascend_address_space", "builtin", "CORE", "copy_from_ub_to_l1", "debug_barrier", "fixpipe", "FixpipeDMAMode", - "FixpipeDualDstMode", "FixpipePreQuantMode", "FixpipePreReluMode", "int64", "is_builtin", "MODE", "PIPE", - "sub_vec_id", "sub_vec_num", "sync_block_all", "sync_block_set", "sync_block_wait", "SYNC_IN_VF" -] - -import enum -from typing import TypeVar, List, Union -from functools import wraps - -from triton._C.libtriton import ir -from triton._C.libtriton.ascend import ir as ascend_ir -import triton.language.core as tl - -import triton.extension.buffer.language as bl -from triton.language.core import _constexpr_to_value -from triton.backends.ascend.driver import NPUUtils - -from . import semantic as semantic - -PIPE = semantic.PIPE - -T = TypeVar("T") - -TRITON_BUILTIN = "__triton_builtin__" -ASCEND_BUILTIN = "__ascend_builtin__" - - -def builtin(fn: T) -> T: - """Mark a function as a buffer language builtin.""" - assert callable(fn) - - @wraps(fn) - def wrapper(*args, **kwargs): - if "_builder" not in kwargs or kwargs["_builder"] is None: - raise ValueError("Did you forget to add @triton.jit ? " - "(`_builder` argument must be provided outside of JIT functions.)") - return fn(*args, **kwargs) - - # also set triton_builtin to true so that CodeGenerator will recognize this function - setattr(wrapper, TRITON_BUILTIN, True) - setattr(wrapper, ASCEND_BUILTIN, True) - - return wrapper - - -def is_builtin(fn) -> bool: - """Is this a registered ascend language builtin function?""" - return getattr(fn, ASCEND_BUILTIN, False) - - -class int64(int): - """ - For custom op, python int argument will be converted to int32 by default, - if a device-side int64 is required, you can pass an al.int64(x) to it. - """ - - def __new__(cls, value): - obj = int.__new__(cls, value) - obj.type = tl.int64 - return obj - - -class CORE(enum.Enum): - VECTOR = ascend_ir.CoreType.VECTOR - CUBE = ascend_ir.CoreType.CUBE - CUBE_OR_VECTOR = ascend_ir.CoreType.CUBE_OR_VECTOR - CUBE_AND_VECTOR = ascend_ir.CoreType.CUBE_AND_VECTOR - - -class PIPE(enum.Enum): - PIPE_S = ascend_ir.PIPE.PIPE_S - PIPE_V = ascend_ir.PIPE.PIPE_V - PIPE_M = ascend_ir.PIPE.PIPE_M - PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1 - PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2 - PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3 - PIPE_ALL = ascend_ir.PIPE.PIPE_ALL - PIPE_FIX = ascend_ir.PIPE.PIPE_FIX - - -class MODE(enum.Enum): - SIMD = ascend_ir.MODE.SIMD - SIMT = ascend_ir.MODE.SIMT - MIX = ascend_ir.MODE.MIX - - -class ascend_address_space_base(bl.address_space): - - def __init__(self, address_space_value: ascend_ir.AddressSpace) -> None: - super().__init__() - self.real_address_space = address_space_value - - def to_ir(self, builder: ir.builder) -> ir.attribute: - return builder.get_target_attribute(self.real_address_space) - - -class ascend_address_space_group: - - def __init__(self): - for k, v in {k: v - for k, v in ascend_ir.AddressSpace.__dict__.items() - if isinstance(v, ascend_ir.AddressSpace)}.items(): - setattr(self, k, ascend_address_space_base(v)) - - -ascend_address_space = ascend_address_space_group() - - -@builtin -def sub_vec_id(_builder=None) -> tl.tensor: - """ - Get the Vector Core index on the AI Core. - """ - return semantic.sub_vec_id(_builder) - - -@builtin -def copy_from_ub_to_l1(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], _builder: None) -> None: - """ - Copies data from the Unified Buffer (UB) to the L1 Buffer. - - :param src: The source data located in the Unified Buffer. - :type src: tl.tensor | bl.buffer - :param dst: The destination buffer located in L1 memory. - :type dst: tl.tensor | bl.buffer - """ - return semantic.copy_from_ub_to_l1(src, dst, _builder) - - -def create_sync_block(sender, receiver, event_id, is_set: bool, sender_pipe=None, receiver_pipe=None, _builder=None): - sender = _constexpr_to_value(sender) - receiver = _constexpr_to_value(receiver) - assert isinstance(sender, str) and (sender == "cube" - or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" - assert isinstance(receiver, str) and (receiver == "cube" or receiver - == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" - if isinstance(event_id, int): - assert (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" - if sender == receiver: - raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') - if sender_pipe is None and receiver_pipe is None: - if sender == "cube": - sender_pipe = PIPE.PIPE_FIX - receiver_pipe = PIPE.PIPE_MTE2 - if sender == "vector": - sender_pipe = PIPE.PIPE_MTE3 - receiver_pipe = PIPE.PIPE_MTE2 - if not isinstance(sender_pipe, PIPE) or not isinstance(receiver_pipe, PIPE): - raise TypeError("sender_pipe and receiver_pipe must be instances of PIPE enum") - if is_set: - return semantic.create_sync_block_set(sender, receiver, event_id, sender_pipe, receiver_pipe, _builder) - return semantic.create_sync_block_wait(sender, receiver, event_id, sender_pipe, receiver_pipe, _builder) - - -@builtin -def sync_block_set(sender, receiver, event_id, sender_pipe=None, receiver_pipe=None, _builder=None): - return create_sync_block(sender, receiver, event_id, True, sender_pipe, receiver_pipe, _builder) - - -@builtin -def sync_block_wait(sender, receiver, event_id, sender_pipe=None, receiver_pipe=None, _builder=None): - return create_sync_block(sender, receiver, event_id, False, sender_pipe, receiver_pipe, _builder) - - -@builtin -def sync_block_all(mode, event_id, _builder=None): - mode = _constexpr_to_value(mode) - event_id = _constexpr_to_value(event_id) - assert isinstance(mode, str), f"mode: {mode} is not string" - assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" - assert mode in ("all_cube", "all_vector", "all", - "all_sub_vector"), f"ERROR: mode = {mode}, only supports all_cube/all_vector/all/all_sub_vector" - _builder.sync_block_all(mode, event_id) - - -class FixpipeDMAMode(enum.Enum): - NZ2DN = ascend_ir.FixpipeDMAMode.NZ2DN - NZ2ND = ascend_ir.FixpipeDMAMode.NZ2ND - NZ2NZ = ascend_ir.FixpipeDMAMode.NZ2NZ - - -class FixpipeDualDstMode(enum.Enum): - NO_DUAL = ascend_ir.FixpipeDualDstMode.NO_DUAL - COLUMN_SPLIT = ascend_ir.FixpipeDualDstMode.COLUMN_SPLIT - ROW_SPLIT = ascend_ir.FixpipeDualDstMode.ROW_SPLIT - - -class FixpipePreQuantMode(enum.Enum): - NO_QUANT = ascend_ir.FixpipePreQuantMode.NO_QUANT - F322BF16 = ascend_ir.FixpipePreQuantMode.F322BF16 - F322F16 = ascend_ir.FixpipePreQuantMode.F322F16 - S322I8 = ascend_ir.FixpipePreQuantMode.S322I8 - - -class FixpipePreReluMode(enum.Enum): - LEAKY_RELU = ascend_ir.FixpipePreReluMode.LEAKY_RELU - NO_RELU = ascend_ir.FixpipePreReluMode.NO_RELU - NORMAL_RELU = ascend_ir.FixpipePreReluMode.NORMAL_RELU - P_RELU = ascend_ir.FixpipePreReluMode.P_RELU - - -@builtin -def fixpipe( - src: tl.tensor, - dst: bl.buffer, - dma_mode: FixpipeDMAMode = FixpipeDMAMode.NZ2ND, - dual_dst_mode: FixpipeDualDstMode = FixpipeDualDstMode.NO_DUAL, - _builder=None, -) -> None: - """ - Directly store a tensor on L0C to a local buffer via fixpipe. - Fixpipe is pipeline that performing data movement from L0C to other memory hierarchies. - Currently support: - - L0C to UB (for Ascend910_95 sereies) - - :param src: the source tensor, Must be located in the l0C memory region. - :type src: tl.tensor - :param dst: The destination buffer, Must be located in the UB memory region. - :type dst: bl.buffer - :param dma_mode: DMA transfer mode, "nz2nd" enables NZ to ND layout transformation - :type dma_mode: str - """ - if not _builder.is_910_95(): - raise RuntimeError("this feature is only supported on Ascend910_95") - if not isinstance(src, tl.tensor): - raise TypeError("src is not of tensor type") - elif not isinstance(dst, bl.buffer): - raise TypeError("dst is not of buffer type") - if dst.space != ascend_address_space.UB: - raise TypeError("dst must be located in the UB memory region") - - if len(dst.shape) == 2 and (dst.type.element_ty == tl.float32 or dst.type.element_ty == tl.int32): - N = dst.shape[1] - if N % 8 != 0: - raise ValueError("32b Fixpipe last dim must be aligned to 8") - if (dma_mode != FixpipeDMAMode.NZ2ND) and (N % 16 != 0): - raise ValueError("32b non-NZ2ND Fixpipe last dim must be aligned to 16") - if (dual_dst_mode == FixpipeDualDstMode.COLUMN_SPLIT) and (N % 32 != 0): - raise ValueError("32b Column split dual Fixpipe last dim must be aligned to 32") - M = dst.shape[0] - if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 8 != 0): - raise ValueError("32b NZ2DN Fixpipe first dim must be aligned to 8") - dst16bits = (dst.type.element_ty == tl.float16 or dst.type.element_ty == tl.int16 - or dst.type.element_ty == tl.bfloat16) - if len(dst.shape) == 2 and dst16bits: - N = dst.shape[1] - if N % 16 != 0: - raise ValueError("16b Fixpipe last dim must be aligned to 16") - M = dst.shape[0] - if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 16 != 0): - raise ValueError("16b NZ2DN Fixpipe first dim must be aligned to 16") - - return semantic.fixpipe(src, dst, dma_mode, dual_dst_mode, FixpipePreQuantMode.NO_QUANT, FixpipePreReluMode.NO_RELU, - _builder) - - -class SYNC_IN_VF(enum.Enum): - VV_ALL = enum.auto() - VST_VLD = enum.auto() - VLD_VST = enum.auto() - VST_VST = enum.auto() - VS_ALL = enum.auto() - VST_LD = enum.auto() - VLD_ST = enum.auto() - VST_ST = enum.auto() - SV_ALL = enum.auto() - ST_VLD = enum.auto() - LD_VST = enum.auto() - ST_VST = enum.auto() - - -@builtin -def debug_barrier( - sync_mode: SYNC_IN_VF, - _builder=None, -) -> None: - return semantic.debug_barrier(sync_mode.name, _builder) - - -@builtin -def sub_vec_num(_builder=None) -> tl.constexpr: - """ - Get the Vector Core Num on one AI Core. - """ - npuUtils = NPUUtils() - cube_num = npuUtils.get_aivector_core_num() - vector_num = npuUtils.get_aicore_num() - const_val = cube_num // vector_num - return tl.constexpr(const_val) +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = [ + "ascend_address_space", "builtin", "CORE", "copy_from_ub_to_l1", "copy", "debug_barrier", "fixpipe", + "FixpipeDMAMode", "FixpipeDualDstMode", "FixpipePreQuantMode", "FixpipePreReluMode", "int64", "is_builtin", "MODE", + "PIPE", "IteratorType", "sub_vec_id", "sub_vec_num", "sync_block_all", "sync_block_set", "sync_block_wait", + "SYNC_IN_VF" +] + +import enum +from typing import TypeVar, List, Union +from functools import wraps + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +import triton.language.core as tl + +import triton.extension.buffer.language as bl +from triton.language.core import _constexpr_to_value +from triton.backends.ascend.driver import NPUUtils + +from . import semantic as semantic + +PIPE = semantic.PIPE + +T = TypeVar("T") + +TRITON_BUILTIN = "__triton_builtin__" +ASCEND_BUILTIN = "__ascend_builtin__" + + +def builtin(fn: T) -> T: + """Mark a function as a buffer language builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + # also set triton_builtin to true so that CodeGenerator will recognize this function + setattr(wrapper, TRITON_BUILTIN, True) + setattr(wrapper, ASCEND_BUILTIN, True) + + return wrapper + + +def is_builtin(fn) -> bool: + """Is this a registered ascend language builtin function?""" + return getattr(fn, ASCEND_BUILTIN, False) + + +class int64(int): + """ + For custom op, python int argument will be converted to int32 by default, + if a device-side int64 is required, you can pass an al.int64(x) to it. + """ + + def __new__(cls, value): + obj = int.__new__(cls, value) + obj.type = tl.int64 + return obj + + +class CORE(enum.Enum): + VECTOR = ascend_ir.CoreType.VECTOR + CUBE = ascend_ir.CoreType.CUBE + CUBE_OR_VECTOR = ascend_ir.CoreType.CUBE_OR_VECTOR + CUBE_AND_VECTOR = ascend_ir.CoreType.CUBE_AND_VECTOR + + +class PIPE(enum.Enum): + PIPE_S = ascend_ir.PIPE.PIPE_S + PIPE_V = ascend_ir.PIPE.PIPE_V + PIPE_M = ascend_ir.PIPE.PIPE_M + PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1 + PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2 + PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3 + PIPE_ALL = ascend_ir.PIPE.PIPE_ALL + PIPE_FIX = ascend_ir.PIPE.PIPE_FIX + + +class MODE(enum.Enum): + SIMD = ascend_ir.MODE.SIMD + SIMT = ascend_ir.MODE.SIMT + MIX = ascend_ir.MODE.MIX + + +class IteratorType(enum.Enum): + Parallel = ascend_ir.IteratorType.Parallel + Broadcast = ascend_ir.IteratorType.Broadcast + Transpose = ascend_ir.IteratorType.Transpose + Reduction = ascend_ir.IteratorType.Reduction + Interleave = ascend_ir.IteratorType.Interleave + Deinterleave = ascend_ir.IteratorType.Deinterleave + Inverse = ascend_ir.IteratorType.Inverse + Pad = ascend_ir.IteratorType.Pad + Concat = ascend_ir.IteratorType.Concat + Gather = ascend_ir.IteratorType.Gather + Cumulative = ascend_ir.IteratorType.Cumulative + Opaque = ascend_ir.IteratorType.Opaque + + +class ascend_address_space_base(bl.address_space): + + def __init__(self, address_space_value: ascend_ir.AddressSpace) -> None: + super().__init__() + self.real_address_space = address_space_value + + def to_ir(self, builder: ir.builder) -> ir.attribute: + return builder.get_target_attribute(self.real_address_space) + + +class ascend_address_space_group: + + def __init__(self): + for k, v in {k: v + for k, v in ascend_ir.AddressSpace.__dict__.items() + if isinstance(v, ascend_ir.AddressSpace)}.items(): + setattr(self, k, ascend_address_space_base(v)) + + +ascend_address_space = ascend_address_space_group() + + +@builtin +def sub_vec_id(_builder=None) -> tl.tensor: + """ + Get the Vector Core index on the AI Core. + """ + return semantic.sub_vec_id(_builder) + + +@builtin +def copy_from_ub_to_l1(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], _builder: None) -> None: + """ + Copies data from the Unified Buffer (UB) to the L1 Buffer. + + :param src: The source data located in the Unified Buffer. + :type src: tl.tensor | bl.buffer + :param dst: The destination buffer located in L1 memory. + :type dst: tl.tensor | bl.buffer + """ + from warnings import warn + warn("copy_from_ub_to_l1 is deprecated, please use copy instead.") + return semantic.copy_from_ub_to_l1(src, dst, _builder) + + +@builtin +def copy(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], _builder: None) -> None: + """ + Copies data from the Unified Buffer (UB) to the Unified Buffer (UB) or L1 Buffer. + + :param src: The source data located in the Unified Buffer. + :type src: tl.tensor | bl.buffer + :param dst: The destination buffer located Unified Buffer (UB) or L1 memory. + :type dst: tl.tensor | bl.buffer + """ + return semantic.copy(src, dst, _builder) + + +def create_sync_block(sender, receiver, event_id, is_set: bool, sender_pipe=None, receiver_pipe=None, _builder=None): + sender = _constexpr_to_value(sender) + receiver = _constexpr_to_value(receiver) + assert isinstance(sender, str) and (sender == "cube" + or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver + == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + if isinstance(event_id, int): + assert (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + if sender == receiver: + raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') + if sender_pipe is None and receiver_pipe is None: + if sender == "cube": + sender_pipe = PIPE.PIPE_FIX + receiver_pipe = PIPE.PIPE_MTE2 + if sender == "vector": + sender_pipe = PIPE.PIPE_MTE3 + receiver_pipe = PIPE.PIPE_MTE2 + if not isinstance(sender_pipe, PIPE) or not isinstance(receiver_pipe, PIPE): + raise TypeError("sender_pipe and receiver_pipe must be instances of PIPE enum") + if is_set: + return semantic.create_sync_block_set(sender, receiver, event_id, sender_pipe, receiver_pipe, _builder) + return semantic.create_sync_block_wait(sender, receiver, event_id, sender_pipe, receiver_pipe, _builder) + + +@builtin +def sync_block_set(sender, receiver, event_id, sender_pipe=None, receiver_pipe=None, _builder=None): + return create_sync_block(sender, receiver, event_id, True, sender_pipe, receiver_pipe, _builder) + + +@builtin +def sync_block_wait(sender, receiver, event_id, sender_pipe=None, receiver_pipe=None, _builder=None): + return create_sync_block(sender, receiver, event_id, False, sender_pipe, receiver_pipe, _builder) + + +@builtin +def sync_block_all(mode, event_id, _builder=None): + mode = _constexpr_to_value(mode) + event_id = _constexpr_to_value(event_id) + assert isinstance(mode, str), f"mode: {mode} is not string" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + assert mode in ("all_cube", "all_vector", "all", + "all_sub_vector"), f"ERROR: mode = {mode}, only supports all_cube/all_vector/all/all_sub_vector" + _builder.sync_block_all(mode, event_id) + + +class FixpipeDMAMode(enum.Enum): + NZ2DN = ascend_ir.FixpipeDMAMode.NZ2DN + NZ2ND = ascend_ir.FixpipeDMAMode.NZ2ND + NZ2NZ = ascend_ir.FixpipeDMAMode.NZ2NZ + + +class FixpipeDualDstMode(enum.Enum): + NO_DUAL = ascend_ir.FixpipeDualDstMode.NO_DUAL + COLUMN_SPLIT = ascend_ir.FixpipeDualDstMode.COLUMN_SPLIT + ROW_SPLIT = ascend_ir.FixpipeDualDstMode.ROW_SPLIT + + +class FixpipePreQuantMode(enum.Enum): + NO_QUANT = ascend_ir.FixpipePreQuantMode.NO_QUANT + F322BF16 = ascend_ir.FixpipePreQuantMode.F322BF16 + F322F16 = ascend_ir.FixpipePreQuantMode.F322F16 + S322I8 = ascend_ir.FixpipePreQuantMode.S322I8 + + +class FixpipePreReluMode(enum.Enum): + LEAKY_RELU = ascend_ir.FixpipePreReluMode.LEAKY_RELU + NO_RELU = ascend_ir.FixpipePreReluMode.NO_RELU + NORMAL_RELU = ascend_ir.FixpipePreReluMode.NORMAL_RELU + P_RELU = ascend_ir.FixpipePreReluMode.P_RELU + + +@builtin +def fixpipe( + src: tl.tensor, + dst: bl.buffer, + dma_mode: FixpipeDMAMode = FixpipeDMAMode.NZ2ND, + dual_dst_mode: FixpipeDualDstMode = FixpipeDualDstMode.NO_DUAL, + _builder=None, +) -> None: + """ + Directly store a tensor on L0C to a local buffer via fixpipe. + Fixpipe is pipeline that performing data movement from L0C to other memory hierarchies. + Currently support: + - L0C to UB (for Ascend910_95 sereies) + + :param src: the source tensor, Must be located in the l0C memory region. + :type src: tl.tensor + :param dst: The destination buffer, Must be located in the UB memory region. + :type dst: bl.buffer + :param dma_mode: DMA transfer mode, "nz2nd" enables NZ to ND layout transformation + :type dma_mode: str + """ + if not _builder.is_910_95(): + raise RuntimeError("this feature is only supported on Ascend910_95") + if not isinstance(src, tl.tensor): + raise TypeError("src is not of tensor type") + elif not isinstance(dst, bl.buffer): + raise TypeError("dst is not of buffer type") + if dst.space != ascend_address_space.UB: + raise TypeError("dst must be located in the UB memory region") + + if len(dst.shape) == 2 and (dst.type.element_ty == tl.float32 or dst.type.element_ty == tl.int32): + N = dst.shape[1] + if N % 8 != 0: + raise ValueError("32b Fixpipe last dim must be aligned to 8") + if (dma_mode != FixpipeDMAMode.NZ2ND) and (N % 16 != 0): + raise ValueError("32b non-NZ2ND Fixpipe last dim must be aligned to 16") + if (dual_dst_mode == FixpipeDualDstMode.COLUMN_SPLIT) and (N % 32 != 0): + raise ValueError("32b Column split dual Fixpipe last dim must be aligned to 32") + M = dst.shape[0] + if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 8 != 0): + raise ValueError("32b NZ2DN Fixpipe first dim must be aligned to 8") + dst16bits = (dst.type.element_ty == tl.float16 or dst.type.element_ty == tl.int16 + or dst.type.element_ty == tl.bfloat16) + if len(dst.shape) == 2 and dst16bits: + N = dst.shape[1] + if N % 16 != 0: + raise ValueError("16b Fixpipe last dim must be aligned to 16") + M = dst.shape[0] + if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 16 != 0): + raise ValueError("16b NZ2DN Fixpipe first dim must be aligned to 16") + + return semantic.fixpipe(src, dst, dma_mode, dual_dst_mode, FixpipePreQuantMode.NO_QUANT, FixpipePreReluMode.NO_RELU, + _builder) + + +class SYNC_IN_VF(enum.Enum): + VV_ALL = enum.auto() + VST_VLD = enum.auto() + VLD_VST = enum.auto() + VST_VST = enum.auto() + VS_ALL = enum.auto() + VST_LD = enum.auto() + VLD_ST = enum.auto() + VST_ST = enum.auto() + SV_ALL = enum.auto() + ST_VLD = enum.auto() + LD_VST = enum.auto() + ST_VST = enum.auto() + + +@builtin +def debug_barrier( + sync_mode: SYNC_IN_VF, + _builder=None, +) -> None: + return semantic.debug_barrier(sync_mode.name, _builder) + + +@builtin +def sub_vec_num(_builder=None) -> tl.constexpr: + """ + Get the Vector Core Num on one AI Core. + """ + npuUtils = NPUUtils() + cube_num = npuUtils.get_aivector_core_num() + vector_num = npuUtils.get_aicore_num() + const_val = cube_num // vector_num + return tl.constexpr(const_val) diff --git a/third_party/ascend/language/cann/extension/custom_op.py b/third_party/ascend/language/cann/extension/custom_op.py index b3352c26b2..8644a6d9e6 100644 --- a/third_party/ascend/language/cann/extension/custom_op.py +++ b/third_party/ascend/language/cann/extension/custom_op.py @@ -152,22 +152,121 @@ def _args_to_operands(op, builder, args, kwargs): return operands +def _bind_op_arguments(op, args, kwargs): + if not op.signature.parameters: + return None + return op.signature.bind(*args, **kwargs) + + +def _make_align_dim_attrs(op, builder, arg_attrs): + # Find op argument by name using op.align_dim's key + # We want to return a dict mapping for each align_dim key -> int attribute for the actual bound argument value. + name = 'align_dim' + if not hasattr(op, name): + return + + # To find argument indices matching each align_dim key, check the op.signature parameters + # and map align_dim key (argument name) to its index position. + align_arg_indices = {} + if hasattr(op, "signature"): + param_names = list(op.signature.parameters.keys()) + for arg_name in op.align_dim.keys(): + if arg_name in param_names: + align_arg_indices[arg_name] = param_names.index(arg_name) + + for arg, align_val in op.align_dim.items(): + if isinstance(arg, str) and arg in align_arg_indices: + arg_attrs[align_arg_indices[arg]] = {name: builder.get_int_attr(align_val)} + print(arg_attrs[align_arg_indices[arg]]) + elif isinstance(arg, int): + arg_attrs[arg] = {name: builder.get_int_attr(align_val)} + print(arg_attrs[arg]) + else: + assert False, f"{name}'s keys should be string or int" + + +def _make_arg_attrs(op, builder): + num_args = len(op.signature.parameters) if hasattr(op, "signature") else 0 + arg_attrs = [{} for _ in range(num_args)] + + _make_align_dim_attrs(op, builder, arg_attrs) + return arg_attrs + + def _add_optional_attr(op, name, builder, attrs): if hasattr(op, name): attrs[name] = builder.get_str_attr(getattr(op, name)) +def _add_bitcode_attr(op, builder, attrs): + name = 'bitcode' + if not hasattr(op, name): + return + + from pathlib import Path + bitcode = Path(getattr(op, name)) + assert bitcode.exists(), f"Provided bitcode ({name}) not exist" + attrs[name] = builder.get_str_attr(str(bitcode.absolute())) + + +def _add_optional_extra_buffer_attr(op, builder, attrs): + name = 'extra_buffers' + if not hasattr(op, name): + return + + extra_buffers = getattr(op, name) + if isinstance(extra_buffers, tuple): + extra_buffers = [extra_buffers] + + extra_buffer_types, extra_buffer_sizes = zip(*extra_buffers) + attrs[name + "_types"] = builder.get_type_array_attr([ty.to_ir(builder) for ty in extra_buffer_types]) + attrs[name + "_sizes"] = builder.get_i64_array_attr(list(extra_buffer_sizes)) + + +def _add_optional_indexing_map_attr(op, builder, attrs): + # Optional indexing map attribute: + # `indexing_map` should be an iterable of al.affine_map (MLIR AffineMap) objects. + name = 'indexing_map' + if not hasattr(op, name): + return + + indexing_map = getattr(op, name) + attrs[name] = builder.get_affine_map_array_attr(indexing_map) + + +def _add_optional_iterator_types_attr(op, builder, attrs): + name = 'iterator_types' + if not hasattr(op, name): + return + + attrs[name] = builder.get_iterator_types_attr([iterator_type.value for iterator_type in getattr(op, name)]) + + def _make_attrs(op, builder): attrs = { 'hivm.tcore_type': builder.get_core_type_attr(op.core.value), 'hivm.pipe': builder.get_pipe_attr(op.pipe.value), 'hivm.vf_mode': builder.get_vf_mode_attr(op.mode.value), } + + if not op.name.startswith('__builtin_'): + assert hasattr(op, 'symbol'), f"Non builtin custom op, symbol is required." + assert hasattr(op, 'bitcode'), f"Non builtin custom op, bitcode path is required." + + # Add bit code path attribute, formalize to abosulte path. + _add_bitcode_attr(op, builder, attrs) + + _add_optional_indexing_map_attr(op, builder, attrs) + _add_optional_iterator_types_attr(op, builder, attrs) + + _add_optional_extra_buffer_attr(op, builder, attrs) + _add_optional_attr(op, 'symbol', builder, attrs) _add_optional_attr(op, 'source', builder, attrs) _add_optional_attr(op, 'compile', builder, attrs) # Extra attributes can be added here, such as op.extra_attr="attr_a=xx" _add_optional_attr(op, 'extra_attr', builder, attrs) + return attrs @@ -207,8 +306,9 @@ def custom_semantic(name: str, *args, _builder=None, **kwargs): inputs = _args_to_operands(op, _builder, args, kwargs) # Setup attributes. attrs = _make_attrs(op, _builder) + arg_attrs = _make_arg_attrs(op, _builder) # Build IR for the custom op. - res = _builder.create_custom_op(name, attrs, inputs, outputs) + res = _builder.create_custom_op(name, attrs, inputs, outputs, arg_attrs) # Results with same types as outputs. res_types = [out.type for out in outs] return _to_result(res, res_types) @@ -228,6 +328,7 @@ def register_custom_op(op): setattr(op, 'name', op.__name__) # The op name should not be used. assert op.name not in _custom_op_registry, f"Custom op name '{op.name}' already used." + # Check required core, pipe, mode fields. assert hasattr(op, 'core'), "'core' field is required." assert hasattr(op, 'pipe'), "'pipe' field is required." diff --git a/third_party/ascend/language/cann/extension/mem_ops.py b/third_party/ascend/language/cann/extension/mem_ops.py index 859bbecd67..72ba4a54ea 100644 --- a/third_party/ascend/language/cann/extension/mem_ops.py +++ b/third_party/ascend/language/cann/extension/mem_ops.py @@ -1,551 +1,538 @@ -import numbers -import triton.language as tl -from triton.language import semantic as real_semantic -from triton.language.core import ( - _constexpr_to_value, - _tensor_member_fn, - _unwrap_iterable, - builtin, - constexpr, - dtype, - tensor, - check_bit_width, - _unwrap_if_constexpr, -) -from triton.language.semantic import ( - wrap_tensor, - _str_to_rounding_mode, - not_equal, - _str_to_dot_input_precision, - binary_op_type_checking_impl, - integer_promote_impl, - broadcast_impl_shape, - _str_to_sem, - _str_to_scope, - bitcast, - bitwise_op_type_checking_impl, - to_tensor, - _str_to_load_cache_modifier, - _str_to_eviction_policy, - _str_to_padding_option, - _canonicalize_boundary_check, -) - -from typing import Optional, Tuple, List, overload, Union -from triton._C.libtriton import ir - -from ._utils import _convert_elem_to_ir_value - - -@_tensor_member_fn -@builtin -def index_select(src: tensor, idx: tensor, bound, lstdim_blksiz, offsets, numels, _builder=None): - """ - Embedding - :src_ptr: - :idx: - """ - - def embedding_gather_impl(src: tl.tensor, idx: tl.tensor, bound: int, blksiz: int, offsets: Tuple, numels: Tuple, - builder: ir.builder) -> tl.tensor: - assert idx.dtype.is_int(), "index must be an integer tensor" - if not src.dtype.element_ty.is_floating(): - raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {src.dtype.element_ty}") - - require_i64 = idx.dtype.is_int64() - # require_i64 = True - offsets = [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in offsets] - numels = [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in numels] - ret = builder.create_embedding_gather(src.handle, idx.handle, bound, blksiz, offsets, numels) - ret_shape = [_unwrap_if_constexpr(s) for s in idx.shape] - ret_shape.append(blksiz) - return wrap_tensor(ret, src.dtype.element_ty, ret_shape) - - bound = _constexpr_to_value(bound) - lstdim_blksiz = _constexpr_to_value(lstdim_blksiz) - - return embedding_gather_impl(src, idx, bound, lstdim_blksiz, offsets, numels, _builder) - - -@_tensor_member_fn -@builtin -def index_put(ptr: tensor, index: tensor, value: tensor, dim: int, index_boundary: int, end_offset: tuple, - start_offset: tuple, dst_stride: tuple, _builder=None): - """ - Index put values from a tensor into a destination tensor. - - Index put operation for different tensor ranks: - 1. 2D index scatter (0 <= dim < 1): - 1.1 dim = 0 - out[index[i]][start_offset[1]:end_offset[1]] = value[i][0:end_offset[1]-start_offset[1]] - 2. 3D index scatter (0 <= dim < 2): - 2.1 dim = 0 - out[index[i]][start_offset[1]:end_offset[1]][start_offset[2]:end_offset[2]] - = value[i][0:end_offset[1]-start_offset[1]][0:end_offset[2]-start_offset[2]] - 2.2 dim = 1 - out[start_offset[0]:end_offset[0]][index[j]][start_offset[2]:end_offset[2]] - = value[0:end_offset[0]-start_offset[0]][j][0:end_offset[2]-start_offset[2]] - - - :param ptr: pointer type, the destination tensor pointer (in GM) - :param index: tensor, a index to scatter (in UB) - :param value: tensor, a value to store (in UB) - :param dim: int32, the dimension to scatter along - :param index_boundary: int64, the upper boundary for index values - :param end_offset: tuple of int, the offsets of each dimension for the end of the scatter region - :param start_offset: tuple of int, the offsets of each dimension for the start of the scatter region - :param dst_stride: tuple of int, the stride of each dimension of destination tensor - - Constraints - *********** - - `ptr` and `value` must have the same rank. - - `ptr.dtype` only supports `float16`, `bfloat16`, `float32` currently. - - `index` must be an integer tensor. If `index.rank` != 1, it will be reshaped to 1D. - - `index.numel` must equal `value.shape[dim]`. - - `value` support 2~5D tensors. - - `dim` must be valid (0 <= dim < rank(value) - 1). - - Example - ******* - .. code-block:: python - - import torch - import triton - import triton.language as tl - from triton.language.extra.cann.extension import index_put - - @triton.jit - def simple_index_put_kernel(value_ptr, index_ptr, dst_ptr): - # index tile shape: [2] - index_local = tl.arange(0, 2) - x1_local = tl.arange(0, 2)[None, :] # shape=(1,2) - - index_tile = tl.load(index_ptr + index_local) - value_tile = tl.load(value_ptr + index_local[:, None]*2 + x1_local) - - index_put( - ptr=dst_ptr, - index=index_tile, - value=value_tile, - dim=0, - index_boundary=4, - end_offset=(2, 2), - start_offset=(0, 0), - dst_stride=(2, 1) - ) - - dst = torch.zeros((4,2), device='npu', dtype=torch.float32) - value = torch.tensor([[1.,2.], [3.,4.]], device='npu') - index = torch.tensor([2, 0], device='npu') - - simple_index_put_kernel[(1,)](value, index, dst) - print("IndexPut result:", dst) # ref:[[3.,4.], [0.,0.], [1.,2.], [0.,0.]] - """ - - def index_put_impl(ptr: tl.tensor, index: tl.tensor, value: tl.tensor, dim: int, index_boundary: int, - end_offset: Tuple, start_offset: Tuple, dst_stride: Tuple, builder: ir.builder): - assert index.dtype.is_int(), "index must be an integer tensor" - if not ptr.dtype.element_ty.is_floating(): - raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {ptr.dtype.element_ty}") - if not isinstance(dim, int): - raise ValueError("dim must be of type tl.constexpr") - - v_rank = len(value.shape) - idx_rank = len(index.shape) - if v_rank < 2 or v_rank > 5: - raise ValueError(f"value rank must be in [2, 5], got value rank={v_rank}") - if dim < 0 or dim >= v_rank - 1: - raise ValueError(f"dim must satisfy 0<=dim 5: - raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") - if dim < 0 or dim >= idx_rank: - raise ValueError(f"dim must satisfy 0<=dim 5: - raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") - if dim < 0 or dim >= idx_rank: - raise ValueError(f"dim must satisfy 0<=dim 0 - - dim = _constexpr_to_value(dim) - index_boundary = _constexpr_to_value(index_boundary) - value = _constexpr_to_value(value) - - if not _is_ranked_tensor(value) or isinstance(value, constexpr): - element_ty = ptr.type.scalar.element_ty - value = real_semantic.full(index.shape, value, element_ty, _builder) - return scatter_ub_to_out_impl(ptr, value, index, index_boundary, dim, dst_stride, end_offset, start_offset, - _builder) - - -@_tensor_member_fn -@builtin -def index_select_simd(src, dim, index, src_shape, src_offset, read_shape, _builder=None) -> tensor: - """ - Parallel index_select operation from Global Memory to Unified Buffer (SIMD version). - - Selects data from multiple indices along a specified dimension and loads - them as tiles from GM directly to UB with zero-copy semantics. - - :param src: Source tensor pointer (in GM) - :type src: tensor (pointer type) - :param dim: The dimension along which to select indices - :type dim: int or constexpr - :param index: 1D tensor of indices to select (in UB) - :type index: tensor - :param src_shape: Complete shape of the source tensor (can be int or tensor) - :type src_shape: List[Union[int, tensor]] - :param src_offset: Starting offset for reading (can be int or tensor) - :type src_offset: List[Union[int, tensor]] - :param read_shape: Size to read (tile shape, can be int or tensor) - :type read_shape: List[Union[int, tensor]] - - **Constraints:** - - - ``read_shape[dim]`` must be ``-1`` - - ``src_offset[dim]`` can be ``-1`` (will be ignored) - - Boundary handling: ``src_offset + read_shape > src_shape`` automatically - truncates to ``src_shape`` boundary - - Does not check if ``index`` contains out-of-bounds values - - **Example:** - - .. code-block:: python - - @triton.jit - def kernel(src_ptr, output_ptr, indices_ptr, M, N, D, ...): - # Load indices (e.g., [5, 10, 15, 20]) - indices = tl.load(indices_ptr + tl.arange(0, 4)) - - # Example 1: Static shapes (constants) - # Index select from dimension 1 - # src: [8, 100, 256], index_select at dim=1 - # Read: [4, ?, 128] starting from [4, ?, 128] - result = extension.index_select_simd( - src_ptr, - dim=1, - index=indices, - src_shape=[8, 100, 256], - src_offset=[4, -1, 128], - read_shape=[4, -1, 128] - ) - # result shape: [4, 4, 128] - - # Example 2: Dynamic shapes (variables) - result2 = extension.index_select_simd( - src_ptr, - dim=1, - index=indices, - src_shape=[M, N, D], - src_offset=[4, -1, 128], - read_shape=[4, -1, 128] - ) - - tl.store(output_ptr + ..., result) - - :return: Result tensor in UB with shape where ``dim`` is replaced - by the length of ``index`` - :rtype: tensor - """ - - def index_select_simd_impl(src: tl.tensor, dim: int, index: tl.tensor, src_shape: List[Union[int, tl.tensor]], - src_offset: List[Union[int, tl.tensor]], read_shape: List[Union[int, tl.tensor]], - builder: ir.builder) -> tl.tensor: - # Validate inputs - ndim = len(src_shape) - assert len(src_offset) == ndim, \ - f"src_offset length {len(src_offset)} must match src_shape length {ndim}" - assert len(read_shape) == ndim, \ - f"read_shape length {len(read_shape)} must match src_shape length {ndim}" - assert 0 <= dim < ndim, \ - f"dim={dim} must be in range [0, {ndim})" - assert len(index.shape) == 1, \ - f"index must be 1D tensor, got {len(index.shape)}D" - assert dim < ndim - 1, \ - f"index_select_simd cannot support trailing dimension as dim={dim}, ndim={ndim}" - - newsrc_shape = [o.handle for o in src_shape] - newsrc_offset = [o.handle for o in src_offset] - # Create output type - return_shape = [index.shape[0] if i == dim else read_shape[i] for i in range(ndim)] - element_ty = src.type.element_ty - output_ty = tl.block_type(element_ty, return_shape) - out = builder.create_index_select_simd(src.handle, index.handle, dim, newsrc_shape, newsrc_offset, read_shape, - return_shape) - return tl.tensor(out, output_ty) - - dim = _constexpr_to_value(dim) - - # Process shape parameters: convert constexpr to values, keep tensors as-is - def process_param(val): - """Convert constexpr to value, keep tensor or int as-is""" - if isinstance(val, tensor): - return val - else: - return _constexpr_to_value(val) - - newsrc_shape = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_shape] - newsrc_offset = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_offset] - assert len(index.shape) == 1, "index must be a 1D tensor" - - return index_select_simd_impl(src, dim, index, newsrc_shape, newsrc_offset, read_shape, _builder) +import numbers +import triton.language as tl +from triton.language import semantic as real_semantic +from triton.language.core import ( + _constexpr_to_value, + _tensor_member_fn, + _unwrap_iterable, + builtin, + constexpr, + dtype, + tensor, + check_bit_width, + _unwrap_if_constexpr, +) +from triton.language.semantic import ( + wrap_tensor, + _str_to_rounding_mode, + not_equal, + _str_to_dot_input_precision, + binary_op_type_checking_impl, + integer_promote_impl, + broadcast_impl_shape, + _str_to_sem, + _str_to_scope, + bitcast, + bitwise_op_type_checking_impl, + to_tensor, + _str_to_load_cache_modifier, + _str_to_eviction_policy, + _str_to_padding_option, + _canonicalize_boundary_check, +) + +from typing import Optional, Tuple, List, overload, Union +from triton._C.libtriton import ir + +from ._utils import _convert_elem_to_ir_value + + +@_tensor_member_fn +@builtin +def index_put(ptr: tensor, index: tensor, value: tensor, dim: int, index_boundary: int, end_offset: tuple, + start_offset: tuple, dst_stride: tuple, _builder=None): + """ + Index put values from a tensor into a destination tensor. + + Index put operation for different tensor ranks: + 1. 2D index scatter (0 <= dim < 1): + 1.1 dim = 0 + out[index[i]][start_offset[1]:end_offset[1]] = value[i][0:end_offset[1]-start_offset[1]] + 2. 3D index scatter (0 <= dim < 2): + 2.1 dim = 0 + out[index[i]][start_offset[1]:end_offset[1]][start_offset[2]:end_offset[2]] + = value[i][0:end_offset[1]-start_offset[1]][0:end_offset[2]-start_offset[2]] + 2.2 dim = 1 + out[start_offset[0]:end_offset[0]][index[j]][start_offset[2]:end_offset[2]] + = value[0:end_offset[0]-start_offset[0]][j][0:end_offset[2]-start_offset[2]] + + + :param ptr: pointer type, the destination tensor pointer (in GM) + :param index: tensor, a index to scatter (in UB) + :param value: tensor, a value to store (in UB) + :param dim: int32, the dimension to scatter along + :param index_boundary: int64, the upper boundary for index values + :param end_offset: tuple of int, the offsets of each dimension for the end of the scatter region + :param start_offset: tuple of int, the offsets of each dimension for the start of the scatter region + :param dst_stride: tuple of int, the stride of each dimension of destination tensor + + Constraints + *********** + - `ptr` and `value` must have the same rank. + - `ptr.dtype` only supports `float16`, `bfloat16`, `float32` currently. + - `index` must be an integer tensor. If `index.rank` != 1, it will be reshaped to 1D. + - `index.numel` must equal `value.shape[dim]`. + - `value` support 2~5D tensors. + - `dim` must be valid (0 <= dim < rank(value) - 1). + + Example + ******* + .. code-block:: python + + import torch + import triton + import triton.language as tl + from triton.language.extra.cann.extension import index_put + + @triton.jit + def simple_index_put_kernel(value_ptr, index_ptr, dst_ptr): + # index tile shape: [2] + index_local = tl.arange(0, 2) + x1_local = tl.arange(0, 2)[None, :] # shape=(1,2) + + index_tile = tl.load(index_ptr + index_local) + value_tile = tl.load(value_ptr + index_local[:, None]*2 + x1_local) + + index_put( + ptr=dst_ptr, + index=index_tile, + value=value_tile, + dim=0, + index_boundary=4, + end_offset=(2, 2), + start_offset=(0, 0), + dst_stride=(2, 1) + ) + + dst = torch.zeros((4,2), device='npu', dtype=torch.float32) + value = torch.tensor([[1.,2.], [3.,4.]], device='npu') + index = torch.tensor([2, 0], device='npu') + + simple_index_put_kernel[(1,)](value, index, dst) + print("IndexPut result:", dst) # ref:[[3.,4.], [0.,0.], [1.,2.], [0.,0.]] + """ + + def index_put_impl(ptr: tl.tensor, index: tl.tensor, value: tl.tensor, dim: int, index_boundary: int, + end_offset: Tuple, start_offset: Tuple, dst_stride: Tuple, _builder: ir.builder): + assert index.dtype.is_int(), "index must be an integer tensor" + if not ptr.dtype.element_ty.is_floating(): + raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {ptr.dtype.element_ty}") + if not isinstance(dim, int): + raise ValueError("dim must be of type tl.constexpr") + + v_rank = len(value.shape) + idx_rank = len(index.shape) + if v_rank < 2 or v_rank > 5: + raise ValueError(f"value rank must be in [2, 5], got value rank={v_rank}") + if dim < 0 or dim >= v_rank - 1: + raise ValueError(f"dim must satisfy 0<=dim 5: + raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") + if dim < 0 or dim >= idx_rank: + raise ValueError(f"dim must satisfy 0<=dim 5: + raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") + if dim < 0 or dim >= idx_rank: + raise ValueError(f"dim must satisfy 0<=dim 0 + + dim = _constexpr_to_value(dim) + index_boundary = _constexpr_to_value(index_boundary) + value = _constexpr_to_value(value) + + if not _is_ranked_tensor(value) or isinstance(value, constexpr): + element_ty = ptr.type.scalar.element_ty + value = real_semantic.full(index.shape, value, element_ty, _builder) + return scatter_ub_to_out_impl(ptr, value, index, index_boundary, dim, dst_stride, end_offset, start_offset, + _builder) + + +@_tensor_member_fn +@builtin +def index_select_simd(src, dim, index, src_shape, src_offset, read_shape, _builder=None) -> tensor: + """ + Parallel index_select operation from Global Memory to Unified Buffer (SIMD version). + + Selects data from multiple indices along a specified dimension and loads + them as tiles from GM directly to UB with zero-copy semantics. + + :param src: Source tensor pointer (in GM) + :type src: tensor (pointer type) + :param dim: The dimension along which to select indices + :type dim: int or constexpr + :param index: 1D tensor of indices to select (in UB) + :type index: tensor + :param src_shape: Complete shape of the source tensor (can be int or tensor) + :type src_shape: List[Union[int, tensor]] + :param src_offset: Starting offset for reading (can be int or tensor) + :type src_offset: List[Union[int, tensor]] + :param read_shape: Size to read (tile shape, can be int or tensor) + :type read_shape: List[Union[int, tensor]] + + **Constraints:** + + - ``read_shape[dim]`` must be ``-1`` + - ``src_offset[dim]`` can be ``-1`` (will be ignored) + - Boundary handling: ``src_offset + read_shape > src_shape`` automatically + truncates to ``src_shape`` boundary + - Does not check if ``index`` contains out-of-bounds values + + **Example:** + + .. code-block:: python + + @triton.jit + def kernel(src_ptr, output_ptr, indices_ptr, M, N, D, ...): + # Load indices (e.g., [5, 10, 15, 20]) + indices = tl.load(indices_ptr + tl.arange(0, 4)) + + # Example 1: Static shapes (constants) + # Index select from dimension 1 + # src: [8, 100, 256], index_select at dim=1 + # Read: [4, ?, 128] starting from [4, ?, 128] + result = extension.index_select_simd( + src_ptr, + dim=1, + index=indices, + src_shape=[8, 100, 256], + src_offset=[4, -1, 128], + read_shape=[4, -1, 128] + ) + # result shape: [4, 4, 128] + + # Example 2: Dynamic shapes (variables) + result2 = extension.index_select_simd( + src_ptr, + dim=1, + index=indices, + src_shape=[M, N, D], + src_offset=[4, -1, 128], + read_shape=[4, -1, 128] + ) + + tl.store(output_ptr + ..., result) + + :return: Result tensor in UB with shape where ``dim`` is replaced + by the length of ``index`` + :rtype: tensor + """ + + def index_select_simd_impl(src: tl.tensor, dim: int, index: tl.tensor, src_shape: List[Union[int, tl.tensor]], + src_offset: List[Union[int, tl.tensor]], read_shape: List[Union[int, tl.tensor]], + _builder: ir.builder) -> tl.tensor: + # Validate inputs + ndim = len(src_shape) + assert len(src_offset) == ndim, \ + f"src_offset length {len(src_offset)} must match src_shape length {ndim}" + assert len(read_shape) == ndim, \ + f"read_shape length {len(read_shape)} must match src_shape length {ndim}" + assert 0 <= dim < ndim, \ + f"dim={dim} must be in range [0, {ndim})" + assert len(index.shape) == 1, \ + f"index must be 1D tensor, got {len(index.shape)}D" + assert dim < ndim - 1, \ + f"index_select_simd cannot support trailing dimension as dim={dim}, ndim={ndim}" + # Handle both tensor and int offsets (for interpreter mode) + newsrc_shape = [] + for s in src_shape: + if isinstance(s, tensor): + newsrc_shape.append(s.handle) + elif isinstance(s, int): + # For interpreter mode: keep as int + newsrc_shape.append(s) + else: + newsrc_shape.append(s.handle if hasattr(s, 'handle') else s) + newsrc_offset = [] + for s in src_offset: + if isinstance(s, tensor): + newsrc_offset.append(s.handle) + elif isinstance(s, int): + # For interpreter mode: keep as int + newsrc_offset.append(s) + else: + newsrc_offset.append(s.handle if hasattr(s, 'handle') else s) + + # Create output type + return_shape = [index.shape[0] if i == dim else read_shape[i] for i in range(ndim)] + element_ty = src.type.element_ty + output_ty = tl.block_type(element_ty, return_shape) + out = _builder.create_index_select_simd(src.handle, index.handle, dim, newsrc_shape, newsrc_offset, read_shape, + return_shape) + return tl.tensor(out, output_ty) + + dim = _constexpr_to_value(dim) + + # Process shape parameters: convert constexpr to values, keep tensors as-is + def process_param(val): + """Convert constexpr to value, keep tensor or int as-is""" + if isinstance(val, tensor): + return val + else: + return _constexpr_to_value(val) + + newsrc_shape = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_shape] + newsrc_offset = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_offset] + assert len(index.shape) == 1, "index must be a 1D tensor" + + return index_select_simd_impl(src, dim, index, newsrc_shape, newsrc_offset, read_shape, _builder) diff --git a/third_party/ascend/language/cann/extension/semantic.py b/third_party/ascend/language/cann/extension/semantic.py index 29df62e651..e4a90ad9d5 100644 --- a/third_party/ascend/language/cann/extension/semantic.py +++ b/third_party/ascend/language/cann/extension/semantic.py @@ -1,129 +1,148 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# Copyright 2018-2020 Philippe Tillet -# Copyright 2020-2022 OpenAI -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -__all__ = [ - "fixpipe", - "create_address_space", -] - -import enum -from typing import (TypeVar, List, Union) - -from triton._C.libtriton import ir -from triton._C.libtriton.ascend import ir as ascend_ir -import triton.language.core as tl -import triton.language.extra.cann.extension as al -import triton.extension.buffer.language as bl - -from triton.language import semantic as real_semantic - -T = TypeVar('T') - - -def create_address_space(address_space: ascend_ir.AddressSpace, - builder: ascend_ir.ascendnpu_ir_builder) -> ir.attribute: - return builder.get_target_attribute(address_space) - - -class PIPE(enum.Enum): - PIPE_S = ascend_ir.PIPE.PIPE_S - PIPE_V = ascend_ir.PIPE.PIPE_V - PIPE_M = ascend_ir.PIPE.PIPE_M - PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1 - PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2 - PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3 - PIPE_ALL = ascend_ir.PIPE.PIPE_ALL - PIPE_FIX = ascend_ir.PIPE.PIPE_FIX - - -def create_sync_block_set(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): - if isinstance(event_id, int): - _builder.sync_block_set(sender, receiver, - real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, sender_pipe.value, - receiver_pipe.value) - elif isinstance(event_id, tl.constexpr): - _builder.sync_block_set(sender, receiver, - real_semantic.to_tensor(event_id, _builder).handle, sender_pipe.value, - receiver_pipe.value) - else: - _builder.sync_block_set(sender, receiver, event_id.handle, sender_pipe.value, receiver_pipe.value) - - -def create_sync_block_wait(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): - if isinstance(event_id, int): - _builder.sync_block_wait(sender, receiver, - real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, sender_pipe.value, - receiver_pipe.value) - elif isinstance(event_id, tl.constexpr): - _builder.sync_block_wait(sender, receiver, - real_semantic.to_tensor(event_id, _builder).handle, sender_pipe.value, - receiver_pipe.value) - else: - _builder.sync_block_wait(sender, receiver, event_id.handle, sender_pipe.value, receiver_pipe.value) - - -def sub_vec_id(builder: ascend_ir.ascendnpu_ir_builder) -> tl.tensor: - return tl.tensor(builder.create_get_sub_vec_id(), tl.int64) - - -def copy_from_ub_to_l1(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], builder): - if not builder.is_910_95(): - raise RuntimeError("this feature is only supported on Ascend910_95") - if isinstance(src, tl.tensor) or isinstance(dst, tl.tensor): - raise TypeError("tensor not support yet") - if src.shape != dst.shape: - raise TypeError("src and dst must have same shape") - if src.dtype != dst.dtype: - raise TypeError("src and dst need to have the same type") - if isinstance(src, bl.buffer) and isinstance(dst, bl.buffer): - if src.space != al.ascend_address_space.UB: - raise TypeError("src's AddressSpace must be UB") - if dst.space != al.ascend_address_space.L1: - raise TypeError("dst's AddressSpace must be L1") - builder.create_copy_buffer(src.handle, dst.handle) - else: - raise TypeError("src and dst must be tl.tensor or bl.buffer") - - -def fixpipe( - src: tl.tensor, - dst, - dma_mode, - dual_dst_mode, - pre_quant_mode, - pre_relu_mode, - builder: ascend_ir.ascendnpu_ir_builder, -) -> None: - builder.create_fixpipe( - src.handle, - dst.handle, - dma_mode.value, - dual_dst_mode.value, - pre_quant_mode.value, - pre_relu_mode.value, - ) - - -def debug_barrier(sync_mode: str, builder) -> None: - target = tl.tensor(builder.get_int64(0), tl.int64) - attr = builder.get_str_attr(sync_mode) - builder.create_debug_barrier(target.handle, "SYNC_IN_VF", attr) +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = [ + "fixpipe", + "create_address_space", +] + +import enum +from typing import (TypeVar, List, Union) + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +import triton.language.core as tl +import triton.language.extra.cann.extension as al +import triton.extension.buffer.language as bl + +from triton.language import semantic as real_semantic + +T = TypeVar('T') + + +def create_address_space(address_space: ascend_ir.AddressSpace, + builder: ascend_ir.ascendnpu_ir_builder) -> ir.attribute: + return builder.get_target_attribute(address_space) + + +class PIPE(enum.Enum): + PIPE_S = ascend_ir.PIPE.PIPE_S + PIPE_V = ascend_ir.PIPE.PIPE_V + PIPE_M = ascend_ir.PIPE.PIPE_M + PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1 + PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2 + PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3 + PIPE_ALL = ascend_ir.PIPE.PIPE_ALL + PIPE_FIX = ascend_ir.PIPE.PIPE_FIX + + +def create_sync_block_set(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): + if isinstance(event_id, int): + _builder.sync_block_set(sender, receiver, + real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, sender_pipe.value, + receiver_pipe.value) + elif isinstance(event_id, tl.constexpr): + _builder.sync_block_set(sender, receiver, + real_semantic.to_tensor(event_id, _builder).handle, sender_pipe.value, + receiver_pipe.value) + else: + _builder.sync_block_set(sender, receiver, event_id.handle, sender_pipe.value, receiver_pipe.value) + + +def create_sync_block_wait(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): + if isinstance(event_id, int): + _builder.sync_block_wait(sender, receiver, + real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, sender_pipe.value, + receiver_pipe.value) + elif isinstance(event_id, tl.constexpr): + _builder.sync_block_wait(sender, receiver, + real_semantic.to_tensor(event_id, _builder).handle, sender_pipe.value, + receiver_pipe.value) + else: + _builder.sync_block_wait(sender, receiver, event_id.handle, sender_pipe.value, receiver_pipe.value) + + +def sub_vec_id(builder: ascend_ir.ascendnpu_ir_builder) -> tl.tensor: + return tl.tensor(builder.create_get_sub_vec_id(), tl.int64) + + +def copy_from_ub_to_l1(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], builder): + if not builder.is_910_95(): + raise RuntimeError("this feature is only supported on Ascend910_95") + if isinstance(src, tl.tensor) or isinstance(dst, tl.tensor): + raise TypeError("tensor not support yet") + if src.shape != dst.shape: + raise TypeError("src and dst must have same shape") + if src.dtype != dst.dtype: + raise TypeError("src and dst need to have the same type") + if isinstance(src, bl.buffer) and isinstance(dst, bl.buffer): + if src.space != al.ascend_address_space.UB: + raise TypeError("src's AddressSpace must be UB") + if dst.space != al.ascend_address_space.L1: + raise TypeError("dst's AddressSpace must be L1") + builder.create_copy_buffer(src.handle, dst.handle) + else: + raise TypeError("src and dst must be tl.tensor or bl.buffer") + + +def copy(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], builder): + if not builder.is_910_95(): + raise RuntimeError("this feature is only supported on Ascend910_95") + if isinstance(src, tl.tensor) or isinstance(dst, tl.tensor): + raise TypeError("tensor not support yet") + if src.shape != dst.shape: + raise TypeError("src and dst must have same shape") + if src.dtype != dst.dtype: + raise TypeError("src and dst need to have the same type") + if isinstance(src, bl.buffer) and isinstance(dst, bl.buffer): + if src.space != al.ascend_address_space.UB: + raise TypeError("src's AddressSpace must be UB") + if dst.space not in (al.ascend_address_space.L1, al.ascend_address_space.UB): + raise TypeError("dst's AddressSpace must be UB or L1") + builder.create_copy_buffer(src.handle, dst.handle) + else: + raise TypeError("src and dst must be tl.tensor or bl.buffer") + + +def fixpipe( + src: tl.tensor, + dst, + dma_mode, + dual_dst_mode, + pre_quant_mode, + pre_relu_mode, + builder: ascend_ir.ascendnpu_ir_builder, +) -> None: + builder.create_fixpipe( + src.handle, + dst.handle, + dma_mode.value, + dual_dst_mode.value, + pre_quant_mode.value, + pre_relu_mode.value, + ) + + +def debug_barrier(sync_mode: str, builder) -> None: + target = tl.tensor(builder.get_int64(0), tl.int64) + attr = builder.get_str_attr(sync_mode) + builder.create_debug_barrier(target.handle, "SYNC_IN_VF", attr) diff --git a/third_party/ascend/language/cann/extension/vec_ops.py b/third_party/ascend/language/cann/extension/vec_ops.py index effbbc0fa7..57e152f9a2 100644 --- a/third_party/ascend/language/cann/extension/vec_ops.py +++ b/third_party/ascend/language/cann/extension/vec_ops.py @@ -1,535 +1,548 @@ -# insert_slice -# extract_slice -# get_element -# sort -# flip -# gather - -import triton.language as tl -from triton.language import semantic, core, standard -from triton.language.core import (_constexpr_to_value, _tensor_member_fn, _unwrap_iterable, builtin, constexpr, dtype, - tensor, check_bit_width, _unwrap_if_constexpr, range) -from triton.language.semantic import ( - wrap_tensor, - _str_to_rounding_mode, - not_equal, - _str_to_dot_input_precision, - binary_op_type_checking_impl, - integer_promote_impl, - broadcast_impl_shape, - _str_to_sem, - _str_to_scope, - bitcast, - bitwise_op_type_checking_impl, - to_tensor, - _str_to_load_cache_modifier, - _str_to_eviction_policy, - _str_to_padding_option, - _canonicalize_boundary_check, -) - -from . import is_compile_on_910_95 -from .aux_ops import compile_hint_impl - -from typing import Optional, Tuple, List, overload -from triton._C.libtriton import ir - - -@_tensor_member_fn -@builtin -def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: - """ - Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. - - :param ful: The tensor to receive tensor. - :type ful: Tensor - :param sub: The tensor to be inserted. - :type sub: Tensor - :param offsets: - :type offsets: tuple of ints - :param sizes: - :type sizes: tuple of ints - :param strides: - :type strides: tuple of ints - """ - - def insert_slice_impl(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], - builder: ir.builder) -> tensor: - assert (len(ful.shape) == len(offsets)) - assert (len(ful.shape) == len(sizes)) - assert (len(ful.shape) == len(strides)) - assert (all([s >= 1 for s in sizes])) - assert (all([s >= 0 for s in strides])) - # Handle both tensor and int offsets (for interpreter mode) - new_offsets = [] - for o in offsets: - if isinstance(o, tensor): - new_offsets.append(o.handle) - elif isinstance(o, int): - # For interpreter mode: keep as int - new_offsets.append(o) - else: - new_offsets.append(o.handle if hasattr(o, 'handle') else o) - ret_type = tl.block_type(ful.type.scalar, ful.shape) - out = builder.create_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) - return tensor(out, ret_type) - - assert len(ful.shape) > 0 - assert len(ful.shape) == len(sub.shape) - new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] - out = insert_slice_impl(ful, sub, new_offsets, sizes, strides, _builder) - return out - - -@_tensor_member_fn -@builtin -def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: - """ - Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. - - :param ful: The tensor to split. - :type ful: Tensor - :param offsets: - :type offsets: tuple of ints - :param sizes: - :type sizes: tuple of ints - :param strides: - :type strides: tuple of ints - """ - - def extract_slice_impl(ful: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], - builder: ir.builder) -> tensor: - assert (len(ful.shape) == len(offsets)) - assert (len(ful.shape) == len(sizes)) - assert (len(ful.shape) == len(strides)) - assert (all([s >= 1 for s in sizes])) - assert (all([s >= 0 for s in strides])) - # Handle both tensor and int offsets (for interpreter mode) - new_offsets = [] - for o in offsets: - if isinstance(o, tensor): - new_offsets.append(o.handle) - elif isinstance(o, int): - # For interpreter mode: keep as int - new_offsets.append(o) - else: - new_offsets.append(o.handle if hasattr(o, 'handle') else o) - ret_type = tl.block_type(ful.type.scalar, sizes) - out = builder.create_extract_slice(ful.handle, new_offsets, sizes, strides) - return tensor(out, ret_type) - - assert len(ful.shape) > 0 - new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] - sub = extract_slice_impl(ful, new_offsets, sizes, strides, _builder) - return sub - - -@_tensor_member_fn -@builtin -def get_element(src, indice, _builder=None, _generator=None): - """ - get_element op reads a ranked tensor and returns one element as specified by the given indices. - The result of the op is a value with the same type as the elements of the tensor. - The arity of indices must match the rank of the accessed value. - - :param src: The tensor to be accessed. - :type src: Tensor - :param indice: - :type indice: tuple of ints - """ - - def get_element_impl(src: tensor, indice: List[tensor], builder: ir.builder): - if len(src.shape) != len(indice): - raise ValueError("Indice's rank must be equal to src tensor's rank") - - # Handle both tensor and int indices (for interpreter mode) - new_indice = [] - for i in indice: - if isinstance(i, tensor): - new_indice.append(i.handle) - elif isinstance(i, int): - # For interpreter mode: convert int to TensorHandle - new_indice.append(i) - else: - # Try to use .handle attribute if available - new_indice.append(i.handle if hasattr(i, 'handle') else i) - - result = builder.create_extract_scalar(src.handle, new_indice) - return wrap_tensor(result, src.type.scalar, None) - - assert len(src.shape) > 0 - new_indice = [semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i for i in indice] - return get_element_impl(src, new_indice, _builder) - - -@builtin -def flip(ptr, dim=-1, _builder=None, _generator=None): - - def flip_impl(ptr: tensor, dim: int, builder: ir.builder, generator=None): - """ - Flips a tensor `ptr` along the dimension `dim`. - - :param ptr: the first input tensor - :type ptr: tensor - :param dim: the dimension to flip along - :type dim: int - :param generator: the code generator (required for reduce operations) - :type generator: generator object - """ - - def _get_flip_dim(dim, shape): - dim = _unwrap_if_constexpr(dim) - shape = _unwrap_if_constexpr(shape) - if dim is None: - dim = len(shape) - 1 - if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index - dim += len(shape) - return constexpr(dim) - - def _log2(i: core.constexpr): - log2 = 0 - n = core.constexpr(i).value - while n > 1: - n >>= 1 - log2 += 1 - return core.constexpr(log2) - - def flip_simd(ptr: tensor, dim: int, builder: ir.builder): - """ - Triton flip operation for simd - - Args: - ptr: tensor, input tensor - dim: int, dimension to flip (can be negative, normalized here) - builder: ir.builder, underlying IR builder - Returns: - flipped: tensor, same type and shape as input - """ - - shape = getattr(ptr, "shape", None) - if shape is None or shape == (): - shape = getattr(getattr(ptr, "type", None), "shape", None) - - rank = None - if shape is not None: - try: - rank = len(shape) - except Exception: - rank = len(list(shape)) - - if rank is not None: - if rank < 1: - raise ValueError("ascend.flip requires tensor rank >= 1") - norm_dim = dim if dim >= 0 else dim + rank - if not (0 <= norm_dim < rank): - raise ValueError(f"ascend.flip got invalid dim={dim} for shape {tuple(shape)}") - dim = norm_dim - else: - if dim < 0: - raise ValueError("ascend.flip with unknown rank requires non-negative dim") - - flipped_vals = builder.create_flip(ptr.handle, dim) - flipped = tensor(flipped_vals, type=ptr.type) - return flipped - - # If compile_mode is not simt, use the simd implementation - if not builder.is_simt_mode(): - return flip_simd(ptr, dim, builder) - core.static_assert(-len(ptr.shape) <= dim and dim < len(ptr.shape), _builder=builder) - _dim: core.constexpr = _get_flip_dim(dim, ptr.shape) - core.static_assert(standard._is_power_of_two(ptr.shape[_dim]), _builder=builder) - steps: core.constexpr = _log2(ptr.shape[_dim]) - # If steps is 0, return the original tensor - if steps == 0: - return ptr - # reshape the swap dimension to (2, 2, ..., 2) - idtype = core.get_int_dtype(bitwidth=ptr.dtype.primitive_bitwidth, signed=True) - y = core.reshape( - ptr.to(idtype, bitcast=True, _builder=builder), - ptr.shape.__getitem__(slice(None, _dim)) + [2] * steps + ptr.shape.__getitem__(slice(_dim + 1, None)), - _builder=builder) - for i in static_range(steps): - y = y.__xor__(standard.xor_sum(y, _dim + i, True, _builder=builder, _generator=generator), _builder=builder) - ptr = core.reshape(y, ptr.shape, _builder=builder).to(ptr.dtype, bitcast=True, _builder=builder) - return ptr - - try: - dim = int(dim.value) if hasattr(dim, "value") else int(dim) - except Exception as e: - raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}") from e - - dim = len(ptr.shape) - 1 if dim == -1 else dim - return flip_impl(ptr, dim, _builder, _generator) - - -class static_range: - """ - Iterator for non-JIT Python functions that need to iterate over constexpr values. - This is used in functions like flip that are called during compilation. - """ - - def __init__(self, arg1, arg2=None, step=None): - if step is None: - self.step = core.constexpr(1) - else: - self.step = step - if arg2 is None: - self.start = core.constexpr(0) - self.end = arg1 - else: - self.start = arg1 - self.end = arg2 - - def __iter__(self): - # Extract actual values from constexpr objects for iteration - start_val = core._constexpr_to_value(self.start) - end_val = core._constexpr_to_value(self.end) - step_val = core._constexpr_to_value(self.step) - # Store as regular Python integers for iteration - self._current = start_val - self._end = end_val - self._step = step_val - return self - - def __next__(self): - if self._current >= self._end: - raise StopIteration - value = self._current - self._current += self._step - return value - - -@builtin -def sort(ptr, dim=-1, descending=False, _builder=None): - """ - sort the input tensor along 'dim' - - param: - ptr: tensor, input tensor - dim: int or tl.constexpr[int], dimension to sort - descending: bool or tl.constexpr[bool], the result is descending or not - _builder: ir.builder - return: - values: tensor, the sorted tensor - """ - - def sort_impl(ptr: tensor, dim: int, descending, builder: ir.builder): - allowed_types = { - tl.int8, tl.int16, tl.bfloat16, tl.float16, tl.float32, tl.int32, tl.int64, tl.float8e4nv, tl.float8e5 - } - base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type - if base_ty not in allowed_types: - raise TypeError( - f"ascend.sort only supports int8, int16, bfloat16, float16, float32, int32, int64, float8e4nv, float8e5" - f"but got {ptr.type}") - - shape = getattr(ptr, "shape", None) - if shape is None or shape == (): - shape = getattr(getattr(ptr, "type", None), "shape", None) - - rank = None - if shape is not None: - try: - rank = len(shape) - except Exception: - rank = len(list(shape)) - - if rank is not None: - if rank < 1: - raise ValueError("ascend.sort requires tensor rank >= 1") - last_dim = rank - 1 - norm_dim = dim if dim >= 0 else dim + rank - if norm_dim != last_dim: - raise ValueError(f"ascend.sort only supports sorting along the last dimension " - f"(dim={last_dim} or -1) for shape {tuple(shape)}, but got dim={dim}") - dim = last_dim - else: - if dim != -1: - raise ValueError("ascend.sort only supports the last dimension; when rank is unknown " - "you must pass dim=-1") - - if hasattr(descending, "value"): - descending = bool(descending.value) - else: - descending = bool(descending) - - sorted_vals = builder.create_sort(ptr.handle, dim, descending) - - values = tensor(sorted_vals, type=ptr.type) - - return values - - try: - dim = int(dim.value) if hasattr(dim, "value") else int(dim) - except Exception as e: - raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}. Error: {str(e)}") from e - - if hasattr(descending, "value"): - descending = bool(descending.value) - else: - descending = bool(descending) - - ret = sort_impl(ptr, dim, descending, _builder) - base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type - if base_ty.is_int8() or base_ty.is_int16(): - compile_hint_impl(ret, "overflow_mode", constexpr("saturate"), _builder) - return ret - - -def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, fp_downcast_rounding: Optional[str] = None, - overflow_mode: Optional[str] = None) -> tensor: - src_ty = input.type - if isinstance(dst_ty, tl.constexpr): - dst_ty = dst_ty.value - if isinstance(fp_downcast_rounding, tl.constexpr): - fp_downcast_rounding = fp_downcast_rounding.value - if src_ty.is_block(): - dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) - if src_ty == dst_ty: - return input - - src_sca_ty = src_ty.scalar - dst_sca_ty = dst_ty.scalar - if src_sca_ty == dst_sca_ty: - return input - - # For fp downcasting default rounding mode should be RTNE, for all other conversions it should - # not be set - fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) - use_custom_rounding = False - if dst_sca_ty.is_floating() and src_sca_ty.is_floating( - ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: - if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE - elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True - else: - if fp_downcast_rounding is not None: - raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " - "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) - if not is_compile_on_910_95: - if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): - raise ValueError("[fp8, fp64] is unsupported on Ascend for now." - "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) - if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): - assert builder.codegen_fns.get( - "convert_custom_types") is not None, "target doesn't provide conversion for this type." - return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) - # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 - # and non-default rounding modes for downcasting - if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ - (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ - use_custom_rounding: - return tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) - - # bf16 <=> (not fp32) - if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ - (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): - return ascend_cast_impl(ascend_cast_impl(input, tl.float32, builder), dst_sca_ty, builder) - - # Standard floating types' casting: truncation - # fp64 => fp32, fp16, bf16 - # fp32 => fp16, bf16 - truncate_fp = src_sca_ty.is_floating() and \ - dst_sca_ty.is_floating() and \ - src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth - if truncate_fp: - return tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Standard floating types' casting: extension - # fp32 => fp64 - # fp16 => fp32, fp64 - # bf16 => fp32, fp64 - ext_fp = src_sca_ty.is_floating() and \ - dst_sca_ty.is_floating() and \ - src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth - if ext_fp: - return tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting between integer types - if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ - (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): - sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() - if dst_sca_ty.is_bool(): - ty = input.dtype.to_ir(builder) - _0 = tensor(builder.get_null_value(ty), input.dtype) - return not_equal(input, _0, builder) - elif overflow_mode == "saturate" and \ - (src_sca_ty.is_int_unsigned() or dst_sca_ty.is_int_unsigned()) and \ - src_sca_ty.int_bitwidth >= dst_sca_ty.int_bitwidth: - return ascend_cast_impl(ascend_cast_impl(input, tl.float32, builder), dst_sca_ty, builder) - return tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) - - # Casting standard floating types to integer types - if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): - if dst_sca_ty.is_bool(): - ty = input.dtype.to_ir(builder) - _0 = tensor(builder.get_null_value(ty), input.dtype) - return not_equal(input, _0, builder) - elif dst_sca_ty.is_int_signed(): - return tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) - else: - return tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting integer types to standard floating types - if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): - if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): - return tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) - else: - return tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting pointer types to integer types - if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): - bitwidth = dst_sca_ty.int_bitwidth - if bitwidth == 64: - return tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) - if bitwidth == 1: - return not_equal(ascend_cast_impl(input, tl.int64, builder), tensor(builder.get_int64(0), tl.int64), - builder) - - # Casting integer types to pointer types - if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): - return tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting pointer types to pointer types - if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): - return tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) - - assert False, f'cannot cast {input} to {dst_ty}' - - -@_tensor_member_fn -@builtin -def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, - overflow_mode: Optional[str] = None, _builder=None): - """ - Casts a tensor to the given :code:`dtype`. - - :param dtype: The target data type. - :type dtype: dtype - :param fp_downcast_rounding: The rounding mode for downcasting - floating-point values. This parameter is only used when self is a - floating-point tensor and dtype is a floating-point type with a - smaller bitwidth. Supported values are :code:`"rtne"` (round to - nearest, ties to even) and :code:`"rtz"` (round towards zero). - :type fp_downcast_rounding: str, optional - :param bitcast: If true, the tensor is bitcasted to the given - :code:`dtype`, instead of being numerically casted. - :type bitcast: bool, optional - :param overflow_mode: When overflow_mode is not set or is "trunc", - truncation (cut-off) will be used to handle overflow. When - overflow_mode is "sautrate", the maximum value of the data type - will be used to handle overflow. - :type overflow_mode: string, optional - """ - overflow_modes = ["trunc", "saturate"] - input = semantic.to_tensor(input, _builder) - if isinstance(bitcast, constexpr): - bitcast = bitcast.value - if bitcast: - return semantic.bitcast(input, dtype, _builder) - ret = ascend_cast_impl(input, dtype, _builder, fp_downcast_rounding, overflow_mode) - if overflow_mode is not None: - if overflow_mode in overflow_modes: - compile_hint_impl(ret, "overflow_mode", overflow_mode, _builder) - else: - raise ValueError(f"Unknown overflow_mode:{overflow_mode} is found.") - return ret +# insert_slice +# extract_slice +# get_element +# sort +# flip +# gather + +import triton.language as tl +from triton.language import semantic, core, standard +from triton.language.core import (_constexpr_to_value, _tensor_member_fn, _unwrap_iterable, builtin, constexpr, dtype, + tensor, check_bit_width, _unwrap_if_constexpr, range) +from triton.language.semantic import ( + wrap_tensor, + _str_to_rounding_mode, + not_equal, + _str_to_dot_input_precision, + binary_op_type_checking_impl, + integer_promote_impl, + broadcast_impl_shape, + _str_to_sem, + _str_to_scope, + bitcast, + bitwise_op_type_checking_impl, + to_tensor, + _str_to_load_cache_modifier, + _str_to_eviction_policy, + _str_to_padding_option, + _canonicalize_boundary_check, +) + +from . import is_compile_on_910_95 +from .aux_ops import compile_hint_impl + +from typing import Optional, Tuple, List, overload +from triton._C.libtriton import ir + + +@_tensor_member_fn +@builtin +def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to receive tensor. + :type ful: Tensor + :param sub: The tensor to be inserted. + :type sub: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + + def insert_slice_impl(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) + # Handle both tensor and int offsets (for interpreter mode) + new_offsets = [] + for o in offsets: + if isinstance(o, tensor): + new_offsets.append(o.handle) + elif isinstance(o, int): + # For interpreter mode: keep as int + new_offsets.append(o) + else: + new_offsets.append(o.handle if hasattr(o, 'handle') else o) + ret_type = tl.block_type(ful.type.scalar, ful.shape) + out = builder.create_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) + return tensor(out, ret_type) + + assert len(ful.shape) > 0 + assert len(ful.shape) == len(sub.shape) + new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] + out = insert_slice_impl(ful, sub, new_offsets, sizes, strides, _builder) + return out + + +@_tensor_member_fn +@builtin +def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to split. + :type ful: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + + def extract_slice_impl(ful: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) + # Handle both tensor and int offsets (for interpreter mode) + new_offsets = [] + for o in offsets: + if isinstance(o, tensor): + new_offsets.append(o.handle) + elif isinstance(o, int): + # For interpreter mode: keep as int + new_offsets.append(o) + else: + new_offsets.append(o.handle if hasattr(o, 'handle') else o) + ret_type = tl.block_type(ful.type.scalar, sizes) + out = builder.create_extract_slice(ful.handle, new_offsets, sizes, strides) + return tensor(out, ret_type) + + assert len(ful.shape) > 0 + new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] + sub = extract_slice_impl(ful, new_offsets, sizes, strides, _builder) + return sub + + +@_tensor_member_fn +@builtin +def get_element(src, indice, _builder=None, _generator=None): + """ + get_element op reads a ranked tensor and returns one element as specified by the given indices. + The result of the op is a value with the same type as the elements of the tensor. + The arity of indices must match the rank of the accessed value. + + :param src: The tensor to be accessed. + :type src: Tensor + :param indice: + :type indice: tuple of ints + """ + + def get_element_impl(src: tensor, indice: List[tensor], builder: ir.builder): + if len(src.shape) != len(indice): + raise ValueError("Indice's rank must be equal to src tensor's rank") + + # Handle both tensor and int indices (for interpreter mode) + new_indice = [] + for i in indice: + if isinstance(i, tensor): + new_indice.append(i.handle) + elif isinstance(i, int): + # For interpreter mode: convert int to TensorHandle + new_indice.append(i) + else: + # Try to use .handle attribute if available + new_indice.append(i.handle if hasattr(i, 'handle') else i) + + result = builder.create_extract_scalar(src.handle, new_indice) + return wrap_tensor(result, src.type.scalar, None) + + assert len(src.shape) > 0 + new_indice = [semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i for i in indice] + return get_element_impl(src, new_indice, _builder) + + +@builtin +def flip(ptr, dim=-1, _builder=None, _generator=None): + + def flip_impl(ptr: tensor, dim: int, builder: ir.builder, generator=None): + """ + Flips a tensor `ptr` along the dimension `dim`. + + :param ptr: the first input tensor + :type ptr: tensor + :param dim: the dimension to flip along + :type dim: int + :param generator: the code generator (required for reduce operations) + :type generator: generator object + """ + + def _get_flip_dim(dim, shape): + dim = _unwrap_if_constexpr(dim) + shape = _unwrap_if_constexpr(shape) + if dim is None: + dim = len(shape) - 1 + if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index + dim += len(shape) + return constexpr(dim) + + def _log2(i: core.constexpr): + log2 = 0 + n = core.constexpr(i).value + while n > 1: + n >>= 1 + log2 += 1 + return core.constexpr(log2) + + def flip_simd(ptr: tensor, dim: int, builder: ir.builder): + """ + Triton flip operation for simd + + Args: + ptr: tensor, input tensor + dim: int, dimension to flip (can be negative, normalized here) + builder: ir.builder, underlying IR builder + Returns: + flipped: tensor, same type and shape as input + """ + + shape = getattr(ptr, "shape", None) + if shape is None or shape == (): + shape = getattr(getattr(ptr, "type", None), "shape", None) + + rank = None + if shape is not None: + try: + rank = len(shape) + except Exception: + rank = len(list(shape)) + + if rank is not None: + if rank < 1: + raise ValueError("ascend.flip requires tensor rank >= 1") + norm_dim = dim if dim >= 0 else dim + rank + if not (0 <= norm_dim < rank): + raise ValueError(f"ascend.flip got invalid dim={dim} for shape {tuple(shape)}") + dim = norm_dim + else: + if dim < 0: + raise ValueError("ascend.flip with unknown rank requires non-negative dim") + + flipped_vals = builder.create_flip(ptr.handle, dim) + flipped = tensor(flipped_vals, type=ptr.type) + return flipped + + # If compile_mode is not simt, use the simd implementation + if not builder.is_simt_mode(): + return flip_simd(ptr, dim, builder) + core.static_assert(-len(ptr.shape) <= dim and dim < len(ptr.shape), _builder=builder) + _dim: core.constexpr = _get_flip_dim(dim, ptr.shape) + core.static_assert(standard._is_power_of_two(ptr.shape[_dim]), _builder=builder) + steps: core.constexpr = _log2(ptr.shape[_dim]) + # If steps is 0, return the original tensor + if steps == 0: + return ptr + # reshape the swap dimension to (2, 2, ..., 2) + idtype = core.get_int_dtype(bitwidth=ptr.dtype.primitive_bitwidth, signed=True) + y = core.reshape( + ptr.to(idtype, bitcast=True, _builder=builder), + ptr.shape.__getitem__(slice(None, _dim)) + [2] * steps + ptr.shape.__getitem__(slice(_dim + 1, None)), + _builder=builder) + for i in static_range(steps): + y = y.__xor__(standard.xor_sum(y, _dim + i, True, _builder=builder, _generator=generator), _builder=builder) + ptr = core.reshape(y, ptr.shape, _builder=builder).to(ptr.dtype, bitcast=True, _builder=builder) + return ptr + + try: + dim = int(dim.value) if hasattr(dim, "value") else int(dim) + except Exception as e: + raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}") from e + + dim = len(ptr.shape) - 1 if dim == -1 else dim + return flip_impl(ptr, dim, _builder, _generator) + + +class static_range: + """ + Iterator for non-JIT Python functions that need to iterate over constexpr values. + This is used in functions like flip that are called during compilation. + """ + + def __init__(self, arg1, arg2=None, step=None): + if step is None: + self.step = core.constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = core.constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + + def __iter__(self): + # Extract actual values from constexpr objects for iteration + start_val = core._constexpr_to_value(self.start) + end_val = core._constexpr_to_value(self.end) + step_val = core._constexpr_to_value(self.step) + # Store as regular Python integers for iteration + self._current = start_val + self._end = end_val + self._step = step_val + return self + + def __next__(self): + if self._current >= self._end: + raise StopIteration + value = self._current + self._current += self._step + return value + + +@builtin +def sort(ptr, dim=-1, descending=False, _builder=None): + """ + sort the input tensor along 'dim' + + param: + ptr: tensor, input tensor + dim: int or tl.constexpr[int], dimension to sort + descending: bool or tl.constexpr[bool], the result is descending or not + _builder: ir.builder + return: + values: tensor, the sorted tensor + """ + + def sort_impl(ptr: tensor, dim: int, descending, builder: ir.builder): + allowed_types = { + tl.int8, tl.int16, tl.bfloat16, tl.float16, tl.float32, tl.int32, tl.int64, tl.float8e4nv, tl.float8e5 + } + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty not in allowed_types: + raise TypeError( + f"ascend.sort only supports int8, int16, bfloat16, float16, float32, int32, int64, float8e4nv, float8e5" + f"but got {ptr.type}") + + shape = getattr(ptr, "shape", None) + if shape is None or shape == (): + shape = getattr(getattr(ptr, "type", None), "shape", None) + + rank = None + if shape is not None: + try: + rank = len(shape) + except Exception: + rank = len(list(shape)) + + if rank is not None: + if rank < 1: + raise ValueError("ascend.sort requires tensor rank >= 1") + last_dim = rank - 1 + norm_dim = dim if dim >= 0 else dim + rank + if norm_dim != last_dim: + raise ValueError(f"ascend.sort only supports sorting along the last dimension " + f"(dim={last_dim} or -1) for shape {tuple(shape)}, but got dim={dim}") + dim = last_dim + else: + if dim != -1: + raise ValueError("ascend.sort only supports the last dimension; when rank is unknown " + "you must pass dim=-1") + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + sorted_vals = builder.create_sort(ptr.handle, dim, descending) + + values = tensor(sorted_vals, type=ptr.type) + + return values + + try: + dim = int(dim.value) if hasattr(dim, "value") else int(dim) + except Exception as e: + raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}. Error: {str(e)}") from e + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + ret = sort_impl(ptr, dim, descending, _builder) + # interpreter mode not support compile_hint overflow_mode, direct return + from triton.runtime.interpreter import InterpreterBuilder + if isinstance(_builder, InterpreterBuilder): + return ret + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty.is_int8() or base_ty.is_int16(): + compile_hint_impl(ret, "overflow_mode", constexpr("saturate"), _builder) + return ret + + +def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, fp_downcast_rounding: Optional[str] = None, + overflow_mode: Optional[str] = None) -> tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if isinstance(fp_downcast_rounding, tl.constexpr): + fp_downcast_rounding = fp_downcast_rounding.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty == dst_sca_ty: + return input + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + if not is_compile_on_910_95: + if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): + raise ValueError("[fp8, fp64] is unsupported on Ascend for now." + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return ascend_cast_impl(ascend_cast_impl(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif overflow_mode == "saturate" and \ + (src_sca_ty.is_int_unsigned() or dst_sca_ty.is_int_unsigned()) and \ + src_sca_ty.int_bitwidth >= dst_sca_ty.int_bitwidth: + if is_compile_on_910_95: + result = tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + compile_hint_impl(result, "saturate_src_unsigned", src_sca_ty.is_int_unsigned(), builder) + compile_hint_impl(result, "saturate_dst_unsigned", dst_sca_ty.is_int_unsigned(), builder) + return result + else: + return ascend_cast_impl(ascend_cast_impl(input, tl.float32, builder), dst_sca_ty, builder) + return tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(ascend_cast_impl(input, tl.int64, builder), tensor(builder.get_int64(0), tl.int64), + builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, + overflow_mode: Optional[str] = None, _builder=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :type dtype: dtype + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :type fp_downcast_rounding: str, optional + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + :type bitcast: bool, optional + :param overflow_mode: When overflow_mode is not set or is "trunc", + truncation (cut-off) will be used to handle overflow. When + overflow_mode is "sautrate", the maximum value of the data type + will be used to handle overflow. + :type overflow_mode: string, optional + """ + overflow_modes = ["trunc", "saturate"] + input = semantic.to_tensor(input, _builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(input, dtype, _builder) + ret = ascend_cast_impl(input, dtype, _builder, fp_downcast_rounding, overflow_mode) + if overflow_mode is not None: + if overflow_mode in overflow_modes: + from triton.runtime.interpreter import InterpreterBuilder + if isinstance(_builder, InterpreterBuilder): + overflow_mode = constexpr(overflow_mode) + compile_hint_impl(ret, "overflow_mode", overflow_mode, _builder) + else: + raise ValueError(f"Unknown overflow_mode:{overflow_mode} is found.") + return ret diff --git a/third_party/ascend/language/cann/libdevice.py b/third_party/ascend/language/cann/libdevice.py index eaba0a831e..07836b7142 100644 --- a/third_party/ascend/language/cann/libdevice.py +++ b/third_party/ascend/language/cann/libdevice.py @@ -22,7 +22,8 @@ from triton.language import core, math, semantic from triton._C.libtriton import ir from triton.runtime.jit import jit -from triton.backends.ascend.utils import get_ascend_arch_from_env +from triton.backends.ascend.utils import get_ascend_arch_from_env, triton_enable_libdevice_simt +from triton.tools.get_ascend_devices import is_compile_on_910_95 @core.extern @@ -82,11 +83,16 @@ def atan(arg0, _builder=None): @core.extern def tanh(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_tanhf", core.dtype("fp32")), - (core.dtype("fp16"), ): ("__hmf_tanhDh", core.dtype("fp16")), + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_tanh_fp32", core.dtype("fp32")), }, is_pure=True, _builder=_builder) + else: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_tanhf", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_tanhDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) @core.extern @@ -109,16 +115,17 @@ def ldexp(arg0, arg1, _builder=None): @core.extern def pow(arg0, arg1, _builder=None): - return core.extern_elementwise( - "", "", [arg0, arg1], { - (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_powf", core.dtype("fp32")), - (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_powf", core.dtype("fp16")), - (core.dtype("bf16"), core.dtype("bf16")): ("__hmf_powf", core.dtype("bf16")), - (core.dtype("int64"), core.dtype("int64")): ("__hmf_powi", core.dtype("int64")), - (core.dtype("int32"), core.dtype("int32")): ("__hmf_powi", core.dtype("int32")), - (core.dtype("int16"), core.dtype("int16")): ("__hmf_powi", core.dtype("int16")), - (core.dtype("int8"), core.dtype("int8")): ("__hmf_powi", core.dtype("int8")), + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_pow_fp32", core.dtype("fp32")), }, is_pure=True, _builder=_builder) + else: + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_powf", core.dtype("fp32")), + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_powDh", core.dtype("fp16")), + (core.dtype("bf16"), core.dtype("bf16")): ("__hmf_powDb", core.dtype("bf16")), + }, is_pure=True, _builder=_builder) @core.extern @@ -133,20 +140,72 @@ def isnan(arg0, _builder=None): @core.extern def div_rz(arg0, arg1, _builder=None): - core.static_print("tl.div_rz is unsupported for now. Use libdevice.div_rz instead.") - core.static_assert(False) + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_div_rz_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.builtin +def fast_dividef(arg0, arg1, _builder=None): + arg0 = semantic.to_tensor(arg0, _builder) + arg1 = semantic.to_tensor(arg1, _builder) + ret = semantic.fdiv(arg0, arg1, False, _builder) + return ret + + +@core.builtin +def fast_expf(arg0, _builder=None): + arg0 = semantic.to_tensor(arg0, _builder) + ret = core.tensor(_builder.create_exp(arg0.handle), arg0.type) + return ret @core.extern def fmod(arg0, arg1, _builder=None): - core.static_print("tl.fmod is unsupported for now. Use libdevice.fmod instead.") - core.static_assert(False) + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_fmod_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_float_as_int_fp32", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern +def atan2(arg0, arg1, _builder): + if arg0.dtype == core.dtype("bf16") or arg1.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.atan2 for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_atan2_fp16", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_atan2_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.builtin +@math._check_dtype(dtypes=["fp32"]) +@math._add_math_1arg_docstr("trunc") def trunc(arg0, _builder=None): - core.static_print("tl.trunc is unsupported for now. Use libdevice.trunc instead.") - core.static_assert(False) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_trunc_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_trunc_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + + zero = semantic.full(arg0.shape, 0.0, arg0.type.scalar, _builder) + condition = semantic.greater_equal(arg0, zero, _builder) + + floor_result = core.tensor(_builder.create_floor(arg0.handle), arg0.type) + ceil_result = core.tensor(_builder.create_ceil(arg0.handle), arg0.type) + + return semantic.where(condition, floor_result, ceil_result, _builder) @core.extern @@ -160,169 +219,255 @@ def round(arg0, _builder=None): @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("acos") def acos(arg0: core.tensor, _builder: ir.builder): - pi = 3.1415926536 - pi_half = 1.5707963268 - sqrt2 = 1.4142135624 - eps = 1e-8 + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.acos for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_acos_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_acos_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + pi = 3.1415926536 + pi_half = 1.5707963268 + sqrt2 = 1.4142135624 + eps = 1e-8 - # |x| < 0.5, acos(x) = pi/2 - [x + x*x²*(0.1666667 + x²*(0.075 + x²*(0.0446429 + 0.0303810*x²))] - arg0 = semantic.to_tensor(arg0, _builder) - abs_x = math.abs(arg0, _builder=_builder) - dtype = arg0.dtype - arg0_2 = semantic.mul(arg0, arg0, True, _builder) - arg0_4 = semantic.mul(arg0_2, arg0_2, True, _builder) - arg0_6 = semantic.mul(arg0_4, arg0_2, True, _builder) - arg0_8 = semantic.mul(arg0_6, arg0_2, True, _builder) - arg0_10 = semantic.mul(arg0_8, arg0_2, True, _builder) - poly = semantic.add(1.0, semantic.mul(0.166667, arg0_2, True, _builder), True, _builder) - poly = semantic.add(poly, semantic.mul(0.075, arg0_4, True, _builder), True, _builder) - poly = semantic.add(poly, semantic.mul(0.044643, arg0_6, True, _builder), True, _builder) - poly = semantic.add(poly, semantic.mul(0.030380, arg0_8, True, _builder), True, _builder) - poly = semantic.add(poly, semantic.mul(0.022372, arg0_10, True, _builder), True, _builder) - acos_center = semantic.sub(pi_half, semantic.mul(arg0, poly, True, _builder), True, _builder) - - # 0.5<|x|<0.9, acos(x) = 2*arctan(t), t=sqrt((1-abs_x)/(1+abs_x)) - numerator_mid = semantic.sub(1.0, abs_x, True, _builder) - denom_mid = semantic.add(1.0, abs_x, True, _builder) - div_mid = semantic.truediv(numerator_mid, denom_mid, _builder) - t_mid = math.sqrt(div_mid, _builder=_builder) - t2_mid = semantic.mul(t_mid, t_mid, True, _builder) - t4_mid = semantic.mul(t2_mid, t2_mid, True, _builder) - t6_mid = semantic.mul(t4_mid, t2_mid, True, _builder) - - # 1 + t2*(-0.3333310 + t2*(0.1999341 + t2*(-0.1420890 + t2*0.1065976))) - poly_mid1 = semantic.mul(0.1065976, t2_mid, True, _builder) - poly_mid2 = semantic.add(-0.1420890, poly_mid1, True, _builder) - poly_mid3 = semantic.mul(poly_mid2, t2_mid, True, _builder) - poly_mid4 = semantic.add(0.1999341, poly_mid3, True, _builder) - poly_mid5 = semantic.mul(poly_mid4, t2_mid, True, _builder) - poly_mid6 = semantic.add(-0.3333310, poly_mid5, True, _builder) - poly_mid = semantic.add(1.0, semantic.mul(poly_mid6, t2_mid, True, _builder), True, _builder) - arctan_t = semantic.mul(t_mid, poly_mid, True, _builder) - acos_mid = semantic.mul(2.0, arctan_t, True, _builder) - is_neg_mid = semantic.less_than(arg0, 0.0, _builder) - acos_mid_signed = semantic.where(is_neg_mid, semantic.sub(pi, acos_mid, True, _builder), acos_mid, _builder) - - is_center = semantic.less_than(abs_x, 0.5, _builder) - res_mid_boundary = semantic.where(is_center, acos_center, acos_mid_signed, _builder) - return res_mid_boundary + # |x| < 0.5, acos(x) = pi/2 - [x + x*x²*(0.1666667 + x²*(0.075 + x²*(0.0446429 + 0.0303810*x²))] + arg0 = semantic.to_tensor(arg0, _builder) + abs_x = math.abs(arg0, _builder=_builder) + dtype = arg0.dtype + arg0_2 = semantic.mul(arg0, arg0, True, _builder) + arg0_4 = semantic.mul(arg0_2, arg0_2, True, _builder) + arg0_6 = semantic.mul(arg0_4, arg0_2, True, _builder) + arg0_8 = semantic.mul(arg0_6, arg0_2, True, _builder) + arg0_10 = semantic.mul(arg0_8, arg0_2, True, _builder) + poly = semantic.add(1.0, semantic.mul(0.166667, arg0_2, True, _builder), True, _builder) + poly = semantic.add(poly, semantic.mul(0.075, arg0_4, True, _builder), True, _builder) + poly = semantic.add(poly, semantic.mul(0.044643, arg0_6, True, _builder), True, _builder) + poly = semantic.add(poly, semantic.mul(0.030380, arg0_8, True, _builder), True, _builder) + poly = semantic.add(poly, semantic.mul(0.022372, arg0_10, True, _builder), True, _builder) + acos_center = semantic.sub(pi_half, semantic.mul(arg0, poly, True, _builder), True, _builder) + + # 0.5<|x|<0.9, acos(x) = 2*arctan(t), t=sqrt((1-abs_x)/(1+abs_x)) + numerator_mid = semantic.sub(1.0, abs_x, True, _builder) + denom_mid = semantic.add(1.0, abs_x, True, _builder) + div_mid = semantic.truediv(numerator_mid, denom_mid, _builder) + t_mid = math.sqrt(div_mid, _builder=_builder) + t2_mid = semantic.mul(t_mid, t_mid, True, _builder) + t4_mid = semantic.mul(t2_mid, t2_mid, True, _builder) + t6_mid = semantic.mul(t4_mid, t2_mid, True, _builder) + + poly_mid1 = semantic.mul(0.1065976, t2_mid, True, _builder) + poly_mid2 = semantic.add(-0.1420890, poly_mid1, True, _builder) + poly_mid3 = semantic.mul(poly_mid2, t2_mid, True, _builder) + poly_mid4 = semantic.add(0.1999341, poly_mid3, True, _builder) + poly_mid5 = semantic.mul(poly_mid4, t2_mid, True, _builder) + poly_mid6 = semantic.add(-0.3333310, poly_mid5, True, _builder) + poly_mid = semantic.add(1.0, semantic.mul(poly_mid6, t2_mid, True, _builder), True, _builder) + arctan_t = semantic.mul(t_mid, poly_mid, True, _builder) + acos_mid = semantic.mul(2.0, arctan_t, True, _builder) + is_neg_mid = semantic.less_than(arg0, 0.0, _builder) + acos_mid_signed = semantic.where(is_neg_mid, semantic.sub(pi, acos_mid, True, _builder), acos_mid, _builder) + + is_center = semantic.less_than(abs_x, 0.6, _builder) + res_mid_boundary = semantic.where(is_center, acos_center, acos_mid_signed, _builder) + return res_mid_boundary @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("sinh") def sinh(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - exp0 = core.tensor(_builder.create_exp(arg0.handle), arg0.type) - exp1 = semantic.truediv(1.0, exp0, _builder) - tmp = semantic.sub(exp0, exp1, True, _builder) - ret = semantic.truediv(tmp, 2.0, _builder) - return ret + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.sinh for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_sinh_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_sinh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + exp0 = core.tensor(_builder.create_exp(arg0.handle), arg0.type) + exp1 = semantic.truediv(1.0, exp0, _builder) + tmp = semantic.sub(exp0, exp1, True, _builder) + ret = semantic.truediv(tmp, 2.0, _builder) + return ret @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("cosh") def cosh(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - exp0 = core.tensor(_builder.create_exp(arg0.handle), arg0.type) - exp1 = semantic.truediv(1.0, exp0, _builder) - tmp = semantic.add(exp0, exp1, True, _builder) - ret = semantic.truediv(tmp, 2.0, _builder) - return ret + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.cosh for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_cosh_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_cosh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + exp0 = core.tensor(_builder.create_exp(arg0.handle), arg0.type) + exp1 = semantic.truediv(1.0, exp0, _builder) + tmp = semantic.add(exp0, exp1, True, _builder) + ret = semantic.truediv(tmp, 2.0, _builder) + return ret @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("acosh") def acosh(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - tmp = semantic.sub(semantic.mul(arg0, arg0, True, _builder), 1.0, True, _builder) - sqrt_res = core.tensor(_builder.create_sqrt(tmp.handle), tmp.type) - sum_res = semantic.add(arg0, sqrt_res, True, _builder) - return core.tensor(_builder.create_log(sum_res.handle), sum_res.type) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.acosh for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_acosh_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_acosh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + tmp = semantic.sub(semantic.mul(arg0, arg0, True, _builder), 1.0, True, _builder) + sqrt_res = core.tensor(_builder.create_sqrt(tmp.handle), tmp.type) + sum_res = semantic.add(arg0, sqrt_res, True, _builder) + return core.tensor(_builder.create_log(sum_res.handle), sum_res.type) @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("asinh") def asinh(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - tmp = semantic.add(semantic.mul(arg0, arg0, True, _builder), 1.0, True, _builder) - sqrt_res = core.tensor(_builder.create_sqrt(tmp.handle), tmp.type) - sum_res = semantic.add(arg0, sqrt_res, True, _builder) - return core.tensor(_builder.create_log(sum_res.handle), sum_res.type) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.asinh for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_asinh_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_asinh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + tmp = semantic.add(semantic.mul(arg0, arg0, True, _builder), 1.0, True, _builder) + sqrt_res = core.tensor(_builder.create_sqrt(tmp.handle), tmp.type) + sum_res = semantic.add(arg0, sqrt_res, True, _builder) + return core.tensor(_builder.create_log(sum_res.handle), sum_res.type) @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("atanh") def atanh(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - a = semantic.add(1.0, arg0, True, _builder) - b = semantic.sub(1.0, arg0, True, _builder) - lna = core.tensor(_builder.create_log(a.handle), a.type) - lnb = core.tensor(_builder.create_log(b.handle), b.type) - tmp = semantic.sub(lna, lnb, True, _builder) - return semantic.mul(tmp, 0.5, True, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.atanh for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_atanh_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_atanh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + a = semantic.add(1.0, arg0, True, _builder) + b = semantic.sub(1.0, arg0, True, _builder) + lna = core.tensor(_builder.create_log(a.handle), a.type) + lnb = core.tensor(_builder.create_log(b.handle), b.type) + tmp = semantic.sub(lna, lnb, True, _builder) + return semantic.mul(tmp, 0.5, True, _builder) @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("expm1") def expm1(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - tmp = core.tensor(_builder.create_exp(arg0.handle), arg0.type) - return semantic.sub(tmp, 1, True, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.expm1 for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_expm1_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_expm1_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + tmp = core.tensor(_builder.create_exp(arg0.handle), arg0.type) + return semantic.sub(tmp, 1, True, _builder) @core.builtin -@math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@math._check_dtype(dtypes=["fp16", "fp32"]) @math._add_math_2arg_docstr("nextafter") def nextafter(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): - x = semantic.to_tensor(arg0, _builder) - y = semantic.to_tensor(arg1, _builder) - dtype_map = {"bf16": core.int16, "fp16": core.int16, "fp32": core.int32} - min_pos_bit = {"bf16": 0x0001, "fp16": 0x0001, "fp32": 0x00000001} - max_neg_bit = {"bf16": 0x8001, "fp16": 0x8001, "fp32": 0x80000001} - int_type = dtype_map[x.type.scalar.name] - x_eq_y = semantic.equal(x, y, _builder) - x_gt_0 = semantic.greater_than(x, 0, _builder) - y_gt_x = semantic.greater_than(y, x, _builder) - next_neg = semantic.xor_(x_gt_0, y_gt_x, _builder) - next_pos = semantic.not_(next_neg, _builder) - - p1 = semantic.full(x.shape, 1, int_type, _builder) - n1 = semantic.full(x.shape, -1, int_type, _builder) - dir_xy = semantic.where(next_pos, p1, n1, _builder) - x_abs = math.abs(x, _builder=_builder) - x_is_0 = semantic.equal(x_abs, 0, _builder) - - min_pos = semantic.full(x.shape, min_pos_bit[x.type.scalar.name], int_type, _builder) - max_neg = semantic.full(x.shape, max_neg_bit[x.type.scalar.name], int_type, _builder) - min_pos = semantic.bitcast(min_pos, x.dtype, _builder) - max_neg = semantic.bitcast(max_neg, x.dtype, _builder) - bits_x = semantic.bitcast(x, int_type, _builder) - bits_next = semantic.add(bits_x, dir_xy, True, _builder) - next_val = semantic.bitcast(bits_next, x.dtype, _builder) - - need_min_pos = semantic.logical_and(x_is_0, next_pos, _builder) - need_max_neg = semantic.logical_and(x_is_0, next_neg, _builder) - next_val = semantic.where(need_min_pos, min_pos, next_val, _builder) - next_val = semantic.where(need_max_neg, max_neg, next_val, _builder) - return semantic.where(x_eq_y, x, next_val, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_nextafter_fp16", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_nextafter_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + x = semantic.to_tensor(arg0, _builder) + y = semantic.to_tensor(arg1, _builder) + dtype_map = {"bf16": core.int16, "fp16": core.int16, "fp32": core.int32} + min_pos_bit = {"bf16": 0x0001, "fp16": 0x0001, "fp32": 0x00000001} + max_neg_bit = {"bf16": 0x8001, "fp16": 0x8001, "fp32": 0x80000001} + int_type = dtype_map[x.type.scalar.name] + x_eq_y = semantic.equal(x, y, _builder) + x_gt_0 = semantic.greater_than(x, 0, _builder) + y_gt_x = semantic.greater_than(y, x, _builder) + next_neg = semantic.xor_(x_gt_0, y_gt_x, _builder) + next_pos = semantic.not_(next_neg, _builder) + + p1 = semantic.full(x.shape, 1, int_type, _builder) + n1 = semantic.full(x.shape, -1, int_type, _builder) + dir_xy = semantic.where(next_pos, p1, n1, _builder) + x_abs = math.abs(x, _builder=_builder) + x_is_0 = semantic.equal(x_abs, 0, _builder) + + min_pos = semantic.full(x.shape, min_pos_bit[x.type.scalar.name], int_type, _builder) + max_neg = semantic.full(x.shape, max_neg_bit[x.type.scalar.name], int_type, _builder) + min_pos = semantic.bitcast(min_pos, x.dtype, _builder) + max_neg = semantic.bitcast(max_neg, x.dtype, _builder) + bits_x = semantic.bitcast(x, int_type, _builder) + bits_next = semantic.add(bits_x, dir_xy, True, _builder) + next_val = semantic.bitcast(bits_next, x.dtype, _builder) + + need_min_pos = semantic.logical_and(x_is_0, next_pos, _builder) + need_max_neg = semantic.logical_and(x_is_0, next_neg, _builder) + next_val = semantic.where(need_min_pos, min_pos, next_val, _builder) + next_val = semantic.where(need_max_neg, max_neg, next_val, _builder) + return semantic.where(x_eq_y, x, next_val, _builder) @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_2arg_docstr("hypot(Euclidean Distance)") def hypot(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - arg1 = semantic.to_tensor(arg1, _builder) - x2 = semantic.mul(arg0, arg0, True, _builder) - y2 = semantic.mul(arg1, arg1, True, _builder) - sum_res = semantic.add(x2, y2, True, _builder) - return core.tensor(_builder.create_sqrt(sum_res.handle), sum_res.type) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.hypot for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_hypot_fp16", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_hypot_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + arg1 = semantic.to_tensor(arg1, _builder) + x2 = semantic.mul(arg0, arg0, True, _builder) + y2 = semantic.mul(arg1, arg1, True, _builder) + sum_res = semantic.add(x2, y2, True, _builder) + return core.tensor(_builder.create_sqrt(sum_res.handle), sum_res.type) # This function is derived from the Cephes Math Library release 2.8: June, 2000 @@ -333,117 +478,132 @@ def hypot(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): @math._check_dtype(dtypes=["fp16", "fp32"]) @math._add_math_2arg_docstr("besseli0 (Modified Bessel function of the first kind, order 0).") def cyl_bessel_i0(arg0: core.tensor, _builder: ir.builder): - param1 = [ - -4.41534164647933937950e-18, - +3.33079451882223809783e-17, - -2.43127984654795469359e-16, - +1.71539128555513303061e-15, - -1.16853328779934516808e-14, - +7.67618549860493561688e-14, - -4.85644678311192946090e-13, - +2.95505266312963983461e-12, - -1.72682629144155570723e-11, - +9.67580903537323691224e-11, - -5.18979560163526290666e-10, - +2.65982372468238665035e-09, - -1.30002500998624804212e-08, - +6.04699502254191894932e-08, - -2.67079385394061173391e-07, - +1.11738753912010371815e-06, - -4.41673835845875056359e-06, - +1.64484480707288970893e-05, - -5.75419501008210370398e-05, - +1.88502885095841655729e-04, - -5.76375574538582365885e-04, - +1.63947561694133579842e-03, - -4.32430999505057594430e-03, - +1.05464603945949983183e-02, - -2.37374148058994688156e-02, - +4.93052842396707084878e-02, - -9.49010970480476444210e-02, - +1.71620901522208775349e-01, - -3.04682672343198398683e-01, - +6.76795274409476084995e-01, - ] - param2 = [ - -7.23318048787475395456e-18, - -4.83050448594418207126e-18, - +4.46562142029675999901e-17, - +3.46122286769746109310e-17, - -2.82762398051658348494e-16, - -3.42548561967721913462e-16, - +1.77256013305652638360e-15, - +3.81168066935262242075e-15, - -9.55484669882830764870e-15, - -4.15056934728722208663e-14, - +1.54008621752140982691e-14, - +3.85277838274214270114e-13, - +7.18012445138366623367e-13, - -1.79417853150680611778e-12, - -1.32158118404477131188e-11, - -3.14991652796324136454e-11, - +1.18891471078464383424e-11, - +4.94060238822496958910e-10, - +3.39623202570838634515e-09, - +2.26666899049817806459e-08, - +2.04891858946906374183e-07, - +2.89137052083475648297e-06, - +6.88975834691682398426e-05, - +3.36911647825569408990e-03, - +8.04490411014108831608e-01, - ] - arg0 = semantic.to_tensor(arg0, _builder) - abs_x = core.tensor(_builder.create_fabs(arg0.handle), arg0.type) - x_a = semantic.sub(semantic.mul(abs_x, 0.5, True, _builder), 2.0, True, _builder) - a_n_2 = 0 - a_n_1 = 0 - a_n = param1[0] - for i in range(1, 30): - a_n_2 = a_n_1 - a_n_1 = a_n - a_n = semantic.sub(semantic.mul(x_a, a_n_1, True, _builder), a_n_2, True, _builder) - a_n = semantic.add(a_n, param1[i], True, _builder) - - f_32 = semantic.full(abs_x.shape, 32.0, abs_x.type.scalar, _builder) - x_b = semantic.sub(semantic.fdiv(f_32, abs_x, True, _builder), 2.0, True, _builder) - b_n_2 = 0 - b_n_1 = 0 - b_n = param2[0] - for i in range(1, 25): - b_n_2 = b_n_1 - b_n_1 = b_n - b_n = semantic.sub(semantic.mul(x_b, b_n_1, True, _builder), b_n_2, True, _builder) - b_n = semantic.add(b_n, param2[i], True, _builder) - - half_exp = semantic.mul(core.tensor(_builder.create_exp(abs_x.handle), abs_x.type), 0.5, True, _builder) - res_a = semantic.mul(half_exp, semantic.sub(a_n, a_n_2, True, _builder), True, _builder) - res_b = semantic.fdiv(semantic.mul(half_exp, semantic.sub(b_n, b_n_2, True, _builder), True, _builder), \ - core.tensor(_builder.create_sqrt(abs_x.handle), abs_x.type), True, _builder) - cond = semantic.less_equal(abs_x, 8.0, _builder) - res = semantic.where(cond, res_a, res_b, _builder) - return res + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("fp16"): + core.static_print("extern livdevice.cyl_bessel_i0 for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_cyl_bessel_i0_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + param1 = [ + -4.41534164647933937950e-18, + +3.33079451882223809783e-17, + -2.43127984654795469359e-16, + +1.71539128555513303061e-15, + -1.16853328779934516808e-14, + +7.67618549860493561688e-14, + -4.85644678311192946090e-13, + +2.95505266312963983461e-12, + -1.72682629144155570723e-11, + +9.67580903537323691224e-11, + -5.18979560163526290666e-10, + +2.65982372468238665035e-09, + -1.30002500998624804212e-08, + +6.04699502254191894932e-08, + -2.67079385394061173391e-07, + +1.11738753912010371815e-06, + -4.41673835845875056359e-06, + +1.64484480707288970893e-05, + -5.75419501008210370398e-05, + +1.88502885095841655729e-04, + -5.76375574538582365885e-04, + +1.63947561694133579842e-03, + -4.32430999505057594430e-03, + +1.05464603945949983183e-02, + -2.37374148058994688156e-02, + +4.93052842396707084878e-02, + -9.49010970480476444210e-02, + +1.71620901522208775349e-01, + -3.04682672343198398683e-01, + +6.76795274409476084995e-01, + ] + param2 = [ + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + +4.46562142029675999901e-17, + +3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + +1.77256013305652638360e-15, + +3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + +1.54008621752140982691e-14, + +3.85277838274214270114e-13, + +7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + +1.18891471078464383424e-11, + +4.94060238822496958910e-10, + +3.39623202570838634515e-09, + +2.26666899049817806459e-08, + +2.04891858946906374183e-07, + +2.89137052083475648297e-06, + +6.88975834691682398426e-05, + +3.36911647825569408990e-03, + +8.04490411014108831608e-01, + ] + arg0 = semantic.to_tensor(arg0, _builder) + abs_x = core.tensor(_builder.create_fabs(arg0.handle), arg0.type) + x_a = semantic.sub(semantic.mul(abs_x, 0.5, True, _builder), 2.0, True, _builder) + a_n_2 = 0 + a_n_1 = 0 + a_n = param1[0] + for i in range(1, 30): + a_n_2 = a_n_1 + a_n_1 = a_n + a_n = semantic.sub(semantic.mul(x_a, a_n_1, True, _builder), a_n_2, True, _builder) + a_n = semantic.add(a_n, param1[i], True, _builder) + + f_32 = semantic.full(abs_x.shape, 32.0, abs_x.type.scalar, _builder) + x_b = semantic.sub(semantic.fdiv(f_32, abs_x, True, _builder), 2.0, True, _builder) + b_n_2 = 0 + b_n_1 = 0 + b_n = param2[0] + for i in range(1, 25): + b_n_2 = b_n_1 + b_n_1 = b_n + b_n = semantic.sub(semantic.mul(x_b, b_n_1, True, _builder), b_n_2, True, _builder) + b_n = semantic.add(b_n, param2[i], True, _builder) + + half_exp = semantic.mul(core.tensor(_builder.create_exp(abs_x.handle), abs_x.type), 0.5, True, _builder) + res_a = semantic.mul(half_exp, semantic.sub(a_n, a_n_2, True, _builder), True, _builder) + res_b = semantic.fdiv(semantic.mul(half_exp, semantic.sub(b_n, b_n_2, True, _builder), True, _builder), \ + core.tensor(_builder.create_sqrt(abs_x.handle), abs_x.type), True, _builder) + cond = semantic.less_equal(abs_x, 8.0, _builder) + res = semantic.where(cond, res_a, res_b, _builder) + return res @core.extern @math._check_dtype(dtypes=["fp16", "fp32"]) def signbit(arg0, _builder=None): - arg0_scalar_ty = arg0.type.scalar - if arg0_scalar_ty == core.float32: - int_ty = core.int32 - else: # arg0 type: float16 / bfloat16 - int_ty = core.int16 + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_signbit_fp16", core.dtype("int32")), + (core.dtype("fp32"), ): ("__hmf_signbit_fp32", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + else: + arg0_scalar_ty = arg0.type.scalar + if arg0_scalar_ty == core.float32: + int_ty = core.int32 + else: # arg0 type: float16 / bfloat16 + int_ty = core.int16 - arg0 = semantic.to_tensor(arg0, _builder) - int_tensor = semantic.bitcast(arg0, int_ty, _builder) - if int_ty == core.int32: - shift = 31 - elif int_ty == core.int16: - shift = 15 + arg0 = semantic.to_tensor(arg0, _builder) + int_tensor = semantic.bitcast(arg0, int_ty, _builder) + if int_ty == core.int32: + shift = 31 + elif int_ty == core.int16: + shift = 15 - shift = semantic.full(arg0.shape, shift, int_ty, _builder) - sign_bit_tensor = semantic.lshr(int_tensor, shift, _builder) - sign_bit_tensor = semantic.and_(sign_bit_tensor, semantic.full(arg0.shape, 1, int_ty, _builder), _builder) - return semantic.equal(sign_bit_tensor, 1, _builder) + shift = semantic.full(arg0.shape, shift, int_ty, _builder) + sign_bit_tensor = semantic.lshr(int_tensor, shift, _builder) + sign_bit_tensor = semantic.and_(sign_bit_tensor, semantic.full(arg0.shape, 1, int_ty, _builder), _builder) + return semantic.equal(sign_bit_tensor, 1, _builder) # Note: @@ -455,100 +615,105 @@ def signbit(arg0, _builder=None): @core.extern @math._check_dtype(dtypes=["fp32"]) def erfinv(arg0, _builder=None): - arg0_scalar_ty = arg0.type.scalar - arg0 = semantic.to_tensor(arg0, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_erfinv_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0_scalar_ty = arg0.type.scalar + arg0 = semantic.to_tensor(arg0, _builder) - inv_sqrt_pi_times_2 = semantic.full(arg0.shape, 1.128379167, arg0_scalar_ty, _builder).handle # 2 / sqrt(pi) - coeff_low_numerator = [-0.140543331, 0.914624893, -1.645349621, 0.886226899] - coeff_low_denominator = [0.012229801, -0.329097515, 1.442710462, -2.118377725, 1.0] - coeff_high_numerator = [1.641345311, 3.429567803, -1.624906493, -1.970840454] - coeff_high_denominator = [1.6370678, 3.5438892, 1.0] - - # low cal - arg0_squared = _builder.create_fmul(arg0.handle, arg0.handle) - numerator_low_range = semantic.full(arg0.shape, coeff_low_numerator[0], arg0_scalar_ty, _builder).handle - for i in range(1, len(coeff_low_numerator)): - numerator_low_range = _builder.create_fma( - numerator_low_range, arg0_squared, - semantic.full(arg0.shape, coeff_low_numerator[i], arg0_scalar_ty, _builder).handle) - - denominator_low_range = semantic.full(arg0.shape, coeff_low_denominator[0], arg0_scalar_ty, _builder).handle - for i in range(1, len(coeff_low_denominator)): - denominator_low_range = _builder.create_fma( - denominator_low_range, arg0_squared, - semantic.full(arg0.shape, coeff_low_denominator[i], arg0_scalar_ty, _builder).handle) - - low_res = _builder.create_fmul(arg0.handle, _builder.create_fdiv(numerator_low_range, denominator_low_range)) - - # high cal - arg0_erf_trans = _builder.create_sqrt( # (log2-log(1-|arg0|))^1/2 - _builder.create_fmul( - semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, - _builder.create_log( + inv_sqrt_pi_times_2 = semantic.full(arg0.shape, 1.128379167, arg0_scalar_ty, _builder).handle # 2 / sqrt(pi) + coeff_low_numerator = [-0.140543331, 0.914624893, -1.645349621, 0.886226899] + coeff_low_denominator = [0.012229801, -0.329097515, 1.442710462, -2.118377725, 1.0] + coeff_high_numerator = [1.641345311, 3.429567803, -1.624906493, -1.970840454] + coeff_high_denominator = [1.6370678, 3.5438892, 1.0] + + # low cal + arg0_squared = _builder.create_fmul(arg0.handle, arg0.handle) + numerator_low_range = semantic.full(arg0.shape, coeff_low_numerator[0], arg0_scalar_ty, _builder).handle + for i in range(1, len(coeff_low_numerator)): + numerator_low_range = _builder.create_fma( + numerator_low_range, arg0_squared, + semantic.full(arg0.shape, coeff_low_numerator[i], arg0_scalar_ty, _builder).handle) + + denominator_low_range = semantic.full(arg0.shape, coeff_low_denominator[0], arg0_scalar_ty, _builder).handle + for i in range(1, len(coeff_low_denominator)): + denominator_low_range = _builder.create_fma( + denominator_low_range, arg0_squared, + semantic.full(arg0.shape, coeff_low_denominator[i], arg0_scalar_ty, _builder).handle) + + low_res = _builder.create_fmul(arg0.handle, _builder.create_fdiv(numerator_low_range, denominator_low_range)) + + # high cal + arg0_erf_trans = _builder.create_sqrt( # (log2-log(1-|arg0|))^1/2 + _builder.create_fmul( + semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, + _builder.create_log( + _builder.create_fdiv( + _builder.create_fsub( + semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder).handle, + _builder.create_fabs(arg0.handle)), + semantic.full(arg0.shape, 2, arg0_scalar_ty, _builder).handle)))) + numerator_high_range = semantic.full(arg0.shape, coeff_high_numerator[0], arg0_scalar_ty, _builder).handle + for i in range(1, len(coeff_high_numerator)): + numerator_high_range = _builder.create_fma( + numerator_high_range, arg0_erf_trans, + semantic.full(arg0.shape, coeff_high_numerator[i], arg0_scalar_ty, _builder).handle) + + denominator_high_range = semantic.full(arg0.shape, coeff_high_denominator[0], arg0_scalar_ty, _builder).handle + for i in range(1, len(coeff_high_denominator)): + denominator_high_range = _builder.create_fma( + denominator_high_range, arg0_erf_trans, + semantic.full(arg0.shape, coeff_high_denominator[i], arg0_scalar_ty, _builder).handle) + + high_res = _builder.create_fdiv(numerator_high_range, denominator_high_range) + high_res = semantic.mul( + semantic.where(signbit(arg0, _builder=_builder), semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder), + semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder), + core.tensor(high_res, arg0.type), True, _builder).handle + + for _ in range(2): + low_res = _builder.create_fsub( + low_res, + _builder.create_fdiv( + _builder.create_fsub(_builder.create_erf(low_res), arg0.handle), + _builder.create_fmul( + inv_sqrt_pi_times_2, + _builder.create_exp( + _builder.create_fmul( + semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, + _builder.create_fmul(low_res, low_res)))))) + + high_res = _builder.create_fsub( + high_res, _builder.create_fdiv( - _builder.create_fsub( - semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder).handle, - _builder.create_fabs(arg0.handle)), - semantic.full(arg0.shape, 2, arg0_scalar_ty, _builder).handle)))) - numerator_high_range = semantic.full(arg0.shape, coeff_high_numerator[0], arg0_scalar_ty, _builder).handle - for i in range(1, len(coeff_high_numerator)): - numerator_high_range = _builder.create_fma( - numerator_high_range, arg0_erf_trans, - semantic.full(arg0.shape, coeff_high_numerator[i], arg0_scalar_ty, _builder).handle) - - denominator_high_range = semantic.full(arg0.shape, coeff_high_denominator[0], arg0_scalar_ty, _builder).handle - for i in range(1, len(coeff_high_denominator)): - denominator_high_range = _builder.create_fma( - denominator_high_range, arg0_erf_trans, - semantic.full(arg0.shape, coeff_high_denominator[i], arg0_scalar_ty, _builder).handle) - - high_res = _builder.create_fdiv(numerator_high_range, denominator_high_range) - high_res = semantic.mul( - semantic.where(signbit(arg0, _builder=_builder), semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder), - semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder), - core.tensor(high_res, arg0.type), True, _builder).handle - - for i in range(2): - low_res = _builder.create_fsub( - low_res, - _builder.create_fdiv( - _builder.create_fsub(_builder.create_erf(low_res), arg0.handle), - _builder.create_fmul( - inv_sqrt_pi_times_2, - _builder.create_exp( - _builder.create_fmul( - semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, - _builder.create_fmul(low_res, low_res)))))) - - high_res = _builder.create_fsub( - high_res, - _builder.create_fdiv( - _builder.create_fsub(_builder.create_erf(high_res), arg0.handle), - _builder.create_fmul( - inv_sqrt_pi_times_2, - _builder.create_exp( - _builder.create_fmul( - semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, - _builder.create_fmul(high_res, high_res)))))) - - arg0_abs = core.tensor(_builder.create_fabs(arg0.handle), arg0.type) - # Check if |arg0| > 1 - arg0_over = semantic.greater_than(arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder) - nan_tensor = semantic.full(arg0.shape, float("nan"), arg0_scalar_ty, _builder) - # Check if |arg0| = 1 - arg0_equal1 = semantic.equal(arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder) - pos_inf_tensor = semantic.full(arg0.shape, float("inf"), arg0_scalar_ty, _builder) - neg_inf_tensor = semantic.full(arg0.shape, float("-inf"), arg0_scalar_ty, _builder) - inf_res = semantic.where(signbit(arg0, _builder=_builder), neg_inf_tensor, pos_inf_tensor, _builder) - # Check if |arg0| >= 0.7 - arg0_high = semantic.greater_equal(arg0_abs, semantic.full(arg0.shape, 0.7, arg0_scalar_ty, _builder), _builder) - - return semantic.where( - arg0_equal1, inf_res, - semantic.where( - arg0_over, nan_tensor, - semantic.where(arg0_high, core.tensor(high_res, arg0.type), core.tensor(low_res, arg0.type), _builder), - _builder), _builder) + _builder.create_fsub(_builder.create_erf(high_res), arg0.handle), + _builder.create_fmul( + inv_sqrt_pi_times_2, + _builder.create_exp( + _builder.create_fmul( + semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, + _builder.create_fmul(high_res, high_res)))))) + + arg0_abs = core.tensor(_builder.create_fabs(arg0.handle), arg0.type) + # Check if |arg0| > 1 + arg0_over = semantic.greater_than(arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder) + nan_tensor = semantic.full(arg0.shape, float("nan"), arg0_scalar_ty, _builder) + # Check if |arg0| = 1 + arg0_equal1 = semantic.equal(arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder) + pos_inf_tensor = semantic.full(arg0.shape, float("inf"), arg0_scalar_ty, _builder) + neg_inf_tensor = semantic.full(arg0.shape, float("-inf"), arg0_scalar_ty, _builder) + inf_res = semantic.where(signbit(arg0, _builder=_builder), neg_inf_tensor, pos_inf_tensor, _builder) + # Check if |arg0| >= 0.7 + arg0_high = semantic.greater_equal(arg0_abs, semantic.full(arg0.shape, 0.7, arg0_scalar_ty, _builder), _builder) + + return semantic.where( + arg0_equal1, inf_res, + semantic.where( + arg0_over, nan_tensor, + semantic.where(arg0_high, core.tensor(high_res, arg0.type), core.tensor(low_res, arg0.type), _builder), + _builder), _builder) # Note: @@ -581,6 +746,8 @@ def gamma(arg0, _builder=None): t = semantic.add(reflect_arg0, 6.5, True, _builder) gamma_res = _builder.create_fmul( + _builder.create_fmul(sqrt_2pi_tensor, + pow(t, semantic.sub(reflect_arg0, 0.5, True, _builder), _builder=_builder).handle), _builder.create_fmul(sqrt_2pi_tensor, pow(t, semantic.sub(reflect_arg0, 0.5, True, _builder), _builder=_builder).handle), _builder.create_fmul( @@ -617,43 +784,20 @@ def gamma(arg0, _builder=None): @core.extern @math._check_dtype(dtypes=["fp32"]) def lgamma(arg0, _builder=None): - arg0_scalar_ty = arg0.type.scalar - arg0 = semantic.to_tensor(arg0, _builder) - - inf_tensor = semantic.full(arg0.shape, float('inf'), arg0_scalar_ty, _builder) - is_inf = semantic.equal(core.tensor(_builder.create_fabs(arg0.handle), arg0.type), inf_tensor, _builder) - gamma_res = _builder.create_fabs(gamma(arg0, _builder=_builder).handle) - lgamma_res = _builder.create_log(gamma_res) - - return semantic.where(is_inf, inf_tensor, core.tensor(lgamma_res, arg0.type), _builder) - - -@core.builtin -@math._check_dtype(dtypes=[ - "fp32", -]) -@math._add_math_1arg_docstr("trunc") -def trunc(arg0: core.tensor, _builder: ir.builder): - """ - Truncate the input to the nearest integer toward zero. - - For positive numbers, this is equivalent to floor(x). - For negative numbers, this is equivalent to ceil(x). - - Special cases: - - trunc(±0) returns ±0. - - trunc(±inf) returns ±inf. - - trunc(NaN) returns NaN. - """ - arg0 = semantic.to_tensor(arg0, _builder) - - zero = semantic.full(arg0.shape, 0.0, arg0.type.scalar, _builder) - condition = semantic.greater_equal(arg0, zero, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_lgamma_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0_scalar_ty = arg0.type.scalar + arg0 = semantic.to_tensor(arg0, _builder) - floor_result = core.tensor(_builder.create_floor(arg0.handle), arg0.type) - ceil_result = core.tensor(_builder.create_ceil(arg0.handle), arg0.type) + inf_tensor = semantic.full(arg0.shape, float('inf'), arg0_scalar_ty, _builder) + is_inf = semantic.equal(core.tensor(_builder.create_fabs(arg0.handle), arg0.type), inf_tensor, _builder) + gamma_res = _builder.create_fabs(gamma(arg0, _builder=_builder).handle) + lgamma_res = _builder.create_log(gamma_res) - return semantic.where(condition, floor_result, ceil_result, _builder) + return semantic.where(is_inf, inf_tensor, core.tensor(lgamma_res, arg0.type), _builder) @core.builtin @@ -662,47 +806,52 @@ def trunc(arg0: core.tensor, _builder: ir.builder): ]) @math._add_math_1arg_docstr("nearbyint") def nearbyint(arg0: core.tensor, _builder: ir.builder): - """ - Round argument x to an integer value in floating-point format. + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_nearbyint_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + """ + Round argument x to an integer value in floating-point format. - Uses the current rounding mode (round-to-nearest-even, aka banker's rounding). - """ - arg0 = semantic.to_tensor(arg0, _builder) + Uses the current rounding mode (round-to-nearest-even, aka banker's rounding). + """ + arg0 = semantic.to_tensor(arg0, _builder) - half = semantic.full(arg0.shape, 0.5, arg0.type.scalar, _builder) + half = semantic.full(arg0.shape, 0.5, arg0.type.scalar, _builder) - positive_adjust = semantic.add(arg0, half, True, _builder) - negative_adjust = semantic.sub(arg0, half, True, _builder) + positive_adjust = semantic.add(arg0, half, True, _builder) + negative_adjust = semantic.sub(arg0, half, True, _builder) - positive_result = core.tensor(_builder.create_floor(positive_adjust.handle), arg0.type) - negative_result = core.tensor(_builder.create_ceil(negative_adjust.handle), arg0.type) + positive_result = core.tensor(_builder.create_floor(positive_adjust.handle), arg0.type) + negative_result = core.tensor(_builder.create_ceil(negative_adjust.handle), arg0.type) - zero = semantic.full(arg0.shape, 0.0, arg0.type.scalar, _builder) - is_positive = semantic.greater_equal(arg0, zero, _builder) - basic_round = semantic.where(is_positive, positive_result, negative_result, _builder) + zero = semantic.full(arg0.shape, 0.0, arg0.type.scalar, _builder) + is_positive = semantic.greater_equal(arg0, zero, _builder) + basic_round = semantic.where(is_positive, positive_result, negative_result, _builder) - # Banker's rounding special treatment: For values exactly in the middle, round to the nearest even number. - fractional = semantic.sub(arg0, basic_round, True, _builder) - abs_fractional = core.tensor(_builder.create_fabs(fractional.handle), fractional.type) + # Banker's rounding special treatment: For values exactly in the middle, round to the nearest even number. + fractional = semantic.sub(arg0, basic_round, True, _builder) + abs_fractional = core.tensor(_builder.create_fabs(fractional.handle), fractional.type) - is_half = semantic.equal(abs_fractional, half, _builder) + is_half = semantic.equal(abs_fractional, half, _builder) - two = semantic.full(arg0.shape, 2.0, arg0.type.scalar, _builder) + two = semantic.full(arg0.shape, 2.0, arg0.type.scalar, _builder) - half_value = math.fdiv(basic_round, two, _builder=_builder) - half_floor = core.tensor(_builder.create_floor(half_value.handle), half_value.type) - double_half = semantic.mul(half_floor, two, True, _builder) + half_value = math.fdiv(basic_round, two, _builder=_builder) + half_floor = core.tensor(_builder.create_floor(half_value.handle), half_value.type) + double_half = semantic.mul(half_floor, two, True, _builder) - is_even = semantic.equal(basic_round, double_half, _builder) + is_even = semantic.equal(basic_round, double_half, _builder) - adjustment = semantic.where(is_positive, semantic.full(arg0.shape, -1.0, arg0.type.scalar, _builder), - semantic.full(arg0.shape, 1.0, arg0.type.scalar, _builder), _builder) + adjustment = semantic.where(is_positive, semantic.full(arg0.shape, -1.0, arg0.type.scalar, _builder), + semantic.full(arg0.shape, 1.0, arg0.type.scalar, _builder), _builder) - banker_result = semantic.where(is_even, basic_round, semantic.add(basic_round, adjustment, True, _builder), - _builder) + banker_result = semantic.where(is_even, basic_round, semantic.add(basic_round, adjustment, True, _builder), + _builder) - # Final result: Use banker's rounding for cases exactly at 0.5, otherwise use basic rounding. - return semantic.where(is_half, banker_result, basic_round, _builder) + # Final result: Use banker's rounding for cases exactly at 0.5, otherwise use basic rounding. + return semantic.where(is_half, banker_result, basic_round, _builder) @core.builtin @@ -711,18 +860,25 @@ def nearbyint(arg0: core.tensor, _builder: ir.builder): ]) @math._add_math_1arg_docstr("arcsine") def asin(arg0: core.tensor, _builder: ir.builder): - """ - Calculate the principal value of the arc sine of the input argument x. + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_asin_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_asin_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + """ + Calculate the principal value of the arc sine of the input argument x. - Returns result in radians, in the interval [-π/2, +π/2] for x inside [-1, +1]. - Returns NaN for x outside [-1, +1]. - """ - arg0 = semantic.to_tensor(arg0, _builder) + Returns result in radians, in the interval [-π/2, +π/2] for x inside [-1, +1]. + Returns NaN for x outside [-1, +1]. + """ + arg0 = semantic.to_tensor(arg0, _builder) - # asin(x) = π/2 - acos(x) - half_pi = semantic.full(arg0.shape, 1.5707963267948966, arg0.type.scalar, _builder) # π/2 - acos_val = acos(arg0, _builder=_builder) - return semantic.sub(half_pi, acos_val, True, _builder) + # asin(x) = π/2 - acos(x) + half_pi = semantic.full(arg0.shape, 1.5707963267948966, arg0.type.scalar, _builder) # π/2 + acos_val = acos(arg0, _builder=_builder) + return semantic.sub(half_pi, acos_val, True, _builder) @core.builtin @@ -731,18 +887,23 @@ def asin(arg0: core.tensor, _builder: ir.builder): ]) @math._add_math_1arg_docstr("base-10 logarithm") def log10(arg0: core.tensor, _builder: ir.builder): - """ - Calculate the base 10 logarithm of the input argument x. + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_log10_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + """ + Calculate the base 10 logarithm of the input argument x. - Returns NaN for x < 0, -inf for x = 0, and +0 for x = 1. - log10(x) = log(x) / log(10) - """ - arg0 = semantic.to_tensor(arg0, _builder) + Returns NaN for x < 0, -inf for x = 0, and +0 for x = 1. + log10(x) = log(x) / log(10) + """ + arg0 = semantic.to_tensor(arg0, _builder) - log_val = math.log(arg0, _builder=_builder) - log10_const = semantic.full(arg0.shape, 2.302585092994046, arg0.type.scalar, _builder) + log_val = math.log(arg0, _builder=_builder) + log10_const = semantic.full(arg0.shape, 2.302585092994046, arg0.type.scalar, _builder) - return math.fdiv(log_val, log10_const, _builder=_builder) + return math.fdiv(log_val, log10_const, _builder=_builder) @core.builtin @@ -751,29 +912,34 @@ def log10(arg0: core.tensor, _builder: ir.builder): ]) @math._add_math_2arg_docstr("copysign") def copysign(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): - """ - Create a floating-point value with the magnitude of x and the sign of y. - """ - x = semantic.to_tensor(arg0, _builder) - y = semantic.to_tensor(arg1, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_copysign_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + """ + Create a floating-point value with the magnitude of x and the sign of y. + """ + x = semantic.to_tensor(arg0, _builder) + y = semantic.to_tensor(arg1, _builder) - magnitude = core.tensor(_builder.create_fabs(x.handle), x.type) + magnitude = core.tensor(_builder.create_fabs(x.handle), x.type) - zero = semantic.full(y.shape, 0.0, y.type.scalar, _builder) - one = semantic.full(y.shape, 1.0, y.type.scalar, _builder) + zero = semantic.full(y.shape, 0.0, y.type.scalar, _builder) + one = semantic.full(y.shape, 1.0, y.type.scalar, _builder) - is_zero = semantic.equal(y, zero, _builder) - reciprocal = math.fdiv(one, y, _builder=_builder) - is_negative_reciprocal = semantic.less_than(reciprocal, zero, _builder) - is_negative_zero = semantic.and_(is_zero, is_negative_reciprocal, _builder) + is_zero = semantic.equal(y, zero, _builder) + y_reciprocal = math.fdiv(one, y, _builder=_builder) + is_negative_reciprocal = semantic.less_than(y_reciprocal, zero, _builder) + is_negative_zero = semantic.and_(is_zero, is_negative_reciprocal, _builder) - is_negative_nonzero = semantic.less_than(y, zero, _builder) - is_negative = semantic.or_(is_negative_zero, is_negative_nonzero, _builder) + is_negative_nonzero = semantic.less_than(y, zero, _builder) + is_negative = semantic.or_(is_negative_zero, is_negative_nonzero, _builder) - neg_magnitude = semantic.mul(magnitude, semantic.full(magnitude.shape, -1.0, magnitude.type.scalar, _builder), True, - _builder) + neg_magnitude = semantic.mul(magnitude, semantic.full(magnitude.shape, -1.0, magnitude.type.scalar, _builder), + True, _builder) - return semantic.where(is_negative, neg_magnitude, magnitude, _builder) + return semantic.where(is_negative, neg_magnitude, magnitude, _builder) if get_ascend_arch_from_env() == "Ascend910_9589": diff --git a/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp b/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp new file mode 100644 index 0000000000..fa82404dc8 --- /dev/null +++ b/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp @@ -0,0 +1,362 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "AutoBlockify/AutoBlockify.h" +#include "AutoBlockify/Utils.h" +#include "npu/Dialect/TritonAscend/IR/TritonAscendDialect.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "auto-blockify" + +using namespace mlir; +using namespace triton; + +PropagateUnrealizedCastDown::PropagateUnrealizedCastDown(MLIRContext *context, + Value logicalBlockId, + Value logicalBlockNum, + int autoBlockifySize) + : OpRewritePattern(context), + logicalBlockId(logicalBlockId), logicalBlockNum(logicalBlockNum), + autoBlockifySize(autoBlockifySize) {} + +LogicalResult +PropagateUnrealizedCastDown::matchAndRewrite(UnrealizedConversionCastOp op, + PatternRewriter &rewriter) const { + if (op.getInputs().size() != 2) + return failure(); + auto funcOp = op->getParentOfType(); + auto input = op.getInputs()[0]; + auto res = op->getResult(0); + SmallPtrSet users(op->user_begin(), op->user_end()); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Handling UnrealizedConversionCastOp:\n" << op << "\n"; + os << "Users:\n"; + for (auto *user : users) + os << *user << "\n"; + }); + for (auto *user : users) { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(user); + if (auto uccOp = dyn_cast(user)) { + if (uccOp->getResultTypes()[0] != input.getType()) { + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << *user << "\n"; + }); + return op.emitError("UnrealizedConversionCastOp cannot be resolved\n"); + } + rewriter.replaceOp(user, input); + } else if (auto blockifyLoop = getBlockifyLoop(user)) { + handleBlockifyLoop(blockifyLoop.value(), user, rewriter); + } else if (auto splatOp = dyn_cast(user)) { + rewriteSplat(op, splatOp, rewriter); + } else if (auto expandDimsOp = dyn_cast(user)) { + rewriteExpandDims(op, expandDimsOp, rewriter); + } else if (auto reduceOp = dyn_cast(user)) { + rewriteReduce(op, reduceOp, rewriter); + } else if (auto scanOp = dyn_cast(user)) { + rewriteScan(op, scanOp, rewriter); + } else if (auto loadOp = dyn_cast(user)) { + rewriteLoad(op, loadOp, rewriter); + } else if (auto storeOp = dyn_cast(user)) { + rewriteStore(op, storeOp, rewriter); + } else if (auto atomicRMWOp = dyn_cast(user)) { + rewriteAtomicRMW(op, atomicRMWOp, rewriter); + } else if (auto assertOp = dyn_cast(user)) { + rewriteAssert(op, assertOp, rewriter); + } else if (auto extractSliceOp = dyn_cast(user)) { + rewriteExtractSlice(op, extractSliceOp, rewriter); + } else if (auto insertSliceOp = dyn_cast(user)) { + rewriteInsertSlice(op, insertSliceOp, rewriter); + } else if (auto whileOp = dyn_cast(user)) { + rewriteWhile(op, whileOp, rewriter); + } else if (auto loopOp = dyn_cast(user)) { + rewriteLoop(op, loopOp, rewriter); + } else if (auto yieldOp = dyn_cast(user)) { + rewriteYield(op, yieldOp, rewriter); + } else if (auto conditionOp = dyn_cast(user)) { + rewriteCondition(op, conditionOp, rewriter); + } else if (user->hasTrait() || + isa(user)) { + rewriteGeneraleOp(op, user, rewriter); + } else if (isa(user)) { + auto *newOp = + createBlockifyLoop(user, op, logicalBlockId, logicalBlockNum, + autoBlockifySize, rewriter); + rewriter.setInsertionPoint(newOp); + handleBlockifyLoop(*getBlockifyLoop(newOp), newOp, rewriter); + } else { + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Unhandled Op\n" << *user << "\n"; + }); + llvm_unreachable("Unhandled operation"); + } + } + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "After successful conversion\n"; + os << funcOp << "\n"; + }); + rewriter.eraseOp(op); + return success(); +} + +AutoBlockifyPass::AutoBlockifyPass(const AutoBlockifyOptions &options) + : AutoBlockifyBase(options) {} + +bool AutoBlockifyPass::checkBlockifiable(Value v) { + if (!checkedValues.insert(v).second) + return true; + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Checking blockifiable:\n" << v << "\n"; + }); + for (auto &use : v.getUses()) { + auto *user = use.getOwner(); + auto opNum = use.getOperandNumber(); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "User:\n" << *user << "\n"; + }); + if (isa( + user) || + llvm::any_of(user->getOperandTypes(), isTensorPtrType)) + return false; + if (auto ifOp = dyn_cast(user)) { + user->setAttr(autoBlockifyRegionOpAttr, UnitAttr::get(v.getContext())); + return true; + } else if (auto whileOp = dyn_cast(user)) { + if (!checkBlockifiable(whileOp.getBeforeArguments()[opNum])) + return false; + } else if (auto loopOp = dyn_cast(user)) { + auto regionIterArg = loopOp.getTiedLoopRegionIterArg(&use); + auto loopResult = loopOp.getTiedLoopResult(&use); + if (!regionIterArg || !loopResult) { + user->setAttr(autoBlockifyRegionOpAttr, UnitAttr::get(v.getContext())); + return true; + } + if (!checkBlockifiable(regionIterArg) || !checkBlockifiable(loopResult)) + return false; + } else if (auto conditionOp = dyn_cast(user)) { + auto whileOp = cast(user->getParentOp()); + if (opNum == 0) { + whileOp->setAttr(autoBlockifyRegionOpAttr, + UnitAttr::get(v.getContext())); + return true; + } + if (!checkBlockifiable(whileOp.getAfterArguments()[opNum - 1]) || + !checkBlockifiable(whileOp->getResult(opNum - 1))) + return false; + } else if (auto conditionOp = dyn_cast(user)) { + if (auto loopOp = dyn_cast(user->getParentOp()); + loopOp && !checkBlockifiable(loopOp.getInits()[opNum])) + return false; + } else { + for (auto res : user->getResults()) { + if (!checkBlockifiable(res)) + return false; + } + } + } + return true; +} + +void AutoBlockifyPass::preProcess(triton::FuncOp func) { + IRRewriter rewriter(func.getContext()); + rewriter.setInsertionPointToStart(&func.getBody().front()); + auto loc = rewriter.getUnknownLoc(); + // Get logical block num + auto xNum = + rewriter.create(loc, triton::ProgramIDDim::X); + auto yNum = + rewriter.create(loc, triton::ProgramIDDim::Y); + auto zNum = + rewriter.create(loc, triton::ProgramIDDim::Z); + auto yzNum = rewriter.create(loc, yNum, zNum); + logicalBlockNum = rewriter.create(loc, yzNum, xNum); + + // Get logical block id + auto xDim = + rewriter.create(loc, triton::ProgramIDDim::X); + auto yDim = + rewriter.create(loc, triton::ProgramIDDim::Y); + auto zDim = + rewriter.create(loc, triton::ProgramIDDim::Z); + xDim->setAttr(logicalBlockIdAttr, rewriter.getUnitAttr()); + yDim->setAttr(logicalBlockIdAttr, rewriter.getUnitAttr()); + zDim->setAttr(logicalBlockIdAttr, rewriter.getUnitAttr()); + auto xFlatten = rewriter.create(loc, xDim, yzNum); + auto yFlatten = rewriter.create(loc, yDim, zNum); + logicalBlockId = rewriter.create(loc, xFlatten, yFlatten); + logicalBlockId = rewriter.create(loc, logicalBlockId, zDim); + + // get blockified block id + auto blockifyTensorType = + RankedTensorType::get({autoBlockifySize}, rewriter.getI32Type()); + auto blockfyRange = rewriter.create( + loc, blockifyTensorType, 0, autoBlockifySize); + auto splatedLogicalBlockId = rewriter.create( + loc, blockfyRange.getType(), logicalBlockId); + Value blockifiedId = + rewriter.create(loc, splatedLogicalBlockId, blockfyRange); + + // get mask + auto splatedBlockNum = rewriter.create( + loc, blockfyRange.getType(), logicalBlockNum); + auto upperboundMask = rewriter.create( + loc, arith::CmpIPredicate::slt, blockifiedId, splatedBlockNum); + auto splatedZero = rewriter.create( + loc, DenseElementsAttr::get(blockifyTensorType, + rewriter.getI32IntegerAttr(0))); + auto lowerboundMask = rewriter.create( + loc, arith::CmpIPredicate::sge, blockifiedId, splatedZero); + Value blockifiedIdMask = + rewriter.create(loc, upperboundMask, lowerboundMask); + + blockifiedId = rewriter + .create( + loc, logicalBlockId.getType(), + ValueRange({blockifiedId, blockifiedIdMask})) + ->getResult(0); + + // replace program id to be computed from blockified id + SmallVector toReplace; + func.walk([&](triton::GetProgramIdOp id) { + if (id->hasAttr(logicalBlockIdAttr)) + return; + toReplace.push_back(id); + }); + for (auto id : toReplace) { + rewriter.setInsertionPoint(id); + Value newId; + if (id.getAxis() == triton::ProgramIDDim::X) { + newId = rewriter.create(id.getLoc(), blockifiedId, yzNum); + newId = rewriter.create(id.getLoc(), newId, xNum); + } else if (id.getAxis() == triton::ProgramIDDim::Y) { + newId = rewriter.create(id.getLoc(), blockifiedId, zNum); + newId = rewriter.create(id.getLoc(), newId, yNum); + } else { + newId = rewriter.create(id.getLoc(), blockifiedId, zNum); + } + rewriter.replaceOp(id, newId); + } + + // Create for loop for region ops + func.walk([&](Operation *op) { + if (op->hasAttr(autoBlockifyRegionOpAttr)) { + auto *newOp = createBlockifyLoop( + op, blockifiedId.getDefiningOp(), + logicalBlockId, logicalBlockNum, autoBlockifySize, rewriter); + newOp->removeAttr(autoBlockifyRegionOpAttr); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); +} + +void AutoBlockifyPass::runOnOperation() { + if (autoBlockifySize == 1) + return; + ModuleOp moduleOp = getOperation(); + if (autoBlockifySize <= 0) { + moduleOp->emitWarning("[AutoBlockify V2] AutoBlockifySize cannot be " + "negative integer, skipping."); + return signalPassFailure(); + } + + MLIRContext *ctx = &getContext(); + + moduleOp.walk([&](triton::FuncOp func) { + LogicalResult result = success(); + func.walk([&](triton::GetProgramIdOp id) { + if (!checkBlockifiable(id.getResult())) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (failed(result)) { + func->emitWarning("Cannot apply auto blockify"); + return WalkResult::skip(); + } + preProcess(func); + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "After preprocess:\n" << func << "\n"; + }); + + RewritePatternSet patterns(ctx); + patterns.add( + ctx, logicalBlockId, logicalBlockNum, autoBlockifySize); + + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + moduleOp->emitError("failed to apply Patterns"); + signalPassFailure(); + return WalkResult::interrupt(); + } + + IRRewriter rewriter(ctx); + func->walk([&](UnrealizedConversionCastOp op) { + rewriter.setInsertionPoint(op); + auto input = op.getInputs()[0]; + auto resType = cast(op->getResultTypes()[0]); + if (auto constantOp = input.getDefiningOp()) { + Attribute val = constantOp.getValue(); + if (auto denseAttr = dyn_cast(val)) + val = denseAttr.getSplatValue(); + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(resType, val)); + } else if (auto tensorType = + dyn_cast(input.getType())) { + input = rewriter.create(input.getLoc(), input, 0); + rewriter.replaceOpWithNewOp(op, resType, input); + } else { + rewriter.replaceOpWithNewOp(op, resType, input); + } + }); + func->setAttr(autoBlockifySizeAttr, + rewriter.getI32IntegerAttr(autoBlockifySize)); + return WalkResult::skip(); + }); + + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, moduleOp))) { + signalPassFailure(); + } +} + +std::unique_ptr> +triton::createAutoBlockifyPass(const AutoBlockifyOptions &options) { + return std::make_unique(options); +} diff --git a/third_party/ascend/lib/AutoBlockify/CMakeLists.txt b/third_party/ascend/lib/AutoBlockify/CMakeLists.txt new file mode 100644 index 0000000000..20ffce4753 --- /dev/null +++ b/third_party/ascend/lib/AutoBlockify/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(AutoBlockify + AutoBlockify.cpp + RewriteOperation.cpp + Utils.cpp + + DEPENDS + AutoBlockifyPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + TritonIR + TritonTransforms + TritonAnalysis + MLIRTransforms + MLIRSupport + MLIRSCFTransforms +) diff --git a/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp b/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp new file mode 100644 index 0000000000..0b610ebf0c --- /dev/null +++ b/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp @@ -0,0 +1,508 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "AutoBlockify/AutoBlockify.h" +#include "AutoBlockify/Utils.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "auto-blockify-rewrite-operation" + +using namespace mlir; +using namespace triton; + +void PropagateUnrealizedCastDown::handleBlockifyLoop( + scf::ForOp blockifyLoop, Operation *op, PatternRewriter &rewriter) const { + SmallVector newOperands; + for (auto opr : op->getOperands()) { + auto uccOp = opr.getDefiningOp(); + if (!uccOp) { + newOperands.push_back(opr); + continue; + } + auto input = uccOp.getInputs()[0]; + auto tensorType = cast(input.getType()); + Value newOperand; + if (tensorType.getRank() > 1) { + SmallVector offsets(tensorType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector sizes(1, rewriter.getIndexAttr(1)); + SmallVector strides(tensorType.getRank(), + rewriter.getIndexAttr(1)); + offsets[0] = blockifyLoop.getInductionVar(); + for (auto dim : llvm::drop_begin(tensorType.getShape())) + sizes.push_back(rewriter.getIndexAttr(dim)); + newOperand = rewriter.create( + input.getLoc(), cast(opr.getType()), input, offsets, + sizes, strides); + } else { + newOperand = rewriter.create( + input.getLoc(), input, ValueRange{blockifyLoop.getInductionVar()}); + if (isa(opr.getType())) { + newOperand = rewriter.create( + input.getLoc(), rewriter.getIndexType(), newOperand); + } + } + newOperands.push_back(newOperand); + } + rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); }); +} + +void PropagateUnrealizedCastDown::rewriteGeneraleOp( + UnrealizedConversionCastOp op, Operation *generalOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto res = op->getResult(0); + auto inputType = cast(input.getType()); + SmallVector newOperands; + SmallVector newResults; + SmallVector newResultTypes; + + for (auto operand : generalOp->getOperands()) + newOperands.push_back(rewriteValue(operand, op, rewriter)); + for (auto resType : generalOp->getResultTypes()) { + newResultTypes.push_back(getExpandedType(resType, op)); + } + auto *newOp = + rewriter.create(generalOp->getLoc(), generalOp->getName().getIdentifier(), + newOperands, newResultTypes, generalOp->getAttrs()); + replaceValue(newOp, generalOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteSplat( + UnrealizedConversionCastOp op, triton::SplatOp splatOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto resType = cast(splatOp.getResult().getType()); + auto curShape = + llvm::to_vector(cast(input.getType()).getShape()); + auto splatedShape = resType.getShape(); + for (auto dim : splatedShape) { + input = rewriter.create(input.getLoc(), input, + curShape.size()); + curShape.push_back(dim); + input = rewriter.create( + input.getLoc(), + RankedTensorType::get(curShape, getElementTypeOrSelf(input)), input); + } + replaceValue(input.getDefiningOp(), splatOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteExpandDims( + UnrealizedConversionCastOp op, triton::ExpandDimsOp expandDimsOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto newOp = rewriter.create( + expandDimsOp.getLoc(), input, expandDimsOp.getAxis() + 1); + for (auto attr : expandDimsOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, expandDimsOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteReduce( + UnrealizedConversionCastOp op, triton::ReduceOp reduceOp, + PatternRewriter &rewriter) const { + auto mask = op.getInputs()[1]; + auto srcs = llvm::map_to_vector(reduceOp.getSrcs(), [&](Value src) { + return rewriteValue(src, op, rewriter); + }); + auto newOp = rewriter.create(reduceOp.getLoc(), srcs, + reduceOp.getAxis() + 1); + auto &newCombineOp = newOp.getCombineOp(); + rewriter.cloneRegionBefore(reduceOp.getCombineOp(), newCombineOp, + newCombineOp.end()); + for (auto attr : reduceOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, reduceOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteScan(UnrealizedConversionCastOp op, + triton::ScanOp scanOp, + PatternRewriter &rewriter) const { + auto mask = op.getInputs()[1]; + auto srcs = llvm::map_to_vector(scanOp.getSrcs(), [&](Value src) { + return rewriteValue(src, op, rewriter); + }); + auto newOp = rewriter.create( + scanOp.getLoc(), srcs, scanOp.getAxis() + 1, scanOp.getReverse()); + auto &newCombineOp = newOp.getCombineOp(); + rewriter.cloneRegionBefore(scanOp.getCombineOp(), newCombineOp, + newCombineOp.end()); + for (auto attr : scanOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, scanOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteLoad(UnrealizedConversionCastOp op, + triton::LoadOp loadOp, + PatternRewriter &rewriter) const { + auto uccMask = op.getInputs()[1]; + auto ptr = rewriteValue(loadOp.getPtr(), op, rewriter); + auto other = rewriteValue(loadOp.getOther(), op, rewriter); + auto mask = rewriteValue(loadOp.getMask(), op, rewriter); + auto res = loadOp.getResult(); + auto resType = getExpandedType(res.getType(), op); + if (!other) { + other = rewriter.create( + rewriter.getUnknownLoc(), + DenseElementsAttr::get( + resType, rewriter.getZeroAttr(getElementTypeOrSelf(res)))); + } + mask = createMask(mask, uccMask, resType.getShape(), rewriter); + auto boundaryCheck = llvm::map_to_vector(loadOp.getBoundaryCheck(), + [](int32_t idx) { return idx + 1; }); + auto newOp = rewriter.create( + loadOp.getLoc(), ptr, mask, other, boundaryCheck, loadOp.getPadding(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + for (auto attr : loadOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, loadOp, uccMask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteStore( + UnrealizedConversionCastOp op, triton::StoreOp storeOp, + PatternRewriter &rewriter) const { + auto uccMask = op.getInputs()[1]; + auto ptr = rewriteValue(storeOp.getPtr(), op, rewriter); + auto value = rewriteValue(storeOp.getValue(), op, rewriter); + auto mask = rewriteValue(storeOp.getMask(), op, rewriter); + auto ptrShape = cast(ptr.getType()).getShape(); + mask = createMask(mask, uccMask, ptrShape, rewriter); + auto boundaryCheck = llvm::map_to_vector(storeOp.getBoundaryCheck(), + [](int32_t idx) { return idx + 1; }); + auto newOp = rewriter.create( + storeOp.getLoc(), ptr, value, mask, boundaryCheck, storeOp.getCache(), + storeOp.getEvict()); + for (auto attr : storeOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(storeOp, newOp); +} + +void PropagateUnrealizedCastDown::rewriteAtomicRMW( + UnrealizedConversionCastOp op, triton::AtomicRMWOp atomicRMWOp, + PatternRewriter &rewriter) const { + auto uccMask = op.getInputs()[1]; + auto ptr = rewriteValue(atomicRMWOp.getPtr(), op, rewriter); + auto val = rewriteValue(atomicRMWOp.getVal(), op, rewriter); + auto mask = rewriteValue(atomicRMWOp.getMask(), op, rewriter); + auto resType = getExpandedType(atomicRMWOp.getResult().getType(), op); + mask = createMask(mask, uccMask, resType.getShape(), rewriter); + auto newOp = rewriter.create( + atomicRMWOp.getLoc(), resType, atomicRMWOp.getAtomicRmwOp(), ptr, val, + mask, atomicRMWOp.getSem(), atomicRMWOp.getScope()); + for (auto attr : atomicRMWOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, atomicRMWOp, uccMask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteAssert( + UnrealizedConversionCastOp op, triton::AssertOp assertOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto inputShape = cast(input.getType()).getShape(); + auto conditionType = cast(mask.getType()); + auto oneAttr = rewriter.getIntegerAttr(getElementTypeOrSelf(mask), 1); + auto one = rewriter.create( + mask.getLoc(), DenseElementsAttr::get(conditionType, oneAttr)); + Value condition = rewriter.create(input.getLoc(), mask, one); + condition = createMask(nullptr, condition, inputShape, rewriter); + condition = + rewriter.create(condition.getLoc(), condition, input); + auto newOp = rewriter.create(assertOp.getLoc(), condition, + assertOp.getMessage()); + for (auto attr : assertOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(assertOp, newOp); +} + +void PropagateUnrealizedCastDown::rewriteExtractSlice( + UnrealizedConversionCastOp op, tensor::ExtractSliceOp extractSliceOp, + PatternRewriter &rewriter) const { + auto mask = op.getInputs()[1]; + auto src = rewriteValue(extractSliceOp.getSource(), op, rewriter); + auto offsets = llvm::to_vector(extractSliceOp.getMixedOffsets()); + auto sizes = llvm::to_vector(extractSliceOp.getMixedSizes()); + auto strides = llvm::to_vector(extractSliceOp.getMixedStrides()); + auto srcType = cast(src.getType()); + offsets.insert(offsets.begin(), rewriter.getIndexAttr(0)); + sizes.insert(sizes.begin(), rewriter.getIndexAttr(srcType.getShape()[0])); + strides.insert(strides.begin(), rewriter.getIndexAttr(1)); + auto newOp = rewriter.create( + extractSliceOp.getLoc(), src, offsets, sizes, strides); + auto newMask = rewriter.create( + mask.getLoc(), mask, offsets, sizes, strides); + for (auto attr : extractSliceOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, extractSliceOp, newMask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteInsertSlice( + UnrealizedConversionCastOp op, tensor::InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const { + auto mask = op.getInputs()[1]; + auto src = rewriteValue(insertSliceOp.getSource(), op, rewriter); + auto dst = rewriteValue(insertSliceOp.getDest(), op, rewriter); + auto offsets = llvm::to_vector(insertSliceOp.getMixedOffsets()); + auto sizes = llvm::to_vector(insertSliceOp.getMixedSizes()); + auto strides = llvm::to_vector(insertSliceOp.getMixedStrides()); + auto srcType = cast(src.getType()); + offsets.insert(offsets.begin(), rewriter.getIndexAttr(0)); + sizes.insert(sizes.begin(), rewriter.getIndexAttr(srcType.getShape()[0])); + strides.insert(strides.begin(), rewriter.getIndexAttr(1)); + auto newOp = rewriter.create( + insertSliceOp.getLoc(), src, dst, offsets, sizes, strides); + for (auto attr : insertSliceOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, insertSliceOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteWhile( + UnrealizedConversionCastOp op, scf::WhileOp whileOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto res = op->getResult(0); + SmallVector indices; + SmallVector newInits; + IRMapping mapping; + for (auto [idx, init] : llvm::enumerate(whileOp.getInits())) { + if (init == res) { + indices.push_back(idx); + newInits.push_back(input); + } else { + newInits.push_back(init); + } + } + auto newOp = rewriter.create( + whileOp.getLoc(), whileOp->getResultTypes(), newInits, + [&](OpBuilder &b, Location loc, ValueRange args) { + mapRegionIterArg(mapping, whileOp.getBeforeArguments(), args, indices, + mask, b); + for (auto &bodyOp : *whileOp.getBeforeBody()) + b.clone(bodyOp, mapping); + }, + [&](OpBuilder &b, Location loc, ValueRange args) { + mapRegionIterArg(mapping, whileOp.getAfterArguments(), args, {}, mask, + b); + for (auto &bodyOp : whileOp.getAfterBody()->without_terminator()) + b.clone(bodyOp, mapping); + auto yieldOp = + cast(whileOp.getAfterBody()->getTerminator()); + mapYieldedValue(mapping, yieldOp, indices, op, b); + }); + for (auto attr : whileOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(whileOp, newOp); +} + +void PropagateUnrealizedCastDown::rewriteLoop(UnrealizedConversionCastOp op, + LoopLikeOpInterface loopOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto res = op->getResult(0); + SmallVector indices; + SmallVector newInits; + IRMapping mapping; + for (auto [idx, init] : llvm::enumerate(loopOp.getInits())) { + if (init == res) { + indices.push_back(idx); + newInits.push_back(input); + } else { + newInits.push_back(init); + } + } + LoopLikeOpInterface newOp; + if (auto forOp = dyn_cast(loopOp.getOperation())) { + newOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInits, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + mapping.map(forOp.getInductionVar(), iv); + mapRegionIterArg(mapping, forOp.getRegionIterArgs(), args, indices, + mask, b); + for (auto &bodyOp : forOp.getBody()->without_terminator()) + b.clone(bodyOp, mapping); + auto yieldOp = cast(forOp.getBody()->getTerminator()); + mapYieldedValue(mapping, yieldOp, indices, op, b); + }); + for (auto attr : forOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + } else { + llvm_unreachable("Unhandled loopOp"); + } + replaceValue(newOp, loopOp, mask, rewriter, indices); +} + +void PropagateUnrealizedCastDown::rewriteIf(UnrealizedConversionCastOp &op, + scf::IfOp ifOp, + ArrayRef indices, + PatternRewriter &rewriter) const { + IRMapping mapping; + auto mask = op.getInputs()[1]; + auto thenBlockBuilder = [&](OpBuilder &b, Location loc) { + for (auto &bodyOp : *ifOp.thenBlock()) + b.clone(bodyOp, mapping); + }; + function_ref elseBlockBuilder = + [&](OpBuilder &b, Location loc) { + for (auto &bodyOp : *ifOp.elseBlock()) + b.clone(bodyOp, mapping); + }; + if (!ifOp.elseBlock()) + elseBlockBuilder = nullptr; + auto newOp = rewriter.create(ifOp.getLoc(), ifOp.getCondition(), + thenBlockBuilder, elseBlockBuilder); + for (auto attr : ifOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + if (mapping.contains(op)) + op = cast(mapping.lookup(op)); + replaceValue(newOp, ifOp, mask, rewriter, indices); +} + +void PropagateUnrealizedCastDown::rewriteYield( + UnrealizedConversionCastOp &op, scf::YieldOp yieldOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto res = op->getResult(0); + SmallVector indices; + auto newOperands = llvm::to_vector(yieldOp.getOperands()); + for (auto [idx, opr] : llvm::enumerate(newOperands)) { + if (opr == res) + indices.push_back(idx); + } + if (auto loopOp = dyn_cast(yieldOp->getParentOp())) { + auto uccOp = rewriter.create( + op.getLoc(), res.getType(), ValueRange({input})); + for (auto curIdx : indices) + newOperands[curIdx] = uccOp->getResult(0); + auto newOp = rewriter.create(yieldOp.getLoc(), newOperands); + for (auto attr : yieldOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(yieldOp, newOp); + rewriter.setInsertionPoint(loopOp); + for (auto curIdx : indices) { + auto &initArg = loopOp.getInitsMutable()[curIdx]; + auto initVal = initArg.get(); + uccOp = rewriter.create( + initVal.getLoc(), input.getType(), ValueRange({initVal})); + uccOp = rewriter.create( + initVal.getLoc(), initVal.getType(), + ValueRange({uccOp->getResult(0), mask})); + rewriter.modifyOpInPlace(loopOp, + [&]() { initArg.set(uccOp->getResult(0)); }); + } + } else if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + for (auto curIdx : indices) + newOperands[curIdx] = input; + auto newOp = rewriter.create(yieldOp.getLoc(), newOperands); + for (auto attr : yieldOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(yieldOp, newOp); + yieldOp = ifOp.thenYield() == yieldOp ? ifOp.elseYield() : ifOp.thenYield(); + if (yieldOp) { + rewriter.setInsertionPoint(yieldOp); + newOperands = llvm::to_vector(yieldOp.getOperands()); + for (auto curIdx : indices) { + auto uccOp = rewriter.create( + op.getLoc(), input.getType(), ValueRange({newOperands[curIdx]})); + newOperands[curIdx] = uccOp->getResult(0); + } + rewriter.replaceOpWithNewOp(yieldOp, newOperands); + } + rewriter.setInsertionPoint(ifOp); + rewriteIf(op, ifOp, indices, rewriter); + } +} + +void PropagateUnrealizedCastDown::rewriteCondition( + UnrealizedConversionCastOp op, scf::ConditionOp conditionOp, + PatternRewriter &rewriter) const { + auto whileOp = cast(conditionOp->getParentOp()); + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto res = op->getResult(0); + int64_t curIdx = -1; + auto args = llvm::to_vector(conditionOp.getArgs()); + for (auto [idx, opr] : llvm::enumerate(args)) { + if (opr == res) + curIdx = idx; + } + args[curIdx] = input; + auto newOp = rewriter.create( + conditionOp.getLoc(), conditionOp.getCondition(), args); + for (auto attr : conditionOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(conditionOp, newOp); + + res = whileOp->getResult(curIdx); + auto oldResType = res.getType(); + auto newResType = getExpandedType(oldResType, op); + rewriter.modifyOpInPlace(whileOp, [&]() { res.setType(newResType); }); + rewriter.setInsertionPointAfter(whileOp); + auto newUccOp = rewriter.create( + res.getLoc(), oldResType, ValueRange({res, mask})); + rewriter.replaceAllUsesExcept(res, newUccOp->getResult(0), newUccOp); + auto arg = whileOp.getAfterArguments()[curIdx]; + auto oldArgType = arg.getType(); + auto newArgType = getExpandedType(oldArgType, op); + rewriter.modifyOpInPlace(whileOp, [&]() { arg.setType(newArgType); }); + rewriter.setInsertionPointToStart(whileOp.getAfterBody()); + newUccOp = rewriter.create( + arg.getLoc(), oldArgType, ValueRange({arg, mask})); + rewriter.replaceAllUsesExcept(arg, newUccOp->getResult(0), newUccOp); +} diff --git a/third_party/ascend/lib/AutoBlockify/Utils.cpp b/third_party/ascend/lib/AutoBlockify/Utils.cpp new file mode 100644 index 0000000000..c6deed7c78 --- /dev/null +++ b/third_party/ascend/lib/AutoBlockify/Utils.cpp @@ -0,0 +1,210 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "AutoBlockify/Utils.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "auto-blockify-utils" + +using namespace mlir; +using namespace triton; + +RankedTensorType getExpandedType(Type type, UnrealizedConversionCastOp op) { + auto target = op.getInputs()[0]; + auto targetType = cast(target.getType()); + SmallVector targetShape{targetType.getShape()[0]}; + if (auto valueType = dyn_cast(type)) { + targetShape.append(valueType.getShape().begin(), + valueType.getShape().end()); + } + return RankedTensorType::get(targetShape, getElementTypeOrSelf(type)); +} + +Value rewriteValue(Value value, UnrealizedConversionCastOp op, + OpBuilder &builder) { + if (value == nullptr) + return nullptr; + if (value == op->getResult(0)) + return op.getInputs()[0]; + return builder + .create( + value.getLoc(), getExpandedType(value.getType(), op), value) + ->getResult(0); +} + +void replaceValue(Operation *newOp, Operation *oldOp, Value newMask, + RewriterBase &rewriter, ArrayRef replaceIndices) { + int64_t idx = 0; + for (auto [res, oldRes] : + llvm::zip_equal(newOp->getResults(), oldOp->getResults())) { + if (replaceIndices.empty() || + llvm::find(replaceIndices, idx) != replaceIndices.end()) { + auto resType = res.getType(); + auto newUccOp = rewriter.create( + newOp->getLoc(), oldRes.getType(), ValueRange({res, newMask})); + rewriter.replaceAllUsesExcept(oldRes, newUccOp->getResult(0), newUccOp); + } else { + rewriter.replaceAllUsesWith(oldRes, res); + } + idx++; + } + rewriter.eraseOp(oldOp); +} + +Value createMask(Value mask, Value uccMask, ArrayRef targetShape, + RewriterBase &rewriter) { + SmallVector curShape{targetShape[0]}; + for (auto [idx, dim] : llvm::drop_begin(llvm::enumerate(targetShape))) { + curShape.push_back(dim); + uccMask = + rewriter.create(uccMask.getLoc(), uccMask, idx); + uccMask = rewriter.create( + uccMask.getLoc(), + RankedTensorType::get(curShape, getElementTypeOrSelf(uccMask)), + uccMask); + } + if (mask) { + mask = rewriter.create(mask.getLoc(), mask, uccMask); + } else { + mask = uccMask; + } + return mask; +} + +void mapRegionIterArg(IRMapping &mapping, ValueRange oldArgs, + ValueRange newArgs, ArrayRef indices, Value mask, + OpBuilder &builder) { + auto newArgIter = newArgs.begin(); + for (auto [idx, oldArg] : llvm::enumerate(oldArgs)) { + if (llvm::find(indices, idx) != indices.end()) { + auto newUccOp = builder.create( + oldArg.getLoc(), oldArg.getType(), ValueRange({*newArgIter, mask})); + mapping.map(oldArg, newUccOp->getResult(0)); + } else { + mapping.map(oldArg, *newArgIter); + } + ++newArgIter; + } +} + +void mapYieldedValue(IRMapping &mapping, scf::YieldOp yieldOp, + ArrayRef indices, UnrealizedConversionCastOp op, + OpBuilder &builder) { + SmallVector newOperands; + for (auto [idx, operand] : llvm::enumerate(yieldOp.getOperands())) { + operand = mapping.lookup(operand); + if (llvm::find(indices, idx) != indices.end()) + newOperands.push_back(rewriteValue(operand, op, builder)); + else + newOperands.push_back(operand); + } + builder.create(yieldOp.getLoc(), newOperands); +} + +Operation *createBlockifyLoop(Operation *targetOp, + UnrealizedConversionCastOp op, + Value logicalBlockId, Value logicalBlockNum, + int autoBlockifySize, RewriterBase &rewriter) { + auto loc = targetOp->getLoc(); + rewriter.setInsertionPoint(targetOp); + auto initVal = + rewriter.create(loc, rewriter.getIndexAttr(0)); + auto stepVal = + rewriter.create(loc, rewriter.getIndexAttr(1)); + auto blockifySizeVal = rewriter.create( + loc, rewriter.getIndexAttr(autoBlockifySize)); + Value upperBound = + rewriter.create(loc, logicalBlockNum, logicalBlockId); + auto i32Zero = + rewriter.create(loc, rewriter.getI32IntegerAttr(0)); + upperBound = rewriter.create(loc, upperBound, i32Zero); + upperBound = rewriter.create(loc, rewriter.getIndexType(), + upperBound); + upperBound = + rewriter.create(loc, upperBound, blockifySizeVal); + SmallVector inits; + if (auto loopOp = dyn_cast(targetOp)) { + inits = llvm::map_to_vector(loopOp.getInits(), + [&rewriter, &op](Value v) -> Value { + return rewriteValue(v, op, rewriter); + }); + } else { + auto resultTypes = + llvm::map_to_vector(targetOp->getResultTypes(), [&op](Type type) { + return getExpandedType(type, op); + }); + inits = + llvm::map_to_vector(resultTypes, [&rewriter, &loc](Type type) -> Value { + auto tensorType = cast(type); + return rewriter.create(loc, tensorType.getShape(), + tensorType.getElementType()); + }); + } + auto mask = op.getInputs()[1]; + Operation *newOp; + auto blockifyLoop = rewriter.create( + loc, initVal, upperBound, stepVal, inits, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + newOp = b.clone(*targetOp); + + SmallVector newResults; + for (auto [arg, res] : llvm::zip_equal(args, newOp->getResults())) { + auto tensorType = cast(arg.getType()); + auto rank = tensorType.getRank(); + Value newRes; + if (rank > 1) { + SmallVector offsets(tensorType.getRank(), + b.getIndexAttr(0)); + SmallVector sizes(1, b.getIndexAttr(1)); + SmallVector strides(tensorType.getRank(), + b.getIndexAttr(1)); + offsets[0] = iv; + for (auto dim : llvm::drop_begin(tensorType.getShape())) + sizes.push_back(b.getIndexAttr(dim)); + newRes = b.create(loc, res, arg, offsets, + sizes, strides); + } else { + newRes = b.create(loc, res, arg, ValueRange{iv}); + } + newResults.push_back(newRes); + } + b.create(loc, newResults); + }); + + replaceValue(blockifyLoop, targetOp, mask, rewriter); + blockifyLoop->setAttr(autoBlockifyLoopAttr, rewriter.getUnitAttr()); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "After creating blockify loop:\n" << blockifyLoop << "\n"; + }); + return newOp; +} + +std::optional getBlockifyLoop(Operation *op) { + while (auto forOp = op->getParentOfType()) { + if (forOp->hasAttr(autoBlockifyLoopAttr)) + return forOp; + op = forOp; + } + return std::nullopt; +} diff --git a/third_party/ascend/lib/CMakeLists.txt b/third_party/ascend/lib/CMakeLists.txt index bd3c0c6c01..b04c8d981d 100644 --- a/third_party/ascend/lib/CMakeLists.txt +++ b/third_party/ascend/lib/CMakeLists.txt @@ -1 +1,35 @@ -add_subdirectory(Conversion) +add_subdirectory(AutoBlockify) +add_subdirectory(Conversion/TritonToHFusion) +add_subdirectory(Conversion/TritonToHIVM) +add_subdirectory(Conversion/TritonToLLVM) +add_subdirectory(TritonAffinityOpt) + +if(TRITON_ENABLE_COVERAGE_HITEST) + set(_instrument_targets + DiscreteMaskAccessConversion + TritonToAnnotation + TritonToHFusion + TritonToHIVM + TritonToLinalg + TritonToLLVM + TritonToStructured + TritonToUnstructure + MLIRTritonNPUUtils # from Utils + TritonAscendIR # from Dialect/TritonAscend/IR + TritonStructuredIR # from Dialect/TritonStructured/IR + AutoBlockify + TritonAffinityOpt + ) + + foreach(_target ${_instrument_targets}) + if(TARGET ${_target}) + set_target_properties(${_target} PROPERTIES + RULE_LAUNCH_COMPILE "hitestwrapper" + RULE_LAUNCH_LINK "hitestwrapper" + ) + message(STATUS "Enabled hitestwrapper for target: ${_target}") + else() + message(WARNING "Target ${_target} not found, please check the actual target name") + endif() + endforeach() +endif() diff --git a/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp b/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp index 9c4d41f9ff..7184b0c402 100644 --- a/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp +++ b/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp @@ -36,14 +36,9 @@ static Type getElementType(Value value) { return type; } -static int64_t get1DTensorLength(Value tensor) { - auto type = mlir::cast(tensor.getType()); - auto shape = type.getShape(); - - assert(shape.size() == 1 && - "ElementwiseInlineAsm now can operate only with 1D tensors"); - - return shape[0]; +static int64_t getTensorNumElements(Value tensor) { + auto type = mlir::cast(tensor.getType()); + return type.getNumElements(); } static Value getInt32Value(RewriterBase &rewriter, Location loc, int val) { @@ -83,15 +78,21 @@ SmallVector packOperands(mlir::triton::ElementwiseInlineAsmOp op, static SmallVector unpackElements(Location loc, Value packedValues, RewriterBase &rewriter) { - auto type = mlir::cast(packedValues.getType()); + auto type = mlir::cast(packedValues.getType()); auto elementType = type.getElementType(); + auto shape = type.getShape(); - int64_t length = get1DTensorLength(packedValues); + int64_t numElements = type.getNumElements(); SmallVector result; - for (int64_t idx = 0; idx < length; idx++) { - SmallVector indexes{ - rewriter.create(loc, idx)}; + for (int64_t linearIdx = 0; linearIdx < numElements; linearIdx++) { + SmallVector indexes(shape.size()); + int64_t remaining = linearIdx; + for (int64_t dim = shape.size() - 1; dim >= 0; dim--) { + indexes[dim] = + rewriter.create(loc, remaining % shape[dim]); + remaining /= shape[dim]; + } Value extracted = rewriter.create(loc, elementType, packedValues, indexes); result.push_back(extracted); @@ -175,55 +176,76 @@ createDestOps(triton::ElementwiseInlineAsmOp op, RewriterBase &rewriter, return ret; } -} // namespace +static LogicalResult processScalarInlineAsm(triton::ElementwiseInlineAsmOp op, + PatternRewriter &rewriter) { + Location loc = op.getLoc(); -struct ElementwiseInlineAsmOpConversion - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + auto outsWrapped = createDestOps(op, rewriter, {}, loc); - LogicalResult matchAndRewrite(triton::ElementwiseInlineAsmOp op, - PatternRewriter &rewriter) const final { - Location loc = op.getLoc(); + SmallVector outs; + for (const auto &resWrapped : outsWrapped) { + outs.push_back(resWrapped[0]); + } + rewriter.replaceOp(op, outs); - SmallVector> unpackedOperands; - for (auto operand : op.getOperands()) { - auto unpackedOperand = unpackElements(loc, operand, rewriter); - unpackedOperands.push_back(unpackedOperand); - } + return success(); +} - int64_t resultLength = get1DTensorLength(op->getResult(0)); - if (resultLength % op.getPackedElement()) { - op.emitError("Result tensor should be diveded to pack"); - return failure(); - } +static LogicalResult processVectorInlineAsm(triton::ElementwiseInlineAsmOp op, + PatternRewriter &rewriter) { + Location loc = op.getLoc(); - SmallVector> unpackedResults(op->getNumResults()); - for (int64_t i = 0; i < resultLength; i += op.getPackedElement()) { - // Block of elements to process with one call to the inline asm. This is - // ordered opposite `unpackedResults`: The outer dim is - // op.getPackedElement(), and the inner dim is the operand. - SmallVector> block(op.getPackedElement()); - for (auto &os : unpackedOperands) { - for (int j = 0; j < op.getPackedElement(); j++) { - block[j].push_back(os[i + j]); - } - } - auto cur = createDestOps(op, rewriter, block, loc); - assert(cur.size() == unpackedResults.size()); - for (unsigned j = 0; j < cur.size(); j++) { - unpackedResults[j].insert(unpackedResults[j].end(), cur[j].begin(), - cur[j].end()); + SmallVector> unpackedOperands; + for (auto operand : op.getOperands()) { + auto unpackedOperand = unpackElements(loc, operand, rewriter); + unpackedOperands.push_back(unpackedOperand); + } + + int64_t resultLength = getTensorNumElements(op->getResult(0)); + if (resultLength % op.getPackedElement()) { + op.emitError("Result tensor should be diveded to pack"); + return failure(); + } + + SmallVector> unpackedResults(op->getNumResults()); + for (int64_t i = 0; i < resultLength; i += op.getPackedElement()) { + // Block of elements to process with one call to the inline asm. This is + // ordered opposite `unpackedResults`: The outer dim is + // op.getPackedElement(), and the inner dim is the operand. + SmallVector> block(op.getPackedElement()); + for (auto &os : unpackedOperands) { + for (int j = 0; j < op.getPackedElement(); j++) { + block[j].push_back(os[i + j]); } } - // Reorder and pack the results. - SmallVector outs; - for (int i = 0; i < unpackedResults.size(); i++) { - outs.push_back(rewriter.create( - loc, op->getResult(i).getType(), unpackedResults[i])); + auto cur = createDestOps(op, rewriter, block, loc); + assert(cur.size() == unpackedResults.size()); + for (unsigned j = 0; j < cur.size(); j++) { + unpackedResults[j].insert(unpackedResults[j].end(), cur[j].begin(), + cur[j].end()); } - rewriter.replaceOp(op, outs); + } + // Reorder and pack the results. + SmallVector outs; + for (int i = 0; i < unpackedResults.size(); i++) { + outs.push_back(rewriter.create( + loc, op->getResult(i).getType(), unpackedResults[i])); + } + rewriter.replaceOp(op, outs); + + return success(); +} - return success(); +} // namespace + +struct ElementwiseInlineAsmOpConversion + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::ElementwiseInlineAsmOp op, + PatternRewriter &rewriter) const final { + return op.getOperands().empty() ? processScalarInlineAsm(op, rewriter) + : processVectorInlineAsm(op, rewriter); } }; diff --git a/third_party/ascend/lib/TritonAffinityOpt/CMakeLists.txt b/third_party/ascend/lib/TritonAffinityOpt/CMakeLists.txt new file mode 100644 index 0000000000..925ca52d92 --- /dev/null +++ b/third_party/ascend/lib/TritonAffinityOpt/CMakeLists.txt @@ -0,0 +1,19 @@ +add_triton_library(TritonAffinityOpt + DAGSSBuffer.cpp + DAG.cpp + DAGSync.cpp + DAGScope.cpp + + DEPENDS + TritonAffinityOptConversionPassIncGen + + LINK_LIBS + BiShengIRHIVMDialect + BiShengIRScopeDialect + MLIRIR + MLIRPass + MLIRTransforms + MLIRSupport + TritonIR + MLIRSCFDialect +) diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAG.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAG.cpp new file mode 100644 index 0000000000..d0222255e1 --- /dev/null +++ b/third_party/ascend/lib/TritonAffinityOpt/DAG.cpp @@ -0,0 +1,518 @@ +#include "TritonAffinityOpt/DAG.h" +#include "bishengir/Dialect/Annotation/IR/Annotation.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include + +namespace mlir { +namespace AffinityDAG { + +const auto printFlags = + OpPrintingFlags().enableDebugInfo(true, true).skipRegions(); + +const char *literalCoreType(CoreType ct) { + switch (ct) { + case VECTOR_ONLY: + return "VECTOR_ONLY"; + case CUBE_ONLY: + return "CUBE_ONLY"; + case CUBE_AND_VECTOR: + return "CUBE_AND_VECTOR"; + case UNDETERMINED: + return "UNDETERMINED"; + } + return "Unknown"; +} + +bool opIsScf(Operation *op) { + if (!llvm::isa(op->getDialect())) + return false; + return true; +} + +Graph::Graph(Block *block, Graph *parent, OpMap opMap, ValueMap valueMap, + bool inheritParent) + : block(block), parent(parent), opMap(opMap), valueMap(valueMap) { + + if (parent && inheritParent) { + if (!this->opMap) { + this->opMap = parent->opMap; + } + + if (!this->valueMap) { + this->valueMap = parent->valueMap; + } + } + + if (!this->opMap) { + this->opMap = std::make_shared(); + } + + if (!this->valueMap) { + this->valueMap = std::make_shared(); + } + + for (auto blockArg : block->getArguments()) { + (*this->valueMap)[blockArg] = std::make_unique(blockArg); + blockArgs.push_back((*this->valueMap)[blockArg].get()); + } + + for (auto &opRef : block->getOperations()) { + opCount += 1; + auto op = &opRef; + auto opNodeUnique = std::make_unique(op, this); + auto opNode = opNodeUnique.get(); + (*this->opMap)[op] = std::move(opNodeUnique); + + if (block->mightHaveTerminator() && op == block->getTerminator()) { + terminator = opNode; + } + + for (auto &subgraph : opNode->subgraphs) { + opCount += subgraph.opCount; + } + } +}; + +bool valueIsScalar(Value value) { + auto type = value.getType(); + + if (type.isIntOrIndexOrFloat()) { + return true; + } + + if (auto tensorType = llvm::dyn_cast(type)) { + return tensorType.getRank() == 0; + } + + if (auto _ = llvm::dyn_cast(type)) { + return true; + } + + return false; +} + +bool valueIsTensorOfPtr(Value value) { + auto type = value.getType(); + if (auto tensorType = llvm::dyn_cast(type)) { + auto elementType = tensorType.getElementType(); + if (llvm::isa(elementType)) { + return true; + } + } + + return false; +} + +OpAbility OpNode::canRunOn() const { + if (opIsScf(op)) { + return OpAbility::CUBE_AND_VECTOR; + } + return llvm::TypeSwitch(op) + .Case([](auto) { return OpAbility::CUBE_ONLY; }) + .Case([](auto) { return OpAbility::CUBE_AND_VECTOR; }) + .Case([](arith::SelectOp op) { + // when cond is vector, selectOp should be vector, otherwise scalar + return (valueIsScalar(op.getCondition()) ? OpAbility::CUBE_AND_VECTOR + : OpAbility::PREFER_VECTOR); + }) + .Default([](Operation *op) { + auto isVector = false; + for (auto operand : op->getOperands()) { + if (!valueIsScalar(operand)) { + // if (valueIsTensorOfPtr(operand)) { + // return SCALAR; + // } + isVector = true; + } + } + + for (auto result : op->getResults()) { + if (!valueIsScalar(result)) { + // if (valueIsTensorOfPtr(result)) { + // return SCALAR; + // } + isVector = true; + } + } + + if (isVector) { + return OpAbility::PREFER_VECTOR; + } + + return OpAbility::CUBE_AND_VECTOR; + }); +} + +OpNode::OpNode(Operation *op, Graph *graph) : Node(Node::NK_Op), op(op) { + if (op == nullptr) { + return; + } + + llvm::outs() << op << "\n"; + + auto &valueMap = *graph->valueMap.get(); + auto &opMap = *graph->opMap.get(); + for (const auto operand : op->getOperands()) { + auto valueNode = valueMap.at(operand).get(); + valueNode->outputs.push_back(this); + inputs.push_back(valueNode); + } + + for (const auto &result : op->getResults()) { + auto valueNodeUnique = std::make_unique(result); + auto valueNode = valueNodeUnique.get(); + valueMap[result] = std::move(valueNodeUnique); + valueNode->source = this; + outputs.push_back(valueNode); + } + + // if (!op->hasTrait()) { + // llvm::dbgs() << "Not building subgraph because op is not SingleBlock: " + // << op << '\n'; return; + // } + + if (auto branchOp = llvm::dyn_cast(op)) { + + OpNode *terminator = nullptr; + llvm::SmallVector, 2> validRegions; + + for (auto ®ion : branchOp->getRegions()) { + if (region.getBlocks().empty()) + continue; + subgraphs.emplace_back(®ion.getBlocks().front(), graph); + validRegions.emplace_back(region, subgraphs.back()); + } + + for (auto [region, subgraph] : validRegions) { + SmallVector succRegions; + + branchOp.getSuccessorRegions(region, succRegions); + if (auto currTerminator = dyn_cast( + subgraph.terminator->op)) { + for (auto &succ : succRegions) { + auto forwardedVal = currTerminator.getSuccessorOperands(succ); + if (succ.isParent()) { + // Step1: first yield to parent -> results: double direction + if (!terminator && subgraph.terminator) { + terminator = subgraph.terminator; + for (auto [forwardedVal, resultNode] : + llvm::zip_equal(forwardedVal, outputs)) { + auto resultValueNode = llvm::dyn_cast(resultNode); + assert(resultValueNode && + "Output of a OpNode should be ValueNode!"); + auto forwardedNode = valueMap[forwardedVal].get(); + resultValueNode->source = forwardedNode; + forwardedNode->outputs.push_back(resultNode); + } + } + + } else { + // Step2: Region terminator -> Succ Operands + auto succRegion = succ.getSuccessor(); + + for (auto [operand, succInput] : + llvm::zip_equal(forwardedVal, succ.getSuccessorInputs())) { + auto forwardedNode = valueMap[operand].get(); + auto succNode = valueMap[succInput].get(); + forwardedNode->outputs.push_back(succNode); + succNode->source = forwardedNode; + } + } + } + } + } + + if (auto loopOp = llvm::dyn_cast(op)) { + // Step3: inits->iter_args (single directional) (should be handled in step + // 2: ) last terminator -> iter_args (bidirectional) + for (auto [init, iterArgVal] : + llvm::zip_equal(loopOp.getInits(), loopOp.getRegionIterArgs())) { + auto &initNode = valueMap[init]; + auto &iterArgNode = valueMap[iterArgVal]; + initNode->outputs.push_back(iterArgNode.get()); + } + // for(auto [init, iterArgVal, yieldNode] : + // llvm::zip_equal(loopOp.getInits(), loopOp.getRegionIterArgs(), + // terminator->outputs)) { + // auto& initNode = valueMap[init]; + // auto& iterArgNode = valueMap[iterArgVal]; + // initNode->outputs.push_back(iterArgNode.get()); + // yieldNode->outputs.push_back(iterArgNode.get()); + // iterArgNode->source = yieldNode; + // } + } + } +} + +// llvm::SmallVector getWriteOperandPriority(OpNode* op) { + +// llvm::SmallVector result(op->getInputs()); + +// auto getPriority = [](ValueNode* node) { +// auto typ = getElementTypeOrSelf(node->value); +// if (typ.isInteger(1)) { +// return 2; +// } +// if (llvm::isa(typ)) { +// return 1; +// } +// return 0; +// }; + +// std::stable_sort(result.begin(), result.end(), [&](ValueNode* a, ValueNode* +// b) { +// return getPriority(a) < getPriority(b); +// }); + +// return result; +// } + +ValueNode *getWriteDataSource(OpNode *op) { + auto inputRange = op->getInputs(); + for (auto node : inputRange.drop_front()) { + auto typ = getElementTypeOrSelf(node->value); + if (!typ.isInteger(1)) { + return node; + } + }; + + return nullptr; +} + +enum class MemPolicy { NONE, READ, WRITE }; + +CoreType Node::absorbCommon() { + + auto sourceNode = getSourceOpNode(); + auto op = sourceNode ? sourceNode->op : nullptr; + + if (!sourceNode || !op) { + CoreType newCoreType = isOnPrivate; + for (auto output : outputs) { + newCoreType = newCoreType | output->isOn(); + isUpstreamOfCubeMem = isUpstreamOfCubeMem || output->isUpstreamOfCubeMem; + } + return newCoreType; + } + + CoreType newCoreType = sourceNode->isOn(); + + OpAbility ability = sourceNode->canRunOn(); + + if (ability == OpAbility::CUBE_ONLY) { + return CUBE_ONLY; + } + + auto memIface = llvm::dyn_cast(op); + auto memPolicy = MemPolicy::NONE; + + if (memIface) { + // Possible improvements: Determine the policy to use based on shapes, + // inputs and outputs, etc + if (memIface.hasEffect()) { + memPolicy = MemPolicy::WRITE; + } else if (memIface.hasEffect()) { + memPolicy = MemPolicy::READ; + } + } + + if (memPolicy == MemPolicy::WRITE) { + if (auto data = getWriteDataSource(sourceNode)) { + auto currCt = data->isOn(); + if (exactlyOneType(currCt)) { + if (currCt == CUBE_ONLY) { + isUpstreamOfCubeMem = true; + } + return currCt; + } + } + + // data is not cube_only + return VECTOR_ONLY; + } + + for (auto output : outputs) { + switch (output->isOn()) { + case CUBE_AND_VECTOR: + newCoreType = newCoreType | VECTOR_ONLY; + // not breaking the switch because we need to handle cube + case CUBE_ONLY: + if (ability != OpAbility::PREFER_VECTOR || output->isUpstreamOfCubeMem || + memPolicy == MemPolicy::READ) { + isUpstreamOfCubeMem = + (isUpstreamOfCubeMem || output->isUpstreamOfCubeMem || + memPolicy == MemPolicy::READ); + newCoreType = newCoreType | CUBE_ONLY; + } + break; + case VECTOR_ONLY: + newCoreType = newCoreType | VECTOR_ONLY; + default: // UNDETERMINED, skip + break; + }; + } + + return newCoreType; +} + +CoreType OpNode::absorbImpl() { + if (opIsScf(op)) { + return CUBE_AND_VECTOR; + } + + auto newCoreType = absorbCommon(); + + // if (canRunOn() == OpAbility::CUBE_AND_VECTOR) { + // for (auto input : inputs) { + // newCoreType = newCoreType | input->isOn(); + // } + // } + + return newCoreType; +} + +CoreType ValueNode::absorbImpl() { return absorbCommon(); } + +std::unique_ptr Graph::fromMultiBlockFunc(triton::FuncOp funcOp) { + + auto dummyBlock = new Block(); + auto dummyGraph = std::make_unique(dummyBlock); + auto dummyNode = std::make_unique(nullptr, dummyGraph.get()); + size_t opCount = 0; + + for (auto &block : funcOp.getBody()) { + auto &subgraph = + dummyNode->subgraphs.emplace_back(&block, dummyGraph.get()); + opCount += subgraph.opCount; + } + + auto &opMap = *dummyGraph->opMap.get(); + auto &valueMap = *dummyGraph->valueMap.get(); + + llvm::SmallVector nodes; + nodes.reserve(opMap.size() + valueMap.size()); + + for (auto &[_, node] : opMap) { + if (node.get()) + nodes.push_back(node.get()); + } + + for (auto &[_, node] : valueMap) { + if (node.get()) + nodes.push_back(node.get()); + } + + auto diffuse = [&]() { + // Not sure if determinism is required + llvm::SmallSetVector worklist(nodes.begin(), nodes.end()); + + size_t threshold = worklist.size() * 5; + + for (size_t i = 0; i < threshold; i++) { + if (worklist.empty()) { + break; + } + + auto node = worklist.pop_back_val(); + + if (node->absorb()) { + auto affected = node->getAffected(); + worklist.insert(affected.begin(), affected.end()); + } + } + }; + + diffuse(); + + for (auto node : nodes) { + if (node->isOn() == UNDETERMINED) { + node->isOnPrivate = VECTOR_ONLY; + } + } + + diffuse(); + + OpPrintingFlags flags; + flags.skipRegions(); + + for (auto [idx, node] : llvm::enumerate(nodes)) { + llvm::TypeSwitch(node) + .Case([&, idx = idx](OpNode *node) { + if (node->op) { + llvm::dbgs() << llvm::formatv( + "\n\n====== OpNode on: {1} @ {0} ======\n", node->op, + literalCoreType(node->isOn())); + node->op->print(llvm::dbgs(), flags); + llvm::dbgs() << "\nAbility: " + << literalCoreType(toCoreType(node->canRunOn())); + llvm::dbgs() << llvm::formatv("\n====== {0} ======\n", node->op); + } + }) + .Case([&, idx = idx](ValueNode *node) { + if (node->value) { + llvm::dbgs() << llvm::formatv( + "\n\n====== ValueNode on {1} @ {0} ======\n", node->value, + literalCoreType(node->isOn())); + node->value.print(llvm::dbgs(), flags); + llvm::dbgs() << llvm::formatv("\n====== {0} ======\n", node->value); + } + }); + // if (auto opNode = llvm::dyn_cast(node)) { + // if (auto forOp = llvm::dyn_cast_if_present(opNode->op)) { + // llvm::dbgs() << "\n==== ForOp ====\n"; + // llvm::dbgs() << forOp << "\n"; + // llvm::dbgs() << "\n---- IterArgs ----\n"; + // for(auto iterArg : forOp.getRegionIterArgs()) { + // auto& valueNode = valueMap[iterArg]; + // llvm::dbgs() << llvm::formatv( + // "{0}: {1} upstream: {2} definingOp: {3} \n", + // iterArg.getArgNumber(), + // literalCoreType(valueNode->isOn()), + // literalCoreType(valueNode->source->isOn()), + // valueNode->getSourceOp()->op + // ); + // } + // llvm::dbgs() << "\n---- Results ----\n"; + // for(auto result : forOp.getResults()) { + // llvm::dbgs() << result.getResultNumber() << ' ' << + // literalCoreType(valueMap[result]->isOn()) << '\n'; + // } + // } + // } + } + + return dummyGraph; +}; + +} // namespace AffinityDAG +} // namespace mlir diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp new file mode 100644 index 0000000000..21b1af55a2 --- /dev/null +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp @@ -0,0 +1,5497 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "TritonAffinityOpt/Passes.h" + +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "bishengir/Dialect/HIVM/IR/HIVMImpl.h" +#include "bishengir/Dialect/HIVM/IR/HIVMInterfaces.h" +#include "bishengir/Dialect/HIVM/Transforms/Passes.h" +#include "bishengir/Dialect/HIVM/Utils/Utils.h" +#include "bishengir/Dialect/Scope/IR/Scope.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include + +// #include "mlir/Pass/Pass.h" +// #include "mlir/Pass/PassManager.h" + +// #include "mlir/Transforms/Canonicalizer.h" +// #include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DAGSSBUFFER +#include "ascend/include/TritonAffinityOpt/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace hivm; + +namespace { +struct DAGSSBufferPass + : public mlir::triton::impl::DAGSSBufferBase { + void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } +}; +} // namespace + +void ControlSsbufV2(ModuleOp module) { + mlir::OpBuilder builder(module.getContext()); + // 用于记录已经处理过的scope.scope操作 + llvm::DenseSet processedScopes; + + auto aiCAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + int cubeControlIndex = 15; + int vectorControlIndex = 14; + + llvm::DenseSet processedScopes2; + module->walk([&](SyncBlockWaitOp op) { + auto pipeS = hivm::PipeAttr::get(op->getContext(), hivm::PIPE::PIPE_S); + if (op.getTpipe() == pipeS || op.getPipe() == pipeS) { + return; + } + + // 向上查找父scope.scope操作 + mlir::Operation *parentOp = op->getParentOp(); + mlir::Operation *scopeOp = nullptr; + mlir::Operation *forOp = nullptr; + + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + parentOp = op->getParentOp(); + while (parentOp) { + if (dyn_cast(parentOp)) { + forOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + // 如果没有找到scope.scope操作,则跳过 + if (!scopeOp) { + return; + } + if (!forOp) { + return; + } + + // 如果该scope已经处理过,则跳过 + if (processedScopes2.count(forOp) > 0) + return; + + // 标记该scope为已处理 + processedScopes2.insert(forOp); + }); + bool firstSet = true; + bool firstWait = true; + for (auto forOp : processedScopes2) { + mlir::Operation *parentOp = forOp->getParentOp(); + mlir::Operation *scopeOp = nullptr; + + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + bool isAIC = false; + // 1. 先检查操作是否有这个属性 + + if (scopeOp->hasAttr("hivm.tcore_type")) { + auto attr = scopeOp->getAttr("hivm.tcore_type"); + if (attr == aiCAttr) { + isAIC = true; + } + } + + if (isAIC) { + // 在for循环的开头插入代码 + builder.setInsertionPoint(scopeOp); + // %ssb_ready_addr = llvm.mlir.constant(0 : i64) : i64 + auto i64Type = builder.getIntegerType(64); + auto i32Type = builder.getIntegerType(32); + + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + // %ssb_ready_addr = llvm.mlir.constant(0 : i64) : i64 + // add sync_block_wait + auto coreAttr = + hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto flagId = + builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); + + // 在循环末尾(yield之前)插入代码 + auto &loopBody = forOp->getRegion(0).front(); + // 找到循环体的terminator(应该是yield操作) + auto *terminator = loopBody.getTerminator(); + builder.setInsertionPoint(terminator); + + // add sync_block_set + coreAttr = + hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::CUBE); + setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + flagId = builder.getIntegerAttr(builder.getI64Type(), cubeControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); + + if (firstWait) { + auto &scopeBlock = scopeOp->getRegion(0).front(); + auto *scope_terminator = scopeBlock.getTerminator(); + builder.setInsertionPoint(scope_terminator); + // add sync_block_wait + coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), + hivm::TCoreType::CUBE); + setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + flagId = + builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); + firstWait = false; + } + } else { + // 1. 在scopeop的开头插入代码 + // 假设scopeOp是一个具有区域的操作,我们获取其第一个块 + if (firstSet) { + auto &scopeBlock = scopeOp->getRegion(0).front(); + builder.setInsertionPointToStart(&scopeBlock); + + // add sync_block_wait + auto coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), + hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto flagId = + builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); + firstSet = false; + } + + auto i64Type = builder.getIntegerType(64); + auto i32Type = builder.getIntegerType(32); + + // 创建需要的常量 + auto c32ConstAttr = mlir::IntegerAttr::get(i64Type, 32); + auto c32ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c32ConstAttr); + + auto c0i64ConstAttr = mlir::IntegerAttr::get(i64Type, 0); + auto c0i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c0i64ConstAttr); + + auto c0i32ConstAttr = mlir::IntegerAttr::get(i32Type, 0); + auto c0i32ConstOp = builder.create( + scopeOp->getLoc(), i32Type, c0i32ConstAttr); + + auto c1i32ConstAttr = mlir::IntegerAttr::get(i32Type, 1); + auto c1i32ConstOp = builder.create( + scopeOp->getLoc(), i32Type, c1i32ConstAttr); + + // %sub_id = hivm.hir.get_sub_block_idx -> i64 + // 这里假设有一个getSubBlockIdxOp操作 + auto subIdOp = + builder.create(scopeOp->getLoc(), i64Type); + + // %ssb_addr_offset = arith.muli %sub_id, %c32_i64 : i64 + auto ssbAddrOffsetOp = builder.create( + scopeOp->getLoc(), subIdOp.getResult(), c32ConstOp.getResult()); + + // %ssb_addr = arith.addi %ssb_addr_offset, %c32_i64 : i64 + auto ssbAddrOp = builder.create( + scopeOp->getLoc(), ssbAddrOffsetOp.getResult(), + c32ConstOp.getResult()); + + // %vec_id = arith.cmpi eq, %sub_id, %c0_i64 : i64 + auto vecIdOp = builder.create( + scopeOp->getLoc(), mlir::arith::CmpIPredicate::eq, + subIdOp.getResult(), c0i64ConstOp.getResult()); + + // 2. 在parentop的开头插入代码 + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + + // add sync_block_wait + auto coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), + hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto flagId = + builder.getIntegerAttr(builder.getI64Type(), cubeControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); + + // 在循环末尾(yield之前)插入代码 + auto &loopBody = forOp->getRegion(0).front(); + // 找到循环体的terminator(应该是yield操作) + auto *terminator = loopBody.getTerminator(); + builder.setInsertionPoint(terminator); + + // add sync_block_wait + coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), + hivm::TCoreType::VECTOR); + setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + flagId = builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); + } + } + + auto i64Type = builder.getIntegerType(64); + auto i32Type = builder.getIntegerType(32); + auto initPtrType = mlir::LLVM::LLVMPointerType::get(builder.getContext(), 11); + SmallVector scopeOps; + module->walk([&](mlir::Operation *op) { + // 检查是否为目标操作 + if (auto scopeOp = dyn_cast(op)) { + scopeOps.push_back(scopeOp); + } + }); + if (!scopeOps.empty()) { + auto scopeOp = scopeOps[0]; + builder.setInsertionPoint(scopeOp); + auto c0i64ConstAttr = mlir::IntegerAttr::get(i64Type, 0); + auto c0i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c0i64ConstAttr); + auto c32i64ConstAttr = mlir::IntegerAttr::get(i64Type, 32); + auto c32i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c32i64ConstAttr); + auto c64i64ConstAttr = mlir::IntegerAttr::get(i64Type, 64); + auto c64i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c64i64ConstAttr); + auto c96i64ConstAttr = mlir::IntegerAttr::get(i64Type, 96); + auto c96i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c96i64ConstAttr); + auto c0i32ConstAttr = mlir::IntegerAttr::get(i32Type, 0); + auto c0i32ConstOp = builder.create( + scopeOp->getLoc(), i32Type, c0i32ConstAttr); + + auto c0initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c0i64ConstOp.getResult()); + auto c32initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c32i64ConstOp.getResult()); + auto c64initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c64i64ConstOp.getResult()); + auto c96initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c96i64ConstOp.getResult()); + + builder.create(scopeOp->getLoc(), c0i32ConstOp, + c0initInttoptrOp); + builder.create(scopeOp->getLoc(), c0i32ConstOp, + c32initInttoptrOp); + builder.create(scopeOp->getLoc(), c0i32ConstOp, + c64initInttoptrOp); + builder.create(scopeOp->getLoc(), c0i32ConstOp, + c96initInttoptrOp); + } +} + +scf::ForOp transformLoop(scf::ForOp forOp, OpBuilder &builder) { + + // 1. 获取原始循环的信息 + Value originalLowerBound = forOp.getLowerBound(); + Value originalUpperBound = forOp.getUpperBound(); + Value originalStep = forOp.getStep(); + SmallVector iterArgs; + for (auto arg : forOp.getInitArgs()) { + iterArgs.push_back(arg); + } + auto yields = forOp.getBody()->getTerminator(); + + // 2. 检查循环体中是否有特定操作 + int hasTargetOps = 0; + forOp.walk([&](Operation *op) { + if (auto ifOp = dyn_cast(op)) { + if (ifOp->hasAttr("ssbuffer")) { + hasTargetOps++; + } + } + }); + // 3. 如果存在目标操作,在迭代参数中添加计数器 + Value counterInit = nullptr; + mlir::Operation *parentOp = forOp->getParentOp(); + mlir::Operation *scopeOp = nullptr; + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + + builder.setInsertionPoint(scopeOp); + for (int i = 0; i < hasTargetOps; i++) { + Location loc = forOp.getLoc(); + auto argType = originalLowerBound.getType(); + + // 添加到迭代参数列表 + iterArgs.push_back(originalLowerBound); + } + // 2. 创建新的上界:originalUpperBound * 2 + Location loc = forOp.getLoc(); + Type ubType = originalStep.getType(); + builder.setInsertionPoint(forOp); + + int count = 0; + for (auto &op : forOp.getBody()->getOperations()) { + if (auto ifOp = dyn_cast(op)) { + auto parentOp = ifOp->getParentOp(); + if (parentOp == forOp && ifOp->hasAttr("ssbuffer")) { + count++; + } + } + } + + Value two; + if (ubType.isIndex()) { + two = builder.create(loc, count - 1); + } else if (auto intType = dyn_cast(ubType)) { + // 对于整数类型,创建相应类型的常数2 + two = builder.create(loc, count - 1, intType); + } else { + // 其他类型可能需要特殊处理 + llvm::errs() << "Warning: Unexpected type for upper bound: " << ubType + << "\n"; + // 尝试创建索引类型的2然后转换 + auto indexTwo = builder.create(loc, count - 1); + two = builder.create(loc, ubType, indexTwo); + } + + auto steps = builder.create(forOp.getLoc(), originalStep, two); + + auto nowUpperBound = + builder.create(forOp.getLoc(), originalUpperBound, steps); + + // 3. Create a new for loop + auto newForOp = + builder.create(forOp.getLoc(), originalLowerBound, + nowUpperBound, originalStep, iterArgs); + + // 4. 设置IR映射表,将旧循环的变量映射到新循环 + IRMapping mapper; + + // 映射迭代变量 + mapper.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + // 映射迭代参数 + for (auto [oldArg, newArg] : + llvm::zip(forOp.getRegionIterArgs(), newForOp.getRegionIterArgs())) { + mapper.map(oldArg, newArg); + } + + SmallVector newCounterArgs; + for (int i = forOp.getRegionIterArgs().size(); + i < newForOp.getRegionIterArgs().size(); i++) { + newCounterArgs.push_back(newForOp.getRegionIterArgs()[i]); + } + // 5. 克隆循环体内容到新循环 + auto &newLoopBody = *newForOp.getBody(); + builder.setInsertionPointToStart(&newLoopBody); + + for (auto &op : forOp.getBody()->without_terminator()) { + builder.clone(op, mapper); + } + + // 6. 克隆yield操作 + if (auto yieldOp = dyn_cast(yields)) { + SmallVector newYieldOperands; + for (auto operand : yieldOp.getOperands()) { + newYieldOperands.push_back(mapper.lookupOrDefault(operand)); + } + if (hasTargetOps != 0) { + for (auto currentCounter : newCounterArgs) { + // 将更新后的计数器添加到yield操作数中 + newYieldOperands.push_back(currentCounter); + } + } + builder.create(yieldOp.getLoc(), newYieldOperands); + } + + // 7. 替换原循环的结果 + if (hasTargetOps != 0) { + // 新循环有额外的计数器结果,但原循环没有对应结果 + // 我们可以选择只替换原循环对应的结果,或者忽略计数器结果 + unsigned numOriginalResults = forOp.getNumResults(); + SmallVector originalResults; + for (unsigned i = 0; i < numOriginalResults; i++) { + originalResults.push_back(newForOp.getResult(i)); + } + forOp.replaceAllUsesWith(originalResults); + } else { + forOp.replaceAllUsesWith(newForOp.getResults()); + } + + // 8. 删除原循环 + forOp.erase(); + return newForOp; +} + +// Find the first occurrence of convert_layout or fixpipe operation after the +// specified operation +Value findFirstTargetOpAfterWait(SyncBlockWaitOp waitOp, + SmallVector &excludedValues) { + bool startSearching = false; + + for (Operation &op : waitOp->getBlock()->getOperations()) { + Value res = nullptr; + if (&op == waitOp) { + startSearching = true; + continue; + } + + if (startSearching) { + if (isa(op)) { + res = op.getOperands()[0]; + } + if (isa(op)) { + res = op.getOperands()[1]; + } + if (isa(op)) { + res = op.getOperands()[1]; + } + if (isa(op)) { + res = op.getOperands()[0]; + } + } + if (res) { + if (llvm::is_contained(excludedValues, res)) { + continue; + } + excludedValues.push_back(res); + return res; + } + } + + return nullptr; +} + +void getWaitType(std::string CoreType, scf::ForOp forOp, + SmallVector &waitTypes, SmallVector &allocTypes) { + auto scalarWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_S); + auto cubeWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_FIX); + auto vectorWaitPipe = + PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_MTE3); + SmallVector excludedValues; + forOp.walk([&](Operation *op) { + if (auto waitOp = dyn_cast(op)) { + auto parentOp = op->getParentOp(); + if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { + auto ifOp = dyn_cast(parentOp); + if (forOp == ifOp->getParentOp()) { + auto waitPipe = waitOp.getPipe(); + if ((waitPipe == cubeWaitPipe && CoreType == "cube") || + (waitPipe == vectorWaitPipe && CoreType == "vector")) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + waitTypes.push_back(0); + allocTypes.push_back(allocOp); + } else if (waitPipe != scalarWaitPipe) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + waitTypes.push_back(1); + allocTypes.push_back(allocOp); + } + } + } + } + }); +} + +DenseMap getCounterOffset(scf::ForOp forOp) { + int i = 0; + DenseMap bufferMap; + auto scalarWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_S); + forOp.walk([&](Operation *op) { + bufferMap[i] = 0; + auto ifOp = dyn_cast(op); + if (ifOp && ifOp->hasAttr("ssbuffer") && ifOp->getParentOp() == forOp) { + ifOp.walk([&](Operation *op) { + if (auto waitOp = dyn_cast(op)) { + if (auto waitIfOp = dyn_cast(op->getParentOp())) { + if (waitIfOp == ifOp) { + auto waitPipe = waitOp.getPipe(); + if ((waitPipe != scalarWaitPipe)) { + bufferMap[i]++; + } + } + } + } + }); + i++; + } + }); + return bufferMap; +} + +SmallVector addBufValLoop(scf::ForOp forOp, + DenseMap VecBitMap, + DenseMap CubeBitMap, + OpBuilder &builder) { + auto aiCAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + bool isAIC = false; + // 向上查找父scope.scope操作 + mlir::Operation *parentOp = forOp->getParentOp(); + mlir::Operation *scopeOp = nullptr; + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + if (scopeOp->hasAttr("hivm.tcore_type")) { + auto attr = scopeOp->getAttr("hivm.tcore_type"); + if (attr == aiCAttr) { + isAIC = true; + } + } + auto bufferMap = getCounterOffset(forOp); + SmallVector buf_vals; + SmallVector if_conditions; + builder.setInsertionPointToStart(&scopeOp->getRegion(0).front()); + + // 1. 提取并处理end值 + Value startValue = forOp.getLowerBound(); + Value endValue = forOp.getUpperBound(); + // 2. 提取并处理step值 + Value stepValue = forOp.getStep(); + builder.setInsertionPoint(forOp); + Location loc = forOp.getLoc(); + int count = 0; + for (auto &op : forOp.getBody()->getOperations()) { + if (auto ifOp = dyn_cast(op)) { + auto parentOp = ifOp->getParentOp(); + if (parentOp == forOp && ifOp->hasAttr("ssbuffer")) { + count++; + } + } + } + + Value two; + Type ubType = stepValue.getType(); + if (ubType.isIndex()) { + two = builder.create(loc, count - 1); + } else if (auto intType = dyn_cast(ubType)) { + // 对于整数类型,创建相应类型的常数2 + two = builder.create(loc, count - 1, intType); + } else { + // 其他类型可能需要特殊处理 + llvm::errs() << "Warning: Unexpected type for upper bound: " << ubType + << "\n"; + // 尝试创建索引类型的2然后转换 + auto indexTwo = builder.create(loc, count - 1); + two = builder.create(loc, ubType, indexTwo); + } + + auto steps = builder.create(forOp.getLoc(), endValue.getType(), + stepValue, two); + + auto subLoopValue = builder.create( + forOp.getLoc(), endValue.getType(), endValue, steps); + + SmallVector WaitType; + SmallVector AllocType; + SmallVector bufferPtrs; + if (isAIC) { + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + // 创建常量32和64 + Value c0 = + builder.create(forOp.getLoc(), 0, 32 // 值32,64位 + ); + Value c32 = builder.create(forOp.getLoc(), 32, + 64 // 值32,64位 + ); + Value c64 = builder.create(forOp.getLoc(), 64, + 64 // 值64,64位 + ); + // 创建inttoptr操作 + Value ssb_vec0_ptr = builder.create( + forOp.getLoc(), + LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 + c32); + Value ssb_vec1_ptr = builder.create( + forOp.getLoc(), + LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 + c64); + bufferPtrs.push_back(ssb_vec0_ptr); + bufferPtrs.push_back(ssb_vec1_ptr); + // 创建load操作 + Value status_vec0 = builder.create( + forOp.getLoc(), builder.getI32Type(), ssb_vec0_ptr); + + Value status_vec1 = builder.create( + forOp.getLoc(), builder.getI32Type(), ssb_vec1_ptr); + + getWaitType("cube", forOp, WaitType, AllocType); + + for (auto i = 0; i < WaitType.size(); i++) { + auto correnspondAlloc = CubeBitMap[AllocType[i]]; + auto i32ConstAttr = + mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + Value bufi_vec0_val = builder.create( + forOp.getLoc(), status_vec0, buf_constant_set); + Value bufi_vec1_val = builder.create( + forOp.getLoc(), status_vec1, buf_constant_set); + Value flag_bufi_vec0; + Value flag_bufi_vec1; + // 创建比较操作 + if (WaitType[i] == 0) { + flag_bufi_vec0 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec0_val, c0); + flag_bufi_vec1 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec1_val, c0); + } else { + flag_bufi_vec0 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec0_val, + buf_constant_set); + flag_bufi_vec1 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec1_val, + buf_constant_set); + } + // 创建最终的and操作 + Value bufi_val = builder.create( + forOp.getLoc(), flag_bufi_vec0, flag_bufi_vec1); + buf_vals.push_back(bufi_val); + } + + } else { + builder.setInsertionPointToStart(&scopeOp->getRegion(0).front()); + Value c0 = + builder.create(forOp.getLoc(), 0, 32 // 值32,64位 + ); + auto i64Type = builder.getIntegerType(64); + // %sub_id = hivm.hir.get_sub_block_idx -> i64 + // 这里假设有一个getSubBlockIdxOp操作 + auto subIdOp = builder.create(scopeOp->getLoc(), i64Type); + auto i64ConstAttr = mlir::IntegerAttr::get(i64Type, 32); + auto cst_offset = builder.create( + scopeOp->getLoc(), i64Type, i64ConstAttr); + auto ssb_addr_offset = + builder.create(scopeOp->getLoc(), subIdOp, cst_offset); + auto ssb_addr = builder.create(scopeOp->getLoc(), + ssb_addr_offset, cst_offset); + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + // 创建inttoptr操作 + Value ssb_cube_ptr = builder.create( + forOp.getLoc(), + LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 + ssb_addr); + bufferPtrs.push_back(ssb_cube_ptr); + // 创建load操作 + Value status_cube = builder.create( + forOp.getLoc(), builder.getI32Type(), ssb_cube_ptr); + + getWaitType("vector", forOp, WaitType, AllocType); + for (auto i = 0; i < WaitType.size(); i++) { + auto correnspondAlloc = VecBitMap[AllocType[i]]; + auto i32ConstAttr = + mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + Value bufi_cube_val = builder.create( + forOp.getLoc(), status_cube, buf_constant_set); + + Value flag_bufi_cube; + // 创建比较操作 + if (WaitType[i] == 0) { + flag_bufi_cube = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_cube_val, c0); + } else { + flag_bufi_cube = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_cube_val, + buf_constant_set); + } + buf_vals.push_back(flag_bufi_cube); + } + } + int bufIdx = 0; + int groupIdx = 0; + + for (const auto &pair : bufferMap) { + if (bufferMap[groupIdx] == 0) { + continue; + } + + // 获取对应的region迭代参数 + Value cnti = builder.create( + forOp.getLoc(), arith::CmpIPredicate::slt, + forOp.getRegionIterArgs()[forOp.getRegionIterArgs().size() - + (bufferMap.size() - 1 - groupIdx)], + subLoopValue); + + // 计算该组中所有buffer值的AND + Value finalBufVal = buf_vals[bufIdx]; + for (int count = 1; count < bufferMap[groupIdx]; count++) { + finalBufVal = builder.create(forOp.getLoc(), finalBufVal, + buf_vals[bufIdx + count]); + } + + auto cond = + builder.create(forOp.getLoc(), finalBufVal, cnti); + if_conditions.push_back(cond); + + // 更新索引 + bufIdx += bufferMap[groupIdx]; + groupIdx++; + } + int ifIndex = 0; + int acc = 0; + int bufferBit = 0; + for (int i = 0; i < CubeBitMap.size(); i++) { + bufferBit += (1 << i); + } + forOp.getBody()->walk([&](Operation *op) { + auto ifOp = dyn_cast(op); + if (ifOp && ifOp->hasAttr("ssbuffer")) { + // 获取then区域 + Block *thenBlock = &ifOp.getThenRegion().front(); + + // 找到then区域中的yield操作 + Operation *yieldOp = nullptr; + for (auto &op : *thenBlock) { + if (isa(op)) { + yieldOp = &op; + break; + } + } + if (yieldOp) { + builder.setInsertionPoint(yieldOp); + + if (isAIC) { + // 创建插入的语句 + // %status_v2 = llvm.load %ssb_ptr : !llvm.ptr<11> -> i32 + Value status_v2_0 = builder.create( + yieldOp->getLoc(), + builder.getIntegerType(32), // i32类型 + bufferPtrs[0] // 假设ssb_ptr已在作用域中定义 + ); + Value status_v2_1 = builder.create( + yieldOp->getLoc(), + builder.getIntegerType(32), // i32类型 + bufferPtrs[1] // 假设ssb_ptr已在作用域中定义 + ); + Value buf_val_new_0 = status_v2_0; + Value buf_val_new_1 = status_v2_1; + auto bufferNum = bufferMap[ifIndex]; + for (int i = 0; i < bufferNum; i++) { + if (WaitType[acc + i] == 0) { + auto correnspondAlloc = CubeBitMap[AllocType[acc + i]]; + auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), + 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new_0 = builder.create( + yieldOp->getLoc(), buf_val_new_0, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + buf_val_new_1 = builder.create( + yieldOp->getLoc(), buf_val_new_1, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + } else { + auto correnspondAlloc = CubeBitMap[AllocType[acc + i]]; + int bitPos = correnspondAlloc; + int basePattern = bufferBit; + int finalValue = basePattern ^ (1 << bitPos); + auto i32ConstAttr = + mlir::IntegerAttr::get(builder.getI32Type(), finalValue); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new_0 = builder.create( + yieldOp->getLoc(), buf_val_new_0, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + buf_val_new_1 = builder.create( + yieldOp->getLoc(), buf_val_new_1, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + } + } + acc += bufferNum; + builder.create(yieldOp->getLoc(), buf_val_new_0, + bufferPtrs[0]); + builder.create(yieldOp->getLoc(), buf_val_new_1, + bufferPtrs[1]); + + } else { + // 创建插入的语句 + // %status_v2 = llvm.load %ssb_ptr : !llvm.ptr<11> -> i32 + Value status_v2 = builder.create( + yieldOp->getLoc(), + builder.getIntegerType(32), // i32类型 + bufferPtrs[0] // 假设ssb_ptr已在作用域中定义 + ); + Value buf_val_new = status_v2; + auto bufferNum = bufferMap[ifIndex]; + for (int i = 0; i < bufferNum; i++) { + if (WaitType[acc + i] == 0) { + auto correnspondAlloc = VecBitMap[AllocType[acc + i]]; + auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), + 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new = builder.create( + yieldOp->getLoc(), buf_val_new, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + } else { + auto correnspondAlloc = VecBitMap[AllocType[acc + i]]; + int bitPos = correnspondAlloc; + int basePattern = bufferBit; + int finalValue = basePattern ^ (1 << bitPos); + auto i32ConstAttr = + mlir::IntegerAttr::get(builder.getI32Type(), finalValue); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new = builder.create( + yieldOp->getLoc(), buf_val_new, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + } + } + acc += bufferNum; + builder.create(yieldOp->getLoc(), buf_val_new, + bufferPtrs[0]); + } + ifIndex++; + } + } + }); + + return if_conditions; +} + +void ReplaceIf(scf::ForOp forOp, SmallVector conditions, + SmallVector &opsToErase, + DenseMap &ifArgMap, OpBuilder &builder, + ModuleOp moduleOp) { + SmallVector ifToProcess; + llvm::outs() << "enter replaceif\n"; + Value step = forOp.getStep(); + auto aiCAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + forOp.getBody()->walk([&](Operation *op) { + auto ifOp = dyn_cast(op); + if (ifOp && ifOp->hasAttr("ssbuffer") && forOp == ifOp->getParentOp()) { + ifToProcess.push_back(ifOp); + } + }); + + IRMapping IRMap; + for (int i = 0; i < ifToProcess.size(); i++) { + auto ifOp = ifToProcess[i]; + auto parentOp = ifOp->getParentOp(); + auto loc = ifOp.getLoc(); + // 获取for循环的iterargs(迭代参数) + auto iterArgs = forOp.getRegionIterArgs(); + if (iterArgs.size() < conditions.size()) { + return; + } + auto thenYieldOp = + dyn_cast(ifOp.getThenRegion().front().getTerminator()); + SmallVector thenResults; + if (thenYieldOp) { + // 如果已有返回值,保留它们 + for (auto result : thenYieldOp.getResults()) { + thenResults.push_back(result); + } + } + // 创建新的else区域,返回两个迭代参数 + SmallVector elseResults; + scf::YieldOp elseYieldOp = nullptr; + bool hasElse = false; + if (!ifOp.getElseRegion().empty()) { + elseYieldOp = + dyn_cast(ifOp.getElseRegion().front().getTerminator()); + hasElse = true; + } + if (elseYieldOp) { + for (auto result : elseYieldOp.getResults()) { + elseResults.push_back(result); + } + } + // 获取最后两个迭代参数 + Value iterArgMinus = iterArgs[iterArgs.size() - (conditions.size() - i)]; + // 创建新的then区域,返回两个迭代参数 + thenResults.push_back(iterArgMinus); + elseResults.push_back(iterArgMinus); + + // 保存原有的操作,以便后续克隆 + SmallVector thenOps; + for (auto &op : ifOp.getThenRegion().front()) { + thenOps.push_back(&op); + } + + SmallVector elseOps; + if (!ifOp.getElseRegion().empty()) { + for (auto &op : ifOp.getElseRegion().front()) { + elseOps.push_back(&op); + } + } + SmallVector resultTypes; + for (auto val : thenResults) { + resultTypes.push_back(val.getType()); + } + // 创建新的scf.if操作 + builder.setInsertionPoint(ifOp); + auto newIfOp = builder.create(loc, resultTypes, conditions[i], + /*withElseRegion=*/true); + newIfOp->setAttr("ssbuffer", builder.getUnitAttr()); + // 处理then区域 + auto &newThenBlock = newIfOp.getThenRegion().front(); + builder.setInsertionPointToStart(&newThenBlock); + + // 克隆then区域的操作 + for (auto op : thenOps) { + if (auto yieldOp = dyn_cast(op)) { + // 处理yield的操作数映射 + SmallVector mappedOperands; + for (auto operand : yieldOp->getOperands()) { + mappedOperands.push_back(IRMap.lookupOrDefault(operand)); + } + // 获取最后两个迭代参数 + Value iterArgMinus = + iterArgs[iterArgs.size() - (conditions.size() - i)]; + + // %ssb_addr = arith.addi %ssb_addr_offset, %c32_i64 : i64 + auto AddIOp = builder.create(forOp->getLoc(), + iterArgMinus, step); + // 这里加个add1 + mappedOperands.push_back(AddIOp); + builder.create(loc, mappedOperands); + } else { + auto newOp = builder.clone(*op, IRMap); + IRMap.map(op->getResults(), newOp->getResults()); + } + } + + // 处理else区域 + auto &newElseBlock = newIfOp.getElseRegion().front(); + builder.setInsertionPointToStart(&newElseBlock); + // 克隆else区域的操作 + if (hasElse) { + for (auto op : elseOps) { + if (auto yieldOp = dyn_cast(op)) { + // 处理yield的操作数映射 + SmallVector mappedOperands; + for (auto operand : yieldOp->getOperands()) { + mappedOperands.push_back(IRMap.lookupOrDefault(operand)); + } + Value iterArgMinus = + iterArgs[iterArgs.size() - (conditions.size() - i)]; + mappedOperands.push_back(iterArgMinus); + builder.create(loc, mappedOperands); + } else { + auto newOp = builder.clone(*op, IRMap); + IRMap.map(op->getResults(), newOp->getResults()); + } + } + } else { + SmallVector cntOperands; + cntOperands.push_back(iterArgMinus); + builder.create(loc, cntOperands); + } + + // 替换原有if操作的使用 + // 首先,将原if操作的结果替换为新if操作的对应结果 + for (unsigned j = 0; j < ifOp.getNumResults(); ++j) { + ifOp.getResult(j).replaceAllUsesWith(newIfOp.getResult(j)); + } + // 获取新if操作所在的块 + Block *newIfBlock = ifOp->getBlock(); + // 在for循环体内替换迭代参数的使用 + forOp.getBody()->walk([&](Operation *op) { + // 检查操作是否与新ifOp在同一个块中 + Block *opBlock = op->getBlock(); + if (opBlock != newIfBlock) { + // 不在同一个块中,跳过 + return; + } + if (op->isBeforeInBlock(newIfOp)) { + return; // 只处理if操作之后的use + } + for (unsigned j = 0; j < op->getNumOperands(); ++j) { + for (auto argIndex = 0; argIndex < conditions.size(); argIndex++) { + // 获取最后两个迭代参数 + Value iterArgMinus = + iterArgs[iterArgs.size() - (conditions.size() - i)]; + if (op->getOperand(j) == iterArgMinus) { + op->setOperand(j, + newIfOp.getResults()[newIfOp.getNumResults() - 1]); + } + } + } + }); + + // // 删除原有的if操作 + opsToErase.push_back(ifOp); + if (ifArgMap.find(newIfOp) == ifArgMap.end()) { + ifArgMap[newIfOp] = iterArgMinus; + } + } +} + +int getNestingDepth(scf::ForOp forOp) { + int depth = 0; + Operation *op = forOp.getOperation(); + while (op) { + if (op->getDialect() && op->getDialect()->getNamespace() == "scf") { + ++depth; + } + op = op->getParentOp(); + } + return depth; +} + +void printDenseMap(const mlir::DenseMap &Map) { + for (const auto &pair : Map) { + mlir::Value val = pair.first; + int bitValue = pair.second; + llvm::outs() << val << " " << bitValue << " allocmap\n\n\n"; + llvm::outs().flush(); + } + llvm::outs() << "------------------------------\n\n\n"; +} + +void getAllocBit(ModuleOp module, DenseMap &VecBitMap, + DenseMap &CubeBitMap, OpBuilder builder) { + auto aiCAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto scalarWaitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto cubeWaitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_FIX); + auto vectorWaitPipe = + PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_MTE3); + + int cubeAcc = 0; + int vecAcc = 0; + SmallVector scopeOpToEdit; + module.walk( + [&](scope::ScopeOp scopeOp) { scopeOpToEdit.push_back(scopeOp); }); + for (auto scopeOp : scopeOpToEdit) { + SmallVector excludedValues; + if (scopeOp->hasAttr("hivm.tcore_type")) { + auto attr = scopeOp->getAttr("hivm.tcore_type"); + if (attr == aiCAttr) { + scopeOp.walk([&](SyncBlockWaitOp waitOp) { + auto parentOp = waitOp->getParentOp(); + if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { + auto waitPipe = waitOp.getPipe(); + if (waitPipe != scalarWaitPipe) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + if (VecBitMap.find(allocOp) != VecBitMap.end()) { + CubeBitMap[allocOp] = VecBitMap[allocOp]; + } else { + CubeBitMap[allocOp] = cubeAcc; + cubeAcc++; + } + } + } + }); + } else { + scopeOp.walk([&](SyncBlockWaitOp waitOp) { + auto parentOp = waitOp->getParentOp(); + if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { + auto waitPipe = waitOp.getPipe(); + if (waitPipe != scalarWaitPipe) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + if (VecBitMap.find(allocOp) == VecBitMap.end()) { + VecBitMap[allocOp] = vecAcc; + vecAcc++; + } + } + } + }); + } + } + } +} + +void modifyForIterargDeps(scf::ForOp forOp, + DenseMap ifCounters) { + Value iterArg = forOp.getInductionVar(); + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (auto ifOp = dyn_cast(op)) { + if (ifCounters.find(ifOp) != ifCounters.end()) { + Value counter = ifCounters[ifOp]; + + ifOp.walk([&](Operation *opInIf) { + for (auto [i, operand] : llvm::enumerate(opInIf->getOperands())) { + if (operand == iterArg) { + opInIf->setOperand(i, counter); + } + } + }); + } + } + } +} + +void FlowSssbuf(ModuleOp module) { + mlir::OpBuilder builder(module.getContext()); + // 收集所有需要转换的循环 + SmallVector targetLoops; + llvm::outs() << "enter flowsssbuf\n\n"; + module.walk([&](Operation *op) { + if (auto forOp = dyn_cast(op)) { + // 检查循环是否包含特定的 sync_block_set 操作 + bool hasSyncBlockSet = false; + forOp.walk([&](Operation *op) { + if (isa(op)) { + if (auto ifOp = dyn_cast(op->getParentOp())) { + if (forOp == ifOp->getParentOp() && ifOp->hasAttr("ssbuffer")) { + hasSyncBlockSet = true; + } + } + } + }); + + if (hasSyncBlockSet) { + if (llvm::find(targetLoops, forOp) == targetLoops.end()) { + targetLoops.push_back(forOp); + } + } + } + }); + llvm::outs() << "enter flowsssbuf\n\n"; + + SmallVector transformLoops; + // 转换每个目标循环 + for (scf::ForOp forOp : targetLoops) { + auto newforOp = transformLoop(forOp, builder); + } + + module.walk([&](Operation *op) { + if (auto forOp = dyn_cast(op)) { + // 检查循环是否包含特定的 sync_block_set 操作 + bool hasSyncBlockSet = false; + forOp.walk([&](Operation *op) { + if (isa(op)) { + if (auto ifOp = dyn_cast(op->getParentOp())) { + if (forOp == ifOp->getParentOp() && ifOp->hasAttr("ssbuffer")) { + hasSyncBlockSet = true; + } + } + } + }); + + if (hasSyncBlockSet) { + if (llvm::find(transformLoops, forOp) == transformLoops.end()) { + transformLoops.push_back(forOp); + } + } + } + }); + + llvm::sort(transformLoops, [](scf::ForOp a, scf::ForOp b) { + return getNestingDepth(a) > getNestingDepth(b); + }); + DenseMap VecBitMap; + DenseMap CubeBitMap; + getAllocBit(module, VecBitMap, CubeBitMap, builder); + printDenseMap(CubeBitMap); + printDenseMap(VecBitMap); + SmallVector opsToErase; + for (scf::ForOp forOp : transformLoops) { + DenseMap ifArgMap; + llvm::outs() << "before replaceif\n"; + auto bufvals = addBufValLoop(forOp, VecBitMap, CubeBitMap, builder); + ReplaceIf(forOp, bufvals, opsToErase, ifArgMap, builder, module); + llvm::outs() << "after replaceif\n"; + for (const auto &pair : ifArgMap) { + auto val = pair.first; + auto bitValue = pair.second; + llvm::outs() << val << " " << bitValue << " ifargmrp\n\n\n"; + llvm::outs().flush(); + } + + modifyForIterargDeps(forOp, ifArgMap); + } + for (auto op : opsToErase) { + op->erase(); + } +} + +bool isTransOp(mlir::Operation *op) { + auto fixpipeOp = dyn_cast(op); + if (fixpipeOp) + return true; + + auto copyOp = dyn_cast(op); + if (!copyOp) + return false; + else { + + Value copySrc = copyOp.getODSOperands(0).front(); + MemRefType copySrcTy = dyn_cast(copySrc.getType()); + auto SrcAddrSpace = + dyn_cast_or_null(copySrcTy.getMemorySpace()); + bool isSrcUbSpace = + SrcAddrSpace.getAddressSpace() == hivm::AddressSpace::UB; + + Value copyDst = copyOp.getODSOperands(1).front(); + MemRefType copyDstTy = dyn_cast(copyDst.getType()); + auto DstAddrSpace = + dyn_cast_or_null(copyDstTy.getMemorySpace()); + bool isDstCbufSpace = + DstAddrSpace.getAddressSpace() == hivm::AddressSpace::L1; + + return isSrcUbSpace && isDstCbufSpace; + } +} + +void FindAndMarkBuffer(ModuleOp module) { + OpBuilder builder(module.getContext()); + unsigned int BufferIdx = 0; + Type idxType = builder.getI32Type(); + StringAttr setFlagAttr = builder.getStringAttr("Set flag"); + StringAttr waitFlagAttr = builder.getStringAttr("Wait flag"); + IntegerAttr idxAttr = builder.getI32IntegerAttr(BufferIdx); + + module.walk([&](mlir::Operation *op) { + if (isTransOp(op)) { + llvm::outs() << "Buffer idx" << BufferIdx << "\n"; + llvm::outs() << "Trans Op" << *op << "\n"; + Value SharedBuffer; + if (auto fixpipeOp = dyn_cast(op)) { + SharedBuffer = fixpipeOp.getODSOperands(1).front(); + } else { + auto copyOp = dyn_cast(op); + SharedBuffer = copyOp.getODSOperands(1).front(); + } + llvm::outs() << "SharedBuffer" << SharedBuffer << "\n"; + + if (!SharedBuffer) { + op->emitWarning("fixpipe op has empty output operand!"); + return; + } + + // 在Buffer的生产op后set flag标记,在Buffer消费op前增加wait flag标记 + op->setAttr("Buffer idx", builder.getI32IntegerAttr(BufferIdx)); + op->setAttr("Wait Flag", builder.getI32IntegerAttr(0)); + op->setAttr("Set Flag", builder.getI32IntegerAttr(1)); + + for (Operation *consumerOp : SharedBuffer.getUsers()) { + if (consumerOp == op) + continue; + if (!consumerOp) + continue; + + llvm::outs() << "consumerOp: " << *consumerOp << "\n"; + + consumerOp->setAttr("Buffer idx", builder.getI32IntegerAttr(BufferIdx)); + consumerOp->setAttr("Wait Flag", builder.getI32IntegerAttr(0)); + } + BufferIdx++; + } + }); +} + +// 结构体存 wait-set 区块信息 +struct WaitSetRegion { + Operation *waitOp; + Operation *lastSetOp; + SmallVector opsToMove; + bool hasCopyOrFixpipe = false; +}; + +struct MergedRegion { + SmallVector regions; + SmallVector opsToMove; + SmallVector yieldValues; + SmallVector resultTypes; +}; + +void MoveIterArgUsersIntoIf(scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // iter_arg -> mergedRegion index + DenseMap iterArgToRegion; + + for (int r = 0; r < mergedRegions.size(); ++r) { + MergedRegion &mr = mergedRegions[r]; + + for (Operation *op : mr.opsToMove) { + for (Value v : op->getOperands()) { + if (auto barg = mlir::dyn_cast(v)) { + if (barg.getOwner() == &body) { + iterArgToRegion.try_emplace(barg, r); + } + } + } + } + } + + if (iterArgToRegion.empty()) + return; + + // 找最后一个 mergedRegion 的最后一个 op + Operation *lastOp = nullptr; + for (MergedRegion &mr : mergedRegions) + lastOp = mr.opsToMove.back(); + + if (!lastOp) + return; + + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) + opIndex[&op] = idx++; + + int startIdx = opIndex[lastOp] + 1; + + // 扫描 for body 尾部 op + for (Operation &op : body) { + if (opIndex[&op] < startIdx) + continue; + + llvm::SmallDenseSet usedRegions; + for (Value v : op.getOperands()) { + if (auto barg = mlir::dyn_cast(v)) { + auto it = iterArgToRegion.find(barg); + if (it != iterArgToRegion.end()) + usedRegions.insert(it->second); + } + } + + // 必须且只能依赖一个 mergedRegion + if (usedRegions.size() != 1) + continue; + + int target = *usedRegions.begin(); + + mergedRegions[target].opsToMove.push_back(&op); + } +} + +void ComputeYieldForMergedRegion(MergedRegion &mr, Block &body) { + + mr.yieldValues.clear(); + mr.resultTypes.clear(); + + SmallPtrSet inRegion(mr.opsToMove.begin(), + mr.opsToMove.end()); + + for (Operation *op : mr.opsToMove) { + for (Value res : op->getResults()) { + bool usedOutside = false; + + for (OpOperand &use : res.getUses()) { + Operation *user = use.getOwner(); + + // 不在同一个 for body,交给外层处理(通常不会出现) + if (user->getBlock() != &body) + continue; + + // 只要有一个 use 在 region 外,就必须 yield + if (!inRegion.contains(user)) { + usedOutside = true; + break; + } + } + + if (usedOutside) { + mr.yieldValues.push_back(res); + mr.resultTypes.push_back(res.getType()); + } + } + } +} + +static void ComputeYieldForMergedRegionV2(MergedRegion &mr, Block &body) { + + mr.yieldValues.clear(); + mr.resultTypes.clear(); + + // 当前 region 内的 ops + SmallPtrSet inRegion(mr.opsToMove.begin(), + mr.opsToMove.end()); + + for (Operation *op : mr.opsToMove) { + for (Value res : op->getResults()) { + + bool usedOutside = false; + + for (OpOperand &use : res.getUses()) { + Operation *user = use.getOwner(); + + // 如果使用在 region 内部 op,跳过 + if (inRegion.contains(user)) + continue; + + // 使用在 region 外部,包括嵌套 region 内部的 block + usedOutside = true; + break; + } + + if (usedOutside) { + mr.yieldValues.push_back(res); + mr.resultTypes.push_back(res.getType()); + } + } + } +} + +static void ComputeYieldForMergedRegionV3(MergedRegion &mr) { + mr.yieldValues.clear(); + mr.resultTypes.clear(); + + // 用 DenseSet 暂存当前 region 的所有 ops + DenseSet regionOps(mr.opsToMove.begin(), mr.opsToMove.end()); + + for (Operation *op : mr.opsToMove) { + for (Value res : op->getResults()) { + + bool needsYield = false; + + for (OpOperand &use : res.getUses()) { + Operation *user = use.getOwner(); + + // 如果 user 不在当前 region,则需要 yield + if (!regionOps.contains(user)) { + needsYield = true; + break; + } + } + + if (needsYield) { + mr.yieldValues.push_back(res); + mr.resultTypes.push_back(res.getType()); + } + } + } +} + +// 递归收集 op 和它所有 region 内的 ops +static void CollectAllNestedOps(Operation *op, + DenseSet ®ionOps) { + if (!op) + return; + + if (regionOps.contains(op)) + return; // 已经收集过 + + regionOps.insert(op); + + // 遍历所有 region,递归收集 + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : block) { + CollectAllNestedOps(&nestedOp, regionOps); + } + } + } +} + +static void ComputeYieldForMergedRegionV4(MergedRegion &mr) { + mr.yieldValues.clear(); + mr.resultTypes.clear(); + + // 用 DenseSet 暂存当前 region 的所有 ops + // 初始 DenseSet: 顶层 opsToMove + DenseSet regionOps; + for (Operation *op : mr.opsToMove) { + CollectAllNestedOps(op, regionOps); // 完整展开嵌套 + } + + for (Operation *op : mr.opsToMove) { + for (Value res : op->getResults()) { + + bool needsYield = false; + + for (OpOperand &use : res.getUses()) { + Operation *user = use.getOwner(); + + // 如果 user 不在当前 region,则需要 yield + if (!regionOps.contains(user)) { + needsYield = true; + break; + } + } + + if (needsYield) { + mr.yieldValues.push_back(res); + mr.resultTypes.push_back(res.getType()); + } + } + } +} + +int findTargetRegion(Operation *startOp, Block &body, + DenseMap &opToRegion) { + + SmallVector worklist{startOp}; + SmallPtrSet visited; + + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + if (!visited.insert(op).second) + continue; + + auto it = opToRegion.find(op); + if (it != opToRegion.end()) + return it->second; + + for (Value operand : op->getOperands()) { + if (isa(operand)) + continue; + + Operation *defOp = operand.getDefiningOp(); + if (defOp && defOp->getBlock() == &body) + worklist.push_back(defOp); + } + } + + return -1; +} + +void greedyAbsorbToRegion(Operation *startOp, int regionIdx, int lowerBound, + Block &body, DenseMap &opIndex, + DenseMap &opToRegion, + SmallVector &mergedRegions) { + + auto &mr = mergedRegions[regionIdx]; + + SmallVector worklist; + SmallPtrSet visited(mr.opsToMove.begin(), + mr.opsToMove.end()); + + // 先把 startOp 本身吸收(如果还没被吸收) + if (!opToRegion.count(startOp)) { + mr.opsToMove.push_back(startOp); + opToRegion[startOp] = regionIdx; + visited.insert(startOp); + } + + worklist.push_back(startOp); + + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + + for (Value operand : op->getOperands()) { + if (isa(operand)) + continue; + + Operation *defOp = operand.getDefiningOp(); + if (!defOp || defOp->getBlock() != &body) + continue; + + int defIdx = opIndex[defOp]; + + // 超过前一个 region 的末尾 + if (defIdx < lowerBound) + continue; + + auto it = opToRegion.find(defOp); + + // 不能跨到其他 region + if (it != opToRegion.end() && it->second != regionIdx) + continue; + + // 去重 + if (!visited.insert(defOp).second) + continue; + + // 吸收 defOp + mr.opsToMove.push_back(defOp); + opToRegion[defOp] = regionIdx; + worklist.push_back(defOp); + } + } +} + +SmallVector +getOperationInput(Operation *op, SmallVector dependValues, + DenseMap>> + &collectDepValueMap) { + // Analyse each Op's input + DenseSet opInput; + if (isa(op) || isa(op)) { + SmallVector regionBlocks; + if (auto ifOp = dyn_cast(op)) { + regionBlocks.push_back(&(ifOp.getThenRegion().front())); + regionBlocks.push_back(&(ifOp.getElseRegion().front())); + } else { + auto forOp = dyn_cast(op); + regionBlocks.push_back(forOp.getBody()); + } + + // recursively walk scf op + for (Block *curBlock : regionBlocks) { + for (auto &curOp : *curBlock) { + for (auto operand : + getOperationInput(&curOp, dependValues, collectDepValueMap)) { + Operation *defOp; + if (auto blockArg = dyn_cast(operand)) { + Block *ownerBlock = blockArg.getOwner(); + defOp = ownerBlock->getParentOp(); + } else { + defOp = operand.getDefiningOp(); + } + Block *defBlock = defOp->getBlock(); + + if (!(defOp == op || llvm::is_contained(regionBlocks, defBlock))) { + opInput.insert(operand); + } + } + } + } + SmallVector retVector(opInput.begin(), opInput.end()); + return retVector; + } else { + SmallVector operands = op->getOperands(); + // store ifresult value that will be replaced + for (auto operand : operands) { + if (llvm::is_contained(dependValues, operand)) { + if (collectDepValueMap.find(operand) != collectDepValueMap.end()) { + collectDepValueMap[operand].second.push_back(op); + } else { + SmallVector userOps; + userOps.push_back(op); + collectDepValueMap[operand] = {operand, userOps}; + } + } + } + return operands; + } +} + +SmallVector collectDepValuesCalculation( + DenseSet forRegionOps, DenseSet regionOps, + Operation *op, SmallVector dependValues, + DenseMap>> + &collectDepValueMap) { + DenseSet collectOps; + std::deque opStack; + bool flag = false; + + opStack.push_back(op); + while (opStack.size()) { + Operation *curOp = opStack.front(); + opStack.pop_front(); + + for (auto operand : + getOperationInput(curOp, dependValues, collectDepValueMap)) { + if (llvm::is_contained(dependValues, operand)) { + flag = true; + } + + Operation *parentOp = operand.getDefiningOp(); + if (llvm::is_contained(regionOps, parentOp)) { + opStack.push_back(parentOp); + continue; + } else if (llvm::is_contained(forRegionOps, parentOp)) { + opStack.push_back(parentOp); + collectOps.insert(parentOp); + } + } + } + + if (flag) { + SmallVector retVector(collectOps.begin(), collectOps.end()); + return retVector; + } else { + collectDepValueMap.clear(); + SmallVector emptyVector; + emptyVector.clear(); + return emptyVector; + } +} + +void copyOpsToMergedRegion( + scf::ForOp forOp, SmallVector collectOps, + MergedRegion &mergedRegion, + DenseMap>> + &collectDepValueMap) { + Block *forBodyBlock = forOp.getBody(); + OpBuilder builder(forOp); + SmallVector clonedOps; + IRMapping mapper; + + // copy calculation of ifreult value related to load/store op + int cnt = 0; + for (Operation &origOp : forBodyBlock->without_terminator()) { + if (cnt >= collectOps.size()) + break; + + if (llvm::is_contained(collectOps, &origOp)) { + builder.setInsertionPointAfter(&origOp); + + Operation *clonedOp = (&origOp)->clone(mapper); + builder.insert(clonedOp); + mapper.map(&origOp, clonedOp); + + clonedOps.push_back(clonedOp); + cnt++; + + // replace the ifresult value by new cloned op's result + SmallVector results = origOp.getResults(); + for (auto [idx, result] : llvm::enumerate(origOp.getResults())) { + if (collectDepValueMap.find(result) != collectDepValueMap.end()) { + collectDepValueMap[result].first = clonedOp->getResult(idx); + } + } + } + } + + DenseSet mergedRegionOps; + for (Operation *op : mergedRegion.opsToMove) { + CollectAllNestedOps(op, mergedRegionOps); + } + + // replace the ifresult value by new cloned op's result + for (Operation *op : mergedRegionOps) { + for (auto [idx, operand] : llvm::enumerate(op->getOperands())) { + if (collectDepValueMap.find(operand) != collectDepValueMap.end()) { + op->setOperand(idx, collectDepValueMap[operand].first); + } + } + } + + // update MergedRegion + clonedOps.append(mergedRegion.opsToMove); + mergedRegion.opsToMove = clonedOps; +} + +void copyLoadCalculation(scf::ForOp forOp, SmallVector dependValues, + SmallVector &mergedRegions) { + mlir::Operation *parentOp = forOp->getParentOp(); + mlir::Operation *scopeOp = nullptr; + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + auto coreTypeAttr = + scopeOp->getAttrOfType(hivm::TCoreTypeAttr::name); + // only process the vector core + if (coreTypeAttr.getTcoretype() == hivm::TCoreType::CUBE) { + return; + } + + // recursively collect all op in forOp + DenseSet forRegionOps; + for (Operation &op : forOp.getBody()->without_terminator()) { + CollectAllNestedOps(&op, forRegionOps); + } + + for (MergedRegion &mr : mergedRegions) { + DenseSet regionOps; + for (Operation *op : mr.opsToMove) { + CollectAllNestedOps(op, regionOps); + } + + for (Operation *op : regionOps) { + if (isa(op) || isa(op)) { + // recusively check that whether load/store op's operands originated + // from if results + DenseMap>> + collectDepValueMap; + SmallVector collectOps = collectDepValuesCalculation( + forRegionOps, regionOps, op, dependValues, collectDepValueMap); + copyOpsToMergedRegion(forOp, collectOps, mr, collectDepValueMap); + } + } + } +} + +// 以 forOp 的 yield value 为中心 +// 决定它应该归属哪个 mergedRegion, 然后再向前吸 operand +void ExpandMergedRegionOpsForAIV(scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // 记录 block 中 op 顺序 + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) + opIndex[&op] = idx++; + + // 建立 op -> region 映射 + DenseMap opToRegion; + for (int r = 0; r < mergedRegions.size(); ++r) + for (Operation *op : mergedRegions[r].opsToMove) + opToRegion[op] = r; + + // 取 scf.yield + auto yieldOp = cast(body.getTerminator()); + + // 依次处理每个 yield value(按编号顺序) + for (Value yv : yieldOp.getOperands()) { + + Operation *defOp = yv.getDefiningOp(); + if (!defOp || defOp->getBlock() != &body) + continue; + + int targetRegion = -1; + + // 如果已经在 region 中 + auto it = opToRegion.find(defOp); + if (it != opToRegion.end()) { + targetRegion = it->second; + } else { + // 否则向前搜索确定归属 + targetRegion = findTargetRegion(defOp, body, opToRegion); + } + + if (targetRegion == -1) + continue; + + // 计算边界 lowerBound + int lowerBound = 0; + + if (targetRegion > 0) { + Operation *prevLast = mergedRegions[targetRegion - 1].opsToMove.back(); + lowerBound = opIndex[prevLast] + 1; + } + + // 真正贪心吸收 + greedyAbsorbToRegion(defOp, targetRegion, lowerBound, body, opIndex, + opToRegion, mergedRegions); + } + + // 每个 region 内按 block 顺序排序 + for (auto &mr : mergedRegions) { + llvm::sort(mr.opsToMove, [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); + } +} + +// 以 mergedRegion 为中心, 向前吸 operand +void ExpandMergedRegionOpsForAIC(scf::ForOp forOp, + SmallVector &mergedRegions) { + Block &body = forOp.getRegion().front(); + + // 记录每个 mergedRegion 的起始 op index + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) { + opIndex[&op] = idx++; + } + + for (int r = 0; r < mergedRegions.size(); ++r) { + MergedRegion &mr = const_cast(mergedRegions[r]); + + // 本 mergedRegion 的最早 op + Operation *firstOp = mr.opsToMove.front(); + int lowerBound = 0; + + // 边界: 前一个 mergedRegion 的最后一个 op + if (r > 0) { + Operation *prevLast = mergedRegions[r - 1].opsToMove.back(); + lowerBound = opIndex[prevLast] + 1; + } + + SmallVector worklist(mr.opsToMove.begin(), mr.opsToMove.end()); + SmallPtrSet visited(mr.opsToMove.begin(), + mr.opsToMove.end()); + + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + + // 往前吸收operand + for (Value operand : op->getOperands()) { + // BlockArgument + if (mlir::isa(operand)) + continue; + + Operation *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + + // 不在 for body + if (defOp->getBlock() != &body) + continue; + + int defIdx = opIndex[defOp]; + + // 超出允许向前吸收的边界 + if (defIdx < lowerBound) + continue; + + // 已经在 opsToMove + if (!visited.insert(defOp).second) + continue; + + // 吸收这个 defOp + mr.opsToMove.push_back(defOp); + worklist.push_back(defOp); + } + } + + // 最后按原 block 顺序排序 + llvm::sort(mr.opsToMove, [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); + } +} + +static void pullInRegionDependencies(Operation *regionOp, int regionId, + DenseMap &opToRegion, + Block &body) { + + SmallVector worklist; + + // 先把 region 内的 op 放进去 + for (Region ®ion : regionOp->getRegions()) + for (Block &block : region) + for (Operation &inner : block) + worklist.push_back(&inner); + + SmallPtrSet visited; + + while (!worklist.empty()) { + Operation *innerOp = worklist.pop_back_val(); + + if (!visited.insert(innerOp).second) + continue; + + // operand 的 defining op + for (Value operand : innerOp->getOperands()) { + + Operation *def = operand.getDefiningOp(); + if (!def) + continue; + + if (def->getBlock() != &body) + continue; + + if (!opToRegion.count(def)) { + + opToRegion[def] = regionId; + + // 如果 def 也是 region-op,继续扩展 + if (def->getNumRegions() > 0) + worklist.push_back(def); + } + } + + // 继续遍历 region + for (Region &r : innerOp->getRegions()) + for (Block &b : r) + for (Operation &child : b) + worklist.push_back(&child); + } +} + +// BFS 查找某个 op 最早被哪个 region 使用 +static int findEarliestRegion(Operation *startOp, + const DenseMap &seedRegionMap, + Block &body) { + + SmallVector worklist{startOp}; + SmallPtrSet visited; + int earliestRegion = -1; + + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + + if (!visited.insert(op).second) + continue; + + for (Value result : op->getResults()) { + for (OpOperand &use : result.getUses()) { + Operation *user = use.getOwner(); + + if (user->getBlock() != &body) + continue; + + auto it = seedRegionMap.find(user); + if (it != seedRegionMap.end()) { + int region = it->second; + if (earliestRegion == -1 || region < earliestRegion) + earliestRegion = region; + } else { + worklist.push_back(user); + } + } + } + } + + return earliestRegion; +} + +void ExpandMergedRegionOpsForAll(scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // block 内 op 顺序 + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) + opIndex[&op] = idx++; + + // seed region map + DenseMap seedRegionMap; + for (int r = 0; r < mergedRegions.size(); r++) { + for (Operation *op : mergedRegions[r].opsToMove) { + seedRegionMap[op] = r; + } + } + + // 最终 op -> region + DenseMap opToRegion = seedRegionMap; + + // ---------- Step1 顺序扫描 ---------- + for (Operation &op : body) { + + if (isa(&op)) + continue; + + if (opToRegion.count(&op)) + continue; + + int region = findEarliestRegion(&op, seedRegionMap, body); + + if (region != -1) + opToRegion[&op] = region; + } + + // ---------- Step2 region-op 依赖补全 ---------- + for (Operation &op : body) { + + auto it = opToRegion.find(&op); + if (it == opToRegion.end()) + continue; + + if (op.getNumRegions() == 0) + continue; + + pullInRegionDependencies(&op, it->second, opToRegion, body); + } + + // ---------- Step3 append op ---------- + SmallPtrSet seen; + + for (Operation &op : body) { + + auto it = opToRegion.find(&op); + if (it == opToRegion.end()) + continue; + + if (!seen.insert(&op).second) + continue; + + int region = it->second; + mergedRegions[region].opsToMove.push_back(&op); + } + + // ---------- Step4 排序 ---------- + for (auto &mr : mergedRegions) { + + llvm::sort(mr.opsToMove, [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); + } +} + +void ExpandMergedRegionOpsByInput(scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // block 内 op 顺序 + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) + opIndex[&op] = idx++; + + // seed region map + DenseMap seedRegionMap; + for (int r = 0; r < mergedRegions.size(); r++) { + for (Operation *op : mergedRegions[r].opsToMove) { + seedRegionMap[op] = r; + } + } + + // 最终 op -> region + DenseMap opToRegion = seedRegionMap; + + // ---------- Step1 顺序扫描 ---------- + for (Operation &op : body) { + + if (isa(&op)) + continue; + + if (opToRegion.count(&op)) + continue; + + int region = findEarliestRegion(&op, seedRegionMap, body); + + if (region != -1) + opToRegion[&op] = region; + } + + // ---------- Step2 region-op 依赖补全 ---------- + for (Operation &op : body) { + + auto it = opToRegion.find(&op); + if (it == opToRegion.end()) + continue; + + if (op.getNumRegions() == 0) + continue; + + pullInRegionDependencies(&op, it->second, opToRegion, body); + } + + // ---------- Step3 append op ---------- + SmallPtrSet seen; + + for (Operation &op : body) { + + auto it = opToRegion.find(&op); + if (it == opToRegion.end()) + continue; + + if (!seen.insert(&op).second) + continue; + + int region = it->second; + mergedRegions[region].opsToMove.push_back(&op); + } + + // ---------- Step4 排序 ---------- + for (auto &mr : mergedRegions) { + + llvm::sort(mr.opsToMove, [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); + } +} + +static void +ExpandMergedRegionOpsByOutput(scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // block 顺序(保持 IR 顺序) + DenseMap opOrder; + int idx = 0; + for (Operation &op : body) + opOrder[&op] = idx++; + + for (auto &merged : mergedRegions) { + + // 收集 region 当前产生的 value + SmallPtrSet regionValues; + + for (Operation *op : merged.opsToMove) + for (Value res : op->getResults()) + regionValues.insert(res); + + bool changed = true; + + while (changed) { + changed = false; + + for (Operation &op : body) { + + if (isa(op) || isa(op)) + continue; + + if (llvm::is_contained(merged.opsToMove, &op)) + continue; + + bool depends = false; + + for (Value operand : op.getOperands()) { + if (regionValues.contains(operand)) { + depends = true; + break; + } + } + + if (!depends) + continue; + + // 加入 region + merged.opsToMove.push_back(&op); + + // 更新 regionValues + for (Value res : op.getResults()) + regionValues.insert(res); + + changed = true; + } + } + + // 排序保持原 block 顺序 + llvm::sort(merged.opsToMove, [&](Operation *a, Operation *b) { + return opOrder[a] < opOrder[b]; + }); + } +} + +static void MoveIndependentOpsIntoIf(scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // 记录哪些 op 已经在 region 里 + SmallPtrSet alreadyAssigned; + + for (auto &mr : mergedRegions) + for (Operation *op : mr.opsToMove) + alreadyAssigned.insert(op); + + // 记录 iter_arg -> region + DenseMap iterArgToRegion; + + for (int r = 0; r < mergedRegions.size(); r++) { + for (Operation *op : mergedRegions[r].opsToMove) { + + for (Value operand : op->getOperands()) { + + if (auto barg = mlir::dyn_cast(operand)) { + + if (barg.getOwner() == &body) + iterArgToRegion[barg] = r; + } + } + } + } + + // block 顺序 + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) + opIndex[&op] = idx++; + + // 扫描所有 op + for (Operation &op : body) { + + if (isa(op) || isa(op)) + continue; + + if (alreadyAssigned.contains(&op)) + continue; + + int targetRegion = -1; + + // 看 operand 是否来自 iter_arg + for (Value operand : op.getOperands()) { + + if (auto barg = mlir::dyn_cast(operand)) { + + if (barg.getOwner() != &body) + continue; + + auto it = iterArgToRegion.find(barg); + if (it != iterArgToRegion.end()) { + + targetRegion = it->second; + break; + } + } + } + + if (targetRegion == -1) + continue; + + mergedRegions[targetRegion].opsToMove.push_back(&op); + alreadyAssigned.insert(&op); + } + + // 排序保持 block 顺序 + for (auto &mr : mergedRegions) { + + llvm::sort(mr.opsToMove, [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); + } +} + +// 暴力包裹 +static void +ExpandMergedRegionOpsGreedyMaximum(scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // 记录哪些 op 已经属于 region + DenseSet regionOps; + + for (auto ®ion : mergedRegions) + for (Operation *op : region.opsToMove) + regionOps.insert(op); + + // block op 列表 + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + DenseMap opIndex; + for (int i = 0; i < ops.size(); i++) + opIndex[ops[i]] = i; + + for (auto ®ion : mergedRegions) { + + if (region.opsToMove.empty()) + continue; + + // 找到 region 在 block 中的范围 + int start = ops.size(); + int end = -1; + + for (Operation *op : region.opsToMove) { + int idx = opIndex[op]; + start = std::min(start, idx); + end = std::max(end, idx); + } + + SmallVector newOps; + + // ---------- backward 扩展 ---------- + for (int i = start - 1; i >= 0; i--) { + Operation *op = ops[i]; + + if (isa(op)) + break; + + if (regionOps.contains(op)) + break; + + newOps.push_back(op); + } + + // ---------- forward 扩展 ---------- + for (int i = end + 1; i < ops.size(); i++) { + Operation *op = ops[i]; + + if (isa(op)) + break; + + if (regionOps.contains(op)) + break; + + newOps.push_back(op); + } + + // 加入 region + for (Operation *op : newOps) { + region.opsToMove.push_back(op); + regionOps.insert(op); + } + } + + // 最后保持 block 顺序 + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +static void CollectForYieldRelatedOps(scf::ForOp forOp, + SmallVector &mergedRegions, + DenseSet &yieldRelatedOps) { + + Block &body = forOp.getRegion().front(); + + // 已经属于 region 的 op + DenseSet regionOps; + for (auto ®ion : mergedRegions) + for (Operation *op : region.opsToMove) + regionOps.insert(op); + + auto yield = cast(body.getTerminator()); + + SmallVector worklist; + DenseSet visited; + + // 初始化 worklist + for (Value v : yield.getOperands()) + worklist.push_back(v); + + while (!worklist.empty()) { + Value v = worklist.pop_back_val(); + + if (!visited.insert(v).second) + continue; + + Operation *def = v.getDefiningOp(); + if (!def) + continue; + + // 只处理 for body 内的 op + if (def->getBlock() != &body) + continue; + + // 已经在 region 内 + if (regionOps.contains(def)) + continue; + + // 记录 + if (yieldRelatedOps.insert(def).second) { + + // 继续向上找依赖 + for (Value operand : def->getOperands()) + worklist.push_back(operand); + } + } +} + +// 贪心吸收region前后的op +static void +ExpandMergedRegionOpsGreedy(scf::ForOp forOp, + SmallVector &mergedRegions, + DenseSet &skipOps) { + + Block &body = forOp.getRegion().front(); + + // 记录哪些 op 已经属于 region + DenseSet regionOps; + for (auto ®ion : mergedRegions) + for (Operation *op : region.opsToMove) + regionOps.insert(op); + + // block op 列表 + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + // op -> index + DenseMap opIndex; + for (int i = 0; i < ops.size(); i++) + opIndex[ops[i]] = i; + + for (auto ®ion : mergedRegions) { + + if (region.opsToMove.empty()) + continue; + + // 找到 region 在 block 中的范围 + int start = ops.size(); + int end = -1; + + for (Operation *op : region.opsToMove) { + int idx = opIndex[op]; + start = std::min(start, idx); + end = std::max(end, idx); + } + + SmallVector newOps; + + // ---------- backward 扩展 ---------- + for (int i = start - 1; i >= 0; i--) { + Operation *op = ops[i]; + + // block terminator + if (isa(op)) + break; + + // 遇到其他 region 的 op + if (regionOps.contains(op)) + break; + + // yield 关联 op,跳过但继续扫描 + if (skipOps.contains(op)) + continue; + + newOps.push_back(op); + } + + // ---------- forward 扩展 ---------- + for (int i = end + 1; i < ops.size(); i++) { + Operation *op = ops[i]; + + // block terminator + if (isa(op)) + break; + + // 遇到其他 region 的 op + if (regionOps.contains(op)) + break; + + // yield 关联 op,跳过 + if (skipOps.contains(op)) + continue; + + newOps.push_back(op); + } + + // 加入 region + for (Operation *op : newOps) { + region.opsToMove.push_back(op); + regionOps.insert(op); + } + } + + // 最后保持 block 顺序 + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +// 贪心吸收region前面的op +static void +ExpandMergedRegionOpsGreedyV2(scf::ForOp forOp, + SmallVector &mergedRegions, + DenseSet &skipOps) { + + Block &body = forOp.getRegion().front(); + + // block op 列表 + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + // op -> index + DenseMap opIndex; + for (int i = 0; i < ops.size(); i++) + opIndex[ops[i]] = i; + + // 记录哪些 op 已经属于 region + DenseSet regionOps; + for (auto ®ion : mergedRegions) + for (Operation *op : region.opsToMove) + regionOps.insert(op); + + for (int r = 0; r < mergedRegions.size(); r++) { + + auto ®ion = mergedRegions[r]; + if (region.opsToMove.empty()) + continue; + + // ---------- 当前 region block 范围 ---------- + int start = ops.size(); + int end = -1; + + for (Operation *op : region.opsToMove) { + int idx = opIndex[op]; + start = std::min(start, idx); + end = std::max(end, idx); + } + + // ---------- 前一个 region 的末尾 ---------- + int prevEnd = -1; + + if (r > 0 && !mergedRegions[r - 1].opsToMove.empty()) { + for (Operation *op : mergedRegions[r - 1].opsToMove) { + prevEnd = std::max(prevEnd, opIndex[op]); + } + } + + SmallVector newOps; + + // ---------- backward expand ---------- + for (int i = start - 1; i > prevEnd; i--) { + + Operation *op = ops[i]; + + // terminator + if (isa(op)) + break; + + // 已属于 region + if (regionOps.contains(op)) + break; + + // yield chain op + if (skipOps.contains(op)) + continue; + + newOps.push_back(op); + } + + // ---------- 加入 region ---------- + for (Operation *op : newOps) { + region.opsToMove.push_back(op); + regionOps.insert(op); + } + } + + // ---------- 保持 block 顺序 ---------- + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +// 贪心吸收region前面的op +static void +ExpandMergedRegionOpsGreedyV2ForAIC(scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // block op 列表 + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + // op -> index + DenseMap opIndex; + for (int i = 0; i < ops.size(); i++) + opIndex[ops[i]] = i; + + // 记录哪些 op 已经属于 region + DenseSet regionOps; + for (auto ®ion : mergedRegions) + for (Operation *op : region.opsToMove) + regionOps.insert(op); + + for (int r = 0; r < mergedRegions.size(); r++) { + + auto ®ion = mergedRegions[r]; + if (region.opsToMove.empty()) + continue; + + // ---------- 当前 region block 范围 ---------- + int start = ops.size(); + int end = -1; + + for (Operation *op : region.opsToMove) { + int idx = opIndex[op]; + start = std::min(start, idx); + end = std::max(end, idx); + } + + // ---------- 前一个 region 的末尾 ---------- + int prevEnd = -1; + + if (r > 0 && !mergedRegions[r - 1].opsToMove.empty()) { + for (Operation *op : mergedRegions[r - 1].opsToMove) { + prevEnd = std::max(prevEnd, opIndex[op]); + } + } + + SmallVector newOps; + + // ---------- backward expand ---------- + for (int i = start - 1; i > prevEnd; i--) { + + Operation *op = ops[i]; + + // terminator + if (isa(op)) + break; + + // 已属于 region + if (regionOps.contains(op)) + break; + + newOps.push_back(op); + } + + // ---------- 加入 region ---------- + for (Operation *op : newOps) { + region.opsToMove.push_back(op); + regionOps.insert(op); + } + } + + // ---------- 保持 block 顺序 ---------- + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +static void MoveForYieldOpIntoRegion(scf::ForOp forOp, + DenseSet &yieldRelatedOps, + SmallVector &mergedRegions) { + + DenseMap opToRegion; + + for (int i = 0; i < mergedRegions.size(); i++) + for (Operation *op : mergedRegions[i].opsToMove) + opToRegion[op] = i; + + auto yield = cast(forOp.getBody()->getTerminator()); + + for (int i = 0; i < yield.getNumOperands(); i++) { + + Value iterArg = forOp.getRegionIterArgs()[i]; + Value yieldVal = yield.getOperand(i); + + Operation *def = yieldVal.getDefiningOp(); + if (!def) + continue; + + if (!yieldRelatedOps.contains(def)) + continue; + + int targetRegion = -1; + + for (Operation *user : iterArg.getUsers()) { + + if (opToRegion.count(user)) { + targetRegion = opToRegion[user]; + break; + } + } + + if (targetRegion == -1) + continue; + + SmallVector stack; + stack.push_back(def); + + while (!stack.empty()) { + Operation *op = stack.pop_back_val(); + + if (!yieldRelatedOps.contains(op)) + continue; + + mergedRegions[targetRegion].opsToMove.push_back(op); + + yieldRelatedOps.erase(op); + + for (Value operand : op->getOperands()) { + if (Operation *dep = operand.getDefiningOp()) + stack.push_back(dep); + } + } + } + + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +static void +MoveRemainingYieldOpsToPrevRegion(scf::ForOp forOp, + DenseSet &yieldRelatedOps, + SmallVector &mergedRegions) { + + if (yieldRelatedOps.empty()) + return; + + Block &body = forOp.getRegion().front(); + + // op -> region index + DenseMap opToRegion; + for (int i = 0; i < mergedRegions.size(); i++) + for (Operation *op : mergedRegions[i].opsToMove) + opToRegion[op] = i; + + // block 顺序 + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + DenseMap opIndex; + for (int i = 0; i < ops.size(); i++) + opIndex[ops[i]] = i; + + for (Operation *op : yieldRelatedOps) { + + if (op->getBlock() != &body) + continue; + + int idx = opIndex[op]; + + int targetRegion = -1; + + // 向前找最近的 region + for (int i = idx - 1; i >= 0; i--) { + Operation *prev = ops[i]; + + if (opToRegion.count(prev)) { + targetRegion = opToRegion[prev]; + break; + } + } + + if (targetRegion == -1) + continue; + + mergedRegions[targetRegion].opsToMove.push_back(op); + } + + // 排序 + 去重 + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +static void MoveIndependentOpsIntoRegionBackwardV2( + scf::ForOp forOp, SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + DenseMap opToRegion; + for (int i = 0; i < mergedRegions.size(); i++) + for (Operation *op : mergedRegions[i].opsToMove) + opToRegion[op] = i; + + // ----------- 收集移动计划 ----------- + DenseMap movePlan; + + for (int i = 0; i < mergedRegions.size(); i++) { + MergedRegion ®ion = mergedRegions[i]; + if (region.opsToMove.empty()) + continue; + + Operation *firstOp = region.opsToMove.front(); + Operation *lastOp = region.opsToMove.back(); + auto itFirst = std::find(ops.begin(), ops.end(), firstOp); + auto itLast = std::find(ops.begin(), ops.end(), lastOp); + if (itFirst == ops.end() || itLast == ops.end()) + continue; + + int startIdx = std::distance(ops.begin(), itFirst); + int endIdx = std::distance(ops.begin(), itLast); + + // ----------- 收集 wait-set 区间 ----------- + SmallVector> waitIntervals; + bool inWait = false; + int begin = -1; + for (int j = startIdx; j <= endIdx; j++) { + Operation *op = ops[j]; + if (op->getName().getStringRef().contains("sync_block_wait")) { + inWait = true; + begin = j + 1; + continue; + } + if (op->getName().getStringRef().contains("sync_block_set") && inWait) { + inWait = false; + waitIntervals.push_back({begin, j - 1}); + } + } + auto isInWaitSet = [&](int idx) { + for (auto &p : waitIntervals) + if (idx >= p.first && idx <= p.second) + return true; + return false; + }; + + // ----------- 从后往前扫描 region 内的 op ----------- + for (int j = endIdx; j >= startIdx; j--) { + Operation *op = ops[j]; + if (isa(op) || isInWaitSet(j)) + continue; + + // ---------- operand 是否依赖本 region ---------- + bool dependCurrentRegion = false; + for (Value operand : op->getOperands()) { + Operation *def = operand.getDefiningOp(); + if (!def) + continue; + if (std::find(region.opsToMove.begin(), region.opsToMove.end(), def) != + region.opsToMove.end()) { + dependCurrentRegion = true; + break; + } + } + if (dependCurrentRegion) + continue; + + // ---------- 当前 region 后续是否使用 ---------- + bool usedLaterInSameRegion = false; + for (Value result : op->getResults()) + for (Operation *user : result.getUsers()) + if (std::find(region.opsToMove.begin(), region.opsToMove.end(), + user) != region.opsToMove.end() && + std::find(region.opsToMove.begin(), region.opsToMove.end(), op) < + std::find(region.opsToMove.begin(), region.opsToMove.end(), + user)) { + usedLaterInSameRegion = true; + break; + } + if (usedLaterInSameRegion) + continue; + + // ---------- 找使用该 op 的后续 region ---------- + int targetRegion = -1; + for (int k = i + 1; k < mergedRegions.size(); ++k) { + for (Operation *candidate : mergedRegions[k].opsToMove) + for (Value operand : candidate->getOperands()) + if (operand.getDefiningOp() == op) { + targetRegion = k; + break; + } + if (targetRegion != -1) + break; + if (targetRegion != -1) + break; + } + if (targetRegion == -1) + continue; + + movePlan[op] = targetRegion; + // llvm::outs() << "MJ: plan move " << *op + // << " -> region " << targetRegion << "\n"; + } + } + + // ----------- 统一应用移动 ----------- + for (auto &it : movePlan) { + Operation *op = it.first; + int targetRegionIdx = it.second; + MergedRegion &targetRegion = mergedRegions[targetRegionIdx]; + // 更新数据结构 + targetRegion.opsToMove.push_back(op); + + llvm::outs() << "MJ: move " << *op << " -> region " << targetRegionIdx + << "\n"; + } + + // ----------- 更新原 region 的 opsToMove ----------- + for (int i = 0; i < mergedRegions.size(); ++i) { + MergedRegion ®ion = mergedRegions[i]; + SmallVector newOps; + for (Operation *op : region.opsToMove) { + auto it = movePlan.find(op); + if (it == movePlan.end() || it->second == i) { + // 没有移动计划,或者移动的目标就是自己,保留 + newOps.push_back(op); + } + } + region.opsToMove.swap(newOps); + } + + // ----------- 排序 + 去重 ----------- + for (auto ®ion : mergedRegions) { + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +// // debug: 如果一个forop的第一个region的最后3条op是%27 = tt.expand_dims %25#1 +// {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> +// %28 = tt.broadcast %27 : tensor<64x1xf32> -> tensor<64x128xf32> +// %29 = arith.mulf %arg10, %28 : tensor<64x128xf32> +// 直接放到第2个region里 +static void TempChange(scf::ForOp forOp, + SmallVector &mergedRegions) { + + if (mergedRegions.size() < 2) + return; + + auto &srcRegion = mergedRegions[0]; + auto &dstRegion = mergedRegions[1]; + + if (srcRegion.opsToMove.size() < 3) + return; + + Operation *op1 = srcRegion.opsToMove[srcRegion.opsToMove.size() - 3]; + Operation *op2 = srcRegion.opsToMove[srcRegion.opsToMove.size() - 2]; + Operation *op3 = srcRegion.opsToMove[srcRegion.opsToMove.size() - 1]; + + // ---------- pattern 匹配 ---------- + if (!op1->getName().getStringRef().contains("tt.expand_dims")) + return; + + if (!op2->getName().getStringRef().contains("tt.broadcast")) + return; + + if (!op3->getName().getStringRef().contains("arith.mulf")) + return; + + llvm::outs() << "TempChange triggered\n"; + + SmallVector opsToMove = {op1, op2, op3}; + + // ---------- 移动到 region2 末尾 ---------- + for (Operation *op : opsToMove) { + dstRegion.opsToMove.push_back(op); + llvm::outs() << "TempChange move: " << *op << "\n"; + } + + // ---------- 从 region1 删除 ---------- + srcRegion.opsToMove.resize(srcRegion.opsToMove.size() - 3); + + // ---------- 排序 ---------- + for (auto ®ion : mergedRegions) { + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +static void sortOperationsByDataFlow(llvm::SmallVector &ops) { + llvm::DenseSet visited; + llvm::SmallVector result; + + std::function dfs = [&](Operation *op) { + if (!visited.insert(op).second) + return; + + for (Value operand : op->getOperands()) { + if (Operation *def = operand.getDefiningOp()) { + if (llvm::is_contained(ops, def)) + dfs(def); + } + } + + result.push_back(op); + }; + + for (Operation *op : ops) + dfs(op); + + ops.assign(result.begin(), result.end()); +} + +static void rewriteOperandsRecursively(Operation *op, + DenseMap &valueMap) { + + // 1 rewrite 当前 op 的 operands + for (OpOperand &operand : op->getOpOperands()) { + Value v = operand.get(); + auto it = valueMap.find(v); + if (it != valueMap.end()) + operand.set(it->second); + } + + // 2 递归进入 region + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : block) { + rewriteOperandsRecursively(&nestedOp, valueMap); + } + } + } +} + +static void CopyOpsToAfterwardRegions( + SmallVector &mergedRegions, + DenseMap &yieldMap, + DenseMap &cloneAndOriYieldMap, + SmallVector &copiedForOps) { + + if (mergedRegions.size() <= 1) + return; + + // 先整理一个 set,方便判断哪些 op 是 yield defining op + DenseSet yieldDefOps; + for (auto &it : yieldMap) + yieldDefOps.insert(it.second); + + // 倒序遍历 region + for (int i = mergedRegions.size() - 1; i >= 0; --i) { + MergedRegion &curRegion = mergedRegions[i]; + + DenseMap valueMap; + SmallVector clonedOps; + + // 遍历前面的 region + for (int k = 0; k < i; ++k) { + MergedRegion &prevRegion = mergedRegions[k]; + + int waitSetLevel = 0; + + for (Operation *op : prevRegion.opsToMove) { + + if (isa(op)) { + waitSetLevel++; + continue; + } + + if (isa(op)) { + waitSetLevel = std::max(waitSetLevel - 1, 0); + continue; + } + + if (waitSetLevel > 0) + continue; + + if (isa(op)) + continue; + + IRMapping mapper; + + for (auto result : op->getResults()) + if (valueMap.count(result)) + mapper.map(result, valueMap[result]); + + Operation *insertPoint = + curRegion.opsToMove.empty() ? nullptr : curRegion.opsToMove.front(); + + OpBuilder builder(insertPoint ? insertPoint : op); + + Operation *cloned = builder.clone(*op, mapper); + + // 记录 result mapping + for (auto it : llvm::zip(op->getResults(), cloned->getResults())) + valueMap[std::get<0>(it)] = std::get<1>(it); + + // 如果这个 op 是 yield defining op,记录 clone -> original + if (yieldDefOps.contains(op)) { + cloneAndOriYieldMap[cloned] = op; + } + + // 记录copy的for op + if (auto forOp = dyn_cast(cloned)) { + copiedForOps.push_back(forOp); + } + + clonedOps.push_back(cloned); + } + } + + // 插入到当前 region 开头 + curRegion.opsToMove.insert(curRegion.opsToMove.begin(), clonedOps.begin(), + clonedOps.end()); + + // rebuild SSA + for (Operation *op : curRegion.opsToMove) { + rewriteOperandsRecursively(op, valueMap); + } + + // 排序保证拓扑顺序 + sortOperationsByDataFlow(curRegion.opsToMove); + } +} + +/// 记录 forOp 的 yield value 与其原始生成的 op 的映射 +static void GetYieldMap(scf::ForOp forOp, + DenseMap &yieldMap) { + yieldMap.clear(); + + // 取 forOp body 的 scf.yield + auto yieldOp = dyn_cast(forOp.getBody()->getTerminator()); + if (!yieldOp) + return; + + for (Value yieldVal : yieldOp.getOperands()) { + // 获取生成 yieldVal 的原始 op + Operation *defOp = yieldVal.getDefiningOp(); + + // 对 block arg(可能是 iter_arg)没有 definingOp 的情况,可以跳过或直接记录 + // nullptr + if (!defOp) + continue; + + yieldMap[yieldVal] = defOp; + } +} + +static Value findIterArgForAIC(Value v, scf::ForOp forOp) { + while (true) { + if (auto arg = dyn_cast(v)) { + if (arg.getOwner() == forOp.getBody()) + return v; + return Value(); + } + + Operation *def = v.getDefiningOp(); + if (!def) + return Value(); + + if (def->getNumOperands() == 0) + return Value(); + + v = def->getOperand(0); + } +} + +static Operation * +findCloneOfYieldOp(Operation *oriYieldOp, + DenseMap &cloneAndOriYieldMap, + MergedRegion ®ion) { + + for (Operation *op : region.opsToMove) { + auto it = cloneAndOriYieldMap.find(op); + if (it != cloneAndOriYieldMap.end() && it->second == oriYieldOp) + return op; + } + return nullptr; +} + +static void RebuildForYielValuesForAIC( + scf::ForOp forOp, SmallVector &mergedRegions, + DenseMap &yieldMap, + DenseMap &cloneAndOriYieldMap) { + + auto yieldOp = cast(forOp.getBody()->getTerminator()); + + for (MergedRegion ®ion : mergedRegions) { + + triton::DotOp dotOp = nullptr; + + for (Operation *op : region.opsToMove) { + if (auto d = dyn_cast(op)) { + dotOp = d; + break; + } + } + + if (!dotOp) + continue; + + // 处理 dot operand + for (Value operand : dotOp->getOperands()) { + + Value iterArg = findIterArgForAIC(operand, forOp); + if (!iterArg) + continue; + + auto arg = cast(iterArg); + int idx = arg.getArgNumber(); + + if (idx >= yieldOp.getNumOperands()) + continue; + + Value oriYieldValue = yieldOp.getOperand(idx); + + auto it = yieldMap.find(oriYieldValue); + if (it == yieldMap.end()) + continue; + + Operation *oriYieldOp = it->second; + + Operation *cloneOp = + findCloneOfYieldOp(oriYieldOp, cloneAndOriYieldMap, region); + + if (!cloneOp) + continue; + + yieldOp.setOperand(idx, cloneOp->getResult(0)); + } + } +} + +void ExpandMergedRegionOps(scf::ForOp forOp, + SmallVector &mergedRegions, + SmallVector &copiedForOps) { + bool isInAIV = false; + auto scopeOp = forOp->getParentOfType(); + if (!scopeOp) + return; + + auto coreTypeAttr = + scopeOp->getAttrOfType(hivm::TCoreTypeAttr::name); + + if (coreTypeAttr.getTcoretype() == hivm::TCoreType::VECTOR) { + isInAIV = true; + } + + if (isInAIV) { + DenseSet yieldRelatedOps; + + // 1 收集 yield 相关 op + CollectForYieldRelatedOps(forOp, mergedRegions, yieldRelatedOps); + + // 2 greedy 扩展 + // ExpandMergedRegionOpsGreedy(forOp, mergedRegions, yieldRelatedOps); + ExpandMergedRegionOpsGreedyV2(forOp, mergedRegions, yieldRelatedOps); + + // 3 与前面wait-set region独立的op应该被放入后面的关联的region + MoveIndependentOpsIntoRegionBackwardV2(forOp, mergedRegions); + + // 4 根据 iter_arg 使用位置放入 region + MoveForYieldOpIntoRegion(forOp, yieldRelatedOps, mergedRegions); + + // 5 剩余 yield chain 放入前一个 region + MoveRemainingYieldOpsToPrevRegion(forOp, yieldRelatedOps, mergedRegions); + } else { // AIC单独处理, 避免出现CUBE内的tensor变量依赖 + // 用Map记录原始的for yield op的的映射 + DenseMap yieldMap; + GetYieldMap(forOp, yieldMap); + + llvm::outs() << "YieldMap:\n"; + for (auto it : yieldMap) { + llvm::outs() << *(it.second) << "\n"; + } + + // 2 greedy 扩展, yield value后续处理 + ExpandMergedRegionOpsGreedyV2ForAIC(forOp, mergedRegions); + + // 复制当前region的除tt.dot、以及[wait - + // set]之间的op到后续的所有MergedRegion 倒序实现 + // 记录clone和original的yield对应op的map + DenseMap cloneAndOriYieldMap; + CopyOpsToAfterwardRegions(mergedRegions, yieldMap, cloneAndOriYieldMap, + copiedForOps); + + // 4 + // 先确定每个MergedRegion的tt.dot的operand的来源是for的哪个iter_arg(递归查找), + // 假设为%arg0, 依据yieldMap可以得到oriYield 遍历当前MergedRegion的所有op, + // 确定哪条op对应的cloneAndOriYieldMap的second是oriYield, 假设为%45 + // 最后替换for yield op对应位置的operand为%45 + RebuildForYielValuesForAIC(forOp, mergedRegions, yieldMap, + cloneAndOriYieldMap); + } +} + +void MergeWaitSetRegions(SmallVector ®ions, + SmallVector &merged) { + for (int i = 0; i < regions.size();) { + MergedRegion mr; + mr.regions.push_back(®ions[i]); + mr.opsToMove.append(regions[i].opsToMove); + + int j = i; + while (!regions[j].hasCopyOrFixpipe && j + 1 < regions.size()) { + j++; + mr.regions.push_back(®ions[j]); + mr.opsToMove.append(regions[j].opsToMove); + } + + merged.push_back(std::move(mr)); + i = j + 1; + } + + for (MergedRegion &mr : merged) { + SmallPtrSet regionValues; + SmallPtrSet opSet; + + for (Operation *op : mr.opsToMove) + opSet.insert(op); + + for (Operation *op : mr.opsToMove) { + for (Value v : op->getResults()) { + bool usedOutside = false; + for (OpOperand &use : v.getUses()) { + Operation *user = use.getOwner(); + if (!opSet.contains(user) && user->getBlock() == op->getBlock()) { + usedOutside = true; + break; + } + } + if (usedOutside) { + mr.yieldValues.push_back(v); + mr.resultTypes.push_back(v.getType()); + } + } + } + } +} + +void GetBlockInfos(SmallVector ®ions, Block &body) { + for (auto it = body.begin(); it != body.end();) { + Operation *op = &*it; + + auto waitOp = dyn_cast(op); + if (!waitOp) { + it++; + continue; + } + + auto pipeS = hivm::PipeAttr::get(op->getContext(), hivm::PIPE::PIPE_S); + if (auto syncWait = dyn_cast(op)) { + if (syncWait.getTpipe() == pipeS || syncWait.getPipe() == pipeS) { + return; + } + } + Operation *lastSetOp = nullptr; + + // 扫描到下一个 wait, 收集所有 set + auto curIt = std::next(it); + auto endIt = curIt; + int setOpCount = 0; + SmallVector opsInRegion; + for (; curIt != body.end(); ++curIt) { + Operation *curOp = &*curIt; + if (isa(curOp) && setOpCount >= 1) + break; + if (isa(curOp)) { + setOpCount++; + endIt = curIt; // setop的位置 + lastSetOp = curOp; // 最后一个 set + } + } + + if (!lastSetOp) { + it = curIt; + continue; + } // 没有 set, 不包 + + // 收集 [wait, ..., lastSet] 之间的 ops + bool hasCopyOrFixpipe = false; + for (auto it2 = it; it2 != std::next(endIt); ++it2) { + Operation *curOp = &*it2; + opsInRegion.push_back(curOp); + if (isa(curOp) || isa(curOp)) { + hasCopyOrFixpipe = true; + } + } + + it = endIt++; + regions.push_back({waitOp, lastSetOp, opsInRegion, hasCopyOrFixpipe}); + } +} + +Value findIterArg(Value v, Type t) { + SmallVector worklist = {v}; + SmallPtrSet visited; + + while (!worklist.empty()) { + Value cur = worklist.front(); + worklist.erase(worklist.begin()); + if (!visited.insert(cur).second) + continue; + + // 匹配scf.for原始迭代参数, 直接返回 + if (auto b = mlir::dyn_cast(cur)) { + auto forOp = mlir::dyn_cast(b.getOwner()->getParentOp()); + if (forOp && b.getType() == t) { + for (Value iterArg : forOp.getRegionIterArgs()) { + if (iterArg.getAsOpaquePointer() == b.getAsOpaquePointer()) { + return b; + } + } + } + } + + Operation *defOp = cur.getDefiningOp(); + if (!defOp) + continue; + + // 核心逻辑:如果当前值是scf.if的结果 + // 进入then块找源头 + if (auto ifOp = mlir::dyn_cast(defOp)) { + Block &thenBlock = ifOp.getThenRegion().front(); + // 找到then块最后一个op(scf.yield) + // 取其operands(即ifOp结果的源头值) + for (auto &innerOp : llvm::reverse(thenBlock)) { + if (auto yieldOp = mlir::dyn_cast(&innerOp)) { + // 按索引匹配: cur是ifOp的第n个结果, 取yieldOp的第n个operand + for (auto [idx, res] : llvm::enumerate(ifOp.getResults())) { + if (res.getAsOpaquePointer() == cur.getAsOpaquePointer()) { + Value srcVal = yieldOp.getOperand(idx); + if (!visited.count(srcVal)) + worklist.push_back(srcVal); + break; + } + } + break; // 找到yield即退出, 无需遍历其他op + } + } + } else { + // 非if结果值 + // 正常往前追溯operands + for (Value operand : defOp->getOperands()) { + if (!visited.count(operand)) + worklist.push_back(operand); + } + } + } + + llvm::outs() << "未找到迭代参数, 返回原值: "; + v.print(llvm::outs()); + llvm::outs() << "\n"; + return v; +} + +// 如果 v 最终被 scf.for 的 yield 使用 +// → 返回对应的 forOp 的 iter_arg +// 如果 v 只是流向后面的 wait-set region / 其他 op +// → 直接返回原值 v +Value findIterArgForAll(Value v, Type t) { + for (Operation *user : v.getUsers()) { + + if (auto yieldOp = dyn_cast(user)) { + + if (auto forOp = dyn_cast(yieldOp->getParentOp())) { + + for (auto [idx, operand] : llvm::enumerate(yieldOp.getOperands())) { + + if (operand.getAsOpaquePointer() == v.getAsOpaquePointer()) { + + Value iterArg = forOp.getRegionIterArgs()[idx]; + + if (iterArg.getType() == t) + return iterArg; + } + } + } + } + } + + return v; +} + +void FindDependValues(SmallVector &dependValues, + SmallVector mergedRegions) { + dependValues.clear(); + for (auto &curMR : mergedRegions) { + for (Value yieldValue : curMR.yieldValues) { + // llvm::outs() << "yieldValue: "<< yieldValue << "\n"; + // 遍历当前区域的yieldValue的所有user OP,判断是否存在依赖关系 + for (OpOperand &use : yieldValue.getUses()) { + Operation *userOp = use.getOwner(); + + // llvm::outs() << "userOp: "<< *userOp << "\n"; + bool isUserInOtherRegion = false; + for (auto &otherMR : mergedRegions) { + // 跳过当前区域,只检查yieldValue是否被其他区域使用 + if (&otherMR == &curMR) + continue; + + // 只要有一个 userOp在 otherMR 的 opsToMove + // 列表中,就认为是dependValue llvm::outs() << "judge comtain\n"; for + // (size_t k = 0; k < otherMR.opsToMove.size(); k++) { + // llvm::outs() << "otherMR op: " << *(otherMR.opsToMove[k]) << + // "\n"; + // } + // llvm::outs() << "otherMR end\n"; + + // if (llvm::is_contained(otherMR.opsToMove, userOp)) { + // isUserInOtherRegion = true; + // llvm::outs() << "is_contained\n"; + // break; + // } + + // 用 DenseSet 暂存当前 region 的所有 ops + // 初始 DenseSet: 顶层 opsToMove + DenseSet otherOps; + for (Operation *op : otherMR.opsToMove) { + CollectAllNestedOps(op, otherOps); // 完整展开嵌套 + } + if (otherOps.contains(userOp)) { + isUserInOtherRegion = true; + break; + } + } + + // 无重复的添加依赖变量 + if (isUserInOtherRegion) { + if (!llvm::is_contained(dependValues, yieldValue)) { + dependValues.push_back(yieldValue); + } + break; + } + } + } + } +} + +void UpdateMergedRegionsWithNewForOp(SmallVector &mergedRegions, + IRMapping &mapper) { + for (auto &mr : mergedRegions) { + // WaitSetRegion 后续已经不使用了,直接释放,否则会出现野指针 + SmallVector newRegions; + newRegions.clear(); + mr.regions = newRegions; + // // 更新 opsToMove 列表 + // llvm::outs() << "before \n"; + // for (auto &op : mr.opsToMove) { + // llvm::outs() << "opsToMove: " << op << ", " << *op << '\n'; + // } + SmallVector newOpsToMove; + newOpsToMove.clear(); + for (Operation *op : mr.opsToMove) { + if (op) { + Operation *newOp = mapper.lookupOrNull(op); + newOpsToMove.push_back(newOp); + } + } + mr.opsToMove = newOpsToMove; + // llvm::outs() << "after \n"; + // for (auto &op : mr.opsToMove) { + // llvm::outs() << "opsToMove: " << op << ", " << *op << '\n'; + // } + // 更新 yieldValues 列表 + SmallVector newYieldValues; + newYieldValues.clear(); + for (Value v : mr.yieldValues) { + if (v) { + newYieldValues.push_back(mapper.lookupOrNull(v)); + } + } + mr.yieldValues = newYieldValues; + // resultTypes 是type 类型,无需更新 + } +} + +void AddArgsForDependValues(scf::ForOp forOp, SmallVector &dependValues, + SmallVector &mergedRegions, + ModuleOp module) { + OpBuilder moduleBuilder(module.getContext()); + SmallVector valueTypes; + valueTypes.clear(); + + if (dependValues.empty()) { + return; + } else { + for (Value v : dependValues) { + Type valueType = v.getType(); + valueTypes.push_back(valueType); + } + } + + // 为每个 dependValue 创建一个初始值(可能不存在相同shape和type的常量tensor) + SmallVector initTensors; + initTensors.clear(); + module.walk([&](Operation *op) { + if (auto constOp = dyn_cast(op)) { + moduleBuilder.setInsertionPoint(constOp); + for (Type valueType : valueTypes) { + auto tensorType = dyn_cast(valueType); + triton::PointerType ptrType; + ptrType = + (tensorType) + ? dyn_cast(tensorType.getElementType()) + : dyn_cast(valueType); + if (ptrType) { + // 如果依赖变量是一个ptr类型 + // 1. 创建 i64 0 + // 2. cast 成 !tt.ptr<...> + Value zero = moduleBuilder.create( + constOp.getLoc(), 0, 64); + Value ptrValue = moduleBuilder.create( + constOp.getLoc(), ptrType, zero); + if (tensorType) { + // 3. splat 成 tensor<...x!tt.ptr<...>> + Value ptrTensor = moduleBuilder.create( + constOp.getLoc(), tensorType, ptrValue); + initTensors.push_back(ptrTensor); + } else { + initTensors.push_back(ptrValue); + } + } else if (auto memrefType = dyn_cast(valueType)) { + // 如果中间变量是一个memref类型,为iterarg创建一个 alloc = memref + // 仅支持#hivm.address_space,对于#hivm.address_space,不存在 + // copy cbuf to cbuf 行为 + auto spaceAttr = + cast(memrefType.getMemorySpace()); + if (spaceAttr && + spaceAttr.getAddressSpace() == hivm::AddressSpace::L1) { + llvm::dbgs() << "AddArgsForDependValues: dependValue type is a " + "memref hivm::AddressSpace::L1 type!!!\n"; + return mlir::WalkResult::interrupt(); + } else { + mlir::Value alloc = moduleBuilder.create( + constOp.getLoc(), memrefType); + initTensors.push_back(alloc); + } + } else { + // 非 ptr 类型创建零值常量 + auto zeroAttr = moduleBuilder.getZeroAttr(valueType); + Value zeroTensor = moduleBuilder.create( + constOp.getLoc(), zeroAttr); + initTensors.push_back(zeroTensor); + } + } + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + + auto initArgs = forOp.getInitArgs(); + + // 构建新的初始化参数列表 + SmallVector newInitArgs(initArgs.begin(), initArgs.end()); + // 添加 dependValue 的初始化参数 + for (Value initTensor : initTensors) { + newInitArgs.push_back(initTensor); + } + + // 获取原循环的边界和步长 + Value lb = forOp.getLowerBound(); + Value ub = forOp.getUpperBound(); + Value step = forOp.getStep(); + + // 创建新的 ForOp,插入点位于原操作之前 + OpBuilder builder(forOp); + auto newForOp = + builder.create(forOp.getLoc(), lb, ub, step, newInitArgs); + + // 获取新循环的 region 块(已自动包含循环索引和迭代参数) + Block &newBlock = newForOp.getRegion().front(); + Block &oldBlock = forOp.getRegion().front(); + + // 建立块参数的映射:原块参数 -> 新块参数 + IRMapping mapper; + for (unsigned i = 0; i < oldBlock.getNumArguments(); ++i) { + mapper.map(oldBlock.getArgument(i), newBlock.getArgument(i)); + } + // 将原循环体中的操作(不包括终结符)克隆到新块中 + // 同时按照顺序克隆新的 dependValues + SmallVector newDependValues = dependValues; + int cnt = 0; + builder.setInsertionPointToStart(&newBlock); + for (auto &op : oldBlock) { + auto newOp = builder.clone(op, mapper); + // dependValue 的定义OP 可能有多个 result + for (size_t i = 0; i < dependValues.size(); i++) { + Operation *defineOp = dependValues[i].getDefiningOp(); + if (defineOp == &op) { + unsigned int index = cast(dependValues[i]).getResultNumber(); + newDependValues[i] = newOp->getResult(index); + cnt++; + break; + } + } + } + // 判断是否找到了所有的 dependValue + if (newDependValues.size() != cnt) { + llvm::outs() << "can not find the depend value! \n"; + return; + } + dependValues = newDependValues; + + // 更新 mergedRegions 中的 op 为新的for循环的 op + UpdateMergedRegionsWithNewForOp(mergedRegions, mapper); + + // 创建新的循环 yield 操作:原操作数 + dependValues + auto oldYield = cast(newBlock.getTerminator()); + SmallVector newYieldOps(oldYield.getOperands()); + // 按顺序增加找到的 dependvalue + for (Value v : newDependValues) { + newYieldOps.push_back(v); + } + builder.setInsertionPointToEnd(&newBlock); + builder.create(oldYield.getLoc(), newYieldOps); + oldYield.erase(); + + // 将原 forOp 的所有使用替换为新 forOp + int oldResultNum = forOp->getResults().size(); + for (auto it : llvm::zip(forOp->getResults(), + newForOp->getResults().take_front(oldResultNum))) { + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + forOp.erase(); +} + +void ComputeElseYieldValues(MergedRegion mergedRegion, + SmallVector &elseYieldValues, + SmallVector dependValues) { + int idx = 0; + for (Value v : mergedRegion.yieldValues) { + Type yieldType = mergedRegion.resultTypes[idx]; + elseYieldValues.push_back(findIterArg(v, yieldType)); + idx++; + } +} + +void ComputeElseYieldValuesV2(MergedRegion mergedRegion, + SmallVector &elseYieldValues, + SmallVector dependValues) { + // 对于yieldValues,其中的 yield value 一定是被 for op yield + // 所引用,或者被其他 region 所使用 + auto forOp = dyn_cast( + mergedRegion.yieldValues[0].getDefiningOp()->getBlock()->getParentOp()); + if (!forOp) { + llvm::outs() << "define op's parent is not ForOp \n"; + return; + } + auto iterArgs = forOp.getRegionIterArgs(); + auto forYieldValues = forOp.getYieldedValues(); + + // 新增的与 dependvalue 相关的 initarg + // 是接在原本for循环args后面,数量与dependvalue数量相等 + int baseDependIdx = iterArgs.size() - dependValues.size(); + + int idx = 0; + for (Value v : mergedRegion.yieldValues) { + Type yieldType = mergedRegion.resultTypes[idx]; + // yieldValue 中是dependvalue 的情况下 + // else yield value 使用对应的新增 iterargs + if (llvm::is_contained(dependValues, v)) { + int dependIdx = 0; + for (; dependIdx < dependValues.size(); dependIdx++) { + if (v == dependValues[dependIdx]) { + break; + } + } + // llvm::outs()<<"v2for:"< newYieldValues; + SmallVector newResultTypes; + + SmallPtrSet seen; + + for (auto [idx, v] : llvm::enumerate(region.yieldValues)) { + if (seen.insert(v).second) { + newYieldValues.push_back(v); + newResultTypes.push_back(region.resultTypes[idx]); + } + } + + region.yieldValues.swap(newYieldValues); + region.resultTypes.swap(newResultTypes); +} + +static void replaceExternalIfOpUses(scf::IfOp ifOp, + ArrayRef oldYieldValues) { + + for (size_t i = 0; i < oldYieldValues.size(); ++i) { + Value oldVal = oldYieldValues[i]; + Value newVal = ifOp.getResult(i); + + SmallVector usesToReplace; + + for (OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) { + + Operation *user = use.getOwner(); + + // 跳过 ifOp 内部的使用(then / else region) + if (ifOp->isAncestor(user)) + continue; + + // 只替换 ifOp 之后的使用 + if (user->getBlock() == ifOp->getBlock()) { + if (!ifOp->isBeforeInBlock(user)) + continue; + } + + usesToReplace.push_back(&use); + } + + for (OpOperand *use : usesToReplace) + use->set(newVal); + } +} + +void CreateIfOps(SmallVector &mergedRegions, + SmallVector dependValues) { + for (auto ®ion : mergedRegions) { + + // 去重yieldvalues + RemoveRedundantYieldValues(region); + + Operation *insertPt = region.opsToMove.front(); + OpBuilder builder(insertPt); + Location loc = insertPt->getLoc(); + Value cond = builder.create(loc, builder.getI1Type(), + builder.getBoolAttr(true)); + + bool needsYield = !region.yieldValues.empty(); + scf::IfOp ifOp; + if (needsYield) + ifOp = builder.create(loc, region.resultTypes, cond, true); + else + ifOp = builder.create(loc, TypeRange{}, cond, false); + + // 加标记 + ifOp->setAttr("ssbuffer", builder.getUnitAttr()); + + // 获取if yield value 在 else块 返回值 + SmallVector elseYieldValues; + + llvm::outs() << "before ComputeElseYieldValuesV2" + << "\n"; + if (needsYield) { + // ComputeElseYieldValues(region, elseYieldValues, dependValues); + ComputeElseYieldValuesV2(region, elseYieldValues, dependValues); + } + + llvm::outs() << "after ComputeElseYieldValuesV2" + << "\n"; + // 将op移进then块 + Block &thenBlock = ifOp.getThenRegion().front(); + for (Operation *m : llvm::reverse(region.opsToMove)) { + m->moveBefore(&thenBlock, thenBlock.begin()); + } + + // 创建 then/else yield + if (needsYield) { + OpBuilder thenBuilder(builder.getContext()); + thenBuilder.setInsertionPointToEnd(&thenBlock); + thenBuilder.create(loc, region.yieldValues); + + // else block + Block &elseBlock = ifOp.getElseRegion().front(); + OpBuilder elseBuilder(&elseBlock, elseBlock.end()); + elseBuilder.create(loc, elseYieldValues); + + // 替换外部使用 + + replaceExternalIfOpUses(ifOp, region.yieldValues); + + // 旧的逻辑 + // Block *block = ifOp->getBlock(); + // auto ifIt = Block::iterator(ifOp); + + // for (size_t i = 0; i < region.yieldValues.size(); ++i) { + // Value oldVal = region.yieldValues[i]; + // Value newVal = ifOp.getResult(i); + + // SmallVector usesToReplace; + + // for (OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) { + // Operation *user = use.getOwner(); + // // 同一个 block, user 必须在 ifOp 之后, 不能在 ifOp 内部(then / + // else) if (user->getBlock() != ifOp->getBlock() || + // !ifOp->isBeforeInBlock(user) || user->getParentOp() == ifOp) + // continue; + // usesToReplace.push_back(&use); + // } + + // for (OpOperand *use : usesToReplace) + // use->set(newVal); + // } + } + + llvm::outs() << "Create ifOp: " << *ifOp << "\n"; + } +} + +void CreateIfOpsOrigin(SmallVector &mergedRegions) { + for (auto ®ion : mergedRegions) { + + // 去重yieldvalues + RemoveRedundantYieldValues(region); + + Operation *insertPt = region.opsToMove.front(); + OpBuilder builder(insertPt); + Location loc = insertPt->getLoc(); + Value cond = builder.create(loc, builder.getI1Type(), + builder.getBoolAttr(true)); + + bool needsYield = !region.yieldValues.empty(); + scf::IfOp ifOp; + if (needsYield) + ifOp = builder.create(loc, region.resultTypes, cond, true); + else + ifOp = builder.create(loc, TypeRange{}, cond, false); + + // 加标记 + ifOp->setAttr("ssbuffer", builder.getUnitAttr()); + + // 将op移进then块 + Block &thenBlock = ifOp.getThenRegion().front(); + for (Operation *m : llvm::reverse(region.opsToMove)) { + m->moveBefore(&thenBlock, thenBlock.begin()); + } + + // 创建 then/else yield + if (needsYield) { + OpBuilder thenBuilder(builder.getContext()); + thenBuilder.setInsertionPointToEnd(&thenBlock); + thenBuilder.create(loc, region.yieldValues); + + // else block + SmallVector elseYieldValues; + int idx = 0; + for (Value v : region.yieldValues) { + Type yieldType = region.resultTypes[idx]; + elseYieldValues.push_back(findIterArgForAll(v, yieldType)); + idx++; + } + Block &elseBlock = ifOp.getElseRegion().front(); + OpBuilder elseBuilder(&elseBlock, elseBlock.end()); + elseBuilder.create(loc, elseYieldValues); + + // 替换外部使用 + Block *block = ifOp->getBlock(); + auto ifIt = Block::iterator(ifOp); + + for (size_t i = 0; i < region.yieldValues.size(); ++i) { + Value oldVal = region.yieldValues[i]; + Value newVal = ifOp.getResult(i); + + SmallVector usesToReplace; + + for (OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) { + Operation *user = use.getOwner(); + // 同一个 block, user 必须在 ifOp 之后, 不能在 ifOp 内部(then / + // else) + if (user->getBlock() != ifOp->getBlock() || + !ifOp->isBeforeInBlock(user) || user->getParentOp() == ifOp) + continue; + usesToReplace.push_back(&use); + } + + for (OpOperand *use : usesToReplace) + use->set(newVal); + } + } + + llvm::outs() << "Create ifOp: " << *ifOp << "\n"; + } +} + +void AddIfCondition(ModuleOp module) { + SmallVector copiedForOps; + SmallVector forOpList; + SmallVector, 1> regionList; + + module.walk([&](scf::ForOp forOp) { + Block &body = forOp.getRegion().front(); + SmallVector regions; + + // 获取基本的wait-set分块信息 + GetBlockInfos(regions, body); + + SmallVector mergedRegions; + // 合并wait-set块, 依据copyop / fixpipeop合并 + MergeWaitSetRegions(regions, mergedRegions); + + // 扩展if包裹的op范围 + // AIV、AIC处理有区别 + ExpandMergedRegionOps(forOp, mergedRegions, copiedForOps); + + // 处理forop的末尾对于iter_arg的自增操作, 如tt.advance, 移进对应的if op + MoveIterArgUsersIntoIf(forOp, mergedRegions); + + // 获取if yield的value, 并更新if内op的user为yield value + for (MergedRegion &mr : mergedRegions) { + // ComputeYieldForMergedRegion(mr, body); + ComputeYieldForMergedRegionV4(mr); + } + + // // 创建最终的if op + // CreateIfOpsOrigin(mergedRegions); + // }); + + forOpList.push_back(forOp); + regionList.push_back(mergedRegions); + }); + + llvm::outs() << "CopyForOp:\n"; + for (auto op : copiedForOps) { + llvm::outs() << *op << "\n"; + } + + SmallVector tmpOps; + for (auto copiedOp : copiedForOps) { + Block &body = copiedOp.getRegion().front(); + SmallVector regions; + + // 获取基本的wait-set分块信息 + GetBlockInfos(regions, body); + + SmallVector mergedRegions; + // 合并wait-set块, 依据copyop / fixpipeop合并 + MergeWaitSetRegions(regions, mergedRegions); + + // 扩展if包裹的op范围 + // AIV、AIC处理有区别 + ExpandMergedRegionOps(copiedOp, mergedRegions, tmpOps); + + // 处理forop的末尾对于iter_arg的自增操作, 如tt.advance, 移进对应的if op + MoveIterArgUsersIntoIf(copiedOp, mergedRegions); + + // 获取if yield的value, 并更新if内op的user为yield value + for (MergedRegion &mr : mergedRegions) { + // ComputeYieldForMergedRegion(mr, body); + ComputeYieldForMergedRegionV4(mr); + } + + // // 创建最终的if op + // CreateIfOpsOrigin(mergedRegions); + // }); + + forOpList.push_back(copiedOp); + regionList.push_back(mergedRegions); + } + + for (size_t i = 0; i < forOpList.size(); ++i) { + scf::ForOp oldForOp = forOpList[i]; + SmallVector newMergedRegions = regionList[i]; + + // 找到所有的VV或CC依赖 + SmallVector dependValues; + llvm::outs() << "FindDependValues! \n "; + FindDependValues(dependValues, newMergedRegions); + + if (dependValues.size() != 0) { + copyLoadCalculation(oldForOp, dependValues, newMergedRegions); + + // repeat previous operations + for (MergedRegion &mr : newMergedRegions) { + mr.yieldValues.clear(); + mr.resultTypes.clear(); + ComputeYieldForMergedRegionV4(mr); + } + FindDependValues(dependValues, newMergedRegions); + } + + // 如果存在VV或CC依赖,更新ForOp添加新的对应args + if (dependValues.size() != 0) { + AddArgsForDependValues(oldForOp, dependValues, newMergedRegions, module); + } + + // 创建最终的if op + llvm::outs() << "before create if ops" << '\n'; + CreateIfOps(newMergedRegions, dependValues); + } +} + +void ChangeAdvanceOpForm(ModuleOp module) { + module.walk([&](scf::ForOp forOp) { + Block &body = forOp.getRegion().front(); + constexpr int num = 8; + SmallVector ifOps; + for (Operation &op : body) + if (auto ifOp = dyn_cast(&op)) + ifOps.push_back(ifOp); + + for (scf::IfOp ifOp : ifOps) { + // 找 then region 中的 advance + triton::AdvanceOp advanceOp; + for (Operation &thenOp : ifOp.getThenRegion().front()) { + if (auto adv = dyn_cast(thenOp)) { + advanceOp = adv; + break; + } + } + if (!advanceOp) + continue; + + // base 必须是 for的iter_arg + Value base = advanceOp.getPtr(); + auto barg = dyn_cast(base); + if (!barg || barg.getOwner() != &body) + continue; + + // yield 去掉 advance 的返回值 + auto thenYield = + cast(ifOp.getThenRegion().front().getTerminator()); + auto elseYield = + cast(ifOp.getElseRegion().front().getTerminator()); + + int advanceIdx = -1; + for (auto it : llvm::enumerate(thenYield.getOperands())) { + if (it.value() == advanceOp.getResult()) { + advanceIdx = it.index(); + break; + } + } + + if (advanceIdx == -1) + continue; + + // 删除 advance + SmallVector thenOps(thenYield.getOperands().begin(), + thenYield.getOperands().end()); + SmallVector elseOps(elseYield.getOperands().begin(), + elseYield.getOperands().end()); + + thenOps.erase(thenOps.begin() + advanceIdx); + elseOps.erase(elseOps.begin() + advanceIdx); + + thenYield->setOperands(thenOps); + elseYield->setOperands(elseOps); + + // 重建 ifOp(去掉 advance 对应的 result) + OpBuilder ifBuilder(ifOp); + ifBuilder.setInsertionPoint(ifOp); + + // 构造新的 result types + SmallVector newResultTypes; + for (int i = 0; i < ifOp.getNumResults(); ++i) { + if (i != advanceIdx) + newResultTypes.push_back(ifOp.getResult(i).getType()); + } + + // 创建新的 if + auto newIf = ifBuilder.create(ifOp.getLoc(), newResultTypes, + ifOp.getCondition(), + /*withElseRegion=*/true); + newIf->setAttr("ssbuffer", ifBuilder.getUnitAttr()); + // 把已经修改过 yield 的 region 搬过去 + newIf.getThenRegion().takeBody(ifOp.getThenRegion()); + newIf.getElseRegion().takeBody(ifOp.getElseRegion()); + + // 替换if result的user + int newIdx = 0; + for (int oldIdx = 0; oldIdx < ifOp.getNumResults(); ++oldIdx) { + if (oldIdx == advanceIdx) + continue; + ifOp.getResult(oldIdx).replaceAllUsesWith(newIf.getResult(newIdx++)); + } + + OpBuilder builder(newIf); + builder.setInsertionPointAfter(newIf); + + Value flag = newIf.getCondition(); + + SmallVector newOffsets; + for (Value off : advanceOp.getOffsets()) { + auto intTy = cast(off.getType()); + auto zero = builder.create(newIf.getLoc(), 0, + intTy.getWidth()); + auto sel = + builder.create(newIf.getLoc(), flag, off, zero); + newOffsets.push_back(sel); + } + + auto newAdvance = builder.create( + newIf.getLoc(), base.getType(), base, newOffsets); + + // 原 if 的 advance result 的 users,接到 newAdvance + ifOp.getResult(advanceIdx).replaceAllUsesWith(newAdvance.getResult()); + + // 删除旧的ifOp和advance + advanceOp.erase(); + ifOp.erase(); + } + }); +} + +void processRedudantIf(ModuleOp module) { + SmallVector forOps; + llvm::outs() << module << " wwwww\n\n\n"; + module.walk([&](scf::ForOp forOp) { + auto initArgs = forOp.getInitArgs(); + if (initArgs.size() == 5) { + forOps.push_back(forOp); + } + }); + + for (auto forOp : forOps) { + auto initArgs = forOp.getInitArgs(); + Value newInit = initArgs[2]; + + // 构建新的初始化参数列表 + SmallVector newInitArgs(initArgs.begin(), initArgs.end()); + newInitArgs.push_back(newInit); + + // 获取原循环的边界和步长 + Value lb = forOp.getLowerBound(); + Value ub = forOp.getUpperBound(); + Value step = forOp.getStep(); + + // 创建新的 ForOp,插入点位于原操作之前 + OpBuilder builder(forOp); + auto newForOp = + builder.create(forOp.getLoc(), lb, ub, step, newInitArgs); + + // 获取新循环的 region 块(已自动包含循环索引和迭代参数) + Block &newBlock = newForOp.getRegion().front(); + Block &oldBlock = forOp.getRegion().front(); + + // 建立块参数的映射:原块参数 -> 新块参数(前6个对应) + IRMapping mapper; + for (unsigned i = 0; i < oldBlock.getNumArguments(); ++i) { + mapper.map(oldBlock.getArgument(i), newBlock.getArgument(i)); + } + // 将原循环体中的操作(不包括终结符)克隆到新块中 + builder.setInsertionPointToStart(&newBlock); + for (auto &op : oldBlock) { + auto newOp = builder.clone(op, mapper); + } + + // 在新块中查找第一个 scf::IfOp(即原代码中的第一个 if) + scf::IfOp firstIfOp = nullptr; + for (auto &op : newBlock.getOperations()) { + if (auto ifOp = dyn_cast(&op)) { + firstIfOp = ifOp; + break; + } + } + assert(firstIfOp && "Expected at least one if op in the loop body"); + + // 修改第一个 if 的 else 分支的 yield 操作: + // 将其第二个操作数(索引1)从原来的 %arg9 改为新迭代参数(新块参数索引6) + Block &elseBlock = firstIfOp.getElseRegion().front(); + auto elseYield = cast(elseBlock.getTerminator()); + SmallVector newElseYieldOps(elseYield.getOperands()); + newElseYieldOps[1] = newBlock.getArgument(6); // 新迭代参数 + builder.setInsertionPoint(elseYield); + builder.create(elseYield.getLoc(), newElseYieldOps); + elseYield->erase(); + + // 创建新的循环 yield 操作:原5个操作数 + 第一个 if 的第二个结果 + auto oldYield = cast(newBlock.getTerminator()); + SmallVector newYieldOps(oldYield.getOperands()); + newYieldOps.push_back(firstIfOp.getResult(1)); // 第一个 if 的第二个结果 + builder.setInsertionPointToEnd(&newBlock); + builder.create(oldYield.getLoc(), newYieldOps); + oldYield.erase(); + + // 将原 forOp 的所有使用替换为新 forOp 的前5个结果 + for (auto it : + llvm::zip(forOp->getResults(), newForOp->getResults().take_front(5))) { + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + } + for (auto forOp : forOps) { + forOp.erase(); + } +} +// 针对依赖变量,对原本的for op增加double buffer相关的迭代参数 +scf::ForOp addDoubleBuffForArgs(ModuleOp module, SmallVector uniqueDeps, + int bufferNum) { + mlir::OpBuilder builder(module.getContext()); + SmallVector depValueForIdxs; + + // ========== 找到scf.if所在的scf::ForOp ========== + if (!isa(uniqueDeps[0].getDefiningOp()->getParentOp())) { + llvm::errs() << "Error: parent op of scf.if is not scf.for"; + } + scf::ForOp forOp = + dyn_cast(uniqueDeps[0].getDefiningOp()->getParentOp()); + + for (Value dependencyValue : uniqueDeps) { + // ========== 步骤1:验证目标Value是scf.if的返回值,并找到对应的scf::IfOp + // ========== + Operation *ifOp = dependencyValue.getDefiningOp(); + if (!ifOp || !isa(ifOp)) { + llvm::errs() << "Error: 目标Value不是scf.if的返回值\n"; + return nullptr; + } + scf::IfOp targetIfOp = dyn_cast(ifOp); + + // 确认当前Value是scf.if的第几个返回值 + int64_t depValueIdx = -1; + for (auto [idx, result] : llvm::enumerate(targetIfOp.getResults())) { + if (result == dependencyValue) { + depValueIdx = idx; + break; + } + } + + // ========== 步骤2:找到%38#2关联的scf.for迭代参数以及索引 ========== + // %38#2对应scf.if else分支yield的第2个操作数 → 即%arg10 + Operation *elseYield = targetIfOp.elseYield(); + Value dependencyArg = elseYield->getOperand( + depValueIdx); // depValueIdx=2,对应else yield的第2个参数 + + int64_t depValueForIdx = -1; + for (auto [idx, result] : llvm::enumerate(forOp.getRegionIterArgs())) { + if (result == dependencyArg) { + depValueForIdx = idx; + break; + } + } + depValueForIdxs.push_back(depValueForIdx); + llvm::outs() << "depValueForIdx: " << depValueForIdx << '\n'; + } + + llvm::outs() << "oldFor: " << forOp << '\n'; + + // 获取原始循环的信息 + Value originalLowerBound = forOp.getLowerBound(); + Value originalUpperBound = forOp.getUpperBound(); + Value originalStep = forOp.getStep(); + SmallVector originalInitArgs = forOp.getInitArgs(); + SmallVector iterArgs; + for (auto arg : originalInitArgs) { + iterArgs.push_back(arg); + } + auto yields = forOp.getBody()->getTerminator(); + + // 创建计数器初始零值 + Value counterInit = nullptr; + mlir::Operation *parentOp = forOp->getParentOp(); + mlir::Operation *scopeOp = nullptr; + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + + builder.setInsertionPoint(scopeOp); + Location loc = forOp.getLoc(); + auto boundType = originalLowerBound.getType(); + counterInit = builder.create(loc, 0, boundType); + + // 添加和depValueForIdxs相同的迭代参数和计数器 + for (int64_t idx : depValueForIdxs) { + for (int i = 0; i < bufferNum - 1; i++) { + iterArgs.push_back(originalInitArgs[idx]); + } + + // 在迭代参数中添加计数器 + for (int i = 0; i < 2; i++) { + iterArgs.push_back(counterInit); + } + } + + builder.setInsertionPoint(forOp); + // 创建新的for循环 + auto newForOp = + builder.create(forOp.getLoc(), originalLowerBound, + originalUpperBound, originalStep, iterArgs); + + // 设置IR映射表,将旧循环的变量映射到新循环 + IRMapping mapper; + + // 映射迭代变量 + mapper.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + // 映射迭代参数 + for (auto [oldArg, newArg] : + llvm::zip(forOp.getRegionIterArgs(), newForOp.getRegionIterArgs())) { + mapper.map(oldArg, newArg); + } + + SmallVector newArgs; + for (int i = forOp.getRegionIterArgs().size(); + i < newForOp.getRegionIterArgs().size(); i++) { + newArgs.push_back(newForOp.getRegionIterArgs()[i]); + } + // 克隆循环体内容到新循环 + auto &newLoopBody = *newForOp.getBody(); + builder.setInsertionPointToStart(&newLoopBody); + + for (auto &op : forOp.getBody()->without_terminator()) { + builder.clone(op, mapper); + } + + // 克隆yield操作 + if (auto yieldOp = dyn_cast(yields)) { + SmallVector newYieldOperands; + for (auto operand : yieldOp.getOperands()) { + newYieldOperands.push_back(mapper.lookupOrDefault(operand)); + } + // 将新增的迭代参数添加到yield操作数中 + for (auto currentCounter : newArgs) { + newYieldOperands.push_back(currentCounter); + } + builder.create(yieldOp.getLoc(), newYieldOperands); + } + + // 替换原循环的结果 + unsigned numOriginalResults = forOp.getNumResults(); + SmallVector originalResults; + for (unsigned i = 0; i < numOriginalResults; i++) { + originalResults.push_back(newForOp.getResult(i)); + } + forOp.replaceAllUsesWith(originalResults); + + // 8. 删除原循环 + forOp.erase(); + + llvm::outs() << "for op erased!\n"; + return newForOp; +} + +SmallVector buildNBufferProducer(OpBuilder &builder, Location loc, + Value frontCnt, Value newDepVal, + ArrayRef buffs, + ArrayRef constants) { + // N-buffer producer: determines which buffer is written to newDepVal based on + // frontCnt % N + const int N = buffs.size(); + SmallVector results; + + // idx = frontCnt % N + Value bufferIndex = + builder.create(loc, frontCnt, constants[N]); + + // 1. buffer0: handle the first buffer separately + Value isBuffer0 = builder.create(loc, arith::CmpIPredicate::eq, + bufferIndex, constants[0]); + + auto dstShapedType = mlir::dyn_cast(newDepVal.getType()); + auto maskType = + RankedTensorType::get(dstShapedType.getShape(), isBuffer0.getType()); + Value mask = builder.create(loc, maskType, isBuffer0); + Value newBuff0 = + builder.create(loc, mask, newDepVal, buffs[0]); + + results.push_back(newBuff0); + + // 2. Double-buffer specialization (when N == 2, a direct select is + // sufficient) + if (N == 2) { + + Value newBuff1 = + builder.create(loc, mask, buffs[1], newDepVal); + + auto nextCnt = builder.create(loc, frontCnt, constants[1]); + + results.push_back(newBuff1); + results.push_back(nextCnt.getResult()); + + return results; + } + + // 3. Build the root IF: when idx == 0, + // use the first buffer; otherwise enter the nestedIf chain to use other + // buffers + SmallVector resultTypes; + for (int i = 1; i < N; ++i) + resultTypes.push_back(buffs[i].getType()); + + auto rootIf = builder.create(loc, resultTypes, isBuffer0, true); + + // ---- THEN: buffers are directly forwarded ---- + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&rootIf.getThenRegion().front()); + + SmallVector unchangedBuffers(buffs.begin() + 1, buffs.end()); + + builder.create(loc, unchangedBuffers); + } + + // 4. Construct the nested-if chain, updating one buffer at each level + Block *currentElseBlock = &rootIf.getElseRegion().front(); + + scf::IfOp parentIf = rootIf; + + for (int i = 1; i < N - 1; ++i) { + + builder.setInsertionPointToStart(currentElseBlock); + + // Check whether the current buffer is selected + Value isCurrent = builder.create( + loc, arith::CmpIPredicate::eq, bufferIndex, constants[i]); + + // Update buffer[i] + dstShapedType = mlir::dyn_cast(newDepVal.getType()); + maskType = + RankedTensorType::get(dstShapedType.getShape(), isCurrent.getType()); + mask = builder.create(loc, maskType, isCurrent); + Value updatedBuffer = + builder.create(loc, mask, newDepVal, buffs[i]); + + // If this is the last level: directly yield both buffers + if (i == N - 2) { + + dstShapedType = mlir::dyn_cast(newDepVal.getType()); + maskType = + RankedTensorType::get(dstShapedType.getShape(), isCurrent.getType()); + mask = builder.create(loc, maskType, isCurrent); + Value lastBuffer = + builder.create(loc, mask, buffs[N - 1], newDepVal); + + builder.create(loc, ValueRange{updatedBuffer, lastBuffer}); + + break; + } + + // Create the next nested if + SmallVector subResultTypes; + for (int j = i + 1; j < N; ++j) + subResultTypes.push_back(buffs[j].getType()); + + auto nextIf = + builder.create(loc, subResultTypes, isCurrent, true); + + // THEN: forward the remaining buffers + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&nextIf.getThenRegion().front()); + + SmallVector remainingBuffers(buffs.begin() + i + 1, buffs.end()); + + builder.create(loc, remainingBuffers); + } + + // Update the else yield + builder.setInsertionPointToEnd(&parentIf.getElseRegion().front()); + + SmallVector yields; + yields.push_back(updatedBuffer); + yields.append(nextIf.getResults().begin(), nextIf.getResults().end()); + + builder.create(loc, yields); + + parentIf = nextIf; + currentElseBlock = &nextIf.getElseRegion().front(); + } + + // 5. Update the frontCnt counter + builder.setInsertionPointAfter(rootIf); + + auto nextCnt = builder.create(loc, frontCnt, constants[1]); + + // Collect results + results.append(rootIf.getResults().begin(), rootIf.getResults().end()); + + results.push_back(nextCnt.getResult()); + + return results; +} + +SmallVector buildNBufferConsumer(OpBuilder &builder, Location loc, + Value postCnt, ArrayRef oldBuffs, + ArrayRef constants) { + // Consumer: selects which buffer to read based on postCnt % N + const int bufferNum = oldBuffs.size(); + SmallVector results; + + // idx = postCnt % N + Value bufferIndex = + builder.create(loc, postCnt, constants[bufferNum]); + + Value isBuffer0 = builder.create(loc, arith::CmpIPredicate::eq, + bufferIndex, constants[0]); + auto dstShapedType = mlir::dyn_cast(oldBuffs[0].getType()); + auto maskType = + RankedTensorType::get(dstShapedType.getShape(), isBuffer0.getType()); + auto mask = builder.create(loc, maskType, isBuffer0); + + // 1. Double-buffer specialization (avoid generating scf.if) + if (bufferNum == 2) { + Value selected = + builder.create(loc, mask, oldBuffs[0], oldBuffs[1]); + auto nextCnt = builder.create(loc, postCnt, constants[1]); + + results.push_back(selected); + results.push_back(nextCnt); + + return results; + } + + // 2. Build the root IF: + // when idx == 0, use the first buffer; otherwise enter the nestedIf chain to + // use other buffers + SmallVector resultTypes{oldBuffs[0].getType()}; + + auto rootIf = builder.create(loc, resultTypes, isBuffer0, true); + + // ---- THEN: directly return buffer0 ---- + { + builder.setInsertionPointToStart(&rootIf.getThenRegion().front()); + + builder.create(loc, oldBuffs[0]); + } + + // 3. Construct the nested-if chain + Block *currentElse = &rootIf.getElseRegion().front(); + + for (int i = 1; i < bufferNum - 2; ++i) { + + builder.setInsertionPointToStart(currentElse); + + Value isCurrent = builder.create( + loc, arith::CmpIPredicate::eq, bufferIndex, constants[i]); + + auto nestedIf = builder.create( + loc, TypeRange{oldBuffs[0].getType()}, isCurrent, true); + + // THEN → return the current buffer + { + builder.setInsertionPointToStart(&nestedIf.getThenRegion().front()); + + builder.create(loc, oldBuffs[i]); + } + + // ELSE → yield nested result + builder.setInsertionPointToEnd(currentElse); + builder.create(loc, nestedIf.getResult(0)); + + // Enter the next else branch + currentElse = &nestedIf.getElseRegion().front(); + } + + // 4. Final level (use select to finish) + builder.setInsertionPointToStart(currentElse); + + int last = bufferNum - 2; + + Value isLast = builder.create(loc, arith::CmpIPredicate::eq, + bufferIndex, constants[last]); + + maskType = RankedTensorType::get({}, isLast.getType()); + dstShapedType = mlir::dyn_cast(oldBuffs[last].getType()); + maskType = RankedTensorType::get(dstShapedType.getShape(), isLast.getType()); + mask = builder.create(loc, maskType, isLast); + + Value finalSelect = builder.create(loc, mask, oldBuffs[last], + oldBuffs[last + 1]); + + builder.create(loc, finalSelect); + + // rootIf result = selected buffer + results.push_back(rootIf.getResult(0)); + + // 5. Update the postCnt counter + builder.setInsertionPointAfter(rootIf); + + auto nextCnt = builder.create(loc, postCnt, constants[1]); + + results.push_back(nextCnt); + + return results; +} + +void replaceDepsMap(scf::IfOp oldIfOp, scf::IfOp newIfOp, + SmallVector &newDeps, bool isFront, + DenseMap> &newIfResultDeps) { + mlir::IRMapping valueMap; + + // old result -> new result + for (unsigned i = 0; i < oldIfOp.getNumResults(); ++i) { + valueMap.map(oldIfOp.getResult(i), newIfOp.getResult(i)); + } + + if (isFront) { + for (int i = 0; i < newDeps.size(); i++) { + Value v = newDeps[i]; + if (valueMap.contains(v)) + newDeps[i] = valueMap.lookup(v); + } + } + + // rewrite deps in-place + for (auto &it : newIfResultDeps) { + auto &deps = it.second; + + for (auto &value : deps) { + if (auto mapped = valueMap.lookupOrNull(value)) + value = mapped; + } + } +} + +scf::IfOp addResultsForFrontIfOp( + scf::IfOp frontIfOp, OpBuilder builder, int bufferNum, Value depValue, + SmallVector constants, SmallVector buffs, Value frontCnt, + Value postCnt, SmallVector &extraResultIndices, + SmallVector &newDeps, + DenseMap> &newIfResultDeps) { + OpBuilder::InsertionGuard guard(builder); + + Location loc = frontIfOp.getLoc(); + Value cond = frontIfOp.getCondition(); + + auto &oldThenBlock = frontIfOp.getThenRegion().front(); + auto &oldElseBlock = frontIfOp.getElseRegion().front(); + + // New result types = old results + extra buffers + counter + SmallVector newResultTypes(frontIfOp.getResultTypes().begin(), + frontIfOp.getResultTypes().end()); + + for (int i = 1; i < bufferNum; ++i) + newResultTypes.push_back(buffs[i].getType()); + + newResultTypes.push_back(frontCnt.getType()); + + unsigned oldNumResults = frontIfOp.getNumResults(); + + // Create new IfOp + builder.setInsertionPoint(frontIfOp); + auto newIfOp = + builder.create(loc, newResultTypes, cond, /*hasElse=*/true); + + SmallVector bufferIndices(bufferNum); + SmallVector newBuffs; + int frontCntIndex = -1; + + // THEN region + { + mlir::IRMapping mapping; + Block &newThenBlock = newIfOp.getThenRegion().front(); + + builder.setInsertionPointToStart(&newThenBlock); + + // Clone original then body + for (auto &op : oldThenBlock.without_terminator()) + builder.clone(op, mapping); + + // Update dependency value position inf ifOp results + auto result = dyn_cast(depValue); + if (!result) { + llvm::outs() << "depValue is not a result Value!\n"; + return nullptr; + } + + int depIdx = result.getResultNumber(); + Value depYieldValue = frontIfOp.thenYield()->getOperand(depIdx); + + Value newDepVal = mapping.contains(depYieldValue) + ? mapping.lookup(depYieldValue) + : depYieldValue; + + builder.setInsertionPointAfter(newDepVal.getDefiningOp()); + + // Create N buffer + SmallVector produced = buildNBufferProducer( + builder, loc, frontCnt, newDepVal, buffs, constants); + + // Last value in newBuffs is the counter + newBuffs.append(produced.begin(), produced.end() - 1); + + // Rebuild new yield + SmallVector thenOperands; + + for (Value v : oldThenBlock.getTerminator()->getOperands()) { + Value mapped = mapping.lookupOrDefault(v); + + // Replace first buffer + if (mapped == newDepVal) { + thenOperands.push_back(newBuffs[0]); + bufferIndices[0] = thenOperands.size() - 1; + } else { + thenOperands.push_back(mapped); + } + } + + // Replace other buffer + for (int i = 1; i < bufferNum; ++i) { + thenOperands.push_back(newBuffs[i]); + bufferIndices[i] = thenOperands.size() - 1; + } + + // Add counter + thenOperands.push_back(produced.back()); + frontCntIndex = thenOperands.size() - 1; + + builder.setInsertionPointToEnd(&newThenBlock); + builder.create(loc, thenOperands); + + // record new result indices + for (int idx : bufferIndices) + extraResultIndices.push_back(idx); + + extraResultIndices.push_back(frontCntIndex); + } + + // ELSE region + { + mlir::IRMapping mapping; + Block &newElseBlock = newIfOp.getElseRegion().front(); + + builder.setInsertionPointToStart(&newElseBlock); + + // Clone original else body + for (auto &op : oldElseBlock.without_terminator()) + builder.clone(op, mapping); + + builder.setInsertionPointToEnd(&newElseBlock); + + SmallVector elseOperands; + + for (Value v : oldElseBlock.getTerminator()->getOperands()) + elseOperands.push_back(mapping.lookupOrDefault(v)); + + // Add buffer + for (int i = 1; i < bufferNum; ++i) + elseOperands.push_back(buffs[i]); + + // Add counter + elseOperands.push_back(frontCnt); + + builder.create(loc, elseOperands); + } + + // Update dependency value + replaceDepsMap(frontIfOp, newIfOp, newDeps, true, newIfResultDeps); + + // Replace old ifOp + frontIfOp.replaceAllUsesWith(newIfOp.getResults().take_front(oldNumResults)); + + frontIfOp.erase(); + + return newIfOp; +} + +scf::IfOp addResultsForPostIfOp( + scf::IfOp postIfOp, scf::IfOp newfrontIfOp, OpBuilder builder, + int bufferNum, Value newDepValue, SmallVector constants, + SmallVector buffs, Value frontCnt, Value postCnt, + SmallVector &extraResultIndices, SmallVector &newDeps, + DenseMap> &newIfResultDeps) { + // 1. Parse the extra result indices produced by frontIf (added buffers and + // counters) + SmallVector bufferIndices(extraResultIndices.begin(), + extraResultIndices.end() - 1); + int frontCntIndex = extraResultIndices[bufferNum]; + + Location ifLoc = postIfOp.getLoc(); + Value cond = postIfOp.getCondition(); + + auto &oldThenBlock = postIfOp.getThenRegion().front(); + auto &oldElseBlock = postIfOp.getElseRegion().front(); + + // 2. Create a new IfOp (add a new postCnt result) + SmallVector newResultTypes(postIfOp.getResultTypes().begin(), + postIfOp.getResultTypes().end()); + newResultTypes.push_back(postCnt.getType()); + + builder.setInsertionPoint(postIfOp); + auto newIfOp = builder.create(ifLoc, newResultTypes, cond, + /*hasElse=*/true); + + mlir::IRMapping mapping; + + // 3. THEN region: clone the original logic, insert the multibuffer consumer + // and update dependency buffers + auto &newThenBlock = newIfOp.getThenRegion().front(); + builder.setInsertionPointToStart(&newThenBlock); + + // clone then body + for (auto &op : oldThenBlock.without_terminator()) + builder.clone(op, mapping); + builder.setInsertionPointToStart(&newThenBlock); + + // Find dependency uses that need to be replaced (located inside the current + // IfOp) + SmallVector replaceUses; + for (auto &use : newDepValue.getUses()) { + if (newIfOp == dyn_cast(use.getOwner()->getParentOp())) { + replaceUses.push_back(&use); + } + } + + // Collect buffers produced by frontIf + SmallVector oldBuffers; + for (int i = 0; i < bufferIndices.size(); ++i) + oldBuffers.push_back(newfrontIfOp.getResult(bufferIndices[i])); + + // Multibuffer consumer caculation + SmallVector consumerResults = + buildNBufferConsumer(builder, ifLoc, postCnt, oldBuffers, constants); + + Value selectedBuffer = consumerResults[0]; + Value nextPostCnt = consumerResults[1]; + + // Replace dependent buffer + for (auto *usePtr : replaceUses) { + usePtr->set(selectedBuffer); + } + + // Create then yield + SmallVector thenOperands; + for (auto v : oldThenBlock.getTerminator()->getOperands()) + thenOperands.push_back(mapping.lookupOrDefault(v)); + + int postCntIndex = thenOperands.size(); + thenOperands.push_back(nextPostCnt); + + builder.setInsertionPointToEnd(&newThenBlock); + builder.create(ifLoc, thenOperands); + extraResultIndices.push_back(postCntIndex); + + // 4. ELSE region:forward counter directly + auto &newElseBlock = newIfOp.getElseRegion().front(); + + for (auto &op : oldElseBlock.without_terminator()) + builder.clone(op, mapping); + + builder.setInsertionPointToEnd(&newElseBlock); + + SmallVector elseOperands; + for (auto v : oldElseBlock.getTerminator()->getOperands()) + elseOperands.push_back(mapping.lookupOrDefault(v)); + + elseOperands.push_back(postCnt); + + builder.create(ifLoc, elseOperands); + + // 5. Replace old ifOp with new one + auto oldNumResults = postIfOp.getNumResults(); + + // Update depency value + replaceDepsMap(postIfOp, newIfOp, newDeps, false, newIfResultDeps); + + postIfOp.replaceAllUsesWith(newIfOp.getResults().take_front(oldNumResults)); + + postIfOp.erase(); + + return newIfOp; +} + +void addMultiBuffCaculate(ModuleOp module, SmallVector newUniqueDeps, + DenseMap> &ifResultDeps, + scf::ForOp &newForOp, int bufferNum) { + + // ============================================================ + // Overall Idea + // + // For each dependency Value: + // 1. Find the front IfOp that produces it + // 2. Add multi-buffer results to the front IfOp + // 3. Find the post IfOp that consumes the result and extend it accordingly + // 4. Update the for-loop yield so that buffer states are correctly propagated + // ============================================================ + + OpBuilder builder(module.getContext()); + int processedDepCount = 0; + + SmallVector postIfOps; + newForOp.walk([&](scf::IfOp postIfOp) { postIfOps.push_back(postIfOp); }); + for (auto postIfOp : postIfOps) { + if (!ifResultDeps.count(postIfOp)) { + continue; + } + auto newDeps = ifResultDeps[postIfOp]; + for (int depValueIdx = 0; depValueIdx < newDeps.size(); depValueIdx++) { + Value depValue = newDeps[depValueIdx]; + + // Step 1. Locate the front IfOp that produces depValue + Operation *defOp = depValue.getDefiningOp(); + if (!defOp || !isa(defOp)) { + llvm::outs() << "Error: depValue is not produced by scf.if\n"; + break; + } + + scf::IfOp frontIfOp = cast(defOp); + + // Position of depValue in the IfOp results + auto result = dyn_cast(depValue); + if (!result) { + llvm::outs() << "depValue is not an OpResult!\n"; + return; + } + + int64_t depResultIndex = result.getResultNumber(); + + // Position of depValue in the IfOp results + Value depYieldValue = frontIfOp.thenYield()->getOperand(depResultIndex); + + // Step 2. Find the multi-buffer position in the ForOp + int64_t extraArgBaseIdx = + newForOp.getRegionIterArgs().size() - + (2 + bufferNum - 1) * (newUniqueDeps.size() - processedDepCount++); + + // Collect all buffers + SmallVector buffers; + + // buffer0 来自 else yield + buffers.push_back(frontIfOp.elseYield()->getOperand(depResultIndex)); + + // Other buffers come from for iter args + for (int i = 1; i < bufferNum; ++i) { + buffers.push_back( + newForOp.getRegionIterArgs()[extraArgBaseIdx + i - 1]); + } + + // Two counters + Value frontCnt = + newForOp.getRegionIterArgs()[extraArgBaseIdx + bufferNum - 1]; + Value postCnt = newForOp.getRegionIterArgs()[extraArgBaseIdx + bufferNum]; + + // Step 3. Create constants (0 ~ bufferNum) for rem / cmp buffer selection + // logic + SmallVector constants; + builder.setInsertionPoint(frontIfOp); + + auto dataType = frontCnt.getType(); + for (int i = 0; i <= bufferNum; ++i) { + constants.push_back(builder.create( + frontIfOp.getLoc(), dataType, builder.getIntegerAttr(dataType, i))); + } + + // Record the positions of newly added results in the IfOp + SmallVector extraResultIndices(bufferNum + 1); + extraResultIndices.clear(); + + // Step 4. Extend the front IfOp + scf::IfOp newFrontIfOp = addResultsForFrontIfOp( + frontIfOp, builder, bufferNum, depValue, constants, buffers, frontCnt, + postCnt, extraResultIndices, newDeps, ifResultDeps); + + // buffer result indices + SmallVector bufferResultIndices(extraResultIndices.begin(), + extraResultIndices.end() - 1); + + int frontCntResultIndex = extraResultIndices[bufferNum]; + + Value newDepValue = newFrontIfOp.getResult(depResultIndex); + + // Step 5. Find the post IfOp that consumes the dependency value + scf::IfOp postIfOp = nullptr; + + for (auto &use : newDepValue.getUses()) { + if (auto candidate = + dyn_cast(use.getOwner()->getParentOp())) { + postIfOp = candidate; + break; + } + } + + if (!postIfOp) { + llvm::outs() << "Error: no consuming IfOp found.\n"; + return; + } + + // Step 6. Extend the post IfOp + + scf::IfOp newPostIfOp = addResultsForPostIfOp( + postIfOp, newFrontIfOp, builder, bufferNum, newDepValue, constants, + buffers, frontCnt, postCnt, extraResultIndices, newDeps, + ifResultDeps); + + llvm::outs() << "after addResultsForPostIfOp.\n"; + + int postCntResultIndex = extraResultIndices.back(); + + // Step 7. Update the ForOp yield (buffer propagation) + auto forYield = cast(newForOp.getBody()->getTerminator()); + + // Update buffer1 ~ bufferN + for (int i = 1; i < bufferNum; ++i) { + + int yieldIdx = extraArgBaseIdx + (i - 1); + + if (yieldIdx < forYield->getNumOperands() && + bufferResultIndices[i] < newFrontIfOp.getNumResults()) { + + forYield->setOperand(yieldIdx, + newFrontIfOp.getResult(bufferResultIndices[i])); + + llvm::outs() << "Replaced yield operand " << yieldIdx << "\n"; + } else { + llvm::errs() << "Warning: index out of range\n"; + } + } + + // Step 8. Update frontCnt + OpOperand *frontCntYieldUse = nullptr; + + for (auto &use : frontCnt.getUses()) { + if (isa(use.getOwner()) && + newForOp == use.getOwner()->getParentOp()) { + frontCntYieldUse = &use; + break; + } + } + + frontCntYieldUse->set(newFrontIfOp.getResult(frontCntResultIndex)); + + // Step 9. Update postCnt + OpOperand *postCntYieldUse = nullptr; + + for (auto &use : postCnt.getUses()) { + if (isa(use.getOwner()) && + newForOp == use.getOwner()->getParentOp()) { + postCntYieldUse = &use; + break; + } + } + + postCntYieldUse->set(newPostIfOp.getResult(postCntResultIndex)); + } + } + + llvm::outs() << "multibuffer end!\n"; +} + +// Compute the nesting level of an ifOp within the specified forOp +static int computeIfLevel(scf::IfOp ifOp, scf::ForOp rootForOp) { + int level = 1; + + Operation *parent = ifOp->getParentOp(); + + while (parent && parent != rootForOp.getOperation()) { + if (isa(parent)) + level++; + + parent = parent->getParentOp(); + } + + return level; +} + +int assignIfOpLevels(scf::ForOp forOp) { + SmallVector targetIfOps; + int maxLevel = 0; + // Collect all ifOp assigned with ssbuffer tag + forOp.walk([&](scf::IfOp ifOp) { + if (ifOp->hasAttr("ssbuffer")) { + targetIfOps.push_back(ifOp); + } + }); + + // Caculate buffer levels + for (auto ifOp : targetIfOps) { + int level = computeIfLevel(ifOp, forOp); + maxLevel = std::max(level, maxLevel); + Builder builder(ifOp.getContext()); + ifOp->setAttr("ssbuffer.level", builder.getI32IntegerAttr(level)); + } + return maxLevel; +} + +static bool hasSSBufferIf(scf::ForOp forOp) { + bool found = false; + + forOp.walk([&](scf::IfOp ifOp) { + if (ifOp->hasAttr("ssbuffer")) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + return found; +} + +static bool hasAncestorSSBufferFor(scf::ForOp forOp) { + Operation *parent = forOp->getParentOp(); + + while (parent) { + if (auto parentFor = dyn_cast(parent)) { + if (hasSSBufferIf(parentFor)) + return true; + } + parent = parent->getParentOp(); + } + + return false; +} + +static bool hasAncestorRootFor(scf::ForOp forOp) { + Operation *parent = forOp->getParentOp(); + + while (parent) { + if (auto parentFor = dyn_cast(parent)) { + if (hasSSBufferIf(parentFor)) + return true; + } + parent = parent->getParentOp(); + } + return false; +} + +SmallVector +collectIfInfo(scf::ForOp &curForOp, + DenseMap> &ifDeps, int level) { + // Find all dependency variables based on the inputs and outputs of ifOp + SmallVector allDeps; + DenseSet producedValues; + scf::ForOp newForOp = nullptr; + curForOp.walk([&](scf::IfOp ifOp) { + auto attr = ifOp->getAttrOfType("ssbuffer.level"); + // No level or level mismatch → continue searching + if (!attr || attr.getInt() != level) + return WalkResult::advance(); + + // Levels match → check the direct parent + if (auto parentFor = dyn_cast(ifOp->getParentOp())) { + newForOp = parentFor; // 更新 + } + + // Stop walking regardless of whether the parent is a for-loop + return WalkResult::interrupt(); + }); + + if (newForOp) + curForOp = newForOp; + + // Step 1: Collect first to preserve order + SmallVector ifOps; + curForOp.walk([&](scf::IfOp ifOp) { + auto curLevel = ifOp->getAttrOfType("ssbuffer.level"); + if (!curLevel || curLevel.getInt() != level) { + return WalkResult::advance(); + } + ifOps.push_back(ifOp); + return WalkResult::advance(); + }); + llvm::outs() << "ifOps:" << ifOps.size() << "\n"; + + int miniDepNum = 2; + if (ifOps.size() < miniDepNum) { + return allDeps; + } + // Step 2: Process in order + for (auto ifOp : ifOps) { + llvm::outs() << "ifOp->getOperands():" << ifOp->getOperands().size() + << "\n"; + SmallVector deps; + if (producedValues.empty()) { + llvm::outs() << "producedValues为空!" + << "\n"; + } + + // inputs + Region &thenRegion = ifOp.getThenRegion(); + for (Operation &op : thenRegion.front()) { + for (Value operand : op.getOperands()) { + for (Value v : producedValues) { + if (operand == v && !llvm::is_contained(deps, operand)) { + deps.push_back(operand); + } + } + } + } + + // outputs + for (Value result : ifOp.getResults()) { + producedValues.insert(result); + } + + if (!deps.empty()) { + ifDeps[ifOp] = deps; + allDeps.append(deps.begin(), deps.end()); + } + } + llvm::outs().flush(); + return allDeps; +} + +bool isCube(scope::ScopeOp scope) { + bool ret = false; + scope.walk([&](Operation *op) { + if (isa(op)) { + ret = true; + } + }); + return ret; +} + +// Traverse each Vector scope, find the outer ForOp, and process internal IfOps +void WalkAIVNestedForAndProcess( + ModuleOp module, DenseMap> &ifResultDeps, + int bufferNum) { + if (bufferNum < 2) { + return; + } + + module.walk([&](scope::ScopeOp scope) { + if (isCube(scope)) { + return; + } + + // Traverse ForOps inside the Cube scope (outer loops) + SmallVector targetFors; + + scope.walk([&](scf::ForOp forOp) { + // Must contain an ssbuffer if + if (!hasSSBufferIf(forOp)) + return WalkResult::advance(); + + // Skip if an ancestor is already the root + if (hasAncestorRootFor(forOp)) + return WalkResult::advance(); + + // Find rootForOp + targetFors.push_back(forOp); + + return WalkResult::advance(); + }); + llvm::outs() << "targetFors: " << targetFors.size(); + int maxLevels; + for (auto outerFor : targetFors) { + ifResultDeps.clear(); + scf::ForOp currentFor = outerFor; + maxLevels = assignIfOpLevels(currentFor); + for (int level = 1; level <= maxLevels; level++) { + auto uniqueDeps = collectIfInfo(currentFor, ifResultDeps, level); + llvm::outs() << "maxLevels:" << maxLevels << "\n"; + if (uniqueDeps.empty()) { + continue; + } + llvm::outs() << "uniqueDeps:" << uniqueDeps.size() << "\n"; + auto newForOp = addDoubleBuffForArgs(module, uniqueDeps, bufferNum); + DenseMap> newIfResultDeps; + auto uniqueList = collectIfInfo(newForOp, newIfResultDeps, level); + addMultiBuffCaculate(module, uniqueList, newIfResultDeps, newForOp, + bufferNum); + } + } + }); +} + +void DAGSSBufferPass::runOnOperation() { + auto module = getOperation(); + + AddIfCondition(module); + + FlowSssbuf(module); + ControlSsbufV2(module); + + // advance不能出现在if里, 规避处理 + ChangeAdvanceOpForm(module); + + DenseMap> ifResultDeps; + WalkAIVNestedForAndProcess(module, ifResultDeps, 2); + + return; +} + +std::unique_ptr> mlir::triton::createDAGSSBufferPass() { + return std::make_unique(); +} diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp new file mode 100644 index 0000000000..ed82b46084 --- /dev/null +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp @@ -0,0 +1,1103 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "TritonAffinityOpt/Passes.h" + +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "bishengir/Dialect/HIVM/IR/HIVMImpl.h" +#include "bishengir/Dialect/HIVM/IR/HIVMInterfaces.h" +#include "bishengir/Dialect/HIVM/Transforms/Passes.h" +#include "bishengir/Dialect/HIVM/Utils/Utils.h" +#include "bishengir/Dialect/Scope/IR/Scope.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include + +#include "TritonAffinityOpt/DAG.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DAGSCOPE +#include "ascend/include/TritonAffinityOpt/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace hivm; + +namespace { +struct DAGScopePass : public mlir::triton::impl::DAGScopeBase { + void runOnOperation() override; +}; +} // namespace + +static std::pair +encapsulateWithScope(triton::FuncOp funcOp) { + Block &entryBlock = funcOp.getBody().front(); + Block &lastBlock = funcOp.getBody().back(); + Operation *terminator = lastBlock.getTerminator(); + + // 辅助函数:判断操作是否应该被跳过 + auto shouldSkipOp = [](Operation *op) -> bool { + return isa(op) || isa(op) || + isa(op); + }; + + // 第三步:准备要移动的操作列表(按顺序) + SmallVector opsToMove; + DenseMap opOrder; + int order = 0; + + // 记录原始顺序并收集需要移动的操作 + for (Operation &op : lastBlock.without_terminator()) { + opOrder[&op] = order++; + if (!shouldSkipOp(&op)) { + opsToMove.push_back(&op); + } + } + + // 按原始顺序排序 + std::sort( + opsToMove.begin(), opsToMove.end(), + [&](Operation *a, Operation *b) { return opOrder[a] < opOrder[b]; }); + + if (opsToMove.empty()) { + return std::make_pair(nullptr, nullptr); + } + + // 第四步:创建scope操作并移动操作 + Operation *lastOpToMove = opsToMove.back(); + OpBuilder builder(&lastBlock, ++lastOpToMove->getIterator()); + + // 创建第一个scope + auto scopeOp = builder.create(builder.getUnknownLoc(), + llvm::ArrayRef{}); + scopeOp.getBodyRegion().emplaceBlock(); + Block *scopeBody = &scopeOp.getBodyRegion().front(); + + // 移动操作到scope中 + OpBuilder scopeBuilder(scopeBody, scopeBody->end()); + DenseMap valueMapping; + + for (Operation *op : opsToMove) { + SmallVector originalResults = op->getResults(); + op->remove(); + scopeBuilder.insert(op); + + // 更新值的映射 + for (size_t i = 0; i < originalResults.size(); ++i) { + valueMapping[originalResults[i]] = op->getResult(i); + } + } + + // 添加return操作 + scopeBuilder.create(builder.getUnknownLoc()); + + // 创建第二个scope(如果需要) + scopeBuilder.setInsertionPointAfter(scopeOp); + auto newScopeOp = scopeBuilder.create( + builder.getUnknownLoc(), llvm::ArrayRef{}); + newScopeOp.getRegion().emplaceBlock(); + + OpBuilder newScopeBuilder(&newScopeOp.getRegion().front(), + newScopeOp.getRegion().front().begin()); + newScopeBuilder.create(scopeOp->getLoc()); + + // 设置属性 + auto vecAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + auto aicAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + + scopeOp->setAttr(hivm::TCoreTypeAttr::name, vecAttr); + newScopeOp->setAttr(hivm::TCoreTypeAttr::name, aicAttr); + + return std::make_pair(scopeOp, newScopeOp); +} + +struct OpMoveInfo { + Operation *op; + Operation *targetParent; // 目标父操作(nullptr表示aicScope本身) +}; + +// 递归遍历函数 - 优化版本 +void collectOpsToMove(Operation *op, AffinityDAG::Graph &graph, + Operation *parentFor, + llvm::SmallVector &aivToMove, + llvm::SmallVector &cubeToMove) { + // 检查当前操作是否需要移动 + bool needsMoveAiv = false; + bool needsMoveCube = false; + auto &valueTypes = graph.getValueTypes(); + // 检查结果类型 + int i = 0; + for (auto res : op->getResults()) { + i++; + if (AffinityDAG::intersects(valueTypes[res], + AffinityDAG::CoreType::VECTOR_ONLY)) { + needsMoveAiv = true; + } + if (AffinityDAG::intersects(valueTypes[res], + AffinityDAG::CoreType::CUBE_ONLY)) { + needsMoveCube = true; + } + } + + if (isa(op)) { + auto res = op->getOperand(0); + if (AffinityDAG::intersects(valueTypes[res], + AffinityDAG::CoreType::VECTOR_ONLY)) { + needsMoveAiv = true; + } + if (AffinityDAG::intersects(valueTypes[res], + AffinityDAG::CoreType::CUBE_ONLY)) { + needsMoveCube = true; + } + } + // 检查特定操作类型 + if (isa(op)) { + needsMoveAiv = true; + } + + // 检查特定操作类型 + if (isa(op)) { + needsMoveCube = true; + } + + // 检查特定操作类型 + if (isa(op) || isa(op) || isa(op)) { + needsMoveAiv = true; + needsMoveCube = true; + } + + if (isa(op)) { + if (auto storeOp = dyn_cast(op)) { + // 获取所有操作数列表 + auto operands = storeOp.getOperands(); + bool typeMatched = false; + + // 按顺序检查第1个、第0个、第2个操作数 + std::vector checkOrder = {1, 0, 2}; + for (size_t idx : checkOrder) { + // 先判断操作数索引是否有效,避免越界访问 + if (idx >= operands.size()) { + continue; + } + auto operand = operands[idx]; + auto coreType = valueTypes[operand]; + + if (coreType == AffinityDAG::CoreType::VECTOR_ONLY) { + needsMoveAiv = true; + typeMatched = true; + } else if (coreType == AffinityDAG::CoreType::CUBE_ONLY) { + needsMoveCube = true; + typeMatched = true; + } + } + // 所有指定操作数都不匹配时,执行原else逻辑 + if (!typeMatched) { + needsMoveAiv = true; + needsMoveCube = true; + } + } + } + + if (isa(op)) { + if (auto assertOp = dyn_cast(op)) { + // 获取所有操作数列表 + auto operand = assertOp.getCondition(); + + auto coreType = valueTypes[operand]; + if (coreType == AffinityDAG::CoreType::VECTOR_ONLY) { + needsMoveAiv = true; + } else if (coreType == AffinityDAG::CoreType::CUBE_ONLY) { + needsMoveCube = true; + } else { + needsMoveAiv = true; + needsMoveCube = true; + } + } + } + + // 检查 Sync 操作的 tcore_type 属性 + if ((isa(op) || isa(op))) { + mlir::OpBuilder builder(op); + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + if (op->getAttr("tcore_type") == coreAttr) { + needsMoveCube = true; + } else { + needsMoveAiv = true; + } + } + + // 如果不需要移动,直接返回 + if (!needsMoveAiv && !needsMoveCube) { + llvm::outs() << "Unsupport Op: " << *op << " \n"; + } + + // 处理 for 循环 + if (auto forOp = dyn_cast(op)) { + // 确定父级 for 循环 + Operation *targetParent = parentFor != nullptr ? parentFor : nullptr; + aivToMove.push_back({op, targetParent}); + cubeToMove.push_back({op, targetParent}); + + // 递归处理循环体 + for (auto &block : forOp.getRegion()) { + for (auto &innerOp : block) { + collectOpsToMove(&innerOp, graph, forOp, aivToMove, cubeToMove); + } + } + } else if (auto ifOp = dyn_cast(op)) { + // 确定父级 for 循环 + Operation *targetParent = parentFor != nullptr ? parentFor : nullptr; + aivToMove.push_back({op, targetParent}); + cubeToMove.push_back({op, targetParent}); + + // 递归处理循环体 + for (auto &block : ifOp.getThenRegion()) { + for (auto &innerOp : block) { + collectOpsToMove(&innerOp, graph, ifOp, aivToMove, cubeToMove); + } + } + + // 检查并遍历IfOp的else分支(如果存在) + for (auto &block : ifOp.getElseRegion()) { + for (auto &innerOp : block) { + collectOpsToMove(&innerOp, graph, ifOp, aivToMove, cubeToMove); + } + } + } else { + if (needsMoveAiv) { + // 处理其他操作 + aivToMove.push_back({op, parentFor}); + } + if (needsMoveCube) { + cubeToMove.push_back({op, parentFor}); + } + } +} + +mlir::Block *getBlockByIndex(mlir::Region ®ion, int blockIndex) { + // 边界校验:索引非法时返回nullptr + if (blockIndex < 0) + return nullptr; + + int currentIdx = 0; + for (auto &block : region) { + if (currentIdx == blockIndex) { + return █ // 找到对应索引的Block,直接返回 + } + currentIdx++; + } + // 索引越界时返回nullptr + return nullptr; +} + +void processOperationToMove( + const OpMoveInfo &info, + llvm::DenseMap &parentMap, + mlir::OpBuilder &builder, mlir::IRMapping &mapper, mlir::Block *aivBlock, + mlir::Operation *terminator, AffinityDAG::Graph &graph, int MoveType) { + // llvm::outs()<<*info.op<<" ssss\n\n\n"; + // llvm::outs().flush(); + // 获取原始Block信息并计算索引 + mlir::Block *originalBlock = info.op->getBlock(); + int originalRegionIndex = -1; + int originalBlockIndex = -1; + int blockCounter = 0; + auto &valueTypes = graph.getValueTypes(); + if (originalBlock) { + mlir::Operation *parentOp = info.op->getParentOp(); // 原始父操作 + if (parentOp) { // 确保父操作存在 + // 老版本MLIR用 getParent() 替代 getParentRegion(),返回值就是Region* + mlir::Region *blockBelongsToRegion = originalBlock->getParent(); + int regionCounter = 0; + for (auto ®ion : parentOp->getRegions()) { // 遍历父操作的所有region + // 直接对比指针,判断当前region是否是block所属的region + if (®ion == blockBelongsToRegion) { + originalRegionIndex = regionCounter; + break; + } + regionCounter++; + } + } + } + + if (originalBlock) { + for (auto &block : originalBlock->getParent()->getBlocks()) { + if (&block == originalBlock) { + originalBlockIndex = blockCounter; + break; + } + blockCounter++; + } + } + + if (originalBlockIndex == -1) { + originalBlockIndex = 0; + } + if (originalRegionIndex == -1) { + originalRegionIndex = 0; + } + + // 处理 scf::ForOp 类型操作 + if (mlir::isa(info.op)) { + auto forOp = mlir::cast(info.op); + + auto getMapped = [&](mlir::Value v) { return mapper.lookupOrDefault(v); }; + auto inputs = forOp.getInitArgs(); + auto outputs = forOp.getResults(); + + // 分离需要移动到aivScope的参数 + llvm::SmallVector aivInputs; + llvm::DenseMap aivInputsMap; + int aivIndex = 1; + + for (int i = 0; i < inputs.size(); ++i) { + if (valueTypes[outputs[i]] != MoveType) { + aivInputs.push_back(inputs[i]); + aivInputsMap[i + 1] = aivIndex; + aivIndex++; + } + } + + // 创建新的for循环 + auto aivForOp = builder.create( + forOp.getLoc(), getMapped(forOp.getLowerBound()), + getMapped(forOp.getUpperBound()), getMapped(forOp.getStep()), + llvm::to_vector(llvm::map_range(aivInputs, getMapped))); + + // 清空循环体 + if (!aivForOp.getBody()->empty()) { + aivForOp.getBody()->getTerminator()->erase(); + } + + // 处理原始循环的yield操作 + auto oldBody = forOp.getBody(); + auto oldYield = + mlir::dyn_cast(oldBody->getTerminator()); + assert(oldYield && "scf::ForOp must have a yield terminator"); + + llvm::SmallVector aivYieldOperands; + for (int i = 0; i < inputs.size(); ++i) { + if (valueTypes[outputs[i]] != MoveType) { + aivYieldOperands.push_back(oldYield.getOperand(i)); + } + } + + // 映射循环参数 + auto oldBodyArgs = forOp.getBody()->getArguments(); + auto aivBodyArgs = aivForOp.getBody()->getArguments(); + + for (auto it = aivInputsMap.begin(); it != aivInputsMap.end(); ++it) { + int oldInputIndex = it->first; + int mappedNewIndex = it->second; + mapper.map(oldBodyArgs[oldInputIndex], aivBodyArgs[mappedNewIndex]); + mapper.map((*info.op).getResults()[oldInputIndex - 1], + aivForOp->getResults()[mappedNewIndex - 1]); + } + mapper.map(oldBodyArgs[0], aivBodyArgs[0]); + + // 将新循环移动到目标位置 + if (info.targetParent == nullptr) { + mlir::Block *targetBlock = aivBlock; + if (terminator) { + aivForOp->moveBefore(terminator); + } else { + aivForOp->moveBefore(targetBlock, targetBlock->end()); + } + parentMap[forOp] = aivForOp; + } else { + auto targetParent = parentMap[info.targetParent]; + auto ®ion = targetParent->getRegion(originalRegionIndex); + + if (region.empty()) { + region.push_back(new mlir::Block()); + } + + mlir::Block *targetBlock = getBlockByIndex(region, originalBlockIndex); + if (targetBlock) { + aivForOp->moveBefore(targetBlock, targetBlock->end()); + parentMap[forOp] = aivForOp; + } else { + llvm::outs() << "Can't find block by index\n"; + } + } + } + + // 处理 scf::YieldOp 类型操作 + else if (mlir::isa(info.op)) { + auto yieldOp = mlir::cast(info.op); + + // 处理父节点为 scf::ForOp 的情况 + if (auto parentForOp = + mlir::dyn_cast(info.targetParent)) { + auto it = parentMap.find(parentForOp); + if (it == parentMap.end()) { + return; + } + auto targetOp = it->second; + auto newForOp = mlir::cast(targetOp); + + auto oldInputs = parentForOp.getInitArgs(); + auto oldOutputs = parentForOp.getResults(); + auto oldYieldOperands = yieldOp.getOperands(); + + llvm::SmallVector newYieldOperands; + for (int i = 0; i < oldInputs.size(); ++i) { + if (valueTypes[oldOutputs[i]] != MoveType) { + mlir::Value oldOperand = oldYieldOperands[i]; + mlir::Value newOperand = mapper.lookupOrDefault(oldOperand); + newYieldOperands.push_back(newOperand); + } + } + + auto newYieldOp = builder.create(yieldOp.getLoc(), + newYieldOperands); + auto ®ion = newForOp->getRegion(0); + mlir::Block *targetBlock = ®ion.front(); + newYieldOp->moveBefore(targetBlock, targetBlock->end()); + } + // 处理父节点为 scf::IfOp 的情况 + else if (auto parentIfOp = + mlir::dyn_cast(info.targetParent)) { + auto it = parentMap.find(parentIfOp); + if (it == parentMap.end()) { + return; + } + auto targetOp = it->second; + auto newIfOp = mlir::cast(targetOp); + + auto oldInputs = parentIfOp.getResults(); + auto oldOutputs = parentIfOp.getResults(); + auto oldYieldOperands = yieldOp.getOperands(); + + llvm::SmallVector newYieldOperands; + for (int i = 0; i < oldInputs.size(); ++i) { + if (valueTypes[oldOutputs[i]] != MoveType) { + mlir::Value oldOperand = oldYieldOperands[i]; + mlir::Value newOperand = mapper.lookupOrDefault(oldOperand); + newYieldOperands.push_back(newOperand); + } + } + + auto ®ion = newIfOp->getRegion(originalRegionIndex); + auto newYieldOp = builder.create(yieldOp.getLoc(), + newYieldOperands); + mlir::Block *targetBlock = getBlockByIndex(region, originalBlockIndex); + if (targetBlock) { + newYieldOp->moveBefore(targetBlock, targetBlock->end()); + } else { + llvm::outs() << "Can't find block by index\n"; + } + } + } + + // 处理 scf::IfOp 类型操作 + else if (mlir::isa(info.op)) { + auto ifOp = mlir::cast(info.op); + + auto getMapped = [&](mlir::Value v) { return mapper.lookupOrDefault(v); }; + mlir::Value condition = ifOp.getCondition(); + + // 分离需要移动到aivScope的结果 + llvm::SmallVector aivResults; + llvm::SmallVector aivResultTypes; + llvm::DenseMap aivResultMap; + int aivResultIndex = 0; + + for (int i = 0; i < ifOp.getNumResults(); ++i) { + mlir::Value result = ifOp.getResult(i); + if (valueTypes[result] != MoveType) { + aivResults.push_back(result); + aivResultTypes.push_back(result.getType()); + aivResultMap[i] = aivResultIndex; + aivResultIndex++; + } + } + + // 创建新的if操作 + auto aivIfOp = builder.create( + ifOp.getLoc(), aivResultTypes, getMapped(condition)); + + // 映射if操作结果 + for (auto &[oldIdx, newIdx] : aivResultMap) { + mapper.map(ifOp.getResult(oldIdx), aivIfOp.getResult(newIdx)); + } + + // 初始化then和else区域 + mlir::Region &thenRegion = aivIfOp.getThenRegion(); + mlir::Block *thenBlock = new mlir::Block(); + thenRegion.push_back(thenBlock); + + mlir::Region &elseRegion = ifOp.getElseRegion(); + if (!elseRegion.empty()) { + mlir::Region &elseRegion = aivIfOp.getElseRegion(); + mlir::Block *elseBlock = new mlir::Block(); + elseRegion.push_back(elseBlock); + } + + // 将新if操作移动到目标位置 + if (info.targetParent == nullptr) { + mlir::Block *targetBlock = aivBlock; + if (terminator) { + aivIfOp->moveBefore(terminator); + } else { + aivIfOp->moveBefore(targetBlock, targetBlock->end()); + } + parentMap[ifOp] = aivIfOp; + } else { + auto ®ion = + parentMap[info.targetParent]->getRegion(originalRegionIndex); + if (region.empty()) { + region.push_back(new mlir::Block()); + } + + mlir::Block *targetBlock = getBlockByIndex(region, originalBlockIndex); + if (targetBlock) { + aivIfOp->moveBefore(targetBlock, targetBlock->end()); + parentMap[ifOp] = aivIfOp; + } else { + llvm::outs() << "Can't find block by index\n"; + } + } + } + + // 处理其他类型操作(克隆) + else { + auto clonedOp = builder.clone(*info.op, mapper); + auto numberRes = clonedOp->getNumResults(); + for (auto i = 0; i < numberRes; i++) { + mapper.map((*info.op).getResults()[i], clonedOp->getResults()[i]); + } + + if (info.targetParent == nullptr) { + mlir::Block *targetBlock = aivBlock; + clonedOp->moveBefore(terminator); + parentMap[info.op] = clonedOp; + } else { + auto parentIt = parentMap.find(info.targetParent); + auto mappedParentOp = parentIt->second; + auto ®ion = mappedParentOp->getRegion(originalRegionIndex); + + if (region.empty()) { + region.push_back(new mlir::Block()); + } + + mlir::Block *targetBlock = getBlockByIndex(region, originalBlockIndex); + if (targetBlock) { + clonedOp->moveBefore(targetBlock, targetBlock->end()); + } else { + llvm::outs() << "Can't find block by index\n"; + } + } + } +} + +static void SplitScope(triton::FuncOp funcOp, AffinityDAG::Graph &graph, + Operation *aivScope, Operation *aicScope, + ModuleOp module) { + llvm::SmallVector aivToMove; + llvm::SmallVector cubeToMove; + for (auto &block : aivScope->getRegion(0)) { + for (auto &op : block) { + collectOpsToMove(&op, graph, nullptr, aivToMove, cubeToMove); + } + } + mlir::IRMapping aivmapper; + mlir::OpBuilder builder(aivScope); + llvm::DenseMap aivparentMap; + + // 第二遍:实际移动操作 + // 先移动for循环 + mlir::Block *aivBlock = + &aivScope->getRegion(0).front(); // 或者使用合适的block + SmallVector deleteOp; + auto *terminator = aivBlock->getTerminator(); + // 如果操作已被使用,直接跳过 + llvm::SmallVector + aivUsedOp; // 改为函数内静态,保持原有逻辑 + for (const auto &info : aivToMove) { + if (std::find(aivUsedOp.begin(), aivUsedOp.end(), info.op) != + aivUsedOp.end()) { + return; + } + aivUsedOp.push_back(info.op); + processOperationToMove(info, aivparentMap, builder, aivmapper, aivBlock, + terminator, graph, AffinityDAG::CoreType::CUBE_ONLY); + } + + llvm::DenseMap aicparentMap; + mlir::IRMapping aicmapper; + mlir::Block *aicBlock = + &aicScope->getRegion(0).front(); // 或者使用合适的block + terminator = aicBlock->getTerminator(); + llvm::SmallVector + aicUsedOp; // 改为函数内静态,保持原有逻辑 + for (const auto &info : cubeToMove) { + if (std::find(aicUsedOp.begin(), aicUsedOp.end(), info.op) != + aicUsedOp.end()) { + return; + } + aicUsedOp.push_back(info.op); + processOperationToMove(info, aicparentMap, builder, aicmapper, aicBlock, + terminator, graph, + AffinityDAG::CoreType::VECTOR_ONLY); + } + + for (const auto &info : aivToMove) { + if (std::find(deleteOp.begin(), deleteOp.end(), info.op) == + deleteOp.end()) { + deleteOp.push_back(info.op); + } + } + for (const auto &info : cubeToMove) { + if (std::find(deleteOp.begin(), deleteOp.end(), info.op) == + deleteOp.end()) { + deleteOp.push_back(info.op); + } + } + + // llvm::outs() << "\n" << module<<" ====== ddd ====== \n\n\n"; + // llvm::outs().flush(); + for (auto it = deleteOp.rbegin(); it != deleteOp.rend(); ++it) { + (*it)->erase(); // 解引用反向迭代器,调用 erase 方法 + } + return; +} + +/// 创建setop +static hivm::SyncBlockSetOp +createSyncBlockSetOp(OpBuilder &builder, Location loc, hivm::TCoreType coreType, + hivm::PIPE setPipeEnum, hivm::PIPE waitPipeEnum, + int64_t flag) { + MLIRContext *ctx = builder.getContext(); + auto coreAttr = hivm::TCoreTypeAttr::get(ctx, coreType); + auto setPipe = hivm::PipeAttr::get(ctx, setPipeEnum); + auto waitPipe = hivm::PipeAttr::get(ctx, waitPipeEnum); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + return builder.create(loc, coreAttr, setPipe, waitPipe, + flagId); +} + +/// 创建waitop +static hivm::SyncBlockWaitOp +createSyncBlockWaitOp(OpBuilder &builder, Location loc, + hivm::TCoreType coreType, hivm::PIPE setPipeEnum, + hivm::PIPE waitPipeEnum, int64_t flag) { + MLIRContext *ctx = builder.getContext(); + auto coreAttr = hivm::TCoreTypeAttr::get(ctx, coreType); + auto setPipe = hivm::PipeAttr::get(ctx, setPipeEnum); + auto waitPipe = hivm::PipeAttr::get(ctx, waitPipeEnum); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + return builder.create(loc, coreAttr, setPipe, waitPipe, + flagId); +} + +// 在scope return前插入wait +static void insertWaitBeforeFinalReturn(Region *region, OpBuilder &builder, + int64_t flag, bool coretypebool) { + for (Block &block : *region) { + if (auto returnOp = + dyn_cast_or_null(block.getTerminator())) { + builder.setInsertionPoint(returnOp); + if (coretypebool) { + createSyncBlockWaitOp(builder, returnOp->getLoc(), + hivm::TCoreType::CUBE, hivm::PIPE::PIPE_V, + hivm::PIPE::PIPE_FIX, flag); + return; + } else { + createSyncBlockWaitOp(builder, returnOp->getLoc(), + hivm::TCoreType::VECTOR, hivm::PIPE::PIPE_M, + hivm::PIPE::PIPE_MTE3, flag); + return; + } + } + } +} + +/// 在scope内起始位置加上set +static void insertSetAtRegionStart(Region *region, OpBuilder &builder, + int64_t flag, bool coretypebool) { + if (!region->empty()) { + Block &entry = region->front(); + Location loc = entry.empty() ? region->getParentOp()->getLoc() + : entry.front().getLoc(); + builder.setInsertionPointToStart(&entry); + if (coretypebool) { + createSyncBlockSetOp(builder, loc, hivm::TCoreType::VECTOR, + hivm::PIPE::PIPE_V, hivm::PIPE::PIPE_FIX, flag); + } else { + createSyncBlockSetOp(builder, loc, hivm::TCoreType::CUBE, + hivm::PIPE::PIPE_M, hivm::PIPE::PIPE_MTE3, flag); + } + } +} + +static Operation *findNextSyncBlockSetAfter(Operation *startOp) { + Block *block = startOp->getBlock(); + auto it = ++startOp->getIterator(); + for (; it != block->end(); ++it) { + if (isa(*it)) + return &*it; + } + return nullptr; +} + +static hivm::SyncBlockWaitOp findWaitOpInRegionWithFlag(Region *region, + int64_t flag) { + hivm::SyncBlockWaitOp result; + region->walk([&](hivm::SyncBlockWaitOp op) { + auto flagAttr = op->getAttrOfType("static_flag_id"); + if (flagAttr && flagAttr.getInt() == flag) { + result = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + +static Operation *findInsertionPointAfterWaitForAIV(Operation *waitOp) { + Block *block = waitOp->getBlock(); + auto it = ++waitOp->getIterator(); + + for (; it != block->end(); ++it) { + if (isa(*it) || isa(*it)) { + break; + } + } + + while (it != block->begin()) { + auto prevIt = std::prev(it); + if (isa(*prevIt)) { + it = prevIt; + } else { + break; + } + } + + return &*it; +} + +static Operation *findInsertionPointAfterWaitForAIC(Operation *waitOp) { + Block *block = waitOp->getBlock(); + auto it = ++waitOp->getIterator(); + for (; it != block->end(); ++it) { + if (auto fixpipe = dyn_cast(*it)) { + if (it != block->begin()) { + auto prev = std::prev(it); + if (isa(*prev)) + return &*prev; + } + return &*it; + } + if (isa(*it)) + return &*it; + } + return nullptr; +} + +// 查找 FixpipeOp 下一行的 sync_block_set 操作的 flag 值 +static int findFixPipeFlagSafe(hivm::FixpipeOp fixpipeOp) { + mlir::Operation *fixpipeOperation = fixpipeOp.getOperation(); + if (!fixpipeOperation || !fixpipeOperation->getBlock()) { + return -1; + } + + // 获取 FixpipeOp 的迭代器 + auto it = ++fixpipeOperation->getIterator(); + + // 遍历后续操作直到找到 sync_block_set + while (it != fixpipeOperation->getBlock()->end()) { + mlir::Operation &op = *it++; + + if (op.getName().getStringRef() == "hivm.hir.sync_block_set") { + auto staticFlagAttr = + op.getAttrOfType("static_flag_id"); + return staticFlagAttr.getInt(); + break; + } + } + + return -1; +} + +/// cube处理逻辑 +static void processFixpipeOpsInAIC(Region *aicRegion, Region *aivRegion) { + + MLIRContext *ctx = aicRegion->getContext(); + OpBuilder builder(ctx); + SmallVector fixpipes; + aicRegion->walk([&](hivm::FixpipeOp op) { fixpipes.push_back(op); }); + + for (auto fixpipeOp : fixpipes) { + + auto newflag = findFixPipeFlagSafe(fixpipeOp); + // 1. 在 FixpipeOp 前插 Wait + builder.setInsertionPoint(fixpipeOp); + createSyncBlockWaitOp(builder, fixpipeOp->getLoc(), hivm::TCoreType::CUBE, + hivm::PIPE::PIPE_V, hivm::PIPE::PIPE_FIX, newflag); + bool coretypebool = true; + + // 2. 在 aicRegion 末尾 Return 前插 Wait + insertWaitBeforeFinalReturn(aicRegion, builder, newflag, coretypebool); + + // 3. 在 aivRegion 开头插 Set + insertSetAtRegionStart(aivRegion, builder, newflag, coretypebool); + + // 4. 在 aicRegion 向后找 SyncBlockSetOp + if (auto *nextSetOp = findNextSyncBlockSetAfter(fixpipeOp)) { + auto setFlagAttr = + nextSetOp->getAttrOfType("static_flag_id"); + // 调试:打印set + // llvm::dbgs() << "aicnextSetOp:"; + // nextSetOp->dump(); + if (!setFlagAttr) { + llvm::dbgs() << "AIC can not find setop in aic\n"; + continue; + } + int64_t setflag = setFlagAttr.getInt(); + + // 5. 在 aivRegion 中找 flag=setflag 的 WaitOp + auto targetWait = findWaitOpInRegionWithFlag(aivRegion, setflag); + if (!targetWait) { + llvm::dbgs() << "AIC can not find waitop in aiv\n"; + continue; + } + + // 调试:打印wait + // llvm::dbgs() << "aictargetWait:"; + // llvm::dbgs() << targetWait << "\n"; + + // 6. 从该 Wait 向下找 ToMemrefOp 或 Yield,插 Set(newflag) + if (auto *insertPt = findInsertionPointAfterWaitForAIV(targetWait)) { + builder.setInsertionPoint(insertPt); + createSyncBlockSetOp(builder, fixpipeOp->getLoc(), + hivm::TCoreType::VECTOR, hivm::PIPE::PIPE_V, + hivm::PIPE::PIPE_FIX, newflag); + } + } + } +} + +// 查找 copyOp 下一行的 sync_block_set 操作的 flag 值 +static int findCopyFlagSafe(bufferization::ToMemrefOp toMemrefOp) { + mlir::Operation *toMemrefOperation = toMemrefOp.getOperation(); + if (!toMemrefOperation || !toMemrefOperation->getBlock()) { + return -1; + } + + // 获取 copyOp 的迭代器 + auto it = ++toMemrefOperation->getIterator(); + + // 遍历后续操作直到找到 sync_block_set + while (it != toMemrefOperation->getBlock()->end()) { + mlir::Operation &op = *it++; + + if (op.getName().getStringRef() == "hivm.hir.sync_block_set") { + auto staticFlagAttr = + op.getAttrOfType("static_flag_id"); + return staticFlagAttr.getInt(); + break; + } + } + + return -1; +} +/// vector处理逻辑 +static void processToMemrefOpsInAIV(Region *aivRegion, Region *aicRegion) { + + MLIRContext *ctx = aivRegion->getContext(); + OpBuilder builder(ctx); + SmallVector toMemrefs; + aivRegion->walk( + [&](bufferization::ToMemrefOp op) { toMemrefs.push_back(op); }); + + for (auto toMemrefOp : toMemrefs) { + auto newflag = findCopyFlagSafe(toMemrefOp); + + // 1. 在 ToMemrefOp 前插 Wait + builder.setInsertionPoint(toMemrefOp); + createSyncBlockWaitOp(builder, toMemrefOp->getLoc(), + hivm::TCoreType::VECTOR, hivm::PIPE::PIPE_M, + hivm::PIPE::PIPE_MTE3, newflag); + bool coretypebool = false; + + // 2. 在 aivRegion 末尾 Return 前插 Wait + insertWaitBeforeFinalReturn(aivRegion, builder, newflag, coretypebool); + + // 3. 在 aicRegion 开头插 Set + insertSetAtRegionStart(aicRegion, builder, newflag, coretypebool); + + // 4. 在 aivRegion 向后找 SyncBlockSetOp + if (auto *nextSetOp = findNextSyncBlockSetAfter(toMemrefOp)) { + auto setFlagAttr = + nextSetOp->getAttrOfType("static_flag_id"); + // 调试:打印set及其所有attribute + // llvm::dbgs() << "aivnextSetOp:"; + // nextSetOp->dump(); + // llvm::dbgs() << "Attributes:\n"; + // for (auto namedAttr : nextSetOp->getAttrs()) { + // llvm::dbgs() << " " << namedAttr.getName() << " = "; + // namedAttr.getValue().print(llvm::dbgs()); + // llvm::dbgs() << "\n"; + // } + if (!setFlagAttr) { + llvm::dbgs() << "AIV can not find setop in aiv\n"; + continue; + } + int64_t setflag = setFlagAttr.getInt(); + + // 5. 在 aicRegion 中找 flag=setflag 的 WaitOp + auto targetWait = findWaitOpInRegionWithFlag(aicRegion, setflag); + + if (!targetWait) { + llvm::dbgs() << "AIV can not find waitop in aic\n"; + continue; + } + + // 调试:打印wait + // llvm::dbgs() << "aivtargetWait:"; + // llvm::dbgs() << targetWait << "\n"; + + // 6. 从该 Wait 向下找 Fixpipe 前 Wait 或 Yield,插 Set(newflag) + if (auto *insertPt = findInsertionPointAfterWaitForAIC(targetWait)) { + builder.setInsertionPoint(insertPt); + createSyncBlockSetOp(builder, toMemrefOp->getLoc(), + hivm::TCoreType::CUBE, hivm::PIPE::PIPE_M, + hivm::PIPE::PIPE_MTE3, newflag); + } + } + } +} + +/// 同步点增强 +void addSyncOpsForBufferWait(ModuleOp module) { + for (auto funcOp : + llvm::make_early_inc_range(module.getOps())) { + if (funcOp.getBody().empty()) { + continue; + } + + Region *aicRegion = nullptr; + Region *aivRegion = nullptr; + + funcOp.walk([&](scope::ScopeOp scopeOp) { + auto coreTypeAttr = scopeOp->getAttrOfType( + hivm::TCoreTypeAttr::name); + if (!coreTypeAttr) + return; + + if (coreTypeAttr.getTcoretype() == hivm::TCoreType::CUBE) { + aicRegion = &scopeOp.getRegion(); + } + if (coreTypeAttr.getTcoretype() == hivm::TCoreType::VECTOR) { + aivRegion = &scopeOp.getRegion(); + } + }); + + if (!aicRegion || !aivRegion) { + continue; + } + + processFixpipeOpsInAIC(aicRegion, aivRegion); + processToMemrefOpsInAIV(aivRegion, aicRegion); + } +} + +void DAGScopePass::runOnOperation() { + auto module = getOperation(); + // llvm::outs()<())) { + // skip invalid function + if (funcOp.getBody().empty()) { + continue; + } + + // 收集所有 memref.alloc 操作 + llvm::SmallVector allocOps; + + // 遍历函数中的所有操作(包括嵌套区域中的操作) + funcOp.walk([&](mlir::Operation *op) { + if (mlir::isa(op)) { + allocOps.push_back(op); + } + }); + + mlir::Block &entryBlock = funcOp.getBody().front(); + mlir::Block::iterator insertPos = entryBlock.begin(); + + // 将 alloc 操作移动到函数的最前面 + for (mlir::Operation *allocOp : allocOps) { + // 如果 alloc 操作已经是最前面的操作,跳过 + if (allocOp->getBlock() == &entryBlock && + allocOp->isBeforeInBlock(&*insertPos)) { + continue; + } + + // 将 alloc 操作移动到指定位置 + allocOp->moveBefore(&entryBlock, insertPos); + } + + auto funcName = funcOp.getName(); + auto *graph_ptr = + AffinityDAG::GraphManager::getInstance().getGraph(funcName); + if (!graph_ptr) { + continue; + } + auto &main_graph = *graph_ptr; + + auto ScopeList = encapsulateWithScope(funcOp); + auto aivScope = ScopeList.first; // 第一个元素 + auto aicScope = ScopeList.second; // 第二个元素 + + SplitScope(funcOp, main_graph, aivScope, aicScope, module); + } + + addSyncOpsForBufferWait(module); + // llvm::outs()<> mlir::triton::createDAGScopePass() { + return std::make_unique(); +} diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp new file mode 100644 index 0000000000..2d01330aed --- /dev/null +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp @@ -0,0 +1,1391 @@ +#include "TritonAffinityOpt/Passes.h" + +#include "bishengir/Dialect/Annotation/IR/Annotation.h" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "bishengir/Dialect/HIVM/IR/HIVMImpl.h" +#include "bishengir/Dialect/HIVM/IR/HIVMInterfaces.h" +#include "bishengir/Dialect/HIVM/Transforms/Passes.h" +#include "bishengir/Dialect/HIVM/Utils/Utils.h" +#include "bishengir/Dialect/Scope/IR/Scope.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/Support/Casting.h" + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" +#include +#include + +#include "TritonAffinityOpt/DAG.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DAGSYNC +#include "ascend/include/TritonAffinityOpt/Passes.h.inc" +} // namespace triton +} // namespace mlir + +// 使用 DAG 命名空间 +using namespace mlir; +using namespace hivm; +using namespace AffinityDAG; + +llvm::DenseMap *valueTypes; +// 修改类声明,将数据搬运逻辑集成到同步插入中 +namespace { +struct DAGSyncPass : public mlir::triton::impl::DAGSyncBase { + void runOnOperation() override; + +private: + // 原有的辅助函数 + CoreType getNodeDeviceType(OpNode *node, + llvm::DenseMap *valueTypes); + bool needVectorCubeSync(CoreType src, CoreType dst); + + // 修改后的同步插入函数,包含数据搬运 + void insertSyncAndMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, + CoreType srcType, CoreType dstType, + mlir::OpBuilder &builder, int flag, + llvm::DenseMap *valueMap, + Graph &mainGraph); + + // 新增:处理跨 block 的同步和数据搬运 + void insertSyncAndMovementForCrossBlock( + mlir::Operation *srcOp, mlir::Operation *dstOp, CoreType srcType, + CoreType dstType, mlir::OpBuilder &builder, int flag, + bool dstIsInnerBlock, llvm::DenseMap *valueMap, + Graph &mainGraph); + + // 新增:处理 scf.for 循环迭代参数的同步 + void processScfForSync(mlir::scf::ForOp forOp, Node *forNode, + llvm::DenseMap *valueTypes, + mlir::OpBuilder &builder, int &flag); + + // 数据搬运相关的辅助函数 + void insertCubeToVectorDataMovement(mlir::Operation *srcOp, + mlir::Operation *dstOp, + mlir::Value srcResult, + mlir::OpBuilder &builder, + mlir::Location loc, mlir::Value iterArgs); + + void + insertVectorToCubeDataMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, + Operation *posOp, mlir::Value srcResult, + mlir::OpBuilder &builder, mlir::Location loc, + llvm::DenseMap *valueMap); + + // 获取或创建合适的 memref.alloc + mlir::Value getOrCreateAllocation(mlir::Operation *op, mlir::Type tensorType, + hivm::AddressSpace addressSpace, + mlir::OpBuilder &builder, + mlir::Location loc); + + // 获取 tensor 的形状和元素类型 + mlir::RankedTensorType getTensorType(mlir::Value tensorValue); + + // 替换 dstOp 中使用 srcResult 的操作数 + void replaceOperandWithNewValue(mlir::Operation *dstOp, mlir::Value oldValue, + mlir::Value newValue); + + // Find sync position + Operation *FindLastestPosition(Operation *srcOp, Graph &mainGraph, + OpBuilder &builder); + Operation *FindEarliestPosition(Operation *dstOp, Graph &mainGraph, + OpBuilder &builder); +}; +} // namespace + +void DAGSyncPass::processScfForSync( + mlir::scf::ForOp forOp, Node *forNode, + llvm::DenseMap *valueTypes, mlir::OpBuilder &builder, + int &flag) { + + mlir::Block *loopBody = forOp.getBody(); + mlir::scf::YieldOp yieldOp = nullptr; + for (mlir::Operation &op : *loopBody) { + if (auto yield = mlir::dyn_cast(&op)) { + yieldOp = yield; + break; + } + } + Location loc = forOp.getLoc(); + + for (int i = 0; i < forOp.getInitArgs().size(); i++) { + mlir::BlockArgument iterArg = loopBody->getArgument(i + 1); + // 找到首次使用 + mlir::Operation *firstUser = nullptr; + + for (mlir::Operation &op : *loopBody) { + // 跳过 yield 操作 + if (mlir::isa(&op)) { + continue; + } + + // 检查是否使用该迭代参数 + bool usesIterArg = false; + for (mlir::Value operand : op.getOperands()) { + if (operand == iterArg) { + usesIterArg = true; + break; + } + } + + if (usesIterArg) { + firstUser = &op; + break; + } + } + // map 内找到对应的iterType,iterType由首次在loop内使用到的op定义 + if (!firstUser) { + continue; + } + CoreType iterType = CoreType::CUBE_AND_VECTOR; + if (valueTypes->find(firstUser->getResult(0)) != valueTypes->end()) { + iterType = valueTypes->find(firstUser->getResult(0))->second; + } + + // 获取对应yield + mlir::Value yieldOperand = yieldOp->getOperand(i); + CoreType yieldType = CoreType::CUBE_AND_VECTOR; + if (valueTypes->find(yieldOperand) != valueTypes->end()) { + yieldType = valueTypes->find(yieldOperand)->second; + } + mlir::Operation *yieldDefiningOp = yieldOperand.getDefiningOp(); + + if (yieldType == CoreType::CUBE_ONLY && iterType == CoreType::VECTOR_ONLY) { + + // 2. 插入同步指令 + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + + // set 在 yieldDefiningOp 后 + builder.setInsertionPointAfter(yieldDefiningOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + mlir::Value srcResult = yieldDefiningOp->getResult(0); + + // // 1. 插入数据搬运 + insertCubeToVectorDataMovement(yieldDefiningOp, firstUser, srcResult, + builder, loc, iterArg); + + // wait 在 firstUser 前 + builder.setInsertionPoint(firstUser); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + // llvm::outs() << "yieldOp" << yieldDefiningOp << "iterargs" << firstUser + // << "\n"; llvm::outs() << "Inserted CUBE->VECTOR sync and data movement + // (flag=" << flag << ")\n"; + } + // VECTOR -> CUBE + else if (yieldType == CoreType::VECTOR_ONLY && + iterType == CoreType::CUBE_ONLY) { + + // 2. 插入同步指令 + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto waitPipe = + PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + + // set 在 yieldDefiningOp 后 + builder.setInsertionPointAfter(yieldDefiningOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运 + // insertVectorToCubeDataMovement(yieldDefiningOp, firstUser, srcResult, + // builder, loc, iterArg); + + // wait 在 firstUser 前 + builder.setInsertionPoint(firstUser); + coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + // llvm::outs() << "yieldOp" << yieldDefiningOp << "iterargs" << firstUser + // << "\n"; llvm::outs() << "Inserted VECTOR->CUBE sync and data movement + // (flag=" << flag << ")\n"; + } + } +} + +// 获取节点的设备类型 +CoreType DAGSyncPass::getNodeDeviceType( + OpNode *node, llvm::DenseMap *valueTypes) { + if (!node || !node->op) { + return CoreType::CUBE_AND_VECTOR; + } + + // 尝试从节点的结果中获取设备类型 + // 通常使用第一个结果来代表节点的设备类型 + if (node->op->getNumResults() > 0) { + mlir::Value result = node->op->getResult(0); + auto it = valueTypes->find(result); + if (it != valueTypes->end()) { + return it->second; + } + } + + // 如果没有找到,检查操作数 + // for (mlir::Value operand : node->op->getOperands()) { + // auto it = valueTypes->find(operand); + // if (it != valueTypes->end()) { + // return it->second; + // } + // } + + return CoreType::CUBE_AND_VECTOR; // 默认 +} + +// 判断是否需要vector<->cube同步 +bool DAGSyncPass::needVectorCubeSync(CoreType src, CoreType dst) { + return (src == CoreType::VECTOR_ONLY && dst == CoreType::CUBE_ONLY) || + (src == CoreType::CUBE_ONLY && dst == CoreType::VECTOR_ONLY); +} + +// 获取 tensor 类型 +mlir::RankedTensorType DAGSyncPass::getTensorType(mlir::Value tensorValue) { + if (auto tensorType = + dyn_cast(tensorValue.getType())) { + return tensorType; + } + return nullptr; +} + +// 替换操作数 +void DAGSyncPass::replaceOperandWithNewValue(mlir::Operation *dstOp, + mlir::Value oldValue, + mlir::Value newValue) { + for (unsigned i = 0; i < dstOp->getNumOperands(); ++i) { + if (dstOp->getOperand(i) == oldValue) { + dstOp->setOperand(i, newValue); + // llvm::outs() << "Replaced operand " << i << " of " << + // dstOp->getName().getStringRef() + // << " with new value\n"; + } + } +} + +// 修改 getOrCreateAllocation 函数,将 alloc 提到函数最外层 +mlir::Value DAGSyncPass::getOrCreateAllocation(mlir::Operation *op, + mlir::Type tensorType, + hivm::AddressSpace addressSpace, + mlir::OpBuilder &builder, + mlir::Location loc) { + auto rankedTensorType = cast(tensorType); + auto elementType = rankedTensorType.getElementType(); + auto shape = rankedTensorType.getShape(); + + auto addressSpaceAttr = + hivm::AddressSpaceAttr::get(builder.getContext(), addressSpace); + auto memrefType = mlir::MemRefType::get(shape, elementType, + /*layout=*/nullptr, addressSpaceAttr); + + // 查找是否已经存在相同类型的 allocation(在函数的 entry block 中) + mlir::Operation *funcOp = op; + while (funcOp && !mlir::isa(funcOp)) { + funcOp = funcOp->getParentOp(); + } + + if (auto func = mlir::dyn_cast(funcOp)) { + // 在函数的 entry block 中查找现有的 allocation + mlir::Block &entryBlock = func.getBody().front(); + // for (auto& blockOp : entryBlock) { + // if (auto allocOp = mlir::dyn_cast(&blockOp)) { + // if (allocOp.getType() == memrefType) { + // // 找到匹配的 allocation,直接复用 + // llvm::outs() << "Reusing existing allocation: " << allocOp << + // "\n"; return allocOp.getResult(); + // } + // } + // } + + // 没有找到现有的 allocation,在函数开头创建新的 + builder.setInsertionPointToStart(&entryBlock); + return builder.create(loc, memrefType); + } + + // 如果没有找到函数,回退到原逻辑 + builder.setInsertionPoint(op); + return builder.create(loc, memrefType); +} + +// 插入 CUBE -> VECTOR 数据搬运 +void DAGSyncPass::insertCubeToVectorDataMovement( + mlir::Operation *srcOp, mlir::Operation *dstOp, mlir::Value srcResult, + mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iterArgs) { + auto srcTensorType = getTensorType(srcResult); + if (!srcTensorType) { + return; + } + + // 1. 在 srcOp 之后创建 UB 空间的 memref.alloc + builder.setInsertionPointAfter(srcOp); + mlir::Value ubAlloc = getOrCreateAllocation( + srcOp, srcTensorType, hivm::AddressSpace::UB, builder, loc); + + // 2. 创建 fixpipe 指令 + builder.setInsertionPointAfter(srcOp); + FixpipeDMAModeAttr dmaModeAttr = + FixpipeDMAModeAttr::get(builder.getContext(), FixpipeDMAMode::NZ2ND); + + auto fixpipeOp = + builder.create(loc, mlir::TypeRange{}, // 没有返回值 + srcResult, // src + ubAlloc, // dst + /*unit_flag_cond=*/mlir::ValueRange{}, + /*dma_mode=*/dmaModeAttr, + /*dual_dst_mode=*/nullptr, + /*pre_quant=*/nullptr, + /*pre_relu=*/nullptr, + /*channel_split=*/nullptr, + /*unit_flag_mode=*/mlir::ArrayAttr{}); + + llvm::outs() << "Inserted fixpipe after " << srcOp->getName().getStringRef() + << " for CUBE->VECTOR data movement\n"; + + // 3. 在 dstOp 前创建 memory_space_cast 和 to_tensor + builder.setInsertionPoint(dstOp); + + // memory_space_cast(如果需要) + mlir::Value plainMemref = ubAlloc; + auto memrefType = cast(ubAlloc.getType()); + if (memrefType.getMemorySpace()) { + auto plainMemrefType = mlir::MemRefType::get(memrefType.getShape(), + memrefType.getElementType()); + plainMemref = builder.create( + loc, plainMemrefType, ubAlloc); + (*valueTypes)[plainMemref] = CoreType::VECTOR_ONLY; + } + + // 4. 创建 to_tensor + auto toTensorOp = builder.create( + loc, + srcTensorType, // 原始的 tensor 类型 + plainMemref, + /*restrict=*/true, + /*writable=*/true); + (*valueTypes)[toTensorOp.getResult()] = CoreType::VECTOR_ONLY; + + // 5. 替换 dstOp 的操作数 + if (!iterArgs) { + replaceOperandWithNewValue(dstOp, srcResult, toTensorOp.getResult()); + } else { + replaceOperandWithNewValue(dstOp, iterArgs, toTensorOp.getResult()); + } +} + +static uint64_t getElemBytesForAlign(Type t) { + if (auto ft = dyn_cast(t)) + return (uint64_t)((ft.getWidth() + 7) / 8); + if (auto it = dyn_cast(t)) + return (uint64_t)((it.getWidth() + 7) / 8); + if (isa(t)) + return 8ULL; + if (auto ct = dyn_cast(t)) + return 2ULL * getElemBytesForAlign(ct.getElementType()); + return 0ULL; +} + +static FailureOr getBlockElemsFor32BAlign(Type elemType) { + constexpr uint64_t kAlignBytes = 32; + uint64_t elemBytes = getElemBytesForAlign(elemType); + if (elemBytes <= 0) + return failure(); + if (elemBytes >= kAlignBytes) + return 1; + if (kAlignBytes % elemBytes != 0) + return failure(); + return kAlignBytes / elemBytes; +} + +static std::optional> +newCbubAllocShape(memref::AllocOp allocOp) { + auto type = dyn_cast(allocOp.getType()); + // 仅支持静态 2D MemRef + if (!type || type.getRank() != 2) + return std::nullopt; + + auto shape = type.getShape(); + int64_t M = shape[0]; + int64_t N = shape[1]; + auto elemType = type.getElementType(); + auto blkOr = getBlockElemsFor32BAlign(elemType); + int64_t blk = (int64_t)*blkOr; + // 必须是静态且 16 对齐 + if (ShapedType::isDynamic(M) || ShapedType::isDynamic(N)) + return std::nullopt; + if (M % 16 != 0) + return std::nullopt; + + // 新 shape: (N/16, M/16, 16, 16) + SmallVector newShape = {N / blk, M / 16, 16, blk}; + + return newShape; +} + +// 修改 VECTOR->CUBE 数据搬运函数 +void DAGSyncPass::insertVectorToCubeDataMovement( + mlir::Operation *srcOp, mlir::Operation *dstOp, Operation *posOp, + mlir::Value srcResult, mlir::OpBuilder &builder, mlir::Location loc, + llvm::DenseMap *valueMap) { + auto srcTensorType = getTensorType(srcResult); + if (!srcTensorType) { + return; + } + if (isa(srcOp) && isa(dstOp)) { + return; + } + + // 1. 在 srcOp 之后创建 UB 空间的 memref.alloc(用于 to_memref) + builder.setInsertionPointAfter(srcOp); + + // 首先创建 UB 空间的 memref type + auto ubSpaceAttr = + hivm::AddressSpaceAttr::get(builder.getContext(), hivm::AddressSpace::UB); + auto ubMemrefType = mlir::MemRefType::get(srcTensorType.getShape(), + srcTensorType.getElementType(), + /*layout=*/nullptr, ubSpaceAttr); + + // 创建 bufferization.to_memref + if (srcOp->getBlock() == dstOp->getBlock()) { + builder.setInsertionPoint(posOp); + } + auto toMemrefOp = + builder.create(loc, ubMemrefType, srcResult); + + // 2. 创建 CBUF 空间的 memref.alloc(用于 copy 的目标) + mlir::Value cbufAllocOld = getOrCreateAllocation( + srcOp, srcTensorType, hivm::AddressSpace::L1, builder, loc); + auto cbufShape = *newCbubAllocShape( + dyn_cast(cbufAllocOld.getDefiningOp())); + // 获取旧的memref类型并创建新的类型 + auto oldType = dyn_cast(cbufAllocOld.getType()); + + // 获取新的维度数量 + unsigned newRank = cbufShape.size(); + + // 方法1:创建新的恒等布局映射 + AffineMap identityMap = builder.getMultiDimIdentityMap(newRank); + MemRefLayoutAttrInterface layout = AffineMapAttr::get(identityMap); + + // 方法2:如果旧类型有布局,尝试调整它(更安全的选择) + // 先检查旧类型是否有布局 + if (auto oldLayout = oldType.getLayout()) { + if (auto affineMapAttr = dyn_cast(oldLayout)) { + // 如果旧布局是AffineMap,尝试创建新的恒等映射 + // 因为维度改变,旧的affine map可能不再有效 + layout = AffineMapAttr::get(identityMap); + } else { + // 对于其他类型的布局,可能需要特殊处理 + layout = oldLayout; + } + } + + // 创建新的alloc类型 + auto newAllocType = MemRefType::get(cbufShape, oldType.getElementType(), + layout, // 使用新创建的布局 + oldType.getMemorySpace()); + + builder.setInsertionPoint(cbufAllocOld.getDefiningOp()); + // 创建新的alloc操作 + auto cbufAlloc = builder.create( + cbufAllocOld.getDefiningOp()->getLoc(), newAllocType); + + builder.setInsertionPointAfter(toMemrefOp); + // 3. 创建 copy 指令(src 是 ub memref,dst 是 cbuf memref) + auto copyOp = + builder.create(loc, mlir::TypeRange{}, // 没有返回值 + toMemrefOp.getResult(), // src (memref in UB) + cbufAlloc // dst (memref in CBUF) + ); + + // llvm::outs() << "Inserted copy after " << srcOp->getName().getStringRef() + // << " for VECTOR->CUBE data movement\n"; + + // 4. 在 dstOp 前创建 convert_layout + builder.setInsertionPoint(dstOp); + auto ndLayout = + hivm::DataLayoutAttr::get(builder.getContext(), hivm::DataLayout::ND); + // 创建 convert_layout + auto convertLayoutOp = builder.create( + loc, + cbufAllocOld.getType(), // 输出类型与输入相同 + cbufAlloc, + ndLayout, // srcLayout + ndLayout // dstLayout + ); + (*valueTypes)[convertLayoutOp.getResult()] = CoreType::CUBE_ONLY; + + // 5. 创建 memory_space_cast + auto cbufMemrefType = cast(convertLayoutOp.getType()); + auto plainMemrefType = mlir::MemRefType::get(cbufMemrefType.getShape(), + cbufMemrefType.getElementType()); + + auto memspaceCastOp = builder.create( + loc, plainMemrefType, convertLayoutOp.getResult()); + (*valueTypes)[memspaceCastOp.getResult()] = CoreType::CUBE_ONLY; + + // 6. 创建 to_tensor + auto toTensorOp = builder.create( + loc, + srcTensorType, // 原始的 tensor 类型 + memspaceCastOp.getResult(), + /*restrict=*/true, + /*writable=*/true); + (*valueTypes)[toTensorOp.getResult()] = CoreType::CUBE_ONLY; + + // 7. 替换 dstOp 的操作数 + replaceOperandWithNewValue(dstOp, srcResult, toTensorOp.getResult()); +} + +Operation *DAGSyncPass::FindLastestPosition(Operation *srcOp, Graph &mainGraph, + OpBuilder &builder) { + Operation *insertPos = nullptr; + auto opMap = mainGraph.getOpMapLegacy(); + auto valueTypes = &mainGraph.getValueTypes(); + // Find the first cube-dependent vector core operation. + for (auto nextOp = srcOp->getNextNode(); nextOp != nullptr; + nextOp = nextOp->getNextNode()) { + auto nextType = getNodeDeviceType(opMap[nextOp], valueTypes); + if (nextType == CoreType::CUBE_ONLY) + continue; + // No memref ops in IR yet; directly tracing operands + for (auto operand : nextOp->getOperands()) { + auto defOp = operand.getDefiningOp(); + auto defType = getNodeDeviceType(opMap[defOp], valueTypes); + if (defType == CoreType::CUBE_ONLY) { + // To prevent UB overflow, we need to break the dependency at the point + // where the result shape is minimized + // — i.e., trace upward to find the first broadcast. + for (auto prevOp = nextOp->getPrevNode(); + prevOp != nullptr && prevOp != srcOp; + prevOp = prevOp->getPrevNode()) { + if (isa(prevOp)) { + if (prevOp->getPrevNode() && + isa(prevOp->getPrevNode())) { + return prevOp->getPrevNode(); + } + return prevOp; + } + } + // Can't find the result shape is minimized + return nextOp; + } + } + + // Once meet SyncBlockWaitOp, return now! + if (auto waitOp = dyn_cast(nextOp)) { + if (waitOp.getTcoreType() == + hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR)) { + return nextOp; + } + } + insertPos = nextOp; + } + return insertPos; +} + +Operation *DAGSyncPass::FindEarliestPosition(Operation *dstOp, Graph &mainGraph, + OpBuilder &builder) { + auto insertPos = dstOp; + auto opMap = mainGraph.getOpMapLegacy(); + auto valueTypes = &mainGraph.getValueTypes(); + for (auto prevOp = dstOp->getPrevNode(); prevOp != nullptr; + prevOp = prevOp->getPrevNode()) { + if (dstOp->getBlock() != prevOp->getBlock()) + continue; + // Once meet SyncBlockSetOp, return now! + if (auto waitOp = dyn_cast(prevOp)) { + if (waitOp.getTcoreType() == + hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR)) { + return insertPos; + } + } + insertPos = prevOp; + } + return insertPos; +} + +// 主要的同步和数据搬运插入函数 +void DAGSyncPass::insertSyncAndMovement( + mlir::Operation *srcOp, mlir::Operation *dstOp, CoreType srcType, + CoreType dstType, mlir::OpBuilder &builder, int flag, + llvm::DenseMap *valueMap, Graph &mainGraph) { + mlir::Location loc = srcOp->getLoc(); + // 保存当前的插入点 + mlir::OpBuilder::InsertionGuard guard(builder); + + // 检查是否是跨 block + mlir::Block *srcBlock = srcOp->getBlock(); + mlir::Block *dstBlock = dstOp->getBlock(); + bool sameBlock = (srcBlock == dstBlock); + + if (!sameBlock) { + // 检查是否是外层到内层的依赖 + bool dstIsInnerBlock = false; + mlir::Operation *dstParentOp = dstBlock->getParentOp(); + while (dstParentOp) { + if (dstParentOp->getBlock() == srcBlock) { + dstIsInnerBlock = true; + break; + } + if (dstParentOp->getBlock()) { + dstParentOp = dstParentOp->getBlock()->getParentOp(); + } else { + break; + } + } + + if (dstIsInnerBlock) { + insertSyncAndMovementForCrossBlock(srcOp, dstOp, srcType, dstType, + builder, flag, true, valueMap, + mainGraph); + return; + } + } + + // 同一 block 内的处理 + // 获取 srcOp 的输出(假设第一个结果) + if (srcOp->getNumResults() == 0) { + return; + } + mlir::Value srcResult = srcOp->getResult(0); + + // CUBE -> VECTOR + if (srcType == CoreType::CUBE_ONLY && dstType == CoreType::VECTOR_ONLY) { + + // 2. 插入同步指令 + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto lastSetPipe = + PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto lastWaitPipe = + PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + auto flagAddId = builder.getIntegerAttr(builder.getI64Type(), flag * 2); + auto lastFlagAddId = + builder.getIntegerAttr(builder.getI64Type(), (flag - 1) * 2); + + // set 在 srcOp 后 + builder.setInsertionPointAfter(srcOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // wait 在 dstOp 前 + + auto posOp = FindEarliestPosition(dstOp, mainGraph, builder); + builder.setInsertionPoint(posOp); + coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运 + insertCubeToVectorDataMovement(srcOp, dstOp, srcResult, builder, loc, + nullptr); + + // llvm::outs() << "Inserted CUBE->VECTOR sync and data movement (flag=" << + // flag << ")\n"; + } + // VECTOR -> CUBE + else if (srcType == CoreType::VECTOR_ONLY && dstType == CoreType::CUBE_ONLY) { + + // 2. 插入同步指令 + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto lastSetPipe = + PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto lastWaitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + auto flagAddId = builder.getIntegerAttr(builder.getI64Type(), flag * 2); + auto lastFlagAddId = + builder.getIntegerAttr(builder.getI64Type(), (flag - 1) * 2); + + // set 在 srcOp 后 + // builder.setInsertionPointAfter(srcOp); + auto posOp = FindLastestPosition(srcOp, mainGraph, builder); + if (posOp) { + builder.setInsertionPoint(posOp); + } else { + builder.setInsertionPointAfter(srcOp); + } + auto setOp = builder.create(loc, coreAttr, setPipe, + waitPipe, flagId); + + // wait 在 dstOp 前 + builder.setInsertionPoint(dstOp); + coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运 + insertVectorToCubeDataMovement(srcOp, dstOp, setOp, srcResult, builder, loc, + valueMap); + + // llvm::outs() << "Inserted VECTOR->CUBE sync and data movement (flag=" << + // flag << ")\n"; + } +} + +// 跨 block 的同步和数据搬运 +void DAGSyncPass::insertSyncAndMovementForCrossBlock( + mlir::Operation *srcOp, mlir::Operation *dstOp, CoreType srcType, + CoreType dstType, mlir::OpBuilder &builder, int flag, bool dstIsInnerBlock, + llvm::DenseMap *valueMap, Graph &mainGraph) { + if (!dstIsInnerBlock) { + insertSyncAndMovement(srcOp, dstOp, srcType, dstType, builder, flag, + valueMap, mainGraph); + return; + } + + mlir::Location loc = srcOp->getLoc(); + mlir::Block *dstBlock = dstOp->getBlock(); + + // 获取 srcOp 的输出 + if (srcOp->getNumResults() == 0) { + return; + } + mlir::Value srcResult = srcOp->getResult(0); + + // CUBE -> VECTOR + if (srcType == CoreType::CUBE_ONLY && dstType == CoreType::VECTOR_ONLY) { + + // 2. 插入同步指令(跨 block 特殊处理) + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + + // set 在 srcOp 后(外层) + builder.setInsertionPointAfter(srcOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运(同 block 内逻辑) + insertCubeToVectorDataMovement(srcOp, dstOp, srcResult, builder, loc, + nullptr); + + // wait 在内层 block 入口前 + mlir::Operation *parentOp = dstBlock->getParentOp(); + if (parentOp) { + while (srcOp->getBlock() != parentOp->getBlock()) { + parentOp = parentOp->getBlock()->getParentOp(); + } + builder.setInsertionPoint(parentOp); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } else { + builder.setInsertionPoint(dstOp); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } + + } + // VECTOR -> CUBE + else if (srcType == CoreType::VECTOR_ONLY && dstType == CoreType::CUBE_ONLY) { + + // 2. 插入同步指令(跨 block 特殊处理) + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + + // set 在 srcOp 后(外层) + builder.setInsertionPointAfter(srcOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运(同 block 内逻辑) + insertVectorToCubeDataMovement(srcOp, dstOp, srcOp, srcResult, builder, loc, + valueMap); + + // wait 在内层 block 入口前 + mlir::Operation *parentOp = dstBlock->getParentOp(); + if (parentOp) { + builder.setInsertionPoint(parentOp); + coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } else { + builder.setInsertionPoint(dstOp); + coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } + + // llvm::outs() << "Inserted cross-block VECTOR->CUBE sync and data movement + // (flag=" << flag << ")\n"; + } +} + +void LegalizeDot(triton::FuncOp funcOp) { + mlir::OpBuilder builder(funcOp); + funcOp.walk([&](triton::DotOp dotOp) { + // 获取dot操作的输入 + Value a = dotOp.getOperands()[0]; + Value b = dotOp.getOperands()[1]; + Value c = dotOp.getOperands()[2]; // 累加器参数 + + // 检查累加器是否为全零常量 + bool isZeroAccumulator = false; + + // 检查是否直接是arith.constant 0 + if (auto constantOp = c.getDefiningOp()) { + if (auto denseAttr = dyn_cast(constantOp.getValue())) { + if (denseAttr.isSplat() && + denseAttr.getSplatValue().getValueAsDouble() == 0.0) { + isZeroAccumulator = true; + } + } + } + + if (!isZeroAccumulator) { + // 创建新的零累加器 + Location loc = dotOp.getLoc(); + auto resultType = dotOp.getResult().getType(); + + Value originalResult = dotOp.getResult(); + builder.setInsertionPoint(dotOp); + // 创建全零张量 + auto zeroAttr = DenseElementsAttr::get( + dyn_cast(resultType), APFloat(0.0f)); + auto zeroConstant = builder.create(loc, zeroAttr); + + // 创建新的dot操作,使用零作为累加器 + auto newDot = + builder.create(loc, resultType, a, b, zeroConstant); + + // 创建加法操作,将新的dot结果与原来的累加器c相加 + auto addOp = builder.create(loc, newDot, c); + + // 用addOp替换原来的dotOp + originalResult.replaceAllUsesWith(addOp.getResult()); + + // 删除原dotOp(如果它没有其他用途) + if (dotOp.use_empty()) { + dotOp.erase(); + } + } + }); +} + +static void rewriteCopyChainForCbub(hivm::CopyOp copyOp, + ArrayRef newShape, + OpBuilder &builder) { + + // 获取 copy 的输入(ins),应为 to_memref 的结果 + Value insVal = copyOp.getOperands()[0]; + auto toMemRefOp = insVal.getDefiningOp(); + if (!toMemRefOp) + return; + + Value inputTensor = toMemRefOp.getTensor(); + auto inputTensorType = dyn_cast(inputTensor.getType()); + if (!inputTensorType || inputTensorType.getRank() != 2) + return; + + // blk = 32/位宽 + // 中间 reshape 形状:[M/16, 16, N/ blk, blk] + int64_t M = inputTensorType.getShape()[0]; + int64_t N = inputTensorType.getShape()[1]; + auto elemType = inputTensorType.getElementType(); + auto blkOr = getBlockElemsFor32BAlign(elemType); + int64_t blk = (int64_t)*blkOr; + SmallVector intermediateShape3D = {M, N / blk, blk}; + SmallVector intermediateShapetrans = {N / blk, M, blk}; + auto elementType = inputTensorType.getElementType(); + auto interTensor3DType = + RankedTensorType::get(intermediateShape3D, elementType); + auto interTensortransType = + RankedTensorType::get(intermediateShapetrans, elementType); + + auto finalTensorType = RankedTensorType::get(newShape, elementType); + + auto loc = inputTensor.getLoc(); + + // Set insertion point before copyOp (or toMemRefOp) + auto tensorOp = inputTensor.getDefiningOp(); + builder.setInsertionPointAfter(tensorOp); + + // 插入 triton.reshape 将 2D tensor 展开为 3D + auto reshape3DOp = + builder.create(loc, interTensor3DType, inputTensor); + (*valueTypes)[reshape3DOp.getResult()] = CoreType::VECTOR_ONLY; + + // nark tiling dim for reshapeop + auto markOp3d = builder.create(loc, reshape3DOp); + auto tilingDimAttr3d = builder.getDictionaryAttr(SmallVector{ + NamedAttribute(builder.getStringAttr("1"), builder.getIndexAttr(1))}); + markOp3d->setAttr("tiling_dim_mapping", tilingDimAttr3d); + + // 插入 triton.trans 调整维度顺序 Insert tt.trans {order = [1, 0, 2]} + SmallVector order = {1, 0, 2}; + auto orderAttr = + builder.getDenseI32ArrayAttr(order); // OpBuilder supports this + auto transOp = builder.create( + loc, interTensortransType, reshape3DOp.getResult(), orderAttr); + (*valueTypes)[transOp.getResult()] = CoreType::VECTOR_ONLY; + + // 插入 triton.reshape 将 3D tensor 展开为 4D + auto reshape4DOp = builder.create(loc, finalTensorType, + transOp.getResult()); + (*valueTypes)[reshape4DOp.getResult()] = CoreType::VECTOR_ONLY; + + // nark tiling dim for reshapeop + auto markOp4d = builder.create(loc, reshape4DOp); + auto tilingDimAttr4d = builder.getDictionaryAttr(SmallVector{ + NamedAttribute(builder.getStringAttr("1"), builder.getIndexAttr(1))}); + markOp4d->setAttr("tiling_dim_mapping", tilingDimAttr4d); + + // Create new to_memref + builder.setInsertionPoint(toMemRefOp); + auto newMemRefType = MemRefType::get(newShape, elementType, mlir::AffineMap{}, + toMemRefOp.getType().getMemorySpace()); + auto newToMemRefOp = builder.create( + toMemRefOp.getLoc(), newMemRefType, reshape4DOp.getResult()); + (*valueTypes)[newToMemRefOp.getResult()] = CoreType::VECTOR_ONLY; + + // Create NEW copyOp (replacing the old one) + builder.setInsertionPoint(copyOp); + auto resultTypes = copyOp->getResultTypes(); + auto newCopyOp = + builder.create(copyOp.getLoc(), + resultTypes, // TypeRange + reshape4DOp.getResult(), // src (ins) + copyOp.getOperands()[1] // dst (outs) + ); + + // 替换 uses 并清理旧 op + copyOp.replaceAllUsesWith(newCopyOp); + copyOp.erase(); + toMemRefOp.erase(); + + return; +} + +template +OpTy createBlockSync(OpBuilder builder, hivm::TCoreType coreType, + hivm::PIPE srcPipe, hivm::PIPE dstPipe, int flag, + Operation *cause) { + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), coreType); + auto setPipe = PipeAttr::get(builder.getContext(), srcPipe); + auto waitPipe = PipeAttr::get(builder.getContext(), dstPipe); + return builder.create(cause->getLoc(), coreAttr, setPipe, waitPipe, + flagId); +} + +// since we do not have llvm::set_intersects in this version... +template bool intersects(S1Ty &s1, S2Ty &s2) { + if (s1.size() > s2.size()) { + return intersects(s2, s1); + } + + return llvm::any_of(s1, [&](auto e) { return s2.count(e); }); +} + +bool mayAlias(DataFlowSolver &solver, Value ptrA, Value ptrB) { + if (ptrA == ptrB) { + return true; + } + const auto *stateA = solver.lookupState>(ptrA); + const auto *stateB = solver.lookupState>(ptrB); + if (!stateA || !stateB) { // not triton ptr type + return true; + } + auto infoA = stateA->getValue(); + auto infoB = stateB->getValue(); + + return intersects(infoA.getAllocs(), infoB.getAllocs()); +} + +const size_t MAX_EXPECTED_PARENTS_COUNT = 8; + +std::optional> +findAncestorCommonBlock(mlir::Operation *opA, mlir::Operation *opB) { + if (opA->getBlock() == opB->getBlock()) { + return std::make_pair(opA, opB); + } + + // record all ancestors of opA + llvm::SmallPtrSet ancestorsA; + mlir::Operation *curr = opA; + while (curr) { + ancestorsA.insert(curr); + curr = curr->getParentOp(); + } + + // find the last ancestor of opB which is also the ancestor of opA + mlir::Operation *commonAncOp = nullptr; + curr = opB; + while (curr) { + if (ancestorsA.count(curr)) { + commonAncOp = curr; + break; + } + curr = curr->getParentOp(); + } + + if (!commonAncOp) { + return std::nullopt; + } + + // find the ancestors in the given region + for (mlir::Region ®ion : commonAncOp->getRegions()) { + for (mlir::Block &block : region) { + auto *ancA = block.findAncestorOpInBlock(*opA); + auto *ancB = block.findAncestorOpInBlock(*opB); + if (ancA && ancB) { + return std::make_pair(ancA, ancB); + } + } + } + return std::nullopt; +} + +struct SyncCandidate { + CoreType srcCoreType; + Operation *setCause; + Operation *setAfter; + Operation *waitCause; + Operation *waitBefore; +}; + +// setOp, waitOp +void createBlockSyncBetween(OpBuilder builder, hivm::PIPE srcPipe, + hivm::PIPE dstPipe, SyncCandidate candidate, + int flag) { + auto srcCoreType = toHivm(candidate.srcCoreType); + auto dstCoreType = toHivm(!candidate.srcCoreType); + + builder.setInsertionPointAfter(candidate.setAfter); + auto setOp = createBlockSync( + builder, srcCoreType, srcPipe, dstPipe, flag, candidate.setCause); + builder.setInsertionPoint(candidate.waitBefore); + auto waitOp = createBlockSync( + builder, dstCoreType, srcPipe, dstPipe, flag, candidate.waitCause); +}; + +void addMemEffectsSync(triton::FuncOp funcOp, Graph *graph, OpBuilder &builder, + int &syncFlag) { + DominanceInfo domInfo(funcOp); + PostDominanceInfo postDomInfo(funcOp); + DataFlowSolver solver; + solver.load(); + solver.load(); + + if (failed(solver.initializeAndRun(funcOp))) { + funcOp->emitWarning("SharedMemoryAliasAnalysis failed! This could lead to " + "potential memory related issues! \n"); + } + + // [(node, EffectInstance, LinearisationPt)] + llvm::SmallVector> memOps; + + // [(setAfter, waitBefore, srcOP, dstOp)][CoreType] + llvm::SmallVector candidates; + + funcOp.walk([&](MemoryEffectOpInterface memIface) { + auto *op = memIface.getOperation(); + if (llvm::isa(op)) { + return; + } + + auto *currNode = graph->getOpMap()[op].get(); + SmallVector effects; + + memIface.getEffects(effects); + + for (auto &effect : effects) { + if (!isa(effect.getEffect())) { + continue; + } + memOps.emplace_back(currNode, effect); + bool isWrite = isa(effect.getEffect()); + for (auto &[prevNode, prevEffect] : memOps) { + if ((isa(prevEffect.getEffect()) || isWrite) && + mayAlias(solver, prevEffect.getValue(), effect.getValue()) && + prevNode->isOn() != + currNode->isOn() // write is forced on single core type, so we + // are safe to judge based on whether the core + // types are different + ) { + CoreType srcCoreType = isWrite ? !currNode->isOn() : prevNode->isOn(); + auto opPair = findAncestorCommonBlock(prevNode->op, currNode->op); + if (!opPair.has_value()) { + op->emitWarning(llvm::formatv( + "Unable to find ancestors in common block with {0}\n", + *prevNode->op)); + continue; + } + auto [setAfter, waitBefore] = opPair.value(); + if (setAfter == waitBefore) { + continue; + } + candidates.push_back(SyncCandidate{srcCoreType, prevNode->op, + setAfter, op, waitBefore}); + } + } + } + }); + + auto addBlockSyncCommon = [&builder, &syncFlag](SyncCandidate cand) { + llvm::dbgs() << "\n\n=== Insert sync between ===\n" + << *cand.setAfter << "\n" + << *cand.waitBefore << "\n=== Insert Sync End ===\n\n"; + + auto srcPipe = cand.srcCoreType == CoreType::CUBE_ONLY + ? hivm::PIPE::PIPE_FIX + : hivm::PIPE::PIPE_MTE2; + auto dstPipe = hivm::PIPE::PIPE_S; + createBlockSyncBetween(builder, srcPipe, dstPipe, cand, syncFlag % 14); + syncFlag++; + }; + + if (candidates.empty()) { + return; + } + + auto setAfterDominate = [&domInfo](Operation *a, Operation *b) { + if (domInfo.dominates(a, b)) { + return true; + } + if (domInfo.dominates(b, a)) { + return false; + } + if (a->isAncestor(b)) { + return false; + } + if (b->isAncestor(a)) { + return true; + } + return false; + }; + + auto waitBeforePostDominate = [&postDomInfo](Operation *a, Operation *b) { + if (postDomInfo.postDominates(a, b)) { + return true; + } + if (postDomInfo.postDominates(b, a)) { + return false; + } + if (a->isAncestor(b)) { + return true; + } + if (b->isAncestor(a)) { + return false; + } + return false; + }; + + llvm::sort(candidates, [&](const SyncCandidate &a, const SyncCandidate &b) { + if (a.setAfter != b.setAfter) { + return setAfterDominate(a.setAfter, b.setAfter); + } + + if (a.waitBefore != b.waitBefore) { + return waitBeforePostDominate(a.waitBefore, b.waitBefore); + } + + return false; + }); + + for (auto [i, cand] : llvm::enumerate(candidates)) { + bool shouldInsert = true; + for (auto otherCand : ArrayRef(candidates).drop_front(i + 1)) { + bool duplicated = (cand.waitBefore == otherCand.waitBefore && + cand.setAfter == otherCand.setAfter && + cand.srcCoreType == otherCand.srcCoreType); + bool containsOther = + (cand.srcCoreType == otherCand.srcCoreType && + setAfterDominate(cand.setAfter, otherCand.setAfter) && + waitBeforePostDominate(cand.waitBefore, otherCand.waitBefore)); + if (duplicated || containsOther) { + shouldInsert = false; + break; + } + } + + if (shouldInsert) { + addBlockSyncCommon(cand); + } + } +} + +void DAGSyncPass::runOnOperation() { + auto module = getOperation(); + mlir::OpBuilder builder(&getContext()); + + // 遍历所有函数 + for (auto funcOp : + llvm::make_early_inc_range(module.getOps())) { + // 跳过无效函数 + LegalizeDot(funcOp); + if (funcOp.getBody().empty()) { + continue; + } + + // llvm::outs() << "\n====================================\n"; + // llvm::outs() << "处理函数: " << funcOp.getName() << "\n"; + // llvm::outs() << "====================================\n"; + + auto unique_graph = Graph::fromMultiBlockFunc(funcOp); + std::shared_ptr shared_graph = std::move(unique_graph); + auto &main_graph = *shared_graph; + + auto funcName = funcOp.getName(); + + // 获取 DAG 图的映射 + auto opMapRaw = main_graph.getOpMapLegacy(); + valueTypes = &main_graph.getValueTypes(); + auto *opMap = &opMapRaw; + + if (!opMap || !valueTypes) { + llvm::errs() << "Warning: Failed to create DAG graph for function " + << funcOp.getName() << "\n"; + continue; + } + + // 用于避免重复插入同步 + llvm::DenseSet> + processedPairs; + int syncFlag = 1; + addMemEffectsSync(funcOp, shared_graph.get(), builder, syncFlag); + + // 3. 使用 walk 遍历函数中的所有操作 + funcOp.walk([&](mlir::Operation *op) { + // 查找当前操作对应的 Node + auto nodeIt = opMap->find(op); + if (nodeIt == opMap->end()) { + // 这个操作不在 entry block 的 DAG 图中 + // 可能是嵌套在控制流内部的操作 + return; + } + + OpNode *currentNode = nodeIt->second; + + // 检查是否是 scf.for 操作 + if (auto forOp = mlir::dyn_cast(op)) { + // 处理 scf.for 循环的特殊同步逻辑 + int temp = syncFlag % 14; + processScfForSync(forOp, currentNode, valueTypes, builder, temp); + } + + // 获取当前节点的设备类型 + CoreType currentType = getNodeDeviceType(currentNode, valueTypes); + + // 打印操作信息(可选) + // if (!llvm::isa(op->getDialect())) { + // llvm::outs() << "操作: " << *op + // << " 设备类型: " + // << (currentType == CoreType::VECTOR_ONLY ? "VECTOR" : + // currentType == CoreType::CUBE_ONLY ? "CUBE" : + // "SCALAR") + // << "\n"; + // } + + // 4. 遍历当前节点的所有输入节点 + for (ValueNode *inputValNode : currentNode->getInputs()) { + auto inputOp = inputValNode->value.getDefiningOp(); + if (!inputOp || !opMap->contains(inputOp)) { + continue; + } + + auto inputNode = (*opMap)[inputOp]; + + // 获取输入节点的设备类型 + CoreType inputType = getNodeDeviceType(inputNode, valueTypes); + + // 5. 判断是否需要插入同步和数据搬运 + if (needVectorCubeSync(inputType, currentType)) { + // 检查是否已经处理过这对操作 + auto opPair = std::make_pair(inputOp, op); + if (processedPairs.insert(opPair).second) { + // 插入同步和数据搬运指令 + // 检查是否是跨 block 的依赖 + mlir::Block *srcBlock = inputOp->getBlock(); + mlir::Block *dstBlock = op->getBlock(); + + if (srcBlock == dstBlock) { + // 同一 block 内 + insertSyncAndMovement(inputOp, op, inputType, currentType, + builder, syncFlag % 14, valueTypes, + main_graph); + syncFlag++; + } else { + // 跨 block,判断是否是外层到内层 + llvm::outs() << "#########\n"; + bool dstIsInnerBlock = false; + mlir::Operation *dstParentOp = dstBlock->getParentOp(); + + // 向上查找,看 dstBlock 是否在 srcBlock 的区域内 + while (dstParentOp) { + if (dstParentOp->getBlock() == srcBlock) { + dstIsInnerBlock = true; + break; + } + if (dstParentOp->getBlock()) { + dstParentOp = dstParentOp->getBlock()->getParentOp(); + } else { + break; + } + } + if (dstIsInnerBlock) { + + insertSyncAndMovementForCrossBlock( + inputOp, op, inputType, currentType, builder, syncFlag % 14, + dstIsInnerBlock, valueTypes, main_graph); + syncFlag++; + } + } + } + } + } + }); + + // llvm::outs() << "\n函数 " << funcOp.getName() << " 统计:\n"; + // llvm::outs() << " - 插入的总同步操作数: " << syncFlag << "\n"; + funcOp.walk([&](hivm::CopyOp copyOp) { + llvm::outs() << copyOp << " sss\n\n\n\n"; + rewriteCopyChainForCbub( + copyOp, + dyn_cast(copyOp.getOperands()[1].getType()).getShape(), + builder); + }); + GraphManager::getInstance().registerGraph(funcName, shared_graph); + } + + // llvm::outs()<> mlir::triton::createDAGSyncPass() { + return std::make_unique(); +} diff --git a/third_party/ascend/python/src/ir.cc b/third_party/ascend/python/src/ir.cc index d52868ea6f..f131669ef3 100644 --- a/third_party/ascend/python/src/ir.cc +++ b/third_party/ascend/python/src/ir.cc @@ -231,6 +231,14 @@ void init_triton_ir(py::module &&m) { py::class_(m, "context", py::module_local()) .def(py::init<>()) + .def( + "__enter__", [](MLIRContext &self) -> MLIRContext & { return self; }, + py::return_value_policy::reference) + .def("__exit__", + [](MLIRContext &, py::object, py::object, py::object) -> bool { + // Keep context alive for the duration of the scope. + return false; + }) .def("printOpOnDiagnostic", [](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); }) .def("printStackTraceOnDiagnostic", @@ -662,6 +670,10 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, const std::vector &array) { return self.getBuilder().getI64ArrayAttr(array); }) + .def("get_type_array_attr", + [](TritonOpBuilder &self, const std::vector &array) { + return self.getBuilder().getTypeArrayAttr(array); + }) // Use arith.ConstantOp to create constants // Constants .def("get_int1", @@ -1708,52 +1720,61 @@ void init_triton_ir(py::module &&m) { printingFlags); } }) - .def("run", [](PassManager &self, ModuleOp &mod) { - // TODO: maybe dump module to file and print error for better - // diagnostics - - auto reproducerPath = - triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); - if (!reproducerPath.empty()) { - auto anchorName = self.getOpAnchorName(); - auto passes = self.getPasses(); - Operation *op = mod.getOperation(); - makeReproducer(anchorName, passes, op, reproducerPath); - } - - if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { - ::llvm::DebugFlag = true; - } - - if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); - !debugOnly.empty()) { - llvm::SmallVector split; - llvm::SmallVector storage; - llvm::SmallVector debugTypes; - - StringRef(debugOnly.c_str()).split(split, ','); - llvm::transform(split, std::back_inserter(debugTypes), - [&storage](StringRef str) { - // StringRefs are not always null-terminated. - // The purpose for this storage pattern is to - // produce a collection of C-strings that are. - storage.push_back(str.str()); - return storage.back().c_str(); - }); - - ::llvm::DebugFlag = true; - using namespace llvm; - setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); - } - - bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); - if (haveTiming) { - self.enableTiming(); - } - - if (failed(self.run(mod.getOperation()))) - throw std::runtime_error("PassManager::run failed"); - }); + .def( + "run", + [](PassManager &self, ModuleOp &mod) { + // TODO: maybe dump module to file and print error for better + // diagnostics + + auto *context = mod.getContext(); + if (::triton::tools::getBoolEnv("MLIR_DISABLE_MULTITHREADING")) + context->disableMultithreading(); + + auto reproducerPath = + triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); + if (!reproducerPath.empty()) { + auto anchorName = self.getOpAnchorName(); + auto passes = self.getPasses(); + Operation *op = mod.getOperation(); + makeReproducer(anchorName, passes, op, reproducerPath); + context->disableMultithreading(); + } + + if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { + ::llvm::DebugFlag = true; + } + + if (auto debugOnly = + triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); + !debugOnly.empty()) { + llvm::SmallVector split; + llvm::SmallVector storage; + llvm::SmallVector debugTypes; + + StringRef(debugOnly.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(debugTypes), + [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + + ::llvm::DebugFlag = true; + using namespace llvm; + setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + } + + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + + if (failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); + }, + py::call_guard()); } void init_triton_env_vars(py::module &m) { diff --git a/third_party/ascend/triton_ascend.cc b/third_party/ascend/triton_ascend.cc index 2f08b0f331..7e3a6a2e88 100644 --- a/third_party/ascend/triton_ascend.cc +++ b/third_party/ascend/triton_ascend.cc @@ -9,14 +9,11 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" -#include "ascend/include/TritonToHFusion/Passes.h" -#include "ascend/include/TritonToHIVM/Passes.h" -#include "ascend/include/TritonToLLVM/Passes.h" -#include "incubated/Conversion/DiscreteMaskAccessConversion/Passes.h" -#include "incubated/Conversion/TritonToAnnotation/Passes.h" -#include "incubated/Conversion/TritonToLinalgIncubated/Passes.h" -#include "incubated/Conversion/TritonToStructuredIncubated/Passes.h" -#include "incubated/Conversion/TritonToUnstructureIncubated/Passes.h" +#include "AutoBlockify/Passes.h" +#include "TritonAffinityOpt/Passes.h" +#include "TritonToHFusion/Passes.h" +#include "TritonToHIVM/Passes.h" +#include "TritonToLLVM/Passes.h" #include "npu/Dialect/TritonAscend/IR/TritonAscendDialect.h" #include "ir.h" // TritonOpBuilder @@ -187,26 +184,6 @@ void init_triton_ascend_ir(py::module &&m) { return indexSelectSimdOp.getResult(); }) - .def("create_embedding_gather", - [](TritonOpBuilder &self, Value &src, Value &idx, - const int64_t bound, const int64_t blksiz, - std::vector &offsets, - std::vector &numels) -> Value { - auto elemTy = cast(src.getType()).getPointeeType(); - auto idxTy = cast(idx.getType()); - auto idxShape = idxTy.getShape(); - std::vector retShape(idxShape.begin(), idxShape.end()); - retShape.push_back(blksiz); - auto resType = RankedTensorType::get(retShape, elemTy); - auto idxBitWidth = idxTy.getElementType().getIntOrFloatBitWidth(); - auto bound_val = - self.create(bound, idxBitWidth); - auto blksiz_val = - self.create(blksiz, idxBitWidth); - - return self.create( - resType, src, idx, bound_val, blksiz_val, offsets, numels); - }) .def("create_index_put", [](TritonOpBuilder &self, Value &ptr, Value &index, Value &value, const int32_t dim, const int64_t indexBoundary, @@ -309,49 +286,16 @@ void init_triton_ascend_ir(py::module &&m) { } void init_triton_ascend_passes_ttir(py::module &&m) { - m.def("add_triton_to_structure_incubated", - [](mlir::PassManager &pm, bool enableMaskFallbackConversion, - bool optimizeDynamicOffset, bool compileOn91095) { - pm.addPass(mlir::triton::createTritonToStructuredIncubatedPass( - enableMaskFallbackConversion, optimizeDynamicOffset, - compileOn91095)); - }); - - m.def("add_triton_to_annotation", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createTritonToAnnotationPass()); + m.def("add_auto_blockify", [](mlir::PassManager &pm, int autoBlockifySize) { + AutoBlockifyOptions opts; + opts.autoBlockifySize = autoBlockifySize; + pm.addPass(mlir::triton::createAutoBlockifyPass(opts)); }); - m.def("add_triton_to_linalg_incubated", - [](mlir::PassManager &pm, bool globalKernel, bool namedOps, - bool enableNd2nzOnVector, bool enableSelectAnalysis, - bool compileOn91095) { - pm.addPass(mlir::triton::Incubated::createTritonToLinalgIncubatedPass( - globalKernel, namedOps, enableNd2nzOnVector, enableSelectAnalysis, - compileOn91095)); - }); - - m.def("add_triton_to_unstructure_incubated", - [](mlir::PassManager &pm, bool compileOn91095, bool forceSimtTemplate) { - TritonToUnstructureIncubatedOptions opts; - opts.compileOn91095 = compileOn91095; - opts.forceSimtTemplate = forceSimtTemplate; - pm.addPass( - mlir::triton::createTritonToUnstructureIncubatedPass(opts)); - }); - m.def("add_triton_to_hfusion", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::createTritonToHFusionPass()); }); - m.def("add_discrete_mask_access_conversion", - [](mlir::PassManager &pm, bool compileOn91095, bool forceSimtTemplate) { - DiscreteMaskAccessConversionOptions opts; - opts.compileOn91095 = compileOn91095; - opts.forceSimtTemplate = forceSimtTemplate; - pm.addPass( - mlir::triton::createDiscreteMaskAccessConversionPass(opts)); - }); - m.def("add_triton_to_hivm", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::createTritonToHIVMPass()); }); @@ -363,6 +307,18 @@ void init_triton_ascend_passes_ttir(py::module &&m) { m.def("add_bubble_up_operation", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::createBubbleUpOperationPass()); }); + + m.def("add_dag_sync", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createDAGSyncPass()); + }); + + m.def("add_dag_scope", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createDAGScopePass()); + }); + + m.def("add_dag_ssbuffer", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createDAGSSBufferPass()); + }); } // Forward declaration for ascend_ir bindings (defined in ascend_ir.cc) @@ -372,7 +328,6 @@ void init_triton_ascend(py::module &&m) { auto passes = m.def_submodule("passes"); // load dialects m.def("load_dialects", [](mlir::MLIRContext &context) { - context.allowUnregisteredDialects(); mlir::DialectRegistry registry; registry.insert(); context.appendDialectRegistry(registry); diff --git a/third_party/ascend/tutorials/03-matrix-multiplication.py b/third_party/ascend/tutorials/03-matrix-multiplication.py new file mode 100644 index 0000000000..d4cbb0d35f --- /dev/null +++ b/third_party/ascend/tutorials/03-matrix-multiplication.py @@ -0,0 +1,199 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Matrix Multiplication +=============== +""" + +import triton +import triton.language as tl +import torch +import torch_npu +import triton.language.extra.cann.extension as extension + +DEV = "npu" + + +def get_autotune_config(): + return [ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}), + triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}), + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + ACTIVATION: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + GROUP_SIZE_M: tl.constexpr = 1 + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs_base = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs_base = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + msk_m = offs_am < M + msk_n = offs_bn < N + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a_ptrs = a_ptrs_base + k * BLOCK_SIZE_K * stride_ak + b_ptrs = b_ptrs_base + k * BLOCK_SIZE_K * stride_bk + a = tl.load( + a_ptrs, + mask=msk_m[:, None] and (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=msk_n[None, :] and (offs_k[:, None] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + # Original vector operations + # # ----------------------------------------------------------- + # # Write back the block of the output matrix C with masks. + # Comment out the following lines to enable split the workload to two vector cores + SUB_BLK_M: tl.constexpr = BLOCK_SIZE_M // 2 + for s in extension.parallel(0, 2, bind_sub_block=True): + vec_sub_blk = extension.extract_slice(accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1)) + if ACTIVATION == "leaky_relu_custom": + vec_sub_blk = leaky_relu_custom(vec_sub_blk) + c_sub_blk = vec_sub_blk.to(tl.float16) + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + s * SUB_BLK_M + tl.arange(0, SUB_BLK_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c_sub_blk, mask=c_mask) + + +# We can fuse `leaky_relu_custom` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu_custom(x): + return tl.where(x >= 0, x, 0.01 * x) + 1.0 + + +def torch_matmul(a, b, activation=""): + c = torch.matmul(a, b) + if activation == "leaky_relu_custom": + c = torch.where(c >= 0, c, 0.01 * c) + 1.0 + return c + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + # 1D launch kernel where each block gets its own program. + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation, # + ) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). +def test(): + activation = "leaky_relu_custom" + torch.manual_seed(0) + a = torch.randn((512, 512), device=DEV, dtype=torch.float16) + b = torch.randn((512, 512), device=DEV, dtype=torch.float16) + triton_output = matmul(a, b, activation) + torch_output = torch_matmul(a, b, activation) + print(f"triton_output_with_fp16_inputs={triton_output}") + print(f"torch_output_with_fp16_inputs={torch_output}") + torch.testing.assert_close(triton_output, torch_output, atol=1e-3, rtol=1e-3) + print("Passed") + + +if __name__ == "__main__": + test() diff --git a/third_party/ascend/tutorials/04-low-memory-dropout.py b/third_party/ascend/tutorials/04-low-memory-dropout.py new file mode 100644 index 0000000000..2c5570a0f4 --- /dev/null +++ b/third_party/ascend/tutorials/04-low-memory-dropout.py @@ -0,0 +1,137 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Low-Memory Dropout +================== +""" + +import tabulate +import torch +import torch_npu + +import triton +import triton.language as tl + +DEV = "npu" + + +@triton.jit +def _dropout( + x_ptr, # pointer to the input + x_keep_ptr, # pointer to a mask of 0s and 1s + output_ptr, # pointer to the output + n_elements, # number of elements in the `x` tensor + p, # probability that an element of `x` is changed to zero + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + # Load data + x = tl.load(x_ptr + offsets, mask=mask) + x_keep = tl.load(x_keep_ptr + offsets, mask=mask) + # The line below is the crucial part, described in the paragraph above! + output = tl.where(x_keep != 0, x / (1 - p), 0.0) + # Write-back output + tl.store(output_ptr + offsets, output, mask=mask) + + +def dropout(x, x_keep, p): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) + return output + + +@triton.jit +def _seeded_dropout( + x_ptr, + output_ptr, + n_elements, + p, + seed, + BLOCK_SIZE: tl.constexpr, +): + # compute memory offsets of elements handled by this instance + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # load data from x + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + # randomly prune it + random = tl.rand(seed, offsets) + x_keep = random > p + # write-back + output = tl.where(x_keep, x / (1 - p), 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + + +def seeded_dropout(x, p, seed): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) + return output + + +def test(): + # Input tensor + x = torch.randn(size=(10, ), device=DEV) + # Dropout mask + p = 0.5 + x_keep = (torch.rand(size=(10, ), device=DEV) > p).to(torch.int32) + # + output = dropout(x, x_keep=x_keep, p=p) + print(tabulate.tabulate([ + ["input"] + x.tolist(), + ["keep mask"] + x_keep.tolist(), + ["output"] + output.tolist(), + ])) + + x = torch.randn(size=(10, ), device=DEV) + # Compare this to the baseline - dropout mask is never instantiated! + output = seeded_dropout(x, p=0.5, seed=123) + output2 = seeded_dropout(x, p=0.5, seed=123) + output3 = seeded_dropout(x, p=0.5, seed=512) + + print( + tabulate.tabulate([ + ["input"] + x.tolist(), + ["output (seed = 123)"] + output.tolist(), + ["output (seed = 123)"] + output2.tolist(), + ["output (seed = 512)"] + output3.tolist(), + ])) + + +if __name__ == "__main__": + test() diff --git a/third_party/ascend/tutorials/05-layer-norm.py b/third_party/ascend/tutorials/05-layer-norm.py new file mode 100644 index 0000000000..8af05fc81f --- /dev/null +++ b/third_party/ascend/tutorials/05-layer-norm.py @@ -0,0 +1,126 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Layer Normalization +============= +""" + +import pytest +import torch +import triton +import triton.language as tl +import torch_npu + + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +@torch.inference_mode() +def layer_norm(x, normalized_shape, weight, bias, eps=1e-5): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + kernel = _layer_norm_fwd_fused[(M, )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + return y + + +def _layer_norm(M, N, dtype, eps=1e-5, device='npu'): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(True) + # forward pass + y_tri = layer_norm(x, w_shape, weight, bias, eps) + y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + # compare + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + print(f"y_tri: {y_tri}") + print(f"y_ref: {y_ref}") + print(f"Layer Normalization {M},{N} {dtype} PASSED!") + + +if __name__ == "__main__": + _layer_norm(128, 128, torch.float16) + _layer_norm(128, 128, torch.bfloat16) + _layer_norm(128, 128, torch.float32) diff --git a/third_party/ascend/tutorials/06-demo-autotune.py b/third_party/ascend/tutorials/06-demo-autotune.py deleted file mode 100644 index dc37a9e306..0000000000 --- a/third_party/ascend/tutorials/06-demo-autotune.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -""" -Autotune -============= -""" -import torch, torch_npu -import triton -import triton.language as tl - - -def test_triton_autotune(): - # Return a set of different kernel configurations for autotune - def get_autotune_config(): - return [ - triton.Config({'XS': 1 * 128, 'multibuffer': True}), - triton.Config({'XS': 12 * 1024, 'multibuffer': True}), - triton.Config({'XS': 12 * 1024, 'multibuffer': False}), - triton.Config({'XS': 8 * 1024, 'multibuffer': True}), - ] - - # Use @autotune decorator to automatically select the best kernel configuration - @triton.autotune(configs=get_autotune_config(), # List of configurations - key=["numel"], # the change of numel will trigger autotuning - ) - @triton.jit - def triton_calc_kernel(out_ptr0, in_ptr0, in_ptr1, numel, - XS: tl.constexpr # Block size controlling how many elements each thread block processes - ): - pid = tl.program_id(0) # Get current program ID - idx = pid * XS + tl.arange(0, XS) # Index range handled by current thread block - msk = idx < numel # Mask to avoid out-of-bound access - for i in range(10000): - tmp0 = tl.load(in_ptr0 + idx, mask=msk, other=0.0) # Load x0 - tmp1 = tl.load(in_ptr1 + idx, mask=msk, other=0.0) # Load x1 - tmp2 = tl.math.exp(tmp0) + tmp1 + i - tl.store(out_ptr0 + idx, tmp2, mask=msk) # Store result - - # Function to call the Triton kernel with autotuned configuration - def triton_calc_func(x0, x1): - n = x0.numel() - y0 = torch.empty_like(x0) - grid = lambda meta: (triton.cdiv(n, meta["XS"]), 1, 1) - triton_calc_kernel[grid](y0, x0, x1, n) - return y0 - - # Reference implementation using PyTorch for correctness check - def torch_calc_func(x0, x1): - return torch.exp(x0) + x1 + 10000 - 1 - - DEV = "npu" - DTYPE = torch.float32 - N = 192 * 1024 - x0 = torch.randn((N, ), dtype=DTYPE, device=DEV) - x1 = torch.randn((N, ), dtype=DTYPE, device=DEV) - torch_ref = torch_calc_func(x0, x1) - triton_cal = triton_calc_func(x0, x1) - torch.testing.assert_close(triton_cal, torch_ref) - - -if __name__ == "__main__": - test_triton_autotune() - print("success: test_triton_autotune") diff --git a/third_party/ascend/tutorials/06-fused-attention.py b/third_party/ascend/tutorials/06-fused-attention.py new file mode 100644 index 0000000000..8f67a6b3fb --- /dev/null +++ b/third_party/ascend/tutorials/06-fused-attention.py @@ -0,0 +1,352 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Credits: OpenAI kernel team + +Extra Credits: + +* Original flash attention paper (https://arxiv.org/abs/2205.14135) +* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +""" + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.language.extra.cann.extension as extension + +DEVICE = "npu" + + +@triton.jit +def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, query vector + K_block_ptr, V_block_ptr, # Key and value block pointers for current stage + start_m, qk_scale, # Starting position of current query block, qk scale factor + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # Block size constants + STAGE: tl.constexpr, offs_m: tl.constexpr, + offs_n: tl.constexpr, # Current stage flag, m and n offset indices + N_CTX: tl.constexpr, + fp8_v: tl.constexpr): # Total context length, whether to enable FP8 for value precision + # Set the processing range [lo, hi) for the current stage (in column block units) + # Causal attention, as the name implies, restricts the flow of information during computation, + # only allowing the model to see the current and previous positions. + # In other words, the output at the current position can only depend on the input at or before this position, + # and cannot access information from future positions. + # Causal attention ensures sequential order and prevents "leakage of future information." + # But the following logic will also be triggered + if STAGE == 1: + # Stage 1: process all tokens before the query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + # Stage 2: process the current query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) # Align starting position + # causal = False (no need for masking) + else: + lo, hi = 0, N_CTX # Process the entire context + + # Adjust K and V block pointers to the starting position `lo` + K_block_ptr = tl.advance(K_block_ptr, (lo, 0)) # K is [HEAD_DIM, N_CTX], shift along the second dim by lo + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # V is [N_CTX, HEAD_DIM], shift along the first dim by lo + + # Index mapping for the accumulator , used for slicing when HEAD_DIM >= 256 + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + + # Iterate over all k, v blocks in the current stage and accumulate the output + for start_n in range(lo, hi, BLOCK_N): # Process BLOCK_N columns at a time + start_n = tl.multiple_of(start_n, BLOCK_N) # Align column start position + # -- Compute qk ---- + k = tl.load(K_block_ptr) + # Modify K + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + # Apply causal mask for STAGE 2 + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) # Construct upper triangular mask + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) # Set invalid positions to -∞ + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Update m_ij = max(m_i, max(qk)) + qk -= m_ij[:, None] # Subtract max for softmax stability + else: + qk = qk * qk_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max + qk = qk - m_ij[:, None] # Stabilize + + # Softmax weights p = exp(qk) + p = tl.math.exp(qk) + + # Convert softmax weight type depending on FP8 usage + if fp8_v: + p_cast = p.to(tl.float8e5) # Convert to FP8 format (save memory) + else: + p_cast = p.to(k.dtype) + + v = tl.load(V_block_ptr) # Load corresponding V block + pv = tl.dot(p_cast, v) + l_ij = tl.sum(p, 1) # Softmax denominator (sum of each row) + # -- Update m_i and l_i + alpha = tl.math.exp(m_i - m_ij) # Update factor: exp difference between old and new max + l_i = l_i * alpha + l_ij # Update softmax denominator + # -- Update output accumulator -- + if HEAD_DIM < 256: + acc_ptr = acc_ptr * alpha[:, None] + acc_ptr = tl.dot(p_cast, v, acc_ptr) + else: + # 1. Load current slice of accumulator + acc = tl.load(acc_ptr + block2d_acc) + # 2. Update in slices (split by 1/4 of BLOCK_M to avoid ub overflow) + for i in range(4): + # Calculate start/end rows for current slice + offset = i * (BLOCK_M // 4) + # Extract slice data + acc_i = extension.extract_slice(acc, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + alpha_i = extension.extract_slice(alpha, [offset], [BLOCK_M // 4], [1]) + pv_i = extension.extract_slice(pv, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # Incrementally update slice: acc = acc * alpha + pv + acc_i = acc_i * alpha_i[:, None] + pv_i + # Write updated slice back to accumulator + acc = extension.insert_slice(acc, acc_i, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # 3. updated accumulator + tl.store(acc_ptr + block2d_acc, acc) + + m_i = m_ij # Update current block max + # Advance V and K block pointers to next BLOCK_N range + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + # Return accumulated output acc_ptr, softmax denominator l_i, and max value m_i + return acc_ptr, l_i, m_i + + +@triton.jit +def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, + stride_qk: tl.constexpr, stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, + stride_kk: tl.constexpr, stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, + stride_vk: tl.constexpr, stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, + stride_on: tl.constexpr, Z: tl.constexpr, H: tl.constexpr, N_CTX: tl.constexpr, HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, STAGE: tl.constexpr): + # Total number of blocks in sequence dimension (M) + NUM_BLOCKS_M = N_CTX // BLOCK_M + # Total tasks = number of sequence blocks × batch size (Z) × number of attention heads (H) + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + + # Current M-dimension block index + pid = tl.program_id(0) + + for block_idx in range(pid, NUM_BLOCKS, 20): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + # Create block pointers for Q, K, V, Output + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + # Initialize offsets + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + + # Initialize accumulator + if HEAD_DIM < 256: + acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + else: + acc_offset = (off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM + + off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM + task_m_idx * BLOCK_M * HEAD_DIM) + acc_ptr = acc + acc_offset + + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + + m_i += tl.math.log(l_i) + if HEAD_DIM < 256: + accumulator = acc_ptr / l_i[:, None] + else: + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + accumulator = tl.load(acc_ptr + block2d_acc) + accumulator = accumulator / l_i[:, None] + + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, accumulator.to(Out.type.element_ty)) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, BM, BN): + """ + Forward computation interface: + Args: + ctx: Context object + q: Query tensor (Q), shape [Z, H, N_CTX, HEAD_DIM] + k: Key tensor (K), shape [Z, H, N_CTX, HEAD_DIM] + v: Value tensor (V), shape [Z, H, N_CTX, HEAD_DIM] + causal: Whether to enable causal attention + sm_scale: Scaling factor for QK product + BM: Q block size (BLOCK_M) + BN: K/V block size (BLOCK_N) + Returns: + out: Attention output tensor, shape [Z, H, N_CTX, HEAD_DIM] + """ + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + out = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + + # Number of NPU cores (adjust based on hardware) + num_cores = 20 + acc = torch.zeros((q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), dtype=torch.float32, device=q.device) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[(num_cores, )](q, k, v, M, out, acc, sm_scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), + v.stride(2), v.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), + q.shape[0], q.shape[1], N_CTX=q.shape[2], HEAD_DIM=HEAD_DIM_K, BLOCK_M=BM, BLOCK_N=BN, + STAGE=stage, **extra_kern_args) + + ctx.save_for_backward(q, k, v, out, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return out + + +attention = _attention.apply + + +@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN", [ + (1, 1, 128, 128, False, torch.float16, 32, 128), + (1, 1, 128, 128, False, torch.bfloat16, 64, 128), + (1, 2, 256, 256, False, torch.bfloat16, 32, 256), + (2, 2, 128, 256, False, torch.float16, 64, 128), + (4, 32, 64, 64, False, torch.float16, 32, 64), + (4, 32, 1024, 64, False, torch.bfloat16, 64, 128), + (4, 32, 4096, 64, False, torch.float16, 128, 128), +]) +def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN): + # Filter out non-integer cases; N_CTX must be divisible by BM and BN, and HEAD_DIM must be divisible by 16. + if N_CTX % BM != 0 or N_CTX % BN != 0 or HEAD_DIM % 16 != 0: + pytest.skip("Skipping non-divisible case") + + torch.manual_seed(20) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + + sm_scale = 0.5 + + tri_out = attention(q, k, v, causal, sm_scale, BM, BN) + ref_out = torch_npu.npu_fusion_attention( + q, + k, + v, + H, + padding_mask=None, + atten_mask=None, + scale=sm_scale, + keep_prob=1.0, + input_layout="BNSD", + pre_tockens=65535, + next_tockens=65535, + sparse_mode=0, + )[0] + + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2, equal_nan=True) + print(f"[PASSED] Attention shape:({Z}, {H}, {N_CTX}, {HEAD_DIM}), BM: {BM}, BN: {BN}, dtype: {dtype}") + + +if __name__ == "__main__": + test_op(1, 1, 128, 128, causal=False, dtype=torch.float16, BM=32, BN=128) + test_op(1, 1, 128, 128, causal=False, dtype=torch.bfloat16, BM=64, BN=128) + test_op(1, 2, 256, 256, causal=False, dtype=torch.bfloat16, BM=32, BN=256) + test_op(2, 2, 128, 256, causal=False, dtype=torch.float16, BM=64, BN=128) + test_op(4, 32, 64, 64, causal=False, dtype=torch.float16, BM=32, BN=64) + test_op(4, 32, 1024, 64, causal=False, dtype=torch.bfloat16, BM=64, BN=128) + test_op(4, 32, 4096, 64, causal=False, dtype=torch.float16, BM=128, BN=128) diff --git a/third_party/ascend/tutorials/07-extern-functions.py b/third_party/ascend/tutorials/07-extern-functions.py new file mode 100644 index 0000000000..e433640245 --- /dev/null +++ b/third_party/ascend/tutorials/07-extern-functions.py @@ -0,0 +1,87 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Libdevice (`tl.extra.libdevice`) function +============================== +""" +import inspect +import os +from pathlib import Path + +import torch +import torch_npu + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +from triton.backends.ascend.compiler import get_libdevice + +DEV = "npu" + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) + + +def test(): + torch.manual_seed(0) + size = 98432 + x = torch.rand(size, device=DEV) + output_triton = torch.zeros(size, device=DEV) + output_torch = torch.asin(x) + assert x.device.type == DEV and output_triton.device.type == DEV + n_elements = output_torch.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) + print(output_torch) + print(output_triton) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + + current_file = inspect.getfile(inspect.currentframe()) + current_dir = Path(os.path.dirname(os.path.abspath(current_file))) + extern_libs = {'libdevice': get_libdevice()} + + output_triton = torch.empty_like(x) + asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, extern_libs=extern_libs) + torch.testing.assert_close(output_torch, output_triton, rtol=1e-4, atol=1e-4) + print(output_torch) + print(output_triton) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + + +if __name__ == "__main__": + test() diff --git a/third_party/ascend/tutorials/07-profiler.py b/third_party/ascend/tutorials/07-profiler.py deleted file mode 100644 index f62800f902..0000000000 --- a/third_party/ascend/tutorials/07-profiler.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - - -def profiler_wrapper(fn, *args): - result_path = "./result_profiling" - skip_first = 10 - wait = 0 - warmup = 3 - active = 30 - repeat = 1 - stream = torch.npu.current_stream() - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) - with torch_npu.profiler.profile( - activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], - schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, - skip_first=skip_first), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), record_shapes=True, - profile_memory=False, with_stack=False, with_flops=False, with_modules=False, - experimental_config=experimental_config) as prof: - stream.synchronize() - for i in range(skip_first + (wait + warmup + active) * repeat): - fn(*args) - prof.step() - stream.synchronize() - - -def test_add(x0, x1): - - def torch_func(x0, x1): - res = x0 + x1 - return res - - @triton.jit - def triton_kernel_add(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr): - idx = tl.arange(0, XS) - tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.load(in_ptr1 + idx) - tmp2 = tmp0 + tmp1 - tl.store(out_ptr0 + idx, tmp2) - - def triton_func(x0, x1): - y0 = torch.empty_like(x0) - triton_kernel_add[1, 1, 1](y0, x0, x1, N) - return y0 - - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1) - torch.testing.assert_close(triton_cal, torch_ref) - - def wrapper_func(x0, x1): - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1) - - profiler_wrapper(wrapper_func, x0, x1) - - -def test_or(x0, x1): - - def torch_func(x0, x1): - res = x0 | x1 - return res - - @triton.jit - def triton_kernel_or(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr): - idx = tl.arange(0, XS) - tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.load(in_ptr1 + idx) - tmp2 = tmp0 | tmp1 - tl.store(out_ptr0 + idx, tmp2) - - def triton_func(x0, x1): - y0 = torch.empty_like(x0) - triton_kernel_or[1, 1, 1](y0, x0, x1, N) - return y0 - - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1) - torch.testing.assert_close(triton_cal, torch_ref) - - def wrapper_func(x0, x1): - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1) - - profiler_wrapper(wrapper_func, x0, x1) - - -def test_inductor_add(x0, x1): - # torch_npu._inductor requires torch_npu 2.6.0+ experimental version - import torch_npu._inductor - - def torch_func(x0, x1): - res = x0 + x1 - return res - - compiled_func = torch.compile(torch_func, backend="inductor") - profiler_wrapper(compiled_func, x0, x1) - print("[INFO] Check ./result_profiling directory to find the kernel_details.csv file. " - " Check the columns: Input Shapes,Input Data Types,Input Formats") - - -if __name__ == "__main__": - test_case_is_inductor = False - N = 1024 - low = 1 - high = 100 - - # float32 - x0_fp32 = torch.rand((N, ), dtype=torch.float32).npu() - x1_fp32 = torch.rand((N, ), dtype=torch.float32).npu() - - # float16 - x0_fp16 = torch.rand((N, ), dtype=torch.float16).npu() - x1_fp16 = torch.rand((N, ), dtype=torch.float16).npu() - - # bfloat16 - x0_bf16 = torch.rand((N, ), dtype=torch.bfloat16).npu() - x1_bf16 = torch.rand((N, ), dtype=torch.bfloat16).npu() - - # int64 - x0_i64 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int64).npu() - x1_i64 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int64).npu() - - # int32 - x0_i32 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int32).npu() - x1_i32 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int32).npu() - - # int16 - x0_i16 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int16).npu() - x1_i16 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int16).npu() - - # int8 - x0_i8 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int8).npu() - x1_i8 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int8).npu() - - # bool (i1) - x0_i1 = torch.randint(low=0, high=2, size=(N, )).bool().npu() - x1_i1 = torch.randint(low=0, high=2, size=(N, )).bool().npu() - - test_cases = [ - ('fp32', x0_fp32, x1_fp32), - ('fp16', x0_fp16, x1_fp16), - ('bf16', x0_bf16, x1_bf16), - ('i64', x0_i64, x1_i64), - ('i32', x0_i32, x1_i32), - ('i16', x0_i16, x1_i16), - ('i8', x0_i8, x1_i8), - ('i1', x0_i1, x1_i1), - ] - - for dtype_name, x0, x1 in test_cases: - print(f"Running test for {dtype_name}...") - if dtype_name != 'i1': - if (test_case_is_inductor): - test_inductor_add(x0, x1) - else: - test_add(x0, x1) - else: - test_or(x0, x1) diff --git a/third_party/ascend/tutorials/08-grouped-gemm.py b/third_party/ascend/tutorials/08-grouped-gemm.py new file mode 100644 index 0000000000..96739be81d --- /dev/null +++ b/third_party/ascend/tutorials/08-grouped-gemm.py @@ -0,0 +1,281 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# Copyright (c) Huawei Technologies Co., Ltd. 2025. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Group GEMM +============================ +""" + +import torch +import torch_npu + +import triton +import triton.language as tl +import triton.runtime.driver as driver + +DEV = "npu" + + +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +NUM_CORES = get_npu_properties()["num_aicore"] + + +@triton.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + ], + key=['group_size'], +) +@triton.jit +def grouped_matmul_kernel( + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + group_gemm_sizes, + g_lds, + group_size, + NUM_SM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + tile_idx = tl.program_id(0) + last_problem_end = 0 + for g in range(group_size): + gm = tl.load(group_gemm_sizes + g * 3) + gn = tl.load(group_gemm_sizes + g * 3 + 1) + gk = tl.load(group_gemm_sizes + g * 3 + 2) + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles): + k = gk + lda = tl.load(g_lds + g * 3) + ldb = tl.load(g_lds + g * 3 + 1) + ldc = tl.load(g_lds + g * 3 + 2) + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) + tile_idx_in_gemm = tile_idx - last_problem_end + tile_m_idx = tile_idx_in_gemm // num_n_tiles + tile_n_idx = tile_idx_in_gemm % num_n_tiles + + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] + b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for _ in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + tl.multiple_of(a_ptrs, [16, 16]) + tl.multiple_of(b_ptrs, [16, 16]) + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * ldb + c = accumulator.to(tl.float16) + + offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] + + tl.store(c_ptrs, c) + tile_idx += NUM_SM + + last_problem_end = last_problem_end + num_tiles + + +def group_gemm_fn(group_A, group_B): + device = torch.device(DEV) + assert len(group_A) == len(group_B) + group_size = len(group_A) + + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = group_A[i] + B = group_B[i] + assert A.shape[1] == B.shape[0] + M, K = A.shape + K, N = B.shape + C = torch.empty((M, N), device=device, dtype=A.dtype) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [M, N, K] + g_lds += [A.stride(0), B.stride(0), C.stride(0)] + + d_a_ptrs = torch.tensor(A_addrs, device=device) + d_b_ptrs = torch.tensor(B_addrs, device=device) + d_c_ptrs = torch.tensor(C_addrs, device=device) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) + + def grid(meta): + return (meta['NUM_SM'], ) + + grouped_matmul_kernel[grid]( + d_a_ptrs, + d_b_ptrs, + d_c_ptrs, + d_g_sizes, + d_g_lds, + group_size, + ) + + return group_C + + +def test(): + group_m = [1024, 512, 256, 128] + group_n = [1024, 512, 256, 128] + group_k = [1024, 512, 256, 128] + group_A = [] + group_B = [] + assert len(group_m) == len(group_n) + assert len(group_n) == len(group_k) + group_size = len(group_m) + for i in range(group_size): + M = group_m[i] + N = group_n[i] + K = group_k[i] + A = torch.rand((M, K), device=DEV, dtype=torch.float16) + B = torch.rand((K, N), device=DEV, dtype=torch.float16) + group_A.append(A) + group_B.append(B) + + tri_out = group_gemm_fn(group_A, group_B) + ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] + for i in range(group_size): + torch.testing.assert_close(ref_out[i], tri_out[i], atol=1e-2, rtol=1e-3) + print("Passed") + + +def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): + + def grid(meta): + return (meta['NUM_SM'], ) + + grouped_matmul_kernel[grid]( + a_ptrs, + b_ptrs, + c_ptrs, + sizes, + lds, + group_size, + ) + + +def torch_perf_fn(group_A, group_B): + for a, b in zip(group_A, group_B): + torch.matmul(a, b) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['N'], + x_vals=[2**i for i in range(7, 11)], + line_arg='provider', + line_vals=['torch', 'triton'], + line_names=["Torch", "Triton"], + styles=[('green', '-'), ('blue', '-')], + ylabel="runtime(ms)", + plot_name="group-gemm-performance", + args={}, + )) +def benchmark(N, provider): + group_size = 4 + group_A = [] + group_B = [] + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for _ in range(group_size): + A = torch.rand((N, N), device=DEV, dtype=torch.float16) + B = torch.rand((N, N), device=DEV, dtype=torch.float16) + C = torch.empty((N, N), device=DEV, dtype=torch.float16) + group_A.append(A) + group_B.append(B) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [N, N, N] + g_lds += [N, N, N] + + d_a_ptrs = torch.tensor(A_addrs, device=DEV) + d_b_ptrs = torch.tensor(B_addrs, device=DEV) + d_c_ptrs = torch.tensor(C_addrs, device=DEV) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEV) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEV) + + quantiles = [0.5, 0.2, 0.8] + + def bench_torch(): + torch_perf_fn(group_A, group_B) + + def bench_triton(): + triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size) + + if provider == 'torch': + ms, min_ms, max_ms = triton.testing.do_bench(bench_torch, quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(bench_triton, quantiles=quantiles) + return ms, max_ms, min_ms + + +if __name__ == "__main__": + test() diff --git a/third_party/ascend/tutorials/09-persistent-matmul.py b/third_party/ascend/tutorials/09-persistent-matmul.py new file mode 100644 index 0000000000..0e085dc624 --- /dev/null +++ b/third_party/ascend/tutorials/09-persistent-matmul.py @@ -0,0 +1,334 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Persistent Matmul +===================== +""" + +import argparse +import time + +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver + +DEV = "npu" +DTYPE = torch.float16 + + +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +def get_num_compute_cores(): + return get_npu_properties()["num_aicore"] + + +def _matmul_launch_metadata(grid, kernel, args): + ret = {} + M, N, K = args["M"], args["N"], args["K"] + bytes_per_elem = args["c_ptr"].element_size() + ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K + ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) + return ret + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_sm = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_sm += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + pid_m = 0 + pid_n = 0 + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in range(0, k_tiles * tiles_per_sm): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + if ki == k_tiles - 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def get_configs(dtype): + return { + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + } + }[dtype] + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + M, K = a.shape + _, N = b.shape + configs = get_configs(a.dtype) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(meta): + return (triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) + + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=configs["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs["GROUP_SIZE_M"], + num_stages=configs["num_stages"], + num_warps=configs["num_warps"], + ) + return c + + +def matmul_persistent(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + num_sms = get_num_compute_cores() + M, K = a.shape + _, N = b.shape + configs = get_configs(a.dtype) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(meta): + return (min(num_sms, triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"])), ) + + matmul_kernel_persistent[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=configs["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs["GROUP_SIZE_M"], + NUM_SMS=num_sms, + num_stages=configs["num_stages"], + num_warps=configs["num_warps"], + ) + return c + + +def torch_matmul(a, b): + return torch.matmul(a, b) + + +def bench(K, reps=10): + M = 8192 + N = 8192 + a = torch.randn((M, K), device=DEV, dtype=DTYPE) + b = torch.randn((K, N), device=DEV, dtype=DTYPE) + + for _ in range(reps): + _ = torch_matmul(a, b) + time.sleep(0.01) + for _ in range(reps): + _ = matmul(a, b) + time.sleep(0.01) + for _ in range(reps): + _ = matmul_persistent(a, b) + time.sleep(0.01) + + +def validate(M, N, K): + a = torch.randn((M, K), device=DEV, dtype=DTYPE) + b = torch.randn((K, N), device=DEV, dtype=DTYPE) + + torch_result = torch_matmul(a, b) + naive_result = matmul(a, b) + persistent_result = matmul_persistent(a, b) + + naive_vs_torch = "✅" if torch.allclose(naive_result, torch_result, atol=1.0) else "❌" + persistent_vs_torch = "✅" if torch.allclose(persistent_result, torch_result, atol=1.0) else "❌" + naive_vs_persistent = "✅" if torch.allclose(naive_result, persistent_result, atol=1.0) else "❌" + + print(f"M={M}, N={N}, K={K} verification naive vs torch: {naive_vs_torch} " + f"persistent vs torch: {persistent_vs_torch} naive vs persistent: {naive_vs_persistent}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-K", type=int, required=False, default=512) + parser.add_argument("--K_range", type=int, nargs=2) + parser.add_argument("--K_step", type=int, default=512) + args = parser.parse_args() + + if args.K and args.K_range is None: + args.K_range = [args.K, args.K] + args.K_step = 1 + + torch.manual_seed(0) + + validate(32, 32, 32) + validate(8192, 8192, 512) + + for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + bench(K) diff --git a/third_party/ascend/tutorials/15-embedding_gather_demo.py b/third_party/ascend/tutorials/15-embedding_gather_demo.py deleted file mode 100644 index 84fd70ef27..0000000000 --- a/third_party/ascend/tutorials/15-embedding_gather_demo.py +++ /dev/null @@ -1,118 +0,0 @@ -# only available on 910_95 -import torch -import torch_npu -from torch import empty_strided -from torch._dynamo.testing import rand_strided -import triton -import triton.language as tl - -y0_numel = 128 -r1_numel = 50 -x2_numel = 16 -embedding_size = 1353406 - - -def profiler_wrapper(fn, *args): - result_path = "./result_profiling" - skip_first = 10 - wait = 0 - warmup = 3 - active = 30 - repeat = 1 - stream = torch.npu.current_stream() - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) - with torch_npu.profiler.profile( - activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], - schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, - skip_first=skip_first), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), record_shapes=True, - profile_memory=False, with_stack=False, with_flops=False, with_modules=False, - experimental_config=experimental_config) as prof: - stream.synchronize() - for i in range(skip_first + (wait + warmup + active) * repeat): - fn(*args) - prof.step() - stream.synchronize() - - -def get_autotune_config(): - return [ - triton.Config({ - 'Y0BLOCK': 4, 'Y0BLOCK_SUB': 2, 'X2BLOCK_SUB': x2_numel, 'R1BLOCK_SUB': r1_numel, 'EMBEDDING_SIZE': - embedding_size, 'multibuffer': False - }), - ] - - -@triton.autotune(configs=get_autotune_config(), # List of configurations - key=["numel"], # the change of numel will trigger autotuning - ) -@triton.jit -def triton_unk_fused_embedding_eq_sum_where_zeros_like_0(in_ptr0, in_ptr1, out_ptr0, y0_numel, r1_numel, x2_numel, - Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, - X2BLOCK_SUB: tl.constexpr, R1BLOCK_SUB: tl.constexpr, - EMBEDDING_SIZE: tl.constexpr): - y0_offset = tl.program_id(0) * Y0BLOCK - base_y0 = tl.arange(0, Y0BLOCK_SUB) - loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB - base_r1 = tl.arange(0, R1BLOCK_SUB) - base_x2 = tl.arange(0, X2BLOCK_SUB) - r1 = base_r1[None, None, :] - r1_mask = r1 < r1_numel - x2 = base_x2[None, None, :] - x2_mask = x2 < x2_numel - # loops_x1 = (x1_numel + X2BLOCK_SUB - 1) // X2BLOCK_SUB - # loops_r2 = (r1_numel + R1BLOCK_SUB - 1) // R1BLOCK_SUB - for loop_y0 in range(loops_y0): - y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None, None] - y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) - tmp0 = tl.load(in_ptr0 + (r1 + 50 * y0), r1_mask & y0_mask, other=0.0).to(tl.int32) - tmp1 = tl.full([1, 1, 1], -1, tl.int32) - tmp2 = tmp0 == tmp1 - tmp3 = tl.full([1, 1, 1], 0, tl.int32) - tmp4 = tl.where(tmp2, tmp3, tmp0) - # tmp5 = tl.full([Y0BLOCK_SUB, X2BLOCK_SUB, R1BLOCK_SUB], 1353406, tl.int32) - # tmp6 = tmp4 + tmp5 - # tmp7 = tmp4 < 0 - # tmp8 = tl.where(tmp7, tmp6, tmp4) - # tl.device_assert(((0 <= tmp8) & (tmp8 < 1353406)) | ~(r2_mask & y0_mask), "index out of bounds: 0 <= tmp8 < 1353406") - # tmp10 = tl.load(in_ptr1 + (x1 + 16*tmp8), r2_mask & x1_mask & y0_mask) - # 用下面这行替换上述6行 SIMT - tmp8 = tl.reshape(tmp4, [Y0BLOCK_SUB, R1BLOCK_SUB]) - tmp10 = tl.index_select(in_ptr1, tmp8, EMBEDDING_SIZE, X2BLOCK_SUB, (y0_offset + (loop_y0 * Y0BLOCK_SUB), 0, 0), - (y0_numel, r1_numel, x2_numel)) - tmp14 = tl.sum(tmp10, 1).reshape(Y0BLOCK_SUB, 1, X2BLOCK_SUB) - tl.store(out_ptr0 + (x2 + 16 * y0), tmp14, x2_mask & y0_mask) - - -def triton_func(arg34_1: torch.Tensor, arg35_1: torch.Tensor, buf0: torch.Tensor): - y0_size, _ = arg34_1.size() - grid = lambda meta: (triton.cdiv(y0_size, meta['Y0BLOCK']), ) - triton_unk_fused_embedding_eq_sum_where_zeros_like_0[grid](arg34_1, arg35_1, buf0, y0_numel, r1_numel, x2_numel) - return buf0 - - -def torch_func(x0: torch.Tensor): - return torch.sqrt(x0) - - -torch.manual_seed(0) - -arg34_1 = rand_strided((y0_numel, r1_numel), (r1_numel, 1), device='npu', dtype=torch.int64) -arg35_1 = rand_strided((embedding_size, x2_numel), (x2_numel, 1), device='npu', dtype=torch.float32) -buf0 = empty_strided((y0_numel, x2_numel), (x2_numel, 1), device='npu', dtype=torch.float32) - -output_triton = triton_func(arg34_1, arg35_1, buf0) -print("triton = ", output_triton) - -# output_torch = torch_func(x0) -# print("torch = ", output_torch) -# torch.testing.assert_close(output_triton.cpu(), output_torch.cpu()) - -# def wrapper_func(x0, x1): -# torch_ref = torch_func(x0, x1) -# triton_cal = triton_func(x0, x1) - -# profiler_wrapper(wrapper_func, x0, x1) diff --git a/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/copy_use_analysis.mlir b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/copy_use_analysis.mlir new file mode 100644 index 0000000000..bb98ba7bee --- /dev/null +++ b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/copy_use_analysis.mlir @@ -0,0 +1,262 @@ +// RUN: triton-opt -allow-unregistered-dialect '--triton-to-linalg=named-ops=True enable-nd2nz-on-vector=True compile-on-910-95=True' --split-input-file %s -verify-each 2>&1 | FileCheck %s --check-prefix=NOERR +// NOERR-NOT: failed to legalize unresolved materialization +// CHECK: module +// CHECK: func.func public @backward_dkdv + +module { + tt.func public @backward_dkdv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32, %arg13: i32, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: f32, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32 {tt.divisibility = 16 : i32}, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32 {tt.divisibility = 16 : i32}, %arg25: i32 {tt.divisibility = 16 : i32}, %arg26: i32 {tt.divisibility = 16 : i32}, %arg27: i32 {tt.divisibility = 16 : i32}, %arg28: i32 {tt.divisibility = 16 : i32}, %arg29: i32 {tt.divisibility = 16 : i32}, %arg30: i32 {tt.divisibility = 16 : i32}, %arg31: i32 {tt.divisibility = 16 : i32}, %arg32: i32) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32> + %cst_0 = arith.constant dense<0xFF800000> : tensor<32x32xf32> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32> + %cst_2 = arith.constant dense<1> : tensor<32xi32> + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i32 = arith.constant 32 : i32 + %cst_3 = arith.constant 1.44269502 : f32 + %c1_i32 = arith.constant 1 : i32 + %alloc = memref.alloc() : memref<2x2x16x16xf16, #hivm.address_space> + %alloc_4 = memref.alloc() : memref<32x64xf32, #hivm.address_space> + %alloc_5 = memref.alloc() : memref<2x2x16x16xf16, #hivm.address_space> + %alloc_6 = memref.alloc() : memref<32x32xf32, #hivm.address_space> + %alloc_7 = memref.alloc() : memref<32x32xf32, #hivm.address_space> + %alloc_8 = memref.alloc() : memref<32x64xf32, #hivm.address_space> + %0 = tt.get_program_id x : i32 + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 10 + hivm.hir.sync_block_set[, , ] flag = 9 + hivm.hir.sync_block_set[, , ] flag = 8 + hivm.hir.sync_block_set[, , ] flag = 7 + %1 = tt.get_num_programs x : i32 + %2 = arith.mulf %arg15, %cst_3 : f32 + %3 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> + %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %7 = tt.splat %arg14 : i32 -> tensor<1x64xi32> + %8 = arith.cmpi slt, %6, %7 : tensor<1x64xi32> + %9 = tt.broadcast %8 : tensor<1x64xi1> -> tensor<32x64xi1> + %10 = tt.splat %arg27 : i32 -> tensor<32x1xi32> + %11 = arith.muli %4, %10 : tensor<32x1xi32> + %12 = tt.broadcast %6 : tensor<1x64xi32> -> tensor<32x64xi32> + %13 = tt.splat %arg30 : i32 -> tensor<32x1xi32> + %14 = arith.muli %4, %13 : tensor<32x1xi32> + %15 = tt.splat %arg9 : i32 -> tensor<32xi32> + %16 = arith.muli %3, %15 : tensor<32xi32> + %17 = tt.splat %arg8 : i32 -> tensor<32xi32> + %18 = arith.addi %16, %17 : tensor<32xi32> + %19 = arith.subi %18, %cst_2 : tensor<32xi32> + %20 = arith.subi %arg8, %c1_i32 : i32 + %21 = tt.expand_dims %19 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> + %22 = tt.broadcast %21 : tensor<32x1xi32> -> tensor<32x32xi32> + %23 = tt.splat %2 : f32 -> tensor<32x32xf32> + %24 = tt.splat %arg15 : f32 -> tensor<32x32xf32> + scf.for %arg33 = %0 to %arg32 step %1 : i32 { + %25 = arith.divsi %arg33, %arg32 : i32 + %26 = arith.remsi %arg33, %arg32 : i32 + %27 = arith.divsi %26, %arg13 : i32 + %28 = arith.remsi %26, %arg13 : i32 + %29 = tt.addptr %arg10, %25 : !tt.ptr, i32 + %30 = tt.load %29 : !tt.ptr + %31 = tt.addptr %29, %c1_i32 : !tt.ptr, i32 + %32 = tt.load %31 : !tt.ptr + %33 = arith.subi %32, %30 : i32 + %34 = tt.addptr %arg11, %25 : !tt.ptr, i32 + %35 = tt.load %34 : !tt.ptr + %36 = tt.addptr %34, %c1_i32 : !tt.ptr, i32 + %37 = tt.load %36 : !tt.ptr + %38 = arith.subi %37, %35 : i32 + %39 = tt.splat %38 : i32 -> tensor<32x1xi32> + %40 = arith.cmpi slt, %4, %39 : tensor<32x1xi32> + %41 = tt.broadcast %40 : tensor<32x1xi1> -> tensor<32x64xi1> + %42 = arith.andi %41, %9 : tensor<32x64xi1> + %43 = arith.muli %35, %arg27 : i32 + %44 = tt.addptr %arg6, %43 : !tt.ptr, i32 + %45 = arith.muli %27, %arg28 : i32 + %46 = tt.addptr %44, %45 : !tt.ptr, i32 + %47 = arith.muli %28, %arg26 : i32 + %48 = tt.addptr %46, %47 : !tt.ptr, i32 + %49 = tt.splat %48 : !tt.ptr -> tensor<32x1x!tt.ptr> + %50 = tt.addptr %49, %11 : tensor<32x1x!tt.ptr>, tensor<32x1xi32> + %51 = tt.broadcast %50 : tensor<32x1x!tt.ptr> -> tensor<32x64x!tt.ptr> + %52 = tt.addptr %51, %12 : tensor<32x64x!tt.ptr>, tensor<32x64xi32> + %53 = arith.muli %35, %arg30 : i32 + %54 = tt.addptr %arg7, %53 : !tt.ptr, i32 + %55 = arith.muli %27, %arg31 : i32 + %56 = tt.addptr %54, %55 : !tt.ptr, i32 + %57 = arith.muli %28, %arg29 : i32 + %58 = tt.addptr %56, %57 : !tt.ptr, i32 + %59 = tt.splat %58 : !tt.ptr -> tensor<32x1x!tt.ptr> + %60 = tt.addptr %59, %14 : tensor<32x1x!tt.ptr>, tensor<32x1xi32> + %61 = tt.broadcast %60 : tensor<32x1x!tt.ptr> -> tensor<32x64x!tt.ptr> + %62 = tt.addptr %61, %12 : tensor<32x64x!tt.ptr>, tensor<32x64xi32> + %63 = arith.extsi %33 : i32 to i64 + %64 = tt.addptr %arg4, %30 : !tt.ptr, i32 + %65 = arith.muli %26, %arg23 : i32 + %66 = tt.addptr %64, %65 : !tt.ptr, i32 + %67 = tt.addptr %arg3, %30 : !tt.ptr, i32 + %68 = arith.muli %26, %arg22 : i32 + %69 = tt.addptr %67, %68 : !tt.ptr, i32 + %70:6 = scf.for %arg34 = %20 to %33 step %c32_i32 iter_args(%arg35 = %cst, %arg36 = %cst, %arg37 = %20, %arg38 = %20, %arg39 = %20, %arg40 = %20) -> (tensor<32x64xf32>, tensor<32x64xf32>, i32, i32, i32, i32) : i32 { + %73 = tt.make_tensor_ptr %66, [%63], [%c1_i64], [%arg40] {order = array} : > + %74 = tt.make_tensor_ptr %69, [%63], [%c1_i64], [%arg39] {order = array} : > + %75 = tt.load %74 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %76 = tt.expand_dims %75 {axis = 0 : i32} : tensor<32xf32> -> tensor<1x32xf32> + %77 = tt.load %73 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %78 = tt.expand_dims %77 {axis = 0 : i32} : tensor<32xf32> -> tensor<1x32xf32> + %79 = tt.splat %arg34 : i32 -> tensor<32xi32> + %80 = arith.addi %3, %79 : tensor<32xi32> + %81 = tt.expand_dims %80 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %82 = tt.broadcast %81 : tensor<1x32xi32> -> tensor<32x32xi32> + %83 = arith.cmpi sle, %22, %82 : tensor<32x32xi32> + %84 = arith.select %83, %cst_1, %cst_0 : tensor<32x32xi1>, tensor<32x32xf32> + hivm.hir.sync_block_wait[, , ] flag = 1 + %memspacecast = memref.memory_space_cast %alloc_7 : memref<32x32xf32, #hivm.address_space> to memref<32x32xf32> + %85 = bufferization.to_tensor %memspacecast restrict writable : memref<32x32xf32> + %86 = arith.mulf %85, %23 : tensor<32x32xf32> + %87 = arith.addf %84, %86 : tensor<32x32xf32> + %88 = tt.broadcast %76 : tensor<1x32xf32> -> tensor<32x32xf32> + %89 = arith.subf %87, %88 : tensor<32x32xf32> + %90 = math.exp2 %89 : tensor<32x32xf32> + %91 = arith.mulf %24, %90 : tensor<32x32xf32> + %92 = tt.broadcast %78 : tensor<1x32xf32> -> tensor<32x32xf32> + hivm.hir.sync_block_wait[, , ] flag = 2 + %memspacecast_9 = memref.memory_space_cast %alloc_6 : memref<32x32xf32, #hivm.address_space> to memref<32x32xf32> + %93 = bufferization.to_tensor %memspacecast_9 restrict writable : memref<32x32xf32> + %94 = arith.subf %93, %92 : tensor<32x32xf32> + %95 = arith.mulf %91, %94 : tensor<32x32xf32> + %96 = arith.truncf %90 : tensor<32x32xf32> to tensor<32x32xf16> + %97 = tt.reshape %96 : tensor<32x32xf16> -> tensor<2x16x2x16xf16> + %98 = tt.trans %97 {order = array} : tensor<2x16x2x16xf16> -> tensor<2x2x16x16xf16> + hivm.hir.sync_block_set[, , ] flag = 7 + hivm.hir.sync_block_set[, , ] flag = 8 + hivm.hir.sync_block_wait[, , ] flag = 11 + %99 = bufferization.to_memref %98 : memref<2x2x16x16xf16, #hivm.address_space> + hivm.hir.copy ins(%99 : memref<2x2x16x16xf16, #hivm.address_space>) outs(%alloc : memref<2x2x16x16xf16, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 5 + %100 = arith.truncf %95 : tensor<32x32xf32> to tensor<32x32xf16> + %101 = tt.reshape %100 : tensor<32x32xf16> -> tensor<2x16x2x16xf16> + %102 = tt.trans %101 {order = array} : tensor<2x16x2x16xf16> -> tensor<2x2x16x16xf16> + hivm.hir.sync_block_wait[, , ] flag = 12 + %103 = bufferization.to_memref %102 : memref<2x2x16x16xf16, #hivm.address_space> + hivm.hir.copy ins(%103 : memref<2x2x16x16xf16, #hivm.address_space>) outs(%alloc_5 : memref<2x2x16x16xf16, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 3 + hivm.hir.sync_block_wait[, , ] flag = 4 + %memspacecast_10 = memref.memory_space_cast %alloc_4 : memref<32x64xf32, #hivm.address_space> to memref<32x64xf32> + %104 = bufferization.to_tensor %memspacecast_10 restrict writable : memref<32x64xf32> + %105 = arith.addf %104, %arg35 : tensor<32x64xf32> + hivm.hir.sync_block_wait[, , ] flag = 6 + %memspacecast_11 = memref.memory_space_cast %alloc_8 : memref<32x64xf32, #hivm.address_space> to memref<32x64xf32> + %106 = bufferization.to_tensor %memspacecast_11 restrict writable : memref<32x64xf32> + %107 = arith.addf %106, %arg36 : tensor<32x64xf32> + %108 = arith.addi %arg37, %c32_i32 : i32 + %109 = arith.addi %arg38, %c32_i32 : i32 + %110 = arith.addi %arg39, %c32_i32 : i32 + %111 = arith.addi %arg40, %c32_i32 : i32 + hivm.hir.sync_block_set[, , ] flag = 9 + hivm.hir.sync_block_set[, , ] flag = 10 + scf.yield %105, %107, %108, %109, %110, %111 : tensor<32x64xf32>, tensor<32x64xf32>, i32, i32, i32, i32 + } + %71 = arith.truncf %70#0 : tensor<32x64xf32> to tensor<32x64xf16> + tt.store %52, %71, %42 : tensor<32x64x!tt.ptr> + %72 = arith.truncf %70#1 : tensor<32x64xf32> to tensor<32x64xf16> + tt.store %62, %72, %42 : tensor<32x64x!tt.ptr> + } + hivm.hir.sync_block_wait[, , ] flag = 11 + hivm.hir.sync_block_wait[, , ] flag = 12 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 12 + hivm.hir.sync_block_set[, , ] flag = 11 + %1 = tt.get_num_programs x : i32 + %2 = arith.extsi %arg14 : i32 to i64 + %3 = arith.extsi %arg18 : i32 to i64 + %4 = arith.extsi %arg20 : i32 to i64 + %5 = arith.subi %arg8, %c1_i32 : i32 + %6 = arith.extsi %arg16 : i32 to i64 + %7 = arith.extsi %arg24 : i32 to i64 + scf.for %arg33 = %0 to %arg32 step %1 : i32 { + %8 = arith.divsi %arg33, %arg32 : i32 + %9 = arith.remsi %arg33, %arg32 : i32 + %10 = arith.divsi %9, %arg13 : i32 + %11 = tt.addptr %arg10, %8 : !tt.ptr, i32 + %12 = tt.load %11 : !tt.ptr + %13 = tt.addptr %11, %c1_i32 : !tt.ptr, i32 + %14 = tt.load %13 : !tt.ptr + %15 = arith.subi %14, %12 : i32 + %16 = tt.addptr %arg11, %8 : !tt.ptr, i32 + %17 = tt.load %16 : !tt.ptr + %18 = tt.addptr %16, %c1_i32 : !tt.ptr, i32 + %19 = tt.load %18 : !tt.ptr + %20 = arith.subi %19, %17 : i32 + %21 = arith.muli %17, %arg18 : i32 + %22 = tt.addptr %arg1, %21 : !tt.ptr, i32 + %23 = arith.muli %10, %arg19 : i32 + %24 = tt.addptr %22, %23 : !tt.ptr, i32 + %25 = arith.extsi %20 : i32 to i64 + %26 = tt.make_tensor_ptr %24, [%25, %2], [%3, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %27 = arith.muli %17, %arg20 : i32 + %28 = tt.addptr %arg2, %27 : !tt.ptr, i32 + %29 = arith.muli %10, %arg21 : i32 + %30 = tt.addptr %28, %29 : !tt.ptr, i32 + %31 = tt.make_tensor_ptr %30, [%25, %2], [%4, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %32 = tt.load %26 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %33 = tt.load %31 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %34 = arith.muli %12, %arg16 : i32 + %35 = tt.addptr %arg0, %34 : !tt.ptr, i32 + %36 = arith.muli %9, %arg17 : i32 + %37 = tt.addptr %35, %36 : !tt.ptr, i32 + %38 = arith.extsi %15 : i32 to i64 + %39 = arith.muli %12, %arg24 : i32 + %40 = tt.addptr %arg5, %39 : !tt.ptr, i32 + %41 = arith.muli %9, %arg25 : i32 + %42 = tt.addptr %40, %41 : !tt.ptr, i32 + %43:4 = scf.for %arg34 = %5 to %15 step %c32_i32 iter_args(%arg35 = %5, %arg36 = %5, %arg37 = %5, %arg38 = %5) -> (i32, i32, i32, i32) : i32 { + %44 = tt.make_tensor_ptr %42, [%2, %38], [%c1_i64, %7], [%c0_i32, %arg36] {order = array} : > + %45 = tt.make_tensor_ptr %37, [%2, %38], [%c1_i64, %6], [%c0_i32, %arg35] {order = array} : > + %46 = tt.load %45 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %47 = tt.load %44 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %48 = tt.dot %32, %46, %cst_1 : tensor<32x64xf16> * tensor<64x32xf16> -> tensor<32x32xf32> + hivm.hir.sync_block_wait[, , ] flag = 7 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%48 : tensor<32x32xf32>) outs(%alloc_7 : memref<32x32xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 1 + %49 = tt.dot %33, %47, %cst_1 : tensor<32x64xf16> * tensor<64x32xf16> -> tensor<32x32xf32> + hivm.hir.sync_block_wait[, , ] flag = 8 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%49 : tensor<32x32xf32>) outs(%alloc_6 : memref<32x32xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 2 + %50 = tt.trans %46 {order = array} : tensor<64x32xf16> -> tensor<32x64xf16> + hivm.hir.sync_block_wait[, , ] flag = 3 + %51 = hivm.hir.convert_layout %alloc_5 {dstLayout = #hivm.data_layout, srcLayout = #hivm.data_layout} : (memref<2x2x16x16xf16, #hivm.address_space>) -> memref<32x32xf16, #hivm.address_space> + %memspacecast = memref.memory_space_cast %51 : memref<32x32xf16, #hivm.address_space> to memref<32x32xf16> + %52 = bufferization.to_tensor %memspacecast restrict writable : memref<32x32xf16> + %53 = tt.dot %52, %50, %cst : tensor<32x32xf16> * tensor<32x64xf16> -> tensor<32x64xf32> + hivm.hir.sync_block_set[, , ] flag = 12 + hivm.hir.sync_block_wait[, , ] flag = 9 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%53 : tensor<32x64xf32>) outs(%alloc_4 : memref<32x64xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 4 + %54 = tt.trans %47 {order = array} : tensor<64x32xf16> -> tensor<32x64xf16> + hivm.hir.sync_block_wait[, , ] flag = 5 + %55 = hivm.hir.convert_layout %alloc {dstLayout = #hivm.data_layout, srcLayout = #hivm.data_layout} : (memref<2x2x16x16xf16, #hivm.address_space>) -> memref<32x32xf16, #hivm.address_space> + %memspacecast_9 = memref.memory_space_cast %55 : memref<32x32xf16, #hivm.address_space> to memref<32x32xf16> + %56 = bufferization.to_tensor %memspacecast_9 restrict writable : memref<32x32xf16> + %57 = tt.dot %56, %54, %cst : tensor<32x32xf16> * tensor<32x64xf16> -> tensor<32x64xf32> + hivm.hir.sync_block_set[, , ] flag = 11 + hivm.hir.sync_block_wait[, , ] flag = 10 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%57 : tensor<32x64xf32>) outs(%alloc_8 : memref<32x64xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 6 + %58 = arith.addi %arg35, %c32_i32 : i32 + %59 = arith.addi %arg36, %c32_i32 : i32 + %60 = arith.addi %arg37, %c32_i32 : i32 + %61 = arith.addi %arg38, %c32_i32 : i32 + scf.yield %58, %59, %60, %61 : i32, i32, i32, i32 + } + } + hivm.hir.sync_block_wait[, , ] flag = 7 + hivm.hir.sync_block_wait[, , ] flag = 8 + hivm.hir.sync_block_wait[, , ] flag = 9 + hivm.hir.sync_block_wait[, , ] flag = 10 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/fixpipe_use_analysis.mlir b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/fixpipe_use_analysis.mlir new file mode 100644 index 0000000000..b64a5b826c --- /dev/null +++ b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/fixpipe_use_analysis.mlir @@ -0,0 +1,421 @@ +// RUN: triton-opt -allow-unregistered-dialect '--triton-to-linalg=named-ops=True enable-nd2nz-on-vector=True compile-on-910-95=True' --split-input-file %s -verify-each 2>&1 | FileCheck %s --check-prefix=NOERR +// NOERR-NOT: failed to legalize unresolved materialization +// CHECK: module +// CHECK: func.func public @_hstu_attn_fwd + +module { + tt.func public @_hstu_attn_fwd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: f32, %arg9: f32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c256 = arith.constant 256 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : tensor<256x1xi64> + %cst_0 = arith.constant dense : tensor<256x32xi1> + %cst_1 = arith.constant dense<0> : tensor<32x1xi64> + %cst_2 = arith.constant dense : tensor<32x32xi1> + %c32_i32 = arith.constant 32 : i32 + %c2_i32 = arith.constant 2 : i32 + %c8_i64 = arith.constant 8 : i64 + %c256_i32 = arith.constant 256 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i64 = arith.constant 2 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c128_i64 = arith.constant 128 : i64 + %c255_i32 = arith.constant 255 : i32 + %c1_i64 = arith.constant 1 : i64 + %c3_i32 = arith.constant 3 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_3 = arith.constant dense<0.000000e+00> : tensor<256x32xf16> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x32xf16> + %cst_5 = arith.constant dense<0.000000e+00> : tensor<32x256xf32> + %cst_6 = arith.constant dense<128> : tensor<256x1xi64> + %cst_7 = arith.constant dense<256> : tensor<32x1xi64> + %cst_8 = arith.constant dense<1.000000e+00> : tensor<32x256xf32> + %cst_9 = arith.constant dense<0.000000e+00> : tensor<32x32xf32> + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(32 : i64) : i64 + %2 = llvm.mlir.constant(64 : i64) : i64 + %3 = llvm.mlir.constant(96 : i64) : i64 + %4 = llvm.mlir.constant(0 : i32) : i32 + %5 = llvm.mlir.constant(1 : i32) : i32 + %6 = llvm.mlir.constant(2 : i64) : i64 + %7 = llvm.mlir.constant(2 : i32) : i32 + %8 = llvm.mlir.constant(4 : i32) : i32 + %9 = llvm.mlir.constant(6 : i32) : i32 + %10 = llvm.mlir.constant(1 : i64) : i64 + %11 = llvm.mlir.constant(3 : i32) : i32 + %c64_i64 = arith.constant 64 : i64 + %12 = llvm.mlir.constant(5 : i32) : i32 + %c0_i64 = arith.constant 0 : i64 + %alloc = memref.alloc() : memref<16x2x16x16xf16, #hivm.address_space> + %alloc_10 = memref.alloc() : memref<32x256xf32, #hivm.address_space> + %alloc_11 = memref.alloc() : memref<32x32xf32, #hivm.address_space> + %13 = tt.get_program_id x : i32 + %14 = tt.get_num_programs x : i32 + %15 = arith.cmpi sle, %arg10, %c32_i32 : i32 + %16 = scf.if %15 -> (i64) { + scf.yield %c2_i64 : i64 + } else { + %41 = tt.addptr %arg5, %c2_i32 : !tt.ptr, i32 + %42 = tt.load %41 : !tt.ptr + %43 = arith.extsi %42 : i32 to i64 + scf.yield %43 : i64 + } + %17 = arith.muli %16, %c8_i64 : i64 + %18 = arith.extsi %14 : i32 to i64 + %19 = arith.minsi %18, %17 : i64 + %20 = arith.divsi %17, %19 : i64 + %21 = arith.addi %20, %c1_i64 : i64 + %22 = arith.remsi %17, %19 : i64 + %23 = arith.extsi %13 : i32 to i64 + %24 = arith.cmpi slt, %23, %19 : i64 + %25 = arith.cmpi slt, %23, %22 : i64 + %26 = arith.muli %23, %21 : i64 + %27 = arith.muli %22, %21 : i64 + %28 = arith.subi %23, %22 : i64 + %29 = arith.muli %28, %20 : i64 + %30 = arith.addi %27, %29 : i64 + %31 = arith.select %25, %26, %30 : i64 + %32 = arith.select %24, %31, %c0_i64 : i64 + %33 = arith.select %25, %21, %20 : i64 + %34 = arith.select %24, %33, %c0_i64 : i64 + %35 = arith.cmpi sge, %23, %19 : i64 + cf.cond_br %35, ^bb1, ^bb2 + ^bb1: // 2 preds: ^bb0, ^bb2 + tt.return + ^bb2: // pred: ^bb0 + %36 = arith.cmpi sle, %34, %c0_i64 : i64 + cf.cond_br %36, ^bb1, ^bb3 + ^bb3: // pred: ^bb2 + %37 = llvm.inttoptr %0 : i64 to !llvm.ptr<11> + %38 = llvm.inttoptr %1 : i64 to !llvm.ptr<11> + %39 = llvm.inttoptr %2 : i64 to !llvm.ptr<11> + %40 = llvm.inttoptr %3 : i64 to !llvm.ptr<11> + llvm.store %4, %37 : i32, !llvm.ptr<11> + llvm.store %4, %38 : i32, !llvm.ptr<11> + llvm.store %4, %39 : i32, !llvm.ptr<11> + llvm.store %4, %40 : i32, !llvm.ptr<11> + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 14 + %41 = hivm.hir.get_sub_block_idx -> i64 + %42 = arith.muli %41, %1 : i64 + %43 = arith.addi %42, %1 : i64 + hivm.hir.sync_block_set[, , ] flag = 5 + hivm.hir.sync_block_set[, , ] flag = 4 + %44 = arith.addi %arg11, %c255_i32 : i32 + %45 = arith.divsi %44, %c256_i32 : i32 + %46 = arith.extsi %45 : i32 to i64 + %47 = arith.muli %34, %46 : i64 + %48 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %49 = arith.extsi %48 : tensor<32xi32> to tensor<32xi64> + %50 = tt.splat %arg8 : f32 -> tensor<32x256xf32> + %51 = tt.splat %arg9 : f32 -> tensor<32x256xf32> + %52 = arith.muli %47, %c2_i64 : i64 + %53 = arith.divsi %52, %6 : i64 + %54:5 = scf.for %arg13 = %c0_i64 to %52 step %c1_i64 iter_args(%arg14 = %c0_i64, %arg15 = %cst_1, %arg16 = %cst_2, %arg17 = %c0_i64, %arg18 = %c0_i64) -> (i64, tensor<32x1xi64>, tensor<32x32xi1>, i64, i64) : i64 { + hivm.hir.sync_block_wait[, , ] flag = 15 + %55 = llvm.inttoptr %43 : i64 to !llvm.ptr<11> + %56 = llvm.load %55 : !llvm.ptr<11> -> i32 + %57 = arith.andi %56, %5 : i32 + %58 = arith.cmpi eq, %57, %5 : i32 + %59 = arith.andi %56, %7 : i32 + %60 = arith.cmpi eq, %59, %c0_i32 : i32 + %61 = arith.andi %56, %8 : i32 + %62 = arith.cmpi eq, %61, %8 : i32 + %63 = arith.cmpi slt, %arg17, %53 : i64 + %64 = arith.andi %58, %60 : i1 + %65 = arith.andi %64, %63 : i1 + %66 = arith.cmpi slt, %arg18, %53 : i64 + %67 = arith.andi %62, %66 : i1 + %68:4 = scf.if %65 -> (i64, tensor<32x1xi64>, tensor<32x32xi1>, i64) { + %70 = arith.divsi %arg13, %46 : i64 + %71 = arith.addi %32, %70 : i64 + %72 = arith.divsi %71, %16 : i64 + %73 = arith.remsi %71, %16 : i64 + %74:2 = scf.if %15 -> (i64, i64) { + scf.yield %73, %c0_i64 : i64, i64 + } else { + %108:2 = scf.for %arg19 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg20 = %c0_i32, %arg21 = %c3_i32) -> (i32, i32) : i32 { + %115 = arith.addi %arg20, %arg21 : i32 + %116 = arith.divsi %115, %c2_i32 : i32 + %117 = tt.addptr %arg5, %116 : !tt.ptr, i32 + %118 = tt.load %117 : !tt.ptr + %119 = arith.extsi %118 : i32 to i64 + %120 = arith.cmpi sle, %119, %73 : i64 + %121 = arith.select %120, %arg21, %116 : i32 + %122 = scf.if %120 -> (i32) { + %123 = arith.addi %116, %c1_i32 : i32 + scf.yield %123 : i32 + } else { + scf.yield %arg20 : i32 + } + scf.yield %122, %121 : i32, i32 + } + %109 = arith.subi %108#0, %c1_i32 : i32 + %110 = arith.extsi %109 : i32 to i64 + %111 = tt.addptr %arg5, %110 : !tt.ptr, i64 + %112 = tt.load %111 : !tt.ptr + %113 = arith.extsi %112 : i32 to i64 + %114 = arith.subi %73, %113 : i64 + scf.yield %110, %114 : i64, i64 + } + %75 = tt.addptr %arg3, %74#0 : !tt.ptr, i64 + %76 = tt.load %75 : !tt.ptr + %77 = tt.addptr %75, %c1_i32 : !tt.ptr, i32 + %78 = tt.load %77 : !tt.ptr + %79 = arith.subi %78, %76 : i64 + %80 = arith.muli %72, %c32_i64 : i64 + %81 = arith.muli %76, %c256_i64 : i64 + %82 = arith.addi %80, %81 : i64 + %83 = arith.muli %74#1, %c32_i64 : i64 + %84 = tt.splat %83 : i64 -> tensor<32xi64> + %85 = arith.addi %84, %49 : tensor<32xi64> + %86 = tt.splat %79 : i64 -> tensor<32xi64> + %87 = arith.cmpi slt, %85, %86 : tensor<32xi64> + %88 = tt.expand_dims %85 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64> + %89 = arith.muli %88, %cst_7 : tensor<32x1xi64> + %90 = tt.expand_dims %87 {axis = 1 : i32} : tensor<32xi1> -> tensor<32x1xi1> + %91 = tt.broadcast %90 : tensor<32x1xi1> -> tensor<32x32xi1> + hivm.hir.sync_block_wait[, , ] flag = 1 + %memspacecast = memref.memory_space_cast %alloc_10 : memref<32x256xf32, #hivm.address_space> to memref<32x256xf32> + %92 = bufferization.to_tensor %memspacecast restrict writable : memref<32x256xf32> + %93 = arith.mulf %92, %50 : tensor<32x256xf32> + %94 = arith.subf %cst_5, %93 : tensor<32x256xf32> + %95 = math.exp %94 : tensor<32x256xf32> + %96 = arith.addf %95, %cst_8 : tensor<32x256xf32> + %97 = arith.divf %cst_8, %96 : tensor<32x256xf32> + %98 = arith.mulf %93, %97 : tensor<32x256xf32> + %99 = arith.mulf %98, %51 : tensor<32x256xf32> + %100 = arith.truncf %99 : tensor<32x256xf32> to tensor<32x256xf16> + %101 = tt.reshape %100 : tensor<32x256xf16> -> tensor<2x16x16x16xf16> + %102 = tt.trans %101 {order = array} : tensor<2x16x16x16xf16> -> tensor<16x2x16x16xf16> + hivm.hir.sync_block_set[, , ] flag = 4 + hivm.hir.sync_block_wait[, , ] flag = 6 + %103 = bufferization.to_memref %102 : memref<16x2x16x16xf16, #hivm.address_space> + hivm.hir.copy ins(%103 : memref<16x2x16x16xf16, #hivm.address_space>) outs(%alloc : memref<16x2x16x16xf16, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 2 + %104 = llvm.load %55 : !llvm.ptr<11> -> i32 + %105 = arith.andi %104, %9 : i32 + %106 = arith.ori %105, %7 : i32 + llvm.store %106, %55 : i32, !llvm.ptr<11> + %107 = arith.addi %arg17, %10 : i64 + scf.yield %82, %89, %91, %107 : i64, tensor<32x1xi64>, tensor<32x32xi1>, i64 + } else { + scf.yield %arg14, %arg15, %arg16, %arg17 : i64, tensor<32x1xi64>, tensor<32x32xi1>, i64 + } + %69 = scf.if %67 -> (i64) { + hivm.hir.sync_block_wait[, , ] flag = 3 + %memspacecast = memref.memory_space_cast %alloc_11 : memref<32x32xf32, #hivm.address_space> to memref<32x32xf32> + %70 = bufferization.to_tensor %memspacecast restrict writable : memref<32x32xf32> + scf.for %arg19 = %c0 to %c32 step %c1 { + scf.for %arg20 = %c0 to %c32 step %c1 { + %extracted = tensor.extract %68#1[%arg19, %c0] {DiscreteMemAccess} : tensor<32x1xi64> + %74 = arith.addi %68#0, %extracted : i64 + %75 = arith.index_cast %arg20 : index to i32 + %76 = arith.extsi %75 : i32 to i64 + %77 = arith.addi %74, %76 : i64 + %78 = tt.addptr %arg7, %77 : !tt.ptr, i64 + %extracted_12 = tensor.extract %70[%arg19, %arg20] {DiscreteMemAccess} : tensor<32x32xf32> + %79 = arith.truncf %extracted_12 : f32 to f16 + %extracted_13 = tensor.extract %68#2[%arg19, %arg20] {DiscreteMemAccess} : tensor<32x32xi1> + tt.store %78, %79, %extracted_13 {DiscreteMemAccess} : !tt.ptr + } {ExtractedLoadOrStore} + } {ExtractedLoadOrStore} + hivm.hir.sync_block_set[, , ] flag = 5 + %71 = llvm.load %55 : !llvm.ptr<11> -> i32 + %72 = arith.andi %71, %11 : i32 + llvm.store %72, %55 : i32, !llvm.ptr<11> + %73 = arith.addi %arg18, %10 : i64 + scf.yield %73 : i64 + } else { + scf.yield %arg18 : i64 + } + hivm.hir.sync_block_set[, , ] flag = 14 + scf.yield %68#0, %68#1, %68#2, %68#3, %69 : i64, tensor<32x1xi64>, tensor<32x32xi1>, i64, i64 + } + hivm.hir.sync_block_wait[, , ] flag = 6 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 6 + %41 = arith.addi %arg11, %c255_i32 : i32 + %42 = arith.divsi %41, %c256_i32 : i32 + %43 = arith.extsi %42 : i32 to i64 + %44 = arith.muli %34, %43 : i64 + %45 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %46 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %47 = arith.extsi %45 : tensor<32xi32> to tensor<32xi64> + %48 = tt.expand_dims %45 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %49 = tt.broadcast %48 : tensor<1x32xi32> -> tensor<32x32xi32> + %50 = arith.extsi %46 : tensor<256xi32> to tensor<256xi64> + %51 = tt.broadcast %48 : tensor<1x32xi32> -> tensor<256x32xi32> + %52 = arith.muli %44, %c2_i64 : i64 + %53 = arith.divsi %52, %6 : i64 + %54:5 = scf.for %arg13 = %c0_i64 to %52 step %c1_i64 iter_args(%arg14 = %c0_i64, %arg15 = %cst, %arg16 = %cst_0, %arg17 = %c0_i64, %arg18 = %c0_i64) -> (i64, tensor<256x1xi64>, tensor<256x32xi1>, i64, i64) : i64 { + hivm.hir.sync_block_wait[, , ] flag = 14 + %55 = llvm.inttoptr %c32_i64 : i64 to !llvm.ptr<11> + %56 = llvm.inttoptr %c64_i64 : i64 to !llvm.ptr<11> + %57 = llvm.load %55 : !llvm.ptr<11> -> i32 + %58 = llvm.load %56 : !llvm.ptr<11> -> i32 + %59 = arith.andi %57, %5 : i32 + %60 = arith.andi %58, %5 : i32 + %61 = arith.cmpi eq, %59, %c0_i32 : i32 + %62 = arith.cmpi eq, %60, %c0_i32 : i32 + %63 = arith.andi %61, %62 : i1 + %64 = arith.andi %57, %7 : i32 + %65 = arith.andi %58, %7 : i32 + %66 = arith.cmpi eq, %64, %7 : i32 + %67 = arith.cmpi eq, %65, %7 : i32 + %68 = arith.andi %66, %67 : i1 + %69 = arith.andi %57, %8 : i32 + %70 = arith.andi %58, %8 : i32 + %71 = arith.cmpi eq, %69, %c0_i32 : i32 + %72 = arith.cmpi eq, %70, %c0_i32 : i32 + %73 = arith.andi %71, %72 : i1 + %74 = arith.cmpi slt, %arg17, %53 : i64 + %75 = arith.andi %63, %74 : i1 + %76 = arith.cmpi slt, %arg18, %53 : i64 + %77 = arith.andi %68, %73 : i1 + %78 = arith.andi %77, %76 : i1 + %79:4 = scf.if %75 -> (i64, tensor<256x1xi64>, tensor<256x32xi1>, i64) { + %81 = arith.divsi %arg13, %43 : i64 + %82 = arith.addi %32, %81 : i64 + %83 = arith.remsi %arg13, %43 : i64 + %84 = arith.divsi %82, %16 : i64 + %85 = arith.remsi %82, %16 : i64 + %86:2 = scf.if %15 -> (i64, i64) { + scf.yield %85, %c0_i64 : i64, i64 + } else { + %140:2 = scf.for %arg19 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg20 = %c0_i32, %arg21 = %c3_i32) -> (i32, i32) : i32 { + %147 = arith.addi %arg20, %arg21 : i32 + %148 = arith.divsi %147, %c2_i32 : i32 + %149 = tt.addptr %arg5, %148 : !tt.ptr, i32 + %150 = tt.load %149 : !tt.ptr + %151 = arith.extsi %150 : i32 to i64 + %152 = arith.cmpi sle, %151, %85 : i64 + %153 = arith.select %152, %arg21, %148 : i32 + %154 = scf.if %152 -> (i32) { + %155 = arith.addi %148, %c1_i32 : i32 + scf.yield %155 : i32 + } else { + scf.yield %arg20 : i32 + } + scf.yield %154, %153 : i32, i32 + } + %141 = arith.subi %140#0, %c1_i32 : i32 + %142 = arith.extsi %141 : i32 to i64 + %143 = tt.addptr %arg5, %142 : !tt.ptr, i64 + %144 = tt.load %143 : !tt.ptr + %145 = arith.extsi %144 : i32 to i64 + %146 = arith.subi %85, %145 : i64 + scf.yield %142, %146 : i64, i64 + } + %87 = arith.divsi %84, %c2_i64 : i64 + %88 = tt.addptr %arg3, %86#0 : !tt.ptr, i64 + %89 = tt.load %88 : !tt.ptr + %90 = tt.addptr %88, %c1_i32 : !tt.ptr, i32 + %91 = tt.load %90 : !tt.ptr + %92 = tt.addptr %arg4, %86#0 : !tt.ptr, i64 + %93 = tt.load %92 : !tt.ptr + %94 = tt.addptr %92, %c1_i32 : !tt.ptr, i32 + %95 = tt.load %94 : !tt.ptr + %96 = arith.subi %91, %89 : i64 + %97 = arith.subi %95, %93 : i64 + %98 = arith.muli %84, %c32_i64 : i64 + %99 = arith.muli %89, %c256_i64 : i64 + %100 = arith.addi %98, %99 : i64 + %101 = tt.addptr %arg0, %100 : !tt.ptr, i64 + %102 = arith.muli %87, %c32_i64 : i64 + %103 = arith.muli %93, %c128_i64 : i64 + %104 = arith.addi %102, %103 : i64 + %105 = tt.addptr %arg1, %104 : !tt.ptr, i64 + %106 = arith.muli %83, %c256_i64 : i64 + %107 = arith.muli %86#1, %c32_i64 : i64 + %108 = tt.splat %107 : i64 -> tensor<32xi64> + %109 = arith.addi %108, %47 : tensor<32xi64> + %110 = tt.splat %96 : i64 -> tensor<32xi64> + %111 = arith.cmpi slt, %109, %110 : tensor<32xi64> + %112 = tt.expand_dims %109 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64> + %113 = arith.muli %112, %cst_7 : tensor<32x1xi64> + %114 = tt.splat %101 : !tt.ptr -> tensor<32x1x!tt.ptr> + %115 = tt.addptr %114, %113 : tensor<32x1x!tt.ptr>, tensor<32x1xi64> + %116 = tt.broadcast %115 : tensor<32x1x!tt.ptr> -> tensor<32x32x!tt.ptr> + %117 = tt.addptr %116, %49 : tensor<32x32x!tt.ptr>, tensor<32x32xi32> + %118 = tt.expand_dims %111 {axis = 1 : i32} : tensor<32xi1> -> tensor<32x1xi1> + %119 = tt.broadcast %118 : tensor<32x1xi1> -> tensor<32x32xi1> + %120 = tt.load %117, %119, %cst_4 : tensor<32x32x!tt.ptr> + %121 = tt.splat %106 : i64 -> tensor<256xi64> + %122 = arith.addi %121, %50 : tensor<256xi64> + %123 = tt.splat %97 : i64 -> tensor<256xi64> + %124 = arith.cmpi slt, %122, %123 : tensor<256xi64> + %125 = tt.expand_dims %122 {axis = 1 : i32} : tensor<256xi64> -> tensor<256x1xi64> + %126 = arith.muli %125, %cst_6 : tensor<256x1xi64> + %127 = tt.splat %105 : !tt.ptr -> tensor<256x1x!tt.ptr> + %128 = tt.addptr %127, %126 : tensor<256x1x!tt.ptr>, tensor<256x1xi64> + %129 = tt.broadcast %128 : tensor<256x1x!tt.ptr> -> tensor<256x32x!tt.ptr> + %130 = tt.addptr %129, %51 : tensor<256x32x!tt.ptr>, tensor<256x32xi32> + %131 = tt.expand_dims %124 {axis = 1 : i32} : tensor<256xi1> -> tensor<256x1xi1> + %132 = tt.broadcast %131 : tensor<256x1xi1> -> tensor<256x32xi1> + %133 = tt.load %130, %132, %cst_3 : tensor<256x32x!tt.ptr> + %134 = tt.trans %133 {order = array} : tensor<256x32xf16> -> tensor<32x256xf16> + %135 = tt.dot %120, %134, %cst_5 : tensor<32x32xf16> * tensor<32x256xf16> -> tensor<32x256xf32> + hivm.hir.sync_block_wait[, , ] flag = 4 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%135 : tensor<32x256xf32>) outs(%alloc_10 : memref<32x256xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 1 + %136 = llvm.load %55 : !llvm.ptr<11> -> i32 + %137 = arith.ori %136, %5 : i32 + %138 = arith.ori %137, %5 : i32 + llvm.store %137, %55 : i32, !llvm.ptr<11> + llvm.store %138, %56 : i32, !llvm.ptr<11> + %139 = arith.addi %arg17, %10 : i64 + scf.yield %104, %126, %132, %139 : i64, tensor<256x1xi64>, tensor<256x32xi1>, i64 + } else { + scf.yield %arg14, %arg15, %arg16, %arg17 : i64, tensor<256x1xi64>, tensor<256x32xi1>, i64 + } + %80 = scf.if %78 -> (i64) { + %81 = tensor.empty() : tensor<256x32xf16> + %82 = scf.for %arg19 = %c0 to %c256 step %c1 iter_args(%arg20 = %81) -> (tensor<256x32xf16>) { + %extracted = tensor.extract %79#1[%arg19, %c0] {DiscreteMemAccess} : tensor<256x1xi64> + %92 = arith.addi %79#0, %extracted : i64 + %93 = tt.splat %92 : i64 -> tensor<1x32xi64> + %94 = arith.extsi %48 : tensor<1x32xi32> to tensor<1x32xi64> + %95 = arith.addi %93, %94 : tensor<1x32xi64> + %96 = tt.splat %arg2 : !tt.ptr -> tensor<1x32x!tt.ptr> + %97 = tt.addptr %96, %95 : tensor<1x32x!tt.ptr>, tensor<1x32xi64> + %98 = tt.load %97 {DiscreteMemAccess} : tensor<1x32x!tt.ptr> + %inserted_slice = tensor.insert_slice %98 into %arg20[%arg19, 0] [1, 32] [1, 1] : tensor<1x32xf16> into tensor<256x32xf16> + scf.yield {DiscreteMemAccess} %inserted_slice : tensor<256x32xf16> + } {ExtractedLoadOrStore} + %83 = arith.select %79#2, %82, %cst_3 : tensor<256x32xi1>, tensor<256x32xf16> + hivm.hir.sync_block_wait[, , ] flag = 2 + %84 = hivm.hir.convert_layout %alloc {dstLayout = #hivm.data_layout, srcLayout = #hivm.data_layout} : (memref<16x2x16x16xf16, #hivm.address_space>) -> memref<32x256xf16, #hivm.address_space> + %memspacecast = memref.memory_space_cast %84 : memref<32x256xf16, #hivm.address_space> to memref<32x256xf16> + %85 = bufferization.to_tensor %memspacecast restrict writable : memref<32x256xf16> + %86 = tt.dot %85, %83, %cst_9 : tensor<32x256xf16> * tensor<256x32xf16> -> tensor<32x32xf32> + hivm.hir.sync_block_set[, , ] flag = 6 + hivm.hir.sync_block_wait[, , ] flag = 5 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%86 : tensor<32x32xf32>) outs(%alloc_11 : memref<32x32xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 3 + %87 = llvm.load %55 : !llvm.ptr<11> -> i32 + %88 = arith.andi %87, %12 : i32 + %89 = arith.ori %88, %8 : i32 + %90 = arith.ori %89, %8 : i32 + llvm.store %89, %55 : i32, !llvm.ptr<11> + llvm.store %90, %56 : i32, !llvm.ptr<11> + %91 = arith.addi %arg18, %10 : i64 + scf.yield %91 : i64 + } else { + scf.yield %arg18 : i64 + } + hivm.hir.sync_block_set[, , ] flag = 15 + scf.yield %79#0, %79#1, %79#2, %79#3, %80 : i64, tensor<256x1xi64>, tensor<256x32xi1>, i64, i64 + } + hivm.hir.sync_block_wait[, , ] flag = 4 + hivm.hir.sync_block_wait[, , ] flag = 5 + hivm.hir.sync_block_wait[, , ] flag = 14 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/if_use_analysis.mlir b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/if_use_analysis.mlir new file mode 100644 index 0000000000..6a076b02e7 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/if_use_analysis.mlir @@ -0,0 +1,479 @@ +// RUN: triton-opt -allow-unregistered-dialect '--triton-to-linalg=named-ops=True enable-nd2nz-on-vector=True compile-on-910-95=True' --split-input-file %s -verify-each 2>&1 | FileCheck %s --check-prefix=NOERR +// NOERR-NOT: failed to legalize unresolved materialization +// CHECK: module +// CHECK: func.func public @dsa_prefill_kernel + +module { + tt.func public @dsa_prefill_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: f32) attributes {noinline = false} { + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : tensor<16x128xbf16> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x192xbf16> + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<1.000000e+00> : tensor<16xf32> + %cst_2 = arith.constant dense<0xFF800000> : tensor<16xf32> + %cst_3 = arith.constant dense<9.99999996E-13> : tensor<16xf32> + %cst_4 = arith.constant dense<0xFF800000> : tensor<16x16xf32> + %cst_5 = arith.constant dense<0> : tensor<16x16xi8> + %cst_6 = arith.constant dense<1024> : tensor<1x16xi32> + %cst_7 = arith.constant dense<0.000000e+00> : tensor<16x16xf32> + %c1_i32 = arith.constant 1 : i32 + %cst_8 = arith.constant dense<1024> : tensor<16x1xi32> + %c16_i32 = arith.constant 16 : i32 + %cst_9 = arith.constant dense<0.000000e+00> : tensor<16xf32> + %cst_10 = arith.constant dense<0.000000e+00> : tensor<16x128xf32> + %cst_11 = arith.constant dense : tensor<16x1xi1> + %cst_12 = arith.constant dense<0> : tensor<16x1xi32> + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(32 : i64) : i64 + %2 = llvm.mlir.constant(64 : i64) : i64 + %3 = llvm.mlir.constant(96 : i64) : i64 + %4 = llvm.mlir.constant(0 : i32) : i32 + %5 = llvm.mlir.constant(1 : i32) : i32 + %c2_i32 = arith.constant 2 : i32 + %6 = llvm.mlir.constant(2 : i32) : i32 + %7 = llvm.mlir.constant(4 : i32) : i32 + %c3_i32 = arith.constant 3 : i32 + %c4_i32 = arith.constant 4 : i32 + %c6_i32 = arith.constant 6 : i32 + %8 = llvm.mlir.constant(6 : i32) : i32 + %9 = llvm.mlir.constant(3 : i32) : i32 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %10 = llvm.mlir.constant(5 : i32) : i32 + %alloc = memref.alloc() : memref<1x1x16x16xbf16, #hivm.address_space> + %alloc_13 = memref.alloc() : memref<16x16xf32, #hivm.address_space> + %alloc_14 = memref.alloc() : memref<16x128xf32, #hivm.address_space> + %11 = tt.get_program_id x : i32 + %12 = llvm.inttoptr %0 : i64 to !llvm.ptr<11> + %13 = llvm.inttoptr %1 : i64 to !llvm.ptr<11> + %14 = llvm.inttoptr %2 : i64 to !llvm.ptr<11> + %15 = llvm.inttoptr %3 : i64 to !llvm.ptr<11> + llvm.store %4, %12 : i32, !llvm.ptr<11> + llvm.store %4, %13 : i32, !llvm.ptr<11> + llvm.store %4, %14 : i32, !llvm.ptr<11> + llvm.store %4, %15 : i32, !llvm.ptr<11> + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 14 + %16 = hivm.hir.get_sub_block_idx -> i64 + %17 = arith.muli %16, %1 : i64 + %18 = arith.addi %17, %1 : i64 + hivm.hir.sync_block_set[, , ] flag = 5 + hivm.hir.sync_block_set[, , ] flag = 4 + %19 = arith.divsi %11, %c16_i32 : i32 + %20 = arith.remsi %11, %c16_i32 : i32 + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %22 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %23 = tt.splat %arg19 : f32 -> tensor<16x16xf32> + %24 = arith.muli %19, %arg17 : i32 + %25 = tt.splat %arg18 : i32 -> tensor<16x1xi32> + %26 = tt.splat %24 : i32 -> tensor<16x1xi32> + %27 = tt.splat %arg4 : !tt.ptr -> tensor<16x16x!tt.ptr> + %28 = arith.muli %19, %arg14 : i32 + %29 = arith.muli %20, %arg15 : i32 + %30 = arith.addi %28, %29 : i32 + %31 = tt.splat %arg16 : i32 -> tensor<16x1xi32> + %32 = tt.splat %30 : i32 -> tensor<16x1xi32> + %33 = tt.expand_dims %21 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %34 = tt.broadcast %33 : tensor<1x128xi32> -> tensor<16x128xi32> + %35 = tt.splat %arg3 : !tt.ptr -> tensor<16x128x!tt.ptr> + scf.for %arg20 = %c0_i32 to %c1024_i32 step %c16_i32 : i32 { + %36 = tt.splat %arg20 : i32 -> tensor<16xi32> + %37 = arith.addi %36, %22 : tensor<16xi32> + %38 = tt.expand_dims %37 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %39 = arith.cmpi slt, %38, %cst_8 : tensor<16x1xi32> + %40 = arith.addi %arg20, %c1_i32 : i32 + %41 = arith.muli %38, %25 : tensor<16x1xi32> + %42 = arith.addi %26, %41 : tensor<16x1xi32> + %43 = tt.broadcast %42 : tensor<16x1xi32> -> tensor<16x16xi32> + %44 = tt.broadcast %39 : tensor<16x1xi1> -> tensor<16x16xi1> + %45 = arith.muli %40, %c2_i32 : i32 + %46 = arith.divsi %45, %c16_i32 : i32 + %47 = arith.divsi %46, %6 : i32 + %48:20 = scf.for %arg21 = %c0_i32 to %45 step %c16_i32 iter_args(%arg22 = %cst_10, %arg23 = %cst_2, %arg24 = %cst_10, %arg25 = %cst_9, %arg26 = %c0_i32, %arg27 = %c0_i32, %arg28 = %cst_9, %arg29 = %cst_9, %arg30 = %cst_9, %arg31 = %cst_9, %arg32 = %cst_9, %arg33 = %c0_i32, %arg34 = %c0_i32, %arg35 = %cst_10, %arg36 = %cst_10, %arg37 = %cst_10, %arg38 = %cst_10, %arg39 = %cst_10, %arg40 = %c0_i32, %arg41 = %c0_i32) -> (tensor<16x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, i32, i32, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i32, i32, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, i32, i32) : i32 { + hivm.hir.sync_block_wait[, , ] flag = 15 + %57 = llvm.inttoptr %18 : i64 to !llvm.ptr<11> + %58 = llvm.load %57 : !llvm.ptr<11> -> i32 + %59 = arith.andi %58, %5 : i32 + %60 = arith.cmpi eq, %59, %5 : i32 + %61 = arith.andi %58, %6 : i32 + %62 = arith.cmpi eq, %61, %c0_i32 : i32 + %63 = arith.andi %58, %7 : i32 + %64 = arith.cmpi eq, %63, %7 : i32 + %65 = arith.cmpi slt, %arg26, %47 : i32 + %66 = arith.andi %60, %62 : i1 + %67 = arith.andi %66, %65 : i1 + %68 = arith.cmpi slt, %arg27, %47 : i32 + %69 = arith.andi %64, %68 : i1 + %70:16 = scf.if %67 -> (tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, i32, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i32, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, i32) { + %72 = tt.splat %arg21 : i32 -> tensor<16xi32> + %73 = arith.addi %72, %22 : tensor<16xi32> + hivm.hir.sync_block_wait[, , ] flag = 1 + %memspacecast = memref.memory_space_cast %alloc_13 : memref<16x16xf32, #hivm.address_space> to memref<16x16xf32> + %74 = bufferization.to_tensor %memspacecast restrict writable : memref<16x16xf32> + %75 = arith.mulf %74, %23 : tensor<16x16xf32> + %76 = tt.expand_dims %73 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %77 = tt.broadcast %76 : tensor<1x16xi32> -> tensor<16x16xi32> + %78 = arith.addi %43, %77 : tensor<16x16xi32> + %79 = arith.cmpi slt, %76, %cst_6 : tensor<1x16xi32> + %80 = tt.broadcast %79 : tensor<1x16xi1> -> tensor<16x16xi1> + %81 = arith.andi %44, %80 : tensor<16x16xi1> + %82 = tt.addptr %27, %78 : tensor<16x16x!tt.ptr>, tensor<16x16xi32> + %83 = tt.bitcast %82 : tensor<16x16x!tt.ptr> -> tensor<16x16x!tt.ptr> + %84 = tt.load %83, %81, %cst_5 : tensor<16x16x!tt.ptr> + %85 = arith.cmpi ne, %84, %cst_5 : tensor<16x16xi8> + %86 = arith.select %85, %75, %cst_4 : tensor<16x16xi1>, tensor<16x16xf32> + %87 = "tt.reduce"(%86) <{axis = 1 : i32}> ({ + ^bb0(%arg42: f32, %arg43: f32): + %132 = arith.maxnumf %arg42, %arg43 : f32 + tt.reduce.return %132 : f32 + }) : (tensor<16x16xf32>) -> tensor<16xf32> + %88 = tt.expand_dims %87 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32> + %89 = tt.broadcast %88 : tensor<16x1xf32> -> tensor<16x16xf32> + %90 = arith.subf %86, %89 : tensor<16x16xf32> + %91 = math.exp %90 : tensor<16x16xf32> + %92 = "tt.reduce"(%91) <{axis = 1 : i32}> ({ + ^bb0(%arg42: f32, %arg43: f32): + %132 = arith.addf %arg42, %arg43 : f32 + tt.reduce.return %132 : f32 + }) : (tensor<16x16xf32>) -> tensor<16xf32> + %93 = math.log %92 : tensor<16xf32> + %94 = arith.addf %87, %93 : tensor<16xf32> + %95 = math.exp %arg23 : tensor<16xf32> + %96 = arith.addf %94, %cst_3 : tensor<16xf32> + %97 = math.exp %96 : tensor<16xf32> + %98 = arith.addf %95, %97 : tensor<16xf32> + %99 = math.log %98 : tensor<16xf32> + %100 = arith.cmpf une, %99, %99 : tensor<16xf32> + %101 = arith.select %100, %arg23, %99 : tensor<16xi1>, tensor<16xf32> + %102 = arith.subf %arg23, %101 : tensor<16xf32> + %103 = math.exp %102 : tensor<16xf32> + %104 = arith.cmpf oeq, %87, %cst_2 : tensor<16xf32> + %105 = arith.select %104, %cst_1, %103 : tensor<16xi1>, tensor<16xf32> + %106 = tt.expand_dims %105 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32> + %107 = tt.broadcast %106 : tensor<16x1xf32> -> tensor<16x128xf32> + %108 = arith.mulf %arg22, %107 : tensor<16x128xf32> + %109 = arith.remsi %arg40, %c6_i32 : i32 + %110 = arith.cmpi eq, %109, %c0_i32 : i32 + %111 = arith.select %110, %108, %arg24 : tensor<16x128xf32> + %112:5 = scf.if %110 -> (tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>) { + scf.yield %arg35, %arg36, %arg37, %arg38, %arg39 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } else { + %132 = arith.cmpi eq, %109, %c1_i32 : i32 + %133 = arith.select %132, %108, %arg35 : tensor<16x128xf32> + %134:4 = scf.if %132 -> (tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>) { + scf.yield %arg36, %arg37, %arg38, %arg39 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } else { + %135 = arith.cmpi eq, %109, %c2_i32 : i32 + %136 = arith.select %135, %108, %arg36 : tensor<16x128xf32> + %137:3 = scf.if %135 -> (tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>) { + scf.yield %arg37, %arg38, %arg39 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } else { + %138 = arith.cmpi eq, %109, %c3_i32 : i32 + %139 = arith.select %138, %108, %arg37 : tensor<16x128xf32> + %140:2 = scf.if %138 -> (tensor<16x128xf32>, tensor<16x128xf32>) { + scf.yield %arg38, %arg39 : tensor<16x128xf32>, tensor<16x128xf32> + } else { + %141 = arith.cmpi eq, %109, %c4_i32 : i32 + %142 = arith.select %141, %108, %arg38 : tensor<16x128xf32> + %143 = arith.select %141, %arg39, %108 : tensor<16x128xf32> + scf.yield %142, %143 : tensor<16x128xf32>, tensor<16x128xf32> + } + scf.yield %139, %140#0, %140#1 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } + scf.yield %136, %137#0, %137#1, %137#2 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } + scf.yield %133, %134#0, %134#1, %134#2, %134#3 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } + %113 = arith.addi %arg40, %c1_i32 : i32 + %114 = arith.subf %94, %101 : tensor<16xf32> + %115 = math.exp %114 : tensor<16xf32> + %116 = arith.remsi %arg33, %c6_i32 : i32 + %117 = arith.cmpi eq, %116, %c0_i32 : i32 + %118 = arith.select %117, %115, %arg25 : tensor<16xf32> + %119:5 = scf.if %117 -> (tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) { + scf.yield %arg28, %arg29, %arg30, %arg31, %arg32 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } else { + %132 = arith.cmpi eq, %116, %c1_i32 : i32 + %133 = arith.select %132, %115, %arg28 : tensor<16xf32> + %134:4 = scf.if %132 -> (tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) { + scf.yield %arg29, %arg30, %arg31, %arg32 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } else { + %135 = arith.cmpi eq, %116, %c2_i32 : i32 + %136 = arith.select %135, %115, %arg29 : tensor<16xf32> + %137:3 = scf.if %135 -> (tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) { + scf.yield %arg30, %arg31, %arg32 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } else { + %138 = arith.cmpi eq, %116, %c3_i32 : i32 + %139 = arith.select %138, %115, %arg30 : tensor<16xf32> + %140:2 = scf.if %138 -> (tensor<16xf32>, tensor<16xf32>) { + scf.yield %arg31, %arg32 : tensor<16xf32>, tensor<16xf32> + } else { + %141 = arith.cmpi eq, %116, %c4_i32 : i32 + %142 = arith.select %141, %115, %arg31 : tensor<16xf32> + %143 = arith.select %141, %arg32, %115 : tensor<16xf32> + scf.yield %142, %143 : tensor<16xf32>, tensor<16xf32> + } + scf.yield %139, %140#0, %140#1 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } + scf.yield %136, %137#0, %137#1, %137#2 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } + scf.yield %133, %134#0, %134#1, %134#2, %134#3 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } + %120 = arith.addi %arg33, %c1_i32 : i32 + %121 = tt.expand_dims %92 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32> + %122 = tt.broadcast %121 : tensor<16x1xf32> -> tensor<16x16xf32> + %123 = arith.divf %91, %122 : tensor<16x16xf32> + %124 = arith.truncf %123 : tensor<16x16xf32> to tensor<16x16xbf16> + %125 = tt.reshape %124 : tensor<16x16xbf16> -> tensor<1x16x1x16xbf16> + %126 = tt.trans %125 {order = array} : tensor<1x16x1x16xbf16> -> tensor<1x1x16x16xbf16> + hivm.hir.sync_block_set[, , ] flag = 4 + hivm.hir.sync_block_wait[, , ] flag = 6 + %127 = bufferization.to_memref %126 : memref<1x1x16x16xbf16, #hivm.address_space> + hivm.hir.copy ins(%127 : memref<1x1x16x16xbf16, #hivm.address_space>) outs(%alloc : memref<1x1x16x16xbf16, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 2 + %128 = llvm.load %57 : !llvm.ptr<11> -> i32 + %129 = arith.andi %128, %8 : i32 + %130 = arith.ori %129, %6 : i32 + llvm.store %130, %57 : i32, !llvm.ptr<11> + %131 = arith.addi %arg26, %5 : i32 + scf.yield %101, %111, %118, %131, %119#0, %119#1, %119#2, %119#3, %119#4, %120, %112#0, %112#1, %112#2, %112#3, %112#4, %113 : tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, i32, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i32, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, i32 + } else { + scf.yield %arg23, %arg24, %arg25, %arg26, %arg28, %arg29, %arg30, %arg31, %arg32, %arg33, %arg35, %arg36, %arg37, %arg38, %arg39, %arg40 : tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, i32, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i32, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, i32 + } + %71:4 = scf.if %69 -> (tensor<16x128xf32>, i32, i32, i32) { + %72 = arith.remsi %arg41, %c6_i32 : i32 + %73 = arith.cmpi eq, %72, %c0_i32 : i32 + %74 = scf.if %73 -> (tensor<16x128xf32>) { + scf.yield %70#1 : tensor<16x128xf32> + } else { + %90 = arith.cmpi eq, %72, %c1_i32 : i32 + %91 = scf.if %90 -> (tensor<16x128xf32>) { + scf.yield %70#10 : tensor<16x128xf32> + } else { + %92 = arith.cmpi eq, %72, %c2_i32 : i32 + %93 = scf.if %92 -> (tensor<16x128xf32>) { + scf.yield %70#11 : tensor<16x128xf32> + } else { + %94 = arith.cmpi eq, %72, %c3_i32 : i32 + %95 = scf.if %94 -> (tensor<16x128xf32>) { + scf.yield %70#12 : tensor<16x128xf32> + } else { + %96 = arith.cmpi eq, %72, %c4_i32 : i32 + %97 = arith.select %96, %70#13, %70#14 : tensor<16x128xf32> + scf.yield %97 : tensor<16x128xf32> + } + scf.yield %95 : tensor<16x128xf32> + } + scf.yield %93 : tensor<16x128xf32> + } + scf.yield %91 : tensor<16x128xf32> + } + %75 = arith.addi %arg41, %c1_i32 : i32 + %76 = arith.remsi %arg34, %c6_i32 : i32 + %77 = arith.cmpi eq, %76, %c0_i32 : i32 + %78 = scf.if %77 -> (tensor<16xf32>) { + scf.yield %70#2 : tensor<16xf32> + } else { + %90 = arith.cmpi eq, %76, %c1_i32 : i32 + %91 = scf.if %90 -> (tensor<16xf32>) { + scf.yield %70#4 : tensor<16xf32> + } else { + %92 = arith.cmpi eq, %76, %c2_i32 : i32 + %93 = scf.if %92 -> (tensor<16xf32>) { + scf.yield %70#5 : tensor<16xf32> + } else { + %94 = arith.cmpi eq, %76, %c3_i32 : i32 + %95 = scf.if %94 -> (tensor<16xf32>) { + scf.yield %70#6 : tensor<16xf32> + } else { + %96 = arith.cmpi eq, %76, %c4_i32 : i32 + %97 = arith.select %96, %70#7, %70#8 : tensor<16xf32> + scf.yield %97 : tensor<16xf32> + } + scf.yield %95 : tensor<16xf32> + } + scf.yield %93 : tensor<16xf32> + } + scf.yield %91 : tensor<16xf32> + } + %79 = arith.addi %arg34, %c1_i32 : i32 + %80 = tt.expand_dims %78 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32> + %81 = tt.broadcast %80 : tensor<16x1xf32> -> tensor<16x128xf32> + hivm.hir.sync_block_wait[, , ] flag = 3 + %memspacecast = memref.memory_space_cast %alloc_14 : memref<16x128xf32, #hivm.address_space> to memref<16x128xf32> + %82 = bufferization.to_tensor %memspacecast restrict writable : memref<16x128xf32> + %83 = arith.mulf %82, %81 : tensor<16x128xf32> + %84 = arith.cmpf une, %83, %83 : tensor<16x128xf32> + %85 = arith.select %84, %cst_10, %83 : tensor<16x128xi1>, tensor<16x128xf32> + %86 = arith.addf %74, %85 : tensor<16x128xf32> + hivm.hir.sync_block_set[, , ] flag = 5 + %87 = llvm.load %57 : !llvm.ptr<11> -> i32 + %88 = arith.andi %87, %9 : i32 + llvm.store %88, %57 : i32, !llvm.ptr<11> + %89 = arith.addi %arg27, %5 : i32 + scf.yield %86, %89, %79, %75 : tensor<16x128xf32>, i32, i32, i32 + } else { + scf.yield %arg22, %arg27, %arg34, %arg41 : tensor<16x128xf32>, i32, i32, i32 + } + hivm.hir.sync_block_set[, , ] flag = 14 + scf.yield %71#0, %70#0, %70#1, %70#2, %70#3, %71#1, %70#4, %70#5, %70#6, %70#7, %70#8, %70#9, %71#2, %70#10, %70#11, %70#12, %70#13, %70#14, %70#15, %71#3 : tensor<16x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, i32, i32, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i32, i32, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, i32, i32 + } + %49 = arith.cmpf une, %48#0, %48#0 : tensor<16x128xf32> + %50 = arith.select %49, %cst_10, %48#0 : tensor<16x128xi1>, tensor<16x128xf32> + %51 = arith.muli %38, %31 : tensor<16x1xi32> + %52 = arith.addi %32, %51 : tensor<16x1xi32> + %53 = tt.broadcast %52 : tensor<16x1xi32> -> tensor<16x128xi32> + %54 = arith.addi %53, %34 : tensor<16x128xi32> + %55 = tt.addptr %35, %54 : tensor<16x128x!tt.ptr>, tensor<16x128xi32> + %56 = arith.truncf %50 : tensor<16x128xf32> to tensor<16x128xbf16> + tt.store %55, %56 : tensor<16x128x!tt.ptr> + } + hivm.hir.sync_block_wait[, , ] flag = 6 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 6 + %16 = arith.divsi %11, %c16_i32 : i32 + %17 = arith.remsi %11, %c16_i32 : i32 + %18 = tt.make_range {end = 192 : i32, start = 0 : i32} : tensor<192xi32> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %20 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %21 = arith.muli %16, %arg5 : i32 + %22 = arith.muli %17, %arg6 : i32 + %23 = arith.addi %21, %22 : i32 + %24 = tt.splat %arg7 : i32 -> tensor<16x1xi32> + %25 = tt.splat %23 : i32 -> tensor<16x1xi32> + %26 = tt.expand_dims %18 {axis = 0 : i32} : tensor<192xi32> -> tensor<1x192xi32> + %27 = tt.broadcast %26 : tensor<1x192xi32> -> tensor<16x192xi32> + %28 = tt.splat %arg0 : !tt.ptr -> tensor<16x192x!tt.ptr> + %29 = arith.muli %16, %arg8 : i32 + %30 = arith.muli %17, %arg9 : i32 + %31 = arith.addi %29, %30 : i32 + %32 = tt.splat %arg10 : i32 -> tensor<16x1xi32> + %33 = tt.splat %31 : i32 -> tensor<16x1xi32> + %34 = tt.splat %arg1 : !tt.ptr -> tensor<16x192x!tt.ptr> + %35 = arith.muli %16, %arg11 : i32 + %36 = arith.muli %17, %arg12 : i32 + %37 = arith.addi %35, %36 : i32 + %38 = tt.expand_dims %19 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + scf.for %arg20 = %c0_i32 to %c1024_i32 step %c16_i32 : i32 { + %39 = tt.splat %arg20 : i32 -> tensor<16xi32> + %40 = arith.addi %39, %20 : tensor<16xi32> + %41 = tt.expand_dims %40 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %42 = arith.muli %41, %24 : tensor<16x1xi32> + %43 = arith.addi %25, %42 : tensor<16x1xi32> + %44 = tt.broadcast %43 : tensor<16x1xi32> -> tensor<16x192xi32> + %45 = arith.addi %44, %27 : tensor<16x192xi32> + %46 = arith.cmpi slt, %41, %cst_8 : tensor<16x1xi32> + %47 = tt.addptr %28, %45 : tensor<16x192x!tt.ptr>, tensor<16x192xi32> + %48 = tt.broadcast %46 : tensor<16x1xi1> -> tensor<16x192xi1> + %49 = tt.load %47, %48, %cst_0 : tensor<16x192x!tt.ptr> + %50 = arith.addi %arg20, %c1_i32 : i32 + %51 = arith.muli %50, %c2_i32 : i32 + %52 = arith.divsi %51, %c16_i32 : i32 + %53 = arith.divsi %52, %6 : i32 + %54:4 = scf.for %arg21 = %c0_i32 to %51 step %c16_i32 iter_args(%arg22 = %cst_12, %arg23 = %cst_11, %arg24 = %c0_i32, %arg25 = %c0_i32) -> (tensor<16x1xi32>, tensor<16x1xi1>, i32, i32) : i32 { + hivm.hir.sync_block_wait[, , ] flag = 14 + %55 = llvm.inttoptr %c32_i64 : i64 to !llvm.ptr<11> + %56 = llvm.inttoptr %c64_i64 : i64 to !llvm.ptr<11> + %57 = llvm.load %55 : !llvm.ptr<11> -> i32 + %58 = llvm.load %56 : !llvm.ptr<11> -> i32 + %59 = arith.andi %57, %5 : i32 + %60 = arith.andi %58, %5 : i32 + %61 = arith.cmpi eq, %59, %c0_i32 : i32 + %62 = arith.cmpi eq, %60, %c0_i32 : i32 + %63 = arith.andi %61, %62 : i1 + %64 = arith.andi %57, %6 : i32 + %65 = arith.andi %58, %6 : i32 + %66 = arith.cmpi eq, %64, %6 : i32 + %67 = arith.cmpi eq, %65, %6 : i32 + %68 = arith.andi %66, %67 : i1 + %69 = arith.andi %57, %7 : i32 + %70 = arith.andi %58, %7 : i32 + %71 = arith.cmpi eq, %69, %c0_i32 : i32 + %72 = arith.cmpi eq, %70, %c0_i32 : i32 + %73 = arith.andi %71, %72 : i1 + %74 = arith.cmpi slt, %arg24, %53 : i32 + %75 = arith.andi %63, %74 : i1 + %76 = arith.cmpi slt, %arg25, %53 : i32 + %77 = arith.andi %68, %73 : i1 + %78 = arith.andi %77, %76 : i1 + %79:3 = scf.if %75 -> (tensor<16x1xi32>, tensor<16x1xi1>, i32) { + %81 = tt.splat %arg21 : i32 -> tensor<16xi32> + %82 = arith.addi %81, %20 : tensor<16xi32> + %83 = tt.expand_dims %82 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %84 = arith.muli %83, %32 : tensor<16x1xi32> + %85 = arith.addi %33, %84 : tensor<16x1xi32> + %86 = tt.broadcast %85 : tensor<16x1xi32> -> tensor<16x192xi32> + %87 = arith.addi %86, %27 : tensor<16x192xi32> + %88 = arith.cmpi slt, %83, %cst_8 : tensor<16x1xi32> + %89 = tt.addptr %34, %87 : tensor<16x192x!tt.ptr>, tensor<16x192xi32> + %90 = tt.broadcast %88 : tensor<16x1xi1> -> tensor<16x192xi1> + %91 = tt.load %89, %90, %cst_0 : tensor<16x192x!tt.ptr> + %92 = tt.trans %91 {order = array} : tensor<16x192xbf16> -> tensor<192x16xbf16> + %93 = tt.dot %49, %92, %cst_7 : tensor<16x192xbf16> * tensor<192x16xbf16> -> tensor<16x16xf32> + hivm.hir.sync_block_wait[, , ] flag = 4 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%93 : tensor<16x16xf32>) outs(%alloc_13 : memref<16x16xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 1 + %94 = llvm.load %55 : !llvm.ptr<11> -> i32 + %95 = arith.ori %94, %5 : i32 + %96 = arith.ori %95, %5 : i32 + llvm.store %95, %55 : i32, !llvm.ptr<11> + llvm.store %96, %56 : i32, !llvm.ptr<11> + %97 = arith.addi %arg24, %5 : i32 + scf.yield %83, %88, %97 : tensor<16x1xi32>, tensor<16x1xi1>, i32 + } else { + scf.yield %arg22, %arg23, %arg24 : tensor<16x1xi32>, tensor<16x1xi1>, i32 + } + %80 = scf.if %78 -> (i32) { + %81 = tt.broadcast %79#1 : tensor<16x1xi1> -> tensor<16x128xi1> + %82 = tensor.empty() : tensor<16x128xbf16> + %83 = scf.for %arg26 = %c0 to %c16 step %c1 iter_args(%arg27 = %82) -> (tensor<16x128xbf16>) { + %extracted = tensor.extract %79#0[%arg26, %c0] {DiscreteMemAccess} : tensor<16x1xi32> + %93 = arith.muli %extracted, %arg13 : i32 + %94 = arith.addi %37, %93 : i32 + %95 = tt.splat %94 : i32 -> tensor<1x128xi32> + %96 = arith.addi %95, %38 : tensor<1x128xi32> + %97 = arith.extsi %96 : tensor<1x128xi32> to tensor<1x128xi64> + %98 = tt.splat %arg2 : !tt.ptr -> tensor<1x128x!tt.ptr> + %99 = tt.addptr %98, %97 : tensor<1x128x!tt.ptr>, tensor<1x128xi64> + %100 = tt.load %99 {DiscreteMemAccess} : tensor<1x128x!tt.ptr> + %inserted_slice = tensor.insert_slice %100 into %arg27[%arg26, 0] [1, 128] [1, 1] : tensor<1x128xbf16> into tensor<16x128xbf16> + scf.yield {DiscreteMemAccess} %inserted_slice : tensor<16x128xbf16> + } {ExtractedLoadOrStore} + %84 = arith.select %81, %83, %cst : tensor<16x128xi1>, tensor<16x128xbf16> + hivm.hir.sync_block_wait[, , ] flag = 2 + %85 = hivm.hir.convert_layout %alloc {dstLayout = #hivm.data_layout, srcLayout = #hivm.data_layout} : (memref<1x1x16x16xbf16, #hivm.address_space>) -> memref<16x16xbf16, #hivm.address_space> + %memspacecast = memref.memory_space_cast %85 : memref<16x16xbf16, #hivm.address_space> to memref<16x16xbf16> + %86 = bufferization.to_tensor %memspacecast restrict writable : memref<16x16xbf16> + %87 = tt.dot %86, %84, %cst_10 : tensor<16x16xbf16> * tensor<16x128xbf16> -> tensor<16x128xf32> + hivm.hir.sync_block_set[, , ] flag = 6 + hivm.hir.sync_block_wait[, , ] flag = 5 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%87 : tensor<16x128xf32>) outs(%alloc_14 : memref<16x128xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 3 + %88 = llvm.load %55 : !llvm.ptr<11> -> i32 + %89 = arith.andi %88, %10 : i32 + %90 = arith.ori %89, %7 : i32 + %91 = arith.ori %90, %7 : i32 + llvm.store %90, %55 : i32, !llvm.ptr<11> + llvm.store %91, %56 : i32, !llvm.ptr<11> + %92 = arith.addi %arg25, %5 : i32 + scf.yield %92 : i32 + } else { + scf.yield %arg25 : i32 + } + hivm.hir.sync_block_set[, , ] flag = 15 + scf.yield %79#0, %79#1, %79#2, %80 : tensor<16x1xi32>, tensor<16x1xi1>, i32, i32 + } + } + hivm.hir.sync_block_wait[, , ] flag = 4 + hivm.hir.sync_block_wait[, , ] flag = 5 + hivm.hir.sync_block_wait[, , ] flag = 14 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/General/AutoBlockify/auto_blockify.mlir b/third_party/ascend/unittest/Conversion/General/AutoBlockify/auto_blockify.mlir new file mode 100644 index 0000000000..c9a1dcb409 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/AutoBlockify/auto_blockify.mlir @@ -0,0 +1,134 @@ +// RUN: triton-opt --auto-blockify="auto-blockify-size=5" --split-input-file %s | FileCheck %s + +// ----- + +// CHECK-LABEL: tt.func @kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) attributes {auto_blockify_size = 5 : i32} { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : tensor<5x8xf32> +// CHECK: %[[VAL_2:.*]] = arith.constant dense<8> : tensor<5xi32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : tensor<5xi32> +// CHECK: %[[VAL_4:.*]] = tt.get_num_programs x : i32 +// CHECK: %[[VAL_5:.*]] = tt.get_num_programs y : i32 +// CHECK: %[[VAL_6:.*]] = tt.get_num_programs z : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_5]], %[[VAL_6]] : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : i32 +// CHECK: %[[VAL_9:.*]] = tt.get_program_id x {logical_block_id} : i32 +// CHECK: %[[VAL_10:.*]] = tt.get_program_id y {logical_block_id} : i32 +// CHECK: %[[VAL_11:.*]] = tt.get_program_id z {logical_block_id} : i32 +// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_10]], %[[VAL_6]] : i32 +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_16:.*]] = tt.make_range {end = 5 : i32, start = 0 : i32} : tensor<5xi32> +// CHECK: %[[VAL_17:.*]] = tt.splat %[[VAL_15]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_16]] : tensor<5xi32> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_8]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_20:.*]] = arith.cmpi slt, %[[VAL_18]], %[[VAL_19]] : tensor<5xi32> +// CHECK: %[[VAL_21:.*]] = arith.cmpi sge, %[[VAL_18]], %[[VAL_3]] : tensor<5xi32> +// CHECK: %[[VAL_22:.*]] = arith.ori %[[VAL_20]], %[[VAL_21]] : tensor<5xi1> +// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_7]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_24:.*]] = arith.divsi %[[VAL_18]], %[[VAL_23]] : tensor<5xi32> +// CHECK: %[[VAL_25:.*]] = tt.splat %[[VAL_4]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_26:.*]] = arith.remsi %[[VAL_24]], %[[VAL_25]] : tensor<5xi32> +// CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_26]], %[[VAL_2]] : tensor<5xi32> +// CHECK: %[[VAL_28:.*]] = tt.expand_dims %[[VAL_27]] {axis = 1 : i32} : tensor<5xi32> -> tensor<5x1xi32> +// CHECK: %[[VAL_29:.*]] = tt.broadcast %[[VAL_28]] : tensor<5x1xi32> -> tensor<5x8xi32> +// CHECK: %[[VAL_30:.*]] = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> +// CHECK: %[[VAL_31:.*]] = tt.expand_dims %[[VAL_30]] {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> +// CHECK: %[[VAL_32:.*]] = tt.broadcast %[[VAL_31]] : tensor<1x8xi32> -> tensor<5x8xi32> +// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_29]], %[[VAL_32]] : tensor<5x8xi32> +// CHECK: %[[VAL_34:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<5x8x!tt.ptr> +// CHECK: %[[VAL_35:.*]] = tt.addptr %[[VAL_34]], %[[VAL_33]] : tensor<5x8x!tt.ptr>, tensor<5x8xi32> +// CHECK: %[[VAL_36:.*]] = tt.expand_dims %[[VAL_22]] {axis = 1 : i32} : tensor<5xi1> -> tensor<5x1xi1> +// CHECK: %[[VAL_37:.*]] = tt.broadcast %[[VAL_36]] : tensor<5x1xi1> -> tensor<5x8xi1> +// CHECK: tt.store %[[VAL_35]], %[[VAL_1]], %[[VAL_37]] : tensor<5x8x!tt.ptr> +// CHECK: tt.return +// CHECK: } +tt.func @kernel(%arg0: !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c8_i32 : i32 + %2 = tt.splat %1 : i32 -> tensor<8xi32> + %3 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %4 = arith.addi %2, %3 : tensor<8xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %cst : tensor<8x!tt.ptr> + tt.return +} + +// ----- + +// CHECK-LABEL: tt.func @kernel2( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) attributes {auto_blockify_size = 5 : i32} { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<8> : tensor<5xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_3:.*]] = arith.constant 5 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_6:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> +// CHECK: %[[VAL_7:.*]] = tt.get_num_programs x : i32 +// CHECK: %[[VAL_8:.*]] = tt.get_num_programs y : i32 +// CHECK: %[[VAL_9:.*]] = tt.get_num_programs z : i32 +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_8]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_12:.*]] = tt.get_program_id x {logical_block_id} : i32 +// CHECK: %[[VAL_13:.*]] = tt.get_program_id y {logical_block_id} : i32 +// CHECK: %[[VAL_14:.*]] = tt.get_program_id z {logical_block_id} : i32 +// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_12]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_13]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_19:.*]] = tt.make_range {end = 5 : i32, start = 0 : i32} : tensor<5xi32> +// CHECK: %[[VAL_20:.*]] = tt.splat %[[VAL_18]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : tensor<5xi32> +// CHECK: %[[VAL_22:.*]] = tt.splat %[[VAL_10]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_23:.*]] = arith.divsi %[[VAL_21]], %[[VAL_22]] : tensor<5xi32> +// CHECK: %[[VAL_24:.*]] = tt.splat %[[VAL_7]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_25:.*]] = arith.remsi %[[VAL_23]], %[[VAL_24]] : tensor<5xi32> +// CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_9]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_27:.*]] = arith.divsi %[[VAL_21]], %[[VAL_26]] : tensor<5xi32> +// CHECK: %[[VAL_28:.*]] = tt.splat %[[VAL_8]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_29:.*]] = arith.remsi %[[VAL_27]], %[[VAL_28]] : tensor<5xi32> +// CHECK: %[[VAL_30:.*]] = arith.cmpi slt, %[[VAL_29]], %[[VAL_1]] : tensor<5xi32> +// CHECK: %[[VAL_31:.*]] = arith.muli %[[VAL_25]], %[[VAL_1]] : tensor<5xi32> +// CHECK: %[[VAL_32:.*]] = tt.expand_dims %[[VAL_31]] {axis = 1 : i32} : tensor<5xi32> -> tensor<5x1xi32> +// CHECK: %[[VAL_33:.*]] = tt.broadcast %[[VAL_32]] : tensor<5x1xi32> -> tensor<5x8xi32> +// CHECK: %[[VAL_34:.*]] = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> +// CHECK: %[[VAL_35:.*]] = tt.expand_dims %[[VAL_34]] {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> +// CHECK: %[[VAL_36:.*]] = tt.broadcast %[[VAL_35]] : tensor<1x8xi32> -> tensor<5x8xi32> +// CHECK: %[[VAL_37:.*]] = arith.addi %[[VAL_33]], %[[VAL_36]] : tensor<5x8xi32> +// CHECK: %[[VAL_38:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<5x8x!tt.ptr> +// CHECK: %[[VAL_39:.*]] = tt.addptr %[[VAL_38]], %[[VAL_37]] : tensor<5x8x!tt.ptr>, tensor<5x8xi32> +// CHECK: %[[VAL_40:.*]] = arith.subi %[[VAL_11]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_41:.*]] = arith.maxsi %[[VAL_40]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_42:.*]] = arith.index_cast %[[VAL_41]] : i32 to index +// CHECK: %[[VAL_43:.*]] = arith.minsi %[[VAL_42]], %[[VAL_3]] : index +// CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_5]] to %[[VAL_43]] step %[[VAL_4]] { +// CHECK: %[[VAL_45:.*]] = tensor.extract %[[VAL_30]]{{\[}}%[[VAL_44]]] : tensor<5xi1> +// CHECK: scf.if %[[VAL_45]] { +// CHECK: %[[VAL_46:.*]] = tensor.extract_slice %[[VAL_39]]{{\[}}%[[VAL_44]], 0] [1, 8] [1, 1] : tensor<5x8x!tt.ptr> to tensor<8x!tt.ptr> +// CHECK: tt.store %[[VAL_46]], %[[VAL_6]] : tensor<8x!tt.ptr> +// CHECK: } +// CHECK: } {auto_blockify_loop} +// CHECK: tt.return +// CHECK: } +tt.func @kernel2(%arg0: !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %a = tt.get_program_id y : i32 + %b = arith.cmpi slt, %a, %c8_i32 : i32 + %1 = arith.muli %0, %c8_i32 : i32 + %2 = tt.splat %1 : i32 -> tensor<8xi32> + %3 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %4 = arith.addi %2, %3 : tensor<8xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<8x!tt.ptr>, tensor<8xi32> + scf.if %b { + tt.store %6, %cst : tensor<8x!tt.ptr> + scf.yield + } + tt.return +} diff --git a/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/atomic.mlir b/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/atomic.mlir new file mode 100644 index 0000000000..c6978c29bc --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/atomic.mlir @@ -0,0 +1,201 @@ +// RUN: triton-opt %s --discrete-mask-access-conversion --split-input-file %s | FileCheck %s + +// CHECK-LABEL: tt.func @atomic_add_i32 +// CHECK: %[[default:.*]] = arith.constant dense<0> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw add, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_add_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw add, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_fadd_f32 +// CHECK: %[[default:.*]] = arith.constant dense<0.000000e+00> : tensor<1024xf32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_fadd_f32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw fadd, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xf32>, tensor<1024xi1>) -> tensor<1024xf32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_max_i32 +// CHECK: %[[default:.*]] = arith.constant dense<-2147483648> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_max_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw max, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_umax_i32 +// CHECK: %[[default:.*]] = arith.constant dense<0> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw umax, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_umax_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw umax, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_min_i32 +// CHECK: %[[default:.*]] = arith.constant dense<2147483647> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw min, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_min_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw min, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_umin_i32 +// CHECK: %[[default:.*]] = arith.constant dense<2147483647> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw umin, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_umin_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw umin, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_and_i32 +// CHECK: %[[default:.*]] = arith.constant dense<2147483647> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw and, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_and_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw and, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_or_i32 +// CHECK: %[[default:.*]] = arith.constant dense<0> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw or, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_or_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw or, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_max_i16 +// CHECK: %[[default:.*]] = arith.constant dense<-32768> : tensor<1024xi16> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_max_i16(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw max, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi16>, tensor<1024xi1>) -> tensor<1024xi16> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_max_f16 +// CHECK: %[[default:.*]] = arith.constant dense<0xFC00> : tensor<1024xf16> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_max_f16(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw max, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xf16>, tensor<1024xi1>) -> tensor<1024xf16> + tt.return +} diff --git a/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/loadstore.mlir b/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/loadstore.mlir new file mode 100644 index 0000000000..6854e118ea --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/loadstore.mlir @@ -0,0 +1,67 @@ +// RUN: triton-opt %s --discrete-mask-access-conversion --split-input-file %s | FileCheck %s +// RUN: triton-opt %s --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' + +// CHECK-LABEL: tt.func @discrete_load +// CHECK: %[[loaded_value:.*]] = tt.load %[[load_ptr:.*]] +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[loaded_value]], %[[other:.*]] +// CHECK: tt.store %[[store_ptr:.*]], %[[value]] +tt.func @discrete_load(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<0> : tensor<1024xi32> + %cst_0 = arith.constant dense<200> : tensor<1024xi32> + %cst_1 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst_0 : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_1 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %5, %3, %cst : tensor<1024x!tt.ptr> + tt.store %7, %8 : tensor<1024x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func @discrete_load_without_other +// CHECK: %[[other:.*]] = arith.constant dense<0> +// CHECK: %[[loaded_value:.*]] = tt.load %[[load_ptr:.*]] +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[loaded_value]], %[[other]] +// CHECK: tt.store %[[store_ptr:.*]], %[[value]] +tt.func @discrete_load_without_other(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<0> : tensor<1024xi32> + %cst_0 = arith.constant dense<200> : tensor<1024xi32> + %cst_1 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst_0 : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_1 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %5, %3 : tensor<1024x!tt.ptr> + tt.store %7, %8 : tensor<1024x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func @discrete_store +// CHECK: %[[loaded_value:.*]] = tt.load %[[load_ptr:.*]] : tensor<1024x!tt.ptr> +// CHECK: %[[origin_value:.*]] = tt.load %[[store_ptr:.*]] : tensor<1024x!tt.ptr> +// CHECK: %[[store_value:.*]] = arith.select %[[mask:.*]], %[[loaded_value]], %[[origin_value]] +// CHECK: tt.store %[[store_ptr]], %[[store_value]] +tt.func @discrete_store(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<0> : tensor<1024xi32> + %cst_0 = arith.constant dense<200> : tensor<1024xi32> + %cst_1 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst_0 : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_1 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %5 : tensor<1024x!tt.ptr> + tt.store %7, %8, %3 : tensor<1024x!tt.ptr> + tt.return +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonAscendAllPass/simplify_for_loop.mlir b/third_party/ascend/unittest/Conversion/General/TritonAscendAllPass/simplify_for_loop.mlir new file mode 100644 index 0000000000..d1e8eb8a2e --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonAscendAllPass/simplify_for_loop.mlir @@ -0,0 +1,118 @@ +// RUN: triton-opt -allow-unregistered-dialect --triton-to-structured '--discrete-mask-access-conversion=compile-on-910-95=False force-simt-template=False' '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation --triton-to-structured --triton-to-linalg --split-input-file %s | FileCheck %s +// CHECK-LABEL: func.func @matmul_kernel +// CHECK-DAG: %[[C0:.*]] = arith.constant{{.*}}0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant{{.*}}1 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant{{.*}}64 : index +// CHECK: %{{.*}} = scf.for %{{.*}} = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%{{.*}} = %{{.*}}, %[[ARG16:.*]] = %[[C0]]) -> (tensor<128x256xi32>, index) : i32 { +// CHECK: %[[INNERFOR:.*]]:3 = scf.for {{.*}} = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %[[ARG20:.*]] = %[[ARG16]]) -> (tensor<128x256xi32>, tensor<64x256xi64>, index) : i32 { +// CHECK: %{{.*}} = memref.reinterpret_cast %{{.*}} to offset: [%[[ARG20]]], sizes: [1, 64], strides: [%[[C1]], %[[C1]]] : memref to memref<1x64xi8, strided<[?, ?], offset: ?>> +// CHECK: %{{.*}} = linalg.broadcast ins(%{{.*}} : tensor<64xi8>) outs(%{{.*}} : tensor<128x64xi8>) dimensions = [0] +// CHECK: %[[RES72:.*]] = arith.addi %[[ARG20]], %[[C64]] : index +// CHECK: scf.yield %{{.*}}, %{{.*}}, %[[RES72]] : tensor<128x256xi32>, tensor<64x256xi64>, index +// CHECK: } {{{.*}}tts.simplify_tensor_iter_args.done} +// CHECK: scf.yield %[[INNERFOR]]#0, %[[INNERFOR]]#2 : tensor<128x256xi32>, index +// CHECK: } {{{.*}}tts.simplify_tensor_iter_args.done} + +module { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<1> : tensor<64x256xi8> + %cst_0 = arith.constant dense<0> : tensor<64x256xi8> + %cst_1 = arith.constant dense<0> : tensor<128x64xi8> + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<1024> : tensor<1x256xi32> + %cst_3 = arith.constant dense<1> : tensor<128x1xi32> + %cst_4 = arith.constant dense<64> : tensor<128x64xi32> + %cst_5 = arith.constant dense<0> : tensor<128x256xi32> + %c3_i32 = arith.constant 3 : i32 + %c2_i32 = arith.constant 2 : i32 + %cst_6 = arith.constant dense<8192> : tensor<64x1xi32> + %c8192_i32 = arith.constant 8192 : i32 + %c64_i32 = arith.constant 64 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst_7 = arith.constant dense<1024> : tensor<256xi32> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c32_i32 : i32 + %2 = arith.muli %1, %c8_i32 : i32 + %3 = arith.subi %c1_i32, %2 : i32 + %4 = arith.minsi %3, %c8_i32 : i32 + %5 = arith.remsi %0, %c32_i32 : i32 + %6 = arith.remsi %5, %4 : i32 + %7 = arith.addi %2, %6 : i32 + %8 = arith.divsi %5, %4 : i32 + %9 = arith.muli %8, %c256_i32 : i32 + %10 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %11 = tt.splat %9 : i32 -> tensor<256xi32> + %12 = arith.addi %11, %10 : tensor<256xi32> + %13 = arith.remsi %12, %cst_7 : tensor<256xi32> + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %15 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %16 = tt.splat %arg0 : !tt.ptr -> tensor<1x64x!tt.ptr> + %17 = tt.addptr %16, %15 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> + %18 = tt.broadcast %17 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> + %19 = tt.expand_dims %14 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %20 = tt.splat %arg4 : i32 -> tensor<64x1xi32> + %21 = arith.muli %19, %20 : tensor<64x1xi32> + %22 = tt.expand_dims %13 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %23 = tt.broadcast %21 : tensor<64x1xi32> -> tensor<64x256xi32> + %24 = tt.broadcast %22 : tensor<1x256xi32> -> tensor<64x256xi32> + %25 = arith.addi %23, %24 : tensor<64x256xi32> + %26 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> + %27 = tt.addptr %26, %25 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + %28 = arith.cmpi slt, %19, %cst_6 : tensor<64x1xi32> + %29 = tt.broadcast %28 : tensor<64x1xi1> -> tensor<64x256xi1> + %30 = arith.muli %arg4, %c64_i32 : i32 + %31 = tt.splat %30 : i32 -> tensor<64x256xi32> + %32:2 = scf.for %arg6 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg7 = %cst_5, %arg8 = %18) -> (tensor<128x256xi32>, tensor<128x64x!tt.ptr>) : i32 { + %51 = arith.muli %arg6, %c32_i32 : i32 + %52 = arith.muli %arg6, %c2_i32 : i32 + %53 = arith.shli %c3_i32, %52 : i32 + %54 = tt.splat %53 : i32 -> tensor<64x256xi32> + %55 = tt.splat %52 : i32 -> tensor<64x256xi32> + %56:3 = scf.for %arg9 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg10 = %arg7, %arg11 = %arg8, %arg12 = %27) -> (tensor<128x256xi32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr>) : i32 { + %57 = arith.addi %51, %arg9 : i32 + %58 = arith.muli %57, %c64_i32 : i32 + %59 = arith.subi %c8192_i32, %58 : i32 + %60 = tt.splat %59 : i32 -> tensor<1x64xi32> + %61 = arith.cmpi slt, %15, %60 : tensor<1x64xi32> + %62 = tt.broadcast %61 : tensor<1x64xi1> -> tensor<128x64xi1> + %63 = tt.load %arg11, %62, %cst_1 : tensor<128x64x!tt.ptr> + %64 = tt.load %arg12, %29, %cst_0 : tensor<64x256x!tt.ptr> + %65 = arith.extui %64 : tensor<64x256xi8> to tensor<64x256xi32> + %66 = arith.andi %65, %54 : tensor<64x256xi32> + %67 = arith.shrsi %66, %55 : tensor<64x256xi32> + %68 = arith.trunci %67 : tensor<64x256xi32> to tensor<64x256xi8> + %69 = arith.subi %68, %cst : tensor<64x256xi8> + %70 = tt.dot %63, %69, %arg10 : tensor<128x64xi8> * tensor<64x256xi8> -> tensor<128x256xi32> + %71 = tt.addptr %arg11, %cst_4 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %72 = tt.addptr %arg12, %31 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + scf.yield %70, %71, %72 : tensor<128x256xi32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr> + } + scf.yield %56#0, %56#1 : tensor<128x256xi32>, tensor<128x64x!tt.ptr> + } + %33 = arith.muli %7, %c128_i32 : i32 + %34 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %35 = tt.splat %33 : i32 -> tensor<128xi32> + %36 = arith.addi %35, %34 : tensor<128xi32> + %37 = tt.expand_dims %36 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %38 = tt.splat %arg5 : i32 -> tensor<128x1xi32> + %39 = arith.muli %38, %37 : tensor<128x1xi32> + %40 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> + %41 = tt.addptr %40, %39 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + %42 = tt.expand_dims %12 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %43 = tt.broadcast %41 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> + %44 = tt.broadcast %42 : tensor<1x256xi32> -> tensor<128x256xi32> + %45 = tt.addptr %43, %44 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %46 = arith.cmpi slt, %37, %cst_3 : tensor<128x1xi32> + %47 = arith.cmpi slt, %42, %cst_2 : tensor<1x256xi32> + %48 = tt.broadcast %46 : tensor<128x1xi1> -> tensor<128x256xi1> + %49 = tt.broadcast %47 : tensor<1x256xi1> -> tensor<128x256xi1> + %50 = arith.andi %48, %49 : tensor<128x256xi1> + tt.store %45, %32#0, %50 : tensor<128x256x!tt.ptr> + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToHFusion/fp_to_fp_rtz.mlir b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/fp_to_fp_rtz.mlir new file mode 100644 index 0000000000..fbff6cb341 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/fp_to_fp_rtz.mlir @@ -0,0 +1,18 @@ +// RUN: triton-opt %s -triton-to-hfusion | FileCheck %s + +// CHECK-LABEL: tt.func @test_fp32_to_fp16_rtz +tt.func @test_fp32_to_fp16_rtz(%arg0: tensor<1024xf32>) -> tensor<1024xf16> { + // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1024xf16> + // CHECK: %[[RESULT:.*]] = hfusion.cast {mode = #hfusion.round_mode} ins(%arg0 : tensor<1024xf32>) outs(%[[EMPTY]] : tensor<1024xf16>) -> tensor<1024xf16> + %0 = tt.fp_to_fp %arg0, rounding = rtz : tensor<1024xf32> -> tensor<1024xf16> + // CHECK: return %[[RESULT]] + tt.return %0 : tensor<1024xf16> +} + + +// CHECK-LABEL: tt.func @test_fp32_to_fp16_rtz_fail +tt.func @test_fp32_to_fp16_rtz_fail(%arg0: tensor<1024xf32>) -> tensor<1024xf16> { + %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf32> -> tensor<1024xf16> + // CHECK: %{{.*}} = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf32> -> tensor<1024xf16> + tt.return %0 : tensor<1024xf16> +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToHFusion/mod.mlir b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/mod.mlir new file mode 100644 index 0000000000..d4c264913a --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/mod.mlir @@ -0,0 +1,11 @@ +// RUN: triton-opt %s -triton-to-hfusion | FileCheck %s + +// CHECK: tensor.empty() : tensor<1xf32> +// CHECK: hfusion.elemwise_binary {fun = #hfusion.binary_fn} ins(%arg0, %arg1 : tensor<1xf32>, tensor<1xf32>) outs(%0 : tensor<1xf32>) -> tensor<1xf32> + +module { + tt.func @test_mod(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %0 = ascend.mod %arg0, %arg1 : tensor<1xf32> tensor<1xf32> -> tensor<1xf32> + tt.return %0 : tensor<1xf32> + } +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToHIVM/sync_block_op_conversion.mlir b/third_party/ascend/unittest/Conversion/General/TritonToHIVM/sync_block_op_conversion.mlir new file mode 100644 index 0000000000..186e7210fc --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToHIVM/sync_block_op_conversion.mlir @@ -0,0 +1,20 @@ +// RUN: triton-opt %s --triton-to_hivm | FileCheck %s + +// CHECK-LABEL: tt.func @triton_func +tt.func @triton_func() { + ascend.custom "sync_block_set" {str_args = ["vector", 1 : i32]} + ascend.custom "sync_block_wait" {str_args = ["vector", 1 : i32]} + ascend.custom "sync_block_set" {str_args = ["cube", 2 : i32]} + ascend.custom "sync_block_wait" {str_args = ["cube", 2 : i32]} + ascend.custom "sync_block_all" {str_args = ["all_cube", 1 : i32]} + ascend.custom "sync_block_all" {str_args = ["all_vector", 1 : i32]} + ascend.custom "sync_block_all" {str_args = ["all", 1 : i32]} + tt.return +} +// CHECK: hivm.hir.sync_block_set[, , ] flag = 1 +// CHECK: hivm.hir.sync_block_wait[, , ] flag = 1 +// CHECK: hivm.hir.sync_block_set[, , ] flag = 2 +// CHECK: hivm.hir.sync_block_wait[, , ] flag = 2 +// CHECK: hivm.hir.sync_block[, 1 : i16] tcube_pipe = +// CHECK: hivm.hir.sync_block[, 1 : i16] tvector_pipe = +// CHECK: hivm.hir.sync_block[, 1 : i16] tcube_pipe = tvector_pipe = diff --git a/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw.mlir b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw.mlir new file mode 100644 index 0000000000..f3f36e5dcc --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw.mlir @@ -0,0 +1,109 @@ +// RUN: triton-opt --triton-to-linalg="named-ops=True" --split-input-file %s | FileCheck %s +// CHECK-LABEL: func.func @matmul_atomic_add +// CHECK-NOT: GenericAtomicRMW +// CHECK: tensor.extract_slice +// CHECK: hivm.hir.store ins(%{{.*}} : tensor) outs(%{{.*}} : memref) atomic = + + tt.func public @matmul_atomic_add(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_program_id z : i32 + %c16_i32 = arith.constant 16 : i32 + %c16_i32_0 = arith.constant 16 : i32 + %3 = arith.muli %0, %c16_i32_0 : i32 + %4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %5 = tt.splat %3 : i32 -> tensor<16xi32> + %6 = arith.addi %5, %4 : tensor<16xi32> + %c16_i32_1 = arith.constant 16 : i32 + %c16_i32_2 = arith.constant 16 : i32 + %7 = arith.muli %1, %c16_i32_2 : i32 + %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %9 = tt.splat %7 : i32 -> tensor<16xi32> + %10 = arith.addi %9, %8 : tensor<16xi32> + %11 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %12 = tt.splat %arg10 : i32 -> tensor<16x1xi32> + %13 = arith.muli %11, %12 : tensor<16x1xi32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<16x1x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<16x1x!tt.ptr>, tensor<16x1xi32> + %16 = tt.expand_dims %10 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %17 = tt.splat %arg11 : i32 -> tensor<1x16xi32> + %18 = arith.muli %16, %17 : tensor<1x16xi32> + %19 = tt.broadcast %15 : tensor<16x1x!tt.ptr> -> tensor<16x16x!tt.ptr> + %20 = tt.broadcast %18 : tensor<1x16xi32> -> tensor<16x16xi32> + %21 = tt.addptr %19, %20 : tensor<16x16x!tt.ptr>, tensor<16x16xi32> + %22 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %23 = tt.splat %arg3 : i32 -> tensor<16x1xi32> + %24 = arith.cmpi slt, %22, %23 : tensor<16x1xi32> + %25 = tt.expand_dims %10 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %26 = tt.splat %arg4 : i32 -> tensor<1x16xi32> + %27 = arith.cmpi slt, %25, %26 : tensor<1x16xi32> + %28 = tt.broadcast %24 : tensor<16x1xi1> -> tensor<16x16xi1> + %29 = tt.broadcast %27 : tensor<1x16xi1> -> tensor<16x16xi1> + %30 = arith.andi %28, %29 : tensor<16x16xi1> + %c16_i32_3 = arith.constant 16 : i32 + %c16_i32_4 = arith.constant 16 : i32 + %31 = arith.muli %2, %c16_i32_4 : i32 + %c32_i32 = arith.constant 32 : i32 + %32 = arith.bitcast %31 : i32 to i32 + %33 = arith.bitcast %arg5 : i32 to i32 + %34 = arith.bitcast %c32_i32 : i32 to i32 + %35 = ub.poison : i32 + scf.for %arg12 = %32 to %33 step %34 : i32 { + %36 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %37 = tt.splat %arg12 : i32 -> tensor<16xi32> + %38 = arith.addi %37, %36 : tensor<16xi32> + %39 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %40 = tt.splat %arg6 : i32 -> tensor<16x1xi32> + %41 = arith.muli %39, %40 : tensor<16x1xi32> + %42 = tt.splat %arg0 : !tt.ptr -> tensor<16x1x!tt.ptr> + %43 = tt.addptr %42, %41 : tensor<16x1x!tt.ptr>, tensor<16x1xi32> + %44 = tt.expand_dims %38 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %45 = tt.splat %arg7 : i32 -> tensor<1x16xi32> + %46 = arith.muli %44, %45 : tensor<1x16xi32> + %47 = tt.broadcast %43 : tensor<16x1x!tt.ptr> -> tensor<16x16x!tt.ptr> + %48 = tt.broadcast %46 : tensor<1x16xi32> -> tensor<16x16xi32> + %49 = tt.addptr %47, %48 : tensor<16x16x!tt.ptr>, tensor<16x16xi32> + %50 = tt.expand_dims %38 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %51 = tt.splat %arg8 : i32 -> tensor<16x1xi32> + %52 = arith.muli %50, %51 : tensor<16x1xi32> + %53 = tt.splat %arg1 : !tt.ptr -> tensor<16x1x!tt.ptr> + %54 = tt.addptr %53, %52 : tensor<16x1x!tt.ptr>, tensor<16x1xi32> + %55 = tt.expand_dims %10 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %56 = tt.splat %arg9 : i32 -> tensor<1x16xi32> + %57 = arith.muli %55, %56 : tensor<1x16xi32> + %58 = tt.broadcast %54 : tensor<16x1x!tt.ptr> -> tensor<16x16x!tt.ptr> + %59 = tt.broadcast %57 : tensor<1x16xi32> -> tensor<16x16xi32> + %60 = tt.addptr %58, %59 : tensor<16x16x!tt.ptr>, tensor<16x16xi32> + %61 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %62 = tt.splat %arg3 : i32 -> tensor<16x1xi32> + %63 = arith.cmpi slt, %61, %62 : tensor<16x1xi32> + %64 = tt.expand_dims %38 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %65 = tt.splat %arg5 : i32 -> tensor<1x16xi32> + %66 = arith.cmpi slt, %64, %65 : tensor<1x16xi32> + %67 = tt.broadcast %63 : tensor<16x1xi1> -> tensor<16x16xi1> + %68 = tt.broadcast %66 : tensor<1x16xi1> -> tensor<16x16xi1> + %69 = arith.andi %67, %68 : tensor<16x16xi1> + %70 = tt.expand_dims %38 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %71 = tt.splat %arg5 : i32 -> tensor<16x1xi32> + %72 = arith.cmpi slt, %70, %71 : tensor<16x1xi32> + %73 = tt.expand_dims %10 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %74 = tt.splat %arg4 : i32 -> tensor<1x16xi32> + %75 = arith.cmpi slt, %73, %74 : tensor<1x16xi32> + %76 = tt.broadcast %72 : tensor<16x1xi1> -> tensor<16x16xi1> + %77 = tt.broadcast %75 : tensor<1x16xi1> -> tensor<16x16xi1> + %78 = arith.andi %76, %77 : tensor<16x16xi1> + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0> : tensor<16x16xi32> + %79 = arith.sitofp %cst : tensor<16x16xi32> to tensor<16x16xf32> + %80 = tt.load %49, %69, %79 : tensor<16x16x!tt.ptr> + %c0_i32_5 = arith.constant 0 : i32 + %cst_6 = arith.constant dense<0> : tensor<16x16xi32> + %81 = arith.sitofp %cst_6 : tensor<16x16xi32> to tensor<16x16xf32> + %82 = tt.load %60, %78, %81 : tensor<16x16x!tt.ptr> + %cst_7 = arith.constant 0.000000e+00 : f32 + %cst_8 = arith.constant dense<0.000000e+00> : tensor<16x16xf32> + %83 = tt.dot %80, %82, %cst_8 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> + %84 = tt.atomic_rmw fadd, acq_rel, gpu, %21, %83, %30 : (tensor<16x16x!tt.ptr>, tensor<16x16xf32>, tensor<16x16xi1>) -> tensor<16x16xf32> + } + tt.return + } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw_block.mlir b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw_block.mlir new file mode 100644 index 0000000000..a5ec468da5 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw_block.mlir @@ -0,0 +1,46 @@ +// RUN: triton-opt --triton-to-linalg="named-ops=True" --split-input-file %s | FileCheck %s + +module attributes {hacc.target = #hacc.target<"Ascend910B2">} { + tt.func public @moe_align_block_size_stage4(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} , %arg3: !tt.ptr {tt.divisibility = 16 : i32} , %arg4: !tt.ptr {tt.divisibility = 16 : i32} , %arg5: i32) attributes {noinline = false} { + %cst = arith.constant dense<1> : tensor<1xi32> + %cst_0 = arith.constant dense<0> : tensor<1xi32> + %c250_i32 = arith.constant 250 : i32 + %c16_i32 = arith.constant 16 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.addptr %arg4, %0 : !tt.ptr, i32 + %2 = tt.load %1 : !tt.ptr + %3 = tt.addptr %1, %c1_i32 : !tt.ptr, i32 + %4 = tt.load %3 : !tt.ptr + scf.for %arg6 = %2 to %4 step %c16_i32 : i32 { + %22 = arith.divsi %arg6, %c16_i32 : i32 + %23 = tt.addptr %arg2, %22 : !tt.ptr, i32 + tt.store %23, %0 : !tt.ptr + } + %5 = arith.muli %0, %c250_i32 : i32 + %6 = tt.splat %0 : i32 -> tensor<1xi32> + %7 = arith.cmpi slt, %0, %arg5 : i32 + %8 = tt.splat %7 : i1 -> tensor<1xi1> + %9 = tt.addptr %arg0, %0 : !tt.ptr, i32 + %10 = tt.splat %9 : !tt.ptr -> tensor<1x!tt.ptr> + %11 = tt.load %10, %8, %cst_0 : tensor<1x!tt.ptr> + %12 = tt.addptr %arg3, %5 : !tt.ptr, i32 + %13 = tt.splat %12 : !tt.ptr -> tensor<1x!tt.ptr> + %14 = tt.addptr %13, %11 : tensor<1x!tt.ptr>, tensor<1xi32> + %15 = tt.atomic_rmw add, acq_rel, gpu, %14, %cst, %8 : (tensor<1x!tt.ptr>, tensor<1xi32>, tensor<1xi1>) -> tensor<1xi32> + %16 = tt.splat %arg4 : !tt.ptr -> tensor<1x!tt.ptr> + %17 = tt.addptr %16, %11 : tensor<1x!tt.ptr>, tensor<1xi32> + %18 = tt.load %17, %8, %cst_0 : tensor<1x!tt.ptr> + %19 = arith.addi %15, %18 : tensor<1xi32> + %20 = tt.splat %arg1 : !tt.ptr -> tensor<1x!tt.ptr> + %21 = tt.addptr %20, %19 : tensor<1x!tt.ptr>, tensor<1xi32> + tt.store %21, %6, %8 : tensor<1x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @moe_align_block_size_stage4 + +// CHECK: %[[CAST1:.*]] = memref.reinterpret_cast %[[.*]] to offset: [%[[.*]]], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1], offset: ?>> +// CHECK: %[[CAST2:.*]] = memref.alloc() : memref<1xi32> +// CHECK: memref.copy %[[CAST1]], %[[CAST2]] : memref<1xi32, strided<[1], offset: ?>> to memref<1xi32> diff --git a/third_party/ascend/unittest/Conversion/General/TritonToLinalg/legal_stride.mlir b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/legal_stride.mlir new file mode 100644 index 0000000000..96f84ef0f9 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/legal_stride.mlir @@ -0,0 +1,30 @@ +// RUN: triton-opt --triton-to-linalg="named-ops=True" --split-input-file %s | FileCheck %s +// CHECK-LABEL: func.func @triton_fn_broadcast_nested +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[CAST1:.*]] = memref.reinterpret_cast %[[ARG2:.*]] to offset: [%[[ARG13:.*]]], sizes: [4, 1], strides: [%c4, %[[C1]]] : memref to memref<4x1xf32, strided<[?, ?], offset: ?>> +// CHECK: %[[CAST2:.*]] = memref.reinterpret_cast %[[ARG3:.*]] to offset: [%[[ARG13]]], sizes: [4, 1], strides: [%c4, %[[C1]]] : memref to memref<4x1xf32, strided<[?, ?], offset: ?>> + +module { + tt.func @triton_fn_broadcast_nested(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32){ + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = scf.for %arg10 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg11 = %c0) -> (index) : i32 { + %1 = scf.for %arg12 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg13 = %arg11) -> (index) : i32 { + %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%arg13], sizes: [4, 1], strides: [%c4, %c0] : memref to memref<4x1xf32, strided<[?, ?], offset: ?>> + %alloc = memref.alloc() : memref<4x1xf32> + memref.copy %reinterpret_cast, %alloc : memref<4x1xf32, strided<[?, ?], offset: ?>> to memref<4x1xf32> + %2 = bufferization.to_tensor %alloc restrict writable : memref<4x1xf32> + %reinterpret_cast_0 = memref.reinterpret_cast %arg3 to offset: [%arg13], sizes: [4, 1], strides: [%c4, %c0] : memref to memref<4x1xf32, strided<[?, ?], offset: ?>> + bufferization.materialize_in_destination %2 in writable %reinterpret_cast_0 : (tensor<4x1xf32>, memref<4x1xf32, strided<[?, ?], offset: ?>>) -> () + %3 = arith.addi %arg13, %c1 : index + scf.yield %3 : index + } + scf.yield %1 : index + } + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToLinalg/parse_select.mlir b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/parse_select.mlir new file mode 100644 index 0000000000..f844ca90d5 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/parse_select.mlir @@ -0,0 +1,37 @@ +// RUN: triton-opt --triton-to-linalg --split-input-file %s -verify-each 2>&1 | FileCheck %s --check-prefix=NOERR +// NOERR-NOT: parseSelect currently supports all-ones shape unless cond=i1 with dense constants +// CHECK-LABEL: func.func public @triton_for_if_load + +module { + tt.func public @triton_for_if_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<16xi32> + %c0_i32 = arith.constant 0 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst_0 = arith.constant dense<1> : tensor<16xi32> + %cst_1 = arith.constant dense<32> : tensor<16xi32> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16xf32> + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %1 = tt.get_program_id x : i32 + %2 = arith.cmpi ne, %1, %c0_i32 : i32 + %3 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr> + %5:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<16xi32>, tensor<16xi32>) : i32 { + %6 = arith.muli %arg2, %c16_i32 : i32 + %7 = tt.splat %6 : i32 -> tensor<16xi32> + %8 = arith.addi %arg3, %7 : tensor<16xi32> + %9 = arith.addi %arg4, %7 : tensor<16xi32> + %10 = arith.select %2, %cst_0, %cst : tensor<16xi32> + %11 = arith.addi %8, %10 : tensor<16xi32> + %12 = tt.addptr %3, %11 : tensor<16x!tt.ptr>, tensor<16xi32> + %13 = arith.cmpi slt, %11, %cst_1 : tensor<16xi32> + %14 = tt.load %12 : tensor<16x!tt.ptr> + %15 = arith.select %13, %14, %cst_2 : tensor<16xi1>, tensor<16xf32> + %16 = tt.addptr %4, %9 : tensor<16x!tt.ptr>, tensor<16xi32> + tt.store %16, %15 : tensor<16x!tt.ptr> + scf.yield %11, %9 : tensor<16xi32>, tensor<16xi32> + } + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/CmpConverter.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/CmpConverter.mlir new file mode 100644 index 0000000000..909de6174d --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/CmpConverter.mlir @@ -0,0 +1,24 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +module { + tt.func public @test_cmp(%arg0: tensor<128xi32>) -> tensor<128xi1> { + %cst_12 = arith.constant dense<0> : tensor<128xi32> + %cst_13 = arith.constant dense<1> : tensor<128xi32> + %cst_14 = arith.constant dense<100> : tensor<128xi32> + %39 = arith.cmpi slt, %arg0, %cst_14 : tensor<128xi32> + %40 = arith.select %39, %cst_13, %cst_12 : tensor<128xi1>, tensor<128xi32> + %41 = arith.cmpi ne, %40, %cst_12 : tensor<128xi32> + tt.return %41 : tensor<128xi1> + } +} + +// CHECK-LABEL: tt.func public @test_cmp( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xi32>) -> tensor<128xi1> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<0> : tensor<128xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant dense<1> : tensor<128xi32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<100> : tensor<128xi32> +// CHECK: %[[VAL_4:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_3]] : tensor<128xi32> +// CHECK: %[[VAL_5:.*]] = arith.select %[[VAL_4]], %[[VAL_2]], %[[VAL_1]] : tensor<128xi1>, tensor<128xi32> +// CHECK: %[[VAL_6:.*]] = arith.cmpi ne, %[[VAL_5]], %[[VAL_1]] : tensor<128xi32> +// CHECK: tt.return %[[VAL_6]] : tensor<128xi1> +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/PromotePointerIterArgsPattern.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/PromotePointerIterArgsPattern.mlir new file mode 100644 index 0000000000..7163b89f5f --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/PromotePointerIterArgsPattern.mlir @@ -0,0 +1,73 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +module { + tt.func public @test_promote_pointer_iter(%base_ptr: !tt.ptr {tt.divisibility = 16 : i32}) -> !tt.ptr { + %c1_i32 = arith.constant 1 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %true_mask = arith.constant 1 : i1 + %c0_index = arith.constant 0 : index + %c10_index = arith.constant 10 : index + %c1_index = arith.constant 1 : index + %final_ptr = scf.for %iv = %c0_index to %c10_index step %c1_index iter_args(%ptr = %base_ptr) -> (!tt.ptr) { + %data = tt.load %ptr, %true_mask, %c0_f32 : !tt.ptr + %new_ptr = tt.addptr %ptr, %c1_i32 : !tt.ptr, i32 + scf.yield %new_ptr : !tt.ptr + } + tt.return %final_ptr : !tt.ptr + } +} + +// CHECK-LABEL: tt.func public @test_promote_pointer_iter( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) -> !tt.ptr { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = scf.for %[[VAL_6:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_4]] iter_args(%[[VAL_7:.*]] = %[[VAL_0]]) -> (!tt.ptr) { +// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_7]], %[[VAL_1]] : !tt.ptr, i32 +// CHECK: scf.yield %[[VAL_8]] : !tt.ptr +// CHECK: } +// CHECK: tt.return %[[VAL_5]] : !tt.ptr +// CHECK: } + + +// ----- + + +module { + tt.func public @test_promote_pointer_iter_advance(%base_ptr: !tt.ptr) -> !tt.ptr>{ + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1_i32 = arith.constant 1 : i32 // nonZeroConstant 需要 1 + %c0_i32_2 = arith.constant 0 : i32 + %c0_index = arith.constant 0 : index + %c10_index = arith.constant 10 : index + %c1_index = arith.constant 1 : index + %cst = arith.constant dense<0.000000e+00> : tensor<32xf16> + %ptr0 = tt.make_tensor_ptr %base_ptr, [%c32_i64], [%c1_i64], [%c0_i32] {order = array} : !tt.ptr> + %final_ptr = scf.for %iv = %c0_index to %c10_index step %c1_index iter_args(%ptr = %ptr0) -> !tt.ptr> { + %data = tt.load %ptr : !tt.ptr> + %new_ptr = tt.advance %ptr, [%c1_i32, %c0_i32_2] : !tt.ptr> + scf.yield %new_ptr : !tt.ptr> + } + tt.return %final_ptr : !tt.ptr> + } +} + +// CHECK-LABEL: tt.func public @test_promote_pointer_iter_advance( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> !tt.ptr> { +// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_3:.*]] = arith.constant 32 : i64 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_8:.*]] = tt.make_tensor_ptr %[[VAL_0]], [%[[VAL_3]]], [%[[VAL_2]]], [%[[VAL_1]]] {order = array} : > +// CHECK: %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_7]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (!tt.ptr>) { +// CHECK: %[[VAL_12:.*]] = tt.advance %[[VAL_11]], [%[[VAL_4]], %[[VAL_1]]] : > +// CHECK: scf.yield %[[VAL_12]] : !tt.ptr> +// CHECK: } +// CHECK: tt.return %[[VAL_9]] : !tt.ptr> +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/SplatCmpConverter.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/SplatCmpConverter.mlir new file mode 100644 index 0000000000..f046891ec1 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/SplatCmpConverter.mlir @@ -0,0 +1,16 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s +module { + tt.func public @test_splat_cmp(%arg0: i32, %arg1: i32) -> tensor<128xi1> { + %0 = tt.splat %arg0 : i32 -> tensor<128xi32> + %1 = tt.splat %arg1 : i32 -> tensor<128xi32> + %2 = arith.cmpi slt, %0, %1 : tensor<128xi32> + tt.return %2 : tensor<128xi1> + } +} + +// CHECK-LABEL: tt.func public @test_splat_cmp( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32) -> tensor<128xi1> { +// CHECK: %[[VAL_2:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_1]] : i32 +// CHECK: %[[VAL_3:.*]] = tt.splat %[[VAL_2]] : i1 -> tensor<128xi1> +// CHECK: tt.return %[[VAL_3]] : tensor<128xi1> +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseCmp.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseCmp.mlir new file mode 100644 index 0000000000..abdae3fa16 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseCmp.mlir @@ -0,0 +1,117 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +tt.func public @test_cmp_ult(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32> + %cst_0 = arith.constant dense<512> : tensor<1024xi32> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 {tt.divisibility = dense<512> : tensor<1xi32>} : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = arith.divsi %4, %cst_0 : tensor<1024xi32> + %6 = arith.cmpi ult, %5, %cst_0 : tensor<1024xi32> + %7 = arith.muli %5, %cst_0 : tensor<1024xi32> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %9, %6, %cst : tensor<1024x!tt.ptr> + %11 = arith.muli %0, %c1024_i32 : i32 + %12 = tt.splat %11 : i32 -> tensor<1024xi32> + %13 = arith.addi %12, %2 : tensor<1024xi32> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %15, %10 : tensor<1024x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @test_cmp_ult( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<1024> : tensor<1xi64> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK: %[[VAL_4:.*]] = arith.constant dense<512> : tensor<2xi32> +// CHECK: %[[VAL_5:.*]] = arith.constant 512 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_6]] {tt.divisibility = dense<512> : tensor<1xi32>} : i32 +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index +// CHECK: %[[VAL_10:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_11:.*]] = arith.divsi %[[VAL_9]], %[[VAL_5]] : index +// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_11]], %[[VAL_5]] : index +// CHECK: %[[VAL_13:.*]] = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> +// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_13]], %[[VAL_4]] : tensor<2xi32> +// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_12]] : index to i32 +// CHECK: %[[VAL_16:.*]] = tt.splat %[[VAL_15]] : i32 -> tensor<2xi32> +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_14]], %[[VAL_16]] : tensor<2xi32> +// CHECK: %[[VAL_18:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<2x!tt.ptr> +// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_18]], %[[VAL_17]] : tensor<2x!tt.ptr>, tensor<2xi32> +// CHECK: %[[VAL_20:.*]] = arith.index_cast %[[VAL_11]] : index to i32 +// CHECK: %[[VAL_21:.*]] = tt.splat %[[VAL_20]] : i32 -> tensor<2xi32> +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_13]], %[[VAL_21]] : tensor<2xi32> +// CHECK: %[[VAL_23:.*]] = arith.cmpi slt, %[[VAL_22]], %[[VAL_4]] : tensor<2xi32> +// CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_19]], %[[VAL_23]], %[[VAL_3]] : tensor<2x!tt.ptr> +// CHECK: %[[VAL_25:.*]] = tensor.empty() : tensor<2x512xf32> +// CHECK: %[[VAL_26:.*]] = linalg.broadcast ins(%[[VAL_24]] : tensor<2xf32>) outs(%[[VAL_25]] : tensor<2x512xf32>) dimensions = [1] +// CHECK: %[[VAL_27:.*]] = tensor.reshape %[[VAL_26]](%[[VAL_2]]) : (tensor<2x512xf32>, tensor<1xi64>) -> tensor<1024xf32> +// CHECK: %[[VAL_28:.*]] = arith.muli %[[VAL_7]], %[[VAL_6]] : i32 +// CHECK: %[[VAL_29:.*]] = tt.splat %[[VAL_28]] : i32 -> tensor<1024xi32> +// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_10]] : tensor<1024xi32> +// CHECK: %[[VAL_31:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_32:.*]] = tt.addptr %[[VAL_31]], %[[VAL_30]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: tt.store %[[VAL_32]], %[[VAL_27]] : tensor<1024x!tt.ptr> +// CHECK: tt.return +// CHECK: } + + +// ----- + + +tt.func public @test_cmp_uge(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32> + %cst_0 = arith.constant dense<511> : tensor<1024xi32> + %cst_1 = arith.constant dense<512> : tensor<1024xi32> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 {tt.divisibility = dense<512> : tensor<1xi32>} : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = arith.divsi %4, %cst_1 : tensor<1024xi32> + %6 = arith.cmpi uge, %cst_0, %5 : tensor<1024xi32> + %7 = arith.muli %5, %cst_1 : tensor<1024xi32> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %9, %6, %cst : tensor<1024x!tt.ptr> + %11 = arith.muli %0, %c1024_i32 : i32 + %12 = tt.splat %11 : i32 -> tensor<1024xi32> + %13 = arith.addi %12, %2 : tensor<1024xi32> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %15, %10 : tensor<1024x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @test_cmp_uge( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : tensor<1024xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<511> : tensor<1024xi32> +// CHECK: %[[VAL_4:.*]] = arith.constant dense<512> : tensor<1024xi32> +// CHECK: %[[VAL_5:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_5]] {tt.divisibility = dense<512> : tensor<1xi32>} : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_7]] : i32 -> tensor<1024xi32> +// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_9]], %[[VAL_8]] : tensor<1024xi32> +// CHECK: %[[VAL_11:.*]] = arith.divsi %[[VAL_10]], %[[VAL_4]] : tensor<1024xi32> +// CHECK: %[[VAL_12:.*]] = arith.cmpi ule, %[[VAL_11]], %[[VAL_3]] : tensor<1024xi32> +// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : tensor<1024xi32> +// CHECK: %[[VAL_14:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_14]], %[[VAL_13]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_16:.*]] = tt.load %[[VAL_15]], %[[VAL_12]], %[[VAL_2]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_6]], %[[VAL_5]] : i32 +// CHECK: %[[VAL_18:.*]] = tt.splat %[[VAL_17]] : i32 -> tensor<1024xi32> +// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_8]] : tensor<1024xi32> +// CHECK: %[[VAL_20:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_21:.*]] = tt.addptr %[[VAL_20]], %[[VAL_19]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: tt.store %[[VAL_21]], %[[VAL_16]] : tensor<1024x!tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseConstant.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseConstant.mlir new file mode 100644 index 0000000000..4e1053793d --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseConstant.mlir @@ -0,0 +1,28 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +tt.func public @test_non_splat_mask(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %non_splat_mask = arith.constant dense<[false, true]> : tensor<2xi1> + %c0_f32 = arith.constant dense<0.000000e+00> : tensor<2xf32> + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %ptr_load = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %2 = tt.load %ptr_load, %non_splat_mask, %c0_f32 : tensor<2x!tt.ptr> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %ptr_store = tt.addptr %3, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %ptr_store, %2 : tensor<2x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @test_non_splat_mask( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<[false, true]> : tensor<2xi1> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK: %[[VAL_4:.*]] = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> +// CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<2x!tt.ptr> +// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_5]], %[[VAL_4]] : tensor<2x!tt.ptr>, tensor<2xi32> +// CHECK: %[[VAL_7:.*]] = tt.load %[[VAL_6]], %[[VAL_2]], %[[VAL_3]] : tensor<2x!tt.ptr> +// CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<2x!tt.ptr> +// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_8]], %[[VAL_4]] : tensor<2x!tt.ptr>, tensor<2xi32> +// CHECK: tt.store %[[VAL_9]], %[[VAL_7]] : tensor<2x!tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseMakeRange.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseMakeRange.mlir new file mode 100644 index 0000000000..efdaad3204 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseMakeRange.mlir @@ -0,0 +1,24 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +module { + tt.func public @test_stride_not_one(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %c0_f32 = arith.constant dense<0.000000e+00> : tensor<4xf32> + %fake_range_mask = arith.constant dense<[false, false, false, false]> : tensor<4xi1> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %ptr = tt.addptr %1, %0 : tensor<4x!tt.ptr>, tensor<4xi32> + %2 = tt.load %ptr, %fake_range_mask, %c0_f32 : tensor<4x!tt.ptr> + tt.store %ptr, %2 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: tt.func public @test_stride_not_one( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : tensor<4xf32> +// CHECK: %[[VAL_2:.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> +// CHECK: %[[VAL_3:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<4x!tt.ptr> +// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_3]], %[[VAL_2]] : tensor<4x!tt.ptr>, tensor<4xi32> +// CHECK: tt.store %[[VAL_4]], %[[VAL_1]] : tensor<4x!tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseRem.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseRem.mlir new file mode 100644 index 0000000000..b48d7fc288 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseRem.mlir @@ -0,0 +1,81 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +tt.func public @kernel_with_rem_safe(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256xf32> + %cst_0 = arith.constant dense<1024> : tensor<256xi32> + %c256_i32 = arith.constant 256 : i32 + %cst_1 = arith.constant dense<64> : tensor<256xi32> + %cst_2 = arith.constant dense<128> : tensor<256xi32> + %0 = tt.get_program_id x : i32 + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %2 = arith.remsi %1, %cst_2 : tensor<256xi32> + %3 = arith.cmpi slt, %2, %cst_1 : tensor<256xi32> + %4 = arith.muli %0, %c256_i32 : i32 + %5 = tt.splat %4 : i32 -> tensor<256xi32> + %6 = arith.addi %5, %1 : tensor<256xi32> + %7 = arith.cmpi slt, %6, %cst_0 : tensor<256xi32> + %8 = arith.andi %3, %7 : tensor<256xi1> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %10 = tt.addptr %9, %2 : tensor<256x!tt.ptr>, tensor<256xi32> + %11 = tt.load %10, %8, %cst : tensor<256x!tt.ptr> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr> + %13 = tt.addptr %12, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %13, %11, %8 : tensor<256x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @kernel_with_rem_safe( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : tensor<256xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<1024> : tensor<256xi32> +// CHECK: %[[VAL_4:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant dense<64> : tensor<256xi32> +// CHECK: %[[VAL_6:.*]] = arith.constant dense<128> : tensor<256xi32> +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> +// CHECK: %[[VAL_9:.*]] = arith.remsi %[[VAL_8]], %[[VAL_6]] : tensor<256xi32> +// CHECK: %[[VAL_10:.*]] = arith.cmpi slt, %[[VAL_9]], %[[VAL_5]] : tensor<256xi32> +// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : i32 +// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_11]] : i32 -> tensor<256xi32> +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_8]] : tensor<256xi32> +// CHECK: %[[VAL_14:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_3]] : tensor<256xi32> +// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_10]], %[[VAL_14]] : tensor<256xi1> +// CHECK: %[[VAL_16:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<256x!tt.ptr> +// CHECK: %[[VAL_17:.*]] = tt.addptr %[[VAL_16]], %[[VAL_9]] : tensor<256x!tt.ptr>, tensor<256xi32> +// CHECK: %[[VAL_18:.*]] = tt.load %[[VAL_17]], %[[VAL_15]], %[[VAL_2]] : tensor<256x!tt.ptr> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<256x!tt.ptr> +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_19]], %[[VAL_8]] : tensor<256x!tt.ptr>, tensor<256xi32> +// CHECK: tt.store %[[VAL_20]], %[[VAL_18]], %[[VAL_15]] : tensor<256x!tt.ptr> +// CHECK: tt.return +// CHECK: } + + +// ----- + + +tt.func public @test_remsi_with_broadcast(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<2x2xf32> { + %c0_f32 = arith.constant dense<0.000000e+00> : tensor<2x2xf32> + %c4 = arith.constant dense<4> : tensor<2x2xi32> + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> + %2 = tt.broadcast %1 : tensor<1x2xi32> -> tensor<2x2xi32> + %3 = arith.remsi %2, %c4 : tensor<2x2xi32> + %4 = arith.trunci %3 : tensor<2x2xi32> to tensor<2x2xi1> + %ptrs = tt.splat %arg0 : !tt.ptr -> tensor<2x2x!tt.ptr> + %vals = tt.load %ptrs, %4, %c0_f32 : tensor<2x2x!tt.ptr> + tt.return %vals : tensor<2x2xf32> +} + +// CHECK-LABEL: tt.func public @test_remsi_with_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<2x2xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : tensor<2x2xf32> +// CHECK: %[[VAL_2:.*]] = arith.constant dense<4> : tensor<2x2xi32> +// CHECK: %[[VAL_3:.*]] = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> +// CHECK: %[[VAL_4:.*]] = tt.expand_dims %[[VAL_3]] {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> +// CHECK: %[[VAL_5:.*]] = tt.broadcast %[[VAL_4]] : tensor<1x2xi32> -> tensor<2x2xi32> +// CHECK: %[[VAL_6:.*]] = arith.remsi %[[VAL_5]], %[[VAL_2]] : tensor<2x2xi32> +// CHECK: %[[VAL_7:.*]] = arith.trunci %[[VAL_6]] : tensor<2x2xi32> to tensor<2x2xi1> +// CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<2x2x!tt.ptr> +// CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_8]], %[[VAL_7]], %[[VAL_1]] : tensor<2x2x!tt.ptr> +// CHECK: tt.return %[[VAL_9]] : tensor<2x2xf32> +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/bubbleupoperation.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/bubbleupoperation.mlir new file mode 100644 index 0000000000..626f670d1d --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/bubbleupoperation.mlir @@ -0,0 +1,127 @@ +// RUN: triton-opt %s --bubble-up-operation | FileCheck %s + +// CHECK-LABEL: tt.func @test_subi_extract_bubbleup +tt.func @test_subi_extract_bubbleup(%a: tensor<128xi32>, %b: tensor<128xi32>, %i: index, %c: i32) -> i32 { + %0 = arith.subi %a, %b : tensor<128xi32> + %1 = tensor.extract %0[%i] : tensor<128xi32> + %2 = arith.muli %1, %c : i32 + tt.return %2 : i32 +} + + +// CHECK-LABEL: tt.func @test_maxsi_extract_bubbleup +tt.func @test_maxsi_extract_bubbleup(%a: tensor<128xi32>, %b: tensor<128xi32>, %i: index, %c: i32) -> i32 { + %0 = arith.maxsi %a, %b : tensor<128xi32> + %1 = tensor.extract %0[%i] : tensor<128xi32> + %2 = arith.muli %1, %c : i32 + tt.return %2 : i32 +} + + +// CHECK-LABEL: tt.func @test_minsi_extract_bubbleup +tt.func @test_minsi_extract_bubbleup(%a: tensor<128xi32>, %b: tensor<128xi32>, %i: index, %c: i32) -> i32 { + %0 = arith.minsi %a, %b : tensor<128xi32> + %1 = tensor.extract %0[%i] : tensor<128xi32> + %2 = arith.muli %1, %c : i32 + tt.return %2 : i32 +} + + +// CHECK-LABEL: tt.func @test_extf_extract_bubbleup +tt.func @test_extf_extract_bubbleup(%a: tensor<128xf16>, %i: index, %c: f32) -> f32 { + %0 = arith.extf %a : tensor<128xf16> to tensor<128xf32> + %1 = tensor.extract %0[%i] : tensor<128xf32> + %2 = arith.mulf %1, %c : f32 + tt.return %2 : f32 +} + + +// CHECK-LABEL: tt.func @test_minnumf_extract_bubbleup +tt.func @test_minnumf_extract_bubbleup(%a: tensor<128xf32>, %b: tensor<128xf32>, %i: index, %c: f32) -> f32 { + %0 = arith.minnumf %a, %b : tensor<128xf32> + %1 = tensor.extract %0[%i] : tensor<128xf32> + %2 = arith.mulf %1, %c : f32 + tt.return %2 : f32 +} + + +// CHECK-LABEL: tt.func @test_maxnumf_extract_bubbleup +tt.func @test_maxnumf_extract_bubbleup(%a: tensor<128xf32>, %b: tensor<128xf32>, %i: index, %c: f32) -> f32 { + %0 = arith.maxnumf %a, %b : tensor<128xf32> + %1 = tensor.extract %0[%i] : tensor<128xf32> + %2 = arith.mulf %1, %c : f32 + tt.return %2 : f32 +} + + +// CHECK-LABEL: tt.func @test_cmpf_extract_bubbleup +tt.func @test_cmpf_extract_bubbleup(%a: tensor<128xf32>, %b: tensor<128xf32>, %i: index) -> i1 { + %0 = arith.cmpf olt, %a, %b : tensor<128xf32> + %1 = tensor.extract %0[%i] : tensor<128xi1> + tt.return %1 : i1 +} + + +// CHECK-LABEL: tt.func @test_addptr_extract_bubbleup +tt.func @test_addptr_extract_bubbleup(%a: tensor<128x!tt.ptr>, %b: tensor<128xi32>, %i: index) -> !tt.ptr { + %0 = tt.addptr %a, %b : tensor<128x!tt.ptr>, tensor<128xi32> + %1 = tensor.extract %0[%i] : tensor<128x!tt.ptr> + tt.return %1 : !tt.ptr +} + + +// CHECK-LABEL: tt.func @test_ceil_extract_bubbleup +tt.func @test_ceil_extract_bubbleup(%a: tensor<128xf32>, %i: index, %c: f32) -> f32 { + %0 = math.ceil %a : tensor<128xf32> + %1 = tensor.extract %0[%i] : tensor<128xf32> + %2 = arith.mulf %1, %c : f32 + tt.return %2 : f32 +} + + +// CHECK-LABEL: tt.func @test_slice_extract_dropdim_bubbleup +tt.func @test_slice_extract_dropdim_bubbleup(%a: tensor<128x128x128xf32>, %i: index, %j: index) -> f32 { + %0 = tensor.extract_slice %a[0, %i, 0][1, 1, 128][1, 1, 1] : tensor<128x128x128xf32> to tensor<128xf32> + %1 = tensor.extract %0[%j] : tensor<128xf32> + tt.return %1 : f32 +} + + +// CHECK-LABEL: tt.func @test_expand_slice_bubbleup +tt.func @test_expand_slice_bubbleup(%a: tensor<128xf32>, %i: index, %c: f32) -> tensor<1x1xf32> { + %0 = tt.expand_dims %a {axis = 0 : i32} : tensor<128xf32> -> tensor<1x128xf32> + %1 = tensor.extract_slice %0[0, %i][1, 1][1, 1] : tensor<1x128xf32> to tensor<1x1xf32> + tt.return %1 : tensor<1x1xf32> +} + + +// CHECK-LABEL: tt.func @test_expand_slice_dropdim_bubbleup +tt.func @test_expand_slice_dropdim_bubbleup(%a: tensor<128x128xf32>, %i: index, %c: f32) -> tensor<128x1xf32> { + %0 = tt.expand_dims %a {axis = 2 : i32} : tensor<128x128xf32> -> tensor<128x128x1xf32> + %1 = tensor.extract_slice %0[%i, 0, 0][1, 128, 1][1, 1, 1] : tensor<128x128x1xf32> to tensor<128x1xf32> + tt.return %1 : tensor<128x1xf32> +} + + +// CHECK-LABEL: tt.func @test_splat_slice_bubbleup +tt.func @test_splat_slice_bubbleup(%a: f32, %i: index, %c: f32) -> tensor<1xf32> { + %0 = tt.splat %a : f32 -> tensor<128xf32> + %1 = tensor.extract_slice %0[%i][1][1] : tensor<128xf32> to tensor<1xf32> + tt.return %1 : tensor<1xf32> +} + + +// CHECK-LABEL: tt.func @test_makerange_slice_bubbleup +tt.func @test_makerange_slice_bubbleup(%i: index, %c: f32) -> tensor<1xi32> { + %0 = tt.make_range {start = 0 : i32, end = 128 : i32} : tensor<128xi32> + %1 = tensor.extract_slice %0[%i][1][1] : tensor<128xi32> to tensor<1xi32> + tt.return %1 : tensor<1xi32> +} + + +// CHECK-LABEL: tt.func @test_slice_all_bubbleup +tt.func @test_slice_all_bubbleup(%i: index, %c: f32) -> tensor<128xi32> { + %0 = tt.make_range {start = 0 : i32, end = 128 : i32} : tensor<128xi32> + %1 = tensor.extract_slice %0[0][128][1] : tensor<128xi32> to tensor<128xi32> + tt.return %1 : tensor<128xi32> +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/if_simplifier.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/if_simplifier.mlir new file mode 100644 index 0000000000..7f71ec94a9 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/if_simplifier.mlir @@ -0,0 +1,45 @@ +// RUN: triton-opt --triton-to-unstructure --split-input-file %s | FileCheck %s --implicit-check-not="DiscreteMemAccess" +// CHECK-LABEL: tt.func public @triton_for_if_load +// CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor<16xi32> +// CHECK: %[[CST0:.*]] = arith.constant dense<1> : tensor<16xi32> +// CHECK: %[[SEL:.*]] = arith.select %{{.*}}, %[[CST0]], %[[CST]] : tensor<16xi32> +// CHECK: %[[ADD:.*]] = arith.addi %{{.*}}, %[[SEL]] : tensor<16xi32> +// CHECK: %[[ADDPTR:.*]] = tt.addptr %{{.*}}, %[[ADD]] : tensor<16x!tt.ptr>, tensor<16xi32> + + +module { + tt.func public @triton_for_if_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16xf32> + %cst_0 = arith.constant dense<32> : tensor<16xi32> + %cst_1 = arith.constant dense<1> : tensor<16xi32> + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %1 = tt.get_program_id x : i32 + %2 = arith.cmpi ne, %1, %c0_i32 : i32 + %3 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr> + %5:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<16xi32>, tensor<16xi32>) : i32 { + %6 = arith.muli %arg2, %c16_i32 : i32 + %7 = tt.splat %6 : i32 -> tensor<16xi32> + %8 = arith.addi %arg3, %7 : tensor<16xi32> + %9 = arith.addi %arg4, %7 : tensor<16xi32> + %10 = scf.if %2 -> (tensor<16xi32>) { + %16 = arith.addi %8, %cst_1 : tensor<16xi32> + scf.yield %16 : tensor<16xi32> + } else { + scf.yield %8 : tensor<16xi32> + } + %11 = tt.addptr %3, %10 : tensor<16x!tt.ptr>, tensor<16xi32> + %12 = arith.cmpi slt, %10, %cst_0 : tensor<16xi32> + %13 = tt.load %11 : tensor<16x!tt.ptr> + %14 = arith.select %12, %13, %cst : tensor<16xi1>, tensor<16xf32> + %15 = tt.addptr %4, %9 : tensor<16x!tt.ptr>, tensor<16xi32> + tt.store %15, %14 : tensor<16x!tt.ptr> + scf.yield %10, %9 : tensor<16xi32>, tensor<16xi32> + } + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/nested_loop.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/nested_loop.mlir new file mode 100644 index 0000000000..77aed1ce16 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/nested_loop.mlir @@ -0,0 +1,207 @@ +// RUN: triton-opt --triton-to-unstructure --split-input-file %s | FileCheck %s + +tt.func public @test_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c3_i32 = arith.constant 3 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<128> : tensor<128xi32> + %cst_0 = arith.constant dense<0> : tensor<128xi32> + %cst_1 = arith.constant dense<300> : tensor<128xi32> + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %3 = tt.splat %1 : i32 -> tensor<128xi32> + %4 = arith.addi %3, %2 : tensor<128xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %7 = tt.addptr %6, %4 : tensor<128x!tt.ptr>, tensor<128xi32> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<128x!tt.ptr> + %9 = tt.addptr %8, %4 : tensor<128x!tt.ptr>, tensor<128xi32> + %10 = tt.load %9 : tensor<128x!tt.ptr> + %11 = tt.splat %arg3 : !tt.ptr -> tensor<128x!tt.ptr> + %12 = tt.addptr %11, %10 : tensor<128x!tt.ptr>, tensor<128xi32> + %13:3 = scf.for %arg4 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg5 = %4, %arg6 = %7, %arg7 = %12) -> (tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>) : i32 { + %14:3 = scf.for %arg8 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg9 = %arg5, %arg10 = %arg6, %arg11 = %arg7) -> (tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>) : i32 { + %18 = arith.cmpi slt, %arg9, %cst_1 : tensor<128xi32> + %19 = tt.addptr %5, %arg9 : tensor<128x!tt.ptr>, tensor<128xi32> + %20 = tt.load %19, %18, %cst_0 : tensor<128x!tt.ptr> + %21 = tt.load %arg11 : tensor<128x!tt.ptr> + %22 = arith.addi %20, %21 : tensor<128xi32> + tt.store %arg10, %22, %18 : tensor<128x!tt.ptr> + %23 = arith.addi %arg9, %cst : tensor<128xi32> + %24 = tt.addptr %arg10, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + %25 = tt.addptr %arg11, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + scf.yield %23, %24, %25 : tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr> + } + %15 = arith.addi %14#0, %cst : tensor<128xi32> + %16 = tt.addptr %14#1, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + %17 = tt.addptr %14#2, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + scf.yield %15, %16, %17 : tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr> + } + tt.return +} + +// CHECK-LABEL: tt.func public @test_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_4:.*]] = arith.constant dense<128> : tensor<128xi64> +// CHECK: %[[VAL_8:.*]] = arith.constant 3 : i32 +// CHECK: %[[VAL_15:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_15]], %{{.*}} : i32 +// CHECK: %[[VAL_17:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +// CHECK: %[[VAL_18:.*]] = tt.splat %[[VAL_16]] : i32 -> tensor<128xi32> +// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_17]] : tensor<128xi32> +// CHECK: %[[VAL_20:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_19]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[VAL_22:.*]] = tt.splat %[[VAL_2]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_22]], %[[VAL_19]] : tensor<128x!tt.ptr>, tensor<128xi32> +// CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_23]] : tensor<128x!tt.ptr> +// CHECK: %[[VAL_25:.*]] = arith.extsi %[[VAL_24]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[VAL_26:.*]]:3 = scf.for %[[VAL_27:.*]] = %{{.*}} to %[[VAL_8]] step %{{.*}} iter_args(%[[VAL_28:.*]] = %[[VAL_19]], %[[VAL_29:.*]] = %[[VAL_21]], %[[VAL_30:.*]] = %[[VAL_25]]) -> (tensor<128xi32>, tensor<128xi64>, tensor<128xi64>) : i32 { +// CHECK: %[[VAL_31:.*]]:3 = scf.for %[[VAL_32:.*]] = %{{.*}} to %[[VAL_8]] step %{{.*}} iter_args(%[[VAL_33:.*]] = %[[VAL_28]], %[[VAL_34:.*]] = %[[VAL_29]], %[[VAL_35:.*]] = %[[VAL_30]]) -> (tensor<128xi32>, tensor<128xi64>, tensor<128xi64>) : i32 { +// CHECK: %[[VAL_36:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_37:.*]] = tt.addptr %[[VAL_36]], %[[VAL_34]] : tensor<128x!tt.ptr>, tensor<128xi64> +// CHECK: %[[VAL_38:.*]] = arith.cmpi slt, %[[VAL_33]], %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_39:.*]] = tt.addptr %[[VAL_20]], %[[VAL_33]] : tensor<128x!tt.ptr>, tensor<128xi32> +// CHECK: %[[VAL_40:.*]] = tt.load %[[VAL_39]], %[[VAL_38]], %{{.*}} : tensor<128x!tt.ptr> +// CHECK: %[[VAL_41:.*]] = tensor.empty() : tensor<128xi32> +// CHECK: %[[VAL_42:.*]] = scf.for %[[VAL_43:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[VAL_44:.*]] = %[[VAL_41]]) -> (tensor<128xi32>) { +// CHECK: %[[VAL_45:.*]] = tensor.extract %[[VAL_35]]{{\[}}%[[VAL_43]]] {DiscreteMemAccess} : tensor<128xi64> +// CHECK: %[[VAL_46:.*]] = tt.addptr %[[VAL_3]], %[[VAL_45]] : !tt.ptr, i64 +// CHECK: %[[VAL_47:.*]] = tt.load %[[VAL_46]] {DiscreteMemAccess} : !tt.ptr +// CHECK: %[[VAL_49:.*]] = tensor.insert_slice %{{.*}} into %[[VAL_44]]{{\[}}%[[VAL_43]]] [1] [1] : tensor<1xi32> into tensor<128xi32> +// CHECK: scf.yield {DiscreteMemAccess} %[[VAL_49]] : tensor<128xi32> +// CHECK: } {ExtractedLoadOrStore} +// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_40]], %[[VAL_42]] : tensor<128xi32> +// CHECK: tt.store %[[VAL_37]], %[[VAL_50]], %[[VAL_38]] : tensor<128x!tt.ptr> +// CHECK: %[[VAL_51:.*]] = arith.addi %[[VAL_33]], %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_52:.*]] = arith.addi %[[VAL_34]], %[[VAL_4]] : tensor<128xi64> +// CHECK: %[[VAL_53:.*]] = arith.addi %[[VAL_35]], %[[VAL_4]] : tensor<128xi64> +// CHECK: scf.yield %[[VAL_51]], %[[VAL_52]], %[[VAL_53]] : tensor<128xi32>, tensor<128xi64>, tensor<128xi64> +// CHECK: } +// CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_55:.*]]#0, %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_55]]#1, %[[VAL_4]] : tensor<128xi64> +// CHECK: %[[VAL_57:.*]] = arith.addi %[[VAL_55]]#2, %[[VAL_4]] : tensor<128xi64> +// CHECK: scf.yield %[[VAL_54]], %[[VAL_56]], %[[VAL_57]] : tensor<128xi32>, tensor<128xi64>, tensor<128xi64> +// CHECK: } +// CHECK: tt.return +// CHECK: } + +// ----- + +tt.func public @test_kernel2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c300_i32 = arith.constant 300 : i32 + %c3_i32 = arith.constant 3 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<128> : tensor<128xi32> + %cst_0 = arith.constant dense<0> : tensor<128xi32> + %cst_1 = arith.constant dense<300> : tensor<128xi32> + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %3 = tt.splat %1 : i32 -> tensor<128xi32> + %4 = arith.addi %3, %2 : tensor<128xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %7 = tt.addptr %6, %4 : tensor<128x!tt.ptr>, tensor<128xi32> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<128x!tt.ptr> + %9 = tt.addptr %8, %4 : tensor<128x!tt.ptr>, tensor<128xi32> + %10 = tt.load %9 : tensor<128x!tt.ptr> + %11 = tt.splat %arg3 : !tt.ptr -> tensor<128x!tt.ptr> + %12 = tt.addptr %11, %10 : tensor<128x!tt.ptr>, tensor<128xi32> + %13:3 = scf.while (%arg4 = %cst, %arg5 = %4, %arg6 = %7, %arg7 = %12) : (tensor<128xi32>, tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>) -> (tensor<128x!tt.ptr>, tensor<128xi32>, tensor<128x!tt.ptr>) { + %14 = "tt.reduce"(%arg4) <{axis = 0 : i32}> ({ + ^bb0(%arg8: i32, %arg9: i32): + %16 = arith.addi %arg8, %arg9 : i32 + tt.reduce.return %16 : i32 + }) : (tensor<128xi32>) -> i32 + %15 = arith.cmpi slt, %14, %c300_i32 : i32 + scf.condition(%15) %arg7, %arg5, %arg6 : tensor<128x!tt.ptr>, tensor<128xi32>, tensor<128x!tt.ptr> + } do { + ^bb0(%arg4: tensor<128x!tt.ptr>, %arg5: tensor<128xi32>, %arg6: tensor<128x!tt.ptr>): + %14:4 = scf.while (%arg7 = %c0_i32, %arg8 = %arg6, %arg9 = %arg4, %arg10 = %arg5) : (i32, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>, tensor<128xi32>) -> (tensor<128xi32>, i32, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>) { + %18 = arith.cmpi slt, %arg7, %c3_i32 : i32 + scf.condition(%18) %arg10, %arg7, %arg8, %arg9 : tensor<128xi32>, i32, tensor<128x!tt.ptr>, tensor<128x!tt.ptr> + } do { + ^bb0(%arg7: tensor<128xi32>, %arg8: i32, %arg9: tensor<128x!tt.ptr>, %arg10: tensor<128x!tt.ptr>): + %18 = arith.cmpi slt, %arg7, %cst_1 : tensor<128xi32> + %19 = tt.addptr %5, %arg7 : tensor<128x!tt.ptr>, tensor<128xi32> + %20 = tt.load %19, %18, %cst_0 : tensor<128x!tt.ptr> + %21 = tt.load %arg10 : tensor<128x!tt.ptr> + %22 = arith.addi %20, %21 : tensor<128xi32> + tt.store %arg9, %22, %18 : tensor<128x!tt.ptr> + %23 = arith.addi %arg7, %cst : tensor<128xi32> + %24 = arith.addi %arg8, %c1_i32 : i32 + %25 = tt.addptr %arg9, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + %26 = tt.addptr %arg10, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + scf.yield %24, %25, %26, %23 : i32, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>, tensor<128xi32> + } + %15 = arith.addi %14#0, %cst : tensor<128xi32> + %16 = tt.addptr %14#2, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + %17 = tt.addptr %14#3, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + scf.yield %14#0, %15, %16, %17 : tensor<128xi32>, tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr> + } + tt.return +} + +// CHECK-LABEL: tt.func public @test_kernel2( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_4:.*]] = arith.constant dense<128> : tensor<128xi64> +// CHECK: %[[VAL_8:.*]] = arith.constant 300 : i32 +// CHECK: %[[VAL_16:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %{{.*}} : i32 +// CHECK: %[[VAL_18:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_17]] : i32 -> tensor<128xi32> +// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : tensor<128xi32> +// CHECK: %[[VAL_21:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_22:.*]] = arith.extsi %[[VAL_20]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_2]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_20]] : tensor<128x!tt.ptr>, tensor<128xi32> +// CHECK: %[[VAL_25:.*]] = tt.load %[[VAL_24]] : tensor<128x!tt.ptr> +// CHECK: %[[VAL_26:.*]] = arith.extsi %[[VAL_25]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[VAL_27:.*]]:3 = scf.while (%[[VAL_28:.*]] = %{{.*}}, %[[VAL_29:.*]] = %[[VAL_20]], %[[VAL_30:.*]] = %[[VAL_22]], %[[VAL_31:.*]] = %[[VAL_26]]) : (tensor<128xi32>, tensor<128xi32>, tensor<128xi64>, tensor<128xi64>) -> (tensor<128xi64>, tensor<128xi32>, tensor<128xi64>) { +// CHECK: %[[VAL_32:.*]] = "tt.reduce"(%[[VAL_28]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0(%[[VAL_33:.*]]: i32, %[[VAL_34:.*]]: i32): +// CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_33]], %[[VAL_34]] : i32 +// CHECK: tt.reduce.return %[[VAL_35]] : i32 +// CHECK: }) : (tensor<128xi32>) -> i32 +// CHECK: %[[VAL_36:.*]] = arith.cmpi slt, %[[VAL_32]], %[[VAL_8]] : i32 +// CHECK: scf.condition(%[[VAL_36]]) %[[VAL_31]], %[[VAL_29]], %[[VAL_30]] : tensor<128xi64>, tensor<128xi32>, tensor<128xi64> +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_37:.*]]: tensor<128xi64>, %[[VAL_38:.*]]: tensor<128xi32>, %[[VAL_39:.*]]: tensor<128xi64>): +// CHECK: %[[VAL_40:.*]]:4 = scf.while (%[[VAL_41:.*]] = %{{.*}}, %[[VAL_42:.*]] = %[[VAL_39]], %[[VAL_43:.*]] = %[[VAL_37]], %[[VAL_44:.*]] = %[[VAL_38]]) : (i32, tensor<128xi64>, tensor<128xi64>, tensor<128xi32>) -> (i32, tensor<128xi64>, tensor<128xi64>, tensor<128xi32>) { +// CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_41]], %{{.*}} : i32 +// CHECK: scf.condition(%[[VAL_45]]) %[[VAL_41]], %[[VAL_42]], %[[VAL_43]], %[[VAL_44]] : i32, tensor<128xi64>, tensor<128xi64>, tensor<128xi32> +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_46:.*]]: i32, %[[VAL_47:.*]]: tensor<128xi64>, %[[VAL_48:.*]]: tensor<128xi64>, %[[VAL_49:.*]]: tensor<128xi32>): +// CHECK: %[[VAL_50:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_51:.*]] = tt.addptr %[[VAL_50]], %[[VAL_47]] : tensor<128x!tt.ptr>, tensor<128xi64> +// CHECK: %[[VAL_52:.*]] = arith.cmpi slt, %[[VAL_49]], %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_53:.*]] = tt.addptr %[[VAL_21]], %[[VAL_49]] : tensor<128x!tt.ptr>, tensor<128xi32> +// CHECK: %[[VAL_54:.*]] = tt.load %[[VAL_53]], %[[VAL_52]], %{{.*}} : tensor<128x!tt.ptr> +// CHECK: %[[VAL_55:.*]] = tensor.empty() : tensor<128xi32> +// CHECK: %[[VAL_56:.*]] = scf.for %[[VAL_57:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[VAL_58:.*]] = %[[VAL_55]]) -> (tensor<128xi32>) { +// CHECK: %[[VAL_59:.*]] = tensor.extract %[[VAL_48]]{{\[}}%[[VAL_57]]] {DiscreteMemAccess} : tensor<128xi64> +// CHECK: %[[VAL_60:.*]] = tt.addptr %[[VAL_3]], %[[VAL_59]] : !tt.ptr, i64 +// CHECK: %[[VAL_61:.*]] = tt.load %[[VAL_60]] {DiscreteMemAccess} : !tt.ptr +// CHECK: %[[VAL_62:.*]] = tt.splat %[[VAL_61]] : i32 -> tensor<1xi32> +// CHECK: %[[VAL_63:.*]] = tensor.insert_slice %[[VAL_62]] into %[[VAL_58]]{{\[}}%[[VAL_57]]] [1] [1] : tensor<1xi32> into tensor<128xi32> +// CHECK: scf.yield {DiscreteMemAccess} %[[VAL_63]] : tensor<128xi32> +// CHECK: } {ExtractedLoadOrStore} +// CHECK: %[[VAL_64:.*]] = arith.addi %[[VAL_54]], %[[VAL_56]] : tensor<128xi32> +// CHECK: tt.store %[[VAL_51]], %[[VAL_64]], %[[VAL_52]] : tensor<128x!tt.ptr> +// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_49]], %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_46]], %{{.*}} : i32 +// CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_47]], %[[VAL_4]] : tensor<128xi64> +// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_48]], %[[VAL_4]] : tensor<128xi64> +// CHECK: scf.yield %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_65]] : i32, tensor<128xi64>, tensor<128xi64>, tensor<128xi32> +// CHECK: } +// CHECK: %[[VAL_69:.*]] = arith.addi %[[VAL_70:.*]]#3, %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_71:.*]] = arith.addi %[[VAL_70]]#1, %[[VAL_4]] : tensor<128xi64> +// CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_70]]#2, %[[VAL_4]] : tensor<128xi64> +// CHECK: scf.yield %[[VAL_70]]#3, %[[VAL_69]], %[[VAL_71]], %[[VAL_72]] : tensor<128xi32>, tensor<128xi32>, tensor<128xi64>, tensor<128xi64> +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/splat.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/splat.mlir new file mode 100644 index 0000000000..cc3accd712 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/splat.mlir @@ -0,0 +1,14 @@ +// RUN: triton-opt %s --triton-to-unstructure | FileCheck %s + +// CHECK-LABEL: tt.func@test_unstructure_splatandloadscenario +// CHECK: %[[EXT:.*]] = tensor.extract %{{.*}}[%{{.*}}] {DiscreteMemAccess} : tensor<128x!tt.ptr> +// CHECK: %[[VAL1:.*]] = tt.load %[[EXT]] : !tt.ptr +// CHECK: %[[VAL2:.*]] = tt.splat %[[VAL1]] : f32 -> tensor<128xf32> +tt.func@test_unstructure_splatandloadscenario(%base: !tt.ptr) -> tensor<128xf32> { + %offset = arith.constant 10 : i64 + %offset_tensor = tt.splat %offset : i64 -> tensor<128xi64> + %base_tensor = tt.splat %base : !tt.ptr -> tensor<128x!tt.ptr> + %ptr = tt.addptr %base_tensor, %offset_tensor : tensor<128x!tt.ptr>, tensor<128xi64> + %val = tt.load %ptr : tensor<128x!tt.ptr> + tt.return %val : tensor<128xf32> +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/unstructure_mix.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/unstructure_mix.mlir new file mode 100644 index 0000000000..4516711f08 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/unstructure_mix.mlir @@ -0,0 +1,82 @@ +// RUN: triton-opt --triton-to-unstructure %s | FileCheck %s + +tt.func public @indirect_mix_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<16> : tensor<1x8xi32> + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c8_i32 : i32 + %2 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %3 = tt.splat %1 : i32 -> tensor<8xi32> + %4 = arith.addi %3, %2 : tensor<8xi32> + %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr> + %7 = tt.addptr %6, %5 : tensor<16x!tt.ptr>, tensor<16xi32> + %8 = tt.load %7 : tensor<16x!tt.ptr> + %9 = tt.expand_dims %2 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %10 = tt.splat %arg3 : i32 -> tensor<1x8xi32> + %11 = arith.muli %9, %10 : tensor<1x8xi32> + %12 = tt.expand_dims %8 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64> + %13 = arith.extsi %11 : tensor<1x8xi32> to tensor<1x8xi64> + %14 = tt.broadcast %13 : tensor<1x8xi64> -> tensor<16x8xi64> + %15 = tt.broadcast %12 : tensor<16x1xi64> -> tensor<16x8xi64> + %16 = arith.addi %14, %15 : tensor<16x8xi64> + %17 = tt.splat %arg2 : !tt.ptr -> tensor<16x8x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<16x8x!tt.ptr>, tensor<16x8xi64> + %19 = tt.load %18 : tensor<16x8x!tt.ptr> + %20 = math.exp %19 : tensor<16x8xf32> + %21 = tt.expand_dims %4 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %22 = arith.muli %21, %cst : tensor<1x8xi32> + %23 = tt.expand_dims %5 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %24 = tt.broadcast %22 : tensor<1x8xi32> -> tensor<16x8xi32> + %25 = tt.broadcast %23 : tensor<16x1xi32> -> tensor<16x8xi32> + %26 = arith.addi %24, %25 : tensor<16x8xi32> + %27 = tt.splat %arg0 : !tt.ptr -> tensor<16x8x!tt.ptr> + %28 = tt.addptr %27, %26 : tensor<16x8x!tt.ptr>, tensor<16x8xi32> + tt.store %28, %20 : tensor<16x8x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @indirect_mix_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_7:.*]] = arith.constant dense<16> : tensor<1x8xi32> +// CHECK: %[[VAL_9:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %{{.*}} : i32 +// CHECK: %[[VAL_11:.*]] = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> +// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_10]] : i32 -> tensor<8xi32> +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : tensor<8xi32> +// CHECK: %[[VAL_14:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> +// CHECK: %[[VAL_15:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<16x!tt.ptr> +// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_15]], %[[VAL_14]] : tensor<16x!tt.ptr>, tensor<16xi32> +// CHECK: %[[VAL_17:.*]] = tt.load %[[VAL_16]] : tensor<16x!tt.ptr> +// CHECK: %[[VAL_18:.*]] = tt.expand_dims %[[VAL_11]] {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_3]] : i32 -> tensor<1x8xi32> +// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_18]], %[[VAL_19]] : tensor<1x8xi32> +// CHECK: %[[VAL_21:.*]] = tt.expand_dims %[[VAL_17]] {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64> +// CHECK: %[[VAL_22:.*]] = arith.extsi %[[VAL_20]] : tensor<1x8xi32> to tensor<1x8xi64> +// CHECK: %[[VAL_23:.*]] = tt.broadcast %[[VAL_22]] : tensor<1x8xi64> -> tensor<16x8xi64> +// CHECK: %[[VAL_24:.*]] = tt.broadcast %[[VAL_21]] : tensor<16x1xi64> -> tensor<16x8xi64> +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : tensor<16x8xi64> +// CHECK: %[[VAL_26:.*]] = tensor.empty() : tensor<16x8xf32> +// CHECK: %[[VAL_27:.*]] = scf.for %[[VAL_28:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] iter_args(%[[VAL_29:.*]] = %[[VAL_26]]) -> (tensor<16x8xf32>) { +// CHECK: %[[VAL_30:.*]] = tensor.extract_slice %[[VAL_25]]{{\[}}%[[VAL_28]], 0] [1, 8] [1, 1] {DiscreteMemAccess} : tensor<16x8xi64> to tensor<1x8xi64> +// CHECK: %[[VAL_31:.*]] = tt.splat %[[VAL_2]] : !tt.ptr -> tensor<1x8x!tt.ptr> +// CHECK: %[[VAL_32:.*]] = tt.addptr %[[VAL_31]], %[[VAL_30]] : tensor<1x8x!tt.ptr>, tensor<1x8xi64> +// CHECK: %[[VAL_33:.*]] = tt.load %[[VAL_32]] {DiscreteMemAccess} : tensor<1x8x!tt.ptr> +// CHECK: %[[VAL_34:.*]] = tensor.insert_slice %[[VAL_33]] into %[[VAL_29]]{{\[}}%[[VAL_28]], 0] [1, 8] [1, 1] : tensor<1x8xf32> into tensor<16x8xf32> +// CHECK: scf.yield {DiscreteMemAccess} %[[VAL_34]] : tensor<16x8xf32> +// CHECK: } {ExtractedLoadOrStore} +// CHECK: %[[VAL_35:.*]] = math.exp %[[VAL_27]] : tensor<16x8xf32> +// CHECK: %[[VAL_36:.*]] = tt.expand_dims %[[VAL_13]] {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> +// CHECK: %[[VAL_37:.*]] = arith.muli %[[VAL_36]], %[[VAL_7]] : tensor<1x8xi32> +// CHECK: %[[VAL_38:.*]] = tt.expand_dims %[[VAL_14]] {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> +// CHECK: %[[VAL_39:.*]] = tt.broadcast %[[VAL_37]] : tensor<1x8xi32> -> tensor<16x8xi32> +// CHECK: %[[VAL_40:.*]] = tt.broadcast %[[VAL_38]] : tensor<16x1xi32> -> tensor<16x8xi32> +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_39]], %[[VAL_40]] : tensor<16x8xi32> +// CHECK: %[[VAL_42:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<16x8x!tt.ptr> +// CHECK: %[[VAL_43:.*]] = tt.addptr %[[VAL_42]], %[[VAL_41]] : tensor<16x8x!tt.ptr>, tensor<16x8xi32> +// CHECK: tt.store %[[VAL_43]], %[[VAL_35]] : tensor<16x8x!tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/third_party/ascend/unittest/affine_map/affine_map.py b/third_party/ascend/unittest/affine_map/affine_map.py new file mode 100644 index 0000000000..5d8460aed6 --- /dev/null +++ b/third_party/ascend/unittest/affine_map/affine_map.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +def main(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + c2 = ascend_ir.affine_expr.get_constant(2) + + expr = (d0 + c2) * d1 + print("expr:", expr) + print("expr pure affine:", expr.is_pure_affine()) + print("expr hashable:", hash(expr)) + + m0 = ascend_ir.affine_map.get_identity(2) + m1 = ascend_ir.affine_map.get(2, 0, [d1, d0]) + m2 = ascend_ir.affine_map.get(2, 0, [d0 + d1, d1]) + m3 = ascend_ir.affine_map.get_constant(7) + minor = ascend_ir.affine_map.get_minor_identity(3, 2) + + print("m0:", m0) + print("m1:", m1) + print("m2:", m2) + print("m1 inverse:", m1.inverse_permutation()) + print("m2 submap[1]:", m2.get_sub_map([1])) + print("m2 compose m1:", m2.compose(m1)) + print("m1 as dict:", m1.to_dict()) + print("m3 constant:", m3, "value=", m3.get_constant_result()) + print("minor identity:", minor) + print("m2 results:", [str(x) for x in m2.get_results()]) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/affine_map/affine_map_buffer_type_demo.py b/third_party/ascend/unittest/affine_map/affine_map_buffer_type_demo.py new file mode 100644 index 0000000000..a3e248a454 --- /dev/null +++ b/third_party/ascend/unittest/affine_map/affine_map_buffer_type_demo.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +def main(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + builder = ascend_ir.ascendnpu_ir_builder(ctx) + f32 = builder.get_float_ty() + ub_space = builder.get_target_attribute(ascend_ir.AddressSpace.UB) + + # Build a memref type using an explicit affine map layout. + transpose_map = ascend_ir.affine_map.get(2, 0, [1, 0]) + memref_ty = builder.get_buffer_ty_with_affine_map([8, 16], f32, transpose_map, ub_space) + map_attr = builder.get_affine_map_attr(transpose_map) + + print("affine map:", transpose_map) + print("affine map attr:", map_attr) + print("memref type:", memref_ty) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/affine_map/affine_map_complex_expr_demo.py b/third_party/ascend/unittest/affine_map/affine_map_complex_expr_demo.py new file mode 100644 index 0000000000..4748e69a02 --- /dev/null +++ b/third_party/ascend/unittest/affine_map/affine_map_complex_expr_demo.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +def main(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + s0 = ascend_ir.affine_expr.get_symbol(0) + c3 = ascend_ir.affine_expr.get_constant(3) + c4 = ascend_ir.affine_expr.get_constant(4) + + # Complex expressions with symbols and integer arithmetic. + tiled_row = (d0 + s0).floordiv(c4) + tiled_col = (d1 + c3).ceildiv(c4) + inner = (d0 + d1).mod(c4) + + map_a = ascend_ir.affine_map.get(2, 1, [tiled_row, tiled_col, inner]) + map_b = ascend_ir.affine_map.get(2, 0, [d1, d0]) + map_comp = map_a.compose(map_b) + + print("map_a:", map_a) + print("map_b:", map_b) + print("map_a composed with map_b:", map_comp) + print("map_a results:", [str(r) for r in map_a.get_results()]) + print("map_a submap [0, 2]:", map_a.get_sub_map([0, 2])) + print("map_a metadata:", map_a.to_dict()) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/affine_map/affine_map_indexing_map_demo.py b/third_party/ascend/unittest/affine_map/affine_map_indexing_map_demo.py new file mode 100644 index 0000000000..d8be55d37e --- /dev/null +++ b/third_party/ascend/unittest/affine_map/affine_map_indexing_map_demo.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +def main(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + builder = ascend_ir.ascendnpu_ir_builder(ctx) + + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + c8 = ascend_ir.affine_expr.get_constant(8) + + # Example indexing maps: transpose and a tiled/reduced projection. + map_in0 = ascend_ir.affine_map.get(2, 0, [d1, d0]) + map_in1 = ascend_ir.affine_map.get(2, 0, [d0, d1]) + map_out = ascend_ir.affine_map.get(2, 0, [d0.floordiv(c8), d1.mod(c8)]) + + indexing_map_attr = builder.get_affine_map_array_attr([map_in0, map_in1, map_out]) + print("indexing_map attr:", indexing_map_attr) + + ub_space = builder.get_target_attribute(ascend_ir.AddressSpace.UB) + f32 = builder.get_float_ty() + memref_ty = builder.get_buffer_ty_with_affine_map([16, 32], f32, map_in0, ub_space) + print("buffer type with map_in0:", memref_ty) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/affine_map/affine_map_parse_demo.py b/third_party/ascend/unittest/affine_map/affine_map_parse_demo.py new file mode 100644 index 0000000000..9bb8e09667 --- /dev/null +++ b/third_party/ascend/unittest/affine_map/affine_map_parse_demo.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +def main(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + identity_map = ascend_ir.affine_map.get_identity(2) + transpose_map = ascend_ir.affine_map.get(2, 0, [1, 0]) + + print("identity map:", identity_map) + print(" dims:", identity_map.get_num_dims()) + print(" symbols:", identity_map.get_num_symbols()) + print(" results:", identity_map.get_num_results()) + print(" is_identity:", identity_map.is_identity()) + print(" is_permutation:", identity_map.is_permutation()) + + print("transpose map:", transpose_map) + print(" is_identity:", transpose_map.is_identity()) + print(" is_permutation:", transpose_map.is_permutation()) + print(" as python object:", transpose_map.to_dict()) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/autotune_ut/01-vector-add.py b/third_party/ascend/unittest/autotune_ut/01-vector-add.py new file mode 100644 index 0000000000..1219c1a2e3 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/01-vector-add.py @@ -0,0 +1,86 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Vector Add +============= +""" + +import os + +import torch +import torch_npu +import triton +import triton.language as tl +import triton.backends.ascend.runtime +from triton.backends.ascend.testing import do_bench_npu + + +@triton.autotune(configs=[], key=["n_elements"]) +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +def add_torch(x, y): + return x + y + + +def add_autotune(x, y): + output = torch.empty_like(x) + n_elements = output.numel() + add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )](x, y, output, n_elements) + return output + + +def test_add(size: int): + x = torch.rand(size, device="npu") + y = torch.rand(size, device="npu") + + output_torch = add_torch(x, y) + output_triton = add_autotune(x, y) + assert torch.allclose(output_triton, output_torch) + print(f"Vector Add {size} PASSED!") + + +if __name__ == "__main__": + test_add(98432) diff --git a/third_party/ascend/unittest/autotune_ut/02-fused-softmax.py b/third_party/ascend/unittest/autotune_ut/02-fused-softmax.py new file mode 100644 index 0000000000..66f6c7a371 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/02-fused-softmax.py @@ -0,0 +1,99 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Fused Softmax +============= +""" + +import os + +import torch +import torch_npu +import triton +import triton.language as tl +import triton.backends.ascend.runtime +from triton.backends.ascend.testing import do_bench_npu + + +@triton.autotune( + configs=[], + key=["n_rows", "n_cols"], +) +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, + XBLOCK: tl.constexpr, + XBLOCK_SUB: tl.constexpr, +): + # starting row of the program + row_start = tl.program_id(0) * XBLOCK + for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB): + # The stride represents how much we need to increase the pointer to advance 1 row + row_offsets = row_start + row_idx + tl.arange(0, XBLOCK_SUB)[:, None] + col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] + xmask = row_offsets < n_rows + ymask = col_offsets < n_cols + mask = xmask & ymask + input_ptrs = input_ptr + (row_offsets * input_row_stride + col_offsets) + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = (tl.sum(numerator, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE)) + softmax_output = numerator / denominator + # Write back output to DRAM + output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) + tl.store(output_ptrs, softmax_output, mask=mask) + + +def softmax_torch(x): + return torch.softmax(x, axis=-1) + + +def softmax_autotune(x): + n_rows, n_cols = x.shape + BLOCK_SIZE = n_cols + + # Allocate output + y = torch.empty_like(x) + # Create a number of persistent programs. + softmax_kernel[lambda meta: (triton.cdiv(n_rows, meta["XBLOCK"]), 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, + n_cols, BLOCK_SIZE=BLOCK_SIZE) + return y + + +def test_softmax(shape, dtype): + x = torch.randn(shape, dtype=dtype, device="npu") + y_torch = softmax_torch(x) + y_triton = softmax_autotune(x) + assert torch.allclose(y_triton, y_torch) + print(f"Fused Softmax {shape} {dtype} PASSED!") + + +if __name__ == "__main__": + test_softmax((16896, 1024), torch.float32) diff --git a/third_party/ascend/unittest/autotune_ut/03-layer-norm.py b/third_party/ascend/unittest/autotune_ut/03-layer-norm.py new file mode 100644 index 0000000000..ab547cfa36 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/03-layer-norm.py @@ -0,0 +1,133 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Layer Normalization +============= +""" + +import os + +import torch +import torch_npu +import triton +import triton.language as tl +import triton.backends.ascend.runtime +from triton.backends.ascend.testing import do_bench_npu + + +@triton.autotune( + configs=[], + key=["M", "N"], +) +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, + M, # number of columns in X + eps, # epsilon to avoid division by zero + XBLOCK_SIZE: tl.constexpr, + RBLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row_begin = tl.program_id(0) * XBLOCK_SIZE + row_idx = row_begin + tl.arange(0, XBLOCK_SIZE) + row_mask = row_idx < M + row_offsets = row_idx[:, None] * stride + # Compute mean + _mean = tl.zeros((XBLOCK_SIZE, RBLOCK_SIZE), dtype=tl.float32) + for off in range(0, N, RBLOCK_SIZE): + col_idx = off + tl.arange(0, RBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:, None] & col_mask[None, :] + a = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=1, keep_dims=True) / N + # Compute variance + _var = tl.zeros((XBLOCK_SIZE, RBLOCK_SIZE), dtype=tl.float32) + for off in range(0, N, RBLOCK_SIZE): + col_idx = off + tl.arange(0, RBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:, None] & col_mask[None, :] + x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to(tl.float32) + x = tl.where(mask, x - mean, 0.0) + _var += x * x + var = tl.sum(_var, axis=1, keep_dims=True) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row_idx[:, None], mean, mask=row_mask[:, None]) + tl.store(Rstd + row_idx[:, None], rstd, mask=row_mask[:, None]) + # Normalize and apply linear transformation + for off in range(0, N, RBLOCK_SIZE): + col_idx = off + tl.arange(0, RBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:, None] & col_mask[None, :] + w = tl.load(W + col_idx, mask=col_mask).reshape((1, RBLOCK_SIZE)) + b = tl.load(B + col_idx, mask=col_mask).reshape((1, RBLOCK_SIZE)) + x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + row_offsets + col_idx[None, :], y, mask=mask) + + +def layer_norm_torch(args): + x, w_shape, weight, bias, eps, dtype = args + return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + + +def layer_norm_autotune(args): + x, weight, bias, eps = args + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) + + # enqueue kernel + _layer_norm_fwd_fused[lambda meta: (triton.cdiv(M, meta["XBLOCK_SIZE"]), 1, 1)]( # + x_arg, y, weight, bias, mean, rstd, x_arg.stride(0), N, M, eps # + ) + return y + + +def test_layer_norm(shape, dtype, eps=1e-5): + M, N = shape + device = "npu" + x_shape = shape + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device) + bias = torch.rand(w_shape, dtype=dtype, device=device) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + y_torch = layer_norm_torch((x, w_shape, weight, bias, eps, dtype)) + y_triton = layer_norm_autotune((x, weight, bias, eps)) + assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0) + print(f"Layer Normalization {M},{N} {dtype} PASSED!") + + +if __name__ == "__main__": + test_layer_norm((128, 32), torch.float16) diff --git a/third_party/ascend/unittest/autotune_ut/04-libentry.py b/third_party/ascend/unittest/autotune_ut/04-libentry.py new file mode 100644 index 0000000000..e956ba034b --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/04-libentry.py @@ -0,0 +1,97 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Vector Add with Libentry +============= +""" + +import os + +import torch +import torch_npu +import triton +import triton.language as tl +import triton.backends.ascend.runtime +from triton.runtime.libentry import libentry +from triton.backends.ascend.testing import do_bench_npu + + +# NB: Inserting any other decorator between @triton.autotune and @triton.jit disables +# parallel compilation during autotuning. To enable parallel compilation, apply @triton.autotune +# directly around @triton.jit (i.e., nest autotune as the outermost decorator on the JIT-compiled function) +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 1 * 1024, 'multibuffer': True}), + triton.Config({'BLOCK_SIZE': 12 * 1024, 'multibuffer': True}), + triton.Config({'BLOCK_SIZE': 12 * 1024, 'multibuffer': False}), + triton.Config({'BLOCK_SIZE': 8 * 1024, 'multibuffer': True}), + ], key=["n_elements"]) +@libentry() +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +def add_torch(x, y): + return x + y + + +def add_autotune(x, y): + output = torch.empty_like(x) + n_elements = output.numel() + add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )](x, y, output, n_elements) + return output + + +def test_add(size: int): + x = torch.rand(size, device="npu") + y = torch.rand(size, device="npu") + + output_torch = add_torch(x, y) + output_triton = add_autotune(x, y) + assert torch.allclose(output_triton, output_torch) + print(f"Vector Add {size} with libentry PASSED!") + + +if __name__ == "__main__": + test_add(98432) diff --git a/third_party/ascend/unittest/autotune_ut/test_autotune_param_valid.py b/third_party/ascend/unittest/autotune_ut/test_autotune_param_valid.py new file mode 100644 index 0000000000..4ce48e6c19 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_autotune_param_valid.py @@ -0,0 +1,163 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os + +import pytest +import torch +import torch_npu +import triton +import triton.backends.ascend.runtime +import triton.language as tl + + +@triton.autotune( + configs=[], key={"x": "n_elements"}, hints={ + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SIZE_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + }) +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_SUB: tl.constexpr, +): + offset = tl.program_id(0) * BLOCK_SIZE + loops1 = (BLOCK_SIZE + BLOCK_SIZE_SUB - 1) // BLOCK_SIZE_SUB + for loop in range(0, loops1): + x0 = offset + loop * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE_SUB) + mask = x0 < n_elements + x = tl.load(x_ptr + x0, mask) + y = tl.load(y_ptr + x0, mask) + output = x + y + tl.store(output_ptr + x0, output) + + +def add_torch(x, y): + return x + y + + +def add_autotune(x, y): + output = torch.empty_like(x) + n_elements = output.numel() + add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )](x, y, output, n_elements) + return output + + +@pytest.mark.autotune +@pytest.mark.parametrize('size', [ + 2048, +]) +def test_add(size: int): + x = torch.rand(size, device="npu") + y = torch.rand(size, device="npu") + + output_torch = add_torch(x, y) + output_triton = add_autotune(x, y) + assert torch.allclose(output_triton, output_torch) + + +@pytest.mark.autotune +def test_add_no_reduction_axes(): + try: + + @triton.autotune( + configs=[], key={"x": "n_elements"}, hints={ + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SIZE_SUB"}, + "low_dim_axes": ["x"], + }) + @triton.jit + def add_kernel_exception(): + pass + except ValueError as e: + assert "reduction_axes must be a list" in str(e) + + +@pytest.mark.autotune +def test_add_no_low_dim_axes(): + try: + + @triton.autotune( + configs=[], key={"x": "n_elements"}, hints={ + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SIZE_SUB"}, + "reduction_axes": [], + }) + @triton.jit + def add_kernel_exception(): + pass + except ValueError as e: + assert "low_dim_axes must be a list" in str(e) + + +@pytest.mark.autotune +def test_add_no_tiling_params(): + try: + + @triton.autotune(configs=[], key={"x": "n_elements"}, hints={ + "split_params": {"x": "BLOCK_SIZE"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + }) + @triton.jit + def add_kernel_exception(): + pass + except ValueError as e: + assert "tiling_params must be a dict" in str(e) + + +@pytest.mark.autotune +def test_add_no_split_params(): + try: + + @triton.autotune( + configs=[], key={"x": "n_elements"}, hints={ + "tiling_params": {"x": "BLOCK_SIZE_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + }) + @triton.jit + def add_kernel_exception(): + pass + except ValueError as e: + assert "split_params must be a dict" in str(e) + + +@pytest.mark.autotune +def test_add_no_keyname(): + try: + + @triton.autotune( + configs=[], key={"x0": "n_elements"}, hints={ + "tiling_params": {"x": "BLOCK_SIZE_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + }) + @triton.jit + def add_kernel_exception(): + pass + except ValueError as e: + assert "All keys in 'key' must be valid axis names" in str(e) diff --git a/third_party/ascend/unittest/autotune_ut/test_common.py b/third_party/ascend/unittest/autotune_ut/test_common.py index d512d3358e..8c80b45a7d 100644 --- a/third_party/ascend/unittest/autotune_ut/test_common.py +++ b/third_party/ascend/unittest/autotune_ut/test_common.py @@ -20,6 +20,7 @@ import unittest.mock as mock import pytest +import torch def MockAutoTilingTunerRun(self, *args, **kwargs): @@ -84,3 +85,18 @@ def normalize_axis_list(axis_list: list, sym_to_sem: dict) -> list: def mock_autotuner(): with mock.patch("triton.backends.ascend.runtime.autotuner.AutoTilingTuner.run", new=MockAutoTilingTunerRun): yield + + +def generate_tensor(shape, dtype): + if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': + return torch.randn(size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': + return torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int8': + return torch.randint(low=0, high=127, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'bool': + return torch.randint(low=0, high=2, size=shape).bool() + elif dtype == 'uint8': + return torch.randint(low=0, high=255, size=shape, dtype=torch.uint8) + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) diff --git a/third_party/ascend/unittest/autotune_ut/test_customized_config.py b/third_party/ascend/unittest/autotune_ut/test_customized_config.py new file mode 100644 index 0000000000..3431109569 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_customized_config.py @@ -0,0 +1,91 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os + +import pytest +import torch +import torch_npu +import triton +import triton.backends.ascend.runtime +import triton.language as tl + +os.environ['TRITON_PRINT_AUTOTUNING'] = '0' + + +@triton.autotune( + configs=[ + triton.Config({'XBLOCK': 128, 'XBLOCK_SUB': 32}), + triton.Config({'XBLOCK': 128, 'XBLOCK_SUB': 64}), + triton.Config({'XBLOCK': 396, 'XBLOCK_SUB': 6}), + ], key=["n_rows", "n_cols"], hints={ + "auto_gen_config": False, + }) +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, + XBLOCK: tl.constexpr, + XBLOCK_SUB: tl.constexpr, +): + row_start = tl.program_id(0) * XBLOCK + for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB): + row_offsets = row_start + row_idx + tl.arange(0, XBLOCK_SUB)[:, None] + col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] + xmask = row_offsets < n_rows + ymask = col_offsets < n_cols + mask = xmask & ymask + input_ptrs = input_ptr + (row_offsets * input_row_stride + col_offsets) + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) + row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE) + numerator = tl.exp(row_minus_max) + denominator = (tl.sum(numerator, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE)) + softmax_output = numerator / denominator + output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) + tl.store(output_ptrs, softmax_output, mask=mask) + + +def softmax_torch(x): + return torch.softmax(x, axis=-1) + + +def softmax_autotune(x): + n_rows, n_cols = x.shape + BLOCK_SIZE = n_cols + y = torch.empty_like(x) + softmax_kernel[lambda meta: (triton.cdiv(n_rows, meta["XBLOCK"]), 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, + n_cols, BLOCK_SIZE=BLOCK_SIZE) + return y + + +@pytest.mark.autotune +@pytest.mark.parametrize('shape,dtype', [ + ((16896, 1024), torch.float32), +]) +def test_softmax(shape, dtype): + x = torch.randn(shape, dtype=dtype, device="npu") + y_torch = softmax_torch(x) + y_triton = softmax_autotune(x) + torch.testing.assert_close(y_torch, y_triton, rtol=1e-03, atol=1e-03, equal_nan=True) diff --git a/third_party/ascend/unittest/autotune_ut/test_low_dim_axes_parse.py b/third_party/ascend/unittest/autotune_ut/test_low_dim_axes_parse.py new file mode 100644 index 0000000000..b44355baab --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_low_dim_axes_parse.py @@ -0,0 +1,54 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +from test_common import check_axes_parse_res, mock_autotuner + + +def test_low_dim_axis_parse_base_case1(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["n_elements"]) + @triton.jit + def triton_low_dim_axis_parse_base_case1(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE # <- Separate assignment + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"], ) + act_res = triton_low_dim_axis_parse_base_case1[grid]() + + check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/unittest/autotune_ut/test_mask_parse.py b/third_party/ascend/unittest/autotune_ut/test_mask_parse.py new file mode 100644 index 0000000000..7a18f9f9f8 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_mask_parse.py @@ -0,0 +1,162 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import triton.language as tl +from test_common import check_axes_parse_res, mock_autotuner + + +def test_triton_dot_case1(mock_autotuner): + """ + The current operator is only used for aixs analysis test cases. + CV fused operators do not support autotuning for now. + """ + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["M", "N", "K"]) + @triton.jit + def triton_dot_case1( + A, + B, + C, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + MBLOCK: tl.constexpr, + NBLOCK: tl.constexpr, + MBLOCK_SUB: tl.constexpr, + NBLOCK_SUB: tl.constexpr, + KBLOCK_SUB: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + base_m = pid_m * MBLOCK + base_n = pid_n * NBLOCK + + loops_m = (MBLOCK + MBLOCK_SUB - 1) // MBLOCK_SUB + loops_n = (NBLOCK + NBLOCK_SUB - 1) // NBLOCK_SUB + loops_k = (K + KBLOCK_SUB - 1) // KBLOCK_SUB + + for loop_m in range(loops_m): + for loop_n in range(loops_n): + acc = tl.zeros((MBLOCK_SUB, NBLOCK_SUB), dtype=tl.float32) + + mdx = base_m + loop_m * MBLOCK_SUB + tl.arange(0, MBLOCK_SUB)[:, None] + ndx = base_n + loop_n * NBLOCK_SUB + tl.arange(0, NBLOCK_SUB)[None, :] + + for loop_k in range(loops_k): + kdx = loop_k * KBLOCK_SUB + tl.arange(0, KBLOCK_SUB) + kdx_m = kdx[None, :] # <- + A_ptr = A + mdx * K + kdx_m + a_mask = (mdx < M) & (kdx_m < K) # Use res of Subscript in mask compare + a = tl.load(A_ptr, mask=a_mask, other=0.0) + + kdx_n = kdx[:, None] + B_ptr = B + kdx_n * N + ndx + b_mask = (kdx_n < K) & (ndx < N) + b = tl.load(B_ptr, mask=b_mask, other=0.0) + + acc += tl.dot(a, b) + + C_ptr = C + mdx * N + ndx + c_mask = (mdx < M) & (ndx < N) + tl.store(C_ptr, acc, mask=c_mask) + + ref_res = { + "keys": {"x": "M", "y": "N", "z": "K"}, + "split_params": {"x": "MBLOCK", "y": "NBLOCK"}, + "tiling_params": {"x": "MBLOCK_SUB", "y": "NBLOCK_SUB", "z": "KBLOCK_SUB"}, + "low_dim_axes": ["y", "z"], + "reduction_axes": [], + } + grid = lambda meta: (meta["MBLOCK"], meta["NBLOCK"]) + act_res = triton_dot_case1[grid]() + + check_axes_parse_res(act_res, ref_res) + + +@pytest.mark.skip +def test_triton_dot_case2(mock_autotuner): + """ + The current operator is only used for aixs analysis test cases. + CV fused operators do not support autotuning for now. + """ + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["M", "N", "K"]) + @triton.jit + def triton_dot_case2( + A, + B, + C, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + MBLOCK: tl.constexpr, + NBLOCK: tl.constexpr, + MBLOCK_SUB: tl.constexpr, + NBLOCK_SUB: tl.constexpr, + KBLOCK_SUB: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + base_m = pid_m * MBLOCK + base_n = pid_n * NBLOCK + + loops_m = (MBLOCK + MBLOCK_SUB - 1) // MBLOCK_SUB + loops_n = (NBLOCK + NBLOCK_SUB - 1) // NBLOCK_SUB + loops_k = (K + KBLOCK_SUB - 1) // KBLOCK_SUB + + for loop_m in range(loops_m): + for loop_n in range(loops_n): + acc = tl.zeros((MBLOCK_SUB, NBLOCK_SUB), dtype=tl.float32) + + mdx = base_m + loop_m * MBLOCK_SUB + tl.arange(0, MBLOCK_SUB)[:, None] + ndx = base_n + loop_n * NBLOCK_SUB + tl.arange(0, NBLOCK_SUB)[None, :] + + for loop_k in range(loops_k): + kdx = loop_k * KBLOCK_SUB + tl.arange(0, KBLOCK_SUB) + A_ptr = A + mdx * K + kdx[None, :] # <- + a_mask = (mdx < M) & (kdx[None, :] < K) # Cal subsript directly in mask compare + a = tl.load(A_ptr, mask=a_mask, other=0.0) + + B_ptr = B + kdx[:, None] * N + ndx + b_mask = (kdx[:, None] < K) & (ndx < N) + b = tl.load(B_ptr, mask=b_mask, other=0.0) + + acc += tl.dot(a, b) + + C_ptr = C + mdx * N + ndx + c_mask = (mdx < M) & (ndx < N) + tl.store(C_ptr, acc, mask=c_mask) + + ref_res = { + "keys": {"x": "M", "y": "N", "z": "K"}, + "split_params": {"x": "MBLOCK", "y": "NBLOCK"}, + "tiling_params": {"x": "MBLOCK_SUB", "y": "NBLOCK_SUB", "z": "KBLOCK_SUB"}, + "low_dim_axes": ["y", "z"], + "reduction_axes": [], + } + grid = lambda meta: (meta["MBLOCK"], meta["NBLOCK"]) + act_res = triton_dot_case2[grid]() + + check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/unittest/autotune_ut/test_no_tiling_axis_parse.py b/third_party/ascend/unittest/autotune_ut/test_no_tiling_axis_parse.py new file mode 100644 index 0000000000..f708cef694 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_no_tiling_axis_parse.py @@ -0,0 +1,99 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import os +import shutil +import pytest +import torch +import torch_npu +import triton +import triton.backends.ascend.runtime +import triton.language as tl +from triton.tools.get_ascend_devices import is_compile_on_910_95 + +import test_common + +os.environ['TRITON_ALWAYS_COMPILE'] = '1' +os.environ['TRITON_AUTOTUNE_PARALLEL_COMPILE'] = '0' + + +def case_torch(x): + return torch.permute(x, (1, 0)) + + +@triton.autotune(configs=[], key=['xnumel', 'ynumel'], hints={ + "auto_gen_config": True, +}) +@triton.jit +def triton_permute_2d( + output_ptr, + x_ptr, + xnumel: tl.constexpr, + ynumel: tl.constexpr, + XBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, +): + xpid = tl.program_id(0) + ypid = tl.program_id(1) + + x_off = xpid * XBLOCK + tl.arange(0, XBLOCK)[:, None] + y_off = ypid * YBLOCK + tl.arange(0, YBLOCK)[None, :] + mask = (x_off < xnumel) & (y_off < ynumel) + offs = y_off + x_off * ynumel + b = tl.load(x_ptr + offs, mask=mask) + ox_off = ypid * YBLOCK + tl.arange(0, YBLOCK)[:, None] + oy_off = xpid * XBLOCK + tl.arange(0, XBLOCK)[None, :] + o_mask = (ox_off < ynumel) & (oy_off < xnumel) + o_offs = oy_off + ox_off * xnumel + ret = tl.permute(b, (1, 0)) + tl.store(output_ptr + o_offs, ret, mask=o_mask) + + +def case_triton(x_cal, is_simt_only=False): + xnumel = x_cal.shape[0] + ynumel = x_cal.shape[1] + output = torch.randint(1, (ynumel, xnumel), dtype=x_cal.dtype, device=x_cal.device) + if is_simt_only: + (triton_permute_2d[lambda meta: (triton.cdiv(xnumel, meta['XBLOCK']), triton.cdiv(ynumel, meta['YBLOCK']), 1)]( + output, x_cal, xnumel, ynumel, force_simt_only=True)) + else: + (triton_permute_2d[lambda meta: + (triton.cdiv(xnumel, meta['XBLOCK']), triton.cdiv(ynumel, meta['YBLOCK']), 1)](output, x_cal, + xnumel, + ynumel)) + return output + + +@pytest.mark.parametrize('shape', [(1024, 32), (32, 8)]) +@pytest.mark.parametrize('dtype', ['bfloat16']) +def test_permute(shape, dtype): + x_cal = test_common.generate_tensor(shape, dtype).npu() + torch_output = case_torch(x_cal) + triton_output = case_triton(x_cal) + torch.testing.assert_close(torch_output, triton_output, rtol=1e-03, atol=1e-03, equal_nan=True) + + +@pytest.mark.skipif(not is_compile_on_910_95, reason="only support A5") +@pytest.mark.parametrize('shape', [(1024, 32)]) +@pytest.mark.parametrize('dtype', ['bfloat16']) +def test_permute_simt(shape, dtype): + x_cal = test_common.generate_tensor(shape, dtype).npu() + torch_output = case_torch(x_cal) + triton_output = case_triton(x_cal, True) + torch.testing.assert_close(torch_output, triton_output, rtol=1e-03, atol=1e-03, equal_nan=True) diff --git a/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py b/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py index 2893bf3473..e6f3baa442 100644 --- a/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py +++ b/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py @@ -22,12 +22,12 @@ from test_common import check_axes_parse_res, mock_autotuner -def test_triton_max_last_dim_case(mock_autotuner): +def test_triton_max_last_dim_case1(mock_autotuner): import triton.backends.ascend.runtime @triton.autotune(configs=[], key=["x0_numel", "r1_numel"]) @triton.jit - def triton_max_last_dim( + def triton_max_last_dim1( in_ptr0, out_ptr0, x0_numel, @@ -50,7 +50,8 @@ def triton_max_last_dim( r1_mask = r1 < r1_numel tmp = tl.load(in_ptr0 + (r1 + r1_numel * x0), r1_mask & x0_mask, other=float("-inf")) block_val = tl.maximum(block_val, tmp) - block_res = tl.max(block_val, axis=1)[:, None] + # Reduce along axis = 1 (the last dimension in this 2D tensor) + block_res = tl.max(block_val, axis=1)[:, None] # <- explicit positive axis index tl.store(out_ptr0 + x0, block_res, x0_mask) ref_res = { @@ -60,6 +61,97 @@ def triton_max_last_dim( "low_dim_axes": ["ry"], "reduction_axes": ["ry"], } - act_res = triton_max_last_dim[(1, )]() + grid = lambda meta: (meta["X0BLOCK"], ) + act_res = triton_max_last_dim1[grid]() + + check_axes_parse_res(act_res, ref_res) + + +def test_triton_max_last_dim_case2(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["x0_numel", "r1_numel"]) + @triton.jit + def triton_max_last_dim2( + in_ptr0, + out_ptr0, + x0_numel, + r1_numel, + X0BLOCK: tl.constexpr, + X0BLOCK_SUB: tl.constexpr, + R1BLOCK_SUB: tl.constexpr, + ): + x0_offset = tl.program_id(0) * X0BLOCK + base_x0 = tl.arange(0, X0BLOCK_SUB) + loops_x0 = (X0BLOCK + X0BLOCK_SUB - 1) // X0BLOCK_SUB + base_r1 = tl.arange(0, R1BLOCK_SUB) + loops_r1 = (r1_numel + R1BLOCK_SUB - 1) // R1BLOCK_SUB + for loop_x0 in range(loops_x0): + x0 = x0_offset + (loop_x0 * X0BLOCK_SUB) + base_x0[:, None] + x0_mask = x0 < min(X0BLOCK + x0_offset, x0_numel) + block_val = tl.full([X0BLOCK_SUB, R1BLOCK_SUB], float("-inf"), tl.float32) + for loop_r1 in range(loops_r1): + r1 = (loop_r1 * R1BLOCK_SUB) + base_r1[None, :] + r1_mask = r1 < r1_numel + tmp = tl.load(in_ptr0 + (r1 + r1_numel * x0), r1_mask & x0_mask, other=float("-inf")) + block_val = tl.maximum(block_val, tmp) + # Reduce along axis=-1 (the last dimension, equivalent to axis=1 in 2D) + block_res = tl.max(block_val, axis=-1)[:, None] # <- negative axis index (last dim) + tl.store(out_ptr0 + x0, block_res, x0_mask) + + ref_res = { + "keys": {"x": "x0_numel", "ry": "r1_numel"}, + "split_params": {"x": "X0BLOCK"}, + "tiling_params": {"x": "X0BLOCK_SUB", "ry": "R1BLOCK_SUB"}, + "low_dim_axes": ["ry"], + "reduction_axes": ["ry"], + } + grid = lambda meta: (meta["X0BLOCK"], ) + act_res = triton_max_last_dim2[grid]() + + check_axes_parse_res(act_res, ref_res) + + +def test_triton_max_last_dim_case3(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["x0_numel", "r1_numel"]) + @triton.jit + def triton_max_last_dim3( + in_ptr0, + out_ptr0, + x0_numel, + r1_numel, + X0BLOCK: tl.constexpr, + X0BLOCK_SUB: tl.constexpr, + R1BLOCK_SUB: tl.constexpr, + ): + x0_offset = tl.program_id(0) * X0BLOCK + base_x0 = tl.arange(0, X0BLOCK_SUB) + loops_x0 = (X0BLOCK + X0BLOCK_SUB - 1) // X0BLOCK_SUB + base_r1 = tl.arange(0, R1BLOCK_SUB) + loops_r1 = (r1_numel + R1BLOCK_SUB - 1) // R1BLOCK_SUB + for loop_x0 in range(loops_x0): + x0 = x0_offset + (loop_x0 * X0BLOCK_SUB) + base_x0[:, None] + x0_mask = x0 < min(X0BLOCK + x0_offset, x0_numel) + block_val = tl.full([X0BLOCK_SUB, R1BLOCK_SUB], float("-inf"), tl.float32) + for loop_r1 in range(loops_r1): + r1 = (loop_r1 * R1BLOCK_SUB) + base_r1[None, :] + r1_mask = r1 < r1_numel + tmp = tl.load(in_ptr0 + (r1 + r1_numel * x0), r1_mask & x0_mask, other=float("-inf")) + block_val = tl.maximum(block_val, tmp) + # Reduce along axis=1, passed as a positional argument (not keyword `axis=...`) + block_res = tl.max(block_val, 1)[:, None] # <- explicit positive axis index + tl.store(out_ptr0 + x0, block_res, x0_mask) + + ref_res = { + "keys": {"x": "x0_numel", "ry": "r1_numel"}, + "split_params": {"x": "X0BLOCK"}, + "tiling_params": {"x": "X0BLOCK_SUB", "ry": "R1BLOCK_SUB"}, + "low_dim_axes": ["ry"], + "reduction_axes": ["ry"], + } + grid = lambda meta: (meta["X0BLOCK"], ) + act_res = triton_max_last_dim3[grid]() check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/unittest/autotune_ut/test_split_axis_parse.py b/third_party/ascend/unittest/autotune_ut/test_split_axis_parse.py new file mode 100644 index 0000000000..078873341f --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_split_axis_parse.py @@ -0,0 +1,149 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import unittest.mock as mock + +import triton +import triton.language as tl + +from test_common import check_axes_parse_res, mock_autotuner + + +def test_split_axis_parse_base_case1(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["n_elements"]) + @triton.jit + def triton_split_axis_parse_base_case1(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE # <- Separate assignment + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"], ) + act_res = triton_split_axis_parse_base_case1[grid]() + + check_axes_parse_res(act_res, ref_res) + + +def test_split_axis_parse_base_case2(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["n_elements"]) + @triton.jit + def triton_split_axis_parse_base_case2(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + block_start = tl.program_id(axis=0) * BLOCK_SIZE # <- Computed inline but still named + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"], ) + act_res = triton_split_axis_parse_base_case2[grid]() + + check_axes_parse_res(act_res, ref_res) + + +def test_split_axis_parse_base_case3(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["n_elements"]) + @triton.jit + def triton_split_axis_parse_base_case3(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # <- Fully fused + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"], ) + act_res = triton_split_axis_parse_base_case3[grid]() + + check_axes_parse_res(act_res, ref_res) + + +def test_grid_stride_loop_block_only_tiling_semantics(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["N", "index_len"]) + @triton.jit + def triton_grid_stride_loop_block_only_tiling_semantics( + input_ptr, + output_ptr, + index_ptr, + N: tl.constexpr, + index_len: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + grid_x = tl.num_programs(axis=0) + grid_y = tl.num_programs(axis=1) + for x in range(pid_x * BLOCK_M, index_len, grid_x * BLOCK_M): + row_offsets = x + tl.arange(0, BLOCK_M) + indices = tl.load(index_ptr + row_offsets, mask=row_offsets < index_len, other=0) + for y in range(pid_y * BLOCK_N, N, grid_y * BLOCK_N): + col_offsets = y + tl.arange(0, BLOCK_N) + col_mask = col_offsets < N + inp_offset = indices[:, None] * N + col_offsets[None, :] + out_offset = row_offsets[:, None] * N + col_offsets[None, :] + selected = tl.load(input_ptr + inp_offset, mask=col_mask[None, :], other=0.0) + tl.store(output_ptr + out_offset, selected, mask=col_mask[None, :]) + + act_res = triton_grid_stride_loop_block_only_tiling_semantics[(1, 1)]() + assert act_res["split_params"] == {} + assert act_res["tiling_params"] == {"y": "BLOCK_M", "x": "BLOCK_N"} diff --git a/third_party/ascend/unittest/autotune_ut/test_tiling_axis_parse.py b/third_party/ascend/unittest/autotune_ut/test_tiling_axis_parse.py new file mode 100644 index 0000000000..1c5e150932 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_tiling_axis_parse.py @@ -0,0 +1,123 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import triton +import triton.language as tl +from test_common import check_axes_parse_res, mock_autotuner + + +def test_tiling_axis_parse_base_case1(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["n_elements"]) + @triton.jit + def triton_tiling_axis_parse_base_case1(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + BLOCK_SUB: tl.constexpr): + offset = tl.program_id(axis=0) * BLOCK_SIZE + base = tl.arange(0, BLOCK_SUB) + loops = (BLOCK_SIZE + BLOCK_SUB - 1) // BLOCK_SUB # <- + for loop in range(loops): + offsets = offset + (loop * BLOCK_SUB) + base + mask = offsets < min(BLOCK_SIZE + offset, n_elements) + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"], ) + act_res = triton_tiling_axis_parse_base_case1[grid]() + + check_axes_parse_res(act_res, ref_res) + + +@pytest.mark.skip +def test_tiling_axis_parse_base_case2(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["n_elements"]) + @triton.jit + def triton_tiling_axis_parse_base_case2(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + BLOCK_SUB: tl.constexpr): + offset = tl.program_id(axis=0) * BLOCK_SIZE + base = tl.arange(0, BLOCK_SUB) + for offset_sub in range(0, BLOCK_SIZE, BLOCK_SUB): + offsets = offset + offset_sub + base[:] # <- + mask = offsets < min(BLOCK_SIZE + offset, n_elements) + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"], ) + act_res = triton_tiling_axis_parse_base_case2[grid]() + + check_axes_parse_res(act_res, ref_res) + + +@pytest.mark.skip +def test_tiling_axis_parse_base_case3(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["n_elements"]) + @triton.jit + def triton_tiling_axis_parse_base_case3(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + BLOCK_SUB: tl.constexpr): + offset = tl.program_id(axis=0) * BLOCK_SIZE + base = tl.arange(0, BLOCK_SUB)[:] # <- + for offset_sub in range(0, BLOCK_SIZE, BLOCK_SUB): + offsets = offset + offset_sub + base + mask = offsets < min(BLOCK_SIZE + offset, n_elements) + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"], ) + act_res = triton_tiling_axis_parse_base_case3[grid]() + + check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/unittest/custom_op/builtin_ops_demo.py b/third_party/ascend/unittest/custom_op/builtin_ops_demo.py new file mode 100644 index 0000000000..e3e9fc6750 --- /dev/null +++ b/third_party/ascend/unittest/custom_op/builtin_ops_demo.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +import subprocess +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + + +@triton.jit +def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + index = tl.full([8], 0, tl.int32) + value = tl.full([8, 64], 0, tl.float32) + tmp = tl.full([8], 0, tl.float32) + x = al.custom("__builtin_index_select", x_ptr, index, dim=0, bound=100, end_offset=(2, 2), start_offset=(0, 0), + src_stride=(4, 1), out=x) + al.custom("__builtin_index_put", x_ptr, index, value, dim=0, bound=12, dst_shape=(1, 2, 3), dst_offset=(4, 5, 6), + dst_stride=(8, 4, 1)) + tmp = al.custom("__builtin_gather_load", y_ptr, index, bound=100, dim=0, src_stride=(1, ), index_shape=(3, ), + offsets=(0, ), out=tmp) + al.custom("__builtin_scatter_store", out_ptr, value, index, 1, 0, (1, ), (2, ), (1, )) + y = al.custom("__builtin_indirect_load", x_ptr, index, mask=i < n, other=y, out=y) + al.custom("__builtin_indirect_store", out_ptr, index, value) + tl.store(out_ptr + i, y, mask=i < n) + + +if __name__ == "__main__": + src = ASTSource(my_kernel, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, {"BLOCK": 256}) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + options = NPUOptions() + try: + ttir = ast_to_ttir(my_kernel, src, context, options, {}, {}) + print("=== TTIR ===") + print(ttir) + metadata = { + **options.__dict__, + } + linalg = ttir_to_linalg(ttir, metadata, options, named_ops=True) + print("=== MLIR (linalg) ===") + print(linalg) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + print("failed") diff --git a/third_party/ascend/unittest/custom_op/custom_op_demo.py b/third_party/ascend/unittest/custom_op/custom_op_demo.py new file mode 100644 index 0000000000..a28d6cdea6 --- /dev/null +++ b/third_party/ascend/unittest/custom_op/custom_op_demo.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +import subprocess +import os +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + + +@al.register_custom_op +class min_custom_op: + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_MTE2 + mode = al.MODE.SIMD + + symbol = 'min_custom_op_impl' + bitcode = os.path.abspath(__file__) + + +@al.register_custom_op +class simple_custom_op: + # name is optional, use class name by default. + name = 'simple_custom_op' + + # required attributes. + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + + symbol = 'simple_custom_op_impl' + bitcode = os.path.abspath(__file__) + + # __init__ method is optional, but it can be used for better user experience + # when provided. for example, you can validate arguments here. + def __init__(self, x, y, dim=0, out=None): + assert x.shape == y.shape, "x and y should have same shape" + assert isinstance(dim, int), "dim should be const integer" + assert out, "out is required" + + +@al.register_custom_op +class _example_custom_op: + name = 'example_custom_op' + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + + symbol = 'example_custom_op_impl' + bitcode = os.path.abspath(__file__) + + def __init__(self, src, index, offset: tl.int64, axis, out=None): + # support validate arguments in __init__ method. + assert isinstance(src, tl.tensor), "src should be tensor" + assert index.dtype.is_int(), "index should be integer tensor" + assert isinstance(offset, int), "offset should be integer" + assert isinstance(axis, int), "axis should be integer" + + # support multi-output by using tuple or list. + assert isinstance(out, tuple) and len(out) == 2, "out should be tuple of 2 items" + + # setup the symbol name of the function that will be called at runtime. + rank = len(index.shape) + self.symbol = f"{self.name}_{rank}d_{src.dtype.cname}_{index.dtype.cname}" + + # setup source and compile command if it is implemented by user source code. + self.source = f"workspace/example_custom_op_impl.cce" + self.compile = "bisheng -O2 -std=c++17 -o $@ -c $<" + + # dynamic set argument type. + self.arg_type['axis'] = index.dtype + + +@al.builtin +def example_op(src, index, offset, axis, _builder=None): + # you can wrap a custom op as a builtin operation, + # output can be provided here to make it easy to use. + x = tl.semantic.full(src.shape, 0, tl.float32, _builder) + y = tl.semantic.full(index.shape, 0, tl.float32, _builder) + return al.custom_semantic(_example_custom_op.name, src, index, offset, axis, out=(x, y), _builder=_builder) + + +@triton.jit +def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + y = al.custom("min_custom_op", x, x_ptr, y_ptr + i, al.int64(0), (1, 2, 3), [4.1, 5.2], out=y) + y = al.custom("simple_custom_op", x, y, dim=1, out=y) + index = tl.full((2, 3), 0, tl.int64) + x, y = al.custom("example_custom_op", x, index, offset=1, axis=0, out=(x, y)) + result, _ = example_op(x, index, offset=2, axis=1) + tl.store(out_ptr + i, result, mask=i < n) + + +if __name__ == "__main__": + src = ASTSource(my_kernel, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, {"BLOCK": 256}) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + options = NPUOptions() + try: + ttir = ast_to_ttir(my_kernel, src, context, options, {}, {}) + print("=== TTIR ===") + print(ttir) + metadata = { + **options.__dict__, + } + linalg = ttir_to_linalg(ttir, metadata, options, named_ops=True) + print("=== MLIR (linalg) ===") + print(linalg) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + print("failed") diff --git a/third_party/ascend/unittest/custom_op/custom_op_extra_buffer_demo.py b/third_party/ascend/unittest/custom_op/custom_op_extra_buffer_demo.py new file mode 100644 index 0000000000..fe2dbe351b --- /dev/null +++ b/third_party/ascend/unittest/custom_op/custom_op_extra_buffer_demo.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Demo: declare scratch/extra buffers on a custom op via `extra_buffers` (dtype, size) +# and read back the sizes from lowered HIVM MLIR (`extra_buffers_sizes` attribute). + +from __future__ import annotations + +import os +import re +import subprocess + +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + +# Scratch buffers requested by the custom kernel (element type + length in elements). +SCRATCH_SPEC = [ + (tl.float32, 1024), + (tl.bfloat16, 512), + (tl.int32, 256), +] + + +@al.register_custom_op +class demo_extra_buffer_op: + """Custom op that advertises extra device buffers for the NPU compiler / runtime.""" + + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "demo_extra_buffer_op_impl" + bitcode = os.path.abspath(__file__) + + def __init__(self, x, out=None): + self.indexing_map = [al.affine_map.get_identity(1)] + self.extra_buffers = list(SCRATCH_SPEC) + + +@triton.jit +def kernel_extra_buf(x_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(out_ptr + i, mask=i < n) + r = al.custom("demo_extra_buffer_op", x, out=y) + tl.store(out_ptr + i, r, mask=i < n) + + +def compile_to_linalg_mlir(kernel, signature: dict, constants: dict) -> str | None: + src = ASTSource(kernel, signature, constants) + ctx = ir.context() + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + options = NPUOptions() + try: + ttir = ast_to_ttir(kernel, src, ctx, options, {}, {}) + meta = {**options.__dict__} + return str(ttir_to_linalg(ttir, meta, options, named_ops=True)) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + return None + + +def extract_extra_buffer_sizes_from_mlir(mlir: str) -> list[int]: + """ + Parse `extra_buffers_sizes` from HIVM custom op text. + """ + # Parse [1024, 512, 256, ...] + m = re.search(r"extra_buffers_sizes\s*=\s*\[([^\]]+)\]", mlir) + if m: + raw = m.group(1).replace(" ", "") + return [int(x) for x in raw.split(",") if x] + + return [] + + +def main() -> None: + expected_sizes = [size for _, size in SCRATCH_SPEC] + print("Declared extra_buffers (dtype, element_count):") + for dt, sz in SCRATCH_SPEC: + print(f" {dt} -> {sz} elements") + + mlir = compile_to_linalg_mlir( + kernel_extra_buf, + {"x_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 128}, + ) + if not mlir: + print("Compilation failed.") + return + + parsed = extract_extra_buffer_sizes_from_mlir(mlir) + print("\nParsed extra_buffers_sizes from MLIR:", parsed) + if parsed == expected_sizes: + print("OK: MLIR sizes match the Python extra_buffers specification.") + elif parsed: + print("Note: parsed sizes differ from spec; inspect MLIR spelling below.") + else: + print("Could not parse extra_buffers_sizes automatically; " + "search the dump for 'extra_buffers_sizes'.") + + print("\n--- MLIR excerpt (lines containing hivm.hir.custom) ---") + for line in mlir.splitlines(): + if "hivm.hir.custom" in line and "demo_extra_buffer_op" in line: + print(line) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/custom_op/custom_op_indexing_map_complex_demo.py b/third_party/ascend/unittest/custom_op/custom_op_indexing_map_complex_demo.py new file mode 100644 index 0000000000..11e6dcf3cf --- /dev/null +++ b/third_party/ascend/unittest/custom_op/custom_op_indexing_map_complex_demo.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +import os +import subprocess +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + + +def _make_indexing_maps(): + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + c8 = ascend_ir.affine_expr.get_constant(8) + + # Input maps use transpose and identity-like projections. + in0 = ascend_ir.affine_map.get(2, 0, [d1, d0]) + in1 = ascend_ir.affine_map.get(2, 0, [d0, d1]) + + # Output map models tiled coordinates. + out = ascend_ir.affine_map.get(2, 0, [d0.floordiv(c8), d1.mod(c8)]) + return [in0, in1, out] + + +@al.register_custom_op +class complex_indexing_map_custom_op: + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "complex_indexing_map_custom" + # Fake path: this example checks IR lowering only. + bitcode = os.path.abspath(__file__) + + def __init__(self, x, y, out=None): + assert out is not None + self.indexing_map = _make_indexing_maps() + + +@triton.jit +def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + out = al.custom("complex_indexing_map_custom_op", x, y, out=x) + tl.store(out_ptr + i, out, mask=i < n) + + +if __name__ == "__main__": + src = ASTSource( + my_kernel, + {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + options = NPUOptions() + try: + ttir = ast_to_ttir(my_kernel, src, context, options, {}, {}) + print("=== TTIR ===") + print(ttir) + linalg = ttir_to_linalg(ttir, {**options.__dict__}, options, named_ops=True) + print("=== MLIR (linalg) ===") + print(linalg) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + print("failed") diff --git a/third_party/ascend/unittest/custom_op/custom_op_indexing_map_compose_demo.py b/third_party/ascend/unittest/custom_op/custom_op_indexing_map_compose_demo.py new file mode 100644 index 0000000000..6d42c5b832 --- /dev/null +++ b/third_party/ascend/unittest/custom_op/custom_op_indexing_map_compose_demo.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +import os +import subprocess +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + + +def _compose_indexing_maps(): + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + c4 = ascend_ir.affine_expr.get_constant(4) + + # Base permutation map. + perm = ascend_ir.affine_map.get_permutation([1, 0]) + # Tile map (row-major tile decomposition). + tile = ascend_ir.affine_map.get(2, 0, [d0.floordiv(c4), d1.mod(c4)]) + # Compose tile with permutation to build a different output indexing. + out = tile.compose(perm) + + in0 = ascend_ir.affine_map.get_identity(2) + in1 = perm + return [in0, in1, out] + + +@al.register_custom_op +class compose_indexing_map_custom_op: + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "compose_indexing_map_custom" + bitcode = os.path.abspath(__file__) + + def __init__(self, x, y, out=None): + assert out is not None + self.indexing_map = _compose_indexing_maps() + + +@triton.jit +def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + out = al.custom("compose_indexing_map_custom_op", x, y, out=y) + tl.store(out_ptr + i, out, mask=i < n) + + +if __name__ == "__main__": + src = ASTSource( + my_kernel, + {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + options = NPUOptions() + try: + ttir = ast_to_ttir(my_kernel, src, context, options, {}, {}) + print("=== TTIR ===") + print(ttir) + linalg = ttir_to_linalg(ttir, {**options.__dict__}, options, named_ops=True) + print("=== MLIR (linalg) ===") + print(linalg) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + print("failed") diff --git a/third_party/ascend/unittest/custom_op/test_gather_load.py b/third_party/ascend/unittest/custom_op/test_gather_load.py new file mode 100644 index 0000000000..03e83b9171 --- /dev/null +++ b/third_party/ascend/unittest/custom_op/test_gather_load.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +import torch +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al + + +@triton.jit +def test_gather_load_kernel(src_ptr, index_ptr, out_ptr): + # index tile shape: (2, 2) + cols = tl.arange(0, 2)[None, :] # [[0, 1]] + rows = tl.arange(0, 2)[:, None] # [[0],[1]] + mask = (rows < 2) & (cols < 2) + + # load index tile to UB + index = tl.load(index_ptr + rows * 2 + cols, mask) + + # gather load from GM to UB + dst = tl.full(index.shape, 0, tl.float32) + gathered = al.custom("__builtin_gather_load", src_ptr, index, bound=4, dim=0, src_stride=(2, 1), index_shape=(2, 2), + offsets=(0, 0), out=dst) + + # store result to GM + tl.store(out_ptr + rows * 2 + cols, gathered, mask) + + +if __name__ == "__main__": + src = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]], device='npu') + index = torch.tensor([[0, 1], [2, 3]], device='npu') + out = torch.empty((2, 2), device='npu', dtype=torch.float32) + test_gather_load_kernel[(1, )](src, index, out) + print("result: ", out) # [[1., 4.], [5., 8.]] diff --git a/third_party/ascend/unittest/custom_op/test_index_select.py b/third_party/ascend/unittest/custom_op/test_index_select.py new file mode 100644 index 0000000000..d06174fde1 --- /dev/null +++ b/third_party/ascend/unittest/custom_op/test_index_select.py @@ -0,0 +1,44 @@ +import pytest +import torch +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al + + +@triton.jit +def builtin_index_select_kernel(src_ptr, index_ptr, out_ptr): + # Define 2x2 tile indices for output tensor + r = tl.arange(0, 2)[:, None] # Row indices: shape [2, 1] + c = tl.arange(0, 2)[None, :] # Column indices: shape [1, 2] + + # Load index tensor (shape [2]) from GM to UB + idx = tl.load(index_ptr + tl.arange(0, 2)) + # Initialize empty 2x2 output tile in UB (default value: 0) + dst = tl.full((2, 2), 0, dtype=tl.float32) + + # Invoke __builtin_index_select custom op to gather elements + out_tile = al.custom("__builtin_index_select", src_ptr, # Pointer to source tensor in GM + idx, # Index tensor (in UB) for gathering + dim=0, # Dimension to gather along + bound=4, # Upper bound for valid index values (out-of-bound check) + end_offset=(2, 2), # End offsets of each dimension for the index tensor + start_offset=(0, 0), # Start offsets of each dimension for the source tensor + src_stride=(4, 1), # Stride of each dimension for the source tensor in GM + out=dst # Output tensor (in UB) to store gathered elements + ) + + # Store the gathered tile from UB to output tensor in GM + tl.store(out_ptr + r * 2 + c, out_tile) + + +if __name__ == "__main__": + src = torch.tensor( + [[10., 11., 12., 13.], [20., 21., 22., 23.], [30., 31., 32., 33.], [40., 41., 42., 43.]], + device="npu", + dtype=torch.float32, + ) + index = torch.tensor([2, 0], device="npu", dtype=torch.int32) + out = torch.empty((2, 2), device="npu", dtype=torch.float32) + ref = torch.index_select(src, 0, index.to(torch.int64))[:, :2] + builtin_index_select_kernel[(1, )](src, index, out) + torch.testing.assert_close(out, ref) # ref: [[30., 31.], [10., 11.]] diff --git a/third_party/ascend/unittest/generalization_cases/acc_util.py b/third_party/ascend/unittest/generalization_cases/acc_util.py deleted file mode 100644 index b1295885a6..0000000000 --- a/third_party/ascend/unittest/generalization_cases/acc_util.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import numpy as np -import torch -import torch_npu - -eval_standard = { - torch.float32: { - "rtol": 1e-6, - "small_value": 1e-6, - "small_value_atol": 1e-9, - "etol": 1e-4, - }, - torch.float16: { - "rtol": 1e-3, - "small_value": 1e-3, - "small_value_atol": 1e-5, - "etol": 1e-3, - }, - torch.bfloat16: { - "rtol": 4e-3, - "small_value": 1e-3, - "small_value_atol": 1e-5, - "etol": 1e-3, - }, -} - - -def assert_close(gold: torch.Tensor, act: torch.Tensor, eval_type: str = 'DEFAULT'): - gold = gold.cpu() - act = act.cpu() - if act.dtype == torch.float16 or act.dtype == torch.float32 or act.dtype == torch.bfloat16: - assert gold.dtype == torch.float32, "golden should be f32" - assert not (torch.isnan(act).any() or torch.isinf(act).any()), "actual tensor can not have 'inf' or 'nan'" - eps = eval_standard[act.dtype]['small_value'] - rtol = eval_standard[act.dtype]['rtol'] - atol = eval_standard[act.dtype]['small_value_atol'] - if eval_type == 'DEFAULT': - ae = torch.abs(act - gold) - re = ae / torch.abs(gold) - mask = torch.abs(gold) < eps - - print(f"count ae > {atol}: {(ae > atol).sum()}") - print(f"count re > {rtol}: {(re > rtol).sum()}") - - not_close = torch.where(mask, ae > atol, re > rtol) - print(f"count not_close = {torch.sum(not_close).item()}") - print(f"not_close.numel = {not_close.numel()}, gold.numel = {gold.numel()}") - print(f"not close ratio = {torch.sum(not_close).item() / not_close.numel()}") - if not torch.any(not_close): - return False - - assert torch.sum( - not_close).item() < not_close.numel() * eps, "actual tensor are not close enough with golden tensor,\ -you can use 'benchmark_compare_close' function to compare again!" - - elif eval_type == 'ABS': - act = act.to(gold.dtype) - assert torch.equal(gold, act), "actual tensor and golden tensor are not binary equal!" - else: - assert 0, "ERROR! invalid eval_type" - return False - - -def benchmark_compare_close(gold: torch.Tensor, act: torch.Tensor, std: torch.tensor): - assert act.dtype == std.dtype, "standard tensor's dtype must equal to actual tensor's dtype!" - if act.dtype == torch.float16 or act.dtype == torch.float32 or act.dtype == torch.bfloat16: - assert gold.dtype == torch.float32, "golden should be f32" - assert not (torch.isnan(act).any() or torch.isinf(act).any()), "actual tensor can not have 'inf' or 'nan'" - - gold = gold.cpu() - act = act.cpu() - std = std.cpu() - - eps = eval_standard[act.dtype]['small_value'] - atol = eval_standard[act.dtype]['small_value_atol'] - - mask = torch.abs(gold) <= eps - small_count = mask.sum().item() - - def calculate_relative_errors_except_small(tensor): - re = torch.abs(gold - tensor) / torch.abs(gold) - return torch.where(mask, 0, re) - - act_re = calculate_relative_errors_except_small(act) - std_re = calculate_relative_errors_except_small(std) - act_ae = torch.abs(gold - std) - std_ae = torch.abs(gold - std) - - # 小值域的定义为golden小于某个阈值 eps - act_small_error_count = (mask & (act_ae > atol)).sum().item() - std_small_error_count = (mask & (std_ae > atol)).sum().item() - act_total = act.numel() - std_total = std.numel() - - act_small_error_ratio = act_small_error_count / act_total - std_small_error_ratio = std_small_error_count / std_total - - def calculate_rmse(tensor): - dlt2 = (tensor - gold)**2 - dlt2_except_small_mean = torch.where(mask, 0, dlt2).sum() / small_count - return torch.sqrt(dlt2_except_small_mean) - - act_rmse = calculate_rmse(act) - std_rmse = calculate_rmse(std) - - print(f"act_re.max = {act_re.max()}, std_re.max = {std_re.max()}, limit ratio = 10") - print(f"act_re.sum = {act_re.sum()}, std_re.sum = {std_re.sum()}, limit_ratio = 2") - print( - f"act_small_error_ratio = {act_small_error_ratio}, std_small_error_ratio = {std_small_error_ratio}, limit_ratio = 2" - ) - print(f"act_rmse = {act_rmse}, std_rmse = {std_rmse}, limit_ratio = 2") - - # 条件 1:actual 与 golden 相对误差最大值超过 10 倍 standard 与 golden 相对误差最大值 - assert act_re.max() <= 10 * std_re.max(), "actual re max > stdandard re max's 10 times" - - # 条件 2:actual 与 golden 相对误差均值超过 2 倍 standard 与 golden 相对误差均值 - assert act_re.sum() <= 2 * std_re.sum(), "actual re sum > stdandard re sum's 2 times" - - # 条件 3:actual 小值域 ERROR 占比超过 standard 的两倍 - assert act_small_error_ratio <= 2 * std_small_error_ratio, "act_small_error_ratio > std_small_error_ratio 's 2 times" - - # 条件 4:actual 均方根误差差于 standard 的两倍 - assert act_rmse <= 2 * std_rmse, "act_rmse > std_rmse 's 2 times" - - return False diff --git a/third_party/ascend/unittest/generalization_cases/test_abs.py b/third_party/ascend/unittest/generalization_cases/test_abs.py deleted file mode 100644 index c5d06d0d40..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_abs.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, avoid_not_support -import math -import logging - - -def torch_pointwise(x0): - if x0.dtype != torch.uint32: - return torch.abs(x0) - else: - return torch.abs(x0.to(torch.float32)) - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.abs(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_abs_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.abs(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_pointwise(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) -def test_abs_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_pointwise(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_abs_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_advance.py b/third_party/ascend/unittest/generalization_cases/test_advance.py deleted file mode 100644 index 8c7a75dc46..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_advance.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, ), - strides=(1, ), - offsets=(5, ), - block_shape=(XB, ), - order=(0, ), - ) - bbptr = tl.advance(block_ptr_in, (-5, )) - # XB,YB,1 - X = tl.load(bbptr) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, ), - strides=(1, ), - offsets=(0, ), - block_shape=(XB, ), - order=(0, ), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def fn_npu_2d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xoffset = tl.program_id(0) - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, YB), - strides=(YB, 1), - offsets=(6 + xoffset, 5), - block_shape=(XB, YB), - order=(1, 0), - ) - bbptr = tl.advance(block_ptr_in, (-6, -5)) - # XB,YB,1 - X = tl.load(bbptr) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, YB), - strides=(YB, 1), - offsets=(xoffset, 0), - block_shape=(XB, YB), - order=(1, 0), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def fn_npu_3d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, YB, ZB), - strides=(YB * ZB, ZB, 1), - offsets=(3, 1, 2), - block_shape=(XB, YB, ZB), - order=(2, 1, 0), - ) - bbptr = tl.advance(block_ptr_in, (-3, -1, -2)) - # XB,YB,1 - X = tl.load(bbptr) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, YB, ZB), - strides=(YB * ZB, ZB, 1), - offsets=(0, 0, 0), - block_shape=(XB, YB, ZB), - order=(2, 1, 0), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def triton_advance_4d( - output_ptr, - x_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), - offsets=(6, 5, 4, 3), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), - order=(3, 2, 1, 0), - ) - bbptr = tl.advance(block_ptr_in, (-6, -5, -4, -3)) - x = tl.load(bbptr) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), - offsets=(0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), - order=(3, 2, 1, 0), - ) - tl.store(block_ptr_out, x) - - -@triton.jit -def triton_advance_5d( - output_ptr, - x_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - BLOCK_4: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr, -): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), - offsets=(6, 5, 4, 3, 2), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), - order=(4, 3, 2, 1, 0), - ) - bbptr = tl.advance(block_ptr_in, (-6, -5, -4, -3, -2)) - x = tl.load(bbptr) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), - offsets=(0, 0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), - order=(4, 3, 2, 1, 0), - ) - tl.store(block_ptr_out, x) - - -temporarily_not_support_dtype = ['bool'] - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.full_shape) -def test_npu(dtype, shape): - if dtype in temporarily_not_support_dtype: - return - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - - a = x - blocks = list(x.size()) - strides = list(x.stride()) - grid = (1, ) - if len(shape) == 5: - triton_advance_5d[grid](output, x, *blocks, *blocks, *strides) - elif len(shape) == 4: - triton_advance_4d[grid](output, x, *blocks, *blocks, *strides) - elif len(shape) == 3: - fn_npu_3d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=shape[2]) - elif len(shape) == 2: - if x.numel() * x.element_size() > 8192: - fn_npu_2d[shape[0], 1, 1](output, x, y, z, output1, XB=1, YB=shape[1], ZB=1) - else: - fn_npu_2d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=1) - else: - fn_npu_1d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=1, ZB=1) - - torch.testing.assert_close(output, a) diff --git a/third_party/ascend/unittest/generalization_cases/test_and.py b/third_party/ascend/unittest/generalization_cases/test_and.py deleted file mode 100644 index 4bac287eaf..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_and.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_pointwise(x, y): - res = x & y - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X & Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_and_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val & y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_pointwise(x, y) - output = torch.zeros_like(ans) - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_and_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x & y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_and_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/unittest/generalization_cases/test_argmax.py b/third_party/ascend/unittest/generalization_cases/test_argmax.py deleted file mode 100644 index edbf8b9d8d..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_argmax.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import math -import pytest -import torch -import torch_npu -import numpy as np -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - -logger = logging.getLogger(__name__) - - -# <<<<<<< test_argmax_1d -def torch_argmax(x0, dim, keepdim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - return torch.argmax(x0, dim=dim, keepdim=keepdim).npu() - - -@triton.jit -def triton_argmax_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None) - tmp4 = tl.argmax(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_argmax_1d(dtype, shape): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty(1, dtype=torch.int32).npu() - numel = shape[0] - triton_argmax_1d[1, 1, 1](x0, triton_res, numel, numel) - torch_res = torch_argmax(x0, dim=0, keepdim=True) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmax_1d - - -# <<<<<<< test_argmax_2d -@triton.jit -def triton_argmax_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) - tmp4 = tl.argmax(x, dim) - if dim == 0: - tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) - else: - tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0, 1]) -def test_argmax_2d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - shapex, shapey = shape - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[1 - dim], - ], dtype=torch.int32).npu() - triton_argmax_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) - torch_res = torch_argmax(x0, dim=dim, keepdim=False) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmax_2d - - -# <<<<<<< test_argmax_3d -def torch_argmax_3d(x0, no_reduce_dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - if no_reduce_dim == 0: - return torch.argmax(torch.max(x0, 1)[0], 1).npu() - elif no_reduce_dim == 1: - return torch.argmax(torch.max(x0, 0)[0], 1).npu() - elif no_reduce_dim == 2: - return torch.argmax(torch.max(x0, 0)[0], 0).npu() - else: - assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" - - -@triton.jit -def triton_argmax_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.max(x, 0) - ret = tl.argmax(tmp, 0) - oidx = zidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_argmax_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.max(x, 0) - ret = tl.argmax(tmp, 1) - oidx = yidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_argmax_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.max(x, 1) - ret = tl.argmax(tmp, 1) - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -def triton_argmax_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): - if no_reduce_dim == 0: - triton_argmax_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 1: - triton_argmax_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 2: - triton_argmax_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) -def test_argmax_3d(dtype, shape, no_reduce_dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[no_reduce_dim], - ], dtype=torch.int32).npu() - triton_argmax_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) - torch_res = torch_argmax_3d(x0, no_reduce_dim) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmax_3d - - -# <<<<<<< test_argmax_4d -def torch_argmax_4d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.argmax(x0, dim) - - -@triton.jit -def argmax_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB // MB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_argmax_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - - idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[ - None, None, :, None] * MB + midx[None, None, None, :] - - x = tl.load(in_ptr + idx) - - argmax_4d(out_ptr, x, XB, YB, ZB, MB, DIM) - - -def triton_argmax_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): - triton_argmax_kernel_4d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 2, 4, 8), - (2, 3, 4, 8), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0]) -def test_argmax_4d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_argmax_4d(x0, dim).to(torch.int32) - triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() - triton_argmax_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) - - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmax_4d - - -# <<<<<<< test_argmax_5d -def torch_argmax_5d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.argmax(x0, dim) - - -@triton.jit -def argmax_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, - DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 3: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // NB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_argmax_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - - idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[ - None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] - - x = tl.load(in_ptr + idx) - - argmax_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) - - -def triton_argmax_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): - triton_argmax_kernel_5d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 2, 2, 4, 8), - (2, 2, 3, 4, 8), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0]) -def test_argmax_5d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_argmax_5d(x0, dim).to(torch.int32) - triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() - triton_argmax_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) - - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmax_5d - - -# <<<<<<< test_argmax_1d_bool -@triton.jit -def triton_argmax_1d_bool(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None).to(tl.int1) - tmp4 = tl.argmax(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['bool']) -def test_argmax_1d_bool(dtype, shape): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - x0 = test_common.generate_tensor(shape, dtype) - triton_res = torch.empty(1, dtype=torch.int32).npu() - numel = shape[0] - triton_argmax_1d_bool[1, 1, 1](x0.npu(), triton_res, numel, numel) - np_res = np.argmax(x0.numpy()) - np.equal(triton_res.item(), np_res) - - -# >>>>>>> test_argmax_1d_bool diff --git a/third_party/ascend/unittest/generalization_cases/test_argmin.py b/third_party/ascend/unittest/generalization_cases/test_argmin.py deleted file mode 100644 index 36a671d1ba..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_argmin.py +++ /dev/null @@ -1,360 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import math -import pytest -import torch -import torch_npu -import numpy as np -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - -logger = logging.getLogger(__name__) - - -# <<<<<<< test_argmin_1d -def torch_argmin(input_tensor, dim, keepdim): - return torch.argmin(input_tensor, dim=dim, keepdim=keepdim) - - -@triton.jit -def triton_argmin_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None) - tmp4 = tl.argmin(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_argmin_1d(dtype, shape): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty(1, dtype=torch.int32).npu() - numel = shape[0] - triton_argmin_1d[1, 1, 1](x0, triton_res, numel, numel) - torch_res = torch_argmin(x0, dim=0, keepdim=True) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmin_1d - - -# <<<<<<< test_argmin_2d -@triton.jit -def triton_argmin_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) - tmp4 = tl.argmin(x, dim) - if dim == 0: - tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) - else: - tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0, 1]) -def test_argmin_2d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - shapex, shapey = shape - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[1 - dim], - ], dtype=torch.int32).npu() - triton_argmin_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) - torch_res = torch_argmin(x0, dim=dim, keepdim=False) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmin_2d - - -# <<<<<<< test_argmin_3d -def torch_argmin_3d(x0, no_reduce_dim): - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - if no_reduce_dim == 0: - return torch.argmin(torch.min(x0, 1)[0], 1) - elif no_reduce_dim == 1: - return torch.argmin(torch.min(x0, 0)[0], 1) - elif no_reduce_dim == 2: - return torch.argmin(torch.min(x0, 0)[0], 0) - else: - assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" - - -@triton.jit -def triton_argmin_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.min(x, 0) - ret = tl.argmin(tmp, 0) - oidx = zidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_argmin_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.min(x, 0) - ret = tl.argmin(tmp, 1) - oidx = yidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_argmin_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.min(x, 1) - ret = tl.argmin(tmp, 1) - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -def triton_argmin_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): - if no_reduce_dim == 0: - triton_argmin_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 1: - triton_argmin_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 2: - triton_argmin_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) -def test_argmin_3d(dtype, shape, no_reduce_dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[no_reduce_dim], - ], dtype=torch.int32).npu() - triton_argmin_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) - torch_res = torch_argmin_3d(x0, no_reduce_dim) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmin_3d - - -# <<<<<<< test_argmin_4d -def torch_argmin_4d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.argmin(x0, dim) - - -@triton.jit -def argmin_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB // MB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_argmin_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - - idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[ - None, None, :, None] * MB + midx[None, None, None, :] - - x = tl.load(in_ptr + idx) - - argmin_4d(out_ptr, x, XB, YB, ZB, MB, DIM) - - -def triton_argmin_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): - triton_argmin_kernel_4d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 2, 4, 8), - (2, 3, 4, 8), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0]) -def test_argmin_4d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_argmin_4d(x0, dim).to(torch.int32) - triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() - triton_argmin_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) - - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmin_4d - - -# <<<<<<< test_argmin_5d -def torch_argmin_5d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.argmin(x0, dim) - - -@triton.jit -def argmin_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, - DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 3: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // NB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_argmin_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - - idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[ - None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] - - x = tl.load(in_ptr + idx) - - argmin_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) - - -def triton_argmin_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): - triton_argmin_kernel_5d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 2, 2, 4, 8), - (2, 2, 3, 4, 8), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0]) -def test_argmin_5d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_argmin_5d(x0, dim).to(torch.int32) - triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() - triton_argmin_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) - - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmin_5d - - -# <<<<<<< test_argmin_1d_bool -@triton.jit -def triton_argmin_1d_bool(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None).to(tl.int1) - tmp4 = tl.argmin(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['bool']) -def test_argmin_1d_bool(dtype, shape): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - x0 = test_common.generate_tensor(shape, dtype) - triton_res = torch.empty(1, dtype=torch.int32).npu() - numel = shape[0] - triton_argmin_1d_bool[1, 1, 1](x0.npu(), triton_res, numel, numel) - np_res = np.argmin(x0.numpy()) - np.equal(triton_res.item(), np_res) - - -# >>>>>>> test_argmin_1d_bool diff --git a/third_party/ascend/unittest/generalization_cases/test_associative_scan.py b/third_party/ascend/unittest/generalization_cases/test_associative_scan.py deleted file mode 100644 index 249abc09fb..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_associative_scan.py +++ /dev/null @@ -1,523 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import random -import torch -import torch_npu -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, get_dtype_size - - -def combine_fn_test_torch(a, b, combine_fn): - if combine_fn == 'maximum_fn': - return torch.maximum(a, b) # 最大值 - elif combine_fn == 'minimum_fn': - return torch.minimum(a, b) # 最小值 - elif combine_fn == 'bitwise_xor_fn': - return a ^ b # 按位异或 - elif combine_fn == 'bitwise_or_fn': - return a | b # 按位异 - elif combine_fn == 'bitwise_and_fn': - return a & b # 按位与 - else: - pytest.skip("The combine_fn is not within the following scope , skipping.") - - -def torch_func_scan(input: torch.Tensor, dim: int, combine_fn='maximum', reverse=False): - """ - PyTorch 实现 associative_scan,语义与 Triton 完全对齐 - 支持任意 combine_fn(如 a|b, a&b, min, max 等) - """ - dim = dim % input.ndim - - if reverse: - input = input.flip(dim) - - N = input.size(dim) - - tensors = torch.unbind(input, dim=dim) - - outputs = [] - - carry = tensors[0] - outputs.append(carry) - - for i in range(1, N): - carry = combine_fn_test_torch(tensors[i], carry, combine_fn) - outputs.append(carry) - - output = torch.stack(outputs, dim=dim) - - if reverse: - output = output.flip(dim) - - return output - - -@triton.jit -def bitwise_and_fn(a, b): - return a & b - - -@triton.jit -def bitwise_or_fn(a, b): - return a | b - - -@triton.jit -def bitwise_xor_fn(a, b): - return a ^ b - - -@triton.jit -def minimum_fn(a, b): - return tl.minimum(a, b) - - -@triton.jit -def maximum_fn(a, b): - return tl.maximum(a, b) - - -@triton.jit -def triton_kernel_1d_scan( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - XBLOCK: tl.constexpr, - combine_fn_name: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - idx = tl.arange(0, XBLOCK) - x = tl.load(in_ptr0 + idx) - if combine_fn_name == "maximum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) - elif combine_fn_name == "minimum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) - elif combine_fn_name == "bitwise_or_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) - elif combine_fn_name == "bitwise_xor_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) - elif combine_fn_name == "bitwise_and_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) - - tl.store(out_ptr0 + idx, ret) - - -@triton.jit -def triton_kernel_2d_scan( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, - combine_fn_name: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx = idx_x[:, None] * numel_r + idx_r[None, :] - x = tl.load(in_ptr0 + idx) - - if combine_fn_name == "maximum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) - elif combine_fn_name == "minimum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) - elif combine_fn_name == "bitwise_or_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) - elif combine_fn_name == "bitwise_xor_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) - elif combine_fn_name == "bitwise_and_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) - tl.store(out_ptr0 + idx, ret) - - -@triton.jit -def triton_kernel_3d_scan( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - numel_z: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, - ZBLOCK: tl.constexpr, - combine_fn_name: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - tl.static_assert(numel_z == ZBLOCK, "numel_z must be equal to ZBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx_z = tl.arange(0, ZBLOCK) - idx = idx_x[:, None, None] * numel_r * numel_z + idx_r[None, :, None] * numel_z + idx_z[None, None, :] - x = tl.load(in_ptr0 + idx) - if combine_fn_name == "maximum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) - elif combine_fn_name == "minimum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) - elif combine_fn_name == "bitwise_or_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) - elif combine_fn_name == "bitwise_xor_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) - elif combine_fn_name == "bitwise_and_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) - tl.store(out_ptr0 + idx, ret) - - -@triton.jit -def triton_kernel_4d_scan( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, - combine_fn_name: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - idx = (xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + - zidx[None, None, :, None] * MB + midx[None, None, None, :]) - x = tl.load(in_ptr0 + idx) - if combine_fn_name == "maximum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) - elif combine_fn_name == "minimum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) - elif combine_fn_name == "bitwise_or_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) - elif combine_fn_name == "bitwise_xor_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) - elif combine_fn_name == "bitwise_and_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) - tl.store(out_ptr0 + idx, ret) - - -@triton.jit -def triton_kernel_5d_scan( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, - NB: tl.constexpr, - combine_fn_name: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - idx = (xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + - zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + - nidx[None, None, None, None, :]) - x = tl.load(in_ptr0 + idx) - if combine_fn_name == "maximum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) - elif combine_fn_name == "minimum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) - elif combine_fn_name == "bitwise_or_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) - elif combine_fn_name == "bitwise_xor_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) - elif combine_fn_name == "bitwise_and_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) - tl.store(out_ptr0 + idx, ret) - - -def triton_func_scan(x, dim, combine_fn, reverse): - res = torch.empty_like(x) - shape = x.size() - - if len(shape) == 1: - if dim >= 1: - pytest.skip("dim >= 1 for 1D tensor, skipping.") - triton_kernel_1d_scan[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[0], combine_fn) - elif len(shape) == 2: - if dim >= 2: - pytest.skip("dim >= 2 for 2D tensor, skipping.") - triton_kernel_2d_scan[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1], combine_fn) - elif len(shape) == 3: - if dim >= 3: - pytest.skip("dim >= 3 for 3D tensor, skipping.") - triton_kernel_3d_scan[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[0], x.shape[1], - x.shape[2], combine_fn) - elif len(shape) == 4: - if dim >= 4: - pytest.skip("dim >= 4 for 4D tensor, skipping.") - triton_kernel_4d_scan[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], combine_fn) - elif len(shape) == 5: - if dim >= 5: - pytest.skip("dim >= 5 for 5D tensor, skipping.") - triton_kernel_5d_scan[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4], - combine_fn) - else: - pytest.skip(f"Unsupported tensor dimension: {len(shape)}") - - return res - - -def should_skip_due_to_mem(dtype, shape): - dtype_size = get_dtype_size(dtype) - total_mem = dtype_size * math.prod(shape) - if dtype in ('int8', 'bool'): - threshold = TestUtils.ub_size / 13 - else: - threshold = TestUtils.ub_size / 6 - - if total_mem >= threshold: - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - - -@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize("shape", TestUtils.test_shape1d) -@pytest.mark.parametrize("dim", [0]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_1d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize("shape", TestUtils.test_shape2d) -@pytest.mark.parametrize("dim", [1]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_2d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize("shape", TestUtils.test_shape3d) -@pytest.mark.parametrize("dim", [2]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_3d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize("shape", TestUtils.test_shape4d) -@pytest.mark.parametrize("dim", [3]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_4d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize("shape", TestUtils.test_shape5d) -@pytest.mark.parametrize("dim", [4]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_5d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape1d) -@pytest.mark.parametrize("dim", [0]) -@pytest.mark.parametrize("combine_fn", ['maximum_fn', 'minimum_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", random.sample(TestUtils.test_shape2d, 5)) -@pytest.mark.parametrize("dim", [1]) -@pytest.mark.parametrize("combine_fn", ['maximum_fn', 'minimum_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_float_2d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape3d) -@pytest.mark.parametrize("dim", [2]) -@pytest.mark.parametrize("combine_fn", ['maximum_fn', 'minimum_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape4d) -@pytest.mark.parametrize("dim", [3]) -@pytest.mark.parametrize("combine_fn", ['maximum_fn', 'minimum_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape5d) -@pytest.mark.parametrize("dim", [4]) -@pytest.mark.parametrize("combine_fn", ['maximum_fn', 'minimum_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape1d) -@pytest.mark.parametrize("dim", [0]) -@pytest.mark.parametrize("combine_fn", ['bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_scan_float_invalid(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - -@pytest.mark.parametrize("dtype", ['int32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape1d) -@pytest.mark.parametrize("dim", [0]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [True]) -@test_common.raises_with_match(triton.compiler.errors.MLIRCompilationError, - "reverse=True is not yet supported for scan op") -def test_scan_float_invalid_reverse(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_add.py b/third_party/ascend/unittest/generalization_cases/test_atomic_add.py deleted file mode 100644 index d55448ae66..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_add.py +++ /dev/null @@ -1,576 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import numpy as np -import test_common -from test_common import TestUtils - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'int64', 'bool'}] - - -@triton.jit -def atomic_add(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - offset = tl.program_id(0) * BLOCK_SIZE - index = offset + tl.arange(0, BLOCK_SIZE)[:] - xmask = index < n_elements - - tmp0 = tl.load(in_ptr0 + (index), xmask) - tmp1 = tl.load(out_ptr0 + (index), xmask) - tl.atomic_add(out_ptr1 + (index), tmp0, xmask) - tl.atomic_add(out_ptr1 + (index), tmp1, xmask) - - -@triton.jit -def atomic_add_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic add: y += x (broadcasted) - tl.atomic_add(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_add(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -def promote_dtype(x_dtype, y_dtype): - """ - 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 - """ - # 如果两个数据类型一致,直接返回 - if x_dtype == y_dtype: - return y_dtype - - # 构建类型的优先级列表(从低到高) - priority = [torch.int8, torch.int16, torch.int32, torch.float16, torch.bfloat16, torch.float32] - - # 查找两种类型在优先级列表中的位置 - x_priority = priority.index(x_dtype) - y_priority = priority.index(y_dtype) - - # 如果y的优先级比x小,则提升到x的类型 - if y_priority < x_priority: - return x_dtype - else: - return y_dtype - - -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -@pytest.mark.parametrize('x_shape, y_shape, BLOCK_SIZE', test_cases) -def test_atomic_add_broadcast_combined(x_dtype_str, y_dtype_str, x_shape, y_shape, BLOCK_SIZE): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - # 先构造 x0 - x0 = torch.full(x_shape, 83.0000, dtype=x_dtype).npu() - - y_raw_dtype = eval('torch.' + y_dtype_str) - - out_dtype = promote_dtype(x_dtype, y_raw_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - - # 构造y和out - y = torch.full(y_shape, -105, dtype=y_raw_dtype).npu() - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - # 保存副本用于验证 - x_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - # 计算网格大小和元素总数 - n_elements = y.numel() - grid = (n_elements // BLOCK_SIZE, ) # 自动计算需要的线程块数量 - - # 调用 Triton 核函数 - atomic_add_broadcast[grid](x_ptr=x0, y_ptr=y, out_ptr=out, n_elements=n_elements, BLOCK_SIZE=BLOCK_SIZE) - - # 验证结果:y += x (广播加法) - expected = out_temp + y_temp + x_temp - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x0 = test_common.generate_tensor(shape, x_dtype_str).npu() - x1 = test_common.generate_tensor(shape, y_dtype_str).npu() - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - y = torch.full(x1.shape, 0, dtype=out_dtype).npu() - - # 保存副本用于验证 - x0_temp = x0.clone() - x1_temp = x1.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] - atomic_add[shape[0], 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=shape[1]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - atomic_add[grid_size, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=BLOCK_SIZE) - - expected = y_temp + x1_temp + x0_temp - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add_3d(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x0 = test_common.generate_tensor(shape, x_dtype_str).npu() - x1 = test_common.generate_tensor(shape, y_dtype_str).npu() - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - y = torch.full(x1.shape, 0, dtype=out_dtype).npu() - - # 保存副本用于验证 - x0_temp = x0.clone() - x1_temp = x1.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_add[1, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=shape[0] * shape[1] * shape[2]) - - expected = y_temp + x1_temp + x0_temp - torch.testing.assert_close(y, expected) - - -@triton.jit -def atomic_add_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_add(out_ptr0 + offsets, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_add_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = x1 + x0_value - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_add_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@triton.jit -def atomic_add_5d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr, - NB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1 * NB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1 * NB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1 * NB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] * NB1 - offsets1 = offsets1[:, :, :, :, None] + tl.arange(0, NB1)[None, None, None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tl.atomic_add(out_ptr + offsets1, tmp0) - tl.atomic_add(out_ptr + offsets1, tmp1) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 2, 1, 1), (1, 1, 2, 1, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add_5d(x_dtype_str, y_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() - else: - y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - x0_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 5: - triton_shape1.append(1) - XB1, YB1, ZB1, MB1, NB1 = triton_shape1 - - atomic_add_5d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - NB=NB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - NB1=NB1, - ) - - expected = out_temp + y_temp + x0_temp - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_add_4d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tl.atomic_add(out_ptr + offsets1, tmp0) - tl.atomic_add(out_ptr + offsets1, tmp1) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 2, 1), (1, 1, 2, 2)], - [(1, 1, 1, 1), (1, 1, 2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add_4d(x_dtype_str, y_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() - else: - y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - x0_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 4: - triton_shape1.append(1) - XB1, YB1, ZB1, MB1 = triton_shape1 - - atomic_add_4d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - ) - - expected = out_temp + y_temp + x0_temp - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_add_3d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XB1: tl.constexpr, - YB1: tl.constexpr, ZB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tl.atomic_add(out_ptr + offsets1, tmp0) - tl.atomic_add(out_ptr + offsets1, tmp1) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 2), (1, 2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add_3d_2(x_dtype_str, y_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() - else: - y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - x0_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 3: - triton_shape1.append(1) - XB1, YB1, ZB1 = triton_shape1 - - atomic_add_3d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - ) - - expected = out_temp + y_temp + x0_temp - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_add_2d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] - - offsets1 = tl.arange(0, XB1) * (YB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tl.atomic_add(out_ptr + offsets1, tmp0) - tl.atomic_add(out_ptr + offsets1, tmp1) - - -@pytest.mark.parametrize('param_list', [ - [(1, 2), (2, 2)], - [(1, 1), (2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add_2d(x_dtype_str, y_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() - else: - y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - x0_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 2: - triton_shape1.append(1) - XB1, YB1 = triton_shape1 - - atomic_add_2d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - XB1=XB1, - YB1=YB1, - ) - - expected = out_temp + y_temp + x0_temp - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('param_list', [ - ['uint8', (32, 32), 2], - ['uint16', (32, 32), 2], - ['uint32', (32, 32), 2], -]) -def test_atomic_add_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] / ncore - split_size = shape[0] // ncore - x0_value = 3 - x0_cpu = torch.full(shape, x0_value, dtype=eval(f'torch.{dtype}')).cpu() - x0 = x0_cpu.to("npu") - x1_cpu = torch.full((split_size, shape[1]), 4, dtype=eval(f'torch.{dtype}')).cpu() - x1 = x1_cpu.to("npu") - y_cpu = torch.full((split_size, shape[1]), -10, dtype=eval(f'torch.{dtype}')).cpu() - y = y_cpu.to("npu") - - x1_np = x1_cpu.numpy() - y_ref_np = x1_np + 0 - x1_ref_np = x1_np + ncore * x0_value - - x1_ref = torch.from_numpy(x1_ref_np).npu() - y_ref = torch.from_numpy(y_ref_np).npu() - - @triton.jit - def atomic_add_uint(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) - tmp1 = tl.atomic_add(out_ptr0 + (x1), tmp0, xmask) - tl.store(out_ptr1 + (x1), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_add_uint[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1]) - test_common.validate_cmp(dtype, x1, x1_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_and.py b/third_party/ascend/unittest/generalization_cases/test_atomic_and.py deleted file mode 100644 index 0ef250741c..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_and.py +++ /dev/null @@ -1,562 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import test_common -from test_common import TestUtils - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'float16', 'float32', 'bfloat16', 'bool'}] - - -@triton.jit -def atomic_and(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask) - - -@triton.jit -def atomic_and_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic or: y &= x (broadcasted) - tl.atomic_and(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_and(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - # OR的时候任何位和0做OR都不变 任何位和1做AND也都不变,所以为了保持不变 不能用0 只能用1 - y = torch.full(shape, torch.iinfo(x_dtype).max, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] * 2 - atomic_and[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - aligned_size = grid_size * BLOCK_SIZE - x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - x_concat[0:n_elements] = x[0:n_elements] - x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] - atomic_and[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) - - expected = y_temp & x_temp[0:shape[0]] & x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and_3d(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_and[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - expected = y_temp & x_temp[0:shape[0]] & x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape_ub_overflow) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@test_common.raises_with_match(triton.compiler.errors.MLIRCompilationError, "ub overflow") -def test_atomic_and_ub_overflow(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_and[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - -@triton.jit -def atomic_and_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_and(out_ptr0 + offsets, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_and_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = x1 & x0_value - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_and_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@triton.jit -def atomic_and_5d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr, - NB1: tl.constexpr): - base = tl.program_id(0) * (XB * YB * ZB * MB * NB) - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1 * NB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1 * NB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1 * NB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] * NB1 - offsets1 = offsets1[:, :, :, :, None] + tl.arange(0, NB1)[None, None, None, None, :] - - based_offsets = offsets + base - - tmp0 = tl.load(x_ptr + based_offsets) - tl.atomic_and(out_ptr + offsets1, tmp0) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 2, 1, 1), (1, 1, 2, 1, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and_5d(x_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(x0_shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - - out = torch.full(y_shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 5: - triton_shape1.append(1) - XB1, YB1, ZB1, MB1, NB1 = triton_shape1 - - atomic_and_5d[(2, )]( - x_ptr=x, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - NB=NB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - NB1=NB1, - ) - - expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_and_4d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr): - base = tl.program_id(0) * (XB * YB * ZB * MB) - offsets = tl.arange(0, XB) * (YB * ZB * MB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] - - based_offsets = offsets + base - - tmp0 = tl.load(x_ptr + based_offsets) - tl.atomic_and(out_ptr + offsets1, tmp0) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 2, 1), (1, 1, 2, 2)], - [(1, 1, 1, 1), (1, 1, 2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and_4d(x_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(x0_shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(y_shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 4: - triton_shape1.append(1) - XB1, YB1, ZB1, MB1 = triton_shape1 - - atomic_and_4d[(2, )]( - x_ptr=x, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - ) - - expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_and_3d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XB1: tl.constexpr, - YB1: tl.constexpr, ZB1: tl.constexpr): - base = tl.program_id(0) * (XB * YB * ZB) - offsets = tl.arange(0, XB) * (YB * ZB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] - - based_offsets = offsets + base - - tmp0 = tl.load(x_ptr + based_offsets) - tl.atomic_and(out_ptr + offsets1, tmp0) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 1), (1, 1, 2)], - [(1, 1, 2), (1, 2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and_3d_2(x_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(x0_shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(y_shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 3: - triton_shape1.append(1) - XB1, YB1, ZB1 = triton_shape1 - - atomic_and_3d[(2, )]( - x_ptr=x, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - ) - - expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_and_2d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr): - base = tl.program_id(0) * (XB * YB) - offsets = tl.arange(0, XB) * (YB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] - - offsets1 = tl.arange(0, XB1) * (YB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] - - based_offsets = offsets + base - - tmp0 = tl.load(x_ptr + based_offsets) - tl.atomic_and(out_ptr + offsets1, tmp0) - - -@pytest.mark.parametrize('param_list', [ - [(1, 2), (2, 2)], - [(1, 1), (2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and_2d(x_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(x0_shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(y_shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 2: - triton_shape1.append(1) - XB1, YB1 = triton_shape1 - - atomic_and_2d[(2, )]( - x_ptr=x, - out_ptr=out, - XB=XB, - YB=YB, - XB1=XB1, - YB1=YB1, - ) - - expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_and(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr, - mode: tl.constexpr = 0): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - if mode == 0: - tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask, 'acq_rel', 'cta') - elif mode == 1: - tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask, "test") - elif mode == 2: - tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask, "acq_rel", "test") - - -invalid_types_float = ['float16', 'float32', 'bfloat16'] - - -@pytest.mark.parametrize("sigtype", invalid_types_float) -@test_common.raises_with_match(triton.compiler.errors.MLIRCompilationError, "must be signless-integer-like") -def test_invalid_types_float(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - atomic_and[1, 1, 1](x, y, 1, 1, 32) - - -default_types = ['int8'] - - -@pytest.mark.parametrize("sigtype", default_types) -@pytest.mark.parametrize("test_type", ["sem", "scope"]) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Memory semantic test not supported") -def test_invalid_sem_scope(sigtype, test_type): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - if test_type == "sem": - atomic_and[1, 1, 1](x, y, 1, 1, 32, 1) - elif test_type == "scope": - atomic_and[1, 1, 1](x, y, 1, 1, 32, 2) - - -@triton.jit -def _atomic_and_ss(in_ptr, out_ptr, n_cols, BLOCK_SIZE: tl.constexpr, SEM: tl.constexpr, SCOPE: tl.constexpr): - pid = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = pid < n_cols - val = tl.load(in_ptr + pid, mask) - tl.atomic_and(out_ptr + pid, val, mask, sem=SEM, scope=SCOPE) - - -SEMS = ("relaxed", "acquire", "release", "acq_rel") -SCOPES = ("cta", "gpu", "sys") - - -@pytest.mark.parametrize("sem", SEMS) -@pytest.mark.parametrize("scope", SCOPES) -def test_atomic_sem_vs_scope(sem: str, scope: str): - n_cols = 1024 - BLOCK = 128 - grid = (triton.cdiv(n_cols, BLOCK), ) - - inp = torch.full((n_cols, ), 0xFF, dtype=torch.int32, device="npu") - - base = torch.full_like(inp, 0xFF) - _atomic_and_ss[grid](inp, base, n_cols, BLOCK_SIZE=BLOCK, SEM="acq_rel", SCOPE="gpu") - - cur = torch.full_like(inp, 0xFF) - _atomic_and_ss[grid](inp, cur, n_cols, BLOCK_SIZE=BLOCK, SEM=sem, SCOPE=scope) - - torch.testing.assert_close(cur, base) - - -@pytest.mark.parametrize('param_list', [ - ['uint8', (32, 32), 2], - ['uint16', (32, 32), 2], - ['uint32', (32, 32), 2], - ['uint64', (32, 32), 2], -]) -def test_atomic_and_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] // ncore - split_size = shape[0] // ncore - - val_cpu = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).cpu() - val = val_cpu.to("npu") - - pointer_cpu = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).cpu() - pointer = pointer_cpu.to("npu") - pointer_old_cpu = torch.full_like(pointer_cpu, -10).cpu() - pointer_old = pointer_old_cpu.to("npu") - pointer_ref_cpu = pointer_cpu.clone() - - for i in range(ncore - 1): - pointer_ref_cpu &= val_cpu[(i * split_size):((i + 1) * split_size)] - - pointer_ref_last = pointer_ref_cpu.clone() - pointer_ref_cpu &= val_cpu[((ncore - 1) * split_size):(ncore * split_size)] - pointer_ref = pointer_ref_cpu.to("npu") - - @triton.jit - def atomic_and_uint(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) - tmp1 = tl.atomic_and(out_ptr0 + (x1), tmp0, xmask) - tl.store(out_ptr1 + (x1), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_and_uint[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) - test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_cas.py b/third_party/ascend/unittest/generalization_cases/test_atomic_cas.py deleted file mode 100644 index eab2568755..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_cas.py +++ /dev/null @@ -1,484 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import test_common -from test_common import TestUtils -import numpy as np - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'bfloat16', 'int8', 'bool'}] - - -@triton.jit -def atomic_cas(in_ptr0, in_ptr1, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - tmp1 = tl.load(in_ptr1 + (in_index), xmask) - tl.atomic_cas(out_ptr0 + (out_index), tmp1, tmp0) - - -@triton.jit -def atomic_cas_ndim(x_ptr, y_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, DIM0: tl.constexpr, - DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): - sub_idx = tl.program_id(1) - base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE - base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE - offsets_src = tl.arange(0, BLOCK_SIZE) + base_src - offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst - mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 - tmp = tl.load(x_ptr + offsets_src, mask) - tmp_c = tl.load(y_ptr + offsets_src, mask) - tl.atomic_cas(out_ptr + offsets_dst, tmp_c, tmp) - - -@triton.jit -def atomic_cas_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic or: y |= x (broadcasted) - tl.atomic_cas(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_cas(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - c = torch.randint(low=0, high=2, size=x_shape, dtype=x_dtype).npu() - y = torch.randint(low=0, high=2, size=shape, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - c_temp = c.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] * 2 - atomic_cas[shape[0] * 2, 1, 1](x, c, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - aligned_size = grid_size * BLOCK_SIZE - # value - x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - x_concat[0:n_elements] = x[0:n_elements] - x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] - # compare - c_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - c_concat[0:n_elements] = c[0:n_elements] - c_concat[aligned_size:(aligned_size + n_elements)] = c[n_elements:(n_elements * 2)] - atomic_cas[grid_size * 2, 1, 1](x_concat, c_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, - BLOCK_NUM=grid_size) - - expected = torch.where(y_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], y_temp) - expected = torch.where(expected == c_temp[shape[0]:(shape[0] * 2)], x_temp[shape[0]:(shape[0] * 2)], expected) - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_3d(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - y = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - c_temp = c.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_cas[2, 1, 1](x, c, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - expected = torch.where(y_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], y_temp) - expected = torch.where(expected == c_temp[shape[0]:(shape[0] * 2)], x_temp[shape[0]:(shape[0] * 2)], expected) - torch.testing.assert_close(y, expected) - - -@triton.jit -def atomic_cas_multi_d(in_ptr0, in_ptr1, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - MB: tl.constexpr, NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tmp1 = tl.load(in_ptr1 + offsets) - tl.atomic_cas(out_ptr0 + offsets, tmp1, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_cas_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - c = torch.randint(low=2, high=4, size=shape, dtype=eval('torch.' + dtype)).npu() - x1 = torch.randint(low=2, high=4, size=shape, dtype=eval('torch.' + dtype)).npu() - - x1_ref = torch.where(x1 == c, 3, x1) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - - atomic_cas_multi_d[(1, )](x0, c, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1, 2), - (10, 1, 15, 1, 7), - (1, 1, 1, 1, 257), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_5d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() - - x_temp = x.clone() - c_temp = c.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - BLOCK_SIZE = 256 - ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_cas_ndim[(2 * XB * YB * ZB * MB, ncore)]( - x_ptr=x, - y_ptr=c, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=XB, - DIM1=YB, - DIM2=ZB, - DIM3=MB, - DIM4=NB, - ) - - expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) - expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1), - (1, 1, 2, 2), - (1, 3, 2, 7), - (1, 3, 2, 651), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_4d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() - - x_temp = x.clone() - c_temp = c.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - BLOCK_SIZE = 256 - ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_cas_ndim[(2 * XB * YB * ZB, ncore)]( - x_ptr=x, - y_ptr=c, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=XB, - DIM2=YB, - DIM3=ZB, - DIM4=MB, - ) - - expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) - expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1), - (1, 1, 2), - (1, 31, 275), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_3d_2(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() - - x_temp = x.clone() - c_temp = c.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - BLOCK_SIZE = 256 - ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_cas_ndim[(2 * XB * YB, ncore)]( - x_ptr=x, - y_ptr=c, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=XB, - DIM3=YB, - DIM4=ZB, - ) - - expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) - expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 2), - (1, 1), - (257, 1), - (257, 2), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_2d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() - - x_temp = x.clone() - c_temp = c.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - BLOCK_SIZE = 256 - ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_cas_ndim[(2 * XB, ncore)]( - x_ptr=x, - y_ptr=c, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=XB, - DIM4=YB, - ) - - expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) - expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [(1, ), (9, ), (256, ), (257, ), (65535, ), (65536, )]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_1d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - c_temp = c.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB = triton_shape[0] - BLOCK_SIZE = 256 - ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_cas_ndim[(2, ncore)]( - x_ptr=x, - y_ptr=c, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=1, - DIM4=XB, - ) - - expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) - expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('param_list', [ - ['uint16', (32, 32), 2], - ['uint32', (32, 32), 2], - ['uint64', (32, 32), 2], -]) -def test_atomic_cas_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] // ncore - split_size = shape[0] // ncore - - import random - cmp_val = [random.randint(0, 10) for _ in range(ncore)] - - cmp_cpu_parts = [] - for i in range(ncore): - part = torch.ones(split_size, shape[1], dtype=eval(f'torch.{dtype}')) * cmp_val[i] - cmp_cpu_parts.append(part) - cmp_cpu = torch.cat(cmp_cpu_parts, dim=0) - cmp = cmp_cpu.to("npu") - - val_cpu = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).cpu() - val = val_cpu.to("npu") - - pointer_cpu = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).cpu() - pointer = pointer_cpu.to("npu") - pointer_old_cpu = torch.full_like(pointer_cpu, -10).cpu() - pointer_old = pointer_old_cpu.to("npu") - pointer_ref_cpu = pointer_cpu.clone() - - pointer_ref_np = pointer_cpu.numpy() - val_np = val_cpu.numpy() - for i in range(ncore): - val_subview_np = val_np[(i * split_size):((i + 1) * split_size)] - pointer_ref_np = np.where(pointer_ref_np == cmp_val[i], val_subview_np, pointer_ref_np) - pointer_ref_cpu = torch.from_numpy(pointer_ref_np) - pointer_ref = pointer_ref_cpu.to("npu") - - @triton.jit - def atomic_cas_uint(in_ptr0, in_ptr1, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - val = tl.load(in_ptr0 + (x0), xmask) - cmp = tl.load(in_ptr1 + (x0), xmask) - tmp1 = tl.atomic_cas(out_ptr0 + (x1), cmp, val) - tl.store(out_ptr1 + (x1), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_cas_uint[ncore, 1, 1](val, cmp, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) - test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_max.py b/third_party/ascend/unittest/generalization_cases/test_atomic_max.py deleted file mode 100644 index 87347cb2fd..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_max.py +++ /dev/null @@ -1,258 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import random -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_test_fn_atomic_max_dma(in_ptr0, in_ptr1, out_ptr1, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - index = xoffset + tl.arange(0, BLOCK_SIZE)[:] - mask = index < n_elements - inp0 = tl.load(in_ptr0 + (index), mask) - inp1 = tl.load(in_ptr1 + (index), mask) - tmp1 = tl.atomic_max(out_ptr1 + (index), inp0, mask) - tmp2 = tl.atomic_max(out_ptr1 + (index), inp1, mask) - - -def promote_dtype(x_dtype, y_dtype): - """ - 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 - """ - # 如果两个数据类型一致,直接返回 - if x_dtype == y_dtype: - return y_dtype - - # 构建类型的优先级列表(从低到高) - priority = [torch.int8, torch.int16, torch.int32, torch.float16, torch.bfloat16, torch.float32] - - # 查找两种类型在优先级列表中的位置 - x_priority = priority.index(x_dtype) - y_priority = priority.index(y_dtype) - - # 如果y的优先级比x小,则提升到x的类型 - if y_priority < x_priority: - return x_dtype - else: - return y_dtype - - -# torch.max do not support int -@pytest.mark.parametrize('shape', random.sample(TestUtils.test_shape2d + TestUtils.test_shape1d, 5)) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_max(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x0 = test_common.generate_tensor(shape, x_dtype_str) - x1 = test_common.generate_tensor(shape, y_dtype_str) - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(x1.shape, 0, dtype=out_dtype) - - out_ref = torch.maximum(out, x0) - out_ref = torch.maximum(out_ref, x1) - out_ref = out_ref.npu() - x0 = x0.npu() - x1 = x1.npu() - out = out.npu() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] - triton_test_fn_atomic_max_dma[shape[0], 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=shape[1]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - triton_test_fn_atomic_max_dma[grid_size, 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) - - torch.testing.assert_close(out, out_ref) - - -# 3d -testlist = [ - (1, 22, 39), - (27, 1, 39), - (27, 22, 1), - (1, 1, 23), - (23, 1, 1), - (1, 23, 1), - (27, 5, 3), - (2, 29, 4), - (7, 31, 7), - (3, 5, 8), - (7, 17, 15), - (25, 5, 16), - (23, 5, 31), - (7, 11, 32), - (7, 11, 33), - (2, 3, 255), - (3, 3, 256), - (3, 2, 257), -] - - -@pytest.mark.parametrize('shape', random.sample(testlist, 5)) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_max_3d(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - ncore = 1 - split_size = shape[0] // ncore - x0 = test_common.generate_tensor(shape, x_dtype_str) - x1 = test_common.generate_tensor(shape, y_dtype_str) - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - y = torch.full(shape, 0, dtype=out_dtype) - - out_ref = torch.full_like(x0, 0, dtype=out_dtype) - out_ref = torch.maximum(out_ref, x0) - out_ref = torch.maximum(out_ref, x1) - x0 = x0.npu() - x1 = x1.npu() - y = y.npu() - - n_elements = shape[0] * shape[1] * shape[2] - triton_test_fn_atomic_max_dma[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1] * shape[2]) - y = y.cpu() - torch.testing.assert_close(y, out_ref) - - -@triton.jit -def atomic_max_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_max(out_ptr0 + offsets, tmp0) - - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'int64', 'bool'}] - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_max_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = torch.maximum(x1, x0) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_max_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@triton.jit -def atomic_max_multi_d_2(in_ptr0, out_ptr0, out_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - MB: tl.constexpr, NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tmp1 = tl.load(out_ptr0 + offsets) - tl.atomic_max(out_ptr1 + offsets, tmp0) - tl.atomic_max(out_ptr1 + offsets, tmp1) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_max_4d_5d_2(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - x1 = torch.randint(low=0, high=100, size=shape, dtype=y_dtype).npu() - else: - x1 = torch.randn(shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: - y = torch.full(shape, torch.iinfo(out_dtype).min, dtype=out_dtype).npu() - else: - y = torch.full(shape, float('-inf'), dtype=out_dtype).npu() - - y_tmp = y - x1_ref = torch.maximum(y_tmp, x0) - x1_ref = torch.maximum(x1_ref, x1) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_max_multi_d_2[(1, )](x0, x1, y, *triton_shape) - torch.testing.assert_close(y, x1_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_min.py b/third_party/ascend/unittest/generalization_cases/test_atomic_min.py deleted file mode 100644 index a74e99058f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_min.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import random -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_test_fn_atomic_min_dma(in_ptr0, in_ptr1, out_ptr1, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - index = xoffset + tl.arange(0, BLOCK_SIZE)[:] - mask = index < n_elements - inp0 = tl.load(in_ptr0 + (index), mask) - inp1 = tl.load(in_ptr1 + (index), mask) - tmp1 = tl.atomic_min(out_ptr1 + (index), inp0, mask) - tmp2 = tl.atomic_min(out_ptr1 + (index), inp1, mask) - - -def promote_dtype(x_dtype, y_dtype): - """ - 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 - """ - # 如果两个数据类型一致,直接返回 - if x_dtype == y_dtype: - return y_dtype - - # 构建类型的优先级列表(从低到高) - priority = [torch.int8, torch.int16, torch.int32, torch.float16, torch.bfloat16, torch.float32] - - # 查找两种类型在优先级列表中的位置 - x_priority = priority.index(x_dtype) - y_priority = priority.index(y_dtype) - - # 如果y的优先级比x小,则提升到x的类型 - if y_priority < x_priority: - return x_dtype - else: - return y_dtype - - -# torch.min do not support int -@pytest.mark.parametrize('shape', random.sample(TestUtils.test_shape2d + TestUtils.test_shape1d, 5)) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_min(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x0 = test_common.generate_tensor(shape, x_dtype_str) - x1 = test_common.generate_tensor(shape, y_dtype_str) - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: # 判断是否是整数类型 - out = torch.full(x1.shape, torch.iinfo(out_dtype).max, dtype=out_dtype) - else: - out = torch.full(x1.shape, torch.finfo(out_dtype).max, dtype=out_dtype) - - out_ref = torch.minimum(out, x0) - out_ref = torch.minimum(out_ref, x1) - out_ref = out_ref.npu() - x0 = x0.npu() - x1 = x1.npu() - out = out.npu() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] - triton_test_fn_atomic_min_dma[shape[0], 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=shape[1]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - triton_test_fn_atomic_min_dma[grid_size, 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) - - torch.testing.assert_close(out, out_ref) - - -# 3d -testlist = [ - (1, 22, 39), - (27, 1, 39), - (27, 22, 1), - (1, 1, 23), - (23, 1, 1), - (1, 23, 1), - (27, 5, 3), - (2, 29, 4), - (7, 31, 7), - (3, 5, 8), - (7, 17, 15), - (25, 5, 16), - (23, 5, 31), - (7, 11, 32), - (7, 11, 33), - (2, 3, 255), - (3, 3, 256), - (3, 2, 257), -] - - -@pytest.mark.parametrize('shape', random.sample(testlist, 5)) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_min_3d(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - ncore = 1 - split_size = shape[0] // ncore - x0 = test_common.generate_tensor(shape, x_dtype_str) - x1 = test_common.generate_tensor(shape, y_dtype_str) - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: - y = torch.full(shape, torch.iinfo(out_dtype).max, dtype=out_dtype) - else: - y = torch.full(shape, float('inf'), dtype=out_dtype) - - y_tmp = y - x1_ref = torch.minimum(y_tmp, x0) - x1_ref = torch.minimum(x1_ref, x1) - x0 = x0.npu() - x1 = x1.npu() - y = y.npu() - - n_elements = shape[0] * shape[1] * shape[2] - triton_test_fn_atomic_min_dma[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1] * shape[2]) - y = y.cpu() - torch.testing.assert_close(y, x1_ref) - - -@triton.jit -def atomic_min_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_min(out_ptr0 + offsets, tmp0) - - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'int64', 'bool'}] - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_min_4d_5d(dtype, shape): - x0_value = 1 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = torch.minimum(x1, x0) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_min_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@triton.jit -def atomic_min_multi_d_2(in_ptr0, out_ptr0, out_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - MB: tl.constexpr, NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tmp1 = tl.load(out_ptr0 + offsets) - tl.atomic_min(out_ptr1 + offsets, tmp0) - tl.atomic_min(out_ptr1 + offsets, tmp1) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_min_4d_5d_2(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - x1 = torch.randint(low=0, high=100, size=shape, dtype=y_dtype).npu() - else: - x1 = torch.randn(shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: - y = torch.full(shape, torch.iinfo(out_dtype).max, dtype=out_dtype).npu() - else: - y = torch.full(shape, float('inf'), dtype=out_dtype).npu() - - y_tmp = y - x1_ref = torch.minimum(y_tmp, x0) - x1_ref = torch.minimum(x1_ref, x1) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_min_multi_d_2[(1, )](x0, x1, y, *triton_shape) - torch.testing.assert_close(y, x1_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_or.py b/third_party/ascend/unittest/generalization_cases/test_atomic_or.py deleted file mode 100644 index 4e5493b362..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_or.py +++ /dev/null @@ -1,438 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import test_common -from test_common import TestUtils - -filtered_dtype = [ - dtype for dtype in TestUtils.full_dtype - if dtype not in {'uint32', 'float16', 'float32', 'bfloat16', 'int64', 'bool'} -] - - -@triton.jit -def atomic_or(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - tl.atomic_or(out_ptr0 + (out_index), tmp0, xmask) - - -@triton.jit -def atomic_or_ndim(x_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, DIM0: tl.constexpr, - DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): - sub_idx = tl.program_id(1) - base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE - base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE - offsets_src = tl.arange(0, BLOCK_SIZE) + base_src - offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst - mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 - tmp = tl.load(x_ptr + offsets_src, mask) - tl.atomic_or(out_ptr + offsets_dst, tmp, mask) - - -@triton.jit -def atomic_or_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic or: y |= x (broadcasted) - tl.atomic_or(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_or(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] * 2 - atomic_or[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - aligned_size = grid_size * BLOCK_SIZE - x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - x_concat[0:n_elements] = x[0:n_elements] - x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] - atomic_or[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) - - expected = y_temp | x_temp[0:shape[0]] | x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_3d(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_or[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - expected = y_temp | x_temp[0:shape[0]] | x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -@triton.jit -def atomic_or_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_or(out_ptr0 + offsets, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_or_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = x1 | x0_value - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_or_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1, 2), - (10, 1, 15, 1, 7), - (1, 1, 1, 1, 257), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_5d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - BLOCK_SIZE = 256 - ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_or_ndim[(2 * XB * YB * ZB * MB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=XB, - DIM1=YB, - DIM2=ZB, - DIM3=MB, - DIM4=NB, - ) - - expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1), - (1, 1, 2, 2), - (1, 3, 2, 7), - (1, 3, 2, 651), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_4d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - BLOCK_SIZE = 256 - ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_or_ndim[(2 * XB * YB * ZB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=XB, - DIM2=YB, - DIM3=ZB, - DIM4=MB, - ) - - expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1), - (1, 1, 2), - (1, 31, 275), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_3d_2(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - BLOCK_SIZE = 256 - ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_or_ndim[(2 * XB * YB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=XB, - DIM3=YB, - DIM4=ZB, - ) - - expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 2), - (1, 1), - (257, 1), - (257, 2), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_2d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - BLOCK_SIZE = 256 - ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_or_ndim[(2 * XB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=XB, - DIM4=YB, - ) - - expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [(1, ), (9, ), (256, ), (257, ), (65535, ), (65536, )]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_1d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB = triton_shape[0] - BLOCK_SIZE = 256 - ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_or_ndim[(2, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=1, - DIM4=XB, - ) - - expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('param_list', [ - ['uint8', (32, 32), 2], - ['uint16', (32, 32), 2], - ['uint32', (32, 32), 2], - ['uint64', (32, 32), 2], -]) -def test_atomic_or_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] // ncore - split_size = shape[0] // ncore - - val_cpu = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).cpu() - val = val_cpu.to("npu") - - pointer_cpu = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).cpu() - pointer = pointer_cpu.to("npu") - pointer_old_cpu = torch.full_like(pointer_cpu, -10).cpu() - pointer_old = pointer_old_cpu.to("npu") - pointer_ref_cpu = pointer_cpu.clone() - - for i in range(ncore - 1): - pointer_ref_cpu |= val_cpu[(i * split_size):((i + 1) * split_size)] - - pointer_ref_last = pointer_ref_cpu.clone() - pointer_ref_cpu |= val_cpu[((ncore - 1) * split_size):(ncore * split_size)] - pointer_ref = pointer_ref_cpu.to("npu") - - @triton.jit - def atomic_or_uint(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) - tmp1 = tl.atomic_or(out_ptr0 + (x1), tmp0, xmask) - tl.store(out_ptr1 + (x1), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_or_uint[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) - test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_xchg.py b/third_party/ascend/unittest/generalization_cases/test_atomic_xchg.py deleted file mode 100644 index 740378e929..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_xchg.py +++ /dev/null @@ -1,434 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import test_common -from test_common import TestUtils - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'bfloat16', 'bool'}] - - -@triton.jit -def atomic_xchg(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - tl.atomic_xchg(out_ptr0 + (out_index), tmp0, xmask) - - -@triton.jit -def atomic_xchg_ndim(x_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, DIM0: tl.constexpr, - DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): - sub_idx = tl.program_id(1) - base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE - base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE - offsets_src = tl.arange(0, BLOCK_SIZE) + base_src - offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst - mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 - tmp = tl.load(x_ptr + offsets_src, mask) - tl.atomic_xchg(out_ptr + offsets_dst, tmp, mask) - - -@triton.jit -def atomic_xchg_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic or: y |= x (broadcasted) - tl.atomic_xchg(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_xchg(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] * 2 - atomic_xchg[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - aligned_size = grid_size * BLOCK_SIZE - x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - x_concat[0:n_elements] = x[0:n_elements] - x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] - atomic_xchg[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) - - expected = x_temp[shape[0]:(shape[0] * 2)].expand(y_temp.shape) - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_3d(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_xchg[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - expected = x_temp[shape[0]:(shape[0] * 2)].expand(y_temp.shape) - torch.testing.assert_close(y, expected) - - -@triton.jit -def atomic_xchg_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_xchg(out_ptr0 + offsets, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_xchg_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = x0 - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_xchg_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@pytest.mark.parametrize('shaape', [ - (1, 1, 1, 1, 2), - (10, 1, 15, 1, 7), - (1, 1, 1, 1, 257), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_5d(x_dtype_str, shaape): - shape = shaape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - BLOCK_SIZE = 256 - ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xchg_ndim[(2 * XB * YB * ZB * MB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=XB, - DIM1=YB, - DIM2=ZB, - DIM3=MB, - DIM4=NB, - ) - - expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shaape', [ - (1, 1, 1, 1), - (1, 1, 2, 2), - (1, 3, 2, 7), - (1, 3, 2, 651), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_4d(x_dtype_str, shaape): - shape = shaape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - BLOCK_SIZE = 256 - ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xchg_ndim[(2 * XB * YB * ZB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=XB, - DIM2=YB, - DIM3=ZB, - DIM4=MB, - ) - - expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shaape', [ - (1, 1, 1), - (1, 1, 2), - (1, 31, 275), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_3d_2(x_dtype_str, shaape): - shape = shaape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - BLOCK_SIZE = 256 - ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xchg_ndim[(2 * XB * YB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=XB, - DIM3=YB, - DIM4=ZB, - ) - - expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shaape', [ - (1, 2), - (1, 1), - (257, 1), - (257, 2), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_2d(x_dtype_str, shaape): - shape = shaape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - BLOCK_SIZE = 256 - ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xchg_ndim[(2 * XB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=XB, - DIM4=YB, - ) - - expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shaape', [(1, ), (9, ), (256, ), (257, ), (65535, ), (65536, )]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_1d(x_dtype_str, shaape): - shape = shaape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB = triton_shape[0] - BLOCK_SIZE = 256 - ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xchg_ndim[(2, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=1, - DIM4=XB, - ) - - expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('param_list', [['uint8', (32, 32), 2], ['uint16', - (32, 32), 2], ['uint32', - (32, 32), 2], ['uint64', (32, 32), 2]]) -def test_atomic_xchg_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] // ncore - split_size = shape[0] // ncore - - val_cpu = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).cpu() - val = val_cpu.to("npu") - - pointer_cpu = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).cpu() - pointer = pointer_cpu.to("npu") - - pointer_ref = pointer.clone() - pointer_old_cpu = torch.full_like(val_cpu, -10).cpu() - pointer_old = pointer_old_cpu.to("npu") - pointer_old_ref = pointer_old.clone() - - pointer_ref = val[((ncore - 1) * split_size):(ncore * split_size)].clone() - pointer_old_ref[0:split_size] = pointer - pointer_old_ref[split_size:((ncore - 1) * split_size)] = val[0:(ncore - 2) * split_size] - - @triton.jit - def atomic_xchg_uint(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) - tmp1 = tl.atomic_xchg(out_ptr0 + (x1), tmp0, xmask) - tl.store(out_ptr1 + (x0), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_xchg_uint[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) - - pointer_cpu = pointer.cpu() - pointer_ref_cpu = pointer_ref.cpu() - assert (pointer_cpu == pointer_ref_cpu).all() diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_xor.py b/third_party/ascend/unittest/generalization_cases/test_atomic_xor.py deleted file mode 100644 index 4a83697261..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_xor.py +++ /dev/null @@ -1,441 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import test_common -from test_common import TestUtils - -filtered_dtype = [ - dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'float16', 'float32', 'bfloat16', 'bool'} -] - - -@triton.jit -def atomic_xor(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - tl.atomic_xor(out_ptr0 + (out_index), tmp0, xmask) - - -@triton.jit -def atomic_xor_ndim(x_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, DIM0: tl.constexpr, - DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): - sub_idx = tl.program_id(1) - base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE - base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE - offsets_src = tl.arange(0, BLOCK_SIZE) + base_src - offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst - mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 - tmp = tl.load(x_ptr + offsets_src, mask) - tl.atomic_xor(out_ptr + offsets_dst, tmp, mask) - - -@triton.jit -def atomic_xor_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic or: y |= x (broadcasted) - tl.atomic_xor(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_xor(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - if len(shape) == 1 and shape[0] == 1: # golden 问题,手动验证 - return - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] * 2 - atomic_xor[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - aligned_size = grid_size * BLOCK_SIZE - x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - x_concat[0:n_elements] = x[0:n_elements] - x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] - atomic_xor[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) - - expected = y_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_3d(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_xor[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - expected = y_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -@triton.jit -def atomic_xor_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_xor(out_ptr0 + offsets, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_xor_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = x1 ^ x0_value - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_xor_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1, 2), - (10, 1, 15, 1, 7), - (1, 1, 1, 1, 257), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_5d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - BLOCK_SIZE = 256 - ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xor_ndim[(2 * XB * YB * ZB * MB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=XB, - DIM1=YB, - DIM2=ZB, - DIM3=MB, - DIM4=NB, - ) - - expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1), - (1, 1, 2, 2), - (1, 3, 2, 7), - (1, 3, 2, 651), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_4d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - BLOCK_SIZE = 256 - ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xor_ndim[(2 * XB * YB * ZB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=XB, - DIM2=YB, - DIM3=ZB, - DIM4=MB, - ) - - expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1), - (1, 1, 2), - (1, 31, 275), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_3d_2(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - BLOCK_SIZE = 256 - ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xor_ndim[(2 * XB * YB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=XB, - DIM3=YB, - DIM4=ZB, - ) - - expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 2), - (1, 1), - (257, 1), - (257, 2), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_2d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - BLOCK_SIZE = 256 - ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xor_ndim[(2 * XB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=XB, - DIM4=YB, - ) - - expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [(1, ), (9, ), (256, ), (257, ), (65535, ), (65536, )]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_1d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB = triton_shape[0] - BLOCK_SIZE = 256 - ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xor_ndim[(2, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=1, - DIM4=XB, - ) - - expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('param_list', [ - ['uint8', (32, 32), 2], - ['uint16', (32, 32), 2], - ['uint32', (32, 32), 2], - ['uint64', (32, 32), 2], -]) -def test_atomic_xor_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] // ncore - split_size = shape[0] // ncore - - val_value = 3 - val_cpu = torch.full(shape, val_value, dtype=eval(f'torch.{dtype}')).cpu() - val = val_cpu.to("npu") - - pointer_value = 5 - pointer_cpu = torch.full((split_size, shape[1]), pointer_value, dtype=eval(f'torch.{dtype}')).cpu() - pointer = pointer_cpu.to("npu") - pointer_old_cpu = torch.full_like(pointer_cpu, -10).cpu() - pointer_old = pointer_old_cpu.to("npu") - - pointer_result = pointer_value - for _ in range(ncore): - pointer_result ^= val_value - - pointer_ref_cpu = torch.full_like(pointer_cpu, pointer_result).cpu() - pointer_ref = pointer_ref_cpu.to("npu") - - @triton.jit - def atomic_xor_uint(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) - tmp1 = tl.atomic_xor(out_ptr0 + (x1), tmp0, xmask) - tl.store(out_ptr1 + (x1), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_xor_uint[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) - test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_broadcast.py b/third_party/ascend/unittest/generalization_cases/test_broadcast.py deleted file mode 100644 index e9f7a46d8b..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_broadcast.py +++ /dev/null @@ -1,324 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_broadcast_1d(output_ptr, x_ptr, XS: tl.constexpr, YS: tl.constexpr): - xidx = tl.arange(0, XS)[None, :] - base = tl.load(x_ptr + xidx) - out = base.broadcast_to((YS, XS)) - oidx = tl.arange(0, YS)[:, None] * XS + tl.arange(0, XS)[None, :] - tl.store(output_ptr + oidx, out) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_npu_1d(shape, dtype): - XS = shape[0] - YS = 4 - - x = test_common.generate_tensor((XS, ), dtype=dtype).npu() - std = torch.broadcast_to(x, (YS, XS)) - output = test_common.generate_tensor((YS, XS), dtype=dtype).npu() - fn_broadcast_1d[1, 1, 1](output, x, XS, YS) - test_common.validate_cmp(dtype, std, output) - - -@triton.jit -def fn_broadcast_2d(output_ptr, x_ptr, NUMEL: tl.constexpr, XS: tl.constexpr, YS: tl.constexpr, ZS: tl.constexpr): - zoffset = tl.program_id(0) * ZS - zidx = tl.arange(0, ZS)[None, :] - base = tl.load(x_ptr + zoffset + zidx) - out = base.broadcast_to((YS, ZS)) - oidx = zoffset * YS + tl.arange(0, YS)[:, None] * ZS + zidx - tl.store(output_ptr + oidx, out) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_npu_2d(shape, dtype): - XS = shape[0] - ZS = shape[1] - YS = 4 - NUMEL = XS * ZS - - x = test_common.generate_tensor((XS, 1, ZS), dtype=dtype).npu() # randn not support int type - std = torch.broadcast_to(x, (XS, YS, ZS)) - output = test_common.generate_tensor((XS, YS, ZS), dtype=dtype).npu() - fn_broadcast_2d[XS, 1, 1](output, x, NUMEL, XS, YS, ZS) - test_common.validate_cmp(dtype, std, output) - - -@triton.jit -def triton_broadcast_to_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim0(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, M, N), dtype=dtype).npu() - ans = x0.repeat(L, 1, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim0[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim1(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, 1, N), dtype=dtype).npu() - ans = x0.repeat(1, M, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim1[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * 1 * M + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim2(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, M, 1), dtype=dtype).npu() - ans = x0.repeat(1, 1, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim2[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim01(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim01(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, 1, N), dtype=dtype).npu() - ans = x0.repeat(L, M, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim01[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim02(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * M * 1 + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim02(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, M, 1), dtype=dtype).npu() - ans = x0.repeat(L, 1, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim02[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim12(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * 1 * 1 + tl.arange(0, 1)[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim12(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, 1, 1), dtype=dtype).npu() - ans = x0.repeat(1, M, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim12[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def fn_broadcast_multi_d(to_ptr, from_ptr, F_L: tl.constexpr, F_M: tl.constexpr, F_N: tl.constexpr, F_X: tl.constexpr, - F_Y: tl.constexpr, T_L: tl.constexpr, T_M: tl.constexpr, T_N: tl.constexpr, T_X: tl.constexpr, - T_Y: tl.constexpr): - from_offsets = tl.arange(0, F_L) - if F_M is not None: - from_offsets = from_offsets[:, None] * F_M + tl.arange(0, F_M)[None, :] - if F_N is not None: - from_offsets = from_offsets[:, :, None] * F_N + tl.arange(0, F_N)[None, None, :] - if F_X is not None: - from_offsets = from_offsets[:, :, :, None] * F_X + tl.arange(0, F_X)[None, None, None, :] - if F_Y is not None: - from_offsets = from_offsets[:, :, :, :, None] * F_Y + tl.arange(0, F_Y)[None, None, None, None, :] - - to_offsets = tl.arange(0, T_L) - if T_M is not None: - to_offsets = to_offsets[:, None] * T_M + tl.arange(0, T_M)[None, :] - if T_N is not None: - to_offsets = to_offsets[:, :, None] * T_N + tl.arange(0, T_N)[None, None, :] - if T_X is not None: - to_offsets = to_offsets[:, :, :, None] * T_X + tl.arange(0, T_X)[None, None, None, :] - if T_Y is not None: - to_offsets = to_offsets[:, :, :, :, None] * T_Y + tl.arange(0, T_Y)[None, None, None, None, :] - - from_data = tl.load(from_ptr + from_offsets) - to_data = tl.load(to_ptr + to_offsets) - ret_data = tl.broadcast(from_data, to_data) - - tl.store(to_ptr + to_offsets, ret_data) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shapes', [ - [(1, 64, 16, 1), (2, 64, 16, 2)], - [(8, 1, 1, 2), (8, 8, 4, 2)], -]) -@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) -def test_broadcast_to_4d(shapes, dtype): - from_shape, to_shape = shapes - dtype = eval(f"torch.{dtype}") - - x = torch.randint(0, 8, from_shape, dtype=dtype).npu() - y = torch.randint(0, 8, to_shape, dtype=dtype).npu() - expected = x.expand(to_shape) - - grid = (1, ) - triton_from_shape = [*from_shape] - triton_to_shape = [*to_shape] - while len(triton_from_shape) < 5: - triton_from_shape.append(None) - triton_to_shape.append(None) - fn_broadcast_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) - assert (torch.equal(y, expected)) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) -@pytest.mark.parametrize('shapes', [ - [(1, 4, 2, 1, 4), (2, 4, 2, 8, 4)], - [(3, 1, 2, 1, 4), (3, 4, 2, 8, 4)], -]) -def test_broadcast_to_5d(shapes, dtype): - from_shape, to_shape = shapes - dtype = eval(f"torch.{dtype}") - - x = torch.randint(0, 8, from_shape, dtype=dtype).npu() - y = torch.randint(0, 8, to_shape, dtype=dtype).npu() - expected = x.expand(to_shape) - - grid = (1, ) - triton_from_shape = [*from_shape] - triton_to_shape = [*to_shape] - while len(triton_from_shape) < 5: - triton_from_shape.append(None) - triton_to_shape.append(None) - fn_broadcast_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) - assert (torch.equal(y, expected)) - - -@triton.jit -def fn_broadcast(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -XS: tl.constexpr = 2 -YS: tl.constexpr = 4 -ZS: tl.constexpr = 8 - - -@pytest.mark.parametrize('dtype', - ["uint8", "int8", "int16", "int32", "int64", "float16", "float32", "bfloat16", "bool"]) -def test_broadcast_alltype(dtype): - input = test_common.generate_tensor((1, YS, ZS), dtype).npu() - ans = input.repeat(XS, 1, 1) - output = torch.zeros((XS, YS, ZS), dtype=eval('torch.' + dtype)).npu() - fn_broadcast[1, 1, 1](input, output, XS, YS, ZS) - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_broadcast_to.py b/third_party/ascend/unittest/generalization_cases/test_broadcast_to.py deleted file mode 100644 index 4ec6173874..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_broadcast_to.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import time - -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_broadcast_1d(output_ptr, x_ptr, XS: tl.constexpr, YS: tl.constexpr): - xidx = tl.arange(0, XS)[None, :] - base = tl.load(x_ptr + xidx) - out = base.broadcast_to((YS, XS)) - oidx = tl.arange(0, YS)[:, None] * XS + tl.arange(0, XS)[None, :] - tl.store(output_ptr + oidx, out) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_npu_1d(shape, dtype): - XS = shape[0] - YS = 4 - - x = test_common.generate_tensor((XS, ), dtype=dtype).npu() - std = torch.broadcast_to(x, (YS, XS)) - output = test_common.generate_tensor((YS, XS), dtype=dtype).npu() - fn_broadcast_1d[1, 1, 1](output, x, XS, YS) - test_common.validate_cmp(dtype, std, output) - - -@triton.jit -def fn_broadcast_2d(output_ptr, x_ptr, NUMEL: tl.constexpr, XS: tl.constexpr, YS: tl.constexpr, ZS: tl.constexpr): - zoffset = tl.program_id(0) * ZS - zidx = tl.arange(0, ZS)[None, :] - base = tl.load(x_ptr + zoffset + zidx) - out = base.broadcast_to((YS, ZS)) - oidx = zoffset * YS + tl.arange(0, YS)[:, None] * ZS + zidx - tl.store(output_ptr + oidx, out) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_npu_2d(shape, dtype): - XS = shape[0] - ZS = shape[1] - YS = 4 - NUMEL = XS * ZS - - x = test_common.generate_tensor((XS, 1, ZS), dtype=dtype).npu() - std = torch.broadcast_to(x, (XS, YS, ZS)) - output = test_common.generate_tensor((XS, YS, ZS), dtype=dtype).npu() - fn_broadcast_2d[XS, 1, 1](output, x, NUMEL, XS, YS, ZS) - test_common.validate_cmp(dtype, std, output) - - -@triton.jit -def triton_broadcast_to_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim0(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, M, N), dtype=dtype).npu() - ans = x0.repeat(L, 1, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim0[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim1(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, 1, N), dtype=dtype).npu() - ans = x0.repeat(1, M, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim1[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * 1 * M + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim2(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, M, 1), dtype=dtype).npu() - ans = x0.repeat(1, 1, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim2[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim01(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim01(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, 1, N), dtype=dtype).npu() - ans = x0.repeat(L, M, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim01[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim02(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * M * 1 + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim02(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, M, 1), dtype=dtype).npu() - ans = x0.repeat(L, 1, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim02[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim12(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * 1 * 1 + tl.arange(0, 1)[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim12(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, 1, 1), dtype=dtype).npu() - ans = x0.repeat(1, M, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim12[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def fn_broadcast_to_multi_d(to_ptr, from_ptr, F_L: tl.constexpr, F_M: tl.constexpr, F_N: tl.constexpr, - F_X: tl.constexpr, F_Y: tl.constexpr, T_L: tl.constexpr, T_M: tl.constexpr, - T_N: tl.constexpr, T_X: tl.constexpr, T_Y: tl.constexpr): - from_offsets = tl.arange(0, F_L) - if F_M is not None: - from_offsets = from_offsets[:, None] * F_M + tl.arange(0, F_M)[None, :] - if F_N is not None: - from_offsets = from_offsets[:, :, None] * F_N + tl.arange(0, F_N)[None, None, :] - if F_X is not None: - from_offsets = from_offsets[:, :, :, None] * F_X + tl.arange(0, F_X)[None, None, None, :] - if F_Y is not None: - from_offsets = from_offsets[:, :, :, :, None] * F_Y + tl.arange(0, F_Y)[None, None, None, None, :] - - to_offsets = tl.arange(0, T_L) - if T_M is not None: - to_offsets = to_offsets[:, None] * T_M + tl.arange(0, T_M)[None, :] - if T_N is not None: - to_offsets = to_offsets[:, :, None] * T_N + tl.arange(0, T_N)[None, None, :] - if T_X is not None: - to_offsets = to_offsets[:, :, :, None] * T_X + tl.arange(0, T_X)[None, None, None, :] - if T_Y is not None: - to_offsets = to_offsets[:, :, :, :, None] * T_Y + tl.arange(0, T_Y)[None, None, None, None, :] - - from_data = tl.load(from_ptr + from_offsets) - if F_Y is not None: - ret_data = from_data.broadcast_to((T_L, T_M, T_N, T_X, T_Y)) - elif F_X is not None: - ret_data = from_data.broadcast_to((T_L, T_M, T_N, T_X)) - elif F_N is not None: - ret_data = from_data.broadcast_to((T_L, T_M, T_N)) - elif F_M is not None: - ret_data = from_data.broadcast_to((T_L, T_M)) - else: - ret_data = from_data.broadcast_to((T_L)) - - tl.store(to_ptr + to_offsets, ret_data) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shapes', [ - [(1, 64, 16, 1), (2, 64, 16, 2)], - [(8, 1, 1, 2), (8, 8, 4, 2)], -]) -@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) -def test_broadcast_to_4d(shapes, dtype): - from_shape, to_shape = shapes - dtype = eval(f"torch.{dtype}") - - x = torch.randint(0, 8, from_shape, dtype=dtype).npu() - y = torch.randint(0, 8, to_shape, dtype=dtype).npu() - expected = x.expand(to_shape) - - grid = (1, ) - triton_from_shape = [*from_shape] - triton_to_shape = [*to_shape] - while len(triton_from_shape) < 5: - triton_from_shape.append(None) - triton_to_shape.append(None) - fn_broadcast_to_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) - assert (torch.equal(y, expected)) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) -@pytest.mark.parametrize('shapes', [ - [(1, 4, 2, 1, 4), (2, 4, 2, 8, 4)], - [(3, 1, 2, 1, 4), (3, 4, 2, 8, 4)], -]) -def test_broadcast_to_5d(shapes, dtype): - from_shape, to_shape = shapes - dtype = eval(f"torch.{dtype}") - - x = torch.randint(0, 8, from_shape, dtype=dtype).npu() - y = torch.randint(0, 8, to_shape, dtype=dtype).npu() - expected = x.expand(to_shape) - - grid = (1, ) - triton_from_shape = [*from_shape] - triton_to_shape = [*to_shape] - while len(triton_from_shape) < 5: - triton_from_shape.append(None) - triton_to_shape.append(None) - fn_broadcast_to_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) - assert (torch.equal(y, expected)) - - -XS: tl.constexpr = 2 -YS: tl.constexpr = 4 -ZS: tl.constexpr = 8 -NUMEL: tl.constexpr = XS * ZS - - -@triton.jit -def fn_broadcast_to(output_ptr, input_ptr, length): - col_offsets = tl.arange(0, NUMEL) - input = tl.load(input_ptr + col_offsets) - result = input.reshape((XS, 1, ZS)).broadcast_to((XS, YS, ZS)).reshape((XS * YS * ZS)) - brc_col_offsets = tl.arange(0, NUMEL * YS) - tl.store(output_ptr + brc_col_offsets, result) - - -@pytest.mark.parametrize('dtype', - ["uint8", "int8", "int16", "int32", "int64", "float16", "float32", "bfloat16", "bool"]) -def test_broadcast_to_alltype(dtype): - length = NUMEL - input = test_common.generate_tensor((XS, 1, ZS), dtype).npu() - output = test_common.generate_tensor((XS, YS, ZS), dtype).npu() - fn_broadcast_to[1, 1, 1](output, input, length, debug=True) - assert (torch.equal(output, input.repeat(1, YS, 1))) diff --git a/third_party/ascend/unittest/generalization_cases/test_cast.py b/third_party/ascend/unittest/generalization_cases/test_cast.py deleted file mode 100644 index 3e1608cb97..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_cast.py +++ /dev/null @@ -1,391 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import random -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - - -@triton.jit -def cast_to_bool(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int1) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_i8(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int8) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_i16(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int16) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_i32(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int32) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_i64(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int64) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_fp32(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.float32) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_fp16(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.float16) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_bf16(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.bfloat16) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_uint32(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.uint32) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_int64(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int64) - tl.store(output_ptr + idx, ret) - - -triton_func_map = { - "bool": cast_to_bool, "int8": cast_to_i8, "int16": cast_to_i16, "int32": cast_to_i32, "float16": cast_to_fp16, - "bfloat16": cast_to_bf16, "float32": cast_to_fp32, "uint32": cast_to_uint32, "int64": cast_to_int64 -} - - -def structParam(x0): - dim = x0.dim() - stride0, stride1, stride2 = 0, 0, 0 - shape0, shape1, shape2 = 0, 0, 0 - if dim >= 1: - stride0 = x0.stride(0) - shape0 = x0.shape[0] - if dim >= 2: - stride1 = x0.stride(1) - shape1 = x0.shape[1] - if dim == 3: - stride2 = x0.stride(2) - shape2 = x0.shape[2] - return dim, stride0, stride1, stride2, shape0, shape1, shape2 - - -@pytest.mark.parametrize('shape', random.sample(TestUtils.full_shape, 5)) -@pytest.mark.parametrize('srcDtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dstDtype', TestUtils.full_dtype) -def test_cast(srcDtype, dstDtype, shape): - if srcDtype == dstDtype: - return - srcBytes = get_dtype_size(srcDtype) - dstBytes = get_dtype_size(dstDtype) - dtype_size = max(srcBytes, dstBytes) - if dstDtype == 'int8': - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 100): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 12): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - - x0 = test_common.generate_tensor(shape, srcDtype) - torch_res = x0.to(eval("torch." + dstDtype)) - x0 = x0.npu() - triton_func = triton_func_map.get(dstDtype, None) - assert triton_func is not None, f"triton_func not Found, srcDtype:{srcDtype}, dstDtype:{dstDtype}" - triton_res = torch.empty(shape, dtype=eval("torch." + dstDtype)).npu() - dim, stride0, stride1, stride2, XB, YB, ZB = structParam(x0) - assert 0 <= dim <= 3, f"dim out of range [0, 3], dim:{dim}" - triton_func[1, 1, 1](triton_res, x0, stride0, stride1, stride2, dim, XB, YB, ZB) - test_common.validate_cmp(dstDtype, triton_res, torch_res) - - -@triton.jit -def cast_to_multi_d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - dtype = output_ptr.type.element_ty - - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - X = tl.load(x_ptr + offsets) - ret = tl.cast(X, dtype=dtype) - - tl.store(output_ptr + offsets, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (6, 2, 4, 2), - (4, 2, 8, 4), - (4, 3, 8, 4), -]) -@pytest.mark.parametrize('srcDtype', ['int8', 'float16', 'float32']) -@pytest.mark.parametrize('dstDtype', ['int8', 'float16', 'float32']) -def test_cast_4d(srcDtype, dstDtype, shape): - if srcDtype == dstDtype: - return - srcBytes = get_dtype_size(srcDtype) - dstBytes = get_dtype_size(dstDtype) - dtype_size = max(srcBytes, dstBytes) - if dstDtype == 'int8': - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 100): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 12): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - - x0 = test_common.generate_tensor(shape, srcDtype) - torch_res = x0.to(eval("torch." + dstDtype)) - x0 = x0.npu() - - triton_res = torch.empty(shape, dtype=eval("torch." + dstDtype)).npu() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - grid = (1, ) - cast_to_multi_d[grid](triton_res, x0, *triton_shape) - test_common.validate_cmp(dstDtype, triton_res, torch_res) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 6, 2, 4, 2), - (2, 4, 2, 8, 4), - (3, 4, 2, 8, 4), -]) -@pytest.mark.parametrize('srcDtype', ['int8', 'float16', 'float32']) -@pytest.mark.parametrize('dstDtype', ['int8', 'float16', 'float32']) -def test_cast_5d(srcDtype, dstDtype, shape): - if srcDtype == dstDtype: - return - srcBytes = get_dtype_size(srcDtype) - dstBytes = get_dtype_size(dstDtype) - dtype_size = max(srcBytes, dstBytes) - if dstDtype == 'int8': - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 100): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 12): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - - x0 = test_common.generate_tensor(shape, srcDtype) - torch_res = x0.to(eval("torch." + dstDtype)) - x0 = x0.npu() - - triton_res = torch.empty(shape, dtype=eval("torch." + dstDtype)).npu() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - grid = (1, ) - cast_to_multi_d[grid](triton_res, x0, *triton_shape) - test_common.validate_cmp(dstDtype, triton_res, torch_res) - - -if __name__ == "__main__": - for shape in [(3, ), (3, 3), (3, 3, 3)]: - for srcDtype in ['int8', 'float32', 'bool']: - for dstDtype in ['int8', 'float32', 'bool']: - test_cast(srcDtype, dstDtype, shape) diff --git a/third_party/ascend/unittest/generalization_cases/test_cdiv.py b/third_party/ascend/unittest/generalization_cases/test_cdiv.py deleted file mode 100644 index 4f9afe73c6..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_cdiv.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math -import logging - - -def torch_cdiv(x0, x1, dtype): - return (x0 + x1 - 1) // x1 - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.cdiv(X, Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_cdiv_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.cdiv(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_case2(dtype, shape): - # 生成数据, cdiv int8 溢出的行为triton与torch_cpu不一致 - x = (test_common.generate_tensor(shape, dtype) // 2).abs().npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - y = (y.abs() // 2 + 1) - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_cdiv(x.cpu(), y.cpu(), eval('torch.' + dtype)) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if dtype == 'int8': - if x.numel() * x.element_size() >= 512: - grid = (1, 1, ZB) - ZB = 1 - else: - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_cdiv_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = (test_common.generate_tensor(shape, dtype) // 2).abs().npu() - y = test_common.generate_tensor(shape, dtype).npu() - y = (y.abs() // 2 + 1) - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_cdiv(x.cpu(), y.cpu(), eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_cdiv_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_ceil.py b/third_party/ascend/unittest/generalization_cases/test_ceil.py deleted file mode 100644 index bb0e925658..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_ceil.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import time - -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging -import math - - -def torch_ceil(x0): - res = torch.ceil(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.ceil(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_ceil_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.ceil(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_ceil(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_ceil_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_ceil(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_ceil_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_common.py b/third_party/ascend/unittest/generalization_cases/test_common.py deleted file mode 100644 index e6cf112f74..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_common.py +++ /dev/null @@ -1,343 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import os -import re -import torch -import torch_npu -import math -import logging -from typing import AnyStr -import pytest -import functools -import numpy as np - -_float_dtypes = ['float32', 'float16', 'bfloat16'] -_int_dtypes = ['int32', 'int64', 'int16', 'int8'] -_uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] - -log_level = os.getenv("LOG_LEVEL", "WARN").upper() -level_mapping = { - "DEBUG": logging.DEBUG, "INFO": logging.INFO, "WARN": logging.WARNING, "ERROR": logging.ERROR, "CRITICAL": - logging.CRITICAL -} - -logging.basicConfig(level=level_mapping.get(log_level, logging.WARNING), - format="[%(asctime)s][%(levelname)s] %(message)s") - -bisheng_not_support_dtypes = { - 'abs': [], 'eq': [], 'ne': [], 'flip': ['int64', - 'bfloat16'], 'load_store': ['int64'], 'permute2d': ['int64'], 'permute3d': - ['int64'], 'trans2d': ['int64'], 'trans3d': ['int64'], 'matmul': ['int16', 'int32', 'uint32', 'int64', 'bool'] -} - -tritonascend_not_support_dtypes = { - 'abs': ['bool'], - 'eq': ['bool'], - 'ne': ['bool'], - 'flip': ['bool'], - 'load_store': ['bool'], - 'permute2d': ['bool'], - 'permute3d': ['bool'], - 'trans2d': ['bool'], - 'trans3d': ['bool'], -} - - -def avoid_not_support(op: AnyStr): - - def decorator(test_func): - - @functools.wraps(test_func) - def wrapper(shape, dtype, *args, **kwargs): - if dtype in bisheng_not_support_dtypes.get(op, []): - logging.warn(f'skiped bisheng not support dtype:{dtype}') - return - if dtype in tritonascend_not_support_dtypes.get(op, []): - logging.warn(f'skiped triton ascend not support dtype:{dtype}') - return - return test_func(shape, dtype, *args, **kwargs) - - return wrapper - - return decorator - - -def get_shape1d(in_shape1d): - result = [] - for i in in_shape1d: - v = tuple((i, )) - result.append(v) - return result - - -def get_shape2d(in_shape1d, custom_shape): - result = [] - for a in in_shape1d: - for b in custom_shape: - t1 = tuple((a, b)) - t2 = tuple((b, a)) - if t1 not in result: - result.append(t1) - if t2 not in result: - result.append(t2) - return result - - -def get_shape3d(): - return [(1, 22, 39), (27, 1, 39), (27, 22, 1), (23, 1, 1), (1, 23, 1), (1, 1, 23), (37, 5, 3), (2, 29, 4), - (7, 31, 7), (3, 5, 8), (7, 17, 15), (23, 5, 16), (23, 5, 31), (7, 11, 32), (7, 11, 33), (2, 3, 255), - (3, 3, 256), (3, 2, 257)] - - -def get_shape1_2_3d(in_shape1d, custom_shape): - return get_shape1d(in_shape1d) + get_shape2d(in_shape1d, custom_shape) + get_shape3d() - - -class TestUtils: - in_shape1d = [1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 37, 741] - custom_shape = [3, 13, 32, 256] - batch = [1, 2, 3, 4, 5, 8] - test_shape1d = get_shape1d(in_shape1d) - test_shape2d = get_shape2d(in_shape1d, custom_shape) - test_shape3d = [ - (1, 22, 39), - (27, 1, 39), - (27, 22, 1), - (1, 1, 23), - (23, 1, 1), - (1, 23, 1), - (37, 5, 3), - (2, 29, 4), - (7, 31, 7), - (3, 5, 8), - (7, 17, 15), - (25, 5, 16), - (23, 5, 31), - (7, 11, 32), - (7, 11, 33), - (2, 3, 255), - (3, 3, 256), - (3, 2, 257), - ] - test_shape4d = [(8, 4, 8, 8), (1, 11, 16, 2)] - test_shape5d = [(2, 3, 4, 5, 6), (1, 3, 4, 5, 6), (3, 6, 2, 4, 4)] - test_shape6d = [(2, 3, 5, 6, 3, 2)] - test_shape7d = [(1, 2, 3, 4, 3, 2, 2)] - test_shape_ub_overflow = [(10, 50, 1000)] - test_shape8d = [(1, 2, 3, 2, 5, 3, 7, 2), (1, 3, 2, 5, 6, 7, 2, 1), (2, 3, 7, 3, 2, 3, 2, 3)] - full_shape_4_8d = test_shape4d + test_shape5d + test_shape6d + test_shape7d + test_shape8d - - full_shape = test_shape1d + test_shape2d + test_shape3d - test_shape1_2_3d = full_shape - full_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32', 'bool'] - ub_size = 98304 * 2 - dtype_list = full_dtype - - -def get_dtype_size(dtype): - torch_dtype = eval('torch.' + dtype) - bits = 0 - if torch_dtype == torch.bool: - bits = 8 - elif torch.is_floating_point(torch.tensor(0, dtype=torch_dtype)): - bits = torch.finfo(torch_dtype).bits - else: - bits = torch.iinfo(torch_dtype).bits - return bits // 8 - - -def check_ub_mem_overflow(dtype, shape): - bytes = get_dtype_size(dtype) - if bytes * math.prod(shape) > TestUtils.ub_size: - logging.warning(f'dtype:{dtype} shape:{shape} mem overflow') - return True - return False - - -def generate_numpy(shape, dtype, low=None, high=None): - if dtype in _int_dtypes + _uint_dtypes: - iinfo = np.iinfo(getattr(np, dtype)) - low = iinfo.min if low is None else max(low, iinfo.min) - high = iinfo.max if high is None else min(high, iinfo.max) - dty = getattr(np, dtype) - return np.random.randint(low, high, shape, dtype=dty) - elif dtype == 'float16' or dtype == 'float32': - return np.random.normal(0, 1, shape).astype(dtype) - elif dtype == 'bfloat16': - return (np.random.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') - elif dtype == 'bool': - return np.random.randint(low=0, high=2, size=shape).astype(bool) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def generate_tensor(shape, dtype): - if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': - return torch.randn(size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'uint32': - return torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int8': - return torch.randint(low=0, high=127, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'bool': - return torch.randint(low=0, high=2, size=shape).bool() - elif dtype == 'uint8': - return torch.randint(low=0, high=255, size=shape, dtype=torch.uint8) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def generate_tensor_int_withSigns(shape, dtype): - if dtype == 'int32' or dtype == 'int64' or dtype == 'int16': - return torch.randint(low=-32768, high=32767, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int8': - return torch.randint(low=-128, high=127, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'bool': - return torch.randint(low=0, high=2, size=shape).bool() - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def get_triton_sig_typename(dtype): - if dtype == 'float32': - tyname = "*fp32" - elif dtype == 'int32': - tyname = "*i32" - elif dtype == 'int64': - tyname = "*i64" - elif dtype == 'float16': - tyname = "*fp16" - elif dtype == 'int16': - tyname = "*i16" - elif dtype == 'int8': - tyname = "*i8" - elif dtype == 'bool': - tyname = "*i1" - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - return tyname - - -# Relative error: abs(x_ref - x_cal) / abs(x_ref) -# Absolute error: abs(x_ref - x_cal) - - -# calculation type operators require different error range -# It is a stricter verification and not satisfied now, save it here -def validate_cal(dtype, y_cal, y_ref): - if dtype == 'float16': - if torch.mean(y_ref) < 0.001: - assert torch.abs(y_cal - y_ref) < 0.001, "|y_cal - y_ref| < 0.001 is required !" - else: - diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.001 - # all true - assert diff.all(), "Relative error is less than 0.001 !" - if dtype == 'float32': - if torch.mean(y_ref) < 0.0001: - assert torch.abs(y_cal - y_ref) < 0.0001, "|y_cal - y_ref| < 0.0001 is required !" - else: - diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.0001 - assert diff.all(), "Relative error is less than 0.001 !" - elif dtype == 'bfloat16': - diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.001 - assert diff.all(), "Relative error is less than 0.001 !" - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8': - assert torch.equal(y_cal, y_ref) - elif dtype == 'uint8': - assert torch.equal(y_cal, y_ref) - elif dtype == 'bool': - assert torch.equal(y_cal, y_ref) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -# moving and comparison ops require no precision error -def validate_cmp(dtype, y_cal, y_ref): - y_cal = y_cal.npu() - y_ref = y_ref.npu() - if dtype == 'float16': - torch.testing.assert_close(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) - elif dtype == 'bfloat16': - torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=1e-03, atol=1e-03, - equal_nan=True) - elif dtype == 'float32': - torch.testing.assert_close(y_ref, y_cal, rtol=1e-04, atol=1e-04, equal_nan=True) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8': - assert torch.equal(y_cal, y_ref) - elif dtype == 'uint8' or dtype == 'uint16' or dtype == 'uint32' or dtype == 'uint64': - assert torch.equal(y_cal, y_ref) - elif dtype == 'bool': - assert torch.equal(y_cal.cpu(), y_ref.cpu()) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def validate_cmp_with_expection(dtype, y_cal, y_ref, expect): - if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': - if expect: - assert torch.allclose(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) - else: - assert not torch.allclose(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8' \ - or dtype == 'uint8' or dtype == 'uint16' or dtype == 'uint32' or dtype == 'uint64': - if expect: - assert torch.equal(y_cal, y_ref) - else: - assert not torch.equal(y_cal, y_ref) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def raises_with_match(expected_exception, match_pattern): - - def decorator(test_func): - - @functools.wraps(test_func) - def wrapper(*args, **kwargs): - with pytest.raises(expected_exception, match=match_pattern): - return test_func(*args, **kwargs) - - return wrapper - - return decorator - - -def capture_output(expected_output): - - def decorator(test_func): - - @functools.wraps(test_func) - def wrapper(*args, **kwargs): - capsys = kwargs.pop('capsys', None) - if capsys is None: - try: - capsys = pytest.fixture(capsys)() - except: - raise RuntimeError("This decorator requires pytest's capsys fixture") - test_func(capsys, *args, **kwargs) - captured = capsys.readouterr() - # pybind11::scoped_ostream_redirect captures std::cout with \x00 inserted - # for now, no idea how to eliminate \x00 from C++ side. - cleaned = re.sub(r"\x00", "", captured.out) - assert expected_output in cleaned - - return wrapper - - return decorator diff --git a/third_party/ascend/unittest/generalization_cases/test_cos.py b/third_party/ascend/unittest/generalization_cases/test_cos.py deleted file mode 100644 index 9f1980a515..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_cos.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import time - -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging -import math - - -def torch_cos(x0): - res = torch.cos(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.cos(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_cos_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.cos(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_cos(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_cos_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_cos(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_cos_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_count_dim0.py b/third_party/ascend/unittest/generalization_cases/test_count_dim0.py deleted file mode 100644 index 826f649909..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_count_dim0.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -def standard_count(x0, cmp_val, dim, dtype): - res = (x0 == cmp_val).sum(dim=dim) - return res - - -def standard_count_gt(x0, cmp_val, dim, dtype): - res = (x0 > cmp_val).sum(dim=dim) - return res - - -def standard_count_lt(x0, cmp_val, dim, dtype): - res = (x0 < cmp_val).sum(dim=dim) - return res - - -@triton.jit -def count(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, N) + tl.program_id(2) * N - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x == cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + nblk_idx, ret, mask=nmask) - - -@triton.jit -def count_gt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, N) + tl.program_id(2) * N - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x > cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + nblk_idx, ret, mask=nmask) - - -@triton.jit -def count_lt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, N) + tl.program_id(2) * N - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x < cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + nblk_idx, ret, mask=nmask) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8']) -def test_count_dim0_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count(x0, cmp_val, 0, dtype) - - output = torch.zeros((shape[1], ), dtype=torch.float32).npu() - count[1, 1, rblock](x0, output, cmp_val, 0, xblock, 1, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) -def test_count_gt_dim0_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count_gt(x0, cmp_val, 0, dtype) - - output = torch.zeros((shape[1], ), dtype=torch.float32).npu() - count_gt[1, 1, rblock](x0, output, cmp_val, 0, xblock, 1, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) -def test_count_lt_dim0_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count_lt(x0, cmp_val, 0, dtype) - - output = torch.zeros((shape[1], ), dtype=torch.float32).npu() - count_lt[1, 1, rblock](x0, output, cmp_val, 0, xblock, 1, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) diff --git a/third_party/ascend/unittest/generalization_cases/test_count_dim1.py b/third_party/ascend/unittest/generalization_cases/test_count_dim1.py deleted file mode 100644 index ebd19cf7ab..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_count_dim1.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -def standard_count(x0, cmp_val, dim, dtype): - res = (x0 == cmp_val).sum(dim=dim) - return res - - -def standard_count_gt(x0, cmp_val, dim, dtype): - res = (x0 > cmp_val).sum(dim=dim) - return res - - -def standard_count_lt(x0, cmp_val, dim, dtype): - res = (x0 < cmp_val).sum(dim=dim) - return res - - -@triton.jit -def count(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, M) + tl.program_id(1) * M - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x == cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) - - -@triton.jit -def count_gt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, M) + tl.program_id(1) * M - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x > cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) - - -@triton.jit -def count_lt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, M) + tl.program_id(1) * M - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x < cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8']) -def test_count_dim1_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count(x0, cmp_val, 1, dtype) - - output = torch.zeros((shape[0], ), dtype=torch.float32).npu() - count[1, xblock, 1](x0, output, cmp_val, 1, 1, rblock, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) -def test_count_gt_dim1_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count_gt(x0, cmp_val, 1, dtype) - - output = torch.zeros((shape[0], ), dtype=torch.float32).npu() - count_gt[1, xblock, 1](x0, output, cmp_val, 1, 1, rblock, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) -def test_count_lt_dim1_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count_lt(x0, cmp_val, 1, dtype) - - output = torch.zeros((shape[0], ), dtype=torch.float32).npu() - count_lt[1, xblock, 1](x0, output, cmp_val, 1, 1, rblock, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) diff --git a/third_party/ascend/unittest/generalization_cases/test_cumprod.py b/third_party/ascend/unittest/generalization_cases/test_cumprod.py deleted file mode 100644 index 9af5216d9b..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_cumprod.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import torch_npu -import triton -import triton.language as tl -from triton.runtime.libentry import libentry - -from test_common import TestUtils, validate_cmp, get_dtype_size - - -def torch_func(x, dim, reverse): - is_bf16 = x.dtype == torch.bfloat16 - if is_bf16: - x = x.to(torch.float32) - if reverse: - x = torch.flip(x, [dim]) - res = torch.cumprod(x, dim=dim) - if is_bf16: - res = res.to(torch.bfloat16) - return res - - -@libentry() -@triton.jit -def triton_kernel_1d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - XBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - idx = tl.arange(0, XBLOCK) - x = tl.load(in_ptr0 + idx) - ret = tl.cumprod(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_2d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx = idx_x[:, None] * numel_r + idx_r[None, :] - x = tl.load(in_ptr0 + idx) - ret = tl.cumprod(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_3d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - numel_z: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, - ZBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - tl.static_assert(numel_z == ZBLOCK, "numel_z must be equal to ZBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx_z = tl.arange(0, ZBLOCK) - idx = idx_x[:, None, None] * numel_r * numel_z + idx_r[None, :, None] * numel_z + idx_z[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = tl.cumprod(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_4d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - idx = (xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + - zidx[None, None, :, None] * MB + midx[None, None, None, :]) - x = tl.load(in_ptr0 + idx) - ret = tl.cumprod(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_5d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, - NB: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - idx = (xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + - zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + - nidx[None, None, None, None, :]) - x = tl.load(in_ptr0 + idx) - ret = tl.cumprod(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -def convert_cumprod_dtype(x: torch.Tensor) -> torch.Tensor: - """ - 根据 cumprod 类型转换规则,返回转换后的张量。 - """ - dtype_map = { - torch.int8: torch.int64, - torch.int16: torch.int64, - torch.int32: torch.int64, - torch.int64: torch.int64, - torch.bfloat16: torch.bfloat16, - torch.float16: torch.float16, - torch.float32: torch.float32, - torch.bool: torch.int64, - } - - target_dtype = dtype_map.get(x.dtype, None) - if target_dtype is None: - raise ValueError(f"Unsupported input dtype for cumprod conversion: {x.dtype}") - - return x.to(target_dtype) - - -def triton_func(x, dim, reverse): - x = convert_cumprod_dtype(x) - - res = torch.empty_like(x) - shape = x.size() - if len(shape) == 1: - if dim >= 1: - pytest.skip("dim >= 1 for 1D tensor, skipping.") - triton_kernel_1d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[0]) - elif len(shape) == 2: - if dim >= 2: - pytest.skip("dim >= 2 for 2D tensor, skipping.") - triton_kernel_2d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1]) - elif len(shape) == 3: - if dim >= 3: - pytest.skip("dim >= 3 for 3D tensor, skipping.") - triton_kernel_3d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[0], x.shape[1], - x.shape[2]) - elif len(shape) == 4: - if dim >= 4: - pytest.skip("dim >= 4 for 4D tensor, skipping.") - triton_kernel_4d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3]) - elif len(shape) == 5: - if dim >= 5: - pytest.skip("dim >= 5 for 5D tensor, skipping.") - triton_kernel_5d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4]) - else: - pytest.skip(f"Unsupported tensor dimension: {len(shape)}") - - return res - - -def cumprod_generate_tensor(shape, dtype): - if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': - return torch.rand(size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': - return torch.randint(low=1, high=5, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int8': - return torch.randint(low=1, high=5, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'bool': - return torch.randint(low=0, high=2, size=shape).bool() - else: - raise ValueError(f"Unsupported dtype: {dtype}") - - -def should_skip_due_to_mem(dtype, shape): - dtype_size = get_dtype_size(dtype) - total_mem = dtype_size * math.prod(shape) - - if dtype in ('int8', 'bool'): - threshold = TestUtils.ub_size / 13 - else: - threshold = TestUtils.ub_size / 6 - - if total_mem >= threshold: - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - - -# reverse=True not support; -@pytest.mark.parametrize("dtype", TestUtils.full_dtype) -@pytest.mark.parametrize("shape", TestUtils.full_shape) -@pytest.mark.parametrize("dim", [0, 1, 2, 3, 4]) -@pytest.mark.parametrize("reverse", [False]) -def test_cumprod(dtype, shape, dim, reverse): - should_skip_due_to_mem(dtype, shape) - - x = cumprod_generate_tensor(shape=shape, dtype=dtype) - x_npu = x.npu() - - triton_res = triton_func(x_npu, dim, reverse) - - x_gold = x - cpu_res = torch_func(x_gold, dim, reverse) - - validate_cmp(dtype, triton_res, cpu_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_cumsum.py b/third_party/ascend/unittest/generalization_cases/test_cumsum.py deleted file mode 100644 index 06ef04eceb..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_cumsum.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import torch_npu -import triton -import triton.language as tl -from triton.runtime.libentry import libentry - -import acc_util -import test_common -from test_common import TestUtils, get_dtype_size - - -def torch_func(x, dim, reverse): - if reverse: - x = torch.flip(x, [dim]) - res = torch.cumsum(x, dim=dim) - return res - - -@libentry() -@triton.jit -def triton_kernel_1d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - XBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - idx = tl.arange(0, XBLOCK) - x = tl.load(in_ptr0 + idx) - ret = tl.cumsum(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_2d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx = idx_x[:, None] * numel_r + idx_r[None, :] - x = tl.load(in_ptr0 + idx) - ret = tl.cumsum(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_3d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - numel_z: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, - ZBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - tl.static_assert(numel_z == ZBLOCK, "numel_z must be equal to ZBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx_z = tl.arange(0, ZBLOCK) - idx = idx_x[:, None, None] * numel_r * numel_z + idx_r[None, :, None] * numel_z + idx_z[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = tl.cumsum(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_4d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - idx = (xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + - zidx[None, None, :, None] * MB + midx[None, None, None, :]) - x = tl.load(in_ptr0 + idx) - ret = tl.cumsum(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_5d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, - NB: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - idx = (xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + - zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + - nidx[None, None, None, None, :]) - x = tl.load(in_ptr0 + idx) - ret = tl.cumsum(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -def convert_cumsum_dtype(x: torch.Tensor) -> torch.Tensor: - """ - 根据 cumsum 类型转换规则,返回转换后的张量。 - """ - dtype_map = { - torch.int8: torch.int64, - torch.int16: torch.int64, - torch.int32: torch.int64, - torch.int64: torch.int64, - torch.bfloat16: torch.bfloat16, - torch.float16: torch.float16, - torch.float32: torch.float32, - torch.bool: torch.int64, - } - - target_dtype = dtype_map.get(x.dtype, None) - if target_dtype is None: - raise ValueError(f"Unsupported input dtype for cumsum conversion: {x.dtype}") - - return x.to(target_dtype) - - -def triton_func(x, dim, reverse): - x = convert_cumsum_dtype(x) - - res = torch.empty_like(x) - shape = x.size() - if len(shape) == 1: - if dim >= 1: - pytest.skip("dim >= 1 for 1D tensor, skipping.") - triton_kernel_1d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[0]) - elif len(shape) == 2: - if dim >= 2: - pytest.skip("dim >= 2 for 2D tensor, skipping.") - triton_kernel_2d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1]) - elif len(shape) == 3: - if dim >= 3: - pytest.skip("dim >= 3 for 3D tensor, skipping.") - triton_kernel_3d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[0], x.shape[1], - x.shape[2]) - elif len(shape) == 4: - if dim >= 4: - pytest.skip("dim >= 4 for 4D tensor, skipping.") - triton_kernel_4d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3]) - elif len(shape) == 5: - if dim >= 5: - pytest.skip("dim >= 5 for 5D tensor, skipping.") - triton_kernel_5d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4]) - else: - pytest.skip(f"Unsupported tensor dimension: {len(shape)}") - - return res - - -def cumsum_generate_tensor(shape, dtype): - if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': - return torch.rand(size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': - return torch.randint(low=0, high=3, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int8': - return torch.randint(low=0, high=3, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'bool': - return torch.randint(low=0, high=2, size=shape).bool() - else: - raise ValueError(f"Unsupported dtype: {dtype}") - - -def should_skip_due_to_mem(dtype, shape): - dtype_size = get_dtype_size(dtype) - total_mem = dtype_size * math.prod(shape) - - if dtype in ('int8', 'bool'): - threshold = TestUtils.ub_size / 13 - else: - threshold = TestUtils.ub_size / 6 - - if total_mem >= threshold: - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - - -# reverse=True not support; - - -@pytest.mark.parametrize("dtype", TestUtils.full_dtype) -@pytest.mark.parametrize("shape", TestUtils.full_shape) -@pytest.mark.parametrize("dim", [0, 1, 2, 3, 4]) -@pytest.mark.parametrize("reverse", [False]) -def test_cumsum(dtype, shape, dim, reverse): - should_skip_due_to_mem(dtype, shape) - - x = cumsum_generate_tensor(shape=shape, dtype=dtype) - x_npu = x.npu() - - triton_res = triton_func(x_npu, dim, reverse) - - x_gold = x - cpu_res = torch_func(x_gold, dim, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_debug_barrier.py b/third_party/ascend/unittest/generalization_cases/test_debug_barrier.py deleted file mode 100644 index fb17fcb23f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_debug_barrier.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import numpy as np -import torch -import logging -import pytest -import test_common -from test_common import TestUtils - - -def torch_invert(x0, ddtype): - if 'float' in str(ddtype): - x0 = x0.to(torch.int32) - y_ref = ~x0 - y_ref = y_ref.to(ddtype) - else: - y_ref = ~x0 - return y_ref - - -@triton.jit -def triton_sub(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X - Y - tl.debug_barrier() - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_invert_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = ~x_val - tl.debug_barrier() - tl.store(output_ptr + offsets, ret, mask=masks) - - -test_shape_1d_2d_3d = [(1, ), (2, ), (1, 1), (3, 13), (1, 1, 1), (4, 3, 8)] -test_shape_4_5d = [(1, 1, 1, 1), (2, 2, 2, 2), (1, 1, 1, 1, 1), (2, 2, 2, 2, 1)] - - -@pytest.mark.parametrize('shape', test_shape_1d_2d_3d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_sub(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x - y - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if dtype == 'int8': - if x.numel() * x.element_size() >= 512: - grid = (1, 1, ZB) - ZB = 1 - else: - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - triton_sub[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', test_shape_1d_2d_3d + test_shape_4_5d) -@pytest.mark.parametrize('dtype', ['bool']) -def test_invert_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_invert(x, eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_invert_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_device_print.py b/third_party/ascend/unittest/generalization_cases/test_device_print.py deleted file mode 100644 index 5421db96ed..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_device_print.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -import torch_npu -import triton -import triton.language as tl -import pytest -import sys -import os -import subprocess -import tempfile -import textwrap - -os.environ["TRITON_DEVICE_PRINT"] = "1" -os.environ["TRITON_ENABLE_TASKQUEUE"] = "0" - -shape = (8, ) -XS = 8 -XVALS_INT = [ - 0, - torch.iinfo(torch.int8).min, - torch.iinfo(torch.int8).max, - torch.iinfo(torch.int16).min, - torch.iinfo(torch.int16).max, - torch.iinfo(torch.int32).min, - torch.iinfo(torch.int32).max, - torch.iinfo(torch.int32).max + 1 -] - - -@pytest.mark.parametrize('sigtype', ['int32', 'int64', 'int16', 'int8', 'float32', 'float16', 'bfloat16']) -def test_device_print_int32(sigtype): - - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: - temp_script = f.name - - f.write( - textwrap.dedent(f""" -import torch -import torch_npu -import triton -import triton.language as tl -import os -import sys - -os.environ["TRITON_DEVICE_PRINT"] = "1" -os.environ["TRITON_ENABLE_TASKQUEUE"] = "0" - -@triton.jit -def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr): - idx = tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.load(in_ptr1 + idx) - tmp2 = tmp0 + tmp1 - tl.device_print("OUTPUT = ", tmp2) - tl.store(out_ptr0 + idx, tmp2) - -def main(): - shape = (8,) - XS = 8 - dtype = torch.{sigtype} - - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - - XVALS_INT = [0, -128, 127, -32768, 32767, -2147483648, 2147483647, 2147483648] - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - - out = torch.empty_like(x0) - - triton_kernel[1,](out, x0, x1, XS) - - print("Kernel execution completed") - - return out - -if __name__ == "__main__": - result = main() - print(f"Result shape: {{result.shape}}") - """)) - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - - torch_ref = x0 + x1 - if 'int' in sigtype: - torch_ref_str = ','.join([str(int(val)) for val in torch_ref.cpu().numpy()]) - else: - values = torch_ref.cpu() - if values.dtype == torch.bfloat16: - values = values.float() - torch_ref_str = ','.join([f"{float(val):.6f}" for val in values.numpy()]) - - result = subprocess.run([sys.executable, temp_script], capture_output=True, text=True, env=os.environ.copy()) - - captured_output = result.stdout + "\n=== STDERR ===\n" + result.stderr - - ##with open(f"manual_capture_{sigtype}.txt", "w") as f: - ##f.write(captured_output) - ##f.write(f"torch_ref:{torch_ref_str}") - - if os.path.exists(temp_script): - os.remove(temp_script) - - assert torch_ref_str in captured_output diff --git a/third_party/ascend/unittest/generalization_cases/test_div_rn.py b/third_party/ascend/unittest/generalization_cases/test_div_rn.py deleted file mode 100644 index f0ce253288..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_div_rn.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import time - -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging -import math - - -def torch_divRn(x0, x1): - return x0 / x1 - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.div_rn(X, Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_div_rn_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.div_rn(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_divRn(x, y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_div_rn_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - ans = torch_divRn(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_div_rn_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_dot_scaled.py b/third_party/ascend/unittest/generalization_cases/test_dot_scaled.py deleted file mode 100644 index c40360e45a..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_dot_scaled.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import contextlib -import itertools -import re -import math -import textwrap -import os -import inspect -import pathlib -import test_common -import numpy as np -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - -from numpy.random import RandomState -from triton.language.extra import libdevice - - -@triton.jit -def dot_scale_kernel(a_base, stride_a0: tl.constexpr, stride_a1: tl.constexpr, a_scale, b_base, stride_b0: tl.constexpr, - stride_b1: tl.constexpr, b_scale, out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, type_a: tl.constexpr, type_b: tl.constexpr, acc_num: tl.constexpr): - PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K - PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K - str_a0: tl.constexpr = stride_a0 - a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, str_a0)[None, :] * stride_a1 - b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, BLOCK_N)[None, :] * stride_b1 - - a = tl.load(a_ptr) - b = tl.load(b_ptr) - SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - if a_scale is not None: - scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] - a_scale = tl.load(scale_a_ptr) - if b_scale is not None: - scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] - b_scale = tl.load(scale_b_ptr) - accumulator = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b, acc=accumulator, out_dtype=tl.float32) - if acc_num is not None: - for _ in range(acc_num): - accumulator = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b, acc=accumulator, out_dtype=tl.float32) - - out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] - tl.store(out_ptr, accumulator.to(a.dtype)) - - -def golden_ref(x, scale_x, y, scale_y): - shape_expand_x = x.shape[-1] // scale_x.shape[-1] - if x.dtype == torch.bfloat16: - upscale_x = scale_x.repeat_interleave(shape_expand_x, dim=1).to(torch.int16) - upscale_x = (upscale_x + 127 << 7).view(torch.bfloat16) - else: - scale_fp32 = scale_x.repeat_interleave(shape_expand_x, dim=1).to(torch.int32) - scale_fp32 = (scale_fp32 + 127 << 23).view(torch.float32) - upscale_x = scale_fp32.to(torch.float16) - upscale_y = None - if scale_y is None: - upscale_y = torch.ones_like(y) - else: - scale_y = scale_y.T - shape_expand_y = y.shape[0] // scale_y.shape[0] - if y.dtype == torch.bfloat16: - upscale_y = scale_y.repeat_interleave(shape_expand_y, dim=0).to(torch.int16) - upscale_y = (upscale_y + 127 << 7).view(torch.bfloat16) - else: - scale_fp32 = scale_y.repeat_interleave(shape_expand_y, dim=0).to(torch.int32) - scale_fp32 = (scale_fp32 + 127 << 23).view(torch.float32) - upscale_y = scale_fp32.to(torch.float16) - ret = torch.matmul(x * upscale_x, y * upscale_y) - return ret - - -@pytest.mark.parametrize("M, N, K, rhs_scale, normal_type, acc_num, num_warps", - [(M, N, K, rhs_scale, normal_type, acc_num, 4) - for M, N, K in itertools.product([16, 32, 64, 128], [16, 32, 64, 128], [32, 64]) - for rhs_scale in [False, True] - for normal_type in ["bf16", "fp16"] - for acc_num in [None, 1, 2]]) -def test_scaled_dot(M, N, K, rhs_scale, normal_type, num_warps, acc_num): - device = "npu" - - # The max exponent we use to initialize data in the x/y and associated scale tensor to avoid - # overflow when scaling. - comp_dtype_max_exp = 6 if normal_type == "fp16" else 15 - - torch.manual_seed(0) - - def make_arg(shape, ty): - if ty == "bf16" or ty == "fp16": - comp_dtype = torch.float16 if ty == "fp16" else torch.bfloat16 - ret = torch.randn(shape, dtype=comp_dtype, device=device) - # Clamp to avoid relative error issues - ret.clamp_(-2**comp_dtype_max_exp, 2**comp_dtype_max_exp - 1) - else: - ret = torch.randint(256, shape, dtype=torch.int8, device=device) - return ret - - type_a = normal_type - type_b = type_a - - x = make_arg((M, K), type_a) - y = make_arg((K, N), type_b) - - min_scale, max_scale = (0, 142) if type_a == torch.bfloat16 else (124, 131) - scale_x = torch.randint(min_scale - 128, max_scale - 127, (M, K // 32), dtype=torch.int8, device=device) - min_scale, max_scale = (0, 142) if type_b == torch.bfloat16 else (124, 131) - scale_y = torch.randint(min_scale - 128, max_scale - 127, (N, K // 32), dtype=torch.int8, device=device) - - if not rhs_scale: - scale_y = None - - kernel_kwargs = {"num_warps": num_warps} - z = x.new_empty((M, N), dtype=x.dtype) - pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b, acc_num, - **kernel_kwargs) - z_ref = golden_ref(x, scale_x, y, scale_y) - if acc_num is not None: - z_ref = z_ref * (acc_num + 1) - - atol = 1e-5 - rtol = 1e-2 - torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) - - -@pytest.mark.parametrize("B, M, N, K", [(1, 32, 64, 64)]) -def test_4d_dot(B, M, N, K): - device = "npu" - torch.manual_seed(0) - - x4d = torch.randn((B, B, M, N), dtype=torch.float16, device=device) - y4d = torch.randn((B, B, N, K), dtype=torch.float16, device=device) - - x2d = x4d.view(-1, N) # shape (B*B*M, N) - y2d = y4d.view(-1, K) # shape (B*B*N, K) - scale_x = torch.randint(-10, 10, (x2d.shape[0], N // 32), dtype=torch.int8, device=device) - scale_y = torch.randint(-10, 10, (y2d.shape[1], N // 32), dtype=torch.int8, device=device) - - z = torch.empty((x2d.shape[0], y2d.shape[0]), dtype=x2d.dtype, device=device) - acc_num = None - dot_scale_kernel[(1, )](x2d, *x2d.stride(), scale_x, y2d, *y2d.stride(), None, z, x2d.shape[0], y2d.shape[0], K, - "fp16", "fp16", None, num_warps=4) - z_ref = golden_ref(x2d, scale_x, y2d, None) - if acc_num is not None: - z_ref = z_ref * (acc_num + 1) - - atol = 1e-5 - rtol = 1e-2 - torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) - - -@pytest.mark.parametrize("B, M, N, K", [(2, 16, 16, 32)]) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, - r"lhs last dimension .* must equal rhs penultimate dimension") -def test_2d_dot_invaild_shape(B, M, N, K): - device = "npu" - torch.manual_seed(0) - - x4d = torch.randn((B, B, M, N), dtype=torch.float16, device=device) - y4d = torch.randn((B, B, N, K), dtype=torch.float16, device=device) - - x2d = x4d.view(-1, N) # shape (B*B*M, N) - y2d = y4d.view(-1, K) # shape (B*B*N, K) - scale_x = torch.randint(-10, 10, (x2d.shape[0], N // 32), dtype=torch.int8, device=device) - scale_y = torch.randint(-10, 10, (y2d.shape[1], N // 32), dtype=torch.int8, device=device) - - z = torch.empty((x2d.shape[0], y2d.shape[0]), dtype=x2d.dtype, device=device) - acc_num = None - dot_scale_kernel[(1, )](x2d, *x2d.stride(), scale_x, y2d, *y2d.stride(), None, z, x2d.shape[0], y2d.shape[0], K, - "fp16", "fp16", None, num_warps=4) - - -VALID_MAIN_DTYPES = { - torch.float16, # fp16 - torch.bfloat16, # bf16 -} - -ALL_DTYPES = { - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.float32, # fp32 - torch.bool, -} -ILLEGAL_MAIN_DTYPES = ALL_DTYPES - VALID_MAIN_DTYPES - -ILLEGAL_SCALE_DTYPES = { - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.bfloat16, - torch.bool, -} - -from itertools import product - - -def is_legal_dtype(lhs_dtype, rhs_dtype, lhs_scale_dtype, rhs_scale_dtype): - return (lhs_dtype in VALID_MAIN_DTYPES and rhs_dtype in VALID_MAIN_DTYPES and lhs_scale_dtype is torch.int8 - and rhs_scale_dtype is torch.int8) - - -illegal_cases = [] -for lhs, rhs, lhs_s, rhs_s in product( - VALID_MAIN_DTYPES | ILLEGAL_MAIN_DTYPES, - VALID_MAIN_DTYPES | ILLEGAL_MAIN_DTYPES, - {torch.int8} | ILLEGAL_SCALE_DTYPES, - {torch.int8} | ILLEGAL_SCALE_DTYPES, -): - - if not is_legal_dtype(lhs, rhs, lhs_s, rhs_s): - illegal_cases.append((lhs, rhs, lhs_s, rhs_s)) - -illegal_cases = sorted(set(illegal_cases), key=lambda t: tuple(str(i) for i in t)) - - -@pytest.mark.parametrize( - "lhs_dtype, rhs_dtype, lhs_scale_dtype, rhs_scale_dtype", - illegal_cases, -) -@test_common.raises_with_match(Exception, r"(?i)invalid|unsupported|dtype") -def test_invalid_dtype_should_fail(lhs_dtype, rhs_dtype, lhs_scale_dtype, rhs_scale_dtype): - device = "npu" - M, N, K = 32, 32, 64 - num_warps = 4 - - def make_tensor(shape, dtype): - return torch.randn(shape, dtype=dtype, device=device) \ - if dtype.is_floating_point else \ - torch.randint(-10, 10, shape, dtype=dtype, device=device) - - def make_scale(shape, dtype): - return torch.randint(-10, 10, shape, dtype=dtype, device=device) - - x = make_tensor((M, K), lhs_dtype) - y = make_tensor((K, N), rhs_dtype) - lhs_scale = make_scale((M, K // 32), lhs_scale_dtype) - rhs_scale = make_scale((N, K // 32), rhs_scale_dtype) - z = torch.empty((M, N), dtype=lhs_dtype, device=device) - - dot_scale_kernel[(1, )]( - x, - *x.stride(), - lhs_scale, - y, - *y.stride(), - rhs_scale, - z, - M, - N, - K, - str(lhs_dtype).split('.')[-1], - str(rhs_dtype).split('.')[-1], - None, - num_warps=num_warps, - ) diff --git a/third_party/ascend/unittest/generalization_cases/test_eq.py b/third_party/ascend/unittest/generalization_cases/test_eq.py deleted file mode 100644 index 94292ac21c..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_eq.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math -import logging - - -def torch_eq(x0, x1): - if x0.dtype != torch.uint32: - return x0 == x1 - else: - return x0.to(torch.float32) == x1.to(torch.float32) - - -@triton.jit -def triton_eq(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base1 = tl.arange(0, XBLOCK_SUB) - loops1: tl.constexpr = XBLOCK // XBLOCK_SUB - for loop1 in range(loops1): - x_index = offset + (loop1 * XBLOCK_SUB) + base1 - tmp0 = tl.load(in_ptr0 + x_index, mask=x_index < N) - tmp1 = tl.load(in_ptr1 + x_index, mask=x_index < N) - tmp2 = tmp0 == tmp1 - tl.store(out_ptr0 + x_index, tmp2, mask=x_index < N) - - -@triton.jit -def triton_eq_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val == y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_eq(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - # 生成数据 - x0 = test_common.generate_tensor(shape, dtype).npu() - x1 = test_common.generate_tensor(shape, dtype).npu() - - numel = x0.numel() - ncore = 1 if numel <= 32 else 32 - xblock = math.ceil(numel / ncore) - xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) - - # torch结果 - torch_res = torch_eq(x0, x1).to(eval('torch.' + dtype)) - # triton结果 - triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - N = triton_res.numel() - triton_eq[ncore, 1, 1](x0, x1, triton_res, N, xblock, xblock_sub) - # 比较结果 - torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) - triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) - cmp_dtype = dtype if dtype != 'uint32' else 'float32' - test_common.validate_cmp(cmp_dtype, triton_res, torch_res) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_eq_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_eq(x, y).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_eq_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_erf.py b/third_party/ascend/unittest/generalization_cases/test_erf.py deleted file mode 100644 index b82945f0f0..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_erf.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_erf(x0): - res = torch.erf(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.erf(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_erf_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr, - BLOCK_TOTAL: tl.constexpr): - - pid = tl.program_id(0) - start_idx = pid * BLOCK_TOTAL - local_idx = tl.arange(0, BLOCK_TOTAL) - global_idx = start_idx + local_idx - total_elements = SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 - masks = global_idx < total_elements - - dim1_base = SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 - dim2_base = SHAPE_2 * SHAPE_3 * SHAPE_4 - dim3_base = SHAPE_3 * SHAPE_4 - dim4_base = SHAPE_4 - - idx_0 = (global_idx // dim1_base) % SHAPE_0 - idx_1 = (global_idx // dim2_base) % SHAPE_1 - idx_2 = (global_idx // dim3_base) % SHAPE_2 - idx_3 = (global_idx // dim4_base) % SHAPE_3 - idx_4 = global_idx % SHAPE_4 - - offsets = idx_0 * STRIDE_0 + idx_1 * STRIDE_1 + idx_2 * STRIDE_2 + idx_3 * STRIDE_3 + idx_4 * STRIDE_4 - - x_val = tl.load(x_ptr + offsets, mask=masks) - ret = tl.erf(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_erf(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_erf_4d_5d(shape, dtype): - logging.debug(f"Testing erf for shape={shape}, dtype={dtype}") - - x = test_common.generate_tensor(shape, dtype).npu() - output = torch.empty_like(x) - - ans = torch_erf(x) - - shape_5d = list(shape) - strides_5d = list(x.stride()) - while len(shape_5d) < 5: - shape_5d.append(1) - strides_5d.append(1) - - MAX_BLOCK_ELEMENTS = 1024 - total_elements = x.numel() - - block_5d = [1] * 5 - for i in reversed(range(5)): - if shape_5d[i] == 0: - continue - max_block_i = min(shape_5d[i], MAX_BLOCK_ELEMENTS // (torch.prod(torch.tensor(block_5d)).item())) - block_5d[i] = max_block_i - if torch.prod(torch.tensor(block_5d)).item() >= MAX_BLOCK_ELEMENTS: - break - block_total = torch.prod(torch.tensor(block_5d)).item() - - grid = (triton.cdiv(total_elements, block_total), ) - logging.debug(f"Grid={grid}, block_5d={block_5d}, block_total={block_total}") - - triton_erf_4d_5d[grid](output, x, *block_5d, *shape_5d, *strides_5d, block_total) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_exp.py b/third_party/ascend/unittest/generalization_cases/test_exp.py deleted file mode 100644 index 52233f89c8..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_exp.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import numpy as np -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_pointwise(x0): - res = torch.exp(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.exp(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_exp_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.exp(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_pointwise(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_exp_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_pointwise(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_exp_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_exp2.py b/third_party/ascend/unittest/generalization_cases/test_exp2.py deleted file mode 100644 index b8e8aa3122..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_exp2.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_exp2(x0): - res = torch.pow(2, x0, out=None) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.exp2(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_exp2_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.exp2(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_exp2(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_exp2_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_exp2(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_exp2_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_expand_dims.py b/third_party/ascend/unittest/generalization_cases/test_expand_dims.py deleted file mode 100644 index f9a85f044c..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_expand_dims.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): - yidx = tl.arange(0, YB) - - X = tl.load(x_ptr + yidx) - - ret = tl.expand_dims(X, 1) - - oidx = yidx[:, None] + tl.arange(0, 1)[None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_expand_dims_1d(shape, dtype): - x = test_common.generate_tensor(shape, dtype).npu() - a = x.unsqueeze(1) - - output = torch.randint(1, (shape[0], 1), dtype=eval('torch.' + dtype)).npu() - - fn_npu_1d[1, 1, 1](output, x, YB=shape[0], ZB=1, debug=True) - - torch.testing.assert_close(output, a) - - -@triton.jit -def fn_npu_2d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): - yoffs = tl.program_id(0) - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) - - idx = yidx[:, None] * ZB + zidx[None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.expand_dims(X, 1) - - oidx = yidx[:, None, None] * ZB + tl.arange(0, 1)[None, :, None] + zidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_expand_dims_2d(shape, dtype): - x = test_common.generate_tensor(shape, dtype).npu() - a = x.unsqueeze(1) - - output = torch.randint(1, (shape[0], 1, shape[1]), dtype=eval('torch.' + dtype)).npu() - - if x.numel() * x.element_size() > 8192: - fn_npu_2d[shape[0], 1, 1](output, x, YB=1, ZB=shape[1]) - else: - fn_npu_2d[1, 1, 1](output, x, YB=shape[0], ZB=shape[1]) - - torch.testing.assert_close(output, a) - - -@triton.jit -def fn_npu_3d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.expand_dims(X, 2) - - oidx = xidx[:, None, None, None] * YB * ZB + yidx[None, :, None, None] * ZB + tl.arange( - 0, 1)[None, None, :, None] + zidx[None, None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_expand_dims_3d(dtype, shape): - x = test_common.generate_tensor(shape, dtype).npu() - a = x.unsqueeze(2) - - output = torch.randint(1, (shape[0], shape[1], 1, shape[2]), dtype=eval('torch.' + dtype)).npu() - - fn_npu_3d[1, 1, 1](output, x, XB=shape[0], YB=shape[1], ZB=shape[2]) - - torch.testing.assert_close(output, a) - - -@triton.jit -def fn_npu_multi_d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr): - in_offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - in_offsets = in_offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - in_offsets = in_offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - in_offsets = in_offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - in_offsets = in_offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - X = tl.load(x_ptr + in_offsets) - - ret = tl.expand_dims(X, DIM).reshape(XB * YB * ZB * MB * NB) - - out_offstes = tl.arange(0, XB * YB * ZB * MB * NB) - tl.store(output_ptr + out_offstes, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('dtype', ['int8', 'float16', 'float32']) -@pytest.mark.parametrize('shape', [ - (2, 64, 16, 2), - (8, 8, 4, 2), - (8, 8, 4, 1), -]) -@pytest.mark.parametrize('dim', [-1, 0, 1, 2, 3]) -def test_npu_4d(shape, dtype, dim): - x = test_common.generate_tensor(shape, dtype).npu() - expected = x.unsqueeze(dim) - - output = torch.empty_like(expected) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - grid = (1, ) - fn_npu_multi_d[grid](output, x, *triton_shape, len(shape), dim) - - torch.testing.assert_close(output, expected) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('dtype', ['int8', 'float16', 'float32']) -@pytest.mark.parametrize('shape', [ - (2, 32, 3, 16, 2), - (8, 8, 3, 4, 2), - (8, 8, 3, 4, 1), -]) -@pytest.mark.parametrize('dim', [-1, 0, 1, 2, 3, 4]) -def test_npu_5d(shape, dtype, dim): - x = test_common.generate_tensor(shape, dtype).npu() - expected = x.unsqueeze(dim) - - output = torch.empty_like(expected) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - grid = (1, ) - fn_npu_multi_d[grid](output, x, *triton_shape, len(shape), dim) - - torch.testing.assert_close(output, expected) - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.expand_dims(X, 2) - - oidx = xidx[:, None, None, None] * YB * ZB + yidx[None, :, None, None] * ZB + tl.arange( - 0, 1)[None, None, :, None] + zidx[None, None, None, :] - - tl.store(output_ptr + oidx, ret) - - -paras = [ - ('bfloat16', eval('torch.bfloat16'), 1, 255, 8, 8, 4), - ('uint8', eval('torch.uint8'), 1, 125, 1, 256, 16), - ('uint16', eval('torch.uint16'), 1, 256, 2, 2, 3), - ('uint32', eval('torch.uint32'), 1, 256, 8, 8, 4), - ('uint64', eval('torch.uint64'), 1, 256, 8, 8, 4), - ('bool', eval('torch.bool'), 0, 2, 1, 1, 2), -] - - -@pytest.mark.parametrize('para_type,data_type,low,top,XB,YB,ZB', paras) -def test_expand_dims(para_type, data_type, low, top, XB, YB, ZB): - x = torch.randint(low=low, high=top, size=(XB, YB, ZB), dtype=data_type).npu() - a = x.unsqueeze(2) - output = torch.randint(1, (XB, YB, 1, ZB), dtype=data_type).npu() - fn_npu_[1, 1, 1](output, x, XB=XB, YB=YB, ZB=ZB, debug=True) - test_common.validate_cmp(para_type, output, a) diff --git a/third_party/ascend/unittest/generalization_cases/test_fdiv.py b/third_party/ascend/unittest/generalization_cases/test_fdiv.py deleted file mode 100644 index 099a82387f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_fdiv.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_fdiv(x0, x1): - res = x0 / x1 - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.fdiv(X, Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_fdiv_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.fdiv(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_fdiv(x, y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_fdiv_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_fdiv(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_fdiv_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_full_op.py b/third_party/ascend/unittest/generalization_cases/test_full_op.py deleted file mode 100644 index e74314d52a..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_full_op.py +++ /dev/null @@ -1,1096 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import test_common - -from test_common import TestUtils -import torch -import torch_npu -import pytest -import math -import random - - -@triton.jit -def fn_npu_int8_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.int8) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.int16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_uint32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.uint32) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.int32) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.int64) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.float16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.float32) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bf16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.bfloat16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=0, dtype=tl.int1) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int8_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.int8) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.int16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_uint32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.uint32) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.int32) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.int64) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.float16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.float32) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bf16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.bfloat16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=0, dtype=tl.int1) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int8_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.int8) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.int16) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_uint32_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.uint32) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.int32) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.int64) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.float16) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.float32) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bf16_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.bfloat16) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=0, dtype=tl.int1) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -test_shape1d = TestUtils.test_shape1d -test_shape2d = TestUtils.test_shape2d -test_shape3d = TestUtils.test_shape3d - -# 定义 dtype 到 (test_func, test_sigtype) 的映射 -dtype_mapping3d = { - 'int8': (fn_npu_int8_3d, torch.int8), - 'int16': (fn_npu_int16_3d, torch.int16), - 'int32': (fn_npu_int32_3d, torch.int32), - 'uint32': (fn_npu_uint32_3d, torch.uint32), - 'int64': (fn_npu_int64_3d, torch.int64), - 'float16': (fn_npu_fp16_3d, torch.float16), - 'float32': (fn_npu_fp32_3d, torch.float32), - 'bfloat16': (fn_npu_bf16_3d, torch.bfloat16), - 'bool': (fn_npu_bool_3d, torch.bool), -} -dtype_mapping2d = { - 'int8': (fn_npu_int8_2d, torch.int8), - 'int16': (fn_npu_int16_2d, torch.int16), - 'int32': (fn_npu_int32_2d, torch.int32), - 'uint32': (fn_npu_uint32_2d, torch.uint32), - 'int64': (fn_npu_int64_2d, torch.int64), - 'float16': (fn_npu_fp16_2d, torch.float16), - 'float32': (fn_npu_fp32_2d, torch.float32), - 'bfloat16': (fn_npu_bf16_2d, torch.bfloat16), - 'bool': (fn_npu_bool_2d, torch.bool), -} -dtype_mapping1d = { - 'int8': (fn_npu_int8_1d, torch.int8), - 'int16': (fn_npu_int16_1d, torch.int16), - 'int32': (fn_npu_int32_1d, torch.int32), - 'uint32': (fn_npu_uint32_1d, torch.uint32), - 'int64': (fn_npu_int64_1d, torch.int64), - 'float16': (fn_npu_fp16_1d, torch.float16), - 'float32': (fn_npu_fp32_1d, torch.float32), - 'bfloat16': (fn_npu_bf16_1d, torch.bfloat16), - 'bool': (fn_npu_bool_1d, torch.bool), -} - -# 生成测试用例 -testlist = [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape1d - for func, dtype in [dtype_mapping1d[sigtype]] # 直接解包映射结果 - ] - -testlist += [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape2d - for func, dtype in [dtype_mapping2d[sigtype]] # 直接解包映射结果 - ] - -testlist += [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape3d - for func, dtype in [dtype_mapping3d[sigtype]] # 直接解包映射结果 - ] - - -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist) -def test_npu(testfunc, sigtype, dtype, shape): - x = 0 - output = 0 - if len(shape) == 3: - if dtype == torch.bool: - x = torch.full((shape[0], shape[1], shape[2]), 0, dtype=dtype).npu() - else: - x = torch.full((shape[0], shape[1], shape[2]), 100, dtype=dtype).npu() - output = torch.randint(1, (shape[0], shape[1], shape[2]), dtype=dtype).npu() - testfunc[(1, 1, 1)](output, shape[0], shape[1], shape[2], debug=True) - if len(shape) == 2: - if dtype == torch.bool: - x = torch.full((shape[0], shape[1]), 0, dtype=dtype).npu() - else: - x = torch.full((shape[0], shape[1]), 100, dtype=dtype).npu() - output = torch.randint(1, (shape[0], shape[1]), dtype=dtype).npu() - shape0 = shape[0] - shape1 = shape[1] - if x.numel() * x.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - testfunc[grid](output, shape0, shape1, debug=True) - if len(shape) == 1: - if dtype == torch.bool: - x = torch.full((shape[0], ), 0, dtype=dtype).npu() - else: - x = torch.full((shape[0], ), 100, dtype=dtype).npu() - output = torch.randint(1, (shape[0], ), dtype=dtype).npu() - testfunc[1, 1, 1](output, shape[0], debug=True) - test_common.validate_cmp(sigtype, output, x) - - -@triton.jit -def fn_npu_multi_d(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - dtype = output_ptr.type.element_ty - - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - if (YB * ZB * MB * NB) == 1: - ret = tl.full((XB, ), value=100, dtype=dtype) - elif (ZB * MB * NB) == 1: - ret = tl.full((XB, YB), value=100, dtype=dtype) - elif (MB * NB) == 1: - ret = tl.full((XB, YB, ZB), value=100, dtype=dtype) - elif NB == 1: - ret = tl.full((XB, YB, ZB, MB), value=100, dtype=dtype) - else: - ret = tl.full((XB, YB, ZB, MB, NB), value=100, dtype=dtype) - - tl.store(output_ptr + offsets, ret) - - -testlist_multi_d = [ - (fn_npu_multi_d, 'float32', torch.float32, (4, 2, 16, 16)), - (fn_npu_multi_d, 'float32', torch.float32, (2, 4, 2, 16, 16)), - (fn_npu_multi_d, 'float32', torch.float16, (4, 2, 16, 16)), - (fn_npu_multi_d, 'float32', torch.float16, (2, 4, 2, 16, 16)), - (fn_npu_multi_d, 'float32', torch.int8, (4, 2, 16, 16)), - (fn_npu_multi_d, 'float32', torch.int8, (2, 4, 2, 16, 16)), -] - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist_multi_d) -def test_npu_4d_5d(testfunc, sigtype, dtype, shape): - x = torch.full(shape, 100, dtype=dtype).npu() - - print(f"shape = {x.shape}") - print(x.dtype) - print(torch.flatten(x)[0:16]) - - output = torch.randint(1, shape, dtype=dtype).npu() - - print(f"output.dtype={output.dtype}") - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - testfunc[(1, )](output, *triton_shape) - print(torch.flatten(output)[0:16]) - - test_common.validate_cmp(sigtype, output, x) - - -@triton.jit -def fn_npu_bf16_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.bfloat16) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int8_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int8) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int16) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int32) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int64) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.float16) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.float32) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=0, dtype=tl.int1) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -test_shape6d = TestUtils.test_shape6d -dtype_mapping6d = { - 'int8': (fn_npu_int8_6d, torch.int8), - 'int16': (fn_npu_int16_6d, torch.int16), - 'int32': (fn_npu_int32_6d, torch.int32), - 'int64': (fn_npu_int64_6d, torch.int64), - 'float16': (fn_npu_fp16_6d, torch.float16), - 'float32': (fn_npu_fp32_6d, torch.float32), - 'bfloat16': (fn_npu_bf16_6d, torch.bfloat16), - 'bool': (fn_npu_bool_6d, torch.bool), -} - -testlist6d = [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape6d - for func, dtype in [dtype_mapping6d[sigtype]]] - - -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist6d) -def test_npu_6d(testfunc, sigtype, dtype, shape): - x = 0 - output = 0 - if len(shape) == 6: - if dtype == torch.bool: - x = torch.full((shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]), 0, dtype=dtype).npu() - else: - x = torch.full(shape, 10, dtype=dtype).npu() - output = torch.randint(1, shape, dtype=dtype).npu() - testfunc[1, 1, 1](output, *shape, debug=True) - test_common.validate_cmp(sigtype, output, x) - - -@triton.jit -def fn_npu_bf16_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.bfloat16) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int8_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int8) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int16) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int32) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int64) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.float16) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.float32) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=0, dtype=tl.int1) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -test_shape7d = TestUtils.test_shape7d -dtype_mapping7d = { - 'int8': (fn_npu_int8_7d, torch.int8), - 'int16': (fn_npu_int16_7d, torch.int16), - 'int32': (fn_npu_int32_7d, torch.int32), - 'int64': (fn_npu_int64_7d, torch.int64), - 'float16': (fn_npu_fp16_7d, torch.float16), - 'float32': (fn_npu_fp32_7d, torch.float32), - 'bfloat16': (fn_npu_bf16_7d, torch.bfloat16), - 'bool': (fn_npu_bool_7d, torch.bool), -} - -testlist7d = [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape7d - for func, dtype in [dtype_mapping7d[sigtype]]] - - -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist7d) -def test_npu_7d(testfunc, sigtype, dtype, shape): - x = 0 - output = 0 - if len(shape) == 7: - if dtype == torch.bool: - x = torch.full((shape[0], shape[1], shape[2], shape[3], shape[4], shape[5], shape[6]), 0, dtype=dtype).npu() - else: - x = torch.full(shape, 10, dtype=dtype).npu() - output = torch.randint(1, shape, dtype=dtype).npu() - testfunc[1, 1, 1](output, *shape, debug=True) - test_common.validate_cmp(sigtype, output, x) - - -@triton.jit -def fn_npu_bf16_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.bfloat16) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int8_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int8) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int16) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int32) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int64) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.float16) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.float32) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=0, dtype=tl.int1) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -test_shape8d = TestUtils.test_shape8d -dtype_mapping8d = { - 'int8': (fn_npu_int8_8d, torch.int8), - 'int16': (fn_npu_int16_8d, torch.int16), - 'int32': (fn_npu_int32_8d, torch.int32), - 'int64': (fn_npu_int64_8d, torch.int64), - 'float16': (fn_npu_fp16_8d, torch.float16), - 'float32': (fn_npu_fp32_8d, torch.float32), - 'bfloat16': (fn_npu_bf16_8d, torch.bfloat16), - 'bool': (fn_npu_bool_8d, torch.bool), -} - -testlist8d = [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape8d - for func, dtype in [dtype_mapping8d[sigtype]]] - - -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist8d) -def test_npu_8d(testfunc, sigtype, dtype, shape): - x = 0 - output = 0 - if len(shape) == 8: - if dtype == torch.bool: - x = torch.full((shape[0], shape[1], shape[2], shape[3], shape[4], shape[5], shape[6], shape[7]), 0, - dtype=dtype).npu() - else: - x = torch.full(shape, 10, dtype=dtype).npu() - output = torch.randint(1, shape, dtype=dtype).npu() - testfunc[1, 1, 1](output, *shape, debug=True) - test_common.validate_cmp(sigtype, output, x) diff --git a/third_party/ascend/unittest/generalization_cases/test_ge_op.py b/third_party/ascend/unittest/generalization_cases/test_ge_op.py deleted file mode 100644 index d23da78f38..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_ge_op.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def triton_ge_3d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 >= x1 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_ge_2d(in_ptr0, in_ptr1, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 >= x1 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_ge_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 >= x1 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_ge_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val >= y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_ge(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.where(torch.ge(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - triton_ge_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_ge_2d[grid](x0, x1, output, shape0, shape1) - if len(shape) == 1: - triton_ge_1d[1, 1, 1](x0, x1, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_ge_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.where(torch.ge(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_ge_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_add.py b/third_party/ascend/unittest/generalization_cases/test_general_add.py deleted file mode 100644 index 8ddb6adeb0..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_add.py +++ /dev/null @@ -1,438 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import torch -import test_common -from test_common import TestUtils -import logging -import numpy as np - - -@triton.jit -def triton_add(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X + Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_add_broadcast(in_ptr0, in_ptr1, out_ptr0, X_SHAPE_0: tl.constexpr, X_SHAPE_1: tl.constexpr, - X_SHAPE_2: tl.constexpr, X_SHAPE_3: tl.constexpr, X_SHAPE_4: tl.constexpr, - Y_SHAPE_0: tl.constexpr, Y_SHAPE_1: tl.constexpr, Y_SHAPE_2: tl.constexpr, - Y_SHAPE_3: tl.constexpr, Y_SHAPE_4: tl.constexpr): - x_idx0 = tl.arange(0, X_SHAPE_0) - x_idx1 = tl.arange(0, X_SHAPE_1) - x_idx2 = tl.arange(0, X_SHAPE_2) - x_idx3 = tl.arange(0, X_SHAPE_3) - x_idx4 = tl.arange(0, X_SHAPE_4) - - y_idx0 = tl.arange(0, Y_SHAPE_0) - y_idx1 = tl.arange(0, Y_SHAPE_1) - y_idx2 = tl.arange(0, Y_SHAPE_2) - y_idx3 = tl.arange(0, Y_SHAPE_3) - y_idx4 = tl.arange(0, Y_SHAPE_4) - - xidx = x_idx0[:, None, None, None, None] * X_SHAPE_1 * X_SHAPE_2 * X_SHAPE_3 * X_SHAPE_4 + \ - x_idx1[None, :, None, None, None] * X_SHAPE_2 * X_SHAPE_3 * X_SHAPE_4 + \ - x_idx2[None, None, :, None, None] * X_SHAPE_3 * X_SHAPE_4 + \ - x_idx3[None, None, None, :, None] * X_SHAPE_4 + x_idx4[None, None, None, None, :] - - yidx = y_idx0[:, None, None, None, None] * Y_SHAPE_1 * Y_SHAPE_2 * Y_SHAPE_3 * Y_SHAPE_4 + \ - y_idx1[None, :, None, None, None] * Y_SHAPE_2 * Y_SHAPE_3 * Y_SHAPE_4 + \ - y_idx2[None, None, :, None, None] * Y_SHAPE_3 * Y_SHAPE_4 + \ - y_idx3[None, None, None, :, None] * Y_SHAPE_4 + y_idx4[None, None, None, None, :] - - X = tl.load(in_ptr0 + xidx) - Y = tl.load(in_ptr1 + yidx) - ret = X + Y - - tl.store(out_ptr0 + xidx, ret) - - -@triton.jit -def triton_add_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val + y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_add(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - ans = x + y - output = torch.zeros_like(ans) - - if len(shape) == 1: - triton_add[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - triton_add[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - triton_add[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - triton_add[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - triton_add[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_add[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - triton_add[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_add_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x + y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_add_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -def promote_dtype(x_dtype, y_dtype): - """ - 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 - """ - # 如果两个数据类型一致,直接返回 - if x_dtype == y_dtype: - return y_dtype - - # 构建类型的优先级列表(从低到高) - priority = [ - torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.bfloat16, torch.float32 - ] - - # 查找两种类型在优先级列表中的位置 - x_priority = priority.index(x_dtype) - y_priority = priority.index(y_dtype) - - # 如果y的优先级比x小,则提升到x的类型 - if y_priority < x_priority: - return x_dtype - else: - return y_dtype - - -@pytest.mark.parametrize('param_list', - [[(5, 1, 1, 1, 1), - (5, 1, 1, 2, 1)], [(2, 1), (2, 4)], [(2, 1, 1), (2, 4, 2)], [(2, 1, 1, 1), (2, 4, 2, 2)], - [(2, 1, 1, 1, 1), - (2, 4, 2, 2, 2)], [(1, ), (4, )], [(1, 2, 1), (1, 2, 3)], [(1, 1, 1, 1), (7, 1, 1, 1)]]) -@pytest.mark.parametrize('x_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -@pytest.mark.parametrize('y_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -def test_add_broadcast(param_list, x_dtype_str, y_dtype_str): - x_shape, y_shape = param_list - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = test_common.generate_tensor(y_shape, y_dtype_str).npu() - if y.numel() > x.numel(): - tmp = y - y = x - x = tmp - ans = x + y - while x.dim() < 5: - x = x.unsqueeze(-1) - while y.dim() < 5: - y = y.unsqueeze(-1) - bf2fpFlag = False - out_dtype = promote_dtype(x_dtype, y_dtype) - if (x_dtype == torch.bfloat16 and y_dtype == torch.float16) or \ - (x_dtype == torch.float16 and y_dtype == torch.bfloat16): - out_dtype = torch.float32 - bf2fpFlag = True - out_dtype = str(out_dtype).split('.')[-1] - out = test_common.generate_tensor(x.shape, out_dtype).npu() - - triton_add_broadcast[1, 1, 1](x, y, out, *x.shape, *y.shape) - while out.dim() > ans.dim(): - out = out.squeeze(-1) - - if bf2fpFlag: - torch.testing.assert_close(out, ans, rtol=1e-3, atol=1e-3) - else: - torch.testing.assert_close(out, ans) - - -@triton.jit -def add_5d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr, - NB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1 * NB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1 * NB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1 * NB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] * NB1 - offsets1 = offsets1[:, :, :, :, None] + tl.arange(0, NB1)[None, None, None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tmp2 = tl.load(out_ptr + offsets1) - out = tmp2 + tmp1 + tmp0 - tl.store(out_ptr + offsets1, out) - - -@triton.jit -def add_4d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tmp2 = tl.load(out_ptr + offsets1) - out = tmp2 + tmp1 + tmp0 - tl.store(out_ptr + offsets1, out) - - -@triton.jit -def add_3d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XB1: tl.constexpr, - YB1: tl.constexpr, ZB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tmp2 = tl.load(out_ptr + offsets1) - out = tmp2 + tmp1 + tmp0 - tl.store(out_ptr + offsets1, out) - - -@triton.jit -def add_2d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] - - offsets1 = tl.arange(0, XB1) * (YB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tmp2 = tl.load(out_ptr + offsets1) - out = tmp2 + tmp1 + tmp0 - tl.store(out_ptr + offsets1, out) - - -@pytest.mark.parametrize('param_list', [ - [(5, 1, 1, 1, 1), (5, 1, 1, 2, 1)], -]) -@pytest.mark.parametrize('x_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -@pytest.mark.parametrize('y_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -def test_add_2d_to_5d(x_dtype_str, y_dtype_str, param_list): - x0_shape, y_shape = param_list - ndim = max(len(x0_shape), len(y_shape)) - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x0 = test_common.generate_tensor(x0_shape, x_dtype_str).npu() - y = test_common.generate_tensor(y_shape, y_dtype_str).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - x0_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < ndim: - triton_shape.append(1) - - triton_shape1 = [*y_shape] - while len(triton_shape1) < ndim: - triton_shape1.append(1) - - # 按维度分支 - if ndim == 2: - XB, YB = triton_shape - XB1, YB1 = triton_shape1 - - add_2d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - XB1=XB1, - YB1=YB1, - ) - - elif ndim == 3: - XB, YB, ZB = triton_shape - XB1, YB1, ZB1 = triton_shape1 - - add_3d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - ) - - elif ndim == 4: - XB, YB, ZB, MB = triton_shape - XB1, YB1, ZB1, MB1 = triton_shape1 - - add_4d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - ) - - elif ndim == 5: - XB, YB, ZB, MB, NB = triton_shape - XB1, YB1, ZB1, MB1, NB1 = triton_shape1 - - add_5d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - NB=NB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - NB1=NB1, - ) - - else: - raise ValueError(f"Unsupported tensor dim: {ndim}") - expected = out_temp + y_temp + x0_temp - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['uint16', 'uint32', 'uint64']) -def test_add_uint(shape, dtype): - torch_dtype = eval('torch.' + dtype) - np_x0 = test_common.generate_numpy(shape, dtype) - np_x1 = test_common.generate_numpy(shape, dtype) - np_x2 = test_common.generate_numpy(shape, dtype) - - x0 = torch.from_numpy(np_x0).to(torch_dtype).npu() - x1 = torch.from_numpy(np_x1).to(torch_dtype).npu() - x2 = torch.from_numpy(np_x2).to(torch_dtype).npu() - - #numpy result - ans_numpy = np_x0 + np_x1 - z_ref1 = torch.from_numpy(ans_numpy).npu() - - triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_add[1, 1, shape[0]](triton_res, x0, x1, x2, 1, 1, 1, 1, 1, shape[0]) - test_common.validate_cmp(dtype, z_ref1, triton_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_clamp.py b/third_party/ascend/unittest/generalization_cases/test_general_clamp.py deleted file mode 100644 index d6d91c20e7..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_clamp.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging - - -def torch_clamp(x0, min_, max_): - res = torch.clamp(x0, min_, max_) - return res - - -@triton.jit -def tt_clamp_1d(in_ptr, out_ptr, min_ptr, max_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - idx = tl.arange(0, XB) - - x = tl.load(in_ptr + idx) - min_ = tl.load(min_ptr + idx) - max_ = tl.load(max_ptr + idx) - ret = tl.clamp(x, min_, max_) - - tl.store(out_ptr + idx, ret) - - -@triton.jit -def tt_clamp_2d(in_ptr, out_ptr, min_ptr, max_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - x = tl.load(in_ptr + idx) - min_ = tl.load(min_ptr + idx) - max_ = tl.load(max_ptr + idx) - ret = tl.clamp(x, min_, max_) - - tl.store(out_ptr + idx, ret) - - -@triton.jit -def tt_clamp_3d(in_ptr, out_ptr, min_ptr, max_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - min_ = tl.load(min_ptr + idx) - max_ = tl.load(max_ptr + idx) - ret = tl.clamp(x, min_, max_) - - tl.store(out_ptr + idx, ret) - - -@triton.jit -def triton_clamp_4d_5d(x_ptr, output_ptr, min_ptr, max_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - min_ = tl.load(min_ptr + offsets) - max_ = tl.load(max_ptr + offsets) - ret = tl.clamp(x_val, min_, max_) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_clamp(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - torch.manual_seed(0) - x = test_common.generate_tensor(shape, dtype).npu() - a = test_common.generate_tensor(shape, dtype) - b = test_common.generate_tensor(shape, dtype) - min_ = torch.min(a, b).npu() - max_ = torch.max(a, b).npu() - - grid = (1, 1, 1) - - y_cal = torch.empty(shape, dtype=eval('torch.' + dtype), device="npu") - - y_ref = torch_clamp(x, min_, max_) - if len(shape) == 1: - tt_clamp_1d[grid](x, y_cal, min_, max_, x.numel(), 1, 1, x.numel(), 1, 1) - elif len(shape) == 2: - xnumel, ynumel, znumel = shape + (1, ) - XB, YB, ZB = xnumel, ynumel, znumel - if x.numel() * x.element_size() > 8192: - grid = (1, ynumel, 1) - YB = 1 - tt_clamp_2d[grid](x, y_cal, min_, max_, xnumel, ynumel, znumel, XB, YB, ZB) - - elif len(shape) == 3: - xnumel, ynumel, znumel = shape - XB, YB, ZB = xnumel, ynumel, znumel - tt_clamp_3d[grid](x, y_cal, min_, max_, xnumel, ynumel, znumel, XB, YB, ZB) - - test_common.validate_cmp(dtype, y_cal, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_clamp_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - torch.manual_seed(0) - x = test_common.generate_tensor(shape, dtype).npu() - a = test_common.generate_tensor(shape, dtype) - b = test_common.generate_tensor(shape, dtype) - min_ = torch.min(a, b).npu() - max_ = torch.max(a, b).npu() - - output = torch.empty(shape, dtype=eval('torch.' + dtype), device="npu") - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_clamp(x, min_, max_) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_clamp_4d_5d[grid](x, output, min_, max_, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_div.py b/third_party/ascend/unittest/generalization_cases/test_general_div.py deleted file mode 100644 index fd4e252177..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_div.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import torch -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def triton_div(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X / Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_div_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val / y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) # some shape with int8 over ub -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_div(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - y[y == 0] = 1 - - ans = x / y - output = torch.zeros_like(ans) - if len(shape) == 1: - triton_div[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - triton_div[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - triton_div[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - triton_div[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - triton_div[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_div[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - triton_div[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - # change dtype beacuse of triton processing, triton div op will change from int to float - if dtype in ['int8', 'int16', 'int32', 'int64']: - dtype = 'float32' - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_div_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - y[y == 0] = 1 - - new_shape = shape - if dtype == 'int8' or dtype == 'int16' or dtype == 'int32' or dtype == 'int64': - output = torch.randint(1, new_shape, dtype=eval('torch.float32')).npu() - dtype = 'float32' - else: - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - ans = x / y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_div_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_floor.py b/third_party/ascend/unittest/generalization_cases/test_general_floor.py deleted file mode 100644 index 38bc1621ac..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_floor.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X + tl.floor(Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_floor_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val + tl.floor(y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_floor(dtype, shape): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - y = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x + torch.floor(y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_floor_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x + torch.floor(y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_floor_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_floordiv.py b/third_party/ascend/unittest/generalization_cases/test_general_floordiv.py deleted file mode 100644 index 9e5cf1c6ef..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_floordiv.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import torch -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def triton_floordiv(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X // Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_floordiv_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val // y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) # some shape with int8 over ub -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_floordiv(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - z = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - - new_shape = shape - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - y[y == 0] = 1 - ans = x // y - ans_mask = (x.to(torch.int64) % y.to(torch.int64) != 0) & (~((x ^ y) > 0)).to(ans.dtype) - ans = ans + ans_mask - - if len(shape) == 1: - triton_floordiv[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - triton_floordiv[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - triton_floordiv[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - triton_floordiv[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - triton_floordiv[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_floordiv[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - triton_floordiv[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_floordiv_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - new_shape = shape - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - y[y == 0] = 1 - ans = x // y - ans_mask = (x.to(torch.int64) % y.to(torch.int64) != 0) & (~((x ^ y) > 0)).to(ans.dtype) - ans = ans + ans_mask - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_floordiv_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = y.masked_fill(y == 0, 1) - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - triton_floordiv[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_fma.py b/third_party/ascend/unittest/generalization_cases/test_general_fma.py deleted file mode 100644 index eb255558ca..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_fma.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - Z = tl.load(z_ptr + idx) - - ret = tl.fma(X, Y, Z) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_fma_4d_5d(output_ptr, x_ptr, y_ptr, z_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - z_val = tl.load(z_ptr + offsets, masks) - ret = tl.fma(x_val, y_val, z_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) # math.fma do not support int dtype -def test_fma(dtype, shape): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if (x.dtype == torch.bfloat16): - ans = x.to(torch.float32) * y.to(torch.float32) + z.to(torch.float32) - ans = ans.to(torch.bfloat16) - else: - ans = x * y + z - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_fma_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if (x.dtype == torch.bfloat16): - ans = x.to(torch.float32) * y.to(torch.float32) + z.to(torch.float32) - ans = ans.to(torch.bfloat16) - else: - ans = x * y + z - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_fma_4d_5d[grid](output, x, y, z, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_gather.py b/third_party/ascend/unittest/generalization_cases/test_general_gather.py deleted file mode 100644 index ee2b8bc437..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_gather.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import numpy as np -import torch -import torch_npu -import triton -import triton.language as tl -import test_common -import pytest -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - - -@pytest.mark.parametrize("src_shape, indices_shape, axis", [ - ([2, 2], [4, 2], 0), - ([3, 3], [1, 3], 0), - ([3, 4], [4, 4], 0), - ([4, 4], [8, 4], 0), - ([4, 32], [4, 16], 1), - ([4, 64], [4, 32], 1), - ([128, 64], [128, 128], 1), -]) -def test_gather(src_shape, indices_shape, axis): - - @triton.jit - def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, - src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, - idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, - out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, - out_stride1: tl.constexpr): - src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) - src = tl.load(src_ptr + src_offs) - - idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) - idx = tl.load(idx_ptr + idx_offs) - - out = tl.gather(src, idx, axis) - - out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) - tl.store(out_ptr + out_offs, out) - - def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): - output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) - gather_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], - src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), - indices.stride(1), output.shape[0], output.shape[1], output.stride(0), output.stride(1)) - return output - - DEV = "npu" - src = torch.randn(src_shape, device=DEV) - indices = torch.randint(0, src.shape[axis], indices_shape, device=DEV) - - dtype_size = get_dtype_size('int32') - if dtype_size * math.prod(src.shape) >= (TestUtils.ub_size / 8): - print(f"dtype:int32 shape:{src.shape} mem overflow") - return - - ref = torch.gather(src, axis, indices) - result = triton_gather(src, axis, indices) - torch.testing.assert_close(result, ref, rtol=0, atol=0) - - -@triton.jit -def gather_kernel_multi_d(src_ptr, idx_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - MB: tl.constexpr, NB: tl.constexpr, I_XB: tl.constexpr, I_YB: tl.constexpr, - I_ZB: tl.constexpr, I_MB: tl.constexpr, I_NB: tl.constexpr, DIMS: tl.constexpr, - AXIS: tl.constexpr): - in_offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - in_offsets = in_offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - in_offsets = in_offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - in_offsets = in_offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - in_offsets = in_offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - idx_offsets = tl.arange(0, I_XB) * (I_YB * I_ZB * I_MB * I_NB) - if DIMS > 1: - idx_offsets = idx_offsets[:, None] + tl.arange(0, I_YB)[None, :] * (I_ZB * I_MB * I_NB) - if DIMS > 2: - idx_offsets = idx_offsets[:, :, None] + tl.arange(0, I_ZB)[None, None, :] * (I_MB * I_NB) - if DIMS > 3: - idx_offsets = idx_offsets[:, :, :, None] + tl.arange(0, I_MB)[None, None, None, :] * I_NB - if DIMS > 4: - idx_offsets = idx_offsets[:, :, :, :, None] + tl.arange(0, I_NB)[None, None, None, None, :] - - src = tl.load(src_ptr + in_offsets) - idx = tl.load(idx_ptr + idx_offsets) - - out = tl.gather(src, idx, AXIS) - - tl.store(out_ptr + idx_offsets, out) - - -def triton_gather_multi_d(src: torch.Tensor, axis: int, indices: torch.Tensor): - output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) - - s_shape = [*(src.shape)] - while len(s_shape) < 5: - s_shape.append(1) - i_shape = [*(indices.shape)] - while len(i_shape) < 5: - i_shape.append(1) - gather_kernel_multi_d[(1, )](src, indices, output, *s_shape, *i_shape, len(src.shape), axis) - return output - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize("src_shape, indices_shape, axis", [ - ((2, 2, 4, 8), (2, 2, 4, 8), 0), - ((2, 2, 4, 1), (2, 2, 4, 1), 3), - ((2, 3, 4, 8), (2, 3, 4, 8), 1), - ((2, 3, 4, 8), (2, 3, 4, 8), 2), - ((2, 2, 2, 4, 1), (2, 2, 2, 4, 1), 4), - ((2, 2, 2, 4, 8), (2, 2, 2, 4, 8), 1), - ((2, 2, 3, 4, 8), (2, 2, 3, 4, 8), 2), - ((2, 2, 3, 4, 8), (2, 2, 3, 4, 8), 0), -]) -def test_gather_4d_5d(src_shape, indices_shape, axis): - DEV = "npu" - src = torch.randn(src_shape, device=DEV) - indices = torch.randint(0, src.shape[axis], indices_shape, device=DEV) - - ref = torch.gather(src, axis, indices) - result = triton_gather_multi_d(src, axis, indices) - torch.testing.assert_close(result, ref, rtol=0, atol=0) - - -if __name__ == "__main__": - test_gather([4, 64], [4, 32], 1) - print("success: test_gather") diff --git a/third_party/ascend/unittest/generalization_cases/test_general_interleave.py b/third_party/ascend/unittest/generalization_cases/test_general_interleave.py deleted file mode 100644 index cce95fef86..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_interleave.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import logging -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - zoffs2 = tl.program_id(2) * ZB * 2 - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - zidx2 = tl.arange(0, 2 * ZB) + zoffs2 - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.interleave(X, Y) - - oidx = xidx[:, None, None] * YNUMEL * ZNUMEL * 2 + yidx[None, :, None] * ZNUMEL * 2 + zidx2[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def triton_interleave_4d( - output_ptr, - x_ptr, - y_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] - tmp4 = tl.arange(0, 2 * BLOCK_3)[None, None, None, :] - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - - ret = tl.interleave(x_val, y_val) - - out_offsets = pid + tmp0 * STRIDE_0 * 2 + tmp1 * STRIDE_1 * 2 + tmp2 * STRIDE_2 * 2 + tmp4 * STRIDE_3 - out_masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp4 < 2 * SHAPE_3) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@triton.jit -def triton_interleave_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] - tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] - tmp5 = tl.arange(0, 2 * BLOCK_4)[None, None, None, None, :] - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - - ret = tl.interleave(x_val, y_val) - - out_offsets = pid + tmp0 * STRIDE_0 * 2 + tmp1 * STRIDE_1 * 2 + tmp2 * STRIDE_2 * 2 + tmp3 * STRIDE_3 * 2 + tmp5 * STRIDE_4 - out_masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp5 < 2 * SHAPE_4) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_interleave(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - new_shape = shape[:-1] + (2 * shape[-1], ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.stack((x, y), dim=-1).reshape(new_shape) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_interleave_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape[:-1] + (2 * shape[-1], ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.stack((x, y), dim=-1).reshape(new_shape) - - blocks = list(x.size()) - strides = list(x.stride()) - - grid = (1, ) - if len(shape) == 4: - triton_interleave_4d[grid](output, x, y, *blocks, *blocks, *strides) - else: - triton_interleave_5d[grid](output, x, y, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans, output) - - -@triton.jit -def fn_npu_dtype(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.interleave(X, Y) - - oidx = xidx[:, None, None] * YB * ZB * 2 + yidx[None, :, None] * ZB * 2 + tl.arange(0, 2 * ZB)[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('para_type,data_type,XB,YB,ZB', [ - ('bfloat16', eval('torch.bfloat16'), 8, 8, 4), - ('uint8', eval('torch.uint8'), 1, 256, 16), - ('bool', eval('torch.bool'), 1, 1, 2), -]) -def test_interleave_u(para_type, data_type, XB, YB, ZB): - x = torch.full((XB, YB, ZB), 100, dtype=data_type).npu() - y = torch.full((XB, YB, ZB), 30, dtype=data_type).npu() - output = torch.randint(1, (XB, YB, ZB * 2), dtype=data_type).npu() - ans = torch.stack((x, y), dim=-1).reshape(XB, YB, ZB * 2) - fn_npu_dtype[1, 1, 1](output, x, y, XB, YB, ZB) - test_common.validate_cmp(para_type, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_join.py b/third_party/ascend/unittest/generalization_cases/test_general_join.py deleted file mode 100644 index a1d8cd3cd0..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_join.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - ret = tl.join(X, Y) - - oidx = xidx[:, None, None, None] * YNUMEL * ZNUMEL * 2 + yidx[None, :, None, None] * ZNUMEL * 2 + \ - zidx[None, None, :, None] * 2 + tl.arange(0, 2)[None, None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def triton_join_4d( - output_ptr, - x_ptr, - y_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] - - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - - ret = tl.join(x_val, y_val) - - out_tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] - out_tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] - out_tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] - out_tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] - out_tmp4 = tl.arange(0, 2)[None, None, None, None, :] - out_offsets = pid + out_tmp0 * STRIDE_0 * 2 + out_tmp1 * STRIDE_1 * 2 + out_tmp2 * STRIDE_2 * 2 \ - + out_tmp3 * STRIDE_3 * 2 + out_tmp4 - out_masks = (out_tmp0 < SHAPE_0) & (out_tmp1 < SHAPE_1) & (out_tmp2 < SHAPE_2) \ - & (out_tmp3 < SHAPE_3) & (out_tmp4 < 2) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@triton.jit -def triton_join_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] - tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] - - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - - ret = tl.join(x_val, y_val) - - out_tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None, None] - out_tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None, None] - out_tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None, None] - out_tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None, None] - out_tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :, None] - out_tmp5 = tl.arange(0, 2)[None, None, None, None, None, :] - out_offsets = pid + out_tmp0 * STRIDE_0 * 2 + out_tmp1 * STRIDE_1 * 2 + out_tmp2 * STRIDE_2 * 2 \ - + out_tmp3 * STRIDE_3 * 2 + out_tmp4 * STRIDE_4 * 2 + out_tmp5 - out_masks = (out_tmp0 < SHAPE_0) & (out_tmp1 < SHAPE_1) & (out_tmp2 < SHAPE_2) \ - & (out_tmp3 < SHAPE_3) & (out_tmp4 < SHAPE_4) & (out_tmp5 < 2) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_join(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - new_shape = shape + (2, ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.stack((x, y), dim=-1) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_join_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape + (2, ), dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.stack((x, y), dim=-1) - - blocks = list(x.size()) - strides = list(x.stride()) - - grid = (1, ) - if len(shape) == 4: - triton_join_4d[grid](output, x, y, *blocks, *blocks, *strides) - else: - triton_join_5d[grid](output, x, y, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans, output) - - -@triton.jit -def fn_npu_dtype(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - - idx = xidx[:, None] * YB + yidx[None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.join(X, Y) - - oidx = xidx[:, None, None] * YB * 2 + yidx[None, :, None] * 2 + tl.arange(0, 2)[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('para_type,data_type,XB,YB,ZB', [ - ('bfloat16', eval('torch.bfloat16'), 8, 8, 4), - ('uint8', eval('torch.uint8'), 1, 256, 16), - ('bool', eval('torch.bool'), 1, 1, 2), -]) -def test_join_u(para_type, data_type, XB, YB, ZB): - x = torch.full((XB, YB), 100, dtype=data_type).npu() - y = torch.full((XB, YB), 30, dtype=data_type).npu() - - ans = torch.stack((x, y), dim=-1) - output = torch.randint(1, (XB, YB, 2), dtype=data_type).npu() - fn_npu_dtype[1, 1, 1](output, x, y, XB, YB, ZB, debug=True) - test_common.validate_cmp(para_type, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_log.py b/third_party/ascend/unittest/generalization_cases/test_general_log.py deleted file mode 100644 index 8bcafa2274..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_log.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import triton.language.extra.ascend.libdevice as libdevice -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.log(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_log_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.log(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_log(dtype, shape): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - y = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - z = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.log(x).to(eval('torch.' + dtype)) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_log_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.log(x).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_log_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_log2.py b/third_party/ascend/unittest/generalization_cases/test_general_log2.py deleted file mode 100644 index 0a0321466d..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_log2.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import triton.language.extra.ascend.libdevice as libdevice -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.log2(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_log2_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.log2(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_log2(dtype, shape): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - y = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - z = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.log2(x).to(eval('torch.' + dtype)) - - if len(shape) == 1: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, shape[0], 1, 1, shape[0]) - elif len(shape) == 2: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - mx = max(shape[0], shape[1], shape[2]) - if mx == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif mx == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_dtypes = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_dtypes) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_log2_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_log2_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.log2(x).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_log2_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_maximum.py b/third_party/ascend/unittest/generalization_cases/test_general_maximum.py deleted file mode 100644 index e7de75d0d3..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_maximum.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_maximum(x, y): - return torch.maximum(x, y) - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.maximum(X, Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_maximum_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.maximum(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) -def test_maximum(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_maximum(x, y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -def test_maximum_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_maximum(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_maximum_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_minimum.py b/third_party/ascend/unittest/generalization_cases/test_general_minimum.py deleted file mode 100644 index 1b419c0e5e..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_minimum.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_minimum(x, y): - return torch.minimum(x, y) - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.minimum(X, Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_minimum_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.minimum(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) -def test_minimum(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_minimum(x, y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -def test_minimum_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_minimum(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_minimum_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_mul.py b/third_party/ascend/unittest/generalization_cases/test_general_mul.py deleted file mode 100644 index 29666d72e4..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_mul.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import torch -import test_common -from test_common import TestUtils -import logging -import numpy as np - - -@triton.jit -def triton_mul(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X * Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_mul_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val * y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) # some shape with int8 over ub -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_mul(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - ans = x * y - output = torch.zeros_like(ans) - - if len(shape) == 1: - triton_mul[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - triton_mul[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - triton_mul[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - triton_mul[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - triton_mul[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_mul[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - triton_mul[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_mul_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x * y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_mul_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['uint16', 'uint32', 'uint64']) -def test_mul_uint(shape, dtype): - torch_dtype = eval('torch.' + dtype) - np_x0 = test_common.generate_numpy(shape, dtype) - np_x1 = test_common.generate_numpy(shape, dtype) - np_x2 = test_common.generate_numpy(shape, dtype) - - x0 = torch.from_numpy(np_x0).to(torch_dtype).npu() - x1 = torch.from_numpy(np_x1).to(torch_dtype).npu() - x2 = torch.from_numpy(np_x2).to(torch_dtype).npu() - - #numpy result - ans_numpy = np_x0 * np_x1 - z_ref1 = torch.from_numpy(ans_numpy).npu() - - triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_mul[1, 1, shape[0]](triton_res, x0, x1, x2, 1, 1, 1, 1, 1, shape[0]) - test_common.validate_cmp(dtype, z_ref1, triton_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_ravel.py b/third_party/ascend/unittest/generalization_cases/test_general_ravel.py deleted file mode 100644 index 2e6735709d..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_ravel.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.ravel(X) - - oidx = tl.arange(0, XB * YB * ZB) + xoffs * YNUMEL * ZNUMEL + yoffs * ZNUMEL + zoffs - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def triton_ravel_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.ravel(x_val) - - pid0 = tl.program_id(0) - - flat_idx = tl.arange(0, BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) - out_offsets = pid0 * BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4 + flat_idx - out_masks = out_offsets < SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_ravel(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - new_shape = (x.numel(), ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.ravel(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - if xnumel > 1: - grid = (XB, 1, 1) - XB = 1 - elif ynumel > 1: - grid = (1, YB, 1) - YB = 1 - else: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_ravel_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - - output = torch.randint(1, (x.numel(), ), dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.ravel(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_ravel_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -@triton.jit -def fn_npu_dtype(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.ravel(X) - - oidx = tl.arange(0, XB * YB * ZB) - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('sigtype, dtype, XB, YB, ZB', [ - ('bfloat16', torch.bfloat16, 2, 8, 4), - ('uint8', torch.uint8, 1, 256, 16), - ('bool', torch.bool, 1, 1, 2), -]) -def test_ravel_u(sigtype, dtype, XB, YB, ZB): - x = test_common.generate_tensor((XB, YB, ZB), sigtype).npu() - ans = torch.ravel(x) - output = test_common.generate_tensor((1, XB * YB * ZB), sigtype).npu() - output = output.reshape(-1) - fn_npu_dtype[1, 1, 1](output, x, XB, YB, ZB) - test_common.validate_cmp(sigtype, output, ans) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_reshape.py b/third_party/ascend/unittest/generalization_cases/test_general_reshape.py deleted file mode 100644 index 25b661b0ad..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_reshape.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math -import logging - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.reshape(X, (ZB * YB * XB, )) - - oidx = tl.arange(0, XB * YB * ZB) + xoffs * YNUMEL * ZNUMEL + yoffs * ZNUMEL + zoffs - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def triton_reshape_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.reshape(x_val, (SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4, )) - - pid0 = tl.program_id(0) - - flat_idx = tl.arange(0, BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) - out_offsets = pid0 * BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4 + flat_idx - out_masks = out_offsets < SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_reshape(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - new_shape = (x.numel(), ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x.reshape(-1) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - if xnumel > 1: - grid = (XB, 1, 1) - XB = 1 - elif ynumel > 1: - grid = (1, YB, 1) - YB = 1 - else: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_reshape_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - - output = torch.randint(1, (x.numel(), ), dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x.reshape(-1) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_reshape_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_rsqrt.py b/third_party/ascend/unittest/generalization_cases/test_general_rsqrt.py deleted file mode 100644 index 54a0984fec..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_rsqrt.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.rsqrt(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_rsqrt_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.rsqrt(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_rsqrt( - dtype, - shape, -): - x = test_common.generate_tensor(shape, dtype).abs().npu() - y = test_common.generate_tensor(shape, dtype).abs().npu() - z = test_common.generate_tensor(shape, dtype).abs().npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.rsqrt(x) - - if len(shape) == 1: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, shape[0], 1, 1, shape[0]) - elif len(shape) == 2: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - mx = max(shape[0], shape[1], shape[2]) - if mx == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif mx == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_dtypes = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_dtypes) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_rsqrt_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_rsqrt_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.rsqrt(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_rsqrt_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_sigmoid.py b/third_party/ascend/unittest/generalization_cases/test_general_sigmoid.py deleted file mode 100644 index 91737647e6..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_sigmoid.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.sigmoid(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_sigmoid_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.sigmoid(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sigmoid( - dtype, - shape, -): - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if (x.dtype == torch.bfloat16): - ans = torch.sigmoid(x.to(torch.float32)).to(torch.bfloat16) - else: - ans = torch.sigmoid(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_dtypes = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_dtypes) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_sigmoid_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sigmoid_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if (x.dtype == torch.bfloat16): - ans = torch.sigmoid(x.to(torch.float32)).to(torch.bfloat16) - else: - ans = torch.sigmoid(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_sigmoid_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_sin.py b/third_party/ascend/unittest/generalization_cases/test_general_sin.py deleted file mode 100644 index f52d0405de..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_sin.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import numpy as np -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.sin(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_sin_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.sin(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -import logging - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sin( - dtype, - shape, -): - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.sin(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_dtypes = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_dtypes) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_sin_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sin_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.sin(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_sin_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_softmax.py b/third_party/ascend/unittest/generalization_cases/test_general_softmax.py deleted file mode 100644 index ce4b34d4a1..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_softmax.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -# 实际实现与官网定义不符,可能和triton submodule版本有关, 当前的submodule 不接受指定dim,都是按第0维做softmax -# arith.maximum 不支持类似 1x3 -> 3 和 1 -> 1 的reduce -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_softmax_d0(x1): - res = torch.softmax(x1, axis=0).to(x1.dtype) - return res - - -@triton.jit -def tt_softmax_1d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr): - idx = tl.arange(0, XB) - x = tl.load(in_ptr + idx) - ret = tl.softmax(x) - tl.store(out_ptr + idx, ret) - - -@triton.jit -def tt_softmax_2d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - a = tl.load(in_ptr + idx) - ret = tl.softmax(a) - - tl.store(out_ptr + idx, ret) - - -@triton.jit -def tt_softmax_3d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - a = tl.load(in_ptr + idx) - ret = tl.softmax(a) - - tl.store(out_ptr + idx, ret) - - -@triton.jit -def triton_softmax_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.softmax(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_softmax(dtype, shape): - logging.log(logging.DEBUG, f"shape = {shape}", flush=True) - torch.manual_seed(0) - x = torch.rand(shape, dtype=eval('torch.' + dtype), device="npu") * 10 - grid = (1, 1, 1) - - y_cal = torch.rand(shape, dtype=eval('torch.' + dtype), device="npu") - - y_ref = torch_softmax_d0(x) - if len(shape) == 1: - tt_softmax_1d[grid](x, y_cal, x.numel(), 1, 1, x.numel(), 1, 1) - elif len(shape) == 2: - xnumel, ynumel, znumel = shape + (1, ) - XB, YB, ZB = xnumel, ynumel, znumel - if x.numel() * x.element_size() > 8192: - grid = (1, ynumel, 1) - YB = 1 - tt_softmax_2d[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB) - - elif len(shape) == 3: - mx = max(shape[1], shape[2]) - if mx == shape[1]: - tt_softmax_3d[1, shape[1], 1](x, y_cal, shape[0], shape[1], shape[2], shape[0], 1, shape[2]) - else: - tt_softmax_3d[1, 1, shape[2]](x, y_cal, shape[0], shape[1], shape[2], shape[0], shape[1], 1) - - test_common.validate_cmp(dtype, y_cal, y_ref) - - -invalid_types = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_softmax_invalid_dtype_case(dtype): - x0 = test_common.generate_tensor((1, ), dtype).npu() - - y_cal = torch.zeros((1, ), dtype=eval('torch.' + dtype)).npu() - tt_softmax_1d[1, 1, 1](x0, y_cal, 0, 0, 0, 1, 0, 0) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_softmax_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_softmax_d0(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_softmax_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_split.py b/third_party/ascend/unittest/generalization_cases/test_general_split.py deleted file mode 100644 index 9efedcf71a..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_split.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx=xidx[:,None,None,None]*YNUMEL*ZNUMEL*2+yidx[None,:,None,None]*ZNUMEL*2+ \ - zidx[None,None,:,None]*2 + tl.arange(0,2)[None,None,None,:] - - X = tl.load(x_ptr + idx) - - xx, yy = tl.split(X) - - oidx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - tl.store(output_ptr + oidx, xx) - tl.store(output_ptr1 + oidx, yy) - - -import logging - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_split(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - xx = torch.stack((x, y), dim=-1) - - a, b = torch.split(xx, 1, dim=-1) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - a = a.reshape(XB, YB, ZB) - b = b.reshape(XB, YB, ZB) - output = torch.randint(1, (XB, YB, ZB), dtype=eval('torch.' + dtype)).npu() - output1 = torch.randint(1, (XB, YB, ZB), dtype=eval('torch.' + dtype)).npu() - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - if xnumel > 1: - grid = (XB, 1, 1) - XB = 1 - elif ynumel > 1: - grid = (1, YB, 1) - YB = 1 - else: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, xx, output1, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, a, output) - test_common.validate_cmp(dtype, b, output1) - - -@triton.jit -def fn_npu_4_8d(output_ptr, x_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, WB: tl.constexpr, - VB: tl.constexpr, UB: tl.constexpr, TB: tl.constexpr, SB: tl.constexpr): - - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - widx = tl.arange(0, WB) - vidx = tl.arange(0, VB) - uidx = tl.arange(0, UB) - tidx = tl.arange(0, TB) - sidx = tl.arange(0, SB) - - idx = (xidx[:, None, None, None, None, None, None, None, None] * YB * ZB * WB * VB * UB * TB * SB * 2 + - yidx[None, :, None, None, None, None, None, None, None] * ZB * WB * VB * UB * TB * SB * 2 + - zidx[None, None, :, None, None, None, None, None, None] * WB * VB * UB * TB * SB * 2 + - widx[None, None, None, :, None, None, None, None, None] * VB * UB * TB * SB * 2 + - vidx[None, None, None, None, :, None, None, None, None] * UB * TB * SB * 2 + - uidx[None, None, None, None, None, :, None, None, None] * TB * SB * 2 + - tidx[None, None, None, None, None, None, :, None, None] * SB * 2 + - sidx[None, None, None, None, None, None, None, :, None] * 2 + - tl.arange(0, 2)[None, None, None, None, None, None, None, None, :]) - - X = tl.load(x_ptr + idx) - xx, yy = tl.split(X) - - oidx = (xidx[:, None, None, None, None, None, None, None] * YB * ZB * WB * VB * UB * TB * SB + - yidx[None, :, None, None, None, None, None, None] * ZB * WB * VB * UB * TB * SB + - zidx[None, None, :, None, None, None, None, None] * WB * VB * UB * TB * SB + - widx[None, None, None, :, None, None, None, None] * VB * UB * TB * SB + - vidx[None, None, None, None, :, None, None, None] * UB * TB * SB + - uidx[None, None, None, None, None, :, None, None] * TB * SB + - tidx[None, None, None, None, None, None, :, None] * SB + sidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, xx) - tl.store(output_ptr1 + oidx, yy) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape_4_8d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_split_4_8d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - xx = torch.stack((x, y), dim=-1) - - a, b = torch.split(xx, 1, dim=-1) - - if len(shape) == 1: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, 1, 1, 1, shape[0] - elif len(shape) == 2: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, 1, 1, shape[0], shape[1] - elif len(shape) == 3: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, 1, shape[0], shape[1], shape[2] - elif len(shape) == 4: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, shape[0], shape[1], shape[2], shape[3] - elif len(shape) == 5: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, shape[0], shape[1], shape[2], shape[3], shape[4] - elif len(shape) == 6: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, shape[0], shape[1], shape[2], shape[3], shape[4], shape[5] - elif len(shape) == 7: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, shape[0], shape[1], shape[2], shape[3], shape[4], shape[5], shape[6] - else: - XB, YB, ZB, WB, VB, UB, TB, SB = shape - - a = a.reshape(XB, YB, ZB, WB, VB, UB, TB, SB) - b = b.reshape(XB, YB, ZB, WB, VB, UB, TB, SB) - - output = torch.randint(1, (XB, YB, ZB, WB, VB, UB, TB, SB), dtype=eval('torch.' + dtype)).npu() - output1 = torch.randint(1, (XB, YB, ZB, WB, VB, UB, TB, SB), dtype=eval('torch.' + dtype)).npu() - - grid = (1, 1, 1) - fn_npu_4_8d[grid](output, xx, output1, XB, YB, ZB, WB, VB, UB, TB, SB) - - test_common.validate_cmp(dtype, a, output) - test_common.validate_cmp(dtype, b, output1) - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - xx, yy = tl.split(X) - - oidx = xidx[:, None] * YB + yidx[None, :] - - tl.store(output_ptr + oidx, xx) - tl.store(output_ptr1 + oidx, yy) - - -@pytest.mark.parametrize('para_type, data_type, XB, YB, ZB', [ - ('bfloat16', torch.bfloat16, 2, 8, 2), - ('uint8', torch.uint8, 1, 256, 2), - ('bool', torch.bool, 1, 1, 2), -]) -def test_split_u(para_type, data_type, XB, YB, ZB): - x = test_common.generate_tensor((XB, YB, ZB), para_type).npu() - a, b = torch.split(x, 1, dim=-1) - a = a.reshape(XB, YB) - b = b.reshape(XB, YB) - - output = test_common.generate_tensor((XB, YB), para_type).npu() - output1 = test_common.generate_tensor((XB, YB), para_type).npu() - fn_npu_[1, 1, 1](output, x, output1, XB, YB, ZB, debug=True) - - test_common.validate_cmp(para_type, a, output) - test_common.validate_cmp(para_type, b, output1) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_sub.py b/third_party/ascend/unittest/generalization_cases/test_general_sub.py deleted file mode 100644 index a831d6c421..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_sub.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import torch -import test_common -from test_common import TestUtils -import logging -import numpy as np - - -@triton.jit -def triton_sub(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X - Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_sub_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val - y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_sub(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - ans = x - y - output = torch.zeros_like(ans) - - if len(shape) == 1: - triton_sub[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - triton_sub[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - triton_sub[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - triton_sub[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - triton_sub[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_sub[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - triton_sub[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_sub_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x - y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_sub_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['uint16', 'uint32', 'uint64']) -def test_sub_uint(shape, dtype): - torch_dtype = eval('torch.' + dtype) - np_x0 = test_common.generate_numpy(shape, dtype) - np_x1 = test_common.generate_numpy(shape, dtype) - np_x2 = test_common.generate_numpy(shape, dtype) - - x0 = torch.from_numpy(np_x0).to(torch_dtype).npu() - x1 = torch.from_numpy(np_x1).to(torch_dtype).npu() - x2 = torch.from_numpy(np_x2).to(torch_dtype).npu() - - #numpy result - ans_numpy = np_x0 - np_x1 - z_ref1 = torch.from_numpy(ans_numpy).npu() - - triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_sub[1, 1, shape[0]](triton_res, x0, x1, x2, 1, 1, 1, 1, 1, shape[0]) - test_common.validate_cmp(dtype, z_ref1, triton_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_tensor_descriptor.py b/third_party/ascend/unittest/generalization_cases/test_general_tensor_descriptor.py deleted file mode 100644 index ec4a7c8bc0..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_tensor_descriptor.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils - -full_dtype = test_common._float_dtypes + test_common._int_dtypes + test_common._uint_dtypes -temporarily_not_support_dtype = ['bool'] - - -@triton.jit -def triton_tensor_descriptor_2d( - out_ptr, - x_ptr, - M: tl.constexpr, - N: tl.constexpr, - M_BLOCK: tl.constexpr, - N_BLOCK: tl.constexpr, -): - in_desc = tl.make_tensor_descriptor( - x_ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[M_BLOCK, N_BLOCK], - ) - out_desc = tl.make_tensor_descriptor( - out_ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[M_BLOCK, N_BLOCK], - ) - moffset = tl.program_id(0) * M_BLOCK - noffset = tl.program_id(1) * N_BLOCK - block = in_desc.load([moffset, noffset]) - out_desc.store([moffset, noffset], block) - - -@triton.jit -def triton_tensor_descriptor_3d( - out_ptr, - x_ptr, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - stride_m: tl.constexpr, - stride_n: tl.constexpr, - stride_k: tl.constexpr, - M_BLOCK: tl.constexpr, - N_BLOCK: tl.constexpr, - K_BLOCK: tl.constexpr, -): - in_desc = tl.make_tensor_descriptor( - x_ptr, - shape=[M, N, K], - strides=[stride_m, stride_n, stride_k], - block_shape=[M_BLOCK, N_BLOCK, K_BLOCK], - ) - out_desc = tl.make_tensor_descriptor( - out_ptr, - shape=[M, N, K], - strides=[stride_m, stride_n, stride_k], - block_shape=[M_BLOCK, N_BLOCK, K_BLOCK], - ) - moffset = tl.program_id(0) * M_BLOCK - noffset = tl.program_id(1) * N_BLOCK - koffset = tl.program_id(2) * K_BLOCK - block = in_desc.load([moffset, noffset, koffset]) - out_desc.store([moffset, noffset, koffset], block) - - -@triton.jit -def triton_tensor_descriptor_4d( - out_ptr, - x_ptr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, -): - pid0 = tl.program_id(0) - pid1 = tl.program_id(1) - pid2 = tl.program_id(2) - idx2 = pid2 // BLOCK_3 - idx3 = pid2 % BLOCK_3 - o1 = pid0 * BLOCK_0 - o2 = pid1 * BLOCK_1 - o3 = idx2 * BLOCK_2 - o4 = idx3 * BLOCK_3 - in_desc = tl.make_tensor_descriptor( - x_ptr, - shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3], - strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3], - block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3], - ) - out_desc = tl.make_tensor_descriptor( - out_ptr, - shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3], - strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3], - block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3], - ) - block = in_desc.load([o1, o2, o3, o4]) - out_desc.store([o1, o2, o3, o4], block) - - -@triton.jit -def triton_tensor_descriptor_5d( - out_ptr, - x_ptr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - BLOCK_4: tl.constexpr, -): - pid0 = tl.program_id(0) - pid1 = tl.program_id(1) - pid2 = tl.program_id(2) - idx3 = pid2 // (BLOCK_3 * BLOCK_4) - idx4 = (pid2 // BLOCK_4) % BLOCK_3 - idx5 = pid2 % BLOCK_4 - o1 = pid0 * BLOCK_0 - o2 = pid1 * BLOCK_1 - o3 = idx3 * BLOCK_2 - o4 = idx4 * BLOCK_3 - o5 = idx5 * BLOCK_4 - in_desc = tl.make_tensor_descriptor( - x_ptr, - shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4], - strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4], - block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4], - ) - out_desc = tl.make_tensor_descriptor( - out_ptr, - shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4], - strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4], - block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4], - ) - block = in_desc.load([o1, o2, o3, o4, o5]) - out_desc.store([o1, o2, o3, o4, o5], block) - - -@triton.jit -def triton_tensor_descriptor_function_2d( - out_ptr, - x_ptr, - M: tl.constexpr, - N: tl.constexpr, - M_BLOCK: tl.constexpr, - N_BLOCK: tl.constexpr, -): - in_desc = tl.make_tensor_descriptor( - x_ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[M_BLOCK, N_BLOCK], - ) - out_desc = tl.make_tensor_descriptor( - out_ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[M_BLOCK, N_BLOCK], - ) - moffset = tl.program_id(0) * M_BLOCK - noffset = tl.program_id(1) * N_BLOCK - block = tl.load_tensor_descriptor(in_desc, [moffset, noffset]) - tl.store_tensor_descriptor(out_desc, [moffset, noffset], block) - - -@pytest.mark.parametrize('dtype', full_dtype) -@pytest.mark.parametrize('shape', TestUtils.full_shape) -def test_tensor_descriptor_load_store_nd(dtype, shape): - """test tensor_descriptor load/store for nd tensor""" - - if dtype in temporarily_not_support_dtype: - pytest.skip(f"{dtype} not supported") - - inp = test_common.generate_numpy(shape, dtype) - inp = torch.from_numpy(inp).npu() - out = inp.new_empty(shape) - blocks = list(inp.size()) - strides = list(inp.stride()) - grid = (1, ) - dims = len(shape) - - # 如果最后一维小于16字节,则跳过 - itemsize = torch.tensor([], dtype=inp.dtype).element_size() - if blocks[-1] * itemsize < 16: - pytest.skip(f"last dimension must be at least 16 bytes, but got {blocks[-1] * itemsize} bytes") - - if dims == 2: - if inp.numel() * inp.element_size() > 8192: - triton_tensor_descriptor_2d[shape[0], 1, 1](out, inp, 1, shape[1], 1, shape[1]) - else: - triton_tensor_descriptor_2d[grid](out, inp, *shape, *blocks) - test_common.validate_cmp(dtype, inp, out) - elif dims == 3: - triton_tensor_descriptor_3d[grid](out, inp, *shape, *strides, *blocks) - test_common.validate_cmp(dtype, inp, out) - elif dims == 4: - triton_tensor_descriptor_4d[grid](out, inp, *shape, *strides, *blocks) - test_common.validate_cmp(dtype, inp, out) - elif dims == 5: - triton_tensor_descriptor_5d[grid](out, inp, *shape, *strides, *blocks) - test_common.validate_cmp(dtype, inp, out) - else: - pytest.skip(f"{dims}d not supported") - - -@pytest.mark.parametrize("dtype", test_common._uint_dtypes) -def test_tensor_descriptor_in_function(dtype): - """test tensor_descriptor load/store in function""" - - if dtype in temporarily_not_support_dtype: - pytest.skip(f"{dtype} not supported") - - M, N = 32, 128 - inp = test_common.generate_numpy((M, N), dtype) - inp = torch.from_numpy(inp).npu() - out = inp.new_empty((M, N)) - - M_BLOCK = 8 - N_BLOCK = 32 - grid_m = M // M_BLOCK - grid_n = N // N_BLOCK - - triton_tensor_descriptor_function_2d[(grid_m, grid_n)](out, inp, M, N, M_BLOCK, N_BLOCK) - test_common.validate_cmp(dtype, inp, out) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_view.py b/third_party/ascend/unittest/generalization_cases/test_general_view.py deleted file mode 100644 index 7f0f9b1532..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_view.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import logging -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.view(X, (ZB * YB * XB, )) - - oidx = tl.arange(0, XB * YB * ZB) + xoffs * YNUMEL * ZNUMEL + yoffs * ZNUMEL + zoffs - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def triton_view_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.view(x_val, (SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4, )) - - pid0 = tl.program_id(0) - - flat_idx = tl.arange(0, BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) - out_offsets = pid0 * BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4 + flat_idx - out_masks = out_offsets < SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_view(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - new_shape = (x.numel(), ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x.view(new_shape) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - if xnumel > 1: - grid = (XB, 1, 1) - XB = 1 - elif ynumel > 1: - grid = (1, YB, 1) - YB = 1 - else: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_view_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - - output = torch.randint(1, (x.numel(), ), dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x.view(x.numel(), ) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_view_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_gt_op.py b/third_party/ascend/unittest/generalization_cases/test_gt_op.py deleted file mode 100644 index 7079457b2e..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_gt_op.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_gt_3d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 > x1 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_gt_2d(in_ptr0, in_ptr1, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 > x1 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_gt_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 > x1 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_gt_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val > y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_gt(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.where(torch.gt(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - triton_gt_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_gt_2d[grid](x0, x1, output, shape0, shape1) - if len(shape) == 1: - triton_gt_1d[1, 1, 1](x0, x1, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_gt_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.where(torch.gt(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_gt_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_invert.py b/third_party/ascend/unittest/generalization_cases/test_invert.py deleted file mode 100644 index 698cb8f11c..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_invert.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_invert(x0, ddtype): - if 'float' in str(ddtype): - x0 = x0.to(torch.int32) - y_ref = ~x0 - y_ref = y_ref.to(ddtype) - else: - y_ref = ~x0 - return y_ref - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = ~X - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_invert_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = ~x_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ddtype = eval('torch.' + dtype) - ans = torch_invert(x, ddtype) - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_invert_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_invert(x, eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_invert_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_le_op.py b/third_party/ascend/unittest/generalization_cases/test_le_op.py deleted file mode 100644 index d305395417..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_le_op.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_le_3d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 <= x1 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_le_2d(in_ptr0, in_ptr1, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 <= x1 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_le_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 <= x1 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_le_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val <= y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_le(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.where(torch.le(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - triton_le_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_le_2d[grid](x0, x1, output, shape0, shape1) - if len(shape) == 1: - triton_le_1d[1, 1, 1](x0, x1, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_le_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.where(torch.le(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_le_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_load_store.py b/third_party/ascend/unittest/generalization_cases/test_load_store.py deleted file mode 100644 index 82013d9f38..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_load_store.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, YB: tl.constexpr): - idx = tl.arange(0, YB) - X = tl.load(x_ptr + idx) - tl.store(output_ptr + idx, X) - - -def torch_fn_npu_1d(x): - return x - - -@triton.jit -def fn_npu_2d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): - pid = tl.program_id(0) - y_idx = tl.arange(0, YB)[:, None] + pid * YB - z_idx = tl.arange(0, ZB)[None, :] - idx = y_idx * ZB + z_idx - - X = tl.load(x_ptr + idx) - - tl.store(output_ptr + idx, X) - - -def torch_fn_npu_2d(x): - return x - - -@triton.jit -def fn_npu_3d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - y = tl.arange(0, YB)[:, None, None] - z = tl.arange(0, ZB)[None, :, None] - k = tl.arange(0, KB)[None, None, :] - - idx = y * ZB * KB + z * KB + k - - X = tl.load(x_ptr + idx) - - tl.store(output_ptr + idx, X) - - -def torch_fn_npu_3d(x): - return x - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_npu(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() - triton_res = torch.empty(shape, dtype=data_type).npu() - torch_res = x - if len(shape) == 1: - torch_res = torch_fn_npu_1d(x) - fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) - # uint32 转成 float32算精度,因为torch_npu不支持uint32类型张量的slice - torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) - triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) - cmp_type = dtype if dtype != 'uint32' else 'float32' - test_common.validate_cmp(cmp_type, triton_res[:2 * shape[0] // 3], torch_res[:2 * shape[0] // 3]) - elif len(shape) == 2: - torch_res = torch_fn_npu_2d(x) - fn_npu_2d[shape[0], 1, 1](triton_res, x, 1, shape[1]) - torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) - triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) - cmp_type = dtype if dtype != 'uint32' else 'float32' - test_common.validate_cmp(cmp_type, triton_res[:2 * shape[0] // 3, :2 * shape[1] // 3], - torch_res[:2 * shape[0] // 3, :2 * shape[1] // 3]) - elif len(shape) == 3: - torch_res = torch_fn_npu_3d(x) - fn_npu_3d[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) - triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) - cmp_type = dtype if dtype != 'uint32' else 'float32' - test_common.validate_cmp(cmp_type, triton_res[:2 * shape[0] // 3, :2 * shape[1] // 3, :2 * shape[2] // 3], - torch_res[:2 * shape[0] // 3, :2 * shape[1] // 3, :2 * shape[2] // 3]) - - -# require: all data (4d and 5d) can be placed into but without ub overflow -@triton.jit -def triton_load_store_multi_d(in_ptr0, out_ptr0, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, - SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - tmp_in = tl.load(in_ptr0 + offsets, masks) - tmp_out = tmp_in - tl.store(out_ptr0 + offsets, tmp_out, masks) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('param_list', [ - ['float32', (8, 4, 16, 16)], - ['float16', (8, 4, 16, 16)], - ['int8', (8, 4, 16, 16)], - ['float32', (8, 8, 4, 4)], - ['float16', (8, 8, 4, 4)], - ['int8', (8, 8, 4, 4)], - ['float32', (3, 8, 2, 16, 16)], - ['float16', (3, 8, 2, 16, 16)], - ['int8', (9, 8, 8, 16, 16)], - ['float32', (11, 8, 8, 4, 4)], - ['float16', (11, 8, 8, 4, 4)], - ['int8', (11, 8, 8, 4, 4)], -]) -def test_load_store_4d_5d(param_list): - # 生成数据 - dtype, shape = param_list - x0 = test_common.generate_tensor(shape, dtype).npu() - # torch结果 - y_expect = x0 - y_actual = test_common.generate_tensor(shape, dtype).npu() - # triton结果 - blocks = list(x0.size()) - shapes = list(x0.stride()) - while len(blocks) < 5: - blocks.append(1) - shapes.append(1) - triton_load_store_multi_d[(1, )](x0, y_actual, *blocks, *blocks, *shapes) - # 比较结果 - test_common.validate_cmp(dtype, y_actual, y_expect) diff --git a/third_party/ascend/unittest/generalization_cases/test_logical_and_op.py b/third_party/ascend/unittest/generalization_cases/test_logical_and_op.py deleted file mode 100644 index b89ca7f08f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_logical_and_op.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils, generate_tensor -import logging - - -@triton.jit -def triton_logical_and_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_and(x1) - odx = lblk_idx - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_and_2d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr): - loffs = tl.program_id(0) * L - lblk_idx = tl.arange(0, L) + loffs - mblk_idx = tl.arange(0, M) - idx = lblk_idx[:, None] * M + mblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_and(x1) - odx = lblk_idx[:, None] * M + mblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_and_3d(in_ptr0, in_ptr1, out_ptr0, XB, YB, ZB, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) + tl.program_id(0) * XB - mblk_idx = tl.arange(0, M) + tl.program_id(1) * YB - nblk_idx = tl.arange(0, N) + tl.program_id(2) * ZB - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_and(x1) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_and_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val.logical_and(y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -support_typelist = [ - 'bool', -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', support_typelist) -def test_logical_and(shape, sigtype): - logging.debug(f"dtype:{sigtype} shape:{shape}") - dtype = eval('torch.' + sigtype) - x0 = generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.logical_and(x0, x1) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 1: - triton_logical_and_1d[1, 1, 1](x0, x1, output, shape[0]) - elif len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_logical_and_2d[grid](x0, x1, output, shape0, shape1) - elif len(shape) == 3: - mx = max(shape[0], shape[1], shape[2]) - if mx == shape[0]: - triton_logical_and_3d[shape[0], 1, 1](x0, x1, output, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif mx == shape[1]: - triton_logical_and_3d[1, shape[1], 1](x0, x1, output, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_logical_and_3d[1, 1, shape[2]](x0, x1, output, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['bool']) -def test_logical_and_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.logical_and(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_logical_and_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_logical_or_op.py b/third_party/ascend/unittest/generalization_cases/test_logical_or_op.py deleted file mode 100644 index f470de056f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_logical_or_op.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils, generate_tensor -import logging - - -@triton.jit -def triton_logical_or_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_or(x1) - odx = lblk_idx - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_or_2d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr): - pid = tl.program_id(0) - lblk_idx = tl.arange(0, L) + pid * L - mblk_idx = tl.arange(0, M) - idx = lblk_idx[:, None] * M + mblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_or(x1) - odx = lblk_idx[:, None] * M + mblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_or_3d(in_ptr0, in_ptr1, out_ptr0, XB, YB, ZB, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) + tl.program_id(0) * XB - mblk_idx = tl.arange(0, M) + tl.program_id(1) * YB - nblk_idx = tl.arange(0, N) + tl.program_id(2) * ZB - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_or(x1) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_or_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val.logical_or(y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -support_typelist = [ - 'bool', -] - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('sigtype', support_typelist) -def test_logical_or(shape, sigtype): - logging.debug(f"dtype:{sigtype} shape:{shape}") - dtype = eval('torch.' + sigtype) - x0 = generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.logical_or(x0, x1) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 1: - triton_logical_or_1d[1, 1, 1](x0, x1, output, shape[0]) - elif len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_logical_or_2d[grid](x0, x1, output, shape0, shape1) - elif len(shape) == 3: - mx = max(shape[0], shape[1], shape[2]) - if mx == shape[0]: - triton_logical_or_3d[shape[0], 1, 1](x0, x1, output, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif mx == shape[1]: - triton_logical_or_3d[1, shape[1], 1](x0, x1, output, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_logical_or_3d[1, 1, shape[2]](x0, x1, output, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['bool']) -def test_logical_or_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.logical_or(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_logical_or_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_lshift_op.py b/third_party/ascend/unittest/generalization_cases/test_lshift_op.py deleted file mode 100644 index b70020aca5..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_lshift_op.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_lshift_1d(in_ptr0, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - ret = x0 << 2 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lshift_2d(in_ptr0, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - ret = x0 << 2 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lshift_3d(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - loffs = tl.program_id(0) * L - lblk_idx = tl.arange(0, L) + loffs - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - ret = x0 << 2 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lshift_4d_5d(x_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = x_val << 2 - tl.store(output_ptr + offsets, ret, mask=masks) - - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - -typelist = [ - 'int8', - 'int16', - 'int32', - 'int64', -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_lshift(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = x0 << 2 - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - shape0 = shape[0] - shape1 = shape[1] - shape2 = shape[2] - if x0.numel() * x0.element_size() >= 1024: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_lshift_3d[grid](x0, output, shape0, shape1, shape2) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 1024: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_lshift_2d[grid](x0, output, shape0, shape1) - if len(shape) == 1: - triton_lshift_1d[1, 1, 1](x0, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_lshift_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x << 2 - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_lshift_4d_5d[grid](x, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - triton_lshift_1d[1, 1, 1](x, output, N) diff --git a/third_party/ascend/unittest/generalization_cases/test_lt_op.py b/third_party/ascend/unittest/generalization_cases/test_lt_op.py deleted file mode 100644 index 8f013d7c9f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_lt_op.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def triton_lt_3d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 < x1 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lt_2d(in_ptr0, in_ptr1, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 < x1 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lt_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 < x1 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lt_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val < y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_lt(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.where(torch.lt(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - triton_lt_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_lt_2d[grid](x0, x1, output, shape0, shape1) - if len(shape) == 1: - triton_lt_1d[1, 1, 1](x0, x1, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_lt_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.where(torch.lt(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_lt_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_make_blkptr_matmul.py b/third_party/ascend/unittest/generalization_cases/test_make_blkptr_matmul.py deleted file mode 100644 index 0fdc244e79..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_make_blkptr_matmul.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, avoid_not_support, get_dtype_size - - -@triton.jit -def matmul_kernel( - a_ptr, - b_ptr, - c_ptr, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - acc_dtype: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - matxa_ptr_in = tl.make_block_ptr(a_ptr, (M, K), (K, 1), (0, 0), (M, K), order=(1, 0)) - matxb_ptr_in = tl.make_block_ptr(b_ptr, (K, N), (N, 1), (0, 0), (K, N), order=(1, 0)) - matxc_ptr_in = tl.make_block_ptr(c_ptr, (M, N), (N, 1), (0, 0), (M, N), order=(1, 0)) - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) - a = tl.load(matxa_ptr_in) - b = tl.load(matxb_ptr_in) - accumulator = tl.dot(a, b, accumulator, out_dtype=acc_dtype) - c = accumulator.to(c_ptr.dtype.element_ty) - tl.store(matxc_ptr_in, c) - - -@avoid_not_support('matmul') -@pytest.mark.parametrize('shape', [(16, 32)]) -@pytest.mark.parametrize('dtype', ['float32']) -def test_matmul(shape, dtype): - M, N, K = shape[0], shape[0], shape[1] - - BLOCK_M, BLOCK_N, BLOCK_K = M, N, K - a = test_common.generate_tensor((M, K), dtype) - b = test_common.generate_tensor((K, N), dtype) - - triton_res = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() - accumulator_type = tl.float32 - - matmul_kernel[ - 1, - ](a.npu(), b.npu(), triton_res, M, N, K, accumulator_type, BLOCK_M, BLOCK_N, BLOCK_K, enable_nd2nz_on_vector=False) - - print("PASSED") diff --git a/third_party/ascend/unittest/generalization_cases/test_make_block_ptr.py b/third_party/ascend/unittest/generalization_cases/test_make_block_ptr.py deleted file mode 100644 index 4c95d5623b..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_make_block_ptr.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, ), - strides=(1, ), - offsets=(0, ), - block_shape=(XB, ), - order=(0, ), - ) - X = tl.load(block_ptr_in) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, ), - strides=(1, ), - offsets=(0, ), - block_shape=(XB, ), - order=(0, ), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def fn_npu_2d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xoffset = tl.program_id(0) - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, YB), - strides=(YB, 1), - offsets=(xoffset, 0), - block_shape=(XB, YB), - order=(1, 0), - ) - X = tl.load(block_ptr_in) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, YB), - strides=(YB, 1), - offsets=(xoffset, 0), - block_shape=(XB, YB), - order=(1, 0), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def fn_npu_3d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, YB, ZB), - strides=(YB * ZB, ZB, 1), - offsets=(0, 0, 0), - block_shape=(XB, YB, ZB), - order=(2, 1, 0), - ) - X = tl.load(block_ptr_in) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, YB, ZB), - strides=(YB * ZB, ZB, 1), - offsets=(0, 0, 0), - block_shape=(XB, YB, ZB), - order=(2, 1, 0), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def triton_make_block_ptr_4d( - output_ptr, - x_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), - offsets=(0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), - order=(3, 2, 1, 0), - ) - x = tl.load(block_ptr_in) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), - offsets=(0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), - order=(3, 2, 1, 0), - ) - tl.store(block_ptr_out, x) - - -@triton.jit -def triton_make_block_ptr_5d( - output_ptr, - x_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - BLOCK_4: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr, -): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), - offsets=(0, 0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), - order=(4, 3, 2, 1, 0), - ) - x = tl.load(block_ptr_in) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), - offsets=(0, 0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), - order=(4, 3, 2, 1, 0), - ) - tl.store(block_ptr_out, x) - - -temporarily_not_support_dtype = ['bool'] - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.full_shape) -def test_npu(dtype, shape): - if dtype in temporarily_not_support_dtype: - return - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - - a = x - blocks = list(x.size()) - strides = list(x.stride()) - grid = (1, ) - if len(shape) == 5: - triton_make_block_ptr_5d[grid](output, x, *blocks, *blocks, *strides) - elif len(shape) == 4: - triton_make_block_ptr_4d[grid](output, x, *blocks, *blocks, *strides) - elif len(shape) == 3: - fn_npu_3d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=shape[2]) - elif len(shape) == 2: - if x.numel() * x.element_size() > 8192: - fn_npu_2d[shape[0], 1, 1](output, x, y, z, output1, XB=1, YB=shape[1], ZB=1) - else: - fn_npu_2d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=1) - else: - fn_npu_1d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=1, ZB=1) - torch.testing.assert_close(output, a) diff --git a/third_party/ascend/unittest/generalization_cases/test_matmul.py b/third_party/ascend/unittest/generalization_cases/test_matmul.py deleted file mode 100644 index edeca4f170..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_matmul.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - -import acc_util -import test_common -from test_common import TestUtils, avoid_not_support, get_dtype_size - - -@triton.jit -def matmul_kernel( - a_ptr, - b_ptr, - c_ptr, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - acc_dtype: tl.constexpr, - stride_am: tl.constexpr, - stride_ak: tl.constexpr, - stride_bk: tl.constexpr, - stride_bn: tl.constexpr, - stride_cm: tl.constexpr, - stride_cn: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - pid = tl.program_id(axis=0) - num_pid_n = tl.cdiv(N, BLOCK_N) - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) - offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) - offs_k = tl.arange(0, BLOCK_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) - accumulator = tl.dot(a, b, accumulator, out_dtype=acc_dtype) - a_ptrs += BLOCK_K * stride_ak - b_ptrs += BLOCK_K * stride_bk - c = accumulator.to(c_ptr.dtype.element_ty) - - offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - -@avoid_not_support('matmul') -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_matmul(shape, dtype): - M, N, K = shape[0], shape[0], shape[1] - # 32byte/Dtype_bytes - kalign = 32 // get_dtype_size(dtype) - BLOCK_M, BLOCK_N, BLOCK_K = min(max(M, 16), 32), min(max(N, 16), 32), min(max(K, kalign), 32) - a = test_common.generate_tensor((M, K), dtype) - b = test_common.generate_tensor((K, N), dtype) - - if dtype == "int8": - triton_res = torch.zeros((M, N), dtype=torch.int32).npu() - accumulator_type = tl.int32 - else: - triton_res = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() - accumulator_type = tl.float32 - grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), ) - - matmul_kernel[grid](a.npu(), b.npu(), triton_res, M, N, K, accumulator_type, a.stride(0), a.stride(1), b.stride(0), - b.stride(1), triton_res.stride(0), triton_res.stride(1), BLOCK_M, BLOCK_N, BLOCK_K) - - a_gold = a.to(torch.float32) - b_gold = b.to(torch.float32) - cpu_res = torch.mm(a_gold, b_gold) - - if dtype == "int8": - # torch_npu do not support int8 matmul - a_npu = a.npu().to(torch.float32) - b_npu = b.npu().to(torch.float32) - torch_res = torch.mm(a_npu, b_npu) - triton_res = triton_res.to(torch.float32) - else: - a_npu = a.npu() - b_npu = b.npu() - torch_res = torch.mm(a_npu, b_npu) - - try: - print("starting compare of cpu vs triton:") - acc_util.assert_close(cpu_res, triton_res) - except Exception as e: - print(e) - print("starting compare of cpu vs triton vs torch_npu:") - acc_util.benchmark_compare_close(cpu_res, triton_res, torch_res) - print("PASSED") - - -@avoid_not_support('matmul') -@pytest.mark.parametrize('batch', TestUtils.batch) -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_batch_matmul(shape, dtype, batch): - M, N, K = shape[0], shape[0], shape[1] - # 32byte/Dtype_bytes - kalign = 32 // get_dtype_size(dtype) - BLOCK_M, BLOCK_N, BLOCK_K = min(max(M, 16), 32), min(max(N, 16), 32), min(max(K, kalign), 32) - - aa = test_common.generate_tensor((batch, M, K), dtype) - bb = test_common.generate_tensor((batch, K, N), dtype) - - if dtype == "int8": - final_triton_res = torch.zeros((batch, M, N), dtype=torch.int32).npu() - accumulator_type = tl.int32 - else: - final_triton_res = torch.zeros((batch, M, N), dtype=eval('torch.' + dtype)).npu() - accumulator_type = tl.float32 - - for i in range(0, batch): - if dtype == "int8": - triton_res = torch.zeros((M, N), dtype=torch.int32).npu() - else: - triton_res = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() - grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), ) - a = aa[i] - b = bb[i] - matmul_kernel[grid](a.npu(), b.npu(), triton_res, M, N, K, accumulator_type, a.stride(0), a.stride(1), - b.stride(0), b.stride(1), triton_res.stride(0), triton_res.stride(1), BLOCK_M, BLOCK_N, - BLOCK_K) - final_triton_res[i] = triton_res - - a_gold = aa.to(torch.float32) - b_gold = bb.to(torch.float32) - cpu_res = torch.bmm(a_gold, b_gold) - - if dtype == "int8": - a_npu = aa.npu().to(torch.float32) - b_npu = bb.npu().to(torch.float32) - final_triton_res = final_triton_res.to(torch.float32) - else: - a_npu = aa.npu() - b_npu = bb.npu() - torch_res = torch.bmm(a_npu, b_npu) - - try: - print("starting compare of cpu vs triton:") - acc_util.assert_close(cpu_res, final_triton_res) - except Exception as e: - print(e) - print("starting compare of cpu vs triton vs torch_npu:") - acc_util.benchmark_compare_close(cpu_res, final_triton_res, torch_res) - print("PASSED") - - -if __name__ == "__main__": - test_matmul((16, 32), 'float32') - test_matmul((16, 32), 'int8') - test_batch_matmul(2, (16, 32), 'float32') - test_batch_matmul(2, (16, 32), 'int8') diff --git a/third_party/ascend/unittest/generalization_cases/test_max.py b/third_party/ascend/unittest/generalization_cases/test_max.py deleted file mode 100644 index 6029b77d9b..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_max.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import triton -import triton.language as tl -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - - -# <<<<<<< test_max_1d -def torch_max(x0, dim, keepdim): - inp = x0 if x0.device == "cpu" else x0.cpu() - return torch.max(inp, dim=dim, keepdim=keepdim)[0].npu() - - -@triton.jit -def triton_max_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None) - tmp4 = tl.max(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_max_1d(dtype, shape): - if check_ub_mem_overflow(dtype, shape): - return - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty(1, dtype=eval("torch." + dtype)).npu() - numel = shape[0] - triton_max_1d[1, 1, 1](x0, triton_res, numel, numel) - torch_res = torch_max(x0, dim=0, keepdim=True) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_max_1d - - -# <<<<<<< test_max_2d -@triton.jit -def triton_max_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) - tmp4 = tl.max(x, dim) - if dim == 0: - tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) - else: - tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0, 1]) -def test_max_2d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype == 'int8' or dtype == 'bool': - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 5): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - shapex, shapey = shape - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[1 - dim], - ], dtype=eval("torch." + dtype)).npu() - triton_max_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) - torch_res = torch_max(x0, dim=dim, keepdim=False) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_max_2d - - -# <<<<<<< test_max_3d -def torch_max_3d(x0, no_reduce_dim): - inp = x0 if x0.device == "cpu" else x0.cpu() - if no_reduce_dim == 0: - return torch.max(torch.max(inp, 1)[0], 1)[0].npu() - elif no_reduce_dim == 1: - return torch.max(torch.max(inp, 0)[0], 1)[0].npu() - elif no_reduce_dim == 2: - return torch.max(torch.max(inp, 0)[0], 0)[0].npu() - else: - assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" - - -@triton.jit -def triton_max_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.max(x, 0) - ret = tl.max(tmp, 0) - oidx = zidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_max_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.max(x, 0) - ret = tl.max(tmp, 1) - oidx = yidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_max_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.max(x, 1) - ret = tl.max(tmp, 1) - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -def triton_max_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): - if no_reduce_dim == 0: - triton_max_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 1: - triton_max_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 2: - triton_max_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) -def test_max_3d(dtype, shape, no_reduce_dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[no_reduce_dim], - ], dtype=eval("torch." + dtype)).npu() - triton_max_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) - torch_res = torch_max_3d(x0, no_reduce_dim) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_max_3d - - -# <<<<<<< test_max_4d -def torch_max_4d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.max(x0, dim=dim)[0] - - -@triton.jit -def max_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB // MB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_max_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - - idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[ - None, None, :, None] * MB + midx[None, None, None, :] - - x = tl.load(in_ptr + idx) - - max_4d(out_ptr, x, XB, YB, ZB, MB, DIM) - - -def triton_max_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): - triton_max_kernel_4d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [(2, 2, 4, 8)]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0]) -def test_max_4d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_max_4d(x0, dim) - triton_res = torch.empty_like(torch_res).npu() - triton_max_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) - - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_max_4d - - -# <<<<<<< test_max_5d -def torch_max_5d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.max(x0, dim=dim)[0] - - -@triton.jit -def max_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, - DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 3: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // NB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_max_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - - idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[ - None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] - - x = tl.load(in_ptr + idx) - - max_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) - - -def triton_max_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): - triton_max_kernel_5d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [(2, 2, 2, 4, 8)]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0]) -def test_max_5d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_max_5d(x0, dim) - triton_res = torch.empty_like(torch_res).npu() - triton_max_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) - - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_max_5d diff --git a/third_party/ascend/unittest/generalization_cases/test_min.py b/third_party/ascend/unittest/generalization_cases/test_min.py deleted file mode 100644 index 2c080e3b18..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_min.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import triton -import triton.language as tl -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - - -# <<<<<<< test_min_1d -def torch_min(x0, dim, keepdim): - inp = x0 if x0.device == "cpu" else x0.cpu() - return torch.min(inp, dim=dim, keepdim=keepdim)[0].npu() - - -@triton.jit -def triton_min_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None) - tmp4 = tl.min(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_min_1d(dtype, shape): - if check_ub_mem_overflow(dtype, shape): - return - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty(1, dtype=eval("torch." + dtype)).npu() - numel = shape[0] - triton_min_1d[1, 1, 1](x0, triton_res, numel, numel) - torch_res = torch_min(x0, dim=0, keepdim=True) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_min_1d - - -# <<<<<<< test_min_2d -@triton.jit -def triton_min_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) - tmp4 = tl.min(x, dim) - if dim == 0: - tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) - else: - tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0, 1]) -def test_min_2d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype == 'int8' or dtype == 'bool': - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 5): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - shapex, shapey = shape - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[1 - dim], - ], dtype=eval("torch." + dtype)).npu() - triton_min_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) - torch_res = torch_min(x0, dim=dim, keepdim=False) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_min_2d - - -# <<<<<<< test_min_3d -def torch_min_3d(x0, no_reduce_dim): - inp = x0 if x0.device == "cpu" else x0.cpu() - if no_reduce_dim == 0: - return torch.min(torch.min(inp, 1)[0], 1)[0].npu() - elif no_reduce_dim == 1: - return torch.min(torch.min(inp, 0)[0], 1)[0].npu() - elif no_reduce_dim == 2: - return torch.min(torch.min(inp, 0)[0], 0)[0].npu() - else: - assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" - - -@triton.jit -def triton_min_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.min(x, 0) - ret = tl.min(tmp, 0) - oidx = zidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_min_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.min(x, 0) - ret = tl.min(tmp, 1) - oidx = yidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_min_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.min(x, 1) - ret = tl.min(tmp, 1) - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -def triton_min_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): - if no_reduce_dim == 0: - triton_min_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 1: - triton_min_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 2: - triton_min_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) -def test_min_3d(dtype, shape, no_reduce_dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[no_reduce_dim], - ], dtype=eval("torch." + dtype)).npu() - triton_min_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) - torch_res = torch_min_3d(x0, no_reduce_dim) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_min_3d - - -# <<<<<<< test_min_4d -def torch_min_4d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.min(x0, dim=dim)[0] - - -@triton.jit -def min_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB // MB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_min_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - - idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[ - None, None, :, None] * MB + midx[None, None, None, :] - - x = tl.load(in_ptr + idx) - - min_4d(out_ptr, x, XB, YB, ZB, MB, DIM) - - -def triton_min_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): - triton_min_kernel_4d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [(2, 2, 4, 8)]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0]) -def test_min_4d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_min_4d(x0, dim) - triton_res = torch.empty_like(torch_res).npu() - triton_min_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) - - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_min_4d - - -# <<<<<<< test_min_5d -def torch_min_5d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.min(x0, dim=dim)[0] - - -@triton.jit -def min_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, - DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 3: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // NB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_min_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - - idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[ - None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] - - x = tl.load(in_ptr + idx) - - min_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) - - -def triton_min_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): - triton_min_kernel_5d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [(2, 2, 2, 4, 8)]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0]) -def test_min_5d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_min_5d(x0, dim) - triton_res = torch.empty_like(torch_res).npu() - triton_min_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) - - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_min_5d diff --git a/third_party/ascend/unittest/generalization_cases/test_mod.py b/third_party/ascend/unittest/generalization_cases/test_mod.py deleted file mode 100644 index ce15ea3d84..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_mod.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_pointwise(x, y): - res = x % y - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X % Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_mod_4d( - output_ptr, - x_ptr, - y_ptr, - BLOCK_SIZE: tl.constexpr, - SUB_BLOCK: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - pid = tl.program_id(0) - for loop in range(0, tl.cdiv(BLOCK_SIZE, SUB_BLOCK)): - base_idx = tl.arange(0, SUB_BLOCK) - pid_tensor = tl.full((SUB_BLOCK, ), pid * BLOCK_SIZE + loop * SUB_BLOCK, dtype=tl.int32) - tmp0 = (pid_tensor + base_idx)[:, None, None, None] - tmp1 = tl.arange(0, SHAPE_1)[None, :, None, None] - tmp2 = tl.arange(0, SHAPE_2)[None, None, :, None] - tmp3 = tl.arange(0, SHAPE_3)[None, None, None, :] - offsets = tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 - masks = tmp0 < SHAPE_0 - x = tl.load(x_ptr + offsets, mask=masks) - y = tl.load(y_ptr + offsets, mask=masks) - ret = x % y - tl.store(output_ptr + offsets, ret, mask=masks) - - -@triton.jit -def triton_mod_5d( - output_ptr, - x_ptr, - y_ptr, - BLOCK_SIZE: tl.constexpr, - SUB_BLOCK: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr, -): - pid = tl.program_id(0) - for loop in range(0, tl.cdiv(BLOCK_SIZE, SUB_BLOCK)): - base_idx = tl.arange(0, SUB_BLOCK) - pid_tensor = tl.full((SUB_BLOCK, ), pid * BLOCK_SIZE + loop * SUB_BLOCK, dtype=tl.int32) - tmp0 = (pid_tensor + base_idx)[:, None, None, None, None] - tmp1 = tl.arange(0, SHAPE_1)[None, :, None, None, None] - tmp2 = tl.arange(0, SHAPE_2)[None, None, :, None, None] - tmp3 = tl.arange(0, SHAPE_3)[None, None, None, :, None] - tmp4 = tl.arange(0, SHAPE_4)[None, None, None, None, :] - offsets = tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 - masks = tmp0 < SHAPE_0 - x = tl.load(x_ptr + offsets, mask=masks) - y = tl.load(y_ptr + offsets, mask=masks) - ret = x % y - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) -def test_case2(dtype, shape): - if dtype in ['int8', 'int16', 'int32', 'int64']: - x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - z = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - x[x <= 0] = 1 - y[y <= 0] = 1 - z[z <= 0] = 1 - - ans = torch_pointwise(x.cpu(), y.cpu()) - ans = ans.npu() - output = torch.zeros_like(ans) - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + [(25, 2, 3, 31), (2, 2, 39, 23), (17, 27, 3, 3), - (3, 2, 27, 37)]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_mod_4d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - if dtype in ['int8', 'int16', 'int32', 'int64']: - x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - x[x <= 0] = 1 - y[y <= 0] = 1 - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_pointwise(x.cpu(), y.cpu()) - ans = ans.npu() - - n = x.numel() - block_size = min(triton.next_power_of_2(n), 64) - sub_block_size = 1 - grid = (triton.cdiv(n, block_size), ) - print(" ") - print(f"=== loops: {triton.cdiv(block_size, sub_block_size)}") - print(f"=== grid : {grid}") - triton_mod_4d[grid](output, x, y, block_size, sub_block_size, *list(shape), *list(x.stride())) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape5d + [(32, 5, 3, 1, 8)]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_mod_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - if dtype in ['int8', 'int16', 'int32', 'int64']: - x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - x[x <= 0] = 1 - y[y <= 0] = 1 - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_pointwise(x.cpu(), y.cpu()) - ans = ans.npu() - - n = x.numel() - block_size = min(triton.next_power_of_2(n), 32) - sub_block_size = 1 - grid = (triton.cdiv(n, block_size), ) - print(" ") - print(f"=== loops: {triton.cdiv(block_size, sub_block_size)}") - print(f"=== grid : {grid}") - triton_mod_5d[grid](output, x, y, block_size, sub_block_size, *list(shape), *list(x.stride())) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_ne.py b/third_party/ascend/unittest/generalization_cases/test_ne.py deleted file mode 100644 index a05220da45..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_ne.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math -import logging - - -def torch_ne(x0, x1): - if x0.dtype != torch.uint32: - return x0 != x1 - else: - return x0.to(torch.float32) != x1.to(torch.float32) - - -@triton.jit -def triton_ne(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base1 = tl.arange(0, XBLOCK_SUB) - loops1: tl.constexpr = XBLOCK // XBLOCK_SUB - for loop1 in range(loops1): - x_index = offset + (loop1 * XBLOCK_SUB) + base1 - tmp0 = tl.load(in_ptr0 + x_index, mask=x_index < N) - tmp1 = tl.load(in_ptr1 + x_index, mask=x_index < N) - tmp2 = tmp0 != tmp1 - tl.store(out_ptr0 + x_index, tmp2, mask=x_index < N) - - -@triton.jit -def triton_ne_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val != y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_ne(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - # 生成数据 - x0 = test_common.generate_tensor(shape, dtype).npu() - x1 = test_common.generate_tensor(shape, dtype).npu() - - numel = x0.numel() - ncore = 1 if numel <= 32 else 32 - xblock = math.ceil(numel / ncore) - xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) - - # torch结果 - torch_res = torch_ne(x0, x1).to(eval('torch.' + dtype)) - # triton结果 - triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - N = triton_res.numel() - triton_ne[ncore, 1, 1](x0, x1, triton_res, N, xblock, xblock_sub) - # 比较结果 - torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) - triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) - cmp_dtype = dtype if dtype != 'uint32' else 'float32' - test_common.validate_cmp(cmp_dtype, triton_res, torch_res) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_ne_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_ne(x, y).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_ne_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_neg.py b/third_party/ascend/unittest/generalization_cases/test_neg.py deleted file mode 100644 index 07c738d28f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_neg.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_pointwise(x): - res = -x - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = -X - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_neg_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = -x_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) -def test_case2(dtype, shape): - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_pointwise(x.cpu()) - ans = ans.npu() - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) -def test_neg_4d_5d(shape, dtype): - x = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_pointwise(x.cpu()) - ans = ans.npu() - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_neg_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'bool', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/unittest/generalization_cases/test_not.py b/third_party/ascend/unittest/generalization_cases/test_not.py deleted file mode 100644 index 21397985d6..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_not.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_not(x0): - res = torch.bitwise_not(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = not (X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_not_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = not (x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_not(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_not_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_not(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_not_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_or.py b/third_party/ascend/unittest/generalization_cases/test_or.py deleted file mode 100644 index 9861b8daf7..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_or.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_or(x0, x1): - return x0 | x1 - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X | Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_or_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val | y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - ans = torch_or(x, y) - output = torch.zeros_like(ans) - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_or_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x | y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_or_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/unittest/generalization_cases/test_permute_1d_2d.py b/third_party/ascend/unittest/generalization_cases/test_permute_1d_2d.py deleted file mode 100644 index 70d41abc22..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_permute_1d_2d.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, xnumel: tl.constexpr): - idx = tl.arange(0, xnumel) - - X = tl.load(x_ptr + idx) - - ret = tl.permute(X, (0, )) - - tl.store(output_ptr + idx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_permute_1d(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() - - triton_res = torch.randint(1, shape, dtype=data_type).npu() - torch_res = torch.permute(x, (0, )) - fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -@triton.jit -def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr): - pid = tl.program_id(0) - yidx = tl.arange(0, YB) + pid * YB - zidx = tl.arange(0, ZB) - idx = yidx[:, None] * znumel + zidx[None, :] - - # XB,YB,1 - X = tl.load(x_ptr + idx) - - ret = tl.permute(X, (1, 0)) - - oidx = zidx[:, None] * ynumel + yidx[None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_permute(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - - ynumel = shape[0] - YB = 1 - znumel = shape[1] - ZB = shape[1] - - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=(shape[0], shape[1]), dtype=data_type).npu() - - triton_res = torch.randint(1, (shape[1], shape[0]), dtype=data_type).npu() - torch_res = torch.permute(x, (1, 0)) - fn_npu_021[shape[0], 1, 1](triton_res, x, YB, ZB, ynumel, znumel) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -if __name__ == "__main__": - for shape in [(37, 3)]: - for dtype in TestUtils.dtype_list: - test_permute(shape, dtype) diff --git a/third_party/ascend/unittest/generalization_cases/test_permute_3d.py b/third_party/ascend/unittest/generalization_cases/test_permute_3d.py deleted file mode 100644 index 9696e09f32..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_permute_3d.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def fn_npu_102(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.permute(X, (1, 0, 2)) - - oidx = zidx[:, None, None] * YB * KB + yidx[None, :, None] * KB + kidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_210(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.permute(X, (2, 1, 0)) - - oidx = kidx[:, None, None] * ZB * YB + zidx[None, :, None] * YB + yidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.permute(X, (0, 2, 1)) - - oidx = yidx[:, None, None] * ZB * KB + kidx[None, :, None] * ZB + zidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', ["int8", 'int16', 'int32', 'float16', 'float32', 'bfloat16', 'int64']) -def test_permute_3d(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() - - triton_res = torch.empty((shape[1], shape[0], shape[2]), dtype=data_type).npu() - torch_res = torch.permute(x, (1, 0, 2)) - fn_npu_102[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - test_common.validate_cmp(dtype, triton_res, torch_res) - - # not support yet: need bisheng support later - # triton_res = torch.empty((shape[2], shape[1], shape[0]), dtype=data_type).npu() - # torch_res = torch.permute(x, (2, 1, 0)) - # fn_npu_210[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - # test_common.validate_cmp(dtype, triton_res, torch_res) - - triton_res = torch.empty((shape[0], shape[2], shape[1]), dtype=data_type).npu() - torch_res = torch.permute(x, (0, 2, 1)) - fn_npu_021[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_permute_4d_5d.py b/third_party/ascend/unittest/generalization_cases/test_permute_4d_5d.py deleted file mode 100644 index 615ff3cd6e..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_permute_4d_5d.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def triton_permute_4d( - output_ptr, - x_ptr, - PERM: tl.constexpr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] - tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] - tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None] - tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] - tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None] - tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] - tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None] - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) - x_val = tl.load(x_ptr + offsets, masks) - - if PERM == 0: # 1, 0, 2, 3 - ret = tl.permute(x_val, (1, 0, 2, 3)) - shape0 = SHAPE_1 - shape1 = SHAPE_0 - shape2 = SHAPE_2 - shape3 = SHAPE_3 - elif PERM == 1: # 0, 2, 1, 3 - ret = tl.permute(x_val, (0, 2, 1, 3)) - shape0 = SHAPE_0 - shape1 = SHAPE_2 - shape2 = SHAPE_1 - shape3 = SHAPE_3 - else: # 0, 1, 3, 2 - ret = tl.permute(x_val, (0, 1, 3, 2)) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_3 - shape3 = SHAPE_2 - - s3 = 1 - s2 = s3 * shape3 - s1 = s2 * shape2 - s0 = s1 * shape1 - - if PERM == 0: # 1, 0, 2, 3 - out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 - out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) - elif PERM == 1: # 0, 2, 1, 3 - out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 - out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) - else: # 0, 1, 3, 2 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@triton.jit -def triton_permute_5d(output_ptr, x_ptr, PERM: tl.constexpr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] - tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] - - tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None, None] - tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None, None] - - tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None, None] - tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None, None] - - tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :, None] - tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None, None] - - tmp3_4 = tl.arange(0, BLOCK_3)[None, None, None, None, :] - tmp4_3 = tl.arange(0, BLOCK_4)[None, None, None, :, None] - - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) - x_val = tl.load(x_ptr + offsets, masks) - - if PERM == 0: # 1, 0, 2, 3, 4 - ret = tl.permute(x_val, 1, 0, 2, 3, 4) - shape0 = SHAPE_1 - shape1 = SHAPE_0 - shape2 = SHAPE_2 - shape3 = SHAPE_3 - shape4 = SHAPE_4 - elif PERM == 1: # 0, 2, 1, 3, 4 - ret = tl.permute(x_val, 0, 2, 1, 3, 4) - shape0 = SHAPE_0 - shape1 = SHAPE_2 - shape2 = SHAPE_1 - shape3 = SHAPE_3 - shape4 = SHAPE_4 - elif PERM == 2: # 0, 1, 3, 2, 4 - ret = tl.permute(x_val, 0, 1, 3, 2, 4) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_3 - shape3 = SHAPE_2 - shape4 = SHAPE_4 - else: # 0, 1, 2, 4, 3 - ret = tl.permute(x_val, 0, 1, 2, 4, 3) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_2 - shape3 = SHAPE_4 - shape4 = SHAPE_3 - - s4 = 1 - s3 = s4 * shape4 - s2 = s3 * shape3 - s1 = s2 * shape2 - s0 = s1 * shape1 - - if PERM == 0: # 1, 0, 2, 3, 4 - out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 + tmp4 * s4 - out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) & (tmp4 < shape4) - elif PERM == 1: # 0, 2, 1, 3, 4 - out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 + tmp4 * s4 - out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) & (tmp4 < shape4) - elif PERM == 2: # 0, 1, 3, 2, 4 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 + tmp4 * s4 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) & (tmp4 < shape4) - else: # 0, 1, 2, 4, 3 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp2 * s2 + tmp4_3 * s3 + tmp3_4 * s4 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp2 < shape2) & (tmp4_3 < shape3) & (tmp3_4 < shape4) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -@pytest.mark.parametrize('perm', [0, 1, 2, 3]) # 4d: support 3 mode; 5d: support 4 mode -def test_permute_4d_5d(shape, dtype, perm): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.randint(low=0, high=2, size=shape, dtype=eval('torch.' + dtype)).npu() - grid = (1, ) - if len(shape) == 4: - blocks = list(x.size()) - strides = list(x.stride()) - if perm == 0: # 1, 0, 2, 3; exchange axis 0, 1 - output = torch.empty((shape[1], shape[0], shape[2], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (1, 0, 2, 3)) - triton_permute_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - elif perm == 1: # 0, 2, 1, 3; exchange axis 1, 2 - output = torch.empty((shape[0], shape[2], shape[1], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (0, 2, 1, 3)) - triton_permute_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - elif perm == 2: # 0, 1, 3, 2; exchange axis 2, 3 - output = torch.empty((shape[0], shape[1], shape[3], shape[2]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (0, 1, 3, 2)) - triton_permute_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - else: - pass - else: - blocks = list(x.size()) - strides = list(x.stride()) - - if perm == 0: # 1, 0, 2, 3, 4; exchange axis 0, 1 - output = torch.empty((shape[1], shape[0], shape[2], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (1, 0, 2, 3, 4)) - elif perm == 1: # 0, 2, 1, 3, 4; exchange axis 1, 2 - output = torch.empty((shape[0], shape[2], shape[1], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 2, 1, 3, 4)) - elif perm == 2: # 0, 1, 3, 2, 4; exchange axis 2, 3 - output = torch.empty((shape[0], shape[1], shape[3], shape[2], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 1, 3, 2, 4)) - else: # 0, 1, 2, 4, 3; exchange axis 3, 4 - output = torch.empty((shape[0], shape[1], shape[2], shape[4], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 1, 2, 4, 3)) - triton_permute_5d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_5d, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_rand.py b/third_party/ascend/unittest/generalization_cases/test_rand.py deleted file mode 100644 index 8e66e48bfa..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_rand.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math -import numpy as np -import scipy - - -@triton.jit -def kernel_rand(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): - block_offset = tl.program_id(0) * XBLOCK - block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset - for inner_idx in range(block_size): - global_offset = block_offset + inner_idx - rand_vals = tl.rand(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 - tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 - - -@triton.jit -def triton_rand_4d_5d(output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - # 1D program_id for flatten multi-d offset - pid = tl.program_id(0) - # base offset for dimension 0 - offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 - mask = tl.arange(0, BLOCK_0) < SHAPE_0 - # nested offset expansion - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - ret = tl.rand(5, offsets, 10) - tl.store(output_ptr + offsets, ret, mask=mask) - - -@triton.jit -def kernel_randn(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): - block_offset = tl.program_id(0) * XBLOCK - block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset - for inner_idx in range(block_size): - global_offset = block_offset + inner_idx - rand_vals = tl.randn(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 - tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 - - -@triton.jit -def triton_randn_4d_5d(output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - # 1D program_id for flatten multi-d offset - pid = tl.program_id(0) - # base offset for dimension 0 - offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 - mask = tl.arange(0, BLOCK_0) < SHAPE_0 - # nested offset expansion - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - ret = tl.randn(5, offsets, 10) - tl.store(output_ptr + offsets, ret, mask=mask) - - -@triton.jit -def kernel_randint(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): - block_offset = tl.program_id(0) * XBLOCK - block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset - for inner_idx in range(block_size): - global_offset = block_offset + inner_idx - rand_vals = tl.randint(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 - tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 - - -@triton.jit -def triton_randint_4d_5d(output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - # 1D program_id for flatten multi-d offset - pid = tl.program_id(0) - # base offset for dimension 0 - offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 - mask = tl.arange(0, BLOCK_0) < SHAPE_0 - # nested offset expansion - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - ret = tl.randint(5, offsets, 10) - tl.store(output_ptr + offsets, ret, mask=mask) - - -@triton.jit -def kernel_randint4x(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): - block_offset = tl.program_id(0) * XBLOCK - indices = tl.arange(0, 4) - block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset - for inner_idx in range(0, block_size + 4, step=4): - global_offset = block_offset + inner_idx - rand_vals = tl.randint4x(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 - mask = (global_offset + indices) < (block_offset + block_size) - tl.store(x_ptr + global_offset + indices, rand_vals, mask) # 存储随机数 - - -@triton.jit -def triton_randint4x_4d_5d(output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - # 1D program_id for flatten multi-d offset - pid = tl.program_id(0) - # base offset for dimension 0 - offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 - mask = tl.arange(0, BLOCK_0) < SHAPE_0 - # nested offset expansion - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - ret = tl.randint4x(5, offsets, 10) - tl.store(output_ptr + offsets, ret, mask=mask) - - -# With alpha=0.01, z=-3.0902, N=100, we have (1-0.01)+(-3.0902)*sqrt(0.01*(1-0.01)/100)=0.9593, -# so there must be 96 cases for each shape to have pvalue larger than 0.01. -# There is higher possibility to fail with small shapes, so we will use large shape. -@pytest.mark.parametrize('shape', [ - (256, 256), - (512, 512), - (1024, 1024), -]) -def test_rand_case(shape): - y_calf = torch.zeros(shape, dtype=eval('torch.float32')).npu() - - numel = y_calf.numel() - ncore = 1 if numel < 32 else 32 - xblock = math.ceil(numel / ncore) - - correctness = 0 - for _ in range(100): - ref = np.random.random_sample(shape).flatten() - kernel_rand[ncore, 1, 1](y_calf, 10, numel, xblock) - - pvalue = scipy.stats.kstest(ref, y_calf.cpu().numpy().flatten()).pvalue - if pvalue > 0.01: - correctness += 1 - - assert correctness > 95 - - -@pytest.mark.parametrize('shape', [ - (256, 256), - (512, 512), - (1024, 1024), -]) -def test_randn_case(shape): - y_calf = torch.zeros(shape, dtype=eval('torch.float32')).npu() - - numel = y_calf.numel() - ncore = 1 if numel < 32 else 32 - xblock = math.ceil(numel / ncore) - - correctness = 0 - for _ in range(100): - ref = np.random.standard_normal(shape).flatten() - kernel_randn[ncore, 1, 1](y_calf, 10, numel, xblock) - - pvalue = scipy.stats.kstest(ref, y_calf.cpu().numpy().flatten()).pvalue - if pvalue > 0.01: - correctness += 1 - - assert correctness > 95 - - -@pytest.mark.parametrize('shape', [ - (256, 256), - (512, 512), - (1024, 1024), -]) -def test_randint_case(shape): - y_cali = torch.zeros(shape, dtype=eval('torch.int32')).npu() - - numel = y_cali.numel() - ncore = 1 if numel < 32 else 32 - xblock = math.ceil(numel / ncore) - - correctness = 0 - ii32 = np.iinfo(np.int32) - for _ in range(100): - ref = np.random.randint(low=ii32.min, high=ii32.max, size=shape).flatten() - kernel_randint[ncore, 1, 1](y_cali, 10, numel, xblock) - - pvalue = scipy.stats.kstest(ref, y_cali.cpu().numpy().flatten()).pvalue - if pvalue > 0.01: - correctness += 1 - - assert correctness > 95 - - -@pytest.mark.parametrize('shape', [ - (256, 256), - (512, 512), - (1024, 1024), -]) -def test_randint4x_case(shape): - y_cali = torch.zeros(shape, dtype=eval('torch.int32')).npu() - - numel = y_cali.numel() - ncore = 1 if numel < 32 else 32 - xblock = math.ceil(numel / ncore) - - correctness = 0 - ii32 = np.iinfo(np.int32) - for _ in range(100): - ref = np.random.randint(low=ii32.min, high=ii32.max, size=shape).flatten() - kernel_randint4x[ncore, 1, 1](y_cali, 10, numel, xblock) - - pvalue = scipy.stats.kstest(ref, y_cali.cpu().numpy().flatten()).pvalue - if pvalue > 0.01: - correctness += 1 - - assert correctness > 95 - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -def test_rand_4d_5d(shape): - x = torch.zeros(shape, dtype=eval('torch.float32')).npu() - y = torch.zeros(shape, dtype=eval('torch.int32')).npu() - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_rand_4d_5d[grid](x, *blocks, *blocks, *strides) - triton_randn_4d_5d[grid](x, *blocks, *blocks, *strides) - triton_randint_4d_5d[grid](y, *blocks, *blocks, *strides) - triton_randint4x_4d_5d[grid](y, *blocks, *blocks, *strides) diff --git a/third_party/ascend/unittest/generalization_cases/test_range.py b/third_party/ascend/unittest/generalization_cases/test_range.py deleted file mode 100644 index 992076fc6d..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_range.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import torch -import test_common -import logging - -import triton.language as tl -from test_common import TestUtils - - -@triton.jit -def triton_range(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X + Y - for _ in tl.range(2, 5, 2): - ret = ret + X - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_static_range(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X + Y - for _ in tl.static_range(2, 5, 2): - ret = ret + X - - tl.store(output_ptr + idx, ret) - - -test_shape = [(1, ), (2, ), (1, 1), (3, 4), (1, 1, 1), (2, 4, 8)] - - -@pytest.mark.parametrize('shape', test_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_range(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if dtype == 'bfloat16': - ans = (x.to(torch.float32) + y.to(torch.float32) + x.to(torch.float32) + x.to(torch.float32)).to(torch.bfloat16) - else: - ans = x + y + x + x - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if dtype == 'int8': - if x.numel() * x.element_size() >= 512: - grid = (1, 1, ZB) - ZB = 1 - else: - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - triton_range[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', test_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_static_range(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if dtype == 'bfloat16': - ans = (x.to(torch.float32) + y.to(torch.float32) + x.to(torch.float32) + x.to(torch.float32)).to(torch.bfloat16) - else: - ans = x + y + x + x - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if dtype == 'int8': - if x.numel() * x.element_size() >= 512: - grid = (1, 1, ZB) - ZB = 1 - else: - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - triton_static_range[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_reduce.py b/third_party/ascend/unittest/generalization_cases/test_reduce.py deleted file mode 100644 index 14ab4696cb..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_reduce.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import random -import pytest -import torch -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, get_dtype_size - - -def torch_reduce(x1, dim): - if x1.dtype == torch.float16 or x1.dtype == torch.float32: - res = torch.sum(x1.to(torch.float32), dim=dim).to(x1.dtype) - else: - res = torch.sum(x1, dim=dim).to(x1.dtype) - return res - - -@triton.jit -def _reduce_combine(a, b): - return a + b - - -@triton.jit -def tt_reduce_1d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - idx = tl.arange(0, XB) - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_reduce_2d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - - if dim == 0: - oidx = yidx - else: - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_reduce_1d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - idx = tl.arange(0, XB) - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_reduce_2d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_reduce_3d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - - tl.store(out_ptr, ret) - - -@triton.jit -def tt_reduce_3d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - - if dim == 0: - oidx = yidx[:, None] * znumel + zidx[None, :] - elif dim == 1: - oidx = xidx[:, None] * znumel + zidx[None, :] - else: - oidx = xidx[:, None] * ynumel + yidx[None, :] - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_reduce_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.reduce(x, 0, _reduce_combine) - ret = tl.reduce(tmp, 0, _reduce_combine) - oidx = zidx - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_reduce_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.reduce(x, 0, _reduce_combine) - ret = tl.reduce(tmp, 1, _reduce_combine) - oidx = yidx - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_reduce_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.reduce(x, 1, _reduce_combine) - ret = tl.reduce(tmp, 1, _reduce_combine) - oidx = xidx - - tl.store(out_ptr + oidx, ret) - - -def is_legal_combine(shape, dims): - return dims is None or (len(shape) == 3) or \ - (len(dims) == 1 and dims[0] < len(shape)) - - -dims_map = {(0, 1): tt_reduce_3d_0_1, (1, 2): tt_reduce_3d_1_2, (0, 2): tt_reduce_3d_0_2} - -shape_map = { - 1: {"append_shape": (1, 1), "func": tt_reduce_1d}, 2: {"append_shape": (1, ), "func": tt_reduce_2d}, 3: - {"append_shape": (), "func": tt_reduce_3d} -} - - -def reduce_check_ub_mem_overflow(dtype, shape): - dtype_size = get_dtype_size(dtype) - if (dtype == "int8" or dtype == "bool") and dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): - pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 6): - pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") - - -@pytest.mark.parametrize('shape', random.sample(TestUtils.full_shape, 5)) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (0, 1), (1, 2), (0, 2)]) -def test_reduce(dtype, shape, dims): - if not is_legal_combine(shape, dims): - return - - torch.manual_seed(0) - x = test_common.generate_tensor(shape, dtype).npu() - grid = (1, 1, 1) - - y_ref = torch_reduce(x, dims) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - if dims is None: - reduce_check_ub_mem_overflow(dtype, shape) - append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] - xnumel, ynumel, znumel = shape + append_shape - XB, YB, ZB = xnumel, ynumel, znumel - if len(shape) == 1: - tt_reduce_1d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - if len(shape) == 2: - tt_reduce_2d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - if len(shape) == 3: - tt_reduce_3d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - - test_common.validate_cmp(dtype, y_cal, y_ref) - - elif len(dims) == 1: # 1d reduce, 1-3d shape - append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] - xnumel, ynumel, znumel = shape + append_shape - XB, YB, ZB = xnumel, ynumel, znumel - if (len(shape) == 2) and (x.numel() * x.element_size() > 8192): - if dims[0] == 0: - grid = (1, ynumel, 1) - YB = 1 - else: - grid = (xnumel, 1, 1) - XB = 1 - tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) - test_common.validate_cmp(dtype, y_cal, y_ref) - else: # 3d shape, 2d reduce - tt_kernel = dims_map[dims] - xnumel, ynumel, znumel = shape - XB, YB, ZB = xnumel, ynumel, znumel - - tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) - test_common.validate_cmp(dtype, y_cal, y_ref) - - -@triton.jit -def triton_reduce_multi_d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr, REDUCE_NUMEL: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - x = tl.load(in_ptr + offsets) - - if DIM is not None: - ret = tl.reshape(tl.reduce(x, DIM, _reduce_combine), REDUCE_NUMEL) - o_offsets = tl.arange(0, REDUCE_NUMEL) - tl.store(out_ptr + o_offsets, ret) - else: - ret = tl.reduce(x, DIM, _reduce_combine) - tl.store(out_ptr, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (4, 2, 8, 4), - (4, 3, 8, 1), -]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (3, )]) -def test_reduce_4d(dtype, shape, dims): - torch.manual_seed(0) - - x = test_common.generate_tensor(shape, dtype).npu() - dim = dims[0] if dims is not None else None - - y_ref = torch_reduce(x, dim) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_reduce_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, y_cal, y_ref) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 2, 8, 4), - (3, 4, 2, 8, 1), -]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (3, ), (4, )]) -def test_reduce_5d(dtype, shape, dims): - torch.manual_seed(0) - - x = test_common.generate_tensor(shape, dtype).npu() - dim = dims[0] if dims is not None else None - - y_ref = torch_reduce(x, dim) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_reduce_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_relu.py b/third_party/ascend/unittest/generalization_cases/test_relu.py deleted file mode 100644 index 21880163bf..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_relu.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -import triton.language.extra.ascend.libdevice as libdevice -from test_common import TestUtils -import math - - -def torch_relu(x0, x1): - res = x0 + torch.relu(x1) - return res - - -@triton.jit -def triton_relu(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - xoffset = tl.program_id(0) * XBLOCK - for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): - x_index = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] - xmask = x_index < xnumel - tmp0 = tl.load(in_ptr0 + x_index, xmask) - tmp1 = tl.load(in_ptr1 + x_index, xmask) - tmp2 = tmp0 + libdevice.relu(tmp1) - tl.store(out_ptr0 + x_index, tmp2, xmask) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['float32', 'float16']) -def test_relu(dtype, shape): - # 生成数据 - x0 = test_common.generate_tensor(shape, dtype).npu() - x1 = test_common.generate_tensor(shape, dtype).npu() - - numel = x0.numel() - ncore = 1 if numel <= 32 else 32 - xblock = math.ceil(numel / ncore) - xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) - - # torch结果 - torch_res = torch_relu(x0, x1) - # triton结果 - triton_res = test_common.generate_tensor(shape, dtype).npu() - triton_relu[ncore, 1, 1](x0, x1, triton_res, x0.numel(), xblock, xblock_sub) - # 比较结果 - test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_rshift_op.py b/third_party/ascend/unittest/generalization_cases/test_rshift_op.py deleted file mode 100644 index 33b6fffecd..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_rshift_op.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_rshift_1d(in_ptr0, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - ret = x0 >> 2 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_rshift_2d(in_ptr0, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - ret = x0 >> 2 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_rshift_3d(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - loffs = tl.program_id(0) * L - lblk_idx = tl.arange(0, L) + loffs - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - ret = x0 >> 2 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_rshift_4d_5d(x_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = x_val >> 2 - tl.store(output_ptr + offsets, ret, mask=masks) - - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - -typelist = [ - 'int8', - 'int16', - 'int32', - 'int64', -] - - -# @pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_lshift(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = x0 >> 2 - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - shape0 = shape[0] - shape1 = shape[1] - shape2 = shape[2] - if x0.numel() * x0.element_size() >= 1024: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_rshift_3d[grid](x0, output, shape0, shape1, shape2) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 1024: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_rshift_2d[grid](x0, output, shape0, shape1) - if len(shape) == 1: - triton_rshift_1d[1, 1, 1](x0, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_rshift_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x >> 2 - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_rshift_4d_5d[grid](x, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - triton_rshift_1d[1, 1, 1](x, output, N) diff --git a/third_party/ascend/unittest/generalization_cases/test_scalar_tensor.py b/third_party/ascend/unittest/generalization_cases/test_scalar_tensor.py deleted file mode 100644 index defb936aa7..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_scalar_tensor.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import numpy as np -import torch -import pytest -import test_common - - -def torch_(x0, x1, op_type): - if op_type == 'mul': - return torch.tensor(x0 * x1) - elif op_type == 'lshift': - return torch.tensor(x0 << x1) - elif op_type == 'eq': - return torch.tensor(x0 == x1) - else: - raise TypeError('Invalid op_type') - - -@triton.jit -def scalar_mul(out_ptr0, val0: tl.constexpr, val1: tl.constexpr): - scalar0 = tl.core.tensor(val0, tl.core.block_type(tl.float32, [])) - scalar1 = tl.core.tensor(val1, tl.core.block_type(tl.float32, [])) - ret = scalar0 * scalar1 - tl.store(out_ptr0, ret) - - -@triton.jit -def scalar_lshift(out_ptr0, val0: tl.constexpr, val1: tl.constexpr): - scalar0 = tl.core.tensor(val0, tl.core.block_type(tl.int32, [])) - scalar1 = tl.core.tensor(val1, tl.core.block_type(tl.int32, [])) - ret = scalar0 << scalar1 - tl.store(out_ptr0, ret) - - -@triton.jit -def scalar_eq(out_ptr0, val0: tl.constexpr, val1: tl.constexpr): - scalar0 = tl.core.tensor(val0, tl.core.block_type(tl.int16, [])) - scalar1 = tl.core.tensor(val1, tl.core.block_type(tl.int16, [])) - ret = scalar0 == scalar1 - tl.store(out_ptr0, ret) - - -@pytest.mark.parametrize('param_list', [ - ['float32', 'mul', (1, ), 3.14, 6.66], - ['int32', 'lshift', (1, ), 6, 7], - ['bool', 'eq', (1, ), 5, 5], -]) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "0d block_type is forbidden") -def test_case(param_list): - dtype, op_type, shape, lval, rval = param_list - ans = torch_(lval, rval, op_type) - ret = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - - if op_type == 'mul': - scalar_mul[1, 1, 1](ret, lval, rval) - elif op_type == 'lshift': - scalar_lshift[1, 1, 1](ret, lval, rval) - elif op_type == 'eq': - scalar_eq[1, 1, 1](ret, lval, rval) - - test_common.validate_cmp(dtype, ans, ret) diff --git a/third_party/ascend/unittest/generalization_cases/test_sort.py b/third_party/ascend/unittest/generalization_cases/test_sort.py deleted file mode 100644 index 543915ec2b..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sort.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import pytest -import torch -import numpy as np -import triton.language as tl -import test_common -from test_common import TestUtils - - -# ---------------------- -# 1D sort kernel -# ---------------------- -@triton.jit -def sort_kernel_1d(X, Z, M: tl.constexpr, descending: tl.constexpr): - off = tl.arange(0, M) - x = tl.load(X + off) - x = tl.sort(x, descending=descending, dim=0) - tl.store(Z + off, x) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("shape", TestUtils.test_shape1d) -@pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16", "bool"]) -def test_sort_1d(shape, descending, dtype): - if dtype == "bool": - x = test_common.generate_tensor(shape, dtype) - np_sorted = np.sort(x) - if descending: - np_sorted = np_sorted[::-1].copy() - torch_res = torch.from_numpy(np_sorted).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch.sort(x, descending=descending)[0] - - x = x.npu() - triton_res = torch.zeros_like(x) - M = x.shape[0] - sort_kernel_1d[(1, )](x, triton_res, M, descending) - assert torch.equal(torch_res, triton_res) - - -# ---------------------- -# 2D sort kernel (split by rows, not cutting M axis) -# ---------------------- -@triton.jit -def sort_kernel_2d(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr): - pid = tl.program_id(0) - offx = tl.arange(0, M) - offy = pid * M - off2d = offx + offy - x = tl.load(X + off2d) - x = tl.sort(x, descending=descending, dim=0) - tl.store(Z + off2d, x) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("shape", TestUtils.test_shape2d) -@pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16", "bool"]) -def test_sort_2d(shape, descending, dtype): - if dtype == "bool": - x = test_common.generate_tensor(shape, dtype) - np_sorted = np.sort(x) - if descending: - np_sorted = np_sorted[:, ::-1].copy() - torch_res = torch.from_numpy(np_sorted).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch.sort(x, descending=descending)[0] - - x = x.npu() - triton_res = torch.zeros_like(x) - N, M = x.shape - # 每行一个 block - sort_kernel_2d[(N, )](x, triton_res, N, M, descending) - assert torch.equal(torch_res, triton_res), (torch_res, triton_res) - - -# ---------------------- -# 3D sort kernel (split by D0, D1, not cutting D2) -# ---------------------- -@triton.jit -def sort_kernel_3d(X, Z, D0: tl.constexpr, D1: tl.constexpr, D2: tl.constexpr, descending: tl.constexpr): - pid = tl.program_id(0) - row_id = pid % D1 - batch_id = pid // D1 - - off2 = tl.arange(0, D2) - off1 = row_id * D2 - off0 = batch_id * D1 * D2 - off = off2 + off1 + off0 - - x = tl.load(X + off) - x = tl.sort(x, descending=descending, dim=0) # 一整行排序 - tl.store(Z + off, x) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("shape", TestUtils.test_shape3d) -@pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16", "bool"]) -def test_sort_3d(shape, descending, dtype): - if dtype == "bool": - x = test_common.generate_tensor(shape, dtype) - np_sorted = np.sort(x) - if descending: - np_sorted = np_sorted[:, :, ::-1].copy() - torch_res = torch.from_numpy(np_sorted).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch.sort(x, descending=descending)[0] - - x = x.npu() - triton_res = torch.zeros_like(x) - D0, D1, D2 = x.shape - # 每个 (D0,D1) 对应一个 block - sort_kernel_3d[(D0 * D1, )](x, triton_res, D0, D1, D2, descending) - assert torch.equal(torch_res, triton_res), (torch_res, triton_res) - - -# ---------------------- -# 4D sort kernel -# ---------------------- -@triton.jit -def sort_kernel_4d(X, Z, D0: tl.constexpr, D1: tl.constexpr, D2: tl.constexpr, D3: tl.constexpr, - descending: tl.constexpr): - pid = tl.program_id(0) - row_id = pid % D2 - col_id = (pid // D2) % D1 - batch_id = pid // (D1 * D2) - - off3 = tl.arange(0, D3) - off2 = row_id * D3 - off1 = col_id * D2 * D3 - off0 = batch_id * D1 * D2 * D3 - off = off3 + off2 + off1 + off0 - - x = tl.load(X + off) - x = tl.sort(x, descending=descending, dim=0) - tl.store(Z + off, x) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("shape", TestUtils.test_shape4d) -@pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16", "bool"]) -def test_sort_4d(shape, descending, dtype): - if dtype == "bool": - x = test_common.generate_tensor(shape, dtype) - np_sorted = np.sort(x) - if descending: - np_sorted = np_sorted[:, :, :, ::-1].copy() - torch_res = torch.from_numpy(np_sorted).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch.sort(x, descending=descending)[0] - - x = x.npu() - triton_res = torch.zeros_like(x) - D0, D1, D2, D3 = x.shape - sort_kernel_4d[(D0 * D1 * D2, )](x, triton_res, D0, D1, D2, D3, descending) - assert torch.equal(torch_res, triton_res) - - -# ---------------------- -# 5D sort kernel -# ---------------------- -@triton.jit -def sort_kernel_5d(X, Z, D0: tl.constexpr, D1: tl.constexpr, D2: tl.constexpr, D3: tl.constexpr, D4: tl.constexpr, - descending: tl.constexpr): - pid = tl.program_id(0) - row_id = pid % D3 - col_id = (pid // D3) % D2 - depth_id = (pid // (D2 * D3)) % D1 - batch_id = pid // (D1 * D2 * D3) - - off4 = tl.arange(0, D4) - off3 = row_id * D4 - off2 = col_id * D3 * D4 - off1 = depth_id * D2 * D3 * D4 - off0 = batch_id * D1 * D2 * D3 * D4 - off = off4 + off3 + off2 + off1 + off0 - - x = tl.load(X + off) - x = tl.sort(x, descending=descending, dim=0) - tl.store(Z + off, x) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("shape", TestUtils.test_shape5d) -@pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16", "bool"]) -def test_sort_5d(shape, descending, dtype): - if dtype == "bool": - x = test_common.generate_tensor(shape, dtype) - np_sorted = np.sort(x) - if descending: - np_sorted = np_sorted[:, :, :, :, ::-1].copy() - torch_res = torch.from_numpy(np_sorted).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch.sort(x, descending=descending)[0] - - x = x.npu() - triton_res = torch.zeros_like(x) - D0, D1, D2, D3, D4 = x.shape - sort_kernel_5d[(D0 * D1 * D2 * D3, )](x, triton_res, D0, D1, D2, D3, D4, descending) - assert torch.equal(torch_res, triton_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_sqrt.py b/third_party/ascend/unittest/generalization_cases/test_sqrt.py deleted file mode 100644 index 49055ff811..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sqrt.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import numpy as np -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_sqrt(x0): - res = torch.sqrt(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.sqrt(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_sqrt_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.sqrt(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_sqrt(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_sqrt_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sqrt_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.sqrt(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_sqrt_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_sqrt_rn.py b/third_party/ascend/unittest/generalization_cases/test_sqrt_rn.py deleted file mode 100644 index a1add886b4..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sqrt_rn.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_sqrt_rn(x0): - tmp = torch.sqrt(x0) - return tmp - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.sqrt_rn(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_sqrt_rn_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.sqrt_rn(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_sqrt_rn(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_sqrt_rn_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sqrt_rn_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_sqrt_rn(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_sqrt_rn_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_static_print_and_assert_op.py b/third_party/ascend/unittest/generalization_cases/test_static_print_and_assert_op.py deleted file mode 100644 index d383648c4c..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_static_print_and_assert_op.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -import torch_npu -import triton -import triton.language as tl -import pytest -import test_common - -import os - -os.environ["TRITON_ALWAYS_COMPILE"] = "1" -os.environ["PYTEST_ADDOPTS"] = "-sv" - -shape = (8, ) -XS = 8 -XVALS_INT = [ - 0, - torch.iinfo(torch.int8).min, - torch.iinfo(torch.int8).max, - torch.iinfo(torch.int16).min, - torch.iinfo(torch.int16).max, - torch.iinfo(torch.int32).min, - torch.iinfo(torch.int32).max, - torch.iinfo(torch.int32).max + 1 -] - - -def torch_func(x0, x1): - res = x0 + x1 - return res - - -@triton.jit -def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr): - idx = tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.load(in_ptr1 + idx) - tmp2 = tmp0 + tmp1 - tl.static_print(XBLOCK) - tl.static_print(tmp2) - tl.static_assert(XBLOCK == 8) - tl.store(out_ptr0 + idx, tmp2) - - -def triton_func(x0, x1, XS): - out = torch.empty_like(x0) - triton_kernel[ - 1, - ](out, x0, x1, XS) - return out - - -@pytest.mark.parametrize('sigtype', ['int32', 'int64', 'int16', 'int8', 'float32', 'float16', 'bfloat16']) -def test_static_print_and_assert(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS) - captured = capsys.readouterr() - - if sigtype == "float32": - assert "fp32" in captured.out - if sigtype == "float16": - assert "fp16" in captured.out - if sigtype == "bfloat16": - assert "bf16" in captured.out - if "int" in sigtype: - assert sigtype in captured.out - assert "8" in captured.out - - test_common.validate_cmp(sigtype, triton_cal, torch_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_sum.py b/third_party/ascend/unittest/generalization_cases/test_sum.py deleted file mode 100644 index e3caa4edbe..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sum.py +++ /dev/null @@ -1,332 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import random -import pytest -import torch -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, get_dtype_size - - -def torch_sum(x1, dim): - if x1.dtype == torch.float16 or x1.dtype == torch.bfloat16: - res = torch.sum(x1.to(torch.float32), dim=dim, keepdim=False).to(x1.dtype) - else: - res = torch.sum(x1, dim=dim, keepdim=False).to(x1.dtype) - return res - - -@triton.jit -def tt_sum_1d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - idx = tl.arange(0, XB) - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_sum_2d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - - if dim == 0: - oidx = yidx - else: - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_sum_1d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - idx = tl.arange(0, XB) - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_sum_2d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_sum_3d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - - tl.store(out_ptr, ret) - - -@triton.jit -def tt_sum_3d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - - if dim == 0: - oidx = yidx[:, None] * znumel + zidx[None, :] - elif dim == 1: - oidx = xidx[:, None] * znumel + zidx[None, :] - else: - oidx = xidx[:, None] * ynumel + yidx[None, :] - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_sum_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.sum(x, 0) - ret = tl.sum(tmp, 0) - oidx = zidx - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_sum_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.sum(x, 0) - ret = tl.sum(tmp, 1) - oidx = yidx - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_sum_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.sum(x, 1) - ret = tl.sum(tmp, 1) - oidx = xidx - - tl.store(out_ptr + oidx, ret) - - -def is_legal_combine(shape, dims): - return dims is None or (len(shape) == 3) or \ - (len(dims) == 1 and dims[0] < len(shape)) - - -dims_map = {(0, 1): tt_sum_3d_0_1, (1, 2): tt_sum_3d_1_2, (0, 2): tt_sum_3d_0_2} - -shape_map = { - 1: {"append_shape": (1, 1), "func": tt_sum_1d}, 2: {"append_shape": (1, ), "func": tt_sum_2d}, 3: - {"append_shape": (), "func": tt_sum_3d} -} - - -def reduce_check_ub_mem_overflow(dtype, shape): - dtype_size = get_dtype_size(dtype) - if (dtype == "int8" or dtype == "bool") and dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): - pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 6): - pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") - - -@pytest.mark.parametrize('shape', random.sample(TestUtils.full_shape, 5)) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (0, 1), (1, 2), (0, 2)]) -def test_sum(dtype, shape, dims): - if not is_legal_combine(shape, dims): - return - - torch.manual_seed(0) - x = test_common.generate_tensor(shape, dtype).npu() - grid = (1, 1, 1) - - y_ref = torch_sum(x, dims) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - if dims is None: - reduce_check_ub_mem_overflow(dtype, shape) - append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] - xnumel, ynumel, znumel = shape + append_shape - XB, YB, ZB = xnumel, ynumel, znumel - if len(shape) == 1: - tt_sum_1d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - if len(shape) == 2: - tt_sum_2d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - if len(shape) == 3: - tt_sum_3d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - - test_common.validate_cmp(dtype, y_cal, y_ref) - - elif len(dims) == 1: # 1d sum, 1-3d shape - append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] - xnumel, ynumel, znumel = shape + append_shape - XB, YB, ZB = xnumel, ynumel, znumel - if (len(shape) == 2) and (x.numel() * x.element_size() > 8192): - if dims[0] == 0: - grid = (1, ynumel, 1) - YB = 1 - else: - grid = (xnumel, 1, 1) - XB = 1 - tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) - test_common.validate_cmp(dtype, y_cal, y_ref) - else: # 3d shape, 2d sum - tt_kernel = dims_map[dims] - xnumel, ynumel, znumel = shape - XB, YB, ZB = xnumel, ynumel, znumel - tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) - test_common.validate_cmp(dtype, y_cal, y_ref) - - -@triton.jit -def triton_sum_multi_d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr, REDUCE_NUMEL: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - x = tl.load(in_ptr + offsets) - - if DIM is not None: - ret = tl.reshape(tl.sum(x, DIM), REDUCE_NUMEL) - o_offsets = tl.arange(0, REDUCE_NUMEL) - tl.store(out_ptr + o_offsets, ret) - else: - ret = tl.sum(x, DIM) - tl.store(out_ptr, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (4, 2, 8, 4), - (4, 3, 8, 1), -]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (3, )]) -def test_sum_4d(dtype, shape, dims): - torch.manual_seed(0) - - x = test_common.generate_tensor(shape, dtype).npu() - dim = dims[0] if dims is not None else None - - y_ref = torch_sum(x, dim) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_sum_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, y_cal, y_ref) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 2, 8, 4), - (3, 4, 2, 8, 1), -]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (3, ), (4, )]) -def test_sum_5d(dtype, shape, dims): - torch.manual_seed(0) - - x = test_common.generate_tensor(shape, dtype).npu() - dim = dims[0] if dims is not None else None - - y_ref = torch_sum(x, dim) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_sum_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_sum_dim0.py b/third_party/ascend/unittest/generalization_cases/test_sum_dim0.py deleted file mode 100644 index 9ef39548d7..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sum_dim0.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, get_dtype_size -import math - - -def torch_sum(x0): - res = torch.sum(x0, 0) - return res - - -@triton.jit -def triton_sum(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr, RBLOCK_SUB: tl.constexpr): - xindex = tl.arange(0, XBLOCK) - xmask = xindex[:, None] < xnumel - for roffset_sub in range(0, RBLOCK, RBLOCK_SUB): - rindex = roffset_sub + tl.arange(0, RBLOCK_SUB) - x0 = xindex - r1 = rindex - rmask = rindex < rnumel - tmp0 = tl.load(in_ptr0 + (r1 + (RBLOCK * x0[:, None])), xmask & rmask) - tmp2 = tl.reshape(tmp0, [XBLOCK, RBLOCK_SUB]) - tmp4 = tl.sum(tmp2, 0) - tl.store(out_ptr1 + (rindex), tmp4, rmask) - - -def should_skip_due_to_mem(dtype, shape): - dtype_size = get_dtype_size(dtype) - total_mem = dtype_size * math.prod(shape) - threshold = TestUtils.ub_size / 1.5 - - if total_mem >= threshold: - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'int32']) -def test_case(dtype, shape): - should_skip_due_to_mem(dtype, shape) - x0 = test_common.generate_tensor(shape, dtype).npu() - - rblock = shape[1] - xblock = shape[0] - ncore = 1 #if numel <= 32 else 32 - rblock_sub = rblock #if xblock <= 16 else 16 - RBLOCK_tl = 256 if rblock > 1 else 1 - - y_ref = torch_sum(x0) - y_cal = torch.zeros(shape[1], dtype=eval('torch.' + dtype)).npu() - triton_sum[ncore, 1, 1](x0, y_cal, xblock, rblock, xblock, rblock, rblock_sub) - test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_sum_dim1.py b/third_party/ascend/unittest/generalization_cases/test_sum_dim1.py deleted file mode 100644 index dd304da524..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sum_dim1.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_sum(x0): - res = torch.sum(x0, 1) - return res - - -@triton.jit -def triton_sum(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr, RBLOCK: tl.constexpr): - xoffset = tl.program_id(0) * XBLOCK - rindex = tl.arange(0, RBLOCK)[None, :] - rmask = rindex < rnumel - for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): - xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB) - x0 = xindex - r1 = rindex - xmask = xindex[:, None] < xnumel - xmask_prime = xindex < xnumel - tmp0 = tl.load(in_ptr0 + (r1 + (RBLOCK * x0[:, None])), rmask & xmask) - tmp2 = tl.reshape(tmp0, [XBLOCK_SUB, RBLOCK]) - tmp4 = tl.sum(tmp2, 1) - tl.store(out_ptr1 + (xindex), tmp4, xmask_prime) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int32']) -def test_case(dtype, shape): - x0 = test_common.generate_tensor(shape, dtype).npu() - - rblock = shape[1] - xblock = shape[0] - ncore = 1 #if numel <= 32 else 32 - xblock_sub = xblock if xblock <= 16 else 16 - RBLOCK_tl = 256 if rblock > 1 else 1 - - y_ref = torch_sum(x0) - y_cal = torch.zeros(shape[:-1], dtype=eval('torch.' + dtype)).npu() - triton_sum[ncore, 1, 1](x0, y_cal, xblock, rblock, xblock, xblock_sub, rblock) - test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_swizzle2d.py b/third_party/ascend/unittest/generalization_cases/test_swizzle2d.py deleted file mode 100644 index 46bd2ebdf8..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_swizzle2d.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import random -import triton -import triton.language as tl -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils - - -def swizzle2d(size_i, size_j, size_g): - i = torch.arange(0, size_i)[:, None] - j = torch.arange(0, size_j)[None, :] - ij = i * size_j + j - size_gj = size_g * size_j - group_id = ij // size_gj - off_i = group_id * size_g - size_g = torch.min(size_i - off_i, torch.tensor(size_g).expand_as(off_i)) - ij = ij % size_gj - new_i = off_i + ij % size_g - new_j = ij // size_g - ret = new_i * size_i + new_j - return ret - - -@triton.jit -def fn_npu_(out0, out1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - i = tl.arange(0, XB)[:, None] - j = tl.arange(0, YB)[None, :] - ij = i * YB + j - xx, yy = tl.swizzle2d(i, j, size_i=XB, size_j=YB, size_g=ZB) - - ptr = tl.load(out0) - xx = tl.cast(xx, dtype=ptr.dtype) - yy = tl.cast(yy, dtype=ptr.dtype) - tl.store(out0 + ij, xx) - tl.store(out1 + ij, yy) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_swizzle2d(shape, dtype): - if (shape[0] > 255) or (shape[1] > 255): - return - size_g = random.randint(1, min(shape[0], shape[1])) - ans = swizzle2d(shape[0], shape[1], size_g).to(eval('torch.' + dtype)).npu() - - out0 = test_common.generate_tensor(shape, dtype).npu() - out1 = test_common.generate_tensor(shape, dtype).npu() - fn_npu_[1, 1, 1](out0, out1, shape[0], shape[1], size_g) - triton_ret = out0 * shape[0] + out1 - torch.testing.assert_close(triton_ret, ans) diff --git a/third_party/ascend/unittest/generalization_cases/test_trans_1d_2d.py b/third_party/ascend/unittest/generalization_cases/test_trans_1d_2d.py deleted file mode 100644 index d56ec1bbb4..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_trans_1d_2d.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, xnumel: tl.constexpr): - idx = tl.arange(0, xnumel) - - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 0) - - tl.store(output_ptr + idx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_trans_1d(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() - - triton_res = torch.randint(1, shape, dtype=data_type).npu() - torch_res = torch.permute(x, (0, )) - fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -@triton.jit -def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = yidx[:, None] * ZB + zidx[None, :] - - # XB,YB,1 - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 1, 0) - - oidx = zidx[:, None] * YB + yidx[None, :] - - tl.store(output_ptr + oidx, ret) - - -bisheng_notsupport_dtype = ['int64'] -tritonascend_notsupport_dtype = ['bool'] -# check_ub_mem_overflow没拦住,在kernel中最大ub占用超过ubsize -mem_overflow_scene = [ - ('bfloat16', (128, 256)), - ('bfloat16', (256, 128)), - ('int8', (741, 256)), - ('int8', (256, 741)), - ('int16', (256, 256)), - ('float16', (256, 256)), - ('bfloat16', (256, 256)), - ('int32', (128, 256)), - ('int32', (256, 128)), - ('float32', (128, 256)), - ('float32', (256, 128)), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_permute(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - if dtype in bisheng_notsupport_dtype or dtype in tritonascend_notsupport_dtype: - return - if (dtype, shape) in mem_overflow_scene: - return - if check_ub_mem_overflow(dtype, shape): - return - YB = shape[0] - ZB = shape[1] - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=(YB, ZB), dtype=data_type).npu() - - triton_res = torch.randint(1, (ZB, YB), dtype=data_type).npu() - torch_res = torch.permute(x, (1, 0)) - fn_npu_021[1, 1, 1](triton_res, x, YB, ZB) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -if __name__ == "__main__": - for shape in [(37, 3)]: - for dtype in TestUtils.dtype_list: - test_permute(shape, dtype) diff --git a/third_party/ascend/unittest/generalization_cases/test_trans_3d.py b/third_party/ascend/unittest/generalization_cases/test_trans_3d.py deleted file mode 100644 index 6f8428e575..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_trans_3d.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def fn_npu_102(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 1, 0, 2) - - oidx = zidx[:, None, None] * YB * KB + yidx[None, :, None] * KB + kidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_210(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 2, 1, 0) - - oidx = kidx[:, None, None] * ZB * YB + zidx[None, :, None] * YB + yidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 0, 2, 1) - - oidx = yidx[:, None, None] * ZB * KB + kidx[None, :, None] * ZB + zidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -bisheng_notsupport_dtype = [] -tritonascend_notsupport_dtype = ['bool'] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_permute_3d(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - if dtype in bisheng_notsupport_dtype or dtype in tritonascend_notsupport_dtype: - return - if check_ub_mem_overflow(dtype, shape): - return - - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() - - triton_res = torch.empty((shape[1], shape[0], shape[2]), dtype=data_type).npu() - torch_res = torch.permute(x, (1, 0, 2)) - fn_npu_102[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - test_common.validate_cmp(dtype, triton_res, torch_res) - - # not support yet: need bisheng support later - # triton_res = torch.empty((shape[2], shape[1], shape[0]), dtype=data_type).npu() - # torch_res = torch.permute(x, (2, 1, 0)) - # fn_npu_210[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - # test_common.validate_cmp(dtype, triton_res, torch_res) - - triton_res = torch.empty((shape[0], shape[2], shape[1]), dtype=data_type).npu() - torch_res = torch.permute(x, (0, 2, 1)) - fn_npu_021[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -if __name__ == "__main__": - for shape in [(1, 22, 39)]: - for dtype in TestUtils.dtype_list: - test_permute_3d(shape, dtype) - - -@triton.jit -def fn_npu_102(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 1, 0, 2) - - oidx = (zidx[:, None, None] * YB * KB + yidx[None, :, None] * KB + kidx[None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('sigtype, dtype, XB, YB, ZB', [ - ('bfloat16', torch.bfloat16, 2, 8, 4), - ('uint8', torch.uint8, 1, 256, 16), - ('bool', torch.bool, 1, 1, 2), -]) -def test_permute_3d_u(sigtype, dtype, XB, YB, ZB): - x = test_common.generate_tensor((XB, YB, ZB), sigtype).npu() - triton_res = torch.empty((YB, XB, ZB), dtype=dtype).npu() - torch_res = torch.permute(x, (1, 0, 2)) - fn_npu_102[1, 1, 1](triton_res, x, XB, YB, ZB) - test_common.validate_cmp(sigtype, triton_res, torch_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_trans_4d_5d.py b/third_party/ascend/unittest/generalization_cases/test_trans_4d_5d.py deleted file mode 100644 index 8505e974f3..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_trans_4d_5d.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def triton_trans_4d( - output_ptr, - x_ptr, - PERM: tl.constexpr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] - tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] - tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None] - tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] - tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None] - tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] - tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None] - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) - x_val = tl.load(x_ptr + offsets, masks) - - if PERM == 0: # 1, 0, 2, 3 - ret = tl.trans(x_val, (1, 0, 2, 3)) - shape0 = SHAPE_1 - shape1 = SHAPE_0 - shape2 = SHAPE_2 - shape3 = SHAPE_3 - elif PERM == 1: # 0, 2, 1, 3 - ret = tl.trans(x_val, (0, 2, 1, 3)) - shape0 = SHAPE_0 - shape1 = SHAPE_2 - shape2 = SHAPE_1 - shape3 = SHAPE_3 - else: # 0, 1, 3, 2 - ret = tl.trans(x_val, (0, 1, 3, 2)) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_3 - shape3 = SHAPE_2 - - s3 = 1 - s2 = s3 * shape3 - s1 = s2 * shape2 - s0 = s1 * shape1 - - if PERM == 0: # 1, 0, 2, 3 - out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 - out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) - elif PERM == 1: # 0, 2, 1, 3 - out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 - out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) - else: # 0, 1, 3, 2 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@triton.jit -def triton_trans_5d(output_ptr, x_ptr, PERM: tl.constexpr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] - tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] - - tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None, None] - tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None, None] - - tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None, None] - tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None, None] - - tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :, None] - tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None, None] - - tmp3_4 = tl.arange(0, BLOCK_3)[None, None, None, None, :] - tmp4_3 = tl.arange(0, BLOCK_4)[None, None, None, :, None] - - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) - x_val = tl.load(x_ptr + offsets, masks) - - if PERM == 0: # 1, 0, 2, 3, 4 - ret = tl.trans(x_val, 1, 0, 2, 3, 4) - shape0 = SHAPE_1 - shape1 = SHAPE_0 - shape2 = SHAPE_2 - shape3 = SHAPE_3 - shape4 = SHAPE_4 - elif PERM == 1: # 0, 2, 1, 3, 4 - ret = tl.trans(x_val, 0, 2, 1, 3, 4) - shape0 = SHAPE_0 - shape1 = SHAPE_2 - shape2 = SHAPE_1 - shape3 = SHAPE_3 - shape4 = SHAPE_4 - elif PERM == 2: # 0, 1, 3, 2, 4 - ret = tl.trans(x_val, 0, 1, 3, 2, 4) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_3 - shape3 = SHAPE_2 - shape4 = SHAPE_4 - else: # 0, 1, 2, 4, 3 - ret = tl.trans(x_val, 0, 1, 2, 4, 3) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_2 - shape3 = SHAPE_4 - shape4 = SHAPE_3 - - s4 = 1 - s3 = s4 * shape4 - s2 = s3 * shape3 - s1 = s2 * shape2 - s0 = s1 * shape1 - - if PERM == 0: # 1, 0, 2, 3, 4 - out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 + tmp4 * s4 - out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) & (tmp4 < shape4) - elif PERM == 1: # 0, 2, 1, 3, 4 - out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 + tmp4 * s4 - out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) & (tmp4 < shape4) - elif PERM == 2: # 0, 1, 3, 2, 4 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 + tmp4 * s4 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) & (tmp4 < shape4) - else: # 0, 1, 2, 4, 3 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp2 * s2 + tmp4_3 * s3 + tmp3_4 * s4 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp2 < shape2) & (tmp4_3 < shape3) & (tmp3_4 < shape4) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -@pytest.mark.parametrize('perm', [0, 1, 2, 3]) # 4d: support 3 mode; 5d: support 4 mode -def test_trans_4d_5d(shape, dtype, perm): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.randint(low=0, high=2, size=shape, dtype=eval('torch.' + dtype)).npu() - grid = (1, ) - if len(shape) == 4: - blocks = list(x.size()) - strides = list(x.stride()) - if perm == 0: # 1, 0, 2, 3; exchange axis 0, 1 - output = torch.empty((shape[1], shape[0], shape[2], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (1, 0, 2, 3)) - triton_trans_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - elif perm == 1: # 0, 2, 1, 3; exchange axis 1, 2 - output = torch.empty((shape[0], shape[2], shape[1], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (0, 2, 1, 3)) - triton_trans_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - elif perm == 2: # 0, 1, 3, 2; exchange axis 2, 3 - output = torch.empty((shape[0], shape[1], shape[3], shape[2]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (0, 1, 3, 2)) - triton_trans_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - else: - pass - else: - blocks = list(x.size()) - strides = list(x.stride()) - - if perm == 0: # 1, 0, 2, 3, 4; exchange axis 0, 1 - output = torch.empty((shape[1], shape[0], shape[2], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (1, 0, 2, 3, 4)) - elif perm == 1: # 0, 2, 1, 3, 4; exchange axis 1, 2 - output = torch.empty((shape[0], shape[2], shape[1], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 2, 1, 3, 4)) - elif perm == 2: # 0, 1, 3, 2, 4; exchange axis 2, 3 - output = torch.empty((shape[0], shape[1], shape[3], shape[2], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 1, 3, 2, 4)) - else: # 0, 1, 2, 4, 3; exchange axis 3, 4 - output = torch.empty((shape[0], shape[1], shape[2], shape[4], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 1, 2, 4, 3)) - triton_trans_5d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_5d, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_umulhi.py b/third_party/ascend/unittest/generalization_cases/test_umulhi.py deleted file mode 100644 index 421fc77322..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_umulhi.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import triton -import torch -import pytest -import test_common - -import numpy as np -import triton.language as tl -from test_common import TestUtils - - -# inp the two 32 bit signed integers. -@triton.jit -def umulhi_kernel(X, Y, Z, N: tl.constexpr): - offs = tl.arange(0, N) - x = tl.load(X + offs) - y = tl.load(Y + offs) - z = tl.umulhi(x, y) - tl.store(Z + tl.arange(0, N), z) - - -@triton.jit -def triton_umulhi_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.umulhi(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -# accuracy reference -def umulhi32(a, b): - a_64 = a.astype(np.int64) - b_64 = b.astype(np.int64) - product_64 = a_64 * b_64 - # get the high part - result_high_32 = product_64 >> 32 - return result_high_32.astype(np.int32) - - -@pytest.mark.parametrize('dtype', ['int32']) -@pytest.mark.parametrize('shape', TestUtils.full_shape) -def test_case2(dtype, shape): - N = shape[0] - dtypes = eval('torch.' + dtype) - x = torch.randint(low=0, high=2000, size=shape, dtype=dtypes) - y = torch.randint(low=0, high=2000, size=shape, dtype=dtypes) - xx = x.npu() - yy = y.npu() - z_tri = torch.zeros(size=shape, dtype=dtypes).npu() - umulhi_kernel[(1, )](xx, yy, z_tri, N=N) - - xxx = x.numpy() - yyy = y.numpy() - z_ref = umulhi32(xxx, yyy) - z_ref1 = torch.from_numpy(z_ref).npu() - torch.equal(z_tri, z_ref1) - - -invalid_types = [ - 'int8', - 'int16', - 'int64', - 'float16', - 'float32', - 'bfloat16', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_umulhi_invalid_dtype_case(dtype): - x0 = test_common.generate_tensor((1, ), dtype).npu() - x1 = test_common.generate_tensor((1, ), dtype).npu() - - y_cal = torch.zeros((1, ), dtype=eval('torch.' + dtype)).npu() - umulhi_kernel[(1, )](x0, x1, y_cal, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int32']) -def test_umulhi_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) - y = torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) - xx = x.npu() - yy = y.npu() - - output = torch.zeros(size=shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - xxx = x.numpy() - yyy = y.numpy() - z = umulhi32(xxx, yyy) - ans = torch.from_numpy(z).npu() - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_umulhi_4d_5d[grid](output, xx, yy, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_where.py b/third_party/ascend/unittest/generalization_cases/test_where.py deleted file mode 100644 index 79345b6f21..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_where.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_pointwise(x0, x1): - res = torch.where(x0 < x1, x0, 1) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - tmp2 = X < Y - ret = tl.where(tmp2, X, 1) - - tl.store(output_ptr + idx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['bool', 'float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - ans = torch_pointwise(x, y) - output = torch.zeros_like(ans) - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@triton.jit -def fn_npu_multi_d(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIMS: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - X = tl.load(x_ptr + offsets) - Y = tl.load(y_ptr + offsets) - - tmp2 = X < Y - ret = tl.where(tmp2, X, 1) - - tl.store(output_ptr + offsets, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (4, 2, 8, 4), - (2, 4, 2, 8, 1), - (4, 3, 8, 1), - (3, 4, 2, 8, 4), -]) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) -def test_case_4d_5d(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - ans = torch_pointwise(x, y) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - grid = (1, ) - fn_npu_multi_d[grid](output, x, y, *triton_shape, len(shape)) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_xor.py b/third_party/ascend/unittest/generalization_cases/test_xor.py deleted file mode 100644 index fe696552af..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_xor.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_xor(x0, x1): - return x0 ^ x1 - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X ^ Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_xor_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val ^ y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_xor(x, y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_xor_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x ^ y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_xor_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/unittest/generalization_cases/test_xorsum.py b/third_party/ascend/unittest/generalization_cases/test_xorsum.py deleted file mode 100644 index 633db01c15..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_xorsum.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import triton -import triton.language as tl -import torch -import torch_npu -import pytest -import test_common -import functools -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - - -# <<<<<<< test_xorsum_1d -def torch_xorsum(tensor, dim=None, keepdim=False): - if dim is None: - result = tensor.flatten()[0] - for x in tensor.flatten()[1:]: - result = result ^ x - return result - else: - assert dim < tensor.dim(), f"Invalid dim {dim} for tensor shape {tensor.shape}" - result = tensor.select(dim, 0) - for i in range(1, tensor.size(dim)): - result = result ^ tensor.select(dim, i) - if keepdim: - result = result.unsqueeze(dim) - return result - - -@triton.jit -def triton_xorsum_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None) - tmp4 = tl.xor_sum(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_xorsum_1d(dtype, shape): - if check_ub_mem_overflow(dtype, shape): - return - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty(1, dtype=eval("torch." + dtype)).npu() - numel = shape[0] - triton_xorsum_1d[1, 1, 1](x0, triton_res, numel, numel) - torch_res = torch_xorsum(x0, dim=0, keepdim=True) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_xorsum_1d - - -# <<<<<<< test_xorsum_2d -@triton.jit -def triton_xorsum_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) - tmp4 = tl.xor_sum(x, dim) - if dim == 0: - tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) - else: - tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize('dim', [0, 1]) -def test_xorsum_2d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype in ['int8', 'int16', 'int32', 'int64']: - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - elif dtype in ['bool']: - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 5): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - shapex, shapey = shape - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[1 - dim], - ], dtype=eval("torch." + dtype)).npu() - triton_xorsum_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) - torch_res = torch_xorsum(x0, dim=dim, keepdim=False) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_xorsum_2d - - -# <<<<<<< test_xorsum_3d -def torch_xorsum_3d(x0, no_reduce_dim): - inp = x0 if x0.device == "cpu" else x0.cpu() - if no_reduce_dim == 0: - return torch_xorsum(torch_xorsum(inp, 1), 1).npu() - elif no_reduce_dim == 1: - return torch_xorsum(torch_xorsum(inp, 0), 1).npu() - elif no_reduce_dim == 2: - return torch_xorsum(torch_xorsum(inp, 0), 0).npu() - else: - assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" - - -@triton.jit -def triton_xorsum_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.xor_sum(x, 0) - ret = tl.xor_sum(tmp, 0) - oidx = zidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_xorsum_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.xor_sum(x, 0) - ret = tl.xor_sum(tmp, 1) - oidx = yidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_xorsum_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.xor_sum(x, 1) - ret = tl.xor_sum(tmp, 1) - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -def triton_xorsum_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): - if no_reduce_dim == 0: - triton_xorsum_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 1: - triton_xorsum_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 2: - triton_xorsum_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) -def test_xorsum_3d(dtype, shape, no_reduce_dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[no_reduce_dim], - ], dtype=eval("torch." + dtype)).npu() - triton_xorsum_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) - torch_res = torch_xorsum_3d(x0, no_reduce_dim) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_xorsum_3d - - -# <<<<<<< test_xorsum_4d -@triton.jit -def triton_xorsum_multi_d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr, REDUCE_NUMEL: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - x = tl.load(in_ptr + offsets) - - if DIM is not None: - ret = tl.reshape(tl.xor_sum(x, DIM), REDUCE_NUMEL) - o_offsets = tl.arange(0, REDUCE_NUMEL) - tl.store(out_ptr + o_offsets, ret) - else: - ret = tl.xor_sum(x, DIM) - tl.store(out_ptr, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (4, 2, 8, 4), - (4, 3, 8, 1), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize('dim', [0, 1, 2, 3]) -def test_xorsum_4d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype in ['int8', 'int16', 'int32', 'int64']: - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - print(f"dtype:{dtype} shape:{shape} mem overflow") - return - - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_xorsum(x0, dim=dim, keepdim=False) - triton_res = torch.empty_like(torch_res, dtype=eval("torch." + dtype)).npu() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_xorsum_multi_d[grid](x0, triton_res, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_xorsum_4d - - -# <<<<<<< test_xorsum_5d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 2, 8, 4), - (3, 4, 2, 8, 1), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize('dim', [0, 1, 2, 3, 4]) -def test_xorsum_5d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype in ['int8', 'int16', 'int32', 'int64']: - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - print(f"dtype:{dtype} shape:{shape} mem overflow") - return - - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_xorsum(x0, dim=dim, keepdim=False) - triton_res = torch.empty_like(torch_res, dtype=eval("torch." + dtype)).npu() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_xorsum_multi_d[grid](x0, triton_res, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_xorsum_5d - -if __name__ == "__main__": - test_xorsum_3d('int8', (3, 3, 3), 0) diff --git a/third_party/ascend/unittest/generalization_cases/test_zeros_op.py b/third_party/ascend/unittest/generalization_cases/test_zeros_op.py deleted file mode 100644 index 7e5304d153..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_zeros_op.py +++ /dev/null @@ -1,534 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import random -import torch -import torch_npu -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, check_ub_mem_overflow - - -@triton.jit -def fn_npu_int8_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int8) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int32) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int64_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int64) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.float16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.float32) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bf16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.bfloat16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bool_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int1) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int8_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int8) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int32) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int64_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int64) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.float16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.float32) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bf16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.bfloat16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bool_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int1) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int8_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.int8) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int16_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.int16) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int32_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.int32) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int64_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.int64) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp16_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.float16) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp32_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.float32) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bf16_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.bfloat16) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bool_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.int1) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int8_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.int8) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_int16_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.int16) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_int32_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.int32) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_int64_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.int64) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_fp16_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.float16) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_fp32_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.float32) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_bf16_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.bfloat16) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_bool_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.int1) - tl.store(output_ptr, zero) - - -test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -test_shape0d = [()] -test_shape1d = TestUtils.test_shape1d -test_shape2d = TestUtils.test_shape2d -test_shape3d = TestUtils.test_shape3d - -# 定义 dtype 到 (test_func, test_sigtype) 的映射 -dtype_mapping3d = { - 'int8': (fn_npu_int8_3d, torch.int8), - 'int16': (fn_npu_int16_3d, torch.int16), - 'int32': (fn_npu_int32_3d, torch.int32), - 'int64': (fn_npu_int64_3d, torch.int64), - 'float16': (fn_npu_fp16_3d, torch.float16), - 'float32': (fn_npu_fp32_3d, torch.float32), - 'bfloat16': (fn_npu_bf16_3d, torch.bfloat16), - 'bool': (fn_npu_bool_3d, torch.bool), -} -dtype_mapping2d = { - 'int8': (fn_npu_int8_2d, torch.int8), - 'int16': (fn_npu_int16_2d, torch.int16), - 'int32': (fn_npu_int32_2d, torch.int32), - 'int64': (fn_npu_int64_2d, torch.int64), - 'float16': (fn_npu_fp16_2d, torch.float16), - 'float32': (fn_npu_fp32_2d, torch.float32), - 'bfloat16': (fn_npu_bf16_2d, torch.bfloat16), - 'bool': (fn_npu_bool_2d, torch.bool), -} -dtype_mapping1d = { - 'int8': (fn_npu_int8_1d, torch.int8), - 'int16': (fn_npu_int16_1d, torch.int16), - 'int32': (fn_npu_int32_1d, torch.int32), - 'int64': (fn_npu_int64_1d, torch.int64), - 'float16': (fn_npu_fp16_1d, torch.float16), - 'float32': (fn_npu_fp32_1d, torch.float32), - 'bfloat16': (fn_npu_bf16_1d, torch.bfloat16), - 'bool': (fn_npu_bool_1d, torch.bool), -} -dtype_mapping0d = { - 'int8': (fn_npu_int8_0d, torch.int8), - 'int16': (fn_npu_int16_0d, torch.int16), - 'int32': (fn_npu_int32_0d, torch.int32), - 'int64': (fn_npu_int64_0d, torch.int64), - 'float16': (fn_npu_fp16_0d, torch.float16), - 'float32': (fn_npu_fp32_0d, torch.float32), - 'bfloat16': (fn_npu_bf16_0d, torch.bfloat16), - 'bool': (fn_npu_bool_0d, torch.bool), -} - -# 生成测试用例 -testlist = [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape0d - for func, dtype in [dtype_mapping0d[sigtype]] # 直接解包映射结果 - ] - -testlist += [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape1d - for func, dtype in [dtype_mapping1d[sigtype]] # 直接解包映射结果 - ] - -testlist += [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape2d - for func, dtype in [dtype_mapping2d[sigtype]] # 直接解包映射结果 - ] - -testlist += [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape3d - for func, dtype in [dtype_mapping3d[sigtype]] # 直接解包映射结果 - ] - - -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist) -def test_npu(testfunc, sigtype, dtype, shape): - if check_ub_mem_overflow(sigtype, shape): - pytest.skip(f"dtype:{sigtype} shape:{shape} mem overflow") - x = 0 - output = 0 - if len(shape) == 3: - x = torch.full((shape[0], shape[1], shape[2]), 0, dtype=dtype).npu() - output = torch.randint(1, (shape[0], shape[1], shape[2]), dtype=dtype).npu() - testfunc[(1, 1, 1)](output, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2]) - if len(shape) == 2: - x = torch.full((shape[0], shape[1]), 0, dtype=dtype).npu() - output = torch.randint(1, (shape[0], shape[1]), dtype=dtype).npu() - shape0 = shape[0] - shape1 = shape[1] - if x.numel() * x.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - testfunc[grid](output, shape0, shape1, shape0, shape1) - if len(shape) == 1: - x = torch.full((shape[0], ), 0, dtype=dtype).npu() - output = torch.randint(1, (shape[0], ), dtype=dtype).npu() - testfunc[1, 1, 1](output, shape[0], shape[0]) - if len(shape) == 0: - output = torch.randint(1, size=shape, dtype=dtype).npu() - x = torch.zeros_like(output) - testfunc[(1, )](output_ptr=output, N=1) - test_common.validate_cmp(sigtype, output, x) - - -@triton.jit -def fn_npu_multi_d(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - dtype = output_ptr.type.element_ty - - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - if (YB * ZB * MB * NB) == 1: - ret = tl.zeros((XB, ), dtype=dtype) - elif (ZB * MB * NB) == 1: - ret = tl.zeros((XB, YB), dtype=dtype) - elif (MB * NB) == 1: - ret = tl.zeros((XB, YB, ZB), dtype=dtype) - elif NB == 1: - ret = tl.zeros((XB, YB, ZB, MB), dtype=dtype) - else: - ret = tl.zeros((XB, YB, ZB, MB, NB), dtype=dtype) - - tl.store(output_ptr + offsets, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('param_list', [ - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), -]) -def test_case_4d_5d(param_list): - dtype, shape = param_list - if check_ub_mem_overflow(dtype, shape): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - y_ref = torch.full(shape, 0, dtype=eval('torch.' + dtype)).npu() - print(f"y_ref = {torch.flatten(y_ref)[0:4]}") - - y_cal = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - fn_npu_multi_d[(1, )](y_cal, *triton_shape) - print(f"y_cal = {torch.flatten(y_cal)[0:4]}") - test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_zeroslike.py b/third_party/ascend/unittest/generalization_cases/test_zeroslike.py deleted file mode 100644 index 014ba4bbdc..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_zeroslike.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, check_ub_mem_overflow - - -@triton.jit -def fn_npu_0d(output_ptr, x_ptr, YB: tl.constexpr): - yidx = tl.arange(0, YB) - - idx = yidx - - X = tl.load(x_ptr + idx) - - ret = tl.zeros_like(X) - - oidx = yidx - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, YB: tl.constexpr): - yidx = tl.arange(0, YB) - - idx = yidx - - X = tl.load(x_ptr + idx) - - ret = tl.zeros_like(X) - - oidx = yidx - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_2d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): - pid = tl.program_id(0) - yidx = tl.arange(0, YB)[:, None] + pid * YB - zidx = tl.arange(0, ZB)[None, :] - - idx = yidx * ZB + zidx - - X = tl.load(x_ptr + idx) - - ret = tl.zeros_like(X) - - oidx = yidx * ZB + zidx - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_3d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB)[:, None, None] * ZB * KB - zidx = tl.arange(0, ZB)[None, :, None] * KB - kidx = tl.arange(0, KB)[None, None, :] - - idx = yidx + zidx + kidx - - X = tl.load(x_ptr + idx) - - ret = tl.zeros_like(X) - - oidx = yidx + zidx + kidx - - tl.store(output_ptr + oidx, ret) - - -test_shape0d = [()] -testlist = test_shape0d + TestUtils.test_shape1_2_3d - - -@pytest.mark.parametrize('shape', testlist) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_npu(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - if check_ub_mem_overflow(dtype, shape): - return - x = torch.full(shape, 0, dtype=eval('torch.' + dtype)).npu() - triton_res = torch.empty(shape, dtype=eval('torch.' + dtype)).npu() - torch_res = x - - if len(shape) == 0: - fn_npu_0d[1, 1, 1](triton_res, x, 1) - elif len(shape) == 1: - fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) - elif len(shape) == 2: - fn_npu_2d[shape[0], 1, 1](triton_res, x, 1, shape[1]) - elif len(shape) == 3: - fn_npu_3d[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - - test_common.validate_cmp(dtype, triton_res, torch_res) - - -@triton.jit -def fn_npu_multi_d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - X = tl.load(x_ptr + offsets) - ret = tl.zeros_like(X) - - tl.store(output_ptr + offsets, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('param_list', [ - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), -]) -def test_case_4d_5d(param_list): - dtype, shape = param_list - if check_ub_mem_overflow(dtype, shape): - return - x0 = test_common.generate_tensor(shape, dtype) - y_ref = torch.zeros_like(x0, dtype=eval('torch.' + dtype)).npu() - print(f"y_ref = {torch.flatten(y_ref)[0:4]}") - y_cal = torch.ones(shape, dtype=eval('torch.' + dtype)).npu() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - fn_npu_multi_d[(1, )](y_cal, x0, *triton_shape) - print(f"y_cal = {torch.flatten(y_cal)[0:4]}") - test_common.validate_cmp(dtype, y_cal, y_ref) - - -if __name__ == "__main__": - for dtype in TestUtils.dtype_list: - for shape in [(37, ), (37, 3), (1, 22, 39)]: - test_npu(shape, dtype) diff --git a/third_party/ascend/unittest/kernels/README.md b/third_party/ascend/unittest/kernels/README.md deleted file mode 100644 index 20eb7e42aa..0000000000 --- a/third_party/ascend/unittest/kernels/README.md +++ /dev/null @@ -1,62 +0,0 @@ -# 指导:如何新增kernel测试用例 -新增kernel测试用例可以分为三大步: -1、准备pt文件 -2、在triton-ascend仓中添加kernel算子,完成本地kernel测试 -3、将pt文件上传到obs桶中 - -## 1、准备pt文件 - -pt 文件用于把 GPU(或参考实现)上的输入与输出作为 golden 数据,后续测试会在 NPU 上运行 Triton kernel 并与之比对。 - -**三步生成流程** - -- **步骤 1 — 构造GPU输入并保存副本预处理成NPU kernel的输入**:根据GPU上kernel或pytorch算子的参数构造 `input_data`(键名须与 kernel 参数一致),把所有 Tensor 克隆到 CPU,形成 `input_data_before`,若GPU上算子的输入和NPU上算子有出入,需要提前预处理使`input_data_before`符合NPU上算子入参的要求。 -- **步骤 2 — 运行GPU Kernel获取输出**:在GPU上运行GPU kernel,得到 `gpu_output`,并将 Tensor 转为 CPU。 -- **步骤 3 — 打包并保存**:把 `input_data_before`、`grid`、`gpu_output` 封装为字典,通过 `torch.save` 保存为 `{kernel_name}.pt`。如果有多组用例,保存为 list-of-dicts(`[case0, case1]`)。 - -**精简示例** - -```python -import copy -import torch - -DEVICE = torch.device("cuda:0") -batch_size = 2 -grid = (batch_size,) - -input_data = { - "output_token_ids_ptr": torch.zeros((batch_size, 4), dtype=torch.int32, device=DEVICE), - "cu_num_draft_tokens_ptr": torch.tensor([2, 1], dtype=torch.int32, device=DEVICE), - # ... 其它字段 -} - -# 保存输入副本到 CPU -input_data_before = { - k: (v.clone().cpu() if isinstance(v, torch.Tensor) else copy.deepcopy(v)) - for k, v in input_data.items() -} -# 预处理 input_data_before 符合 NPU kernel 输入 -input_data_before["npu_need_param_key"] = NPU_NEED_PARAMS_VALUE -# 运行 kernel(在 GPU / 参考实现上)并收集输出 -triton_kernel[grid](**input_data) -# 这里用 input_data 作为示例,实际应调用对应的 triton/pytorch 函数 -gpu_output = {k: (v.cpu() if isinstance(v, torch.Tensor) else v) for k, v in input_data.items()} - -save_obj = {"input_data": input_data_before, "grid": grid, "gpu_output": gpu_output} -torch.save(save_obj, ".pt") -# 多组用例场景:torch.save([save_obj1, save_obj2], ".pt") -``` - -## 2、在triton-ascend新增三方kernel测试用例 - -- **步骤 1 — 在triton-ascend仓中新增kernel算子** :本地验证阶段,在 kernels/xxx(例如vllm、sglang) 下新增与算子同名的 Python 文件,内容为Triton kernel函数。 -- **步骤 2 — 本地测试** :将pt文件放在kernels目录下,在项目根目录运行 -python -m pytest -v third_party/ascend/unittest/kernels/test_triton_kernel.py - -**说明** -- 指定单个 kernel:在项目根目录下执行 python -m pytest -v ascend/test/common/test_triton_kernel.py --kernel={kernel_name} -- pt文件查找策略:优先使用仓库内匹配的本地 pt,若本地不存在则按需从远端 OBS 下载 {kernel_name}.pt文件。 -- 本地已存在的pt文件,在执行完测试后不会删除,从obs桶取的文件在跑完测试后会被测试程序直接删除。 - -## 3、将pt文件上传至obs桶 -本地验证通过后,将pt文件统一上传到OBS桶当中,OBS桶链接:https://triton-ascend-artifacts.obs.cn-southwest-2.myhuaweicloud.com/test/kernels/{xxx}_pt/{kernel_name}.pt,xxx为vllm或sglang diff --git a/third_party/ascend/unittest/kernels/common_kernel.py b/third_party/ascend/unittest/kernels/common_kernel.py deleted file mode 100644 index fbbce42dd9..0000000000 --- a/third_party/ascend/unittest/kernels/common_kernel.py +++ /dev/null @@ -1,7 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit -def safe_exp(x): - return tl.exp(tl.where(x <= 0, x, float("-inf"))) diff --git a/third_party/ascend/unittest/kernels/test_common.py b/third_party/ascend/unittest/kernels/test_common.py deleted file mode 100644 index ababcb1540..0000000000 --- a/third_party/ascend/unittest/kernels/test_common.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import Optional -import torch -import pytest - -DEVICE_TYPE_NPU = 'npu' - - -def validate_cmp(dtype, y_cal, y_ref, overflow_mode: Optional[str] = None, device_type: Optional[str] = None): - if device_type is not None: - target_device = torch.device(device_type) - y_cal = y_cal.to(target_device) - y_ref = y_ref.to(target_device) - else: - y_cal = y_cal.npu() - y_ref = y_ref.npu() - if overflow_mode == "saturate": - if dtype in ['float32', 'float16']: - min_value = -torch.finfo(dtype).min - max_value = torch.finfo(dtype).max - elif dtype in ['int32', 'int16', 'int8']: - min_value = torch.iinfo(dtype).min - max_value = torch.iinfo(dtype).max - elif dtype == 'bool': - min_value = 0 - max_value = 1 - else: - raise ValueError('Invalid parameter "dtype" is found : {}'.format(dtype)) - y_ref = torch.clamp(y_ref, min=min_value, max=max_value) - if dtype == 'float16': - torch.testing.assert_close(y_ref, y_cal, rtol=5e-03, atol=5e-03, equal_nan=True) - elif dtype == 'bfloat16': - torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=5e-03, atol=5e-03, - equal_nan=True) - elif dtype == 'float32': - torch.testing.assert_close(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) - elif dtype in ['int64', 'int32', 'int16', 'int8']: - assert torch.equal(y_cal, y_ref) - elif dtype == 'bool': - assert torch.equal(y_cal, y_ref) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def convert_tensor_with_device_type(indata: dict, device_type: str): - target_device = torch.device(device_type) - outdata = {} - - for key, value in indata.items(): - if isinstance(value, torch.Tensor): - if value.device.type != target_device.type: - outdata[key] = value.to(target_device) - else: - outdata[key] = value - else: - outdata[key] = value - - return outdata - - -def compare_data_precision(dict_ref: dict, dict_cal: dict, device_type: str): - keys_ref, keys_cal = set(dict_ref.keys()), set(dict_cal.keys()) - if not keys_ref.issubset(keys_cal): - raise ValueError("The keys of dict_ref is not subset of dict_cal") - - for key in dict_ref.keys(): - val_a, val_b = dict_ref[key], dict_cal[key] - if not isinstance(val_b, type(val_a)): - raise ValueError("The data type of two dicts are different") - - if isinstance(val_a, torch.Tensor): - validate_cmp(dtype=str(val_a.dtype).split('.')[-1], y_ref=val_a, y_cal=val_b, device_type=device_type) - - -def run_and_compare_ptfile(ptfile_path: str, kernel_runner, device_type: str = DEVICE_TYPE_NPU): - try: - datas = torch.load(ptfile_path, map_location=torch.device('cpu')) - except Exception as e: - pytest.fail(f"load file {ptfile_path} failed: {e}") - - def _run_single_case(data): - if not isinstance(data, dict): - pytest.fail("Each case loaded from pt file must be a dict") - - input_data = convert_tensor_with_device_type(data.get("input_data", {}), device_type=device_type) - grid = data.get("grid") - try: - kernel_runner(input_data, grid) - except Exception as e: - pytest.fail(f"kernel_runner execution failed: {e}") - - output_data_cpu = convert_tensor_with_device_type(input_data, device_type='cpu') - expected = data.get("gpu_output", {}) - expected_filtered = {k: expected[k] for k in output_data_cpu.keys() if k in expected} - if not expected_filtered: - pytest.fail("No matching expected outputs found in pt file for comparison") - try: - compare_data_precision(expected_filtered, output_data_cpu, device_type='cpu') - except Exception as e: - pytest.fail(f"The testcase failed: {e}") - - # Supports three scenarios: - # 1) The file stores a single dict (existing behavior) - # 2) The file stores a list, where each element is a case dict - # 3) The file stores a dict, but some tensors represent multiple cases in batch on the 0th dimension (no automatic splitting; it is recommended to use a list) - if isinstance(datas, list): - for _, data in enumerate(datas): - _run_single_case(data) - elif isinstance(datas, dict): - _run_single_case(datas) - else: - pytest.fail("Unsupported pt file format: must be a dict or a list of dicts") diff --git a/third_party/ascend/unittest/kernels/test_triton_kernel.py b/third_party/ascend/unittest/kernels/test_triton_kernel.py deleted file mode 100644 index 528c8eb088..0000000000 --- a/third_party/ascend/unittest/kernels/test_triton_kernel.py +++ /dev/null @@ -1,73 +0,0 @@ -import importlib -import os -import urllib.request -from pathlib import Path - -import pytest - -import test_common - - -def discover_kernels(): - kernels = [] - kernels_root_path = Path(__file__).parents[0] - for p in kernels_root_path.rglob("*.py"): - if not p.is_file(): - continue - if p.parent == kernels_root_path: - continue - rel = p.relative_to(kernels_root_path) - if len(rel.parts) == 1 or p.name == "__init__.py": - continue - module_path = ".".join(rel.with_suffix("").parts) - kernels.append((module_path, p.stem)) - return sorted(kernels, key=lambda x: x[1]) - - -KERNEL_ITEMS = discover_kernels() - - -@pytest.mark.parametrize("module_path, kernel_name", KERNEL_ITEMS) -def test_triton_kernel(module_path, kernel_name, pytestconfig): - selected = pytestconfig.getoption("kernel") - if selected: - if kernel_name not in selected: - pytest.skip(f"skip {kernel_name} due to --kernel filter") - base_url = "https://triton-ascend-artifacts.obs.cn-southwest-2.myhuaweicloud.com" - rel = module_path - parts = rel.split(".") if rel else [] - pt_url = f"{base_url}/test/kernels/{parts[0]}_pt/{kernel_name}.pt" - local_pt = Path(__file__).parent / f"{kernel_name}.pt" - downloaded = False - if not local_pt.exists(): - try: - urllib.request.urlretrieve(pt_url, local_pt) - downloaded = True - except Exception as e: - pytest.fail( - f"Failed to download the {kernel_name}.pt file. Please check whether the {kernel_name}.pt file has been uploaded to the OBS bucket: {e}" - ) - try: - mod = importlib.import_module(module_path) - except Exception as e: - pytest.fail(f"import {module_path} failed: {e}") - - if hasattr(mod, kernel_name): - kernel_attr = kernel_name - else: - candidates = [a for a in dir(mod) if a.endswith("_kernel")] - kernel_attr = candidates[0] if candidates else None - - if not kernel_attr: - pytest.fail(f"No kernel callable found in {module_path}") - - kernel_callable = getattr(mod, kernel_attr) - - def runner(input_data, grid): - kernel_callable[grid](**input_data) - - try: - test_common.run_and_compare_ptfile(str(local_pt), runner, device_type='npu') - finally: - if downloaded and local_pt.exists(): - local_pt.unlink() diff --git a/third_party/ascend/unittest/kernels/vllm/expand_kernel.py b/third_party/ascend/unittest/kernels/vllm/expand_kernel.py deleted file mode 100644 index 8c87b6d0b2..0000000000 --- a/third_party/ascend/unittest/kernels/vllm/expand_kernel.py +++ /dev/null @@ -1,33 +0,0 @@ -import triton -import triton.language as tl -import triton.language.extra.cann.extension as extension - - -@triton.jit(do_not_specialize=["replace_from", "replace_to"]) -def expand_kernel( - output_ptr, # [num_tokens] - input_ptr, # [batch_size] - cu_num_tokens_ptr, # [batch_size] - replace_from, - replace_to, - vec_len, - MAX_NUM_TOKENS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - req_idx = tl.program_id(0) - offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - len_mask = offset < vec_len - - start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) - end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask) - num_tokens = end_idx - start_idx - - src_val = tl.load(input_ptr + offset, len_mask) - src_val = tl.where(src_val == replace_from, replace_to, src_val) - - for i in tl.range(0, BLOCK_SIZE): - num_tokens1 = extension.get_element(num_tokens, (i, )) - start_idx1 = extension.get_element(start_idx, (i, )) - src_val1 = extension.get_element(src_val, (i, )) - offset1 = tl.arange(0, MAX_NUM_TOKENS) - tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1) diff --git a/third_party/ascend/unittest/kernels/vllm/rejection_random_sample_kernel.py b/third_party/ascend/unittest/kernels/vllm/rejection_random_sample_kernel.py deleted file mode 100644 index b5d124a912..0000000000 --- a/third_party/ascend/unittest/kernels/vllm/rejection_random_sample_kernel.py +++ /dev/null @@ -1,55 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit(do_not_specialize=["max_spec_len"]) -def rejection_random_sample_kernel( - output_token_ids_ptr, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens_ptr, # [batch_size] - draft_token_ids_ptr, # [num_tokens] - draft_probs_ptr, # [num_tokens, vocab_size] or None - target_probs_ptr, # [num_tokens, vocab_size] - bonus_token_ids_ptr, # [batch_size] - recovered_token_ids_ptr, # [num_tokens] - uniform_probs_ptr, # [num_tokens] - is_greedy_ptr, # [batch_size] - max_spec_len, - vocab_size, - NO_DRAFT_PROBS: tl.constexpr, -): - req_idx = tl.program_id(0) - is_greedy = tl.load(is_greedy_ptr + req_idx) - if is_greedy: - # Early exost for greedy sampling requests - return - - start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) - end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) - num_draft_tokens = end_idx - start_idx - - rejected = False - for pos in range(num_draft_tokens): - if not rejected: - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - if NO_DRAFT_PROBS: - draft_prob = 1 - else: - draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) - if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: - # Accept - token_id = draft_token_id - else: - # Reject. Use recovered token - rejected = True - token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) - tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) - - if not rejected: - # If all tokens are accepted, append the bonus token - bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) - tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, - bonus_token_id, - ) diff --git a/third_party/ascend/unittest/kernels/vllm/sample_recovered_tokens_kernel.py b/third_party/ascend/unittest/kernels/vllm/sample_recovered_tokens_kernel.py deleted file mode 100644 index 24aa9c7b7c..0000000000 --- a/third_party/ascend/unittest/kernels/vllm/sample_recovered_tokens_kernel.py +++ /dev/null @@ -1,77 +0,0 @@ -import triton -import triton.language as tl -import triton.language.extra.cann.extension as extension - - -@triton.jit -def sample_recovered_tokens_kernel( - output_token_ids_ptr, # [num_tokens] - cu_num_draft_tokens_ptr, # [batch_size] - draft_token_ids_ptr, # [num_tokens] - draft_probs_ptr, # [num_tokens, vocab_size] or None - target_probs_ptr, # [num_tokens, vocab_size] - q_ptr, # [batch_size, vocab_size] - vocab_size, - PADDED_VOCAB_SIZE: tl.constexpr, - NO_DRAFT_PROBS: tl.constexpr, - SUB_BLOCK: tl.constexpr, -): - req_idx = tl.program_id(0) - start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) - end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) - num_draft_tokens = end_idx - start_idx - - # Early exit for out-of-range positions. - pos = tl.program_id(1) - if pos >= num_draft_tokens: - return - - loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK - global_recovered_id = -1 - global_max_p = -1.0 - if NO_DRAFT_PROBS: - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - # Temporarily zero out the probability of the draft token. - # This is essentially the same as target_prob - draft_prob, except that - # n-gram does not have draft_prob. We regard it as 1. - tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0) - for loop_i in range(loop): - vocab_start = loop_i * SUB_BLOCK - vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) - prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset - < vocab_size, other=0) - q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, - other=float("-inf")) - new_p = prob / q - recovered_id = tl.argmax(new_p, axis=-1) - max_p = extension.get_element(new_p, (recovered_id, )) - if max_p > global_max_p: - global_max_p = max_p - global_recovered_id = vocab_start + recovered_id - else: - for loop_i in range(loop): - vocab_start = loop_i * SUB_BLOCK - vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) - draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset - < vocab_size, other=0) - target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset - < vocab_size, other=0) - prob = tl.maximum(target_prob - draft_prob, 0) - # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because - # `tl.argmax` will select the maximum value. - - q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, - other=float("-inf")) - new_p = prob / q - recovered_id = tl.argmax(new_p, axis=-1) - max_p = extension.get_element(new_p, (recovered_id, )) - if max_p > global_max_p: - global_max_p = max_p - global_recovered_id = vocab_start + recovered_id - - tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id) - - if NO_DRAFT_PROBS: - # Restore the original probability. - tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob) diff --git a/third_party/ascend/unittest/pytest_ut/test_01_vector_add.py b/third_party/ascend/unittest/pytest_ut/test_01_vector_add.py new file mode 100644 index 0000000000..8d59474dfb --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_01_vector_add.py @@ -0,0 +1,86 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Vector Addition - Pytest Version +""" + +import torch +import torch_npu + +import triton +import triton.language as tl + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +# %% +# Let's also declare a helper function to (1) allocate the `z` tensor +# and (2) enqueue the above kernel with appropriate grid/block sizes: + + +def add(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + return output + + +# %% +# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: +def test_vector_addition(): + torch.manual_seed(0) + size = 98432 + x = torch.rand(size, device='npu') + y = torch.rand(size, device='npu') + output_torch = x + y + output_triton = add(x, y) + torch.testing.assert_close(output_triton, output_torch) diff --git a/third_party/ascend/unittest/pytest_ut/test_02_fused_softmax.py b/third_party/ascend/unittest/pytest_ut/test_02_fused_softmax.py new file mode 100644 index 0000000000..91fde18a95 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_02_fused_softmax.py @@ -0,0 +1,121 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Fused Softmax +============= +""" + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + + +def naive_softmax(x): + """Compute row-wise softmax of X using native pytorch + + We subtract the maximum element in order to avoid overflows. Softmax is invariant to + this shift. + """ + # read MN elements ; write M elements + x_max = x.max(dim=1)[0] + # read MN + M elements ; write MN elements + z = x - x_max[:, None] + # read MN elements ; write MN elements + numerator = torch.exp(z) + # read MN elements ; write M elements + denominator = numerator.sum(dim=1) + # read MN + M elements ; write MN elements + ret = numerator / denominator[:, None] + # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements + return ret + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, + BLOCK_SIZE: tl.constexpr): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) + + +kernels = {} + + +def softmax(x): + n_rows, n_cols = x.shape + + # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + + # Allocate output + y = torch.empty_like(x) + + # pre-compile kernel to get register usage and compute thread occupancy. + kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) + if kernel is None: + num_programs = 32 + kernel = softmax_kernel + kernels[BLOCK_SIZE] = (kernel, num_programs) + + num_programs = min(num_programs, n_rows) + + # Create a number of persistent programs. + kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE) + return y + + +@pytest.mark.parametrize( + "shape", + [ + (1823, 781), + (128, 257), + ], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_fused_softmax(shape, dtype): + torch.manual_seed(0) + x = torch.randn(shape, dtype=dtype, device="npu") + + y_triton = softmax(x) + y_torch = torch.softmax(x, axis=1) + + torch.testing.assert_close(y_triton, y_torch, atol=1e-4, rtol=1e-4) diff --git a/third_party/ascend/unittest/pytest_ut/test_03_matrix_multiplication.py b/third_party/ascend/unittest/pytest_ut/test_03_matrix_multiplication.py new file mode 100644 index 0000000000..ec5f8e8655 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_03_matrix_multiplication.py @@ -0,0 +1,211 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Matrix Multiplication +=============== +""" + +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import triton.language.extra.cann.extension as extension + +DEV = "npu" + + +def get_autotune_config(): + return [ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}), + triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}), + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + ACTIVATION: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + GROUP_SIZE_M: tl.constexpr = 1 + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs_base = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs_base = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + msk_m = offs_am < M + msk_n = offs_bn < N + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a_ptrs = a_ptrs_base + k * BLOCK_SIZE_K * stride_ak + b_ptrs = b_ptrs_base + k * BLOCK_SIZE_K * stride_bk + a = tl.load( + a_ptrs, + mask=msk_m[:, None] and (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=msk_n[None, :] and (offs_k[:, None] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + # Original vector operations + # # ----------------------------------------------------------- + # # Write back the block of the output matrix C with masks. + # Comment out the following lines to enable split the workload to two vector cores + SUB_BLK_M: tl.constexpr = BLOCK_SIZE_M // 2 + for s in extension.parallel(0, 2, bind_sub_block=True): + vec_sub_blk = extension.extract_slice(accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1)) + if ACTIVATION == "leaky_relu_custom": + vec_sub_blk = leaky_relu_custom(vec_sub_blk) + c_sub_blk = vec_sub_blk.to(tl.float16) + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + s * SUB_BLK_M + tl.arange(0, SUB_BLK_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c_sub_blk, mask=c_mask) + + +# We can fuse `leaky_relu_custom` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu_custom(x): + return tl.where(x >= 0, x, 0.01 * x) + 1.0 + + +def torch_matmul(a, b, activation=""): + c = torch.matmul(a, b) + if activation == "leaky_relu_custom": + c = torch.where(c >= 0, c, 0.01 * c) + 1.0 + return c + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + # 1D launch kernel where each block gets its own program. + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation, # + ) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation. +@pytest.mark.parametrize( + "shape", + [ + (512, 512, 512), + (256, 384, 128), + ], +) +@pytest.mark.parametrize( + "activation", + [ + "", + pytest.param("leaky_relu_custom", + marks=pytest.mark.skip(reason="temporarily skip leaky_relu_custom ub overflow case")), + ], +) +def test_matrix_multiplication(shape, activation): + m, k, n = shape + torch.manual_seed(0) + + a = torch.randn((m, k), device=DEV, dtype=torch.float16) + b = torch.randn((k, n), device=DEV, dtype=torch.float16) + + triton_output = matmul(a, b, activation) + torch_output = torch_matmul(a, b, activation) + + torch.testing.assert_close(triton_output, torch_output, atol=1e-3, rtol=1e-3) diff --git a/third_party/ascend/unittest/pytest_ut/test_04_low_memory_dropout.py b/third_party/ascend/unittest/pytest_ut/test_04_low_memory_dropout.py new file mode 100644 index 0000000000..f615608e7a --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_04_low_memory_dropout.py @@ -0,0 +1,134 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Low-Memory Dropout +================== +""" + +import pytest +import torch +import torch_npu + +import triton +import triton.language as tl + +DEV = "npu" + + +@triton.jit +def _dropout( + x_ptr, # pointer to the input + x_keep_ptr, # pointer to a mask of 0s and 1s + output_ptr, # pointer to the output + n_elements, # number of elements in the `x` tensor + p, # probability that an element of `x` is changed to zero + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + # Load data + x = tl.load(x_ptr + offsets, mask=mask) + x_keep = tl.load(x_keep_ptr + offsets, mask=mask) + # The line below is the crucial part, described in the paragraph above! + output = tl.where(x_keep != 0, x / (1 - p), 0.0) + # Write-back output + tl.store(output_ptr + offsets, output, mask=mask) + + +def dropout(x, x_keep, p): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) + return output + + +@triton.jit +def _seeded_dropout( + x_ptr, + output_ptr, + n_elements, + p, + seed, + BLOCK_SIZE: tl.constexpr, +): + # compute memory offsets of elements handled by this instance + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # load data from x + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + # randomly prune it + random = tl.rand(seed, offsets) + x_keep = random > p + # write-back + output = tl.where(x_keep, x / (1 - p), 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + + +def seeded_dropout(x, p, seed): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) + return output + + +@pytest.mark.parametrize("shape,p", [((10, ), 0.5), ((256, ), 0.5), ((513, ), 0.2), ((32, 64), 0.35)]) +def test_dropout_matches_reference(shape, p): + torch.manual_seed(0) + x = torch.randn(size=shape, device=DEV, dtype=torch.float32) + x_keep = (torch.rand(size=shape, device=DEV) > p).to(torch.int32) + + output = dropout(x, x_keep=x_keep, p=p) + expected = torch.where(x_keep != 0, x / (1 - p), torch.zeros_like(x)) + + torch.testing.assert_close(output, expected, atol=1e-6, rtol=0) + + +@pytest.mark.parametrize("shape,p,seed", [((10, ), 0.5, 123), ((256, ), 0.5, 123), ((513, ), 0.2, 7), + ((32, 64), 0.35, 999)]) +def test_seeded_dropout_is_deterministic(shape, p, seed): + torch.manual_seed(0) + x = torch.randn(size=shape, device=DEV, dtype=torch.float32) + + output = seeded_dropout(x, p=p, seed=seed) + output_same_seed = seeded_dropout(x, p=p, seed=seed) + output_different_seed = seeded_dropout(x, p=p, seed=512) + + torch.testing.assert_close(output, output_same_seed, atol=1e-6, rtol=0) + + assert output.shape == x.shape + assert output.dtype == x.dtype + assert torch.count_nonzero(output != output_different_seed).item() > 0 + assert torch.count_nonzero(output).item() <= x.numel() diff --git a/third_party/ascend/unittest/pytest_ut/test_05_layer_norm.py b/third_party/ascend/unittest/pytest_ut/test_05_layer_norm.py new file mode 100644 index 0000000000..84e3268b9b --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_05_layer_norm.py @@ -0,0 +1,119 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Layer Normalization +============= +""" + +import pytest +import torch +import triton +import triton.language as tl +import torch_npu + + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +@torch.inference_mode() +def layer_norm(x, normalized_shape, weight, bias, eps=1e-5): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + kernel = _layer_norm_fwd_fused[(M, )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + return y + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_layer_norm(dtype): + M, N = 128, 128 + eps = 1e-5 + device = 'npu' + + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device) + bias = torch.rand(w_shape, dtype=dtype, device=device) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + + y_tri = layer_norm(x, w_shape, weight, bias, eps) + y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) diff --git a/third_party/ascend/unittest/pytest_ut/test_06_fused_attention.py b/third_party/ascend/unittest/pytest_ut/test_06_fused_attention.py new file mode 100644 index 0000000000..8f83d6a730 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_06_fused_attention.py @@ -0,0 +1,340 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Credits: OpenAI kernel team + +Extra Credits: + +* Original flash attention paper (https://arxiv.org/abs/2205.14135) +* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +""" + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.language.extra.cann.extension as extension + +DEVICE = "npu" + + +@triton.jit +def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, query vector + K_block_ptr, V_block_ptr, # Key and value block pointers for current stage + start_m, qk_scale, # Starting position of current query block, qk scale factor + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # Block size constants + STAGE: tl.constexpr, offs_m: tl.constexpr, + offs_n: tl.constexpr, # Current stage flag, m and n offset indices + N_CTX: tl.constexpr, + fp8_v: tl.constexpr): # Total context length, whether to enable FP8 for value precision + # Set the processing range [lo, hi) for the current stage (in column block units) + # Causal attention, as the name implies, restricts the flow of information during computation, + # only allowing the model to see the current and previous positions. + # In other words, the output at the current position can only depend on the input at or before this position, + # and cannot access information from future positions. + # Causal attention ensures sequential order and prevents "leakage of future information." + # But the following logic will also be triggered + if STAGE == 1: + # Stage 1: process all tokens before the query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + # Stage 2: process the current query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) # Align starting position + # causal = False (no need for masking) + else: + lo, hi = 0, N_CTX # Process the entire context + + # Adjust K and V block pointers to the starting position `lo` + K_block_ptr = tl.advance(K_block_ptr, (lo, 0)) # K is [HEAD_DIM, N_CTX], shift along the second dim by lo + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # V is [N_CTX, HEAD_DIM], shift along the first dim by lo + + # Index mapping for the accumulator , used for slicing when HEAD_DIM >= 256 + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + + # Iterate over all k, v blocks in the current stage and accumulate the output + for start_n in range(lo, hi, BLOCK_N): # Process BLOCK_N columns at a time + start_n = tl.multiple_of(start_n, BLOCK_N) # Align column start position + # -- Compute qk ---- + k = tl.load(K_block_ptr) + # Modify K + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + # Apply causal mask for STAGE 2 + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) # Construct upper triangular mask + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) # Set invalid positions to -∞ + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Update m_ij = max(m_i, max(qk)) + qk -= m_ij[:, None] # Subtract max for softmax stability + else: + qk = qk * qk_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max + qk = qk - m_ij[:, None] # Stabilize + + # Softmax weights p = exp(qk) + p = tl.math.exp(qk) + + # Convert softmax weight type depending on FP8 usage + if fp8_v: + p_cast = p.to(tl.float8e5) # Convert to FP8 format (save memory) + else: + p_cast = p.to(k.dtype) + + v = tl.load(V_block_ptr) # Load corresponding V block + pv = tl.dot(p_cast, v) + l_ij = tl.sum(p, 1) # Softmax denominator (sum of each row) + # -- Update m_i and l_i + alpha = tl.math.exp(m_i - m_ij) # Update factor: exp difference between old and new max + l_i = l_i * alpha + l_ij # Update softmax denominator + # -- Update output accumulator -- + if HEAD_DIM < 256: + acc_ptr = acc_ptr * alpha[:, None] + acc_ptr = tl.dot(p_cast, v, acc_ptr) + else: + # 1. Load current slice of accumulator + acc = tl.load(acc_ptr + block2d_acc) + # 2. Update in slices (split by 1/4 of BLOCK_M to avoid ub overflow) + for i in range(4): + # Calculate start/end rows for current slice + offset = i * (BLOCK_M // 4) + # Extract slice data + acc_i = extension.extract_slice(acc, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + alpha_i = extension.extract_slice(alpha, [offset], [BLOCK_M // 4], [1]) + pv_i = extension.extract_slice(pv, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # Incrementally update slice: acc = acc * alpha + pv + acc_i = acc_i * alpha_i[:, None] + pv_i + # Write updated slice back to accumulator + acc = extension.insert_slice(acc, acc_i, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # 3. updated accumulator + tl.store(acc_ptr + block2d_acc, acc) + + m_i = m_ij # Update current block max + # Advance V and K block pointers to next BLOCK_N range + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + # Return accumulated output acc_ptr, softmax denominator l_i, and max value m_i + return acc_ptr, l_i, m_i + + +@triton.jit +def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, + stride_qk: tl.constexpr, stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, + stride_kk: tl.constexpr, stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, + stride_vk: tl.constexpr, stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, + stride_on: tl.constexpr, Z: tl.constexpr, H: tl.constexpr, N_CTX: tl.constexpr, HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, STAGE: tl.constexpr): + # Total number of blocks in sequence dimension (M) + NUM_BLOCKS_M = N_CTX // BLOCK_M + # Total tasks = number of sequence blocks × batch size (Z) × number of attention heads (H) + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + + # Current M-dimension block index + pid = tl.program_id(0) + + for block_idx in range(pid, NUM_BLOCKS, 20): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + # Create block pointers for Q, K, V, Output + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + # Initialize offsets + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + + # Initialize accumulator + if HEAD_DIM < 256: + acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + else: + acc_offset = (off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM + + off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM + task_m_idx * BLOCK_M * HEAD_DIM) + acc_ptr = acc + acc_offset + + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + + m_i += tl.math.log(l_i) + if HEAD_DIM < 256: + accumulator = acc_ptr / l_i[:, None] + else: + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + accumulator = tl.load(acc_ptr + block2d_acc) + accumulator = accumulator / l_i[:, None] + + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, accumulator.to(Out.type.element_ty)) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, BM, BN): + """ + Forward computation interface: + Args: + ctx: Context object + q: Query tensor (Q), shape [Z, H, N_CTX, HEAD_DIM] + k: Key tensor (K), shape [Z, H, N_CTX, HEAD_DIM] + v: Value tensor (V), shape [Z, H, N_CTX, HEAD_DIM] + causal: Whether to enable causal attention + sm_scale: Scaling factor for QK product + BM: Q block size (BLOCK_M) + BN: K/V block size (BLOCK_N) + Returns: + o: Attention output tensor, shape [Z, H, N_CTX, HEAD_DIM] + """ + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + out = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + + # Number of NPU cores (adjust based on hardware) + num_cores = 20 + acc = torch.zeros((q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), dtype=torch.float32, device=q.device) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[(num_cores, )](q, k, v, M, out, acc, sm_scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), + v.stride(2), v.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), + q.shape[0], q.shape[1], N_CTX=q.shape[2], HEAD_DIM=HEAD_DIM_K, BLOCK_M=BM, BLOCK_N=BN, + STAGE=stage, **extra_kern_args) + + ctx.save_for_backward(q, k, v, out, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return out + + +attention = _attention.apply + + +# ==================== Pytest Test ==================== +@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN", [ + (1, 1, 128, 128, False, torch.float16, 32, 128), + (1, 1, 128, 128, False, torch.bfloat16, 64, 128), + (1, 2, 256, 256, False, torch.bfloat16, 32, 256), + (2, 2, 128, 256, False, torch.float16, 64, 128), + (4, 32, 64, 64, False, torch.float16, 32, 64), + (4, 32, 1024, 64, False, torch.bfloat16, 64, 128), + (4, 32, 4096, 64, False, torch.float16, 128, 128), +]) +def test_attention_fused(Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN): + if N_CTX % BM != 0 or N_CTX % BN != 0 or HEAD_DIM % 16 != 0: + pytest.skip("Skipping non-divisible case") + + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + + sm_scale = 0.5 + tri_out = attention(q, k, v, causal, sm_scale, BM, BN) + ref_out = torch_npu.npu_fusion_attention( + q, + k, + v, + H, + padding_mask=None, + atten_mask=None, + scale=sm_scale, + keep_prob=1.0, + input_layout="BNSD", + pre_tockens=65535, + next_tockens=65535, + sparse_mode=0, + )[0] + + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2, equal_nan=True) diff --git a/third_party/ascend/unittest/pytest_ut/test_07_extern_functions.py b/third_party/ascend/unittest/pytest_ut/test_07_extern_functions.py new file mode 100644 index 0000000000..01e41209c7 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_07_extern_functions.py @@ -0,0 +1,83 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Libdevice (`tl.extra.libdevice`) function +============================== +""" + +import pytest +import torch +import torch_npu + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +from triton.backends.ascend.compiler import get_libdevice + +DEV = "npu" + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) + + +extern_libs = {'libdevice': get_libdevice()} + + +def run_asin_case(size, use_extern_libs): + torch.manual_seed(0) + x = torch.rand(size, device=DEV) + output_triton = torch.empty_like(x) + output_torch = torch.asin(x) + assert x.device.type == DEV and output_triton.device.type == DEV + + n_elements = output_torch.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + launch_kwargs = {"BLOCK_SIZE": 1024} + if use_extern_libs: + launch_kwargs["extern_libs"] = extern_libs + + asin_kernel[grid](x, output_triton, n_elements, **launch_kwargs) + torch.testing.assert_close(output_torch, output_triton, rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize("size", [98432, 1024]) +def test_asin_kernel_matches_torch(size): + run_asin_case(size=size, use_extern_libs=False) + + +@pytest.mark.parametrize("size", [98432, 1024]) +def test_asin_kernel_matches_torch_with_extern_libs(size): + run_asin_case(size=size, use_extern_libs=True) diff --git a/third_party/ascend/unittest/pytest_ut/test_08_grouped_gemm.py b/third_party/ascend/unittest/pytest_ut/test_08_grouped_gemm.py new file mode 100644 index 0000000000..830f9e57d6 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_08_grouped_gemm.py @@ -0,0 +1,287 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# Copyright (c) Huawei Technologies Co., Ltd. 2025. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Group GEMM +============================ +""" + +import pytest +import torch +import torch_npu + +import triton +import triton.language as tl +import triton.runtime.driver as driver + +DEV = "npu" + + +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +NUM_CORES = get_npu_properties()["num_aicore"] + + +@triton.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + ], + key=['group_size'], +) +@triton.jit +def grouped_matmul_kernel( + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + group_gemm_sizes, + g_lds, + group_size, + NUM_SM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + tile_idx = tl.program_id(0) + last_problem_end = 0 + for g in range(group_size): + gm = tl.load(group_gemm_sizes + g * 3) + gn = tl.load(group_gemm_sizes + g * 3 + 1) + gk = tl.load(group_gemm_sizes + g * 3 + 2) + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles): + k = gk + lda = tl.load(g_lds + g * 3) + ldb = tl.load(g_lds + g * 3 + 1) + ldc = tl.load(g_lds + g * 3 + 2) + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) + tile_idx_in_gemm = tile_idx - last_problem_end + tile_m_idx = tile_idx_in_gemm // num_n_tiles + tile_n_idx = tile_idx_in_gemm % num_n_tiles + + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] + b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for _ in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + tl.multiple_of(a_ptrs, [16, 16]) + tl.multiple_of(b_ptrs, [16, 16]) + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * ldb + c = accumulator.to(tl.float16) + + offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] + + tl.store(c_ptrs, c) + tile_idx += NUM_SM + + last_problem_end = last_problem_end + num_tiles + + +def group_gemm_fn(group_A, group_B): + device = torch.device(DEV) + assert len(group_A) == len(group_B) + group_size = len(group_A) + + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = group_A[i] + B = group_B[i] + assert A.shape[1] == B.shape[0] + M, K = A.shape + K, N = B.shape + C = torch.empty((M, N), device=device, dtype=A.dtype) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [M, N, K] + g_lds += [A.stride(0), B.stride(0), C.stride(0)] + + d_a_ptrs = torch.tensor(A_addrs, device=device) + d_b_ptrs = torch.tensor(B_addrs, device=device) + d_c_ptrs = torch.tensor(C_addrs, device=device) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) + + def grid(meta): + return (meta['NUM_SM'], ) + + grouped_matmul_kernel[grid]( + d_a_ptrs, + d_b_ptrs, + d_c_ptrs, + d_g_sizes, + d_g_lds, + group_size, + ) + + return group_C + + +def build_group_inputs(group_m, group_n, group_k): + assert len(group_m) == len(group_n) + assert len(group_n) == len(group_k) + + group_A = [] + group_B = [] + for m, n, k in zip(group_m, group_n, group_k): + group_A.append(torch.rand((m, k), device=DEV, dtype=torch.float16)) + group_B.append(torch.rand((k, n), device=DEV, dtype=torch.float16)) + return group_A, group_B + + +def run_group_gemm_case(group_m, group_n, group_k): + group_A, group_B = build_group_inputs(group_m, group_n, group_k) + + tri_out = group_gemm_fn(group_A, group_B) + ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] + + assert len(tri_out) == len(ref_out) + for tri_tensor, ref_tensor, m, n in zip(tri_out, ref_out, group_m, group_n): + assert tri_tensor.shape == (m, n) + assert tri_tensor.dtype == torch.float16 + torch.testing.assert_close(ref_tensor, tri_tensor, atol=1e-2, rtol=1e-3) + + +@pytest.mark.parametrize( + "group_m,group_n,group_k", + [([1024, 512, 256, 128], [1024, 512, 256, 128], [1024, 512, 256, 128])], +) +def test_grouped_gemm_tutorial_example(group_m, group_n, group_k): + run_group_gemm_case( + group_m=group_m, + group_n=group_n, + group_k=group_k, + ) + + +def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): + + def grid(meta): + return (meta['NUM_SM'], ) + + grouped_matmul_kernel[grid]( + a_ptrs, + b_ptrs, + c_ptrs, + sizes, + lds, + group_size, + ) + + +def torch_perf_fn(group_A, group_B): + for a, b in zip(group_A, group_B): + torch.matmul(a, b) + + +def run_benchmark_case(N, provider): + group_size = 4 + group_A = [] + group_B = [] + group_C = [] + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + for _ in range(group_size): + A = torch.rand((N, N), device=DEV, dtype=torch.float16) + B = torch.rand((N, N), device=DEV, dtype=torch.float16) + C = torch.empty((N, N), device=DEV, dtype=torch.float16) + group_A.append(A) + group_B.append(B) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [N, N, N] + g_lds += [N, N, N] + + d_a_ptrs = torch.tensor(A_addrs, device=DEV) + d_b_ptrs = torch.tensor(B_addrs, device=DEV) + d_c_ptrs = torch.tensor(C_addrs, device=DEV) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEV) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEV) + + quantiles = [0.5, 0.2, 0.8] + + def bench_torch(): + torch_perf_fn(group_A, group_B) + + def bench_triton(): + triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size) + + if provider == 'torch': + ms, min_ms, max_ms = triton.testing.do_bench(bench_torch, quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(bench_triton, quantiles=quantiles) + + assert ms >= 0 + assert min_ms >= 0 + assert max_ms >= 0 + + +@pytest.mark.parametrize("N", [2**i for i in range(7, 11)]) +@pytest.mark.parametrize("provider", ["torch", "triton"]) +def test_grouped_gemm_benchmark_cases(N, provider): + run_benchmark_case(N=N, provider=provider) diff --git a/third_party/ascend/unittest/pytest_ut/test_09_persistent_matmul.py b/third_party/ascend/unittest/pytest_ut/test_09_persistent_matmul.py new file mode 100644 index 0000000000..f73a3fc46f --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_09_persistent_matmul.py @@ -0,0 +1,325 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Persistent Matmul +===================== +""" + +import time + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver + +DEV = "npu" +DTYPE = torch.float16 + + +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +def get_num_compute_cores(): + return get_npu_properties()["num_aicore"] + + +def _matmul_launch_metadata(grid, kernel, args): + ret = {} + M, N, K = args["M"], args["N"], args["K"] + bytes_per_elem = args["c_ptr"].element_size() + ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K + ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) + return ret + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_sm = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_sm += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + pid_m = 0 + pid_n = 0 + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in range(0, k_tiles * tiles_per_sm): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + if ki == k_tiles - 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def get_configs(dtype): + return { + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + } + }[dtype] + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + M, K = a.shape + _, N = b.shape + configs = get_configs(a.dtype) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(meta): + return (triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) + + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=configs["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs["GROUP_SIZE_M"], + num_stages=configs["num_stages"], + num_warps=configs["num_warps"], + ) + return c + + +def matmul_persistent(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + num_sms = get_num_compute_cores() + M, K = a.shape + _, N = b.shape + configs = get_configs(a.dtype) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(meta): + return (min(num_sms, triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"])), ) + + matmul_kernel_persistent[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=configs["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs["GROUP_SIZE_M"], + NUM_SMS=num_sms, + num_stages=configs["num_stages"], + num_warps=configs["num_warps"], + ) + return c + + +def torch_matmul(a, b): + return torch.matmul(a, b) + + +def bench(K, reps=10): + M = 8192 + N = 8192 + a = torch.randn((M, K), device=DEV, dtype=DTYPE) + b = torch.randn((K, N), device=DEV, dtype=DTYPE) + + for _ in range(reps): + _ = torch_matmul(a, b) + time.sleep(0.01) + for _ in range(reps): + _ = matmul(a, b) + time.sleep(0.01) + for _ in range(reps): + _ = matmul_persistent(a, b) + time.sleep(0.01) + + +def validate(M, N, K): + a = torch.randn((M, K), device=DEV, dtype=DTYPE) + b = torch.randn((K, N), device=DEV, dtype=DTYPE) + + torch_result = torch_matmul(a, b) + naive_result = matmul(a, b) + persistent_result = matmul_persistent(a, b) + return torch_result, naive_result, persistent_result + + +@pytest.mark.skip(reason="temporarily skip persistent matmul validate cases until UB overflow issue is fixed") +@pytest.mark.parametrize( + "M,N,K", + [ + (32, 32, 32), + (8192, 8192, 512), + ], +) +def test_persistent_matmul_validate_cases(M, N, K): + torch.manual_seed(0) + torch_result, naive_result, persistent_result = validate(M, N, K) + + torch.testing.assert_close(naive_result, torch_result, atol=1.0, rtol=0) + torch.testing.assert_close(persistent_result, torch_result, atol=1.0, rtol=0) + torch.testing.assert_close(naive_result, persistent_result, atol=1.0, rtol=0) diff --git a/third_party/ascend/unittest/pytest_ut/test_10_gather_sorted.py b/third_party/ascend/unittest/pytest_ut/test_10_gather_sorted.py new file mode 100644 index 0000000000..8fcc3f6b9e --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_10_gather_sorted.py @@ -0,0 +1,201 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Gather sorted +=============== +This is an example only for npu. +""" + +import pytest +import torch +import torch_npu +import triton +import triton.runtime.driver as driver +import triton.language as tl + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +# a torch-version gather_sorted benchmark +def torch_gather_sorted(embeddings, sorted_idxes, aux_idxes): + # make the result tensor + res = torch.empty((aux_idxes.shape[0], embeddings.shape[-1]), dtype=embeddings.dtype, device=embeddings.device) + + # scatter embeddings + res[aux_idxes] = embeddings[sorted_idxes] + + return res + + +# triton-version gather_sorted's kernel +@triton.jit +def gather_sorted_kernel(embeddings_ptr, sorted_indices_ptr, aux_indices_ptr, res_ptr, rows, cols, + DEFAULT_VALUE: tl.constexpr, BIG_CORE_NUM: tl.constexpr, BIG_ROW_BLOCK_SIZE: tl.constexpr, + COL_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE_SUB: tl.constexpr): + SMALL_ROW_BLOCK_SIZE = BIG_ROW_BLOCK_SIZE - 1 + + emb_dtype = embeddings_ptr.type.element_ty + default_value = tl.cast(DEFAULT_VALUE, dtype=emb_dtype) + + core_idx = tl.program_id(0) + # compute the the size and start index of block + row_block_size = BIG_ROW_BLOCK_SIZE if (core_idx < BIG_CORE_NUM) else SMALL_ROW_BLOCK_SIZE + row_start_idx = (core_idx * BIG_ROW_BLOCK_SIZE) if (core_idx < BIG_CORE_NUM) else ( + BIG_CORE_NUM * BIG_ROW_BLOCK_SIZE + (core_idx - BIG_CORE_NUM) * SMALL_ROW_BLOCK_SIZE) + + # this version has 3-buffers, initilize for buffers + row_block_size_0 = tl.cdiv(row_block_size, 3) + remain_row_block_size = row_block_size - row_block_size_0 + row_block_size_1 = tl.cdiv(remain_row_block_size, 2) + row_block_size_2 = remain_row_block_size - row_block_size_1 + + row_start_idx_0 = row_start_idx + row_start_idx_1 = row_start_idx + row_block_size_0 + row_start_idx_2 = row_start_idx + row_block_size_0 + row_block_size_1 + + # process blocks witn shape (row_block_size, COL_BLOCK_SIZE_SUB) one by one + for col_idx in tl.range(0, COL_BLOCK_SIZE, COL_BLOCK_SIZE_SUB): + + embedding_0 = tl.full((COL_BLOCK_SIZE_SUB, ), default_value, dtype=emb_dtype) + embedding_1 = embedding_0 + 0 + embedding_2 = embedding_0 + 0 + + emb_offsets = col_idx + tl.arange(0, COL_BLOCK_SIZE_SUB) + emb_mask = emb_offsets < cols + + prev_embedding_idx_0 = tl.cast(-1, dtype=tl.int32) + prev_embedding_idx_1 = tl.cast(-1, dtype=tl.int32) + prev_embedding_idx_2 = tl.cast(-1, dtype=tl.int32) + for row_idx in tl.range(row_start_idx_0, row_start_idx_1): + # process the first buffer + embedding_idx_0 = tl.load(sorted_indices_ptr + row_idx) + res_idx_0 = tl.load(aux_indices_ptr + row_idx) + + if (embedding_idx_0 != 0) and (embedding_idx_0 != prev_embedding_idx_0): + embedding_0 = tl.load(embeddings_ptr + embedding_idx_0 * cols + emb_offsets, emb_mask) + tl.store(res_ptr + res_idx_0 * cols + emb_offsets, embedding_0, emb_mask) + else: + tl.store(res_ptr + res_idx_0 * cols + emb_offsets, embedding_0, emb_mask) + + prev_embedding_idx_0 = embedding_idx_0 + + # process the second buffer + if (row_idx + row_block_size_0) < (row_start_idx_1 + row_block_size_1): + embedding_idx_1 = tl.load(sorted_indices_ptr + row_idx + row_block_size_0) + res_idx_1 = tl.load(aux_indices_ptr + row_idx + row_block_size_0) + + if (embedding_idx_1 != 0) and (embedding_idx_1 != prev_embedding_idx_1): + embedding_1 = tl.load(embeddings_ptr + embedding_idx_1 * cols + emb_offsets, emb_mask) + tl.store(res_ptr + res_idx_1 * cols + emb_offsets, embedding_1, emb_mask) + else: + tl.store(res_ptr + res_idx_1 * cols + emb_offsets, embedding_1, emb_mask) + + prev_embedding_idx_1 = embedding_idx_1 + + # process the third buffer + if (row_idx + row_block_size_0 + row_block_size_1) < (row_start_idx_2 + row_block_size_2): + embedding_idx_2 = tl.load(sorted_indices_ptr + row_idx + row_block_size_0 + row_block_size_1) + res_idx_2 = tl.load(aux_indices_ptr + row_idx + row_block_size_0 + row_block_size_1) + + if (embedding_idx_2 != 0) and (embedding_idx_2 != prev_embedding_idx_2): + embedding_2 = tl.load(embeddings_ptr + embedding_idx_2 * cols + emb_offsets, emb_mask) + tl.store(res_ptr + res_idx_2 * cols + emb_offsets, embedding_2, emb_mask) + else: + tl.store(res_ptr + res_idx_2 * cols + emb_offsets, embedding_2, emb_mask) + + prev_embedding_idx_2 = embedding_idx_2 + + +# triton-version gather_sorted's host +def triton_gather_sorted(embeddings: torch.Tensor, sorted_indices: torch.Tensor, aux_indices: torch.Tensor, + default_value=1.0): + # constant settings for npu + ALIGNED = 32 + USE_SIZE = 96 * 1024 + CORE_NUM = get_npu_properties()["num_vectorcore"] + + n_rows = sorted_indices.shape[0] + n_cols = embeddings.shape[1] + # make the result tensor + output = torch.empty(n_rows, n_cols, dtype=embeddings.dtype, device=embeddings.device) + + # when writing an npu kernel using triton, + # you should note that the difference between BLOCK_SIZE and BLOCK_SIZE_SUB + # BLOCK_SIZE specifies the size of data that are processed in one program + col_size_aligned = triton.cdiv(embeddings.shape[-1] * embeddings.element_size(), + ALIGNED) * ALIGNED // embeddings.element_size() + # the data are scattered to multiple programs, which can not be even + # some process more data, some process less + big_row_block_size = triton.cdiv(n_rows, CORE_NUM) + big_core_num = CORE_NUM - ((big_row_block_size * CORE_NUM) - n_rows) + col_block_size = col_size_aligned + # BLOCK_SIZE_SUB specifies the size of data that are processed in one loop of a program + col_block_size_sub = min(1024, col_size_aligned) + + grid = (min(n_rows, CORE_NUM), triton.cdiv(n_cols, col_block_size)) + # launch the kernel + gather_sorted_kernel[grid](embeddings, sorted_indices, aux_indices, output, n_rows, n_cols, default_value, + BIG_CORE_NUM=big_core_num, BIG_ROW_BLOCK_SIZE=big_row_block_size, + COL_BLOCK_SIZE=col_block_size, COL_BLOCK_SIZE_SUB=col_block_size_sub) + + return output + + +# genreate the desired inputs +def generate_inputs(index_shape, table_shape, dtype): + sorted_indices = torch.randint(1, table_shape[0], index_shape, dtype=torch.int32).npu() + mask = torch.rand_like(sorted_indices, dtype=torch.float).npu() < 0.2 + + # make sorted_indices + sorted_indices[mask] = 0 + sorted_indices, _ = torch.sort(sorted_indices) + counts = torch.bincount(sorted_indices) + _, _indices = torch.sort(counts[sorted_indices], descending=True, stable=True) + sorted_indices = sorted_indices[_indices] + + # make aux_indicess + aux_indices = torch.arange(0, index_shape[0], dtype=torch.int32).npu() + _indices = torch.randperm(aux_indices.size(0)) + aux_indices = aux_indices[_indices] + + # make table, the first contains only 1.0 + table = torch.randn(table_shape, dtype=dtype).npu() + table[0] = 1.0 + + return table, sorted_indices, aux_indices + + +# ==================== Pytest Test ==================== +@pytest.mark.parametrize("table_rows", [500, 1000]) +@pytest.mark.parametrize("table_cols", [16, 17, 31, 32, 63, 64, 128, 256, 819, 512, 1024, 8192, 1001, 2003, 17000]) +@pytest.mark.parametrize("index_num", [19, 123, 4321, 54321, 100, 200, 819, 500, 700, 1000]) +def test_gather_sorted(table_rows, table_cols, index_num): + table, sorted_indices, aux_indices = generate_inputs((index_num, ), (table_rows, table_cols), torch.float) + + expect = torch_gather_sorted(table, sorted_indices, aux_indices).cpu() + torch.npu.synchronize() + actual = triton_gather_sorted(table, sorted_indices, aux_indices).cpu() + torch.npu.synchronize() + + torch.testing.assert_close(actual, expect) diff --git a/third_party/ascend/unittest/pytest_ut/test_11_rab_time.py b/third_party/ascend/unittest/pytest_ut/test_11_rab_time.py new file mode 100644 index 0000000000..6603a2a457 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_11_rab_time.py @@ -0,0 +1,351 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Relative Attention Bias Timestamps +=============== +""" + +import math +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver + +NUM_BUCKETS = 128 +BUCKET_DIVISOR = 0.301 + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +def create_pos_w(train_len: int, num_layers: int) -> torch.Tensor: + return torch.arange(0, 2 * train_len + 1).unsqueeze(1).repeat(1, num_layers) + + +def create_past_valid_lens(bs: int, past_len: int) -> torch.Tensor: + return torch.randint(0, past_len, (bs, )) + + +def create_timestamps(train_len: int, candidate_len: int, past_valid_lens: torch.Tensor) -> torch.Tensor: + bs = past_valid_lens.size(0) + timestamps = torch.zeros(bs, train_len + candidate_len // 2) + for i, valid_len in enumerate(past_valid_lens): + if valid_len > 0: + timestamps[i, :valid_len] = torch.arange(1, valid_len.int() + 1) + + if candidate_len <= 0: + return timestamps + timestamps[:, -candidate_len // 2:] = train_len + 1 + + return timestamps + + +def create_timestamps_weights(num_layers: int): + return (torch.arange(0, NUM_BUCKETS + 1).repeat(num_layers).reshape(NUM_BUCKETS + 1, num_layers)) + + +def create_rab_time_grad(num_layers: int, batchsize: int, s: int): + return torch.rand(num_layers, batchsize, s, s) * 1e-4 + + +def create_bucket_timestamps(batchsize: int, s: int): + result = torch.arange(batchsize * s) % NUM_BUCKETS + result = result.unsqueeze(-1).repeat(1, 1, s) + return result + + +@triton.jit +def rab_time_forward_kernel( + inp, + out, + index, + index_len: tl.constexpr, + inp_row_stride: tl.constexpr, + clamp_max: tl.constexpr, + bucketization_divisor: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + COL_BLOCK_SIZE: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + + col_iter_num = tl.cdiv(BLOCK_SIZE, COL_BLOCK_SIZE) + + for col_idx in tl.range(0, col_iter_num): + cols_offsets = (pid0 * BLOCK_SIZE + col_idx * COL_BLOCK_SIZE + tl.arange(0, COL_BLOCK_SIZE)) + cols_mask = cols_offsets < index_len + + out_mask = cols_offsets < index_len + + index_val = tl.load(index + cols_offsets, mask=cols_mask, other=0.0) + index_val = tl.abs(index_val) + index_val = tl.minimum(tl.maximum(index_val, 1.0), clamp_max) + index_val = tl.log(index_val) + index_val = index_val / bucketization_divisor + index_val = tl.cast(index_val, tl.int64) + + inp_val = tl.load(inp + pid1 * inp_row_stride + tl.arange(0, inp_row_stride)) + out_val = tl.gather(inp_val, index_val, 0) + + tl.store(out + pid1 * index_len + cols_offsets, out_val, mask=out_mask) + + +def get_outer_loop_num(num_layers, index_len): + sub_num_layers = num_layers + while sub_num_layers * index_len >= 2**31 - 1: + sub_num_layers = sub_num_layers // 2 + outer_loop_num = (num_layers + sub_num_layers - 1) // sub_num_layers + remain_layers = num_layers % sub_num_layers + return outer_loop_num, sub_num_layers, remain_layers + + +def rab_time_forward_triton(ts_w, timestamps, bucketization_divisor): + ts_w_trans = ts_w.t().contiguous() + + bs, seq_len = timestamps.shape + infer_len = 2 * seq_len + num_layers = ts_w.shape[1] + num_buckets = ts_w.shape[0] - 1 + + timestamps_expanded = timestamps.unsqueeze(-1).repeat(1, 1, 2) + timestamps_expanded = timestamps_expanded.reshape(bs, infer_len, 1) - timestamps_expanded.reshape(bs, 1, infer_len) + + timestamps_expanded = timestamps_expanded.view(-1) + timestamps_expanded = timestamps_expanded.contiguous() + + clamp_max = torch.exp(torch.tensor(num_buckets * bucketization_divisor)).item() + index_len = bs * infer_len * infer_len + + out = torch.empty((num_layers, index_len), dtype=ts_w.dtype, device=ts_w.device) + outer_loop_num, sub_num_layers, remain_layers = get_outer_loop_num(num_layers, index_len) + + CORE_NUM = get_npu_properties()["num_vectorcore"] + BLOCK_SIZE = math.ceil(index_len / CORE_NUM) + COL_BLOCK_SIZE = 8 * 1024 + + curr_layers = sub_num_layers + for i in range(outer_loop_num): + if i == outer_loop_num - 1 and remain_layers != 0: + curr_layers = remain_layers + + def grid(meta): + return (triton.cdiv(index_len, meta["BLOCK_SIZE"]), curr_layers) + + rab_time_forward_kernel[grid]( + ts_w_trans[i * sub_num_layers], + out[i * sub_num_layers], + timestamps_expanded, + index_len, + num_buckets + 1, + clamp_max, + bucketization_divisor, + BLOCK_SIZE, + COL_BLOCK_SIZE, + ) + + out = out.view(num_layers, bs, infer_len, infer_len) + + return out + + +@triton.jit +def rab_time_backward_kernel(inp, src, index, index_len, BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr): + pid0 = tl.program_id(axis=0) + total_col_num = (BLOCK_SIZE if pid0 * BLOCK_SIZE + BLOCK_SIZE < index_len else index_len - pid0 * BLOCK_SIZE) + COL_BLOCK_SIZE = min(COL_BLOCK_SIZE, total_col_num) + col_iter_num = (total_col_num + COL_BLOCK_SIZE - 1) // COL_BLOCK_SIZE + + for col_idx in tl.range(0, col_iter_num): + base_idx = 0 + base_idx = base_idx.to(index.dtype.element_ty) + + col_start_offset = col_idx * COL_BLOCK_SIZE + + acc_result = 0.0 + acc_result = acc_result.to(inp.dtype.element_ty) + cur_col_num = (COL_BLOCK_SIZE if col_start_offset + COL_BLOCK_SIZE < total_col_num else total_col_num - + col_start_offset) + + for cur_idx in range(0, cur_col_num): + cur_offset = pid0 * BLOCK_SIZE + col_start_offset + cur_idx + + src_val = tl.load(src + cur_offset) + new_idx = tl.load(index + cur_offset) + + if base_idx == new_idx: + acc_result += src_val + else: + tl.atomic_add(inp + base_idx, acc_result) + + base_idx = new_idx + acc_result = 0.0 + acc_result = acc_result.to(inp.dtype.element_ty) + acc_result += src_val + + tl.atomic_add(inp + base_idx, acc_result) + + +def rab_time_backward_triton(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor): + num_layers, b, s, _ = rab_time_grad.shape + tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to(rab_time_grad.device) + + bucket_timestamps_expand = (bucket_timestamps.reshape(b, s // 2, 1, s // 2, + 1).repeat(1, 1, 2, 1, 2).reshape(b, s, + s).to(torch.int64)).view(-1) + + index_len = bucket_timestamps_expand.numel() + + rab_time_grad_f32 = rab_time_grad.to(torch.float32) + sorted_bucket_timestamps_expand, sorted_idx = torch.sort(bucket_timestamps_expand.view(-1)) + + torch.npu.synchronize() + + def grid(meta): + return (triton.cdiv(index_len, meta["BLOCK_SIZE"]), ) + + CORE_NUM = get_npu_properties()["num_vectorcore"] + BLOCK_SIZE = math.ceil(index_len / CORE_NUM) + + COL_BLOCK_SIZE = 8 * 1024 + + for layer_idx in range(num_layers): + curr_sorted_grad_f32 = rab_time_grad_f32[layer_idx].view(-1)[sorted_idx] + rab_time_backward_kernel[grid]( + tsw_grad[layer_idx], + curr_sorted_grad_f32, + sorted_bucket_timestamps_expand, + index_len, + BLOCK_SIZE, + COL_BLOCK_SIZE, + ) + + return tsw_grad + + +def rab_time_forward_golden(ts_w: torch.Tensor, timestamps: torch.Tensor, bucketization_divisor: float) -> torch.Tensor: + """ + torch realization of rab time forward for reference. + """ + infer_len = timestamps.shape[1] * 2 + bs = timestamps.shape[0] + num_layers = ts_w.shape[1] + + timestamps = timestamps.unsqueeze(-1).repeat(1, 1, 2) + diff_timestamps = timestamps.reshape(bs, infer_len, 1) - timestamps.reshape(bs, 1, infer_len) + + clamp_max = torch.exp(torch.tensor(NUM_BUCKETS * BUCKET_DIVISOR)) + diff_timestamps = (torch.log(torch.abs(diff_timestamps).clamp(1, clamp_max)) / bucketization_divisor) + bucket_timestamps = diff_timestamps.long() + bucket_timestamps = bucket_timestamps.view(-1) + result = torch.index_select(ts_w, dim=0, index=bucket_timestamps) + + result = result.t() + + result = result.view(num_layers, bs, infer_len, infer_len) + return result + + +def rab_time_backward_golden(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor): + """ + torch realization of rab time backward for reference. + """ + num_layers, b, s, _ = rab_time_grad.shape + tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to(rab_time_grad.device) + + bucket_timestamps_expand = (bucket_timestamps.reshape(b, s // 2, 1, s // 2, + 1).repeat(1, 1, 2, 1, 2).reshape(b, s, s).to(torch.int64)) + for n, grad in enumerate(rab_time_grad.to(torch.float32)): + tsw_grad[n] = tsw_grad[n].scatter_add(src=grad.view(-1), index=bucket_timestamps_expand.view(-1), dim=0) + return tsw_grad + + +def run_rab_time_forward_case(num_layers, train_len, candidate_len, bs, dtype): + past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) + timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to(torch.int32) + timestamps_weights = create_timestamps_weights(num_layers).to(dtype) + timestamps = timestamps.npu() + timestamps_weights = timestamps_weights.npu() + + torch_npu.npu.synchronize() + + # triton output + rab_time_out_triton = rab_time_forward_triton( + ts_w=timestamps_weights, + timestamps=timestamps, + bucketization_divisor=BUCKET_DIVISOR, + ) + torch_npu.npu.synchronize() + + # pytorch output + rab_time_out_golden = rab_time_forward_golden( + ts_w=timestamps_weights, + timestamps=timestamps, + bucketization_divisor=BUCKET_DIVISOR, + ) + torch_npu.npu.synchronize() + + torch.testing.assert_close(rab_time_out_triton, rab_time_out_golden) + + +def run_rab_time_backward_case(num_layers: int, batchsize: int, s: int, dtype: torch.dtype): + grad = create_rab_time_grad(num_layers, batchsize, s).to(dtype).npu() + bucket_timestamps = (create_bucket_timestamps(batchsize, s // 2).to(torch.int32).npu()) + + torch_npu.npu.synchronize() + + golden_result = (rab_time_backward_golden(grad, bucket_timestamps).to(torch.float32).cpu()) + op_result = (rab_time_backward_triton(grad, bucket_timestamps).to(torch.float32).cpu()) + + loss = 1e-4 if dtype == torch.float32 else 1e-3 + torch.testing.assert_close(op_result, golden_result, rtol=loss, atol=loss) + + +@pytest.mark.parametrize( + "num_layers, train_len, candidate_len, batch_size, dtype", + [ + pytest.param( + 8, + 500, + 500, + 4, + torch.float32, + marks=pytest.mark.skip(reason="temporarily skip UB overflow case"), + ), + ], +) +def test_rab_time_forward(num_layers, train_len, candidate_len, batch_size, dtype): + torch.manual_seed(0) + run_rab_time_forward_case(num_layers, train_len, candidate_len, batch_size, dtype) + + +@pytest.mark.parametrize( + "num_layers, batch_size, seq_len, dtype", + [ + (8, 4, 1500, torch.float32), + ], +) +def test_rab_time_backward(num_layers, batch_size, seq_len, dtype): + torch.manual_seed(0) + run_rab_time_backward_case(num_layers, batch_size, seq_len, dtype) diff --git a/third_party/ascend/unittest/pytest_ut/test_12_hstu_attention.py b/third_party/ascend/unittest/pytest_ut/test_12_hstu_attention.py new file mode 100644 index 0000000000..f1acf686b5 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_12_hstu_attention.py @@ -0,0 +1,930 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +HSTU Attention +=============== +""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver +import numpy as np +import torch.nn.functional as F + +DEVICE = "npu" +BLOCK_FWD = 64 +BLOCK_BWD = 32 + + +@dataclass +class JaggedData: + grad: torch.Tensor + q: torch.Tensor + k: torch.Tensor + v: torch.Tensor + bias: torch.Tensor + mask: torch.Tensor + max_seq_len: int + seq_offset: np.ndarray + + +def get_npu_properties(coreType): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device)[coreType] + + +@triton.jit +def _hstu_attn_fwd_one_block( + q, + k_block_ptr, + v_block_ptr, + bias_block_ptr, + alpha, + MAX_SEQ_LEN, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + mask_block, +): + k = tl.load(k_block_ptr) + qk = tl.dot(q, tl.trans(k)) * alpha + if HAS_BIAS: + rel_attn_bias = tl.load(bias_block_ptr) + qk = qk + rel_attn_bias + silu = qk / (1.0 + tl.exp(-qk)) * (1.0 / MAX_SEQ_LEN) + if CAUSAL: + silu = tl.where(mask_block != 0, silu, 0) + v = tl.load(v_block_ptr) + silu = silu.to(v.dtype) + return tl.dot(silu, v) + + +@triton.jit +def _hstu_attn_fwd_compute( # noqa C901 + Q, + K, + V, + seq_offsets, + Out, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vh: tl.constexpr, + stride_om: tl.constexpr, + stride_oh: tl.constexpr, + alpha, + head_num, + MAX_SEQ_LEN, + off_batch, + off_head, + start_m, + seq_start, + seq_len, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + mask_block, + bias, +): + off_head = off_head.to(tl.int64) + off_seq = seq_start.to(tl.int64) + start_m = start_m.to(tl.int32) + + # initialize offsets + q_offset = off_seq * stride_qm + off_head * stride_qh + k_offset = off_seq * stride_kn + off_head * stride_kh + v_offset = off_seq * stride_vn + off_head * stride_kh + + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seq_len, BLOCK_D_Q), + strides=(stride_qm, 1), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_D_Q), + order=(1, 0), + ) + k_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(seq_len, BLOCK_D_Q), + strides=(stride_kn, 1), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_D_Q), + order=(1, 0), + ) + v_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seq_len, BLOCK_D_V), + strides=(stride_vn, 1), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_D_V), + order=(1, 0), + ) + q = tl.load(Q_block_ptr) + + acc = tl.zeros([BLOCK_M, BLOCK_D_V], dtype=tl.float32) + if CAUSAL: + low = 0 + high = start_m + BLOCK_M + else: + low = 0 + high = seq_len + + bias_block_ptr = None + if HAS_BIAS: + bias_block_ptr = tl.make_block_ptr( + base=bias + off_batch * head_num * MAX_SEQ_LEN * MAX_SEQ_LEN + off_head * MAX_SEQ_LEN * MAX_SEQ_LEN, + shape=(MAX_SEQ_LEN, MAX_SEQ_LEN), + strides=(MAX_SEQ_LEN, 1), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + for start_n in range(low, high, BLOCK_N): + acc += _hstu_attn_fwd_one_block( + q=q, + k_block_ptr=k_block_ptr, + v_block_ptr=v_block_ptr, + bias_block_ptr=bias_block_ptr, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + CAUSAL=CAUSAL and start_m == start_n, + HAS_BIAS=HAS_BIAS, + mask_block=mask_block, + ) + k_block_ptr = tl.advance(k_block_ptr, (BLOCK_N, 0)) + v_block_ptr = tl.advance(v_block_ptr, (BLOCK_N, 0)) + if HAS_BIAS: + bias_block_ptr = tl.advance(bias_block_ptr, (0, BLOCK_N)) + + # rematerialize offsets to save registers + offs_v_d = tl.arange(0, BLOCK_D_V) + off_o = Out + off_seq * stride_om + off_head * stride_oh + offs_m = start_m + tl.arange(0, BLOCK_M) + out_ptrs = off_o + offs_m[:, None] * stride_om + offs_v_d[None, :] + tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None]) + + +@triton.jit +def _hstu_attn_fwd( # noqa C901 + Q, + K, + V, + seq_offsets, + Out, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vh: tl.constexpr, + stride_om: tl.constexpr, + stride_oh: tl.constexpr, + alpha: tl.constexpr, + batch: tl.constexpr, + head_num: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + head_dim: tl.constexpr, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + CORE_NUM: tl.constexpr, + tasks: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + mask, + bias, +): + core_id = tl.program_id(0) + cur_batch = 0 + mask_block = None + if CAUSAL and mask is not None: + mask_ptr = tl.make_block_ptr( + base=mask, + shape=(MAX_SEQ_LEN, MAX_SEQ_LEN), + strides=(MAX_SEQ_LEN, 1), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_M), + order=(1, 0), + ) + mask_block = tl.load(mask_ptr) + for col in range(core_id, tasks, CORE_NUM): + seq_end = tl.load(seq_offsets + cur_batch + 1) + start_m = col * BLOCK_M + while start_m >= seq_end * head_num // 2: + cur_batch += 1 + seq_end = tl.load(seq_offsets + cur_batch + 1) + seq_start = tl.load(seq_offsets + cur_batch) + seq_len = seq_end - seq_start + off_batch = cur_batch + off_head = (start_m - seq_start * head_num // 2) // (seq_len // 2) + start_m_1 = (start_m - seq_start * head_num // 2) % (seq_len // 2) + start_m_2 = seq_len - start_m_1 - BLOCK_M + _hstu_attn_fwd_compute( + Q, + K, + V, + seq_offsets, + Out, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_om, + stride_oh, + alpha, + head_num, + MAX_SEQ_LEN, + off_batch, + off_head, + start_m_1, + seq_start, + seq_len, + CAUSAL, + HAS_BIAS, + head_dim, + head_dim, + BLOCK_M, + BLOCK_N, + mask_block=mask_block, + bias=bias, + ) + _hstu_attn_fwd_compute( + Q, + K, + V, + seq_offsets, + Out, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_om, + stride_oh, + alpha, + head_num, + MAX_SEQ_LEN, + off_batch, + off_head, + start_m_2, + seq_start, + seq_len, + CAUSAL, + HAS_BIAS, + head_dim, + head_dim, + BLOCK_M, + BLOCK_N, + mask_block=mask_block, + bias=bias, + ) + + +@triton.jit +def _hstu_attn_bwd_one_block( # noqa C901 + start_m, + offs_n, + offs_m, + q_ptrs, + dq_ptrs, + mask_n, + do_ptrs, + dk, + dv, + k, + v, + pos_offs_n, + seq_len, + max_ids, + stride_qm, + stride_dom, + stride_dqm, + alpha, + MAX_SEQ_LEN, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + bias_block_ptr, +): + pos_offs_m = offs_m + start_m + mask_m = pos_offs_m < seq_len + # recompute qk and silu + q = tl.load( + q_ptrs + start_m * stride_qm, + mask=mask_m[:, None], + other=0.0, + ) + q_trans = tl.trans(q) + qk_trans = tl.dot(k, q_trans) * alpha + if HAS_BIAS: + rel_attn_bias = tl.load(bias_block_ptr) + qk_trans = qk_trans + tl.trans(rel_attn_bias) + sig_trans = 1.0 / (1.0 + tl.exp(-qk_trans)) + silu_trans = qk_trans * sig_trans * (1.0 / MAX_SEQ_LEN) + if CAUSAL: + invalid_mask_trans = pos_offs_m[None, :] == offs_n[:, None] + pos_offs_m_minus_n = pos_offs_m[None, :] - pos_offs_n[:, None] + invalid_mask_trans = invalid_mask_trans | (pos_offs_m_minus_n > 0) + silu_trans = tl.where(invalid_mask_trans, silu_trans, 0) + silu_trans = silu_trans.to(k.dtype) + # compute dv + do = tl.load( + do_ptrs + start_m * stride_dom, + mask=mask_m[:, None], + other=0.0, + ) + dv += tl.dot(silu_trans, do) + # compute dk and dq (dqk = do * v^T dk = dqk^T * q dq = dqk * k) + dqk_trans = tl.dot(v, tl.trans(do)) + dqk_trans = dqk_trans * sig_trans * (1 + qk_trans * (1 - sig_trans)) * (1.0 / MAX_SEQ_LEN) + if CAUSAL: + dqk_trans = tl.where(invalid_mask_trans, dqk_trans, 0) + dqk_trans = dqk_trans.to(k.dtype) + dq = tl.load( + dq_ptrs + start_m * stride_dqm, + mask=mask_m[:, None], + other=0.0, + ) + dq += tl.dot(tl.trans(dqk_trans), k) * alpha + tl.store( + dq_ptrs + start_m * stride_dqm, + dq, + mask=mask_m[:, None], + ) + # Note: the factor `alpha` is delayed until the end of the function to reduce the cost + dk += tl.dot(dqk_trans, q) + return dk, dv + + +@triton.jit +def _hstu_attn_bwd_one_col_block( # noqa C901 + start_n, + seq_len, + Q, + K, + V, + DOut, + DQ, + DK, + DV, + stride_qm, + stride_kn, + stride_vn, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + alpha, + MAX_SEQ_LEN, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + bias, +): + # Work on the subsequence dv[start_n, start_n + BLOCK_N, :] + if CAUSAL: + low = start_n + high = seq_len + else: + low = 0 + high = seq_len + + # initialize row/col offsets + offs_m = tl.arange(0, BLOCK_M) + offs_qk_d = tl.arange(0, BLOCK_D_Q) + offs_v_d = tl.arange(0, BLOCK_D_V) + offs_n = start_n + tl.arange(0, BLOCK_N) + + dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_qk_d[None, :]) + dk = tl.zeros([BLOCK_N, BLOCK_D_Q], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_V], dtype=tl.float32) + + mask_n = offs_n < seq_len + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_qk_d[None, :]) + do_ptrs = DOut + (offs_m[:, None] * stride_dom + offs_v_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_qk_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_v_d[None, :]) + k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) + v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) + + max_ids = seq_len + pos_offs_n = offs_n + # loop over rows + for start_m in tl.range(low, high, BLOCK_M): + bias_block_ptr = None + if HAS_BIAS: + bias_block_ptr = tl.make_block_ptr( + base=bias, + shape=(MAX_SEQ_LEN, MAX_SEQ_LEN), + strides=(MAX_SEQ_LEN, 1), + offsets=(start_m, start_n), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + start_m = tl.multiple_of(start_m, BLOCK_M) + dk, dv = _hstu_attn_bwd_one_block( + start_m=start_m, + offs_n=offs_n, + offs_m=offs_m, + q_ptrs=q_ptrs, + dq_ptrs=dq_ptrs, + mask_n=mask_n, + do_ptrs=do_ptrs, + dk=dk, + dv=dv, + k=k, + v=v, + pos_offs_n=pos_offs_n, + seq_len=seq_len, + max_ids=max_ids, + stride_qm=stride_qm, + stride_dom=stride_dom, + stride_dqm=stride_dqm, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + CAUSAL=CAUSAL, + HAS_BIAS=HAS_BIAS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + bias_block_ptr=bias_block_ptr, + ) + # write-back + dk = dk * alpha + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_v_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_qk_d[None, :]) + tl.store(dv_ptrs, dv.to(k.dtype), mask=mask_n[:, None]) + tl.store(dk_ptrs, dk.to(k.dtype), mask=mask_n[:, None]) + + +@triton.jit +def _hstu_attn_bwd( # noqa C901 + Q, + K, + V, + Grad, + DQ, + DK, + DV, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vh: tl.constexpr, + stride_dom: tl.constexpr, + stride_doh: tl.constexpr, + seq_offsets, + alpha: tl.constexpr, + batch: tl.constexpr, + head_num: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + head_dim: tl.constexpr, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + bias, +): + off = tl.program_id(0) + off_batch = off // head_num + off_head = off % head_num + off_head = off_head.to(tl.int64) + seq_start = tl.load(seq_offsets + off_batch).to(tl.int64) + seq_end = tl.load(seq_offsets + off_batch + 1) + seq_len = (seq_end - seq_start).to(tl.int32) + # offset pointers for batch/head + q_offset = seq_start * stride_qm + off_head * stride_qh + k_offset = seq_start * stride_kn + off_head * stride_kh + v_offset = seq_start * stride_vn + off_head * stride_vh + grad_offset = seq_start * stride_dom + off_head * stride_doh + bias_offset = off_batch * head_num * MAX_SEQ_LEN * MAX_SEQ_LEN + off_head * MAX_SEQ_LEN * MAX_SEQ_LEN + for start_n in range(0, seq_len, BLOCK_N): + _hstu_attn_bwd_one_col_block( + start_n=start_n, + seq_len=seq_len, + Q=Q + q_offset, + K=K + k_offset, + V=V + v_offset, + DOut=Grad + grad_offset, + DQ=DQ + q_offset, + DK=DK + k_offset, + DV=DV + v_offset, + stride_qm=stride_qm, + stride_kn=stride_kn, + stride_vn=stride_vn, + stride_dom=stride_dom, + stride_dqm=stride_qm, + stride_dkn=stride_kn, + stride_dvn=stride_vn, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + CAUSAL=CAUSAL, + HAS_BIAS=HAS_BIAS, + BLOCK_D_Q=head_dim, + BLOCK_D_V=head_dim, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + bias=bias + bias_offset if HAS_BIAS else bias, + ) + + +def triton_hstu_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + max_seq_len: int, + alpha: float, + causal: bool, + mask: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch = seq_offsets.numel() - 1 + total_seq, head_num, head_dim = q.shape + out = torch.empty_like(v) + BLOCK_M = BLOCK_FWD + BLOCK_N = BLOCK_FWD + if total_seq == 0: + print("error") + return out + has_bias = bias is not None + core_num = get_npu_properties('num_aicore') + tasks = total_seq * head_num // BLOCK_M // 2 + grid = (core_num, 1, 1) + _hstu_attn_fwd[grid]( + q, + k, + v, + seq_offsets, + out, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + out.stride(0), + out.stride(1), + alpha, + batch, + head_num, + max_seq_len, + head_dim, + causal, + has_bias, + core_num, + tasks, + BLOCK_M, + BLOCK_N, + mask, + bias, + ) + return out + + +def triton_hstu_attention_bwd( + grad: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + max_seq_len: int, + alpha: float, + causal: bool, + bias: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + if grad.shape[0] == 0: + return dq, dk, dv + batch = seq_offsets.numel() - 1 + _, head_num, head_dim = q.shape + has_bias = bias is not None + grid = ( + batch * head_num, + 1, + ) + _hstu_attn_bwd[grid]( + q, + k, + v, + grad, + dq, + dk, + dv, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + grad.stride(0), + grad.stride(1), + seq_offsets, + alpha, + batch, + head_num, + max_seq_len, + head_dim, + causal, + has_bias, + BLOCK_BWD, + BLOCK_BWD, + bias, + ) + return dq, dk, dv + + +def jagged_data_gen(batch_size, max_seq_len, num_heads, attention_dim, dataType) -> JaggedData: + seq_array = np.arange(256, max_seq_len + 1, 256) + seq_lens = np.random.choice(seq_array, size=batch_size) + if not np.isin(max_seq_len, seq_lens): + seq_lens[np.random.randint(0, batch_size)] = max_seq_len + seq_offset = torch.concat((torch.zeros((1, ), dtype=torch.int64), torch.cumsum(torch.from_numpy(seq_lens), + axis=0))).to(torch.int64).numpy() + max_seq_len = np.max(seq_lens) + total_seqs = np.sum(seq_lens) + grad = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + q = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + k = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + v = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + + bias = torch.empty(batch_size, num_heads, max_seq_len, max_seq_len, dtype=dataType).uniform_(-1, 1) + mask = 1 - torch.triu(torch.ones(batch_size, num_heads, max_seq_len, max_seq_len), diagonal=1).to(torch.float32) + return JaggedData( + grad=grad, + q=q, + k=k, + v=v, + bias=bias, + mask=mask, + max_seq_len=max_seq_len, + seq_offset=seq_offset, + ) + + +def dense_to_jagged(q, dense_tensor, seq_lens): + tensor = torch.zeros_like(q) + offset = 0 + for batch_id, seq_len in enumerate(seq_lens): + tensor[offset:offset + seq_len, :, :] = dense_tensor[batch_id, 0:seq_len, :, :] + offset = offset + seq_len + return tensor + + +def jagged_to_dense(jagged_tensor, seq_lens, head_nums, atten_dim): + need_pad_seq = [] + offset = 0 + for _, seq_len in enumerate(seq_lens): + src_tensor = jagged_tensor[offset:offset + seq_len, :, :].reshape(seq_len, head_nums, atten_dim) + need_pad_seq.append(src_tensor) + offset = offset + seq_len + + dense_tensor = torch.nn.utils.rnn.pad_sequence(need_pad_seq, batch_first=True) + return dense_tensor + + +def gloden_fwd(q, k, v, mask, alpha, seq_offset, attnBias, max_seq_len, enable_mask, enableBias, dataType): + head_nums = q.shape[1] + head_dim = q.shape[2] + batch_size = attnBias.shape[0] + seq_lens = np.zeros((batch_size, )).astype(np.int64) + for batch_id in range(batch_size): + seq_lens[batch_id] = seq_offset[batch_id + 1] - seq_offset[batch_id] + q_dens = jagged_to_dense(q, seq_lens, head_nums, head_dim).to(dataType) + k_dens = jagged_to_dense(k, seq_lens, head_nums, head_dim).to(dataType) + v_dens = jagged_to_dense(v, seq_lens, head_nums, head_dim).to(dataType) + q_dens = q_dens.permute(0, 2, 1, 3) + k_dens = k_dens.permute(0, 2, 3, 1) + v_dens = v_dens.permute(0, 2, 1, 3) + + qk_attn = torch.matmul(q_dens, k_dens) * alpha + qk_attn = qk_attn.to(torch.float32) + attnBias = attnBias.to(torch.float32) + mask = mask.to(torch.float32) + if enableBias: + qk_attn = qk_attn + attnBias + silu = F.silu(qk_attn) * (1 / max_seq_len) + if enable_mask: + silu = silu * mask + silu = silu.to(dataType) + atten_output = torch.matmul(silu, v_dens) + + atten_output = atten_output.permute(0, 2, 1, 3) + atten_output = dense_to_jagged(q, atten_output, seq_lens) + return atten_output.to(dataType) + + +def run_fwd_case(batch_size, max_seq_len, num_heads, attention_dim, data_type): + alpha = 1 + jagged_data = jagged_data_gen(batch_size, max_seq_len, num_heads, attention_dim, data_type) + # golden 输出 + golden_output = gloden_fwd( + jagged_data.q, + jagged_data.k, + jagged_data.v, + jagged_data.mask, + alpha, + jagged_data.seq_offset, + jagged_data.bias, + jagged_data.max_seq_len, + True, + False, + data_type, + ) + # triton 输出 + seq_offsets = torch.tensor(jagged_data.seq_offset, dtype=torch.int64, device=DEVICE) + triton_output = triton_hstu_attention_fwd( + q=jagged_data.q.npu(), + k=jagged_data.k.npu(), + v=jagged_data.v.npu(), + seq_offsets=seq_offsets, + max_seq_len=int(jagged_data.max_seq_len), + alpha=alpha, + causal=True, + mask=jagged_data.mask.npu(), + ) + loss = 1e-4 + if data_type == torch.float16: + loss = 1e-3 + elif data_type == torch.bfloat16: + loss = 1e-2 + torch.testing.assert_close(triton_output.cpu(), golden_output, atol=loss, rtol=loss) + + +def golden_bwd(grad, q, k, v, bias, mask, max_seq_len, seq_offset, enable_mask, silu_scale, enable_bias, data_type): + + def jagged_to_dense_bwd(jagged_tensor, seq_lens, max_seq_len, head_num, head_dim): + batch_size = len(seq_lens) + dense_tensor = torch.zeros(batch_size, max_seq_len, head_num, head_dim, dtype=jagged_tensor.dtype) + + offset = 0 + for batch_id, seq_len in enumerate(seq_lens): + dense_tensor[batch_id, :seq_len, :, :] = jagged_tensor[offset:offset + seq_len, :, :] + offset = offset + seq_len + + return dense_tensor + + def dense_to_jagged_bwd(jagged_tensor, dense_tensor, seq_lens): + tensor = torch.zeros_like(jagged_tensor) + + offset = 0 + for batch_id, seq_len in enumerate(seq_lens): + tensor[offset:offset + seq_len, :, :] = dense_tensor[batch_id, 0:seq_len, :, :] + offset = offset + seq_len + + return tensor + + q = q.cpu() + k = k.cpu() + v = v.cpu() + grad = grad.cpu() + head_nums = grad.shape[1] + head_dim = grad.shape[2] + batch_size = bias.shape[0] + seq_lens = np.zeros((batch_size, )).astype(np.int64) + for batch_id in range(batch_size): + seq_lens[batch_id] = seq_offset[batch_id + 1] - seq_offset[batch_id] + grad_dens = jagged_to_dense_bwd(grad, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + q_dens = jagged_to_dense_bwd(q, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + k_dens = jagged_to_dense_bwd(k, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + v_dens = jagged_to_dense_bwd(v, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + actual_seq_lens = torch.from_numpy(seq_lens).reshape(batch_size, 1, 1, 1).to(data_type) + actual_seq_lens = torch.broadcast_to(actual_seq_lens, bias.shape) + qk = torch.matmul(q_dens.permute(0, 2, 1, 3), k_dens.permute(0, 2, 3, 1)) + gv = torch.matmul(grad_dens.permute(0, 2, 1, 3), v_dens.permute(0, 2, 3, 1)) + qk = qk.float() + gv = gv.float() + bias = bias.float() + if enable_mask: + mask = mask.to(data_type) + mask = mask.float() + if enable_bias: + bias = bias.to(data_type) + bias = bias.float() + qkb = qk + bias + else: + qkb = qk + real_silu_scale = 1 / max_seq_len if silu_scale == 0.0 else silu_scale + + if enable_mask: + score = F.silu(qkb) * real_silu_scale * mask + else: + score = F.silu(qkb) * real_silu_scale + score = score.to(data_type) + v_grad_dens = torch.matmul(score.permute(0, 1, 3, 2), grad_dens.permute(0, 2, 1, 3)).permute(0, 2, 1, 3) + if enable_mask: + bias_grad = gv * real_silu_scale * mask * F.sigmoid(qkb) * (1 + qkb * (1 - F.sigmoid(qkb))) + else: + bias_grad = gv * real_silu_scale * F.sigmoid(qkb) * (1 + qkb * (1 - F.sigmoid(qkb))) + bias_grad = bias_grad.to(data_type) + k_grad_dens = torch.matmul(bias_grad.permute(0, 1, 3, 2), q_dens.permute(0, 2, 1, 3)).permute(0, 2, 1, 3) + q_grad_dens = torch.matmul(bias_grad, k_dens.permute(0, 2, 1, 3)).permute(0, 2, 1, 3) + bias_grad = bias_grad.cpu() + q_grad_dens = q_grad_dens.cpu() + q_grad = dense_to_jagged_bwd(q, q_grad_dens, seq_lens) + k_grad_dens = k_grad_dens.cpu() + k_grad = dense_to_jagged_bwd(k, k_grad_dens, seq_lens) + v_grad_dens = v_grad_dens.cpu() + v_grad = dense_to_jagged_bwd(v, v_grad_dens, seq_lens) + torch.npu.synchronize() + return q_grad, k_grad, v_grad, bias_grad + + +def run_bwd_case(batch_size, max_seq_len, num_heads, attention_dim, data_type): + alpha = 1 + jagged_data = jagged_data_gen(batch_size, max_seq_len, num_heads, attention_dim, data_type) + # golden 输出 + q_grad_golden, k_grad_golden, v_grad_golden, _ = golden_bwd( + jagged_data.grad, + jagged_data.q, + jagged_data.k, + jagged_data.v, + jagged_data.bias, + jagged_data.mask, + jagged_data.max_seq_len, + jagged_data.seq_offset, + True, + 0, + False, + data_type, + ) + + # triton 输出 + seq_offsets = torch.tensor(jagged_data.seq_offset, dtype=torch.int64, device=DEVICE) + dq, dk, dv = triton_hstu_attention_bwd( + grad=jagged_data.grad.npu(), + q=jagged_data.q.npu(), + k=jagged_data.k.npu(), + v=jagged_data.v.npu(), + seq_offsets=seq_offsets, + max_seq_len=int(jagged_data.max_seq_len), + alpha=alpha, + causal=True, + ) + loss = 1e-4 + if data_type == torch.float16: + loss = 1e-3 + elif data_type == torch.bfloat16: + loss = 1e-2 + torch.testing.assert_close(dq.cpu(), q_grad_golden.cpu(), atol=loss, rtol=loss) + torch.testing.assert_close(dk.cpu(), k_grad_golden.cpu(), atol=loss, rtol=loss) + torch.testing.assert_close(dv.cpu(), v_grad_golden.cpu(), atol=loss, rtol=loss) + + +@pytest.mark.parametrize( + "batch_size, max_seq_len, num_heads, attention_dim, data_type", + [ + (2, 1024, 2, 32, torch.float32), + ], +) +def test_hstu_attention_fwd(batch_size, max_seq_len, num_heads, attention_dim, data_type): + np.random.seed(0) + torch.manual_seed(0) + run_fwd_case(batch_size, max_seq_len, num_heads, attention_dim, data_type) + + +@pytest.mark.parametrize( + "batch_size, max_seq_len, num_heads, attention_dim, data_type", + [ + (2, 1024, 2, 32, torch.float32), + ], +) +def test_hstu_attention_bwd(batch_size, max_seq_len, num_heads, attention_dim, data_type): + np.random.seed(0) + torch.manual_seed(0) + run_bwd_case(batch_size, max_seq_len, num_heads, attention_dim, data_type) diff --git a/third_party/ascend/unittest/pytest_ut/test_13_matrix_multiplication_optimized.py b/third_party/ascend/unittest/pytest_ut/test_13_matrix_multiplication_optimized.py new file mode 100644 index 0000000000..27df8cf162 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_13_matrix_multiplication_optimized.py @@ -0,0 +1,216 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver +import triton.language.extra.cann.extension as extension + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), + ], key=["M", "N", "K"]) +@triton.jit +def matmul_kernel( + mat_a, + mat_b, + mat_c, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + num_cores: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_TRESHHOLD: tl.constexpr, +): + pid = tl.program_id(axis=0) + task_m_idx = 0 + task_n_idx = 0 + ''' + 水平分核方式每个任务块编号如下 + [0, 1, 2, 3, 4, 5, 6, 7] + [8, 9, 10, 11, 12, 13, 14, 15] + [16, 17, 18, 19, 20, 21, 22, 23] + [24, 25, 26, 27, 28, 29, 30, 31] + [32, 33, 34, 35, 36, 37, 38, 39] + [40, 41, 42, 43, 44, 45, 46, 47] + [48, 49, 50, 51, 52, 53, 54, 55] + [56, 57, 58, 59, 60, 61, 62, 63] + 0核处理 0 20 40 60 4块任务 + 1核处理 1 21 41 61 4块任务 + 2核处理 2 22 42 62 4块任务 + ... + 19核处理 19 39 59 3块任务 + + 大shape下如果使用传统水平分核方式,会有如下问题 + 1:同一时间大量核心需要访问同一块左矩阵内存,产生Bank冲突,导致硬件访问效率降低 + 2:当完成一整行mat_c运算时,已经将所有右矩阵数据全部使用上,右矩阵较大时会超过L2Cache的容量上限, + 从而导致L2Cache的搬入及换出,此后每行运算都会或多或少产生CacheMiss,导致L2Cche命中率较低,影响 + 算子执行效率 + 此处使用8 * 8对角线分核方式可以按8 * 8的方块沿对角线方向分核计算,可以很大程度优化上面两点。 + + 此处以8*8对角线分核为例,实际以BLOCK_TRESHHOLD为tune参数选择最优的阈值 + 8 * 8 对角线分核方式中,每8 * 8分格内任务块编号如下 + [0, 8, 16, 24, 32, 40, 48, 56] + [57, 1, 9, 17, 25, 33, 41, 49] + [50, 58, 2, 10, 18, 26, 34, 42] + [43, 51, 59, 3, 11, 19, 27, 35] + [36, 44, 52, 60, 4, 12, 20, 28] + [29, 37, 45, 53, 61, 5, 13, 21] + [22, 30, 38, 46, 54, 62, 6, 14] + [15, 23, 31, 39, 47, 55, 63, 7] + + M轴方向超过8个基本块时,使用对角线分核可以明显减小Bank冲突 + 当右矩阵大小超过L2Cache大小时,采取对角线分核可以提升L2Cache利用率 + 所以当矩阵在M和N方向均超过8块时使能对角线分核即可有优化,当右矩阵大小超过L2Cache大小时优化效果尤为明显 + ''' + NUM_BLOCKS_M = triton.cdiv(M, BLOCK_M) + NUM_BLOCKS_N = triton.cdiv(N, BLOCK_N) + NUM_BLOCKS = NUM_BLOCKS_M * NUM_BLOCKS_N + # 当任务量较多时,可以使能对角线分核策略进行优化 + if NUM_BLOCKS_M >= BLOCK_TRESHHOLD and NUM_BLOCKS_N >= BLOCK_TRESHHOLD: + for block_idx in range(pid, NUM_BLOCKS, num_cores): + # 8 * 8 对角线分核代码实现 + curThresholdM = BLOCK_TRESHHOLD if block_idx < ( + NUM_BLOCKS_M // BLOCK_TRESHHOLD * BLOCK_TRESHHOLD) * NUM_BLOCKS_N else NUM_BLOCKS_M % BLOCK_TRESHHOLD + curThresholdM_thresholdN = curThresholdM * BLOCK_TRESHHOLD + curThresholdN = BLOCK_TRESHHOLD if block_idx % (NUM_BLOCKS_N * BLOCK_TRESHHOLD) < ( + curThresholdM * + NUM_BLOCKS_N) // curThresholdM_thresholdN * curThresholdM_thresholdN else NUM_BLOCKS_N % BLOCK_TRESHHOLD + localRelativeBlock = block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) % (BLOCK_TRESHHOLD * curThresholdM) + task_m_idx = localRelativeBlock % curThresholdM + block_idx // (BLOCK_TRESHHOLD * + NUM_BLOCKS_N) * BLOCK_TRESHHOLD + # 求最小公倍数,方便求基本块的坐标 + x, y = curThresholdM, curThresholdN if curThresholdM > curThresholdN else curThresholdN, curThresholdM + while y != 0: + x, y = y, x % y + lcm = curThresholdM * curThresholdN // x + task_n_idx = (localRelativeBlock + (localRelativeBlock // lcm)) % curThresholdN + block_idx % ( + BLOCK_TRESHHOLD * NUM_BLOCKS_N) // curThresholdM_thresholdN * BLOCK_TRESHHOLD + + m_start = task_m_idx * BLOCK_M + n_start = task_n_idx * BLOCK_N + + mat_c_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k_start in range(0, K, BLOCK_K): + mat_a_offset = ( + (m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + (k_start + tl.arange(0, BLOCK_K))[None, :] + mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (k_start + tl.arange(0, BLOCK_K)) < K)[None, :] + mat_a_block = tl.load(mat_a + mat_a_offset, mask=mat_a_mask, other=0.0) + extension.compile_hint(mat_a_block, "dot_pad_only_k") + mat_b_offset = ( + (k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] + mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] + mat_b_block = tl.load(mat_b + mat_b_offset, mask=mat_b_mask, other=0.0) + extension.compile_hint(mat_b_block, "dot_pad_only_k") + mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] + mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] + tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask=mat_c_mask) + else: + # 传统顺序分核 + for block_idx in range(pid, NUM_BLOCKS, num_cores): + task_m_idx = block_idx // NUM_BLOCKS_N + task_n_idx = block_idx % NUM_BLOCKS_N + m_start = task_m_idx * BLOCK_M + n_start = task_n_idx * BLOCK_N + + mat_c_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k_start in range(0, K, BLOCK_K): + mat_a_offset = ( + (m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + (k_start + tl.arange(0, BLOCK_K))[None, :] + mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (k_start + tl.arange(0, BLOCK_K)) < K)[None, :] + mat_a_block = tl.load(mat_a + mat_a_offset, mask=mat_a_mask, other=0.0) + extension.compile_hint(mat_a_block, "dot_pad_only_k") + mat_b_offset = ( + (k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] + mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] + mat_b_block = tl.load(mat_b + mat_b_offset, mask=mat_b_mask, other=0.0) + extension.compile_hint(mat_b_block, "dot_pad_only_k") + mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] + mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] + tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask=mat_c_mask) + + +def triton_matmul( + mat_a, + mat_b, +): + m = mat_a.shape[0] + k = mat_a.shape[1] + n = mat_b.shape[1] + mat_c = torch.empty(m, n, dtype=mat_a.dtype, device=mat_a.device) + ''' + NPU芯片更加亲和512B对齐场景,如下分块通用性能较好,可以使用autotune选取最优 + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 256 + ''' + + num_cores = get_npu_properties()["num_aicore"] + + matmul_kernel[(num_cores, )](mat_a, mat_b, mat_c, m, n, k, num_cores) + return mat_c + + +# ==================== Pytest Test ==================== +def test_matmul_extension(): + M = 2048 + K = 7168 + N = 16384 + + mat_a = torch.randn([M, K], dtype=torch.bfloat16, device="npu") + mat_b = torch.randn([K, N], dtype=torch.bfloat16, device="npu") + + result = triton_matmul(mat_a, mat_b) + golden = torch.matmul(mat_a, mat_b) + + mask = golden.abs() < 1.0 + tmpatol = tmprtol = 2**-6 + + torch.testing.assert_close(result[mask], golden[mask], atol=tmpatol, rtol=0) + torch.testing.assert_close(result[~mask], golden[~mask], atol=0, rtol=tmprtol) diff --git a/third_party/ascend/unittest/pytest_ut/test_14_accuracy_comparison.py b/third_party/ascend/unittest/pytest_ut/test_14_accuracy_comparison.py new file mode 100644 index 0000000000..bb7720ec4a --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_14_accuracy_comparison.py @@ -0,0 +1,144 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + + +def run_add(x0, x1): + """ + 测试 Triton 实现的向量加法与 PyTorch 的结果,精度比对是否一致。 + + 步骤: + 1. 使用 PyTorch 计算参考结果(torch_ref) + 2. 使用 Triton 编写 kernel 并计算结果(triton_cal) + 3. 调用 accuracy_comparison 进行精度比对 + """ + + # 1. 使用 PyTorch 作为参考实现(golden truth) + def torch_func(x0, x1): + res = x0 + x1 + return res + + # 2. 定义 Triton kernel(在 NPU/GPU 上执行) + @triton.jit + def triton_kernel_add(out_ptr0, # 输出指针:结果存储位置 + in_ptr0, # 输入指针0:x0 的起始地址 + in_ptr1, # 输入指针1:x1 的起始地址 + XS: tl.constexpr # constexpr 参数:向量长度,在编译时确定 + ): + # 生成 [0, 1, 2, ..., XS-1] 的索引数组 + idx = tl.arange(0, XS) + # 从 in_ptr0 + idx 处加载 x0 的值 + tmp0 = tl.load(in_ptr0 + idx) + # 从 in_ptr1 + idx 处加载 x1 的值 + tmp1 = tl.load(in_ptr1 + idx) + # 执行加法 + tmp2 = tmp0 + tmp1 + # 将结果写入 out_ptr0 + idx + tl.store(out_ptr0 + idx, tmp2) + + # 3. Triton 封装函数:调用 kernel 并返回结果 + def triton_func(x0, x1): + y0 = torch.empty_like(x0) # 创建与输入形状、dtype 相同的输出张量 + # 启动 kernel:grid = [1, 1, 1] 表示仅使用一个 block + # 注意:XS 必须作为参数传入,因为它是 tl.constexpr 类型 + triton_kernel_add[1, 1, 1](y0, x0, x1, XS=x0.numel()) + return y0 + + # 4. 获取参考结果和 Triton 计算结果 + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1) + + # 5. 精度比对 + accuracy_comparison(triton_cal, torch_ref) + + # 6. 打印成功信息 + print( + f"== dtype:{triton_cal.dtype} == The accuracy comparison between triton_result and torch_result was successful." + ) + + +def accuracy_comparison(y_cal, y_ref): + """ + 精度比对函数:根据数据类型选择合适的比对策略。 + + 不同数据类型的处理策略: + - 浮点类型(float16/32, bfloat16):使用 torch.testing.assert_close,设置相对/绝对误差容限 + - 整数类型(int8/16/32/64):要求完全相等(torch.equal) + - 布尔类型(bool):CPU 上严格比较(避免设备差异) + """ + # 检查输出数据类型是否一致 + assert y_cal.dtype == y_ref.dtype, f"dtype mismatch: {y_cal.dtype} vs {y_ref.dtype}" + tensor_dtype = y_cal.dtype + + # 将张量移动到 NPU(假设测试在 NPU 上进行) + y_cal = y_cal.npu() + y_ref = y_ref.npu() + + # 根据数据类型选择不同的比对方式 + if tensor_dtype == torch.float16: + # float16 精度较低,允许稍大误差 + torch.testing.assert_close(y_ref, y_cal, rtol=1e-3, atol=1e-3, equal_nan=True) + elif tensor_dtype == torch.bfloat16: + # bfloat16 精度更低,建议转为 float32 再比较 + torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=1e-3, atol=1e-3, + equal_nan=True) + elif tensor_dtype == torch.float32: + # float32 精度较高,使用更严格的容差 + torch.testing.assert_close(y_ref, y_cal, rtol=1e-4, atol=1e-4, equal_nan=True) + elif tensor_dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint32]: + # 整数类型应完全相等 + assert torch.equal(y_cal, y_ref), f"Integer tensors are not equal for dtype {tensor_dtype}" + elif tensor_dtype == torch.bool: + # 布尔类型建议在 CPU 上比较,避免设备间布尔表示差异 + assert torch.equal(y_cal.cpu(), y_ref.cpu()), "Boolean tensors are not equal" + else: + raise ValueError(f'Invalid or unsupported tensor dtype: {tensor_dtype}') + + +# ==================== Pytest Test ==================== +@pytest.mark.parametrize("dtype_name, dtype, low, high", [ + ("fp32", torch.float32, 0, 1), + ("fp16", torch.float16, 0, 1), + ("bf16", torch.bfloat16, 0, 1), + ("i64", torch.int64, 1, 100), + ("i32", torch.int32, 1, 100), + ("i16", torch.int16, 1, 100), + ("i8", torch.int8, 1, 100), + ("i1", torch.bool, 0, 2), +]) +def test_all_dtypes(dtype_name, dtype, low, high): + N = 1024 + if dtype == torch.bool: + x0 = torch.randint(low=low, high=high, size=(N, )).bool().npu() + x1 = torch.randint(low=low, high=high, size=(N, )).bool().npu() + elif dtype.is_floating_point: + x0 = torch.rand((N, ), dtype=dtype).npu() + x1 = torch.rand((N, ), dtype=dtype).npu() + else: + x0 = torch.randint(low=low, high=high, size=(N, ), dtype=dtype).npu() + x1 = torch.randint(low=low, high=high, size=(N, ), dtype=dtype).npu() + + print(f"Running test for {dtype_name}...") + run_add(x0, x1) diff --git a/third_party/ascend/unittest/pytest_ut/test_15_demo_autotune.py b/third_party/ascend/unittest/pytest_ut/test_15_demo_autotune.py new file mode 100644 index 0000000000..a63a0b7ec3 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_15_demo_autotune.py @@ -0,0 +1,88 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Autotune +============= +""" +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + + +# Return a set of different kernel configurations for autotune +def get_autotune_config(): + return [ + triton.Config({'XS': 1 * 128, 'multibuffer': True}), + triton.Config({'XS': 12 * 1024, 'multibuffer': True}), + triton.Config({'XS': 12 * 1024, 'multibuffer': False}), + triton.Config({'XS': 8 * 1024, 'multibuffer': True}), + ] + + +# Use @autotune decorator to automatically select the best kernel configuration +@triton.autotune( + configs=get_autotune_config(), + key=["numel"], +) +@triton.jit +def triton_calc_kernel(out_ptr0, in_ptr0, in_ptr1, numel, + XS: tl.constexpr # Block size controlling how many elements each thread block processes + ): + pid = tl.program_id(0) + idx = pid * XS + tl.arange(0, XS) + msk = idx < numel + for i in range(10000): + tmp0 = tl.load(in_ptr0 + idx, mask=msk, other=0.0) + tmp1 = tl.load(in_ptr1 + idx, mask=msk, other=0.0) + tmp2 = tl.math.exp(tmp0) + tmp1 + i + tl.store(out_ptr0 + idx, tmp2, mask=msk) + + +# Function to call the Triton kernel with autotuned configuration +def triton_calc_func(x0, x1): + n = x0.numel() + y0 = torch.empty_like(x0) + + def grid(meta): + return (triton.cdiv(n, meta["XS"]), 1, 1) + + triton_calc_kernel[grid](y0, x0, x1, n) + return y0 + + +# Reference implementation using PyTorch for correctness check +def torch_calc_func(x0, x1): + return torch.exp(x0) + x1 + 10000 - 1 + + +# ==================== Pytest Test ==================== +def test_triton_autotune(): + DEV = "npu" + DTYPE = torch.float32 + N = 192 * 1024 + x0 = torch.randn((N, ), dtype=DTYPE, device=DEV) + x1 = torch.randn((N, ), dtype=DTYPE, device=DEV) + + torch_ref = torch_calc_func(x0, x1) + triton_cal = triton_calc_func(x0, x1) + + torch.testing.assert_close(triton_cal, torch_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_16_profiler.py b/third_party/ascend/unittest/pytest_ut/test_16_profiler.py new file mode 100644 index 0000000000..393969cf63 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_16_profiler.py @@ -0,0 +1,116 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + + +def profiler_wrapper(fn, *args): + result_path = "./result_profiling" + skip_first = 10 + wait = 0 + warmup = 3 + active = 30 + repeat = 1 + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) + with torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, + skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), record_shapes=True, + profile_memory=False, with_stack=False, with_flops=False, with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for _ in range(skip_first + (wait + warmup + active) * repeat): + fn(*args) + prof.step() + stream.synchronize() + + +@triton.jit +def triton_kernel_add(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr): + idx = tl.arange(0, XS) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.load(in_ptr1 + idx) + tmp2 = tmp0 + tmp1 + tl.store(out_ptr0 + idx, tmp2) + + +@triton.jit +def triton_kernel_or(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr): + idx = tl.arange(0, XS) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.load(in_ptr1 + idx) + tmp2 = tmp0 | tmp1 + tl.store(out_ptr0 + idx, tmp2) + + +def triton_add_func(x0, x1, N): + y0 = torch.empty_like(x0) + triton_kernel_add[1, 1, 1](y0, x0, x1, N) + return y0 + + +def triton_or_func(x0, x1, N): + y0 = torch.empty_like(x0) + triton_kernel_or[1, 1, 1](y0, x0, x1, N) + return y0 + + +# ==================== Pytest Test ==================== +@pytest.mark.parametrize("dtype, low, high", [ + (torch.float32, 0, 1), + (torch.float16, 0, 1), + (torch.bfloat16, 0, 1), + (torch.int64, 1, 100), + (torch.int32, 1, 100), + (torch.int16, 1, 100), + (torch.int8, 1, 100), + (torch.bool, 0, 2), +]) +def test_elementwise_ops(dtype, low, high): + N = 1024 + test_case_is_inductor = False + + if dtype == torch.bool: + x0 = torch.randint(low=low, high=high, size=(N, )).bool().npu() + x1 = torch.randint(low=low, high=high, size=(N, )).bool().npu() + triton_cal = triton_or_func(x0, x1, N) + ref = x0 | x1 + else: + if dtype.is_floating_point: + x0 = torch.rand((N, ), dtype=dtype).npu() + x1 = torch.rand((N, ), dtype=dtype).npu() + else: + x0 = torch.randint(low=low, high=high, size=(N, ), dtype=dtype).npu() + x1 = torch.randint(low=low, high=high, size=(N, ), dtype=dtype).npu() + + triton_cal = triton_add_func(x0, x1, N) + ref = x0 + x1 + + torch.testing.assert_close(triton_cal, ref) + + def wrapper(): + _ = triton_add_func(x0, x1, N) if dtype != torch.bool else triton_or_func(x0, x1, N) + + profiler_wrapper(wrapper) diff --git a/third_party/ascend/unittest/pytest_ut/test_17_demo_libentry.py b/third_party/ascend/unittest/pytest_ut/test_17_demo_libentry.py new file mode 100644 index 0000000000..a6172552f4 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_17_demo_libentry.py @@ -0,0 +1,131 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import time + +import pytest +import torch +import torch_npu + +import triton +import triton.language as tl +from triton.runtime.libentry import libentry + +DEV = "npu" +DTYPE = torch.float32 +SEQ_LEN = 2 * 1024 +device = torch.npu.current_device() +stream = torch.npu.current_stream(device) + + +def benchmark(func): + warmup = 10 + repeat = 100 + + def wrapper(*args, **kwargs): + # + for _ in range(warmup): + result = func(*args, **kwargs) + stream.synchronize() + # + start_time = time.perf_counter_ns() + for _ in range(repeat): + result = func(*args, **kwargs) + stream.synchronize() + end_time = time.perf_counter_ns() + # + start_time = start_time * 1e-3 + end_time = end_time * 1e-3 + elapsed_time = (end_time - start_time) / repeat + return (result, elapsed_time) + + return wrapper + + +@libentry() +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + XBLOCK: tl.constexpr, + XBLOCK_SUB: tl.constexpr, + RBLOCK: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * XBLOCK + rblk_idx = tl.arange(0, XBLOCK_SUB) + col_idx = tl.arange(0, RBLOCK) + for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB): + row_offsets = row_start + row_idx + rblk_idx[:, None] + col_offsets = col_idx[None, :] + xmask = row_offsets < n_rows + ymask = col_offsets < n_cols + mask = xmask & ymask + input_idx = row_offsets * input_row_stride + col_offsets + input_ptrs = input_ptr + input_idx + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=1).reshape(XBLOCK_SUB, 1) + softmax_output = numerator / denominator + output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) + tl.store(output_ptrs, softmax_output, mask=mask) + + +@benchmark +def torch_func(x0: torch.Tensor): + m = torch.nn.Softmax(dim=1) + return m(x0) + + +@benchmark +def triton_func(y0: torch.Tensor, x0: torch.Tensor): + n_rows, n_cols = x0.shape + ncore = 40 + xs = (n_rows + ncore - 1) // ncore + xss = min(xs, 5) + softmax_kernel[(ncore, 1, 1)]( + y0, + x0, + x0.stride(0), + y0.stride(0), + n_rows, + n_cols, + XBLOCK=xs, + XBLOCK_SUB=xss, + RBLOCK=n_cols, + ) + return y0 + + +@pytest.mark.parametrize("batch", [1000 * x for x in range(1, 16 + 1)]) +def test_demo_libentry_softmax(batch): + torch.manual_seed(0) + x = torch.rand((batch, SEQ_LEN), dtype=DTYPE, device=DEV) + y = torch.empty_like(x) + + torch_out, _ = torch_func(x) + triton_out, _ = triton_func(y, x) + + torch.testing.assert_close(triton_out, torch_out) diff --git a/third_party/ascend/unittest/pytest_ut/test_18_gather.py b/third_party/ascend/unittest/pytest_ut/test_18_gather.py new file mode 100644 index 0000000000..f8a4c70898 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_18_gather.py @@ -0,0 +1,139 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Gather +=============== +This is an example only for npu. +""" + +import pytest +import torch +import torch_npu +import triton +import triton.runtime.driver as driver +import triton.language as tl + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +# a torch-version gather benchmark +def torch_gather(embeddings, idxes, default_value=0.0): + # make the result tensor + res = torch.empty((idxes.shape[0], embeddings.shape[-1]), dtype=embeddings.dtype, device=embeddings.device) + + # scatter embeddings + res[idxes >= 0] = embeddings[idxes[idxes >= 0]] + # set default values + res[idxes < 0] = default_value + + return res + + +# triton-version gather's kernel +@triton.jit +def gather_kernel(embeddings_ptr, idxes_ptr, res_ptr, rows, cols, DEFAULT_VALUE: tl.constexpr, + BIG_CORE_NUM: tl.constexpr, BIG_ROW_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr, + COL_BLOCK_SIZE_SUB: tl.constexpr): + SMALL_ROW_BLOCK_SIZE = BIG_ROW_BLOCK_SIZE - 1 + + embedding_dtype = embeddings_ptr.type.element_ty + default_value = tl.cast(DEFAULT_VALUE, dtype=embedding_dtype) + default_embedding = tl.full((COL_BLOCK_SIZE_SUB, ), default_value, dtype=embedding_dtype) + + core_idx = tl.program_id(0) + # compute the the size and start index of block + row_block_size = BIG_ROW_BLOCK_SIZE if (core_idx < BIG_CORE_NUM) else SMALL_ROW_BLOCK_SIZE + row_start_idx = (core_idx * BIG_ROW_BLOCK_SIZE) if (core_idx < BIG_CORE_NUM) else ( + BIG_CORE_NUM * BIG_ROW_BLOCK_SIZE + (core_idx - BIG_CORE_NUM) * SMALL_ROW_BLOCK_SIZE) + + # process blocks witn shape (row_block_size, COL_BLOCK_SIZE_SUB) one by one + for col_idx in tl.range(0, COL_BLOCK_SIZE, COL_BLOCK_SIZE_SUB): + emb_col_offsets = col_idx + tl.arange(0, COL_BLOCK_SIZE_SUB) + emb_col_mask = emb_col_offsets < cols + + for row_idx in tl.range(row_start_idx, min(row_start_idx + row_block_size, rows)): + idx_val = tl.load(idxes_ptr + row_idx) + + write_row_offset = row_idx * cols + write_emb_mask = emb_col_mask + + if idx_val >= 0: + read_row_offset = idx_val * cols + read_emb_mask = emb_col_mask + # read embedding + embedding = tl.load(embeddings_ptr + read_row_offset + emb_col_offsets, mask=read_emb_mask) + tl.store(res_ptr + write_row_offset + emb_col_offsets, embedding, write_emb_mask) + else: + # set default values + tl.store(res_ptr + write_row_offset + emb_col_offsets, default_embedding, write_emb_mask) + + +# triton-version gather's host +def triton_gather(embeddings: torch.Tensor, indices: torch.Tensor, default_value=0.0): + # constant settings for npu + USE_SIZE = 96 * 1024 + CORE_NUM = get_npu_properties()["num_vectorcore"] + + n_rows = indices.shape[0] + n_cols = embeddings.shape[1] + # make the result tensor + output = torch.empty(n_rows, n_cols, dtype=embeddings.dtype, device=embeddings.device) + + # when writing an npu kernel using triton, + # you should note that the difference between BLOCK_SIZE and BLOCK_SIZE_SUB + # BLOCK_SIZE specifies the size of data that are processed in one program + col_size_aligned = triton.cdiv(embeddings.shape[-1] * embeddings.element_size(), + 32) * 32 // embeddings.element_size() + # the data are scattered to multiple programs, which can not be even + # some process more data, some process less + big_row_block_size = triton.cdiv(n_rows, CORE_NUM) + big_core_num = CORE_NUM - ((big_row_block_size * CORE_NUM) - n_rows) + col_block_size = col_size_aligned + + # BLOCK_SIZE_SUB specifies the size of data that are processed in one loop of a program + max_col_block_size_sub = USE_SIZE // embeddings.element_size() // 2 + col_block_size_sub = min(col_size_aligned, max_col_block_size_sub) + + grid = (min(n_rows, CORE_NUM), triton.cdiv(n_cols, col_block_size)) + # launch the kernel + gather_kernel[grid](embeddings, indices, output, n_rows, n_cols, default_value, BIG_CORE_NUM=big_core_num, + BIG_ROW_BLOCK_SIZE=big_row_block_size, COL_BLOCK_SIZE=col_block_size, + COL_BLOCK_SIZE_SUB=col_block_size_sub) + + return output + + +# ==================== Pytest Test ==================== +@pytest.mark.parametrize("n_rows", [500, 1000]) +@pytest.mark.parametrize("n_cols", [16, 17, 31, 32, 63, 64, 128, 256, 819, 512, 1024, 8192, 1001, 2003, 17000]) +@pytest.mark.parametrize("index_num", [19, 123, 4321, 54321, 100, 200, 819, 500, 700, 1000]) +def test_gather(n_rows, n_cols, index_num): + indices = torch.randint(0, n_rows, (index_num, ), dtype=torch.int32).npu() + embeddings = torch.randn(n_rows, n_cols, dtype=torch.float).npu() + + expect = torch_gather(embeddings, indices).cpu() + actual = triton_gather(embeddings, indices).cpu() + torch.npu.synchronize() + + torch.testing.assert_close(actual, expect) diff --git a/third_party/ascend/unittest/pytest_ut/test_add.py b/third_party/ascend/unittest/pytest_ut/test_add.py index f6e9dd5a42..be88ce1a5d 100644 --- a/third_party/ascend/unittest/pytest_ut/test_add.py +++ b/third_party/ascend/unittest/pytest_ut/test_add.py @@ -74,3 +74,18 @@ def test_all_blocks_parallel(param_list, monkeypatch): triton_add[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) test_common.validate_cmp(dtype, y_cal, y_ref) monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") + + +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], +]) +def test_auto_blockify(param_list, monkeypatch): + monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_add[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub, auto_blockify_size=ncore) + test_common.validate_cmp(dtype, y_cal, y_ref) + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") diff --git a/third_party/ascend/unittest/pytest_ut/test_address_check.py b/third_party/ascend/unittest/pytest_ut/test_address_check.py new file mode 100644 index 0000000000..3a8429c08f --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_address_check.py @@ -0,0 +1,68 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import torch_npu +import triton +import triton.language as tl +import pytest + + +@triton.jit +def simple_kernel(x_ptr, y_ptr, output_ptr, n_elements): + pid = tl.program_id(axis=0) + offsets = pid * 1024 + tl.arange(0, 1024) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + ret = x + y + tl.store(output_ptr + offsets, ret, mask=mask) + + +def test_npu_tensor_should_success(): + print("Test the NPU tensor. The NPU tensor should be passed and executed properly.") + + size = 1024 + x_npu = torch.rand(size, device='npu') + y_npu = torch.rand(size, device='npu') + output = torch.empty(size, device='npu') + + simple_kernel[(1, )](x_npu, y_npu, output, size) + + expected = x_npu + y_npu + actual = output + + torch.testing.assert_close(expected, actual, rtol=1e-03, atol=1e-03) + + +def test_cpu_tensor_should_fail(): + print("Test the CPU tensor. An address check exception should be raised.") + + size = 1024 + x_cpu = torch.rand(size, device='cpu') + y_cpu = torch.rand(size, device='cpu') + output = torch.empty(size, device='npu') + + with pytest.raises(ValueError) as exc_info: + simple_kernel[(1, )](x_cpu, y_cpu, output, size) + + error_msg = str(exc_info.value) + assert "cannot be accessed from Triton (cpu tensor?)" in error_msg, \ + f"Expected error message to contain CPU tensor rejection hint, but got: {error_msg}" diff --git a/third_party/ascend/unittest/pytest_ut/test_advance_ptr.py b/third_party/ascend/unittest/pytest_ut/test_advance_ptr.py new file mode 100644 index 0000000000..f9052cd28e --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_advance_ptr.py @@ -0,0 +1,51 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest + + +@triton.jit +def fn_npu_3d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + block_ptr_in = tl.make_block_ptr(base=x_ptr, shape=(XB, YB, ZB), strides=(YB * ZB, ZB, 1), offsets=(0, 0, 0), + block_shape=(XB, YB, 2), order=(2, 1, 0)) + block_ptr_out = tl.make_block_ptr(base=output_ptr, shape=(XB, YB, ZB), strides=(YB * ZB, ZB, 1), offsets=(0, 0, 0), + block_shape=(XB, YB, 2), order=(2, 1, 0)) + pid = tl.program_id(axis=0) # pid=0,1 BLOCK_SIZE_N=8 + for _ in range(ZB // 2): + X = tl.load(block_ptr_in, boundary_check=(0, 1, 2)) + tl.store(block_ptr_out, X, boundary_check=(0, 1, 2)) + block_ptr_in = tl.advance(block_ptr_in, (0, 0, 2)) + block_ptr_out = tl.advance(block_ptr_out, (0, 0, 2)) + + +@pytest.mark.parametrize('dtype', ["int32", "float32", "int16"]) +@pytest.mark.parametrize('shape', [(33, 9, 6), (8, 8, 4)]) +def test_advance_with_boundary_check(dtype, shape): + x = torch.randint(low=-128, high=128, size=shape, dtype=eval('torch.' + dtype)).npu() + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + expected = x + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + fn_npu_3d[1, 1, 1](output, x, XB=shape[0], YB=shape[1], ZB=shape[2]) + torch.testing.assert_close(output, expected) diff --git a/third_party/ascend/unittest/pytest_ut/test_affine_map_binding.py b/third_party/ascend/unittest/pytest_ut/test_affine_map_binding.py new file mode 100644 index 0000000000..ba3054b655 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_affine_map_binding.py @@ -0,0 +1,114 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +import triton.language.extra.cann.extension as al +import pytest + + +def test_extension_reexports_affine_bindings(): + assert al.affine_map is ascend_ir.affine_map + assert al.affine_expr is ascend_ir.affine_expr + assert al.affine_constant_expr is ascend_ir.affine_constant_expr + assert al.affine_dim_expr is ascend_ir.affine_dim_expr + assert al.affine_symbol_expr is ascend_ir.affine_symbol_expr + assert al.affine_binary_op_expr is ascend_ir.affine_binary_op_expr + assert al.AffineMap is al.affine_map + assert al.AffineExpr is al.affine_expr + assert al.AffineConstantExpr is al.affine_constant_expr + assert al.AffineDimExpr is al.affine_dim_expr + assert al.AffineSymbolExpr is al.affine_symbol_expr + assert al.AffineBinaryOpExpr is al.affine_binary_op_expr + + +def test_make_affine_map(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + c2 = ascend_ir.affine_expr.get_constant(2) + + expr = (d0 + c2) * d1 + assert "d0" in str(expr) and "d1" in str(expr) + assert not expr.is_pure_affine() + assert hash(expr) == hash(expr) + assert d0 == ascend_ir.affine_expr.get_dim(0) + assert c2 == ascend_ir.affine_expr.get_constant(2) + assert isinstance(c2, ascend_ir.affine_expr) + assert isinstance(d0, ascend_ir.affine_expr) + + identity_map = ascend_ir.affine_map.get_identity(2) + transpose_map = ascend_ir.affine_map.get(2, 0, [1, 0]) + transpose_map_by_expr = ascend_ir.affine_map.get(2, 0, [d1, d0]) + sum_map = ascend_ir.affine_map.get(2, 0, [d0 + d1, d1]) + const_map = ascend_ir.affine_map.get_constant(7) + minor_identity_map = ascend_ir.affine_map.get_minor_identity(3, 2) + + assert identity_map.is_identity() + assert identity_map.is_permutation() + assert identity_map.get_num_dims() == 2 + assert identity_map.get_num_symbols() == 0 + assert identity_map.get_num_results() == 2 + assert str(identity_map) == "(d0, d1) -> (d0, d1)" + + assert not transpose_map.is_identity() + assert transpose_map.is_permutation() + assert str(transpose_map) == "(d0, d1) -> (d1, d0)" + assert str(transpose_map_by_expr) == "(d0, d1) -> (d1, d0)" + assert str(sum_map) == "(d0, d1) -> (d0 + d1, d1)" + assert transpose_map.to_dict() == { + "num_dims": 2, + "num_symbols": 0, + "results": [1, 0], + } + assert str(sum_map.get_sub_map([1])) == "(d0, d1) -> (d1)" + assert str(sum_map.compose(transpose_map)) == "(d0, d1) -> (d1 + d0, d0)" + assert str(transpose_map.inverse_permutation()) == "(d0, d1) -> (d1, d0)" + assert transpose_map == transpose_map_by_expr + assert hash(transpose_map) == hash(transpose_map) + assert [str(r) for r in sum_map.get_results()] == ["d0 + d1", "d1"] + assert const_map.is_single_constant() + assert const_map.get_constant_result() == 7 + assert str(minor_identity_map) == "(d0, d1, d2) -> (d1, d2)" + + +def test_build_buffer_type_with_affine_map(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + builder = ascend_ir.ascendnpu_ir_builder(ctx) + + transpose_map = ascend_ir.affine_map.get(2, 0, [1, 0]) + ub_attr = builder.get_target_attribute(ascend_ir.AddressSpace.UB) + + buffer_ty = builder.get_buffer_ty_with_affine_map([8, 16], builder.get_float_ty(), transpose_map, ub_attr) + + assert "memref<8x16xf32" in str(buffer_ty) + assert "affine_map<(d0, d1) -> (d1, d0)>" in str(buffer_ty) + assert "ub" in str(buffer_ty) + + +if __name__ == '__main__': + test_build_buffer_type_with_affine_map() + test_extension_reexports_affine_bindings() + test_make_affine_map() diff --git a/third_party/ascend/unittest/pytest_ut/test_alloc.py b/third_party/ascend/unittest/pytest_ut/test_alloc.py index a1b2b4360c..450e0c5b98 100644 --- a/third_party/ascend/unittest/pytest_ut/test_alloc.py +++ b/third_party/ascend/unittest/pytest_ut/test_alloc.py @@ -66,6 +66,7 @@ def allocate_local_buffer(XBLOCK: tl.constexpr): bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0A) bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0B) bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0C) + bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.UB, is_mem_unique=True) # ============== Main for manual testing ============== diff --git a/third_party/ascend/unittest/pytest_ut/test_arch.py b/third_party/ascend/unittest/pytest_ut/test_arch.py new file mode 100644 index 0000000000..4946ca3538 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_arch.py @@ -0,0 +1,94 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os +import pytest +import triton +import triton.language as tl +import triton.extension.buffer.language as bl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir, buffer_ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + arch = "Ascend950" + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + buffer_ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + + +@triton.jit +def copy( + A_ptr, + A1_ptr, + M: tl.constexpr, + N: tl.constexpr, +): + offs_a = tl.arange(0, M)[:, None] + offs_b = tl.arange(0, N)[None, :] + + offs_c = (offs_a) * M + (offs_b) + a_ptr = A_ptr + offs_c + a_val = tl.load(a_ptr) + a1_ptr = A1_ptr + offs_c + a1_val = tl.load(a1_ptr) + + add = tl.add(a_val, a1_val) + + add_ub = bl.to_buffer(add, al.ascend_address_space.UB) + A_l1 = bl.alloc(tl.float32, [M, N], al.ascend_address_space.L1) + al.copy(add_ub, A_l1) + + +def test_arch(): + print("=" * 60) + print("Test 1: copy ") + print("=" * 60) + mlir = compile_kernel( + copy, + {"A_ptr": "*fp32", "A1_ptr": "*fp32"}, + {"M": 16, "N": 16}, + ) + print(f"✅ Generated MLIR ({len(mlir)} chars):\n") + print(mlir) + + +# ============== Main for manual testing ============== +if __name__ == "__main__": + test_arch() diff --git a/third_party/ascend/unittest/pytest_ut/test_argmax.py b/third_party/ascend/unittest/pytest_ut/test_argmax.py new file mode 100644 index 0000000000..bb26a40884 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_argmax.py @@ -0,0 +1,65 @@ +import logging +import math +import pytest +import torch +import torch_npu +import numpy as np +import triton +import triton.language as tl + +import test_common + + +def torch_argmax(x0, dim, keepdim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + return torch.argmax(x0, dim=dim, keepdim=keepdim).npu() + + +@triton.jit +def triton_argmax_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + xoffset, None) + tmp4 = tl.argmax(tmp0, 0) + tl.store(out_ptr1, tmp4, None) + + +@pytest.mark.parametrize('shape', [(128, ), (256, ), (37, ), (741, )]) +@pytest.mark.parametrize('dtype', ['int32', 'float32', 'uint8', 'int8']) +def test_argmax_1d(dtype, shape): + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty(1, dtype=torch.int32).npu() + numel = shape[0] + triton_argmax_1d[(1, )](x0, triton_res, numel, numel) + torch_res = torch_argmax(x0, dim=0, keepdim=True) + test_common.validate_cmp("int32", triton_res, torch_res) + + +@triton.jit +def triton_argmax_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, + NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, MNUMEL) + nblk_idx = tl.arange(0, NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * N + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=float('-inf')) + tmp4 = tl.argmax(x, dim) + if dim == 0: + tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) + else: + tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) + + +@pytest.mark.parametrize('shape', [(37, 125), (29, 4), (7, 31)]) +@pytest.mark.parametrize('dtype', ['int32', 'float32', 'uint8', 'int8']) +@pytest.mark.parametrize('dim', [0, 1]) +def test_argmax_2d(dtype, shape, dim): + shapex, shapey = shape + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([ + shape[1 - dim], + ], dtype=torch.int32).npu() + triton_argmax_2d[(1, 1)](x0, triton_res, dim, shapex, shapey, shapex, shapey) + torch_res = torch_argmax(x0, dim=dim, keepdim=False) + test_common.validate_cmp("int32", triton_res, torch_res) diff --git a/third_party/ascend/unittest/pytest_ut/test_argmin.py b/third_party/ascend/unittest/pytest_ut/test_argmin.py new file mode 100644 index 0000000000..25d50b71f0 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_argmin.py @@ -0,0 +1,65 @@ +import logging +import math +import pytest +import torch +import torch_npu +import numpy as np +import triton +import triton.language as tl + +import test_common + + +def torch_argmin(x0, dim, keepdim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + return torch.argmin(x0, dim=dim, keepdim=keepdim).npu() + + +@triton.jit +def triton_argmin_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + xoffset, None) + tmp4 = tl.argmin(tmp0, 0) + tl.store(out_ptr1, tmp4, None) + + +@pytest.mark.parametrize('shape', [(128, ), (256, ), (37, ), (741, )]) +@pytest.mark.parametrize('dtype', ['int32', 'float32', 'uint8', 'int8']) +def test_argmin_1d(dtype, shape): + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty(1, dtype=torch.int32).npu() + numel = shape[0] + triton_argmin_1d[(1, )](x0, triton_res, numel, numel) + torch_res = torch_argmin(x0, dim=0, keepdim=True) + test_common.validate_cmp("int32", triton_res, torch_res) + + +@triton.jit +def triton_argmin_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, + NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, MNUMEL) + nblk_idx = tl.arange(0, NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * N + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=float('inf')) + tmp4 = tl.argmin(x, dim) + if dim == 0: + tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) + else: + tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) + + +@pytest.mark.parametrize('shape', [(37, 125), (29, 4), (7, 31)]) +@pytest.mark.parametrize('dtype', ['int32', 'float32', 'uint8', 'int8']) +@pytest.mark.parametrize('dim', [0, 1]) +def test_argmin_2d(dtype, shape, dim): + shapex, shapey = shape + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([ + shape[1 - dim], + ], dtype=torch.int32).npu() + triton_argmin_2d[(1, 1)](x0, triton_res, dim, shapex, shapey, shapex, shapey) + torch_res = torch_argmin(x0, dim=dim, keepdim=False) + test_common.validate_cmp("int32", triton_res, torch_res) diff --git a/third_party/ascend/unittest/pytest_ut/test_asm.py b/third_party/ascend/unittest/pytest_ut/test_asm.py index 02e69bddd6..189e8c2e4c 100644 --- a/third_party/ascend/unittest/pytest_ut/test_asm.py +++ b/third_party/ascend/unittest/pytest_ut/test_asm.py @@ -1,52 +1,96 @@ -import triton -import triton.language as tl -import numpy as np -import torch -import pytest -import test_common - - -def torch_add(x, y): - res = x + y - return res - - -@triton.jit -def triton_asm_add( - x_ptr, - y_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = tl.inline_asm_elementwise( - asm=""" - ADD.s64 $0, $1, $2 - """, - constraints=("=l,l,l"), - args=[x, y], - dtype=tl.int64, - is_pure=True, - pack=1, - ) - tl.store(output_ptr + offsets, output, mask=mask) - - -@pytest.mark.parametrize('param_list', [ - ['int64', 4096, 1024], -]) -def test_case(param_list): - dtype, length, block_size = param_list - ncore = length // block_size - x = test_common.generate_tensor((length, ), dtype).npu() - y = test_common.generate_tensor((length, ), dtype).npu() - res_ref = torch_add(x, y) - res_cal = torch.zeros((length, ), dtype=eval('torch.' + dtype)).npu() - triton_asm_add[(ncore, )](x, y, res_cal, length, BLOCK_SIZE=block_size) - test_common.validate_cmp(dtype, res_cal, res_ref) +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + + +def torch_add(x, y): + res = x + y + return res + + +@triton.jit +def triton_asm_add( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = tl.inline_asm_elementwise( + asm=""" + ADD.s64 $0, $1, $2 + """, + constraints=("=l,l,l"), + args=[x, y], + dtype=tl.int64, + is_pure=True, + pack=1, + ) + tl.store(output_ptr + offsets, output, mask=mask) + + +@pytest.mark.parametrize('param_list', [ + ['int64', 4096, 1024], +]) +def test_case(param_list): + dtype, length, block_size = param_list + ncore = length // block_size + x = test_common.generate_tensor((length, ), dtype).npu() + y = test_common.generate_tensor((length, ), dtype).npu() + res_ref = torch_add(x, y) + res_cal = torch.zeros((length, ), dtype=eval('torch.' + dtype)).npu() + triton_asm_add[(ncore, )](x, y, res_cal, length, BLOCK_SIZE=block_size) + test_common.validate_cmp(dtype, res_cal, res_ref) + + +@triton.jit +def triton_asm_add_2d( + x_ptr, + y_ptr, + output_ptr, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + row_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M) + col_offsets = tl.arange(0, BLOCK_N) + offsets = row_offsets[:, None] * N + col_offsets[None, :] + mask = (row_offsets[:, None] < M) & (col_offsets[None, :] < N) + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = tl.inline_asm_elementwise( + asm=""" + ADD.s64 $0, $1, $2 + """, + constraints=("=l,l,l"), + args=[x, y], + dtype=tl.int64, + is_pure=True, + pack=1, + ) + tl.store(output_ptr + offsets, output, mask=mask) + + +@pytest.mark.parametrize('param_list', [ + ['int64', 64, 32, 16, 32], +]) +def test_case_2d(param_list): + dtype, M, N, block_m, block_n = param_list + ncore = M // block_m + x = test_common.generate_tensor((M, N), dtype).npu() + y = test_common.generate_tensor((M, N), dtype).npu() + res_ref = torch_add(x, y) + res_cal = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() + triton_asm_add_2d[(ncore, )](x, y, res_cal, M, N, BLOCK_M=block_m, BLOCK_N=block_n) + test_common.validate_cmp(dtype, res_cal, res_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_asm_scalar.py b/third_party/ascend/unittest/pytest_ut/test_asm_scalar.py new file mode 100644 index 0000000000..23b8f6553c --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_asm_scalar.py @@ -0,0 +1,32 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest + + +@triton.jit +def triton_asm_time(output_ptr, ): + y = tl.inline_asm_elementwise( + asm=""" + MOV $0, SYS_CNT + """, + constraints="=l", + args=[], + dtype=(tl.int64), + is_pure=False, + pack=1, + ) + tl.store(output_ptr, y) + + +@pytest.mark.parametrize( + "param_list", + [[ + "int64", + ]], +) +def test_case(param_list): + (dtype, ) = param_list + res_cal = torch.zeros((1, ), dtype=eval("torch." + dtype)).npu() + triton_asm_time[(1, )](res_cal, ) diff --git a/third_party/ascend/unittest/pytest_ut/test_assume1.py b/third_party/ascend/unittest/pytest_ut/test_assume1.py new file mode 100644 index 0000000000..05d5630092 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_assume1.py @@ -0,0 +1,32 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + +from triton._internal_testing import (is_interpreter) + + +@triton.jit +def assume(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): + current_size = N - tl.program_id(0) * BLOCK_N + tl.assume(current_size >= BLOCK_N) + if current_size >= BLOCK_N: + tl.store(out_ptr + tl.program_id(0), current_size) + else: + tl.store(out_ptr + tl.program_id(0), current_size + 101024) + + +@pytest.mark.parametrize('dtype', ["float32"]) +def test_assume(dtype): + NBLOCKS = 1024 // 128 + BLOCK_N = 128 + N = 1024 + output = torch.zeros(NBLOCKS, device='npu') + pgm = assume[(NBLOCKS, )](output, N=N, BLOCK_N=BLOCK_N) + + if is_interpreter(): + return + + assert 'llvm.intr.assume' in pgm.asm['ttadapter'] diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_add.py b/third_party/ascend/unittest/pytest_ut/test_atomic_add.py index 10aa305204..967717ad33 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_add.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_add.py @@ -51,11 +51,28 @@ def atomic_add_supply(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr): tmp1 = tl.atomic_add(out_ptr0 + (x1), tmp0, xmask) +@triton.jit +def atomic_add_for_load_offset(index_ptr, in_ptr0, out_ptr0): + index = tl.atomic_add(index_ptr, 1) + val = tl.load(in_ptr0 + index) + tl.store(out_ptr0, val) + + +@triton.jit +def atomic_add_for_store_offset(index_ptr, out_ptr0): + index = tl.atomic_add(index_ptr, 1) + tl.store(out_ptr0 + index, 1) + + @pytest.mark.parametrize('param_list', [ + ['int64', (256, 32), 2], + ['int32', (32, 32), 2], ['int16', (32, 32), 2], ['int8', (32, 32), 2], + ['uint8', (32, 32), 2], ['float32', (32, 32), 2], ['float16', (64, 64), 4], + ['bfloat16', (64, 64), 4], ['float32', (128, 128), 8], ['float16', (128, 128), 16], ['float32', (32768, 16), 32], @@ -66,7 +83,10 @@ def test_atomic_add(param_list): split_size = shape[0] // ncore x0_value = 3 x0 = torch.full(shape, x0_value, dtype=eval(f'torch.{dtype}')).npu() - x1 = torch.full((split_size, shape[1]), 2, dtype=eval(f'torch.{dtype}')).npu() + if dtype == 'int64': + x1 = torch.randint(-10**15, 10**15, (split_size, shape[1]), dtype=eval(f'torch.{dtype}')).npu() + else: + x1 = torch.full((split_size, shape[1]), 2, dtype=eval(f'torch.{dtype}')).npu() y = torch.full((split_size, shape[1]), -10, dtype=eval(f'torch.{dtype}')).npu() y_ref = x1 + 0 @@ -155,6 +175,33 @@ def test_atomic_add_2d_supply(dtype, shape): test_common.validate_cmp(dtype, x1, x1_ref) +def test_atomic_add_for_load_offset(): + index = torch.tensor([1]).npu() + input_tensor = torch.zeros(5).npu() + output = torch.tensor([1]).npu() + index_ref = index.clone() + index_ref += 1 + output_ref = output.clone() + output_ref = input_tensor[index] + + atomic_add_for_load_offset[(1, )](index, input_tensor, output) + torch.equal(index, index_ref) + torch.equal(output, output_ref) + + +def test_atomic_add_for_store_offset(): + index = torch.tensor([1]).npu() + output = torch.zeros(5).npu() + index_ref = index.clone() + index_ref += 1 + output_ref = output.clone() + output_ref[index] = 1 + + atomic_add_for_store_offset[(1, )](index, output) + torch.equal(index, index_ref) + torch.equal(output, output_ref) + + if __name__ == "__main__": param_list = ['float32', (32, 32), 2] test_atomic_add_2d(param_list) diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_and.py b/third_party/ascend/unittest/pytest_ut/test_atomic_and.py index 19a3bb6958..675f65f773 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_and.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_and.py @@ -44,6 +44,7 @@ def atomic_and(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr ['int32', (32, 32), 2], ['int16', (32, 32), 2], ['int8', (16, 16), 4], + ['uint8', (16, 16), 4], ]) def test_atomic_and(param_list): dtype, shape, ncore = param_list diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py b/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py index 3e3b3a6fc6..e834aa8d79 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py @@ -25,6 +25,14 @@ import torch import torch_npu +types_all = [ + (torch.float32, 'float32'), +] + + +def ceil_div(a, b): + return (a + b - 1) // b + @triton.jit def atomic_cas(in_ptr0, in_ptr1, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): @@ -40,6 +48,44 @@ def atomic_cas(in_ptr0, in_ptr1, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl. tl.store(out_ptr1 + (x1), tmp1, xmask) +@triton.jit +def atomic_cas_with_full( + ptr, + out, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + x = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = x < n_elements + + cmp = tl.full((BLOCK_SIZE, ), 2.0, tl.float32) + val = tl.full((BLOCK_SIZE, ), 1.0, tl.float32) + + old = tl.atomic_cas(ptr + x, cmp, val) # in_ptr(origin 2) -> ref: 1 X + tl.store(out + x, old, mask=mask) # out(origin 1) -> ref: old in_ptr(2) √ + + +@triton.jit +def atomic_cas_without_full( + ptr, + cmp_ptr, + val_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + x = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = x < n_elements + + cmp = tl.load(cmp_ptr + x, mask) # 2 + val = tl.load(val_ptr + x, mask) # 1 + + old = tl.atomic_cas(ptr + x, cmp, val) # old : 2 + tl.store(out_ptr + x, old, mask=mask) + + @pytest.mark.parametrize('param_list', [ ['int16', (8, 8), 2], ['int32', (32, 32), 6], @@ -110,3 +156,35 @@ def test_atomic_cas_return_value(param_list): atomic_cas[ncore, 1, 1](val, cmp, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) test_common.validate_cmp(dtype, pointer, pointer_ref) test_common.validate_cmp(dtype, pointer_old, pointer_old_ref) + + +@pytest.mark.parametrize('dtype,sigtype', types_all) +@pytest.mark.parametrize('n_elements, BLOCK_SIZE', [(4096, 256)]) +@pytest.mark.skip(reason="full tensor has problem, skipped") +def test_atomic_cas_with_full(n_elements, BLOCK_SIZE, dtype, sigtype): + in_ptr = torch.full((n_elements, ), 2, dtype=dtype).npu() + out_ptr = torch.empty_like(in_ptr) + + grid = (ceil_div(n_elements, BLOCK_SIZE), 1, 1) + atomic_cas_with_full[grid](in_ptr, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + # old should be all 2 (for in-range) + torch.testing.assert_close(out_ptr, torch.full_like(out_ptr, 2.0)) + + # final ptr should be all 1 + torch.testing.assert_close(in_ptr, torch.ones_like(in_ptr)) + + +@pytest.mark.parametrize('dtype,sigtype', types_all) +@pytest.mark.parametrize('n_elements, BLOCK_SIZE', [(4096, 256)]) +def test_atomic_cas_without_full(n_elements, BLOCK_SIZE, dtype, sigtype): + in_ptr = torch.full((n_elements, ), 2, dtype=dtype).npu() + cmp_ptr = torch.full((n_elements, ), 2, dtype=dtype).npu() + val_ptr = torch.full((n_elements, ), 1, dtype=dtype).npu() + out_ptr = torch.full((n_elements, ), 1, dtype=dtype).npu() # ref: in_ptr + + grid = (ceil_div(n_elements, BLOCK_SIZE), 1, 1) + atomic_cas_without_full[grid](in_ptr, cmp_ptr, val_ptr, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + torch.testing.assert_close(in_ptr, torch.full_like(in_ptr, 1.0)) + torch.testing.assert_close(out_ptr, torch.full_like(out_ptr, 2.0)) diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_max.py b/third_party/ascend/unittest/pytest_ut/test_atomic_max.py index 942f9429fd..7a7a200e24 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_max.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_max.py @@ -54,13 +54,18 @@ def triton_test_fn_atomic_max_dma_supply(in_ptr0, out_ptr0, n_elements: tl.const # torch.max do not support int @pytest.mark.parametrize('param_list', [ + ['uint8', (32, 32), 2], ['int16', (32, 32), 2], + ['bfloat16', (32, 32), 2], ['float16', (32, 32), 2], ['float32', (128, 128), 8], ['float32', (32768, 16), 32], ['int32', (32, 32), 2], ['int32', (128, 128), 8], ['int32', (32768, 16), 32], + ['int64', (32, 32), 2], + ['int64', (128, 128), 8], + ['int64', (8192, 16), 32], ]) def test_atomic_max(param_list): dtype, shape, ncore = param_list diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_min.py b/third_party/ascend/unittest/pytest_ut/test_atomic_min.py index 213740548d..461a3baa9c 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_min.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_min.py @@ -52,9 +52,12 @@ def triton_test_fn_atomic_min_dma_supply(in_ptr0, out_ptr0, n_elements: tl.const @pytest.mark.parametrize('param_list', [ + ['uint8', (32, 32), 2], ['int8', (32, 32), 2], ['int16', (32, 32), 2], ['int32', (32, 32), 2], + ['int64', (32, 32), 2], + ['bfloat16', (64, 64), 4], ['float16', (64, 64), 4], ['float32', (32, 32), 2], ]) diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_rmw_useanalysis.py b/third_party/ascend/unittest/pytest_ut/test_atomic_rmw_useanalysis.py new file mode 100644 index 0000000000..da1da5eafc --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_rmw_useanalysis.py @@ -0,0 +1,70 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def atomic_rmw_useanalysis_kernel( + input_ptr, + output_ptr, + m_ptr, + d_ptr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + base_idx = pid * 8 + + term1 = 15.0 * 15.0 + term2 = 8.0 * (7.0 - base_idx) + + delta = term1 + term2 + sqrt_delta = tl.sqrt(delta) + + task_idx = tl.ceil((15.0 - sqrt_delta) / 2.0) + task_idx_i32 = task_idx.to(tl.int32) + + block_start = task_idx_i32 * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + data = tl.load(input_ptr + offsets, mask=mask, other=0.0) + m_val = tl.load(m_ptr + offsets, mask=mask, other=0.0) + d_val = tl.load(d_ptr + offsets, mask=mask, other=0.0) + + scaled = data - m_val + p = tl.exp(scaled) + + result = p * (data * 2.0 - d_val) + + output_offsets = offsets + tl.atomic_add(output_ptr + output_offsets, result, mask=mask) + + +def test_atomic_rmw_useanalysis(): + DEVICE = "npu" + N = 1024 + BLOCK_SIZE = 128 + + torch.manual_seed(42) + input_data = torch.randn(N, dtype=torch.float32, device=DEVICE) + m_data = torch.randn(N, dtype=torch.float32, device=DEVICE) + d_data = torch.randn(N, dtype=torch.float32, device=DEVICE) + output_data = torch.zeros(N, dtype=torch.float32, device=DEVICE) + + grid = (8, ) + + atomic_rmw_useanalysis_kernel[grid]( + input_data, + output_data, + m_data, + d_data, + N=N, + BLOCK_SIZE=BLOCK_SIZE, + ) + output_sum = output_data.abs().sum().item() + + if output_sum == 0: + raise AssertionError("UseAnalysis bug detected: atomic_rmw dependencies were erased") + else: + print(" AtomicRMW UseAnalysis is working correctly.") diff --git a/third_party/ascend/unittest/pytest_ut/test_block_ptr.py b/third_party/ascend/unittest/pytest_ut/test_block_ptr.py index 5719e7a373..e8132a4508 100644 --- a/third_party/ascend/unittest/pytest_ut/test_block_ptr.py +++ b/third_party/ascend/unittest/pytest_ut/test_block_ptr.py @@ -84,3 +84,112 @@ def test_npu(para_type, data_type, XB, YB, ZB): fn_npu_[1, 1, 1](output, x, y, z, output1, XB=XB, YB=YB, ZB=ZB, debug=True) print(output) torch.testing.assert_close(output, a) + + +@triton.jit +def dma_block_ptr( + input_ptr, + output_ptr, + scale_ptr, + batch_size, + cu_seqlens_ptr, + stride_i_m, + stride_i_n, + stride_o_m, + stride_o_n, + stride_s_b, + HEAD_DIM, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + n_progs = tl.num_programs(0) + pid = tl.program_id(0) + + cu_num_blocks = 0 + for bid in range(batch_size): + start_loc = tl.load(cu_seqlens_ptr + bid) + end_loc = tl.load(cu_seqlens_ptr + bid + 1) + scale = tl.load(scale_ptr + bid * stride_s_b) + + len_loc = end_loc - start_loc + prev_num_blocks = cu_num_blocks + new_num_blocks = tl.cdiv(len_loc, BLOCK_SIZE_M).to(tl.int32) + i_block_ptr_bbase = tl.make_block_ptr( + input_ptr + start_loc * stride_i_m, + shape=(len_loc, HEAD_DIM), + strides=(stride_i_m, stride_i_n), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), + order=(1, 0), + ) + o_block_ptr_bbase = tl.make_block_ptr( + output_ptr + start_loc * stride_o_m, + shape=(len_loc, HEAD_DIM), + strides=(stride_o_m, stride_o_n), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), + order=(1, 0), + ) + cu_num_blocks += new_num_blocks + for m_id in range((prev_num_blocks + pid) % n_progs, new_num_blocks, n_progs): + i_block_ptr = tl.advance(i_block_ptr_bbase, (m_id * BLOCK_SIZE_M, 0)) + o_block_ptr = tl.advance(o_block_ptr_bbase, (m_id * BLOCK_SIZE_M, 0)) + i_tile = tl.load(i_block_ptr, boundary_check=[0, 1], padding_option="zero") + o_tile = i_tile.to(tl.float32) * scale + tl.store(o_block_ptr, o_tile.to(i_tile.dtype), boundary_check=[0, 1]) + + +def ref_func(inputs, scale, cu_lens): + outputs = torch.zeros_like(inputs) + bsz = cu_lens.size(0) - 1 + for bid in range(bsz): + tmp = inputs[cu_lens[bid]:cu_lens[bid + 1]].to(torch.float32) * scale[bid] + outputs[cu_lens[bid]:cu_lens[bid + 1]] = tmp.to(outputs.dtype) + return outputs + + +def tt_func(inputs, scale, cu_lens): + bsz = cu_lens.size(0) - 1 + outputs = torch.zeros_like(inputs) + head_dim = inputs.size(-1) + assert head_dim <= 1024 + BLOCK_SIZE_N = 1024 + BLOCK_SIZE_M = 4 + dma_block_ptr[ + 20, + ]( + inputs, + outputs, + scale, + bsz, + cu_lens, + inputs.stride(0), + inputs.stride(1), + outputs.stride(0), + outputs.stride(1), + scale.stride(0), + head_dim, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return outputs + + +@pytest.mark.parametrize('param_list', [ + [8, 1024, 1024, True], + [8, 1024, 1024, False], +]) +def test_func(param_list): + bsz, max_len, max_n, test_align = param_list + lens = torch.randint(max_len // 2, max_len, (bsz, ), dtype=torch.int32, device="npu") + n = torch.randint(max_n // 2, max_n, (1, ), dtype=torch.int32, device="npu")[0].item() + if test_align: + lens = (lens + 1023) // 1024 * 1024 + n = (n + 1023) // 1024 * 1024 + cu_lens = torch.cumsum(lens, dim=0) + cu_lens = torch.cat([torch.zeros(1, dtype=torch.int32, device="npu"), cu_lens], dim=0) + inputs = torch.randn(cu_lens[-1], n, dtype=torch.float16, device="npu") + scale = torch.randn(bsz, dtype=torch.float32, device="npu") + ref_output = ref_func(inputs, scale, cu_lens) + tt_output = tt_func(inputs, scale, cu_lens) + torch.testing.assert_close(ref_output, tt_output) diff --git a/third_party/ascend/unittest/pytest_ut/test_boundary_check.py b/third_party/ascend/unittest/pytest_ut/test_boundary_check.py new file mode 100644 index 0000000000..c7dabda498 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_boundary_check.py @@ -0,0 +1,275 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import triton +import triton.language as tl +import pytest + + +# ========== Test 1: Static base address + boundary_check ========== +@triton.jit +def static_base_boundary_check_kernel( + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, +): + ptr = tl.make_block_ptr(base=in_ptr, shape=(BLOCK_SIZE * 2, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + data = tl.load(ptr, boundary_check=(0, ), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def ref_static_base(in_tensor, BLOCK_SIZE): + return in_tensor[:BLOCK_SIZE].sum().item() + + +def test_static_base(): + BLOCK_SIZE = 64 + in_tensor = torch.randn(BLOCK_SIZE * 2, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + static_base_boundary_check_kernel[(1, )]( + out_ptr=out_tensor, + in_ptr=in_tensor, + BLOCK_SIZE=BLOCK_SIZE, + ) + expected = ref_static_base(in_tensor.cpu(), BLOCK_SIZE) + assert torch.allclose(out_tensor.cpu(), torch.tensor(expected, device='cpu'), atol=1e-4) + + +# ========== Test 2: Simple dynamic base address + boundary_check ========== +@triton.jit +def simple_dynamic_base_boundary_check_kernel( + out_ptr, + in_ptr, + offset: tl.int32, + BLOCK_SIZE: tl.constexpr, +): + base = in_ptr + offset + ptr = tl.make_block_ptr(base=base, shape=(BLOCK_SIZE * 2, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + data = tl.load(ptr, boundary_check=(0, ), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def test_simple_dynamic_base(): + BLOCK_SIZE = 64 + offset = 32 + in_tensor = torch.randn(BLOCK_SIZE * 4, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + simple_dynamic_base_boundary_check_kernel[(1, )]( + out_ptr=out_tensor, + in_ptr=in_tensor, + offset=offset, + BLOCK_SIZE=BLOCK_SIZE, + ) + expected = in_tensor.cpu()[offset:offset + BLOCK_SIZE].sum().item() + assert torch.allclose(out_tensor.cpu(), torch.tensor(expected, device='cpu'), atol=1e-4) + + +# ========== Test 3: Nested loop + dynamic base address + advance + boundary_check ========== +@triton.jit +def nested_dynamic_advance_boundary_kernel( + out_ptr, + in_ptr, + stride_in: tl.int32, + OUTER_LOOP: tl.constexpr, + INNER_LOOP: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Smallest reproducible code: The dynamic base address is in the outer loop, + and tl.advance is in the inner loop, where there is a boundary_check. + """ + for i in range(OUTER_LOOP): + base = in_ptr + i * stride_in + ptr = tl.make_block_ptr(base=base, shape=(INNER_LOOP * BLOCK_SIZE, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + for j in range(INNER_LOOP): + cur_ptr = tl.advance(ptr, (j * BLOCK_SIZE, )) + data = tl.load(cur_ptr, boundary_check=(0, ), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr + i * INNER_LOOP + j, result) + + +def ref_nested_dynamic(in_tensor, OUTER_LOOP, INNER_LOOP, BLOCK_SIZE): + """ + PyTorch equivalent implementation: + - Treat in_tensor as a tensor of shape [OUTER_LOOP, INNER_LOOP * BLOCK_SIZE] + - For each (i, j) block: take the BLOCK_SIZE elements starting from j*BLOCK_SIZE in the i-th row and sum them up. + - Note: There is boundary_check + zero padding, but there is no out-of-bound access in this case, so no special handling is needed. + """ + reshaped = in_tensor[:OUTER_LOOP * INNER_LOOP * BLOCK_SIZE].view(OUTER_LOOP, INNER_LOOP * BLOCK_SIZE) + blocks = reshaped.unfold(1, BLOCK_SIZE, BLOCK_SIZE) + return blocks.sum(dim=-1).flatten() + + +def test_nested_dynamic(): + BLOCK_SIZE = 8 + OUTER_LOOP = 2 + INNER_LOOP = 2 + in_tensor = torch.randn(OUTER_LOOP * INNER_LOOP * BLOCK_SIZE * 2, dtype=torch.float32).npu() + out_tensor = torch.zeros(OUTER_LOOP * INNER_LOOP, dtype=torch.float32).npu() + nested_dynamic_advance_boundary_kernel[(1, )]( + out_ptr=out_tensor, + in_ptr=in_tensor, + stride_in=INNER_LOOP * BLOCK_SIZE, + OUTER_LOOP=OUTER_LOOP, + INNER_LOOP=INNER_LOOP, + BLOCK_SIZE=BLOCK_SIZE, + ) + ref = ref_nested_dynamic(in_tensor.cpu(), OUTER_LOOP, INNER_LOOP, BLOCK_SIZE) + assert torch.allclose(out_tensor.cpu(), ref, atol=1e-4) + + +# ========== Test 4: Explicit out-of-bounds access + zero padding + boundary_check ========== +@triton.jit +def out_of_bound_zero_padding_kernel( + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, +): + ptr = tl.make_block_ptr(base=in_ptr, shape=(BLOCK_SIZE, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE * 2, ), order=(0, )) + data = tl.load(ptr, boundary_check=(0, ), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def test_out_of_bound(): + BLOCK_SIZE = 64 + in_tensor = torch.randn(BLOCK_SIZE, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + out_of_bound_zero_padding_kernel[(1, )]( + out_ptr=out_tensor, + in_ptr=in_tensor, + BLOCK_SIZE=BLOCK_SIZE, + ) + expected = in_tensor.cpu().sum().item() + assert torch.allclose(out_tensor.cpu(), torch.tensor(expected, device='cpu'), atol=1e-4) + + +# ========== Test 5:padding_option = NAN + boundary_check========== +@triton.jit +def nan_padding_kernel( + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, +): + ptr = tl.make_block_ptr(base=in_ptr, shape=(BLOCK_SIZE, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE * 2, ), order=(0, )) + data = tl.load(ptr, boundary_check=(0, ), padding_option="nan") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def test_nan_padding(): + BLOCK_SIZE = 64 + in_tensor = torch.randn(BLOCK_SIZE, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + try: + nan_padding_kernel[(1, )]( + out_ptr=out_tensor, + in_ptr=in_tensor, + BLOCK_SIZE=BLOCK_SIZE, + ) + assert torch.isnan(out_tensor.cpu()).any() + except Exception as e: + print(f"Warning: NAN padding test may not be supported: {e}") + + +# ========== Test 6:Multi-layer advance + boundary_check ========== +@triton.jit +def multi_advance_kernel( + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, +): + base = in_ptr + ptr0 = tl.make_block_ptr(base=base, shape=(BLOCK_SIZE * 4, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + ptr1 = tl.advance(ptr0, (BLOCK_SIZE, )) + ptr2 = tl.advance(ptr1, (BLOCK_SIZE, )) + data = tl.load(ptr2, boundary_check=(0, ), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def test_multi_advance(): + BLOCK_SIZE = 64 + in_tensor = torch.randn(BLOCK_SIZE * 4, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + multi_advance_kernel[(1, )]( + out_ptr=out_tensor, + in_ptr=in_tensor, + BLOCK_SIZE=BLOCK_SIZE, + ) + expected = in_tensor.cpu()[2 * BLOCK_SIZE:3 * BLOCK_SIZE].sum().item() + assert torch.allclose(out_tensor.cpu(), torch.tensor(expected, device='cpu'), atol=1e-4) + + +# ========== Test 7:Complex base address calculation + boundary_check ========== +@triton.jit +def complex_base_calculation_kernel( + out_ptr, + in_ptr, + offset1: tl.int32, + offset2: tl.int32, + scale: tl.int32, + BLOCK_SIZE: tl.constexpr, +): + base = in_ptr + offset1 * scale + offset2 + ptr = tl.make_block_ptr(base=base, shape=(BLOCK_SIZE * 2, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + data = tl.load(ptr, boundary_check=(0, ), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def test_complex_base(): + BLOCK_SIZE = 64 + offset1, offset2, scale = 2, 16, 32 + total_offset = offset1 * scale + offset2 + in_tensor = torch.randn(total_offset + BLOCK_SIZE * 2, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + complex_base_calculation_kernel[(1, )]( + out_ptr=out_tensor, + in_ptr=in_tensor, + offset1=offset1, + offset2=offset2, + scale=scale, + BLOCK_SIZE=BLOCK_SIZE, + ) + expected = in_tensor.cpu()[total_offset:total_offset + BLOCK_SIZE].sum().item() + assert torch.allclose(out_tensor.cpu(), torch.tensor(expected, device='cpu'), atol=1e-4) + + +if __name__ == "__main__": + print("Running all boundary_check tests...") + test_static_base() + test_simple_dynamic_base() + test_nested_dynamic() + test_out_of_bound() + test_nan_padding() + test_multi_advance() + test_complex_base() + print("All tests completed successfully!") diff --git a/third_party/ascend/unittest/pytest_ut/test_cat_help_func.py b/third_party/ascend/unittest/pytest_ut/test_cat_help_func.py new file mode 100644 index 0000000000..4b102b9335 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_cat_help_func.py @@ -0,0 +1,728 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import logging +import random +import pytest + +import triton +import triton.language as tl +import torch +import test_common +import numpy as np +import triton.language.extra.cann.extension as extension + + +def gen_1d_cat_shapes(min_val=1, max_val=4096): + shape1 = random.randint(min_val, max_val) + shape2 = random.randint(min_val, max_val) + return (shape1, ), (shape2, ), 0 + + +def gen_2d_cat_shapes(dim=0, min_val=1, max_val=4096): + if dim == 0: + common_col = random.randint(min_val, max_val) + row1 = random.randint(min_val, max_val) + row2 = random.randint(min_val, max_val) + shape1 = (row1, common_col) + shape2 = (row2, common_col) + elif dim == 1: + common_row = random.randint(min_val, max_val) + col1 = random.randint(min_val, max_val) + col2 = random.randint(min_val, max_val) + shape1 = (common_row, col1) + shape2 = (common_row, col2) + else: + raise ValueError("2d shape only support dim=0 or dim=1") + return shape1, shape2, dim + + +def gen_3d_cat_shapes(dim=0, min_val=1, max_val=4096): + if dim not in [0, 1, 2]: + raise ValueError("3d shape only support dim=0/1/2") + + if dim == 0: + common_d1 = random.randint(min_val, max_val) + common_d2 = random.randint(min_val, max_val) + d0_1 = random.randint(min_val, max_val) + d0_2 = random.randint(min_val, max_val) + shape1 = (d0_1, common_d1, common_d2) + shape2 = (d0_2, common_d1, common_d2) + + elif dim == 1: + common_d0 = random.randint(min_val, max_val) + common_d2 = random.randint(min_val, max_val) + d1_1 = random.randint(min_val, max_val) + d1_2 = random.randint(min_val, max_val) + shape1 = (common_d0, d1_1, common_d2) + shape2 = (common_d0, d1_2, common_d2) + + else: # dim == 2 + common_d0 = random.randint(min_val, max_val) + common_d1 = random.randint(min_val, max_val) + d2_1 = random.randint(min_val, max_val) + d2_2 = random.randint(min_val, max_val) + shape1 = (common_d0, common_d1, d2_1) + shape2 = (common_d0, common_d1, d2_2) + + return shape1, shape2, dim + + +def gen_100_cat_shapes(num_groups=100, mix_ratio=(0.3, 0.3, 0.4), min_val=1, max_val=4096): + + shape_list = [] + num_1d = int(num_groups * mix_ratio[0]) + num_2d = int(num_groups * mix_ratio[1]) + num_3d = num_groups - num_1d - num_2d + + for _ in range(num_1d): + shape_list.append(gen_1d_cat_shapes(min_val, max_val)) + + for _ in range(num_2d): + dim = random.choice([0, 1]) + shape_list.append(gen_2d_cat_shapes(dim, min_val, max_val)) + + for _ in range(num_3d): + dim = random.choice([0, 1, 2]) + shape_list.append(gen_3d_cat_shapes(dim, min_val, max_val)) + + random.shuffle(shape_list) + return shape_list + + +full_shape = gen_100_cat_shapes(num_groups=100, mix_ratio=(0.3, 0.4, 0.3), min_val=1, max_val=4096) + + +@triton.jit +def _cat_helper_func_2D_1( + in_ptr0, + in_ptr1, + out_ptr0, + in0_x: tl.constexpr, + in1_x: tl.constexpr, + y0_numel, + x1_numel, + Y0BLOCK: tl.constexpr, + Y0BLOCK_SUB: tl.constexpr, +): + y0_offset = tl.program_id(0) * Y0BLOCK_SUB + base_y0 = tl.arange(0, Y0BLOCK_SUB) + loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB + base_input0_x1 = tl.arange(0, in0_x)[None, :] + base_input1_x1 = tl.arange(0, in1_x)[None, :] + x1 = tl.arange(0, in0_x + in1_x)[None, :] + + for loop in range(loops_y0): + y0 = y0_offset + (loop * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + x1_mask = x1 < x1_numel + tmp0 = tl.load(in_ptr0 + (base_input0_x1 + in0_x * y0), y0_mask) + tmp1 = tl.load(in_ptr1 + (base_input1_x1 + in1_x * y0), y0_mask) + tmp2 = tl.zeros((Y0BLOCK_SUB, in0_x + in1_x), dtype=tmp0.dtype) + tmp3 = extension.insert_slice(tmp2, tmp0, [0, 0], [Y0BLOCK_SUB, in0_x], [1, 1]) + tmp4 = extension.insert_slice(tmp3, tmp1, [0, in0_x], [Y0BLOCK_SUB, in1_x], [1, 1]) + tl.store(out_ptr0 + (x1 + (in0_x + in1_x) * y0), tmp4, x1_mask & y0_mask) + + +@triton.jit +def triton_unk_fused_cat_dim0_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_numel, Y0BLOCK: tl.constexpr, + Y0BLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): + y0_offset = tl.program_id(0) * Y0BLOCK + base_y0 = tl.arange(0, Y0BLOCK_SUB) + loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB + base_x1 = tl.arange(0, X1BLOCK_SUB) + loops_x1 = (x1_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < x1_numel + + tmp0 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp10 = tl.zeros((2 * Y0BLOCK_SUB, X1BLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [Y0BLOCK_SUB, 0], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (2, Y0BLOCK_SUB, X1BLOCK_SUB)) + + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[None, :, None] + new_z0 = tl.arange(0, 2)[:, None, None] + new_x2_mask = new_x2 < x1_numel + new_y1_mask = new_y1 < y0_numel + tl.store(output_ptr + (new_x2 + x1_numel * (new_y1 + y0_numel * new_z0)), tmp13, new_x2_mask & new_y1_mask) + + +@triton.jit +def triton_unk_fused_cat_dim0_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, y1_numel, x1_numel, YBLOCK: tl.constexpr, + YBLOCK_2: tl.constexpr, YBLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): + y0_offset = tl.program_id(0) * YBLOCK + base_y0 = tl.arange(0, YBLOCK_SUB) + loops_y0 = (YBLOCK + YBLOCK_SUB - 1) // YBLOCK_SUB + base_x1 = tl.arange(0, X1BLOCK_SUB) + loops_x1 = (x1_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB + min_numel = 0 + max_numel = 0 + clone_numel = 0 + if y0_numel < y1_numel: + min_numel = y0_numel + max_numel = y1_numel + clone_numel = y1_numel - y0_numel + else: + min_numel = y1_numel + max_numel = y0_numel + clone_numel = y0_numel - y1_numel + + for loops_y in range(loops_y0): + y0 = y0_offset + (loops_y * YBLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(YBLOCK + y0_offset, min_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < x1_numel + + tmp0 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp10 = tl.zeros((2 * YBLOCK_SUB, X1BLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [YBLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [YBLOCK_SUB, 0], [YBLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (2, YBLOCK_SUB, X1BLOCK_SUB)) + + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :] + new_base_y1 = tl.arange(0, YBLOCK_SUB) + new_y1 = y0_offset + (loops_y * YBLOCK_SUB) + new_base_y1[None, :, None] + new_z0 = tl.arange(0, 2)[:, None, None] + new_x2_mask = new_x2 < x1_numel + new_y1_mask = new_y1 < min_numel + tl.store(output_ptr + (new_x2 + x1_numel * new_y1 + x1_numel * y0_numel * new_z0), tmp13, + new_x2_mask & new_y1_mask) + + loops_y1 = (YBLOCK_2 + YBLOCK_SUB - 1) // YBLOCK_SUB + y2_offset = tl.program_id(0) * YBLOCK_2 + min_numel + if y0_numel < y1_numel: + for loops_y1 in range(loops_y1): + y0 = y2_offset + (loops_y1 * YBLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(YBLOCK_2 + y2_offset, y1_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < x1_numel + + tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, YBLOCK_SUB) + new_y1 = y2_offset + y0_numel + (loops_y1 * YBLOCK_SUB) + new_base_y1[:, None] + sum_numel = y0_numel + y1_numel + new_x2_mask = new_x2 < x1_numel + new_y1_mask = new_y1 < sum_numel + tl.store(output_ptr + (new_x2 + x1_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + else: + for loops_y1 in range(loops_y1): + y0 = y2_offset + (loops_y1 * YBLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(YBLOCK_2 + y2_offset, y0_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < x1_numel + + tmp8 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, YBLOCK_SUB) + new_y1 = y2_offset + (loops_y1 * YBLOCK_SUB) + new_base_y1[:, None] + new_x2_mask = new_x2 < x1_numel + new_y1_mask = new_y1 < y0_numel + tl.store(output_ptr + (new_x2 + x1_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + + +@triton.jit +def triton_unk_fused_cat_dim1_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_numel, Y0BLOCK: tl.constexpr, + Y0BLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): + y0_offset = tl.program_id(0) * Y0BLOCK + base_y0 = tl.arange(0, Y0BLOCK_SUB) + loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB + base_x1 = tl.arange(0, X1BLOCK_SUB) + loops_x1 = (x1_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < x1_numel + + tmp0 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp10 = tl.zeros((Y0BLOCK_SUB, 2 * X1BLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [0, X1BLOCK_SUB], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (Y0BLOCK_SUB, 2, X1BLOCK_SUB)) + + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None, None] + new_z0 = tl.arange(0, 2)[None, :, None] + new_x2_mask = new_x2 < x1_numel + new_y1_mask = new_y1 < y0_numel + tl.store(output_ptr + (new_x2 + 2 * x1_numel * new_y1 + x1_numel * new_z0), tmp13, + new_x2_mask & new_y1_mask) + + +@triton.jit +def triton_unk_fused_cat_dim1_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, x0_numel, x1_numel, Y0BLOCK: tl.constexpr, + Y0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): + y0_offset = tl.program_id(0) * Y0BLOCK + base_y0 = tl.arange(0, Y0BLOCK_SUB) + loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB + base_x = tl.arange(0, XBLOCK_SUB) + min_numel = 0 + max_numel = 0 + clone_numel = 0 + if x0_numel < x1_numel: + min_numel = x0_numel + max_numel = x1_numel + clone_numel = x1_numel - x0_numel + else: + min_numel = x1_numel + max_numel = x0_numel + clone_numel = x0_numel - x1_numel + loops_x = (min_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + loops_x2 = (clone_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + for loop_x in range(loops_x): + x = (loop_x * XBLOCK_SUB) + base_x[None, :] + x_mask = x < min_numel + + tmp0 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask) + tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask) + tmp10 = tl.zeros((Y0BLOCK_SUB, 2 * XBLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [0, XBLOCK_SUB], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (Y0BLOCK_SUB, 2, XBLOCK_SUB)) + + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = (loop_x * XBLOCK_SUB) + new_base_x2[None, None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None, None] + new_z0 = tl.arange(0, 2)[None, :, None] + new_x2_mask = new_x2 < min_numel + new_y1_mask = new_y1 < y0_numel + sum_numel = x0_numel + x1_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1 + x0_numel * new_z0), tmp13, new_x2_mask & new_y1_mask) + + if x0_numel < x1_numel: + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < x1_numel + + tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = x0_numel + min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None] + sum_numel = x0_numel + x1_numel + new_x2_mask = new_x2 < sum_numel + new_y1_mask = new_y1 < y0_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + else: + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < x0_numel + + tmp8 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None] + sum_numel = x0_numel + x1_numel + new_x2_mask = new_x2 < x0_numel + new_y1_mask = new_y1 < y0_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + + +@triton.jit +def triton_unk_fused_cat_3d_dim0(output_ptr, x_ptr, y_ptr, z0_numel, z1_numel, y1_numel, x1_numel, ZBLOCK: tl.constexpr, + ZBLOCK_2: tl.constexpr, ZBLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): + z0_offset = tl.program_id(0) * ZBLOCK + base_z0 = tl.arange(0, ZBLOCK_SUB) + loops_z0 = (ZBLOCK + ZBLOCK_SUB - 1) // ZBLOCK_SUB + xy_numel = x1_numel * y1_numel + base_x1 = tl.arange(0, X1BLOCK_SUB) + loops_x1 = (xy_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB + min_numel = 0 + max_numel = 0 + clone_numel = 0 + if z0_numel < z1_numel: + min_numel = z0_numel + max_numel = z1_numel + clone_numel = z1_numel - z0_numel + else: + min_numel = z1_numel + max_numel = z0_numel + clone_numel = z0_numel - z1_numel + + for loops_z in range(loops_z0): + z0 = z0_offset + (loops_z * ZBLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(ZBLOCK + z0_offset, min_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < xy_numel + + tmp0 = tl.load(x_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask) + tmp8 = tl.load(y_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask) + tmp10 = tl.zeros((2 * ZBLOCK_SUB, X1BLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [ZBLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [ZBLOCK_SUB, 0], [ZBLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (2, ZBLOCK_SUB, X1BLOCK_SUB)) + + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :] + new_base_z1 = tl.arange(0, ZBLOCK_SUB) + new_z1 = z0_offset + (loops_z * ZBLOCK_SUB) + new_base_z1[None, :, None] + new_z0 = tl.arange(0, 2)[:, None, None] + new_x2_mask = new_x2 < xy_numel + new_z1_mask = new_z1 < min_numel + tl.store(output_ptr + (new_x2 + xy_numel * new_z1 + xy_numel * z0_numel * new_z0), tmp13, + new_x2_mask & new_z1_mask) + + loops_z1 = (ZBLOCK_2 + ZBLOCK_SUB - 1) // ZBLOCK_SUB + z2_offset = tl.program_id(0) * ZBLOCK_2 + min_numel + if z0_numel < z1_numel: + for loops_z1 in range(loops_z1): + z0 = z2_offset + (loops_z1 * ZBLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(ZBLOCK_2 + z2_offset, z1_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < xy_numel + + tmp8 = tl.load(y_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask) + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :] + new_base_z1 = tl.arange(0, ZBLOCK_SUB) + new_z1 = z2_offset + z0_numel + (loops_z1 * ZBLOCK_SUB) + new_base_z1[:, None] + sum_numel = z0_numel + z1_numel + new_x2_mask = new_x2 < xy_numel + new_z1_mask = new_z1 < sum_numel + tl.store(output_ptr + (new_x2 + xy_numel * new_z1), tmp8, new_x2_mask & new_z1_mask) + else: + for loops_z1 in range(loops_z1): + z0 = z2_offset + (loops_z1 * ZBLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(ZBLOCK_2 + z2_offset, z0_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < xy_numel + + tmp8 = tl.load(x_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask) + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :] + new_base_z1 = tl.arange(0, ZBLOCK_SUB) + new_z1 = z2_offset + (loops_z1 * ZBLOCK_SUB) + new_base_z1[:, None] + new_x2_mask = new_x2 < xy_numel + new_z1_mask = new_z1 < z0_numel + tl.store(output_ptr + (new_x2 + xy_numel * new_z1), tmp8, new_x2_mask & new_z1_mask) + + +@triton.jit +def triton_unk_fused_cat_3d_dim1(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, y1_numel, x0_numel, + Z0BLOCK: tl.constexpr, Z0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): + z0_offset = tl.program_id(0) * Z0BLOCK + base_z0 = tl.arange(0, Z0BLOCK_SUB) + loops_z0 = (Z0BLOCK + Z0BLOCK_SUB - 1) // Z0BLOCK_SUB + base_x = tl.arange(0, XBLOCK_SUB) + min_numel = 0 + max_numel = 0 + clone_numel = 0 + if y0_numel < y1_numel: + min_numel = y0_numel * x0_numel + max_numel = y1_numel * x0_numel + clone_numel = (y1_numel - y0_numel) * x0_numel + else: + min_numel = y1_numel * x0_numel + max_numel = y0_numel * x0_numel + clone_numel = (y0_numel - y1_numel) * x0_numel + loops_x = (min_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + loops_x2 = (clone_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop_z0 in range(loops_z0): + z0 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(Z0BLOCK + z0_offset, z0_numel) + for loop_x in range(loops_x): + x = (loop_x * XBLOCK_SUB) + base_x[None, :] + x_mask = x < min_numel + + tmp0 = tl.load(x_ptr + (x + x0_numel * y0_numel * z0), x_mask & z0_mask) + tmp8 = tl.load(y_ptr + (x + x0_numel * y1_numel * z0), x_mask & z0_mask) + tmp10 = tl.zeros((Z0BLOCK_SUB, 2 * XBLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Z0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [0, XBLOCK_SUB], [Z0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (Z0BLOCK_SUB, 2, XBLOCK_SUB)) + + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = (loop_x * XBLOCK_SUB) + new_base_x2[None, None, :] + new_base_z1 = tl.arange(0, Z0BLOCK_SUB) + new_z1 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + new_base_z1[:, None, None] + new_z0 = tl.arange(0, 2)[None, :, None] + new_x2_mask = new_x2 < min_numel + new_z1_mask = new_z1 < z0_numel + sum_numel = min_numel + max_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_z1 + x0_numel * y0_numel * new_z0), tmp13, + new_x2_mask & new_z1_mask) + + if y0_numel == y1_numel: + return + + if y0_numel < y1_numel: + for loop_z0 in range(loops_z0): + z0 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(Z0BLOCK + z0_offset, z0_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < y1_numel * x0_numel + + tmp8 = tl.load(y_ptr + (x + x0_numel * y1_numel * z0), x_mask & z0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = x0_numel * y0_numel + min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_z1 = tl.arange(0, Z0BLOCK_SUB) + new_z1 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + new_base_z1[:, None] + sum_numel = min_numel + max_numel + new_x2_mask = new_x2 < sum_numel + new_z1_mask = new_z1 < z0_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_z1), tmp8, new_x2_mask & new_z1_mask) + else: + for loop_z0 in range(loops_z0): + z0 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(Z0BLOCK + z0_offset, z0_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < x0_numel * y0_numel + + tmp8 = tl.load(x_ptr + (x + x0_numel * y0_numel * z0), x_mask & z0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_z1 = tl.arange(0, Z0BLOCK_SUB) + new_z1 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + new_base_z1[:, None] + sum_numel = min_numel + max_numel + new_x2_mask = new_x2 < x0_numel * y0_numel + new_z1_mask = new_z1 < z0_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_z1), tmp8, new_x2_mask & new_z1_mask) + + +@triton.jit +def triton_unk_fused_cat_3d_dim2(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, x0_numel, x1_numel, + Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): + y0_offset = tl.program_id(0) * Y0BLOCK + base_y0 = tl.arange(0, Y0BLOCK_SUB) + loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB + base_x = tl.arange(0, XBLOCK_SUB) + min_numel = 0 + max_numel = 0 + clone_numel = 0 + zy_numel = z0_numel * y0_numel + if x0_numel < x1_numel: + min_numel = x0_numel + max_numel = x1_numel + clone_numel = x1_numel - x0_numel + else: + min_numel = x1_numel + max_numel = x0_numel + clone_numel = x0_numel - x1_numel + loops_x = (min_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + loops_x2 = (clone_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, zy_numel) + for loop_x in range(loops_x): + x = (loop_x * XBLOCK_SUB) + base_x[None, :] + x_mask = x < min_numel + + tmp0 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask) + tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask) + tmp10 = tl.zeros((Y0BLOCK_SUB, 2 * XBLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [0, XBLOCK_SUB], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (Y0BLOCK_SUB, 2, XBLOCK_SUB)) + + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = (loop_x * XBLOCK_SUB) + new_base_x2[None, None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None, None] + new_z0 = tl.arange(0, 2)[None, :, None] + new_x2_mask = new_x2 < min_numel + new_y1_mask = new_y1 < zy_numel + sum_numel = x0_numel + x1_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1 + x0_numel * new_z0), tmp13, new_x2_mask & new_y1_mask) + + if x0_numel == x1_numel: + return + + if x0_numel < x1_numel: + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, zy_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < x1_numel + + tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = x0_numel + min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None] + sum_numel = x0_numel + x1_numel + new_x2_mask = new_x2 < sum_numel + new_y1_mask = new_y1 < zy_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + else: + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, zy_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < x0_numel + + tmp8 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None] + sum_numel = x0_numel + x1_numel + new_x2_mask = new_x2 < x0_numel + new_y1_mask = new_y1 < zy_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + + +testlist = [ + # ===================== 1D场景(15组,dim=0) ===================== + ((3, ), (3, ), 0), + ((7, ), (9, ), 0), + ((13, ), (11, ), 0), + ((2047, ), (2047, ), 0), + ((2701, ), (3003, ), 0), + ((4093, ), (3095, ), 0), + + # ===================== 2D场景(20组,dim0/dim1) ===================== + # dim0(行拼接,列维度一致) + ((3, 5), (3, 5), 0), + ((1005, 300), (2007, 300), 0), + ((1307, 400), (309, 400), 0), + ((303, 500), (303, 500), 0), + # dim1(列拼接,行维度一致) + ((7, 9), (7, 9), 1), + ((100, 1001), (100, 2003), 1), + ((200, 2005), (200, 207), 1), + ((300, 707), (300, 707), 1), + + # ===================== 3D场景(15组,dim0/dim1/dim2) ===================== + # dim0(第0维拼接,d1/d2一致) + ((378, 200, 300), (101, 200, 300), 0), + ((378, 70, 50), (601, 70, 50), 0), + # dim1(第1维拼接,d0/d2一致) + ((100, 452, 300), (100, 201, 300), 1), + ((65, 1735, 57), (65, 2001, 57), 1), + # dim2(第2维拼接,d0/d1一致) + ((87, 200, 387), (87, 200, 501), 2), + ((20, 337, 543), (20, 337, 401), 2), +] + + +@pytest.mark.parametrize('testlists', testlist) +@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_cat_bigshape(testlists, dtype): + torch_dtype = eval('torch.' + dtype) + np_x0 = test_common.generate_numpy(testlists[0], dtype) + np_x1 = test_common.generate_numpy(testlists[1], dtype) + cat_dim = testlists[2] + + x0 = torch.from_numpy(np_x0).to(torch_dtype).npu() + x1 = torch.from_numpy(np_x1).to(torch_dtype).npu() + + if len(x0.shape) > 3: + pytest.skip("dim > 3 for 3D+ tensor, skipping.") + + torch_res = torch.cat([x0, x1], dim=cat_dim) + triton_res = torch.zeros_like(torch_res) + num_core = 32 + if len(x0.shape) == 3: + if cat_dim == 0: + ZBLOCK = (min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core + ZBLOCK_2 = (max(x0.shape[0], x1.shape[0]) - min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core + triton_unk_fused_cat_3d_dim0[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x1.shape[0], x0.shape[1], + x0.shape[2], ZBLOCK, ZBLOCK_2, 1, 256) + elif cat_dim == 1: + Z0BLOCK = (x0.shape[0] + num_core - 1) // num_core + triton_unk_fused_cat_3d_dim1[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x1.shape[1], + x1.shape[2], Z0BLOCK, 1, 256) + else: + Y0BLOCK = (x0.shape[0] * x0.shape[1] + num_core - 1) // num_core + triton_unk_fused_cat_3d_dim2[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x0.shape[2], + x1.shape[2], Y0BLOCK, 1, 256) + test_common.validate_cmp(dtype, torch_res, triton_res) + return + numel_large = torch_res.numel() > 512 and len(x0.shape) < 3 + if numel_large or (cat_dim == 0 and len(x0.shape) == 2): + squeeze_flag = False + if len(x0.shape) == 1: + squeeze_flag = True + x0 = torch.unsqueeze(x0, dim=0) + x1 = torch.unsqueeze(x1, dim=0) + triton_res = torch.unsqueeze(triton_res, dim=0) + cat_dim = 1 + if cat_dim == 1: + Y0BLOCK = (x0.shape[0] + num_core - 1) // num_core + if x0.shape[1] == x1.shape[1]: + triton_unk_fused_cat_dim1_sameshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], + Y0BLOCK, 1, 256) + else: + triton_unk_fused_cat_dim1_diffshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], + x1.shape[1], Y0BLOCK, 1, 256) + else: + if x0.shape[0] == x1.shape[0]: + Y0BLOCK = (x0.shape[0] + num_core - 1) // num_core + triton_unk_fused_cat_dim0_sameshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], + Y0BLOCK, 1, 256) + else: + YBLOCK = (min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core + YBLOCK_2 = (max(x0.shape[0], x1.shape[0]) - min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core + triton_unk_fused_cat_dim0_diffshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x1.shape[0], + x1.shape[1], YBLOCK, YBLOCK_2, 1, 256) + if squeeze_flag: + triton_res = triton_res.squeeze() + else: + squeeze_flag = False + if len(x0.shape) == 1: + squeeze_flag = True + x0 = torch.unsqueeze(x0, dim=0) + x1 = torch.unsqueeze(x1, dim=0) + triton_res = torch.unsqueeze(triton_res, dim=0) + _cat_helper_func_2D_1[num_core, 1, 1](x0, x1, triton_res, x0.shape[1], x1.shape[1], x0.shape[0], + x0.shape[1] + x1.shape[1], 256, 16) + if squeeze_flag: + triton_res = triton_res.squeeze() + + test_common.validate_cmp(dtype, torch_res, triton_res) diff --git a/third_party/ascend/unittest/pytest_ut/test_celoss_indices.py b/third_party/ascend/unittest/pytest_ut/test_celoss_indices.py new file mode 100644 index 0000000000..69c1f9b0a2 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_celoss_indices.py @@ -0,0 +1,112 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import triton +import triton.language as tl + + +@triton.jit +def celoss_indices_kernel( + inp_ptr, + tgt_ptr, + w_ptr, + out_ptr, + w_tgt_ptr, + ignore_index, + C, + D, + BLOCK_C: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_d = tl.program_id(0).to(tl.int64) + pid_n = tl.program_id(1).to(tl.int64) + offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D).to(tl.int64) + + tgt_ptrs = tgt_ptr + pid_n * D + offset_d + tgt_mask = offset_d < D + tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) + + ignore_mask = not (tgt == ignore_index) and tgt_mask + + tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) + tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) + + for off in range(0, C, BLOCK_C): + offset_c = off + tl.arange(0, BLOCK_C) + inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] + inp_mask = offset_c[:, None] < C and offset_d[None, :] < D + inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) + cur_max = tl.maximum(tmp_max, inp) + cur_exp = tl.exp(inp - cur_max) + tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp + tmp_max = cur_max + + final_max = tl.max(tmp_max, axis=0) + tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :]) + final_sum = tl.log(tl.sum(tmp_sum, axis=0)) + inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d + inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32) + + out = final_sum + final_max - inp_tgt + w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d + + if w_ptr is None: + w_tgt = ignore_mask + else: + w_tgt = tl.load(w_ptr + tgt, mask=tgt_mask, other=0).to(tl.float32) + w_tgt = tl.where(ignore_mask, w_tgt, 0) + + tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) + out *= w_tgt + out_ptrs = out_ptr + pid_n * D + offset_d + tl.store(out_ptrs, out, mask=tgt_mask) + + +def test_celoss_indices_kernel(shape=(1, 2)): + device = "npu" + dtype = torch.float16 + ignore_index = -100 + BLOCK_C = 256 + BLOCK_D = 1 + + N, C = shape + D = 1 + + inp = torch.randn(shape, dtype=dtype, device=device) + tgt = torch.randint(0, C, (N, ), dtype=torch.int64, device=device) + wgt = torch.randn(C, dtype=dtype, device=device) + + out_triton = torch.empty((N * D, ), dtype=torch.float32, device=device) + w_tgt_triton = torch.empty((N * D, ), dtype=torch.float32, device=device) + + grid = (triton.cdiv(D, BLOCK_D), N) + celoss_indices_kernel[grid]( + inp, + tgt, + wgt, + out_triton, + w_tgt_triton, + ignore_index, + C, + D, + BLOCK_C=BLOCK_C, + BLOCK_D=BLOCK_D, + ) diff --git a/third_party/ascend/unittest/pytest_ut/test_compile_hint.py b/third_party/ascend/unittest/pytest_ut/test_compile_hint.py index 87b7fb3463..4582201936 100644 --- a/third_party/ascend/unittest/pytest_ut/test_compile_hint.py +++ b/third_party/ascend/unittest/pytest_ut/test_compile_hint.py @@ -45,6 +45,7 @@ def triton_compile_hint(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_ tl.store(out_ptr0 + (xindex), tmp2, xmask) +@pytest.mark.skip(reason="not supported after the NPUIR is updated in April, and will be fixed later") @pytest.mark.parametrize('param_list', [ ['float32', (2, 4096, 8), 2, 32768, 1024], ]) diff --git a/third_party/ascend/unittest/pytest_ut/test_complex_mask.py b/third_party/ascend/unittest/pytest_ut/test_complex_mask.py new file mode 100644 index 0000000000..a8f3ada327 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_complex_mask.py @@ -0,0 +1,69 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + + +def copy(x): + return x.clone() + + +@triton.jit +def copy_kernel(in_ptr, out_ptr, N: tl.constexpr, NUMEL): + idx_block = tl.arange(0, N) + is_valid = N <= NUMEL + x = tl.load(in_ptr + idx_block, mask=idx_block < N) + mask_i1 = is_valid[:, None] & (idx_block < N)[None, :] + tl.store(out_ptr + idx_block[None, :], x[None, :], mask=mask_i1) + + +@triton.jit +def permute_copy_kernel(in_ptr, out_ptr, M: tl.constexpr, N: tl.constexpr, NUMEL): + idx_block_n = tl.arange(0, N) + idx_block_m = tl.arange(0, M) + idx_block = idx_block_m[:, None] + idx_block_n[None, :] * M + is_valid = N <= NUMEL + x = tl.load(in_ptr + idx_block, mask=(idx_block_m[:, None] < M) & (idx_block_n[None, :] < N)) + mask_i1 = (is_valid[:, None, None]) & (idx_block_m[None, :, None] < M) & (idx_block_n[None, None, :] < N) + tl.store(out_ptr + idx_block[None, :], x[None, :], mask=mask_i1) + + +def test_complex_mask_copy(): + N = 1024 + x = torch.randn(N, dtype=torch.float32).npu() + y = torch.empty_like(x).npu() + copy_kernel[(1, )](x, y, N=N, NUMEL=N) + torch.testing.assert_close(x, y) + + +def test_complex_mask_permute_copy(): + M = 4 + N = 32 + x = torch.randn(M * N, dtype=torch.float32).npu() + y = torch.empty_like(x).npu() + permute_copy_kernel[(1, )](x, y, M=M, N=N, NUMEL=M * N) + torch.testing.assert_close(x, y) diff --git a/third_party/ascend/unittest/pytest_ut/test_copy.py b/third_party/ascend/unittest/pytest_ut/test_copy.py index f0a2a778be..4cf9daa322 100644 --- a/third_party/ascend/unittest/pytest_ut/test_copy.py +++ b/third_party/ascend/unittest/pytest_ut/test_copy.py @@ -70,15 +70,16 @@ def copy( a1_val = tl.load(a1_ptr) add = tl.add(a_val, a1_val) - add_ub = bl.to_buffer(add, al.ascend_address_space.UB) + A_l1 = bl.alloc(tl.float32, [M, N], al.ascend_address_space.L1) al.copy_from_ub_to_l1(add_ub, A_l1) + A_ub = bl.alloc(tl.float32, [M, N], al.ascend_address_space.UB) + al.copy(add_ub, A_ub) -# ============== Main for manual testing ============== -if __name__ == "__main__": +def test_copy(): print("=" * 60) print("Test 1: copy ") print("=" * 60) @@ -89,3 +90,8 @@ def copy( ) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) + + +# ============== Main for manual testing ============== +if __name__ == "__main__": + test_copy() diff --git a/third_party/ascend/unittest/pytest_ut/test_cumprod.py b/third_party/ascend/unittest/pytest_ut/test_cumprod.py index c4c6833574..feed59b7c4 100644 --- a/third_party/ascend/unittest/pytest_ut/test_cumprod.py +++ b/third_party/ascend/unittest/pytest_ut/test_cumprod.py @@ -89,7 +89,7 @@ def cumprod_generate_tensor(shape, dtype): @pytest.mark.parametrize("dtype", support_dtypes) @pytest.mark.parametrize("shape", [(7, 23)]) @pytest.mark.parametrize("dim", [0, 1]) -@pytest.mark.parametrize("reverse", [False]) +@pytest.mark.parametrize("reverse", [False, True]) def test_cumprod(dtype, shape, dim, reverse): x0 = cumprod_generate_tensor(shape=shape, dtype=dtype).npu() triton_cal = triton_func(x0, dim, reverse) diff --git a/third_party/ascend/unittest/pytest_ut/test_cumsum.py b/third_party/ascend/unittest/pytest_ut/test_cumsum.py index edc4553a0a..c9196b6bf7 100644 --- a/third_party/ascend/unittest/pytest_ut/test_cumsum.py +++ b/third_party/ascend/unittest/pytest_ut/test_cumsum.py @@ -73,7 +73,7 @@ def triton_func(x, dim, reverse): @pytest.mark.parametrize("dtype", support_dtypes) @pytest.mark.parametrize("shape", [(7, 23)]) @pytest.mark.parametrize("dim", [0, 1]) -@pytest.mark.parametrize("reverse", [False]) +@pytest.mark.parametrize("reverse", [False, True]) def test_cumsum(dtype, shape, dim, reverse): x0 = generate_tensor(shape=shape, dtype=dtype).npu() triton_cal = triton_func(x0, dim, reverse) diff --git a/third_party/ascend/unittest/pytest_ut/test_custom.py b/third_party/ascend/unittest/pytest_ut/test_custom.py index 9c2d35c0b8..8b3b6d0743 100755 --- a/third_party/ascend/unittest/pytest_ut/test_custom.py +++ b/third_party/ascend/unittest/pytest_ut/test_custom.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import subprocess +import os import triton import triton.language as tl import triton.language.extra.cann.extension as al @@ -39,9 +40,31 @@ class my_custom_op: core = al.CORE.VECTOR pipe = al.PIPE.PIPE_V mode = al.MODE.SIMT + symbol = "my_custom_func" + # fake path, this test only check Triton successfully lowered to MLIR + bitcode = os.path.abspath(__file__) + iterator_types = [ + al.IteratorType.Parallel, + al.IteratorType.Broadcast, + al.IteratorType.Transpose, + al.IteratorType.Reduction, + al.IteratorType.Interleave, + al.IteratorType.Deinterleave, + al.IteratorType.Inverse, + al.IteratorType.Pad, + al.IteratorType.Concat, + al.IteratorType.Gather, + al.IteratorType.Cumulative, + al.IteratorType.Opaque, + ] def __init__(self, x, ptr1, ptr2, offset: tl.int64, other, out=None): - pass + # Add optional custom-op attribute: ArrayAttr. + self.indexing_map = [al.affine_map.get_identity(1)] + + # Tag ptr2 as an argument that should be aligned at dimension 1. + # Tag 2nd argument that should be aligned at dimension 0. + self.align_dim = {"ptr2": 1, 1: 0} @triton.jit @@ -55,6 +78,90 @@ def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): tl.store(out_ptr + i, result, mask=i < n) +@al.register_custom_op +class my_custom_op_extra_buf: + """Custom op declaring extra_buffers with several scalar Triton dtypes.""" + + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "my_extra_buf_func" + bitcode = os.path.abspath(__file__) + + def __init__(self, x, out=None): + self.indexing_map = [al.affine_map.get_identity(1)] + self.extra_buffers = [ + (tl.bfloat16, 256), + (tl.float64, 424242), + (tl.int8, 11), + (tl.float16, 22), + (tl.int32, 33), + ] + + +@al.register_custom_op +class my_custom_op_extra_buf_single_buf: + """Custom op declaring extra_buffers with single buf.""" + + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "my_extra_buf_func_single_buf" + bitcode = os.path.abspath(__file__) + + def __init__(self, x, out=None): + self.indexing_map = [al.affine_map.get_identity(1)] + self.extra_buffers = (tl.bfloat16, 256) + + +@triton.jit +def kernel_extra_buf(x_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(out_ptr + i, mask=i < n) + r = al.custom("my_custom_op_extra_buf", x, out=y) + tl.store(out_ptr + i, r, mask=i < n) + + +@triton.jit +def kernel_extra_buf_single_buf(x_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(out_ptr + i, mask=i < n) + r = al.custom("my_custom_op_extra_buf_single_buf", x, out=y) + tl.store(out_ptr + i, r, mask=i < n) + + +@al.register_custom_op +class my_custom_op_extra_buf_wide: + """Cover more integer widths and unsigned dtypes in extra_buffers_types.""" + + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "my_extra_buf_wide_func" + bitcode = os.path.abspath(__file__) + + def __init__(self, x, out=None): + self.indexing_map = [al.affine_map.get_identity(1)] + self.extra_buffers = [ + (tl.int16, 1001), + (tl.uint16, 1002), + (tl.int64, 1003), + (tl.uint32, 1004), + (tl.uint8, 1005), + ] + + +@triton.jit +def kernel_extra_buf_wide(x_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(out_ptr + i, mask=i < n) + r = al.custom("my_custom_op_extra_buf_wide", x, out=y) + tl.store(out_ptr + i, r, mask=i < n) + + # ============== Pytest tests ============== @@ -75,14 +182,116 @@ def test_custom_op(): assert "hivm.pipe = #hivm.pipe" in line assert "hivm.tcore_type = #hivm.tcore_type" in line assert "hivm.vf_mode = #hivm.vf_mode" in line + # Optional indexing map attribute should be attached. + assert "indexing_map = [" in line + # Tagged argument alignment info is attached as integer operand attr. + assert "align_dim = 1" in line + assert "align_dim = 0" in line # All offset converted to int64. assert 'i64, ' in line assert 'i32, ' not in line + assert "iterator_types" in line + for iterator_name in ( + "parallel", + "broadcast", + "transpose", + "reduction", + "interleave", + "deinterleave", + "inverse", + "pad", + "concat", + "gather", + "cumulative", + "opaque", + ): + assert iterator_name in line + + +def _custom_lines(mlir: str, op_name: str): + # Match the MLIR string attribute exactly (avoid `my_custom_op` matching + # `my_custom_op_extra_buf`). + quoted = f'"{op_name}"' + return [line for line in mlir.splitlines() if "hivm.hir.custom" in line and quoted in line] + + +def test_custom_op_extra_buffers_mixed_scalar_types(): + """extra_buffers_types must preserve bf16/f64/i8/f16/i32 (not all lowered to f32).""" + mlir = compile_kernel( + kernel_extra_buf, + {"x_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + assert mlir and len(mlir) > 0 + lines = _custom_lines(mlir, "my_custom_op_extra_buf") + assert lines, "expected at least one hivm.hir.custom line for my_custom_op_extra_buf" + line = lines[0] + assert "extra_buffers_types" in line + assert "extra_buffers_sizes" in line + assert "bf16" in line + assert "f64" in line + assert "i8" in line + assert "f16" in line + assert "i32" in line + assert "424242" in line + + +def test_custom_op_extra_buffers_single_buffer(): + mlir = compile_kernel( + kernel_extra_buf_single_buf, + {"x_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + assert mlir and len(mlir) > 0 + lines = _custom_lines(mlir, "my_custom_op_extra_buf_single_buf") + assert lines, "expected at least one hivm.hir.custom line for my_custom_op_extra_buf_single_buf" + line = lines[0] + assert "extra_buffers_types" in line + assert "extra_buffers_sizes" in line + assert "f32" in line + + +def test_custom_op_extra_buffers_integer_variants(): + """extra_buffers accept int16/uint16/int64/uint32/uint8 (IR uses i* storage types).""" + mlir = compile_kernel( + kernel_extra_buf_wide, + {"x_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + assert mlir and len(mlir) > 0 + lines = _custom_lines(mlir, "my_custom_op_extra_buf_wide") + assert lines + line = lines[0] + assert "extra_buffers_types" in line + assert "extra_buffers_sizes" in line + assert "i16" in line + assert "i64" in line + assert "i32" in line + assert "i8" in line + assert "1001" in line and "1005" in line + + +def test_custom_op_without_extra_buffers_has_no_extra_buffer_attrs(): + """Ops that do not set extra_buffers should not emit extra_buffers_* attributes.""" + mlir = compile_kernel( + my_kernel, + {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + assert mlir + for line in _custom_lines(mlir, "my_custom_op"): + assert "extra_buffers_types" not in line + assert "extra_buffers_sizes" not in line # ============== Main for manual testing ============== if __name__ == "__main__": + test_custom_op() + test_custom_op_without_extra_buffers_has_no_extra_buffer_attrs() + test_custom_op_extra_buffers_integer_variants() + test_custom_op_extra_buffers_mixed_scalar_types() + test_custom_op_extra_buffers_single_buffer() mlir = compile_kernel(my_kernel, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, {"BLOCK": 256}) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_atomic.py b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_atomic.py new file mode 100644 index 0000000000..c0527c7dc4 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_atomic.py @@ -0,0 +1,48 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import triton +import triton.language as tl +import torch_npu +import pytest + + +@triton.jit +def single_disc_mask_atomic_add_kernel( + in_ptr, + BLOCK_N: tl.constexpr, +): + col_offs = tl.arange(0, BLOCK_N) + disc_mask = (col_offs * 2) < BLOCK_N + ptr_in = in_ptr + col_offs + tl.atomic_add(ptr_in, 1, mask=disc_mask) + + +@pytest.mark.parametrize("BLOCK_N", [8]) +def test_single_discrete_mask_atomic_add(BLOCK_N): + in_tensor = torch.arange(BLOCK_N, dtype=torch.float16, device='npu') + expected = in_tensor.clone() + single_disc_mask_atomic_add_kernel[(1, )](in_tensor, BLOCK_N=BLOCK_N) + + half = BLOCK_N // 2 + expected[:half] += 1 + assert torch.allclose(in_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{in_tensor.cpu()}" diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py index a13322f361..cb03ff1bec 100644 --- a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py @@ -18,6 +18,21 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +# ============================================================================= +# Discrete mask access conversion test suite +# +# Test matrix (mask type x operation type): +# +# | mask type | load only | store only | load + store | +# |---------------------------------|-----------|------------|--------------| +# | single discrete mask | (A) | (B) | - | +# | single continuous mask | (C) | (D) | - | +# | continuous & discrete 2-way | (E) | (F) | (G) | +# | continuous & discrete 4-way | - | - | (H) | +# | broadcast(cont & disc) 2-D AND | (I) | - | (J) | +# +# ============================================================================= + import torch import triton import triton.language as tl @@ -25,8 +40,205 @@ import pytest +# ============================================================================= +# (A) Single discrete mask -- load only +# ============================================================================= +@triton.jit +def single_disc_mask_load_kernel( + in_ptr, + out_ptr, + BLOCK_N: tl.constexpr, +): + col_offs = tl.arange(0, BLOCK_N) + disc_mask = (col_offs * 2) < BLOCK_N + ptr_in = in_ptr + col_offs + ptr_out = out_ptr + col_offs + data = tl.load(ptr_in, mask=disc_mask, other=0.0) + tl.store(ptr_out, data) + + +@pytest.mark.parametrize("BLOCK_N", [8]) +def test_single_discrete_mask_load(BLOCK_N): + in_tensor = torch.arange(BLOCK_N, dtype=torch.float16, device='npu') + out_tensor = torch.empty(BLOCK_N, dtype=torch.float16, device='npu') + + single_disc_mask_load_kernel[(1, )](in_tensor, out_tensor, BLOCK_N=BLOCK_N) + + half = BLOCK_N // 2 + expected = torch.zeros(BLOCK_N, dtype=torch.float16, device='npu') + expected[:half] = in_tensor[:half] + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (B) Single discrete mask -- store only +# ============================================================================= +@triton.jit +def single_disc_mask_store_kernel( + in_ptr, + out_ptr, + BLOCK_N: tl.constexpr, +): + col_offs = tl.arange(0, BLOCK_N) + disc_mask = (col_offs * 2) < BLOCK_N + ptr_in = in_ptr + col_offs + ptr_out = out_ptr + col_offs + data = tl.load(ptr_in) + tl.store(ptr_out, data, mask=disc_mask) + + +@pytest.mark.parametrize("BLOCK_N", [8]) +def test_single_discrete_mask_store(BLOCK_N): + in_tensor = torch.arange(BLOCK_N, dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_N, ), -1.0, dtype=torch.float16, device='npu') + + single_disc_mask_store_kernel[(1, )](in_tensor, out_tensor, BLOCK_N=BLOCK_N) + + half = BLOCK_N // 2 + expected = torch.full((BLOCK_N, ), -1.0, dtype=torch.float16, device='npu') + expected[:half] = in_tensor[:half] + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (C) Single continuous mask -- load only +# ============================================================================= +@triton.jit +def single_cont_mask_load_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + cont_mask = row_offs < M # Continuous mask + ptr_in = in_ptr + row_offs + ptr_out = out_ptr + row_offs + data = tl.load(ptr_in, mask=cont_mask, other=0.0) + tl.store(ptr_out, data, mask=cont_mask) + + +@pytest.mark.parametrize("M,BLOCK_M", [(6, 8)]) +def test_single_continuous_mask_load(M, BLOCK_M): + in_tensor = torch.arange(BLOCK_M, dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_M, ), -1.0, dtype=torch.float16, device='npu') + + single_cont_mask_load_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M) + + expected = torch.full((BLOCK_M, ), -1.0, dtype=torch.float16, device='npu') + expected[:M] = in_tensor[:M] + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (D) Single continuous mask -- store only +# ============================================================================= +@triton.jit +def single_cont_mask_store_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + cont_mask = row_offs < M + ptr_in = in_ptr + row_offs + ptr_out = out_ptr + row_offs + data = tl.load(ptr_in) + tl.store(ptr_out, data, mask=cont_mask) + + +@pytest.mark.parametrize("M,BLOCK_M", [(6, 8)]) +def test_single_continuous_mask_store(M, BLOCK_M): + in_tensor = torch.arange(BLOCK_M, dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_M, ), -1.0, dtype=torch.float16, device='npu') + single_cont_mask_store_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M) + expected = torch.full((BLOCK_M, ), -1.0, dtype=torch.float16, device='npu') + expected[:M] = in_tensor[:M] + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (E) Continuous & discrete 2-way AND -- load only +# ============================================================================= +@triton.jit +def cont_disc_combined_mask_load_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + # Continuous mask + row_boundary = row_offs < M + # Discrete mask + col_stride = (col_offs * 2) < BLOCK_N + mask = row_boundary[:, None] & col_stride[None, :] + ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + data = tl.load(ptr_in, mask=mask, other=0.0) + tl.store(ptr_out, data) + + +@pytest.mark.parametrize("M,BLOCK_M,BLOCK_N", [(6, 8, 8)]) +def test_cont_disc_combined_mask_load(M, BLOCK_M, BLOCK_N): + in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + out_tensor = torch.empty((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + + cont_disc_combined_mask_load_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + + half_n = BLOCK_N // 2 + expected = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + expected[:M, :half_n] = 1.0 + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (F) Continuous & discrete 2-way AND -- store only +# ============================================================================= +@triton.jit +def cont_disc_combined_mask_store_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + row_boundary = row_offs < M # continuous -> contLeaf + col_stride = (col_offs * 2) < BLOCK_N # discrete -> discLeaf + mask = row_boundary[:, None] & col_stride[None, :] + ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + data = tl.load(ptr_in) + tl.store(ptr_out, data, mask=mask) + + +@pytest.mark.parametrize("M,BLOCK_M,BLOCK_N", [(6, 8, 8)]) +def test_cont_disc_combined_mask_store(M, BLOCK_M, BLOCK_N): + in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') + cont_disc_combined_mask_store_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + half_n = BLOCK_N // 2 + expected = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') + expected[:M, :half_n] = 1.0 + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (G) Continuous & discrete 2-way AND -- load + store (complex interleave, original) +# ============================================================================= @triton.jit -def simple_discrete_mask_load_kernel( +def interleave_cont_disc_mask_kernel( in_ptr, out_ptr, M: tl.constexpr, @@ -35,9 +247,9 @@ def simple_discrete_mask_load_kernel( pid = tl.program_id(0) col_offs = tl.arange(0, N) even_col_offs = tl.arange(0, N // 2) * 2 - even_col_mask = even_col_offs < N + even_col_mask = even_col_offs < N # discrete: cmpi(muli(range,2), N) row_offs = tl.arange(0, M) - row_mask = row_offs < M + row_mask = row_offs < M # continuous: cmpi(range_M, M) in_even_ptr = in_ptr + row_offs[:, None] * N + even_col_offs[None, :] in_odd_ptr = in_ptr + row_offs[:, None] * N + even_col_offs[None, :] + 1 even_data = tl.load(in_even_ptr, mask=row_mask[:, None] & even_col_mask[None, :], other=0.0) @@ -47,21 +259,150 @@ def simple_discrete_mask_load_kernel( tl.store(out_ptr, rotated_data) -@pytest.mark.parametrize("M", [(4)]) -@pytest.mark.parametrize("N", [(8)]) +@pytest.mark.skip(reason="not supported after the NPUIR is updated in April, and will be fixed later") +@pytest.mark.parametrize("M", [4]) +@pytest.mark.parametrize("N", [8]) def test_discrete_mask_load_store(M, N): + """Regression test: mask=row_mask & even_col_mask (continuous & discrete 2-way)""" input_tensor = torch.arange(M * N, dtype=torch.float16, device='npu').reshape(M, N) output_tensor = torch.empty_like(input_tensor) - grid = (1, ) - simple_discrete_mask_load_kernel[grid]( - input_tensor, - output_tensor, - M=M, - N=N, - ) + interleave_cont_disc_mask_kernel[(1, )](input_tensor, output_tensor, M=M, N=N) even_cols = input_tensor[:, 0::2] odd_cols = input_tensor[:, 1::2] ref_output = torch.empty_like(input_tensor) ref_output[:, 0::2] = -odd_cols ref_output[:, 1::2] = even_cols assert torch.allclose(output_tensor.float(), ref_output.float()) + + +# ============================================================================= +# (H) Continuous & discrete 4-way AND -- load + store +# ============================================================================= +@triton.jit +def multi_cont_disc_mask_load_store_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + + row_boundary = row_offs < M # continuous mask + col_boundary = col_offs < N # continuous mask + row_stride = (row_offs * 2) < BLOCK_M # discrete mask + col_stride = (col_offs * 2) < BLOCK_N # discrete mask + + mask = (row_boundary[:, None] & col_boundary[None, :] & row_stride[:, None] & col_stride[None, :]) + + ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + + data = tl.load(ptr_in, mask=mask, other=0.0) + result = data + 1.0 + tl.store(ptr_out, result, mask=mask) + + +@pytest.mark.parametrize("M,N,BLOCK_M,BLOCK_N", [ + (6, 6, 8, 8), +]) +def test_multi_cont_disc_mask_load_store(M, N, BLOCK_M, BLOCK_N): + in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + out_tensor = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + + multi_cont_disc_mask_load_store_kernel[(1, )](in_tensor, out_tensor, M=M, N=N, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + + half_m = BLOCK_M // 2 # = 4 + half_n = BLOCK_N // 2 # = 4 + expected = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + expected[:half_m, :half_n] = 2.0 + + assert torch.allclose(out_tensor, expected), (f"BLOCK=({BLOCK_M},{BLOCK_N}), valid=({M},{N})\n" + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}") + + +# ============================================================================= +# (I) broadcast(continuous & discrete) 2-D AND -- load only +# ============================================================================= +@triton.jit +def broadcast_cont_disc_2d_load_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + + row_boundary = row_offs < M + row_disc = (row_offs * 2) < BLOCK_M + mask = row_boundary[:, None] & row_disc[:, None] & (col_offs < BLOCK_N)[None, :] + + ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + + data = tl.load(ptr_in, mask=mask, other=0.0) + tl.store(ptr_out, data) + + +@pytest.mark.parametrize("M,BLOCK_M,BLOCK_N", [(3, 4, 8)]) +def test_broadcast_cont_disc_2d_load(M, BLOCK_M, BLOCK_N): + in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + out_tensor = torch.empty((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + + broadcast_cont_disc_2d_load_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + + disc_true_rows = BLOCK_M // 2 + both_true_rows = min(M, disc_true_rows) + + expected = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + expected[:both_true_rows, :] = 1.0 + + assert torch.allclose(out_tensor, expected), (f"M={M}, BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}\n" + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}") + + +# ============================================================================= +# (J) broadcast(continuous & discrete) 2-D AND -- load + store +# ============================================================================= +@triton.jit +def broadcast_cont_disc_2d_load_store_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + + row_boundary = row_offs < M + row_disc = (row_offs * 2) < BLOCK_M + + combined = row_boundary[:, None] & row_disc[:, None] & (col_offs < BLOCK_N)[None, :] + + ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + + data = tl.load(ptr_in, mask=combined, other=0.0) + tl.store(ptr_out, data, mask=combined) + + +@pytest.mark.parametrize("M,BLOCK_M,BLOCK_N", [(3, 4, 8)]) +def test_broadcast_cont_disc_2d_load_store(M, BLOCK_M, BLOCK_N): + in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') + + broadcast_cont_disc_2d_load_store_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + + disc_true_rows = BLOCK_M // 2 + both_true_rows = min(M, disc_true_rows) + + expected = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') + expected[:both_true_rows, :] = 1.0 + + assert torch.allclose(out_tensor, expected), (f"M={M}, BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}\n" + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}") diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_tail_block_mte_oob.py b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_tail_block_mte_oob.py new file mode 100644 index 0000000000..d8129d2677 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_tail_block_mte_oob.py @@ -0,0 +1,238 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +# ============================================================================= +# MTE (Memory Tag Extension) OOB regression test for DiscreteMaskAccessConversionPass +# +# This file verifies that DiscreteMaskAccessConversionPass correctly bounds +# global-memory accesses when the load/store mask is a combined discrete mask +# +# Test strategy +# ------------- +# The test engineers this condition in four steps: +# Step 1 — probe: Trigger a fresh 2 MB NPU segment; measure its size. +# Step 2 — pre_fill: Fill the segment with small tensors until the remaining +# free space is in [IN_BYTES, TARGET_FREE]. +# Step 3 — in_tensor: Allocate the test tensor; it lands at the segment tail +# with only ~7680 bytes gap to the boundary. +# Step 4 — kernel: Run the kernel + synchronize. Before the fix the +# OOB read (24576 bytes) crosses the boundary → MTE. +# After the fix the copy is bounded to IN_BYTES → no MTE. +# +# Memory layout at the time of the kernel call (before fix): +# +# ┌──────────────────────────── 2 MB segment ────────────────────────────────┐ +# │ probe(512 B) │←────── pre_fill (~2025 KB) ──────→│ in_tensor(8192 B) │gap│ +# └──────────────────────────────────────────────────────────────────────────┘ +# ↑ segment end +# ├──── OOB_BYTES (24576 B) ─────→ +# crosses boundary → MTE ✓ +# +# ============================================================================= + +import math +import torch +import triton +import triton.language as tl +import torch_npu +import pytest + + +@triton.jit +def cont_disc_oob_inplace_2d_kernel( + ptr, + M, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + + # Continuous bound (contMask): marks the M valid rows in this tile. + row_boundary = row_offs < (M - pid_m * BLOCK_M) + row_disc = (row_offs * 2) < BLOCK_M + combined = row_boundary[:, None] & row_disc[:, None] & (col_offs < BLOCK_N)[None, :] + + row_start = pid_m * BLOCK_M + ptr_2d = ptr + (row_start + row_offs[:, None]) * BLOCK_N + col_offs[None, :] + + # load triggers DiscreteMaskAccessConversionPass. + # Before fix: copy size = BLOCK_M × BLOCK_N × 2 bytes = 32768 bytes (OOB). + # After fix: copy size = M × BLOCK_N × 2 bytes = 8192 bytes (safe). + data = tl.load(ptr_2d, mask=combined, other=0.0) + tl.store(ptr_2d, data, mask=row_boundary[:, None]) + + +# ============================================================================= +# Memory setup helper +# ============================================================================= +def _fill_segment_to_boundary(dtype, device, in_bytes, target_free, chunk_max_bytes): + """Allocate a fresh NPU segment and fill it so that only ~target_free bytes remain. + + Returns + ------- + pre_fillers : list of torch.Tensor + All tensors allocated (probe + fill chunks). The caller is responsible + for deleting them in `finally`. + pool_free_after_fill : int + Segment free space after filling, in bytes. + seg_size : int + Total size of the triggered segment, in bytes. + """ + elem_size = torch.finfo(dtype).bits // 8 + + # --- Step 1: probe — trigger a fresh 2 MB small-alloc segment ---------- + pool0 = torch.npu.memory_reserved(0) + alloc0 = torch.npu.memory_allocated(0) + + probe = torch.empty(1, dtype=dtype, device=device) + + pool1 = torch.npu.memory_reserved(0) + alloc1 = torch.npu.memory_allocated(0) + + seg_size = pool1 - pool0 # should be 2 MB = 2097152 bytes + probe_actual = alloc1 - alloc0 # NPU 512-byte aligned → 512 bytes + + print(f"\n[mte] Step 1: probe") + print(f"[mte] segment_size = {seg_size} bytes ({seg_size // 1024} KB)") + print(f"[mte] probe_actual = {probe_actual} bytes") + print(f"[mte] pool_free = {seg_size - probe_actual} bytes") + + # --- Step 2: pre_fill — leave only [in_bytes, target_free] bytes free --- + # Chunks are kept ≤ chunk_max_bytes to stay in the small-alloc pool and + # avoid opening a new segment via the large-alloc path. + pre_fillers = [probe] + + for chunk in [ + chunk_max_bytes, chunk_max_bytes // 2, chunk_max_bytes // 4, chunk_max_bytes // 8, 32 * 1024, 16 * 1024, + 8 * 1024, 4 * 1024, 2 * 1024, 1024, 512 + ]: + while True: + free = torch.npu.memory_reserved(0) - torch.npu.memory_allocated(0) + if free <= in_bytes: + break # not enough room even for in_tensor; stop + if free <= target_free: + break # already in target range; try smaller chunk + if free <= target_free + chunk: + break # this chunk would overshoot; try smaller chunk + try: + t = torch.empty(chunk // elem_size, dtype=dtype, device=device) + pre_fillers.append(t) + except RuntimeError: + break # segment exhausted; try smaller chunk + + pool_free_after_fill = torch.npu.memory_reserved(0) - torch.npu.memory_allocated(0) + pre_bytes = sum(t.numel() * elem_size for t in pre_fillers) + print(f"\n[mte] Step 2: pre_fill") + print(f"[mte] tensors = {len(pre_fillers)}, total = {pre_bytes} bytes ({pre_bytes // 1024} KB)") + print(f"[mte] pool_free = {pool_free_after_fill} bytes (target [{in_bytes}, {target_free}] bytes)") + + return pre_fillers, pool_free_after_fill, seg_size + + +# ============================================================================= +# Test: MTE OOB via segment-boundary placement +# ============================================================================= +@pytest.mark.parametrize("BLOCK_M,BLOCK_N,M", [ + (4, 4096, 1), +]) +def test_mte_segment_boundary_oob(BLOCK_M, BLOCK_N, M): + """Regression: combined discrete mask load causes OOB on tail blocks. + + Verifies that DiscreteMaskAccessConversionPass correctly bounds + the memory copy to M rows (the contiguous range), not BLOCK_M rows (the full tile). + + Test outcome: + - Before fix: RuntimeError (MTE OOB) — the test would fail. + - After fix: no exception — the test passes. + """ + dtype = torch.float16 + device = 'npu' + elem_size = 2 # float16 + + in_bytes = M * BLOCK_N * elem_size # 8192 bytes + oob_bytes = (BLOCK_M - M) * BLOCK_N * elem_size # 24576 bytes + # TARGET_FREE: midpoint between in_bytes and oob_bytes. + # Ensures in_tensor fits AND gap < oob_bytes so OOB crosses segment boundary. + target_free = (in_bytes + oob_bytes) // 2 # 16384 bytes + chunk_max_bytes = 512 * 1024 # 512 KB + + print(f"\n[mte] BLOCK_M={BLOCK_M} BLOCK_N={BLOCK_N} M={M}") + print(f"[mte] in_bytes = {in_bytes} bytes (in_tensor: {M}×{BLOCK_N}×{elem_size})") + print(f"[mte] oob_bytes = {oob_bytes} bytes (unfixed copy: {BLOCK_M}×{BLOCK_N}×{elem_size} - in_bytes)") + print(f"[mte] target_free = {target_free} bytes (must satisfy in_bytes < target_free < oob_bytes)") + + torch.npu.empty_cache() + + pre_fillers = [] + in_tensor = None + + try: + pre_fillers, pool_free_after_fill, _ = _fill_segment_to_boundary(dtype, device, in_bytes, target_free, + chunk_max_bytes) + except Exception as exc: + torch.npu.empty_cache() + pytest.skip(f"Memory layout setup failed (allocator behaviour may differ): {exc}") + + # Verify pre_fill achieved the required free-space window. + if not (in_bytes <= pool_free_after_fill <= target_free): + for t in reversed(pre_fillers): + del t + torch.npu.empty_cache() + pytest.skip(f"pre_fill did not reach target range [{in_bytes}, {target_free}] bytes; " + f"got {pool_free_after_fill} bytes. " + f"Skipping MTE check (NPU allocator behaviour may differ).") + + try: + # Step 3: allocate in_tensor — lands at the very end of the segment. + # NPU 512-byte alignment means the allocator consumes + # in_bytes + 512 = 8704 bytes, leaving gap ≈ target_free - 8704 = 7680 bytes. + in_tensor = torch.ones(M * BLOCK_N, dtype=dtype, device=device).view(M, BLOCK_N) + + gap = torch.npu.memory_reserved(0) - torch.npu.memory_allocated(0) + print(f"\n[mte] Step 3: in_tensor") + print(f"[mte] address = [{in_tensor.data_ptr():#x}, {in_tensor.data_ptr() + in_bytes:#x})") + print(f"[mte] gap = {gap} bytes (in_tensor end → segment end)") + + if oob_bytes <= gap: + pytest.skip(f"gap ({gap} bytes) >= oob_bytes ({oob_bytes} bytes): " + f"OOB would not cross the segment boundary. " + f"Skipping MTE check.") + print(f"[mte] oob_bytes({oob_bytes} B) > gap({gap} B) → MTE expected if unfixed ✓") + + # Step 4: run kernel + num_pids_m = math.ceil(M / BLOCK_M) + print(f"\n[mte] Step 4: kernel (grid=({num_pids_m},))") + cont_disc_oob_inplace_2d_kernel[(num_pids_m, )](in_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + torch.npu.synchronize() + print("[mte] PASSED: fix is effective, no OOB.") + + except RuntimeError as exc: + pytest.fail(f"MTE OOB triggered — DiscreteMaskAccessConversionPass fix " + f"may not be applied or is incomplete.\nError: {exc}") + + finally: + if in_tensor is not None: + del in_tensor + for t in reversed(pre_fillers): + del t + torch.npu.empty_cache() + print("[mte] Memory released.") diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_overlap_mask.py b/third_party/ascend/unittest/pytest_ut/test_discrete_overlap_mask.py new file mode 100644 index 0000000000..2d4f6437dc --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_overlap_mask.py @@ -0,0 +1,235 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + +# --------------------------------------------------------------------------- +# Fixed Constants +# --------------------------------------------------------------------------- +_M_ROWS = 16 # Rows per program +_OFFS = 8 # Write window offset for two programs +_HALF = 12 # Mask threshold +_NUM_C = 24 # Rows of matrix C (= OFFS + M_ROWS, ensure pid=1 write window not out of bounds) + +assert _OFFS < _HALF < _M_ROWS, "OFFS < HALF < M_ROWS to ensure True/False on both sides" +assert _NUM_C >= _OFFS + _M_ROWS, "NUM_C must accommodate upper bound of pid=1 write window" + + +# --------------------------------------------------------------------------- +# Triton Kernel +# --------------------------------------------------------------------------- +@triton.jit +def _copy_matrix_kernel( + A_ptr, + idx_ptr, + C_ptr, + idx_stride, + A_row_stride, + A_col_stride, + C_row_stride, + C_col_stride, + BLOCK_N: tl.constexpr, + HALF: tl.constexpr, +): + """ + Discrete memory access + overlapping write window + runtime mask. + + pid=0 write window: rows [0, 15], mask=True when idx < HALF + pid=1 write window: rows [8, 23], mask=True when idx >= HALF + Overlap region : rows [8, 15] -> triggers load-select-store RMW + """ + program_id = tl.program_id(axis=0).to(tl.int64) + N_id = tl.program_id(axis=1).to(tl.int64) + + OFFS: tl.constexpr = 8 + M_ROWS: tl.constexpr = 16 + + N_BLOCK = N_id * BLOCK_N + tl.arange(0, BLOCK_N) # shape: (BLOCK_N,) + M_BLOCK = tl.arange(0, M_ROWS) # shape: (M_ROWS,) + + # Discrete row indices (loaded at runtime -> mask cannot be statically analyzed) + idx = tl.load(idx_ptr + program_id * idx_stride + M_BLOCK) + + # Runtime mask (generates scf.if -> compiler converts to load-select-store) + if program_id == 0: + mask = idx < HALF + else: + mask = idx >= HALF + + val = tl.load( + A_ptr + idx[:, None] * A_row_stride + tl.arange(0, BLOCK_N)[None, :] * A_col_stride, + mask=mask[:, None], + ) + + # Write to C (mask=False rows rely on load-select-store to preserve original values) + tl.store( + C_ptr + (OFFS * program_id + M_BLOCK[:, None]) * C_row_stride + N_BLOCK[None, :] * C_col_stride, + val, + mask=mask[:, None], + ) + # C (24 × N) Program 0 Program 1 + # Row 0~7 ──────────── write value (mask=True) ── Not involved + # Row 8~11 ──────────── write value (mask=True) ── Not written (mask=False, overwritten by P0 to 0) + # Row 12~15 ──────────── Not written (mask=False) ── write value (mask=True) + # Row 16~23 ──────────── Not involved ── write value (mask=True) + + +# --------------------------------------------------------------------------- +# Helper: Construct discrete index vector +# --------------------------------------------------------------------------- +def _make_idx(device: str) -> torch.Tensor: + """ + Construct 2x16 index matrix that meets mask distribution requirements. + + pid=0 row (idx0): + First HALF=12 values ∈ [0, HALF) -> mask=True + Last 4 values ∈ [HALF, M_ROWS) -> mask=False + pid=1 row (idx1): + First 4 values ∈ [OFFS, OFFS+4) -> mask=False + Last HALF=12 values ∈ [HALF, HALF*2) -> mask=True + """ + + def shuffle_quads(lst: list) -> list: + """Reverse each group of 4 elements (ignore if less than 4).""" + out = lst[:] + for i in range(0, len(out) - 3, 4): + out[i], out[i + 1], out[i + 2], out[i + 3] = \ + out[i + 3], out[i + 2], out[i + 1], out[i] + return out + + num_false = _M_ROWS - _HALF # = 4 + + seg0_true = shuffle_quads(list(range(0, _HALF))) # 12 values, < 12 + seg0_false = shuffle_quads(list(range(_HALF, _HALF + num_false))) # 4 values, >= 12 + idx0 = seg0_true + seg0_false # Total length 16 + + seg1_false = shuffle_quads(list(range(_OFFS, _OFFS + num_false))) # 4 values, < 12 + seg1_true = shuffle_quads(list(range(_HALF, _HALF + _HALF))) # 12 values, >= 12 + idx1 = seg1_false + seg1_true # Total length 16 + + assert len(idx0) == _M_ROWS, f"idx0 length error: {len(idx0)}" + assert len(idx1) == _M_ROWS, f"idx1 length error: {len(idx1)}" + assert all(v < _HALF for v in seg0_true), "pid=0 True segment should all be < HALF" + assert all(v >= _HALF for v in seg0_false), "pid=0 False segment should all be >= HALF" + assert all(v < _HALF for v in seg1_false), "pid=1 False segment should all be < HALF" + assert all(v >= _HALF for v in seg1_true), "pid=1 True segment should all be >= HALF" + + return torch.tensor([idx0, idx1], dtype=torch.int32, device=device) + + +# --------------------------------------------------------------------------- +# Dtype Mapping +# --------------------------------------------------------------------------- +_DTYPE_MAP = { + 'int32': torch.int32, + 'float32': torch.float32, + 'float16': torch.float16, + 'int16': torch.int16, +} + + +# --------------------------------------------------------------------------- +# Single Execution + Verification +# --------------------------------------------------------------------------- +def _run_once(BLOCK_N: int, dtype_str: str) -> None: + """ + Execute kernel once and verify results. + + Expectations: + C[0:HALF, :] all 0 -- pid=0 writes rows [0,HALF) of A (all 0) + C[HALF:NUM_C, :] all 1 -- pid=1 writes rows [HALF, NUM_C) of A (all 1) + """ + dev = 'npu' + td = _DTYPE_MAP[dtype_str] + zero_val = 0.0 if dtype_str.startswith('float') else 0 + one_val = 1.0 if dtype_str.startswith('float') else 1 + + # A: First HALF rows all 0, last HALF rows all 1 + A = torch.zeros((_NUM_C, BLOCK_N), dtype=td, device=dev) + A[_HALF:, :] = one_val + + idx = _make_idx(dev) + + # C: Fill all with 2 + C = torch.full((_NUM_C, BLOCK_N), 2, dtype=td, device=dev) + + grid = (2, 1) + _copy_matrix_kernel[grid]( + A_ptr=A, + idx_ptr=idx, + C_ptr=C, + idx_stride=idx.stride(0), + A_row_stride=A.stride(0), + A_col_stride=A.stride(1), + C_row_stride=C.stride(0), + C_col_stride=C.stride(1), + BLOCK_N=BLOCK_N, + HALF=_HALF, + enable_sync_block_lock=True, + ) + + # Verification + assert torch.all(C[:_HALF] == zero_val), ( + f"[dtype={dtype_str}, BLOCK_N={BLOCK_N}] " + f"C[:HALF] should all be {zero_val}, actual unique values: {C[:_HALF].unique().tolist()}") + assert torch.all(C[_HALF:] == one_val), ( + f"[dtype={dtype_str}, BLOCK_N={BLOCK_N}] " + f"C[HALF:] should all be {one_val}, actual unique values: {C[_HALF:].unique().tolist()}") + + +@pytest.mark.parametrize("param_list", [ + # --- int32 --- + (16, 'int32'), + (32, 'int32'), + (64, 'int32'), + # --- float32 --- + (16, 'float32'), + (32, 'float32'), + (64, 'float32'), +]) +def test_discrete_overlap_mask(param_list): + """ + Verify no precision issues in discrete access + overlapping write window + runtime mask scenario. + + Race condition errors are probabilistic. Each parameter combination is executed 10 times + to fully cover concurrent timing scenarios. + If sync_block_lock fix is effective, all 10 runs pass; if race condition exists, assertion failure + occurs with high probability. + """ + BLOCK_N, dtype_str = param_list + for _ in range(10): + _run_once(BLOCK_N, dtype_str) + + +if __name__ == "__main__": + configs = [ + (32, 'int32'), + (32, 'float32'), + ] + for BLOCK_N, dtype_str in configs: + print(f"Testing BLOCK_N={BLOCK_N}, dtype={dtype_str} ...", end=" ", flush=True) + for _ in range(10): + _run_once(BLOCK_N, dtype_str) + print("PASS (10 rounds)") + print("All tests passed.") diff --git a/third_party/ascend/unittest/pytest_ut/test_dot.py b/third_party/ascend/unittest/pytest_ut/test_dot.py index a4837d0449..635df74c5e 100644 --- a/third_party/ascend/unittest/pytest_ut/test_dot.py +++ b/third_party/ascend/unittest/pytest_ut/test_dot.py @@ -26,6 +26,16 @@ import test_common +@pytest.fixture(scope="function") +def restore_npu_hf32_setting(): + original_allow_hf32 = torch_npu.npu.matmul.allow_hf32 + try: + torch_npu.npu.matmul.allow_hf32 = True + yield + finally: + torch_npu.npu.matmul.allow_hf32 = original_allow_hf32 + + def torch_dot_None(x0, x1): res = torch.matmul(x0, x1) return res @@ -49,6 +59,60 @@ def triton_dot_2_None(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.constexpr tl.store(output_ptr + oidx, ret, mask=out_mask) +@triton.jit +def triton_dot_2_allow_tf32(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr): + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + + x_mask = (bidx[:, None] < B) & (cidx[None, :] < C) + y_mask = (cidx[:, None] < C) & (didx[None, :] < D) + out_mask = (bidx[:, None] < B) & (didx[None, :] < D) + Xidx = bidx[:, None] * C + cidx[None, :] + Yidx = cidx[:, None] * D + didx[None, :] + X = tl.load(x_ptr + Xidx, mask=x_mask, other=0.0) + Y = tl.load(y_ptr + Yidx, mask=y_mask, other=0.0) + ret = tl.dot(X, Y, allow_tf32=True) + oidx = bidx[:, None] * D + didx[None, :] + tl.store(output_ptr + oidx, ret, mask=out_mask) + + +@triton.jit +def triton_dot_2_input_tf32(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr): + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + + x_mask = (bidx[:, None] < B) & (cidx[None, :] < C) + y_mask = (cidx[:, None] < C) & (didx[None, :] < D) + out_mask = (bidx[:, None] < B) & (didx[None, :] < D) + Xidx = bidx[:, None] * C + cidx[None, :] + Yidx = cidx[:, None] * D + didx[None, :] + X = tl.load(x_ptr + Xidx, mask=x_mask, other=0.0) + Y = tl.load(y_ptr + Yidx, mask=y_mask, other=0.0) + ret = tl.dot(X, Y, input_precision="tf32") + oidx = bidx[:, None] * D + didx[None, :] + tl.store(output_ptr + oidx, ret, mask=out_mask) + + +@triton.jit +def triton_dot_2_ignore_tf32(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr): + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + + x_mask = (bidx[:, None] < B) & (cidx[None, :] < C) + y_mask = (cidx[:, None] < C) & (didx[None, :] < D) + out_mask = (bidx[:, None] < B) & (didx[None, :] < D) + Xidx = bidx[:, None] * C + cidx[None, :] + Yidx = cidx[:, None] * D + didx[None, :] + X = tl.load(x_ptr + Xidx, mask=x_mask, other=0.0) + Y = tl.load(y_ptr + Yidx, mask=y_mask, other=0.0) + ret = tl.dot(X, Y, input_precision="hf32") + oidx = bidx[:, None] * D + didx[None, :] + tl.store(output_ptr + oidx, ret, mask=out_mask) + + testlist1 = [ (10, 13, 35, 39), ] @@ -60,12 +124,58 @@ def triton_dot_2_None(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.constexpr ] +@pytest.mark.skip(reason="not supported after the NPUIR is updated in April, and will be fixed later") @pytest.mark.parametrize("B, C, D", testlist2) @pytest.mark.parametrize("sigtype", typelist) -def test_dot_2(sigtype, B, C, D): +def test_dot_2(restore_npu_hf32_setting, sigtype, B, C, D): x = test_common.generate_tensor((B, C), sigtype).npu() y = test_common.generate_tensor((C, D), sigtype).npu() z_ref = torch_dot_None(x, y).to(torch.float32) z = torch.zeros((B, D), dtype=torch.float32).npu() triton_dot_2_None[1, 1, 1](z, x, y, B, C, D) test_common.validate_cmp(sigtype, z, z_ref) + + +@pytest.mark.xfail( + reason="Temporarily disabled: TA backend does not support allow_tf32 yet. Will be fixed in follow-up.") +@pytest.mark.parametrize("B, C, D", testlist2) +@pytest.mark.parametrize("sigtype", typelist) +def test_dot_2_allow_tf32(restore_npu_hf32_setting, sigtype, B, C, D): + x = test_common.generate_tensor((B, C), sigtype).npu() + y = test_common.generate_tensor((C, D), sigtype).npu() + z_ref = torch_dot_None(x, y).to(torch.float32) + z = torch.zeros((B, D), dtype=torch.float32).npu() + triton_dot_2_allow_tf32[1, 1, 1](z, x, y, B, C, D) + test_common.validate_cmp(sigtype, z, z_ref) + + +@pytest.mark.skip(reason="not supported after the NPUIR is updated in April, and will be fixed later") +@pytest.mark.parametrize("B, C, D", testlist2) +@pytest.mark.parametrize("sigtype", typelist) +def test_dot_2_input_tf32(restore_npu_hf32_setting, sigtype, B, C, D): + x = test_common.generate_tensor((B, C), sigtype).npu() + y = test_common.generate_tensor((C, D), sigtype).npu() + z_ref = torch_dot_None(x, y).to(torch.float32) + z = torch.zeros((B, D), dtype=torch.float32).npu() + triton_dot_2_input_tf32[1, 1, 1](z, x, y, B, C, D) + test_common.validate_cmp(sigtype, z, z_ref) + + +@pytest.mark.parametrize("B, C, D", testlist2) +@pytest.mark.parametrize("sigtype", typelist) +def test_dot_2_ignore_tf32(sigtype, B, C, D): + input_type = "bfloat16" + x = test_common.generate_tensor((B, C), input_type).npu() + y = test_common.generate_tensor((C, D), input_type).npu() + z = torch.zeros((B, D), dtype=torch.float32).npu() + + original_allow_hf32 = torch_npu.npu.matmul.allow_hf32 + try: + torch_npu.npu.matmul.allow_hf32 = False + z_ref = torch_dot_None(x.to(torch.float32), y.to(torch.float32)).to(torch.float32) + + finally: + torch_npu.npu.matmul.allow_hf32 = original_allow_hf32 + + triton_dot_2_ignore_tf32[1, 1, 1](z, x, y, B, C, D) + test_common.validate_cmp(sigtype, z, z_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_erfinv.py b/third_party/ascend/unittest/pytest_ut/test_erfinv.py index 9e45a5c553..51d155a115 100644 --- a/third_party/ascend/unittest/pytest_ut/test_erfinv.py +++ b/third_party/ascend/unittest/pytest_ut/test_erfinv.py @@ -82,3 +82,27 @@ def test_all_blocks_parallel(param_list, monkeypatch): triton_erfinv[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub) test_common.validate_cmp(dtype, y_cal, y_ref) monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") + + +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], +]) +def test_auto_blockify(param_list, monkeypatch): + monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") + dtype, shape, ncore, xblock, xblock_sub = param_list + x = test_common.generate_tensor(shape, dtype).npu() + x[0][0][0] = 1 # erfinv(1) -> ∞ + x[0][0][1] = -1 # erfinv(-1) -> -∞ + + # Avoid numerical instability near ±1 + # Move values in (threshold, 1) to threshold and (-1, -threshold) to -threshold + threshold = 1 - 1.1e-4 + too_close_pos = (x > threshold) & (x < 1) + too_close_neg = (x < -threshold) & (x > -1) + x[too_close_pos] = threshold + x[too_close_neg] = -threshold + y_ref = torch.erfinv(x).npu() + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_erfinv[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub, auto_blockify_size=ncore) + test_common.validate_cmp(dtype, y_cal, y_ref) + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") diff --git a/third_party/ascend/unittest/pytest_ut/test_expm1.py b/third_party/ascend/unittest/pytest_ut/test_expm1.py index 90665b030f..7a06320d61 100644 --- a/third_party/ascend/unittest/pytest_ut/test_expm1.py +++ b/third_party/ascend/unittest/pytest_ut/test_expm1.py @@ -45,6 +45,7 @@ def triton_expm1(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constex tl.store(out_ptr0 + (x0), tmp1, None) +@pytest.mark.skip(reason="expm1 failed sometimes, wait for fix") @pytest.mark.parametrize('param_list', [ ['float32', (2, 4096, 8), 2, 32768, 1024], ]) diff --git a/third_party/ascend/unittest/generalization_cases/test_tan.py b/third_party/ascend/unittest/pytest_ut/test_fast_dividef.py similarity index 69% rename from third_party/ascend/unittest/generalization_cases/test_tan.py rename to third_party/ascend/unittest/pytest_ut/test_fast_dividef.py index 4d6b6454cb..be534a5fdb 100644 --- a/third_party/ascend/unittest/generalization_cases/test_tan.py +++ b/third_party/ascend/unittest/pytest_ut/test_fast_dividef.py @@ -1,62 +1,59 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math -import triton.language.extra.ascend.libdevice as libdevice - - -def torch_pointwise(x0): - res = torch.tan(x0) - return res - - -@triton.jit -def triton_tan(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base1 = tl.arange(0, XBLOCK_SUB) - loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop1 in range(loops1): - x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 - x0 = offset + (loop1 * XBLOCK_SUB) + base1 - tmp0 = tl.load(in_ptr0 + (x0), None) - tmp2 = libdevice.tan(tmp0) - tl.store(out_ptr0 + (x0), tmp2, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['float32', 'float16']) -def test_case(dtype, shape): - x0 = test_common.generate_tensor(shape, dtype).npu() - - numel = x0.numel() - ncore = 1 if numel <= 32 else 32 - xblock = math.ceil(numel / ncore) - xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) - - y_ref = torch_pointwise(x0) - y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_tan[ncore, 1, 1](x0, y_cal, xblock, xblock_sub) - test_common.validate_cmp(dtype, y_cal, y_ref) +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common +import triton.language.extra.cann.libdevice as libdevice + + +def torch_pointwise(x0, x1): + res = x0 / x1 + return res + + +@triton.jit +def triton_fast_dividef(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = libdevice.fast_dividef(tmp0, tmp1) + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], +]) +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_fast_dividef[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_log1p.py b/third_party/ascend/unittest/pytest_ut/test_fast_expf.py similarity index 69% rename from third_party/ascend/unittest/generalization_cases/test_log1p.py rename to third_party/ascend/unittest/pytest_ut/test_fast_expf.py index fa37cbd298..be534a5fdb 100644 --- a/third_party/ascend/unittest/generalization_cases/test_log1p.py +++ b/third_party/ascend/unittest/pytest_ut/test_fast_expf.py @@ -1,62 +1,59 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math -import triton.language.extra.ascend.libdevice as libdevice - - -def torch_pointwise(x0): - res = torch.log1p(x0) - return res - - -@triton.jit -def triton_log1p(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base1 = tl.arange(0, XBLOCK_SUB) - loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop1 in range(loops1): - x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 - x0 = offset + (loop1 * XBLOCK_SUB) + base1 - tmp0 = tl.load(in_ptr0 + (x0), None) - tmp2 = libdevice.log1p(tmp0) - tl.store(out_ptr0 + (x0), tmp2, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['float32', 'float16']) -def test_case(dtype, shape): - x0 = test_common.generate_tensor(shape, dtype).npu() - - numel = x0.numel() - ncore = 1 if numel <= 32 else 32 - xblock = math.ceil(numel / ncore) - xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) - - y_ref = torch_pointwise(x0) - y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_log1p[ncore, 1, 1](x0, y_cal, xblock, xblock_sub) - test_common.validate_cmp(dtype, y_cal, y_ref) +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common +import triton.language.extra.cann.libdevice as libdevice + + +def torch_pointwise(x0, x1): + res = x0 / x1 + return res + + +@triton.jit +def triton_fast_dividef(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = libdevice.fast_dividef(tmp0, tmp1) + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], +]) +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_fast_dividef[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_gamma.py b/third_party/ascend/unittest/pytest_ut/test_gamma.py index 388ed36af2..b3cdf3b9af 100644 --- a/third_party/ascend/unittest/pytest_ut/test_gamma.py +++ b/third_party/ascend/unittest/pytest_ut/test_gamma.py @@ -68,3 +68,19 @@ def test_all_blocks_parallel(param_list, monkeypatch): triton_gamma[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub) test_common.validate_cmp(dtype, y_cal, y_ref) monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") + + +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 2048, 8), 2, 32768, 512], +]) +def test_auto_blockify(param_list, monkeypatch): + monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") + dtype, shape, ncore, xblock, xblock_sub = param_list + x = torch.abs(test_common.generate_tensor(shape, dtype)) + x_np = x.cpu().numpy() + x = x.npu() + y_ref = torch.from_numpy(gamma(x_np)).to(x.device).to(x.dtype).npu() + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_gamma[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub, auto_blockify_size=ncore) + test_common.validate_cmp(dtype, y_cal, y_ref) + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") diff --git a/third_party/ascend/unittest/pytest_ut/test_if_advance.py b/third_party/ascend/unittest/pytest_ut/test_if_advance.py new file mode 100644 index 0000000000..2c63e7ebba --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_if_advance.py @@ -0,0 +1,39 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al + + +@triton.jit +def triton_if_advance_kernel(in_ptr0, in_ptr1, out_ptr, xnumel, ynumel, k_loops, XBLOCK: tl.constexpr, + YBLOCK: tl.constexpr): + + K_block_ptr = tl.make_block_ptr(base=in_ptr0, shape=(xnumel, ynumel), strides=(ynumel, 1), offsets=(0, 0), + block_shape=(XBLOCK, YBLOCK), order=(1, 0)) + V_block_ptr = tl.make_block_ptr(base=in_ptr1, shape=(ynumel, xnumel), strides=(xnumel, 1), offsets=(0, 0), + block_shape=(YBLOCK, XBLOCK), order=(1, 0)) + O_block_ptr = tl.make_block_ptr(base=out_ptr, shape=(xnumel, xnumel), strides=(xnumel, 1), offsets=(0, 0), + block_shape=(XBLOCK, XBLOCK), order=(1, 0)) + res = tl.zeros([XBLOCK, XBLOCK], tl.float32) + for i in range(0, k_loops): + if i > 0: + K_block_ptr = tl.advance(K_block_ptr, (0, YBLOCK)) + V_block_ptr = tl.advance(V_block_ptr, (YBLOCK, 0)) + a = tl.load(K_block_ptr) + b = tl.load(V_block_ptr) + res = tl.dot(a, b, acc=res) + tl.store(O_block_ptr, res) + + +def test_if_advance(): + x = torch.randn((64, 256), dtype=torch.float32, device="npu") + y = torch.randn((256, 64), dtype=torch.float32, device="npu") + out_tri = torch.empty((64, 64), dtype=torch.float32, device="npu") + out_std = torch.empty((64, 64), dtype=torch.float32, device="npu") + torch.matmul(x, y, out=out_std) + triton_if_advance_kernel[1, 1, 1](x, y, out_tri, 64, 256, 4, 64, 64) + torch.testing.assert_close(out_std, out_tri, atol=1e-2, rtol=1e-2) + + +test_if_advance() diff --git a/third_party/ascend/unittest/pytest_ut/test_if_load.py b/third_party/ascend/unittest/pytest_ut/test_if_load.py new file mode 100644 index 0000000000..7932366cb4 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_if_load.py @@ -0,0 +1,82 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + +import torch +import torch_npu + + +@triton.jit +def triton_if_load(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): + base1 = tl.arange(0, XBLOCK) + index = base1 + if tl.program_id(0) == 0: + base1 = base1 * 1 + else: + base1 = base1 * 2 + tmp0 = tl.load(in_ptr0 + base1, base1 < XBLOCK, other=0.0) + tl.store(out_ptr0 + index, tmp0, None) + + +@triton.jit +def triton_for_if_load(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + base1 = tl.arange(0, XBLOCK_SUB) + index = base1 + loops = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for i in range(loops): + base1 = base1 + i * XBLOCK_SUB + index = index + i * XBLOCK_SUB + if tl.program_id(0) != 0: + base1 = base1 + 1 + + tmp0 = tl.load(in_ptr0 + base1, base1 < XBLOCK, other=0.0) + tl.store(out_ptr0 + index, tmp0, None) + + +@pytest.mark.parametrize('param_list', [ + ['float32', (32, ), 32], +]) +def test_if_load(param_list): + dtype, shape, xblock = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = x0.clone() + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_if_load[(1, )](x0, y_cal, xblock) + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@pytest.mark.parametrize('param_list', [ + ['float32', (32, ), 32, 16], +]) +def test_if_load(param_list): + dtype, shape, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = x0.clone() + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_for_if_load[(1, )](x0, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_implicit_atomic.py b/third_party/ascend/unittest/pytest_ut/test_implicit_atomic.py new file mode 100644 index 0000000000..b7a539f392 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_implicit_atomic.py @@ -0,0 +1,140 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import pytest +import triton +import test_common +import triton.language as tl + +types_all = [ + (torch.float32, 'float32'), +] + + +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def addptr_implicit_perm_atomic_add_2d( + ptr, + out, + ynumel, + xnumel, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] # [1, YB] + x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] # [XB, 1] + + val = 1.0 + (x.to(tl.float32) * 0.01) + (y.to(tl.float32) * 0.001) # [XB, YB] + xmask = x < xnumel + ymask = y < ynumel + old = tl.atomic_add(ptr + (x + 4 * y), val, xmask & ymask) + + tl.store(out + (x + 4 * y), old) + + +@triton.jit +def addptr_implicit_perm_atomic_cas_2d( + ptr, + out, + cmp_ptr, + val_ptr, + ynumel, + xnumel, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] + x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] + + xmask = x < xnumel + ymask = y < ynumel + mask = xmask & ymask + + offset = x + 4 * y + + cmp = tl.load(cmp_ptr + offset, mask=mask, other=0.0).to(tl.float32) + val = tl.load(val_ptr + offset, mask=mask, other=0.0).to(tl.float32) + + old = tl.atomic_cas(ptr + offset, cmp, val) + + tl.store(out + offset, old, mask=mask) + + +@pytest.mark.parametrize('dtype,sigtype', types_all) +@pytest.mark.parametrize('xnumel, ynumel, XBLOCK, YBLOCK', [(4, 512, 4, 64)]) +def test_addptr_implicit_perm_atomic_add_2d( + dtype, + sigtype, + xnumel, + ynumel, + XBLOCK, + YBLOCK, +): + in_ptr = torch.zeros((ynumel * 4, ), dtype=dtype).npu() + out_ptr = torch.ones_like(in_ptr) + + grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) + addptr_implicit_perm_atomic_add_2d[grid](in_ptr, out_ptr, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK) + + y_idx = torch.arange(ynumel).unsqueeze(1).npu() + x_idx = torch.arange(xnumel).unsqueeze(0).npu() + idx = (x_idx + 4 * y_idx).reshape(-1) + torch.testing.assert_close(out_ptr[idx], torch.zeros_like(out_ptr[idx])) + + val_ref = (1.0 + 0.01 * x_idx.to(torch.float32) + 0.001 * y_idx.to(torch.float32)).reshape(-1) + torch.testing.assert_close(in_ptr[idx], val_ref, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize('dtype,sigtype', types_all) +@pytest.mark.parametrize('xnumel, ynumel, XBLOCK, YBLOCK', [(4, 512, 4, 64)]) +def test_addptr_implicit_perm_atomic_cas_2d( + dtype, + sigtype, + xnumel, + ynumel, + XBLOCK, + YBLOCK, +): + in_ptr = torch.full((ynumel * 4, ), 2, dtype=dtype).npu() + out_ptr = torch.full((ynumel * 4, ), 1, dtype=dtype).npu() + cmp_ptr = torch.full((ynumel * 4, ), 2, dtype=dtype).npu() + val_ptr = torch.full((ynumel * 4, ), 1, dtype=dtype).npu() + + grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) + addptr_implicit_perm_atomic_cas_2d[grid](in_ptr, out_ptr, cmp_ptr, val_ptr, ynumel, xnumel, YBLOCK=YBLOCK, + XBLOCK=XBLOCK) + + y_idx = torch.arange(ynumel).unsqueeze(1).npu() + x_idx = torch.arange(xnumel).unsqueeze(0).npu() + idx = (x_idx + 4 * y_idx).reshape(-1) + + torch.testing.assert_close(out_ptr[idx], torch.full_like(out_ptr[idx], 2.0)) + + torch.testing.assert_close(in_ptr[idx], torch.ones_like(in_ptr[idx])) + + +if __name__ == '__main__': + case_2d = (4, 512, 4, 64) + test_addptr_implicit_perm_atomic_add_2d(*types_all[0], *case_2d) + test_addptr_implicit_perm_atomic_cas_2d(*types_all[0], *case_2d) diff --git a/third_party/ascend/unittest/pytest_ut/test_implicit_permute.py b/third_party/ascend/unittest/pytest_ut/test_implicit_permute.py new file mode 100644 index 0000000000..9aaafa2371 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_implicit_permute.py @@ -0,0 +1,1038 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import pytest +import triton +import triton.language as tl +import test_common + +types_all = [ + (torch.float32, 'float32'), +] + +case_2d = [ + # X, Y, XBLOCK, YBLOCK + (512, 32, 4, 64), +] + +case_3d = [ + # X, Y, Z, XBLOCK, YBLOCK, ZBLOCK + (100, 40, 32, 10, 4, 4), +] + +case_4d = [ + # X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK + (100, 80, 20, 16, 20, 4, 4, 4), +] + + +# ---------------------------------------------------------- +# Triton kernel +# ---------------------------------------------------------- +@triton.jit +def addptr_implicit_perm_load_store_2d_static_stride(ptr, out, ynumel, xnumel, stride_y: tl.constexpr, + stride_x: tl.constexpr, YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr): + # logical indices (A^T view) + x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] + y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] + mask = (x < xnumel) & (y < ynumel) + + # IMPORTANT: + # ptr is a row-major A, but we interpret it as A^T via stride + offset = x * stride_x + y * stride_y + + val = tl.load(ptr + offset, mask) + tl.store(out + offset, val, mask) + + +@triton.jit +def addptr_implicit_perm_load_store_2d( + ptr, + out, + ynumel, + xnumel, + stride_y, + stride_x, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + # logical indices (A^T view) + x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] + y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] + + mask = (x < xnumel) & (y < ynumel) + + # IMPORTANT: + # ptr is a row-major A, but we interpret it as A^T via stride + offset = x * stride_x + y * stride_y + + val = tl.load(ptr + offset, mask=mask) + tl.store(out + offset, val, mask=mask) + + +@triton.jit +def addptr_implicit_perm_load_store_3d_static_stride( + ptr, + out, + znumel, + ynumel, + xnumel, + stride_z: tl.constexpr, + stride_y: tl.constexpr, + stride_x: tl.constexpr, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x = pid_x * XBLOCK + tl.arange(0, XBLOCK)[:, None, None] + y = pid_y * YBLOCK + tl.arange(0, YBLOCK)[None, :, None] + z = pid_z * ZBLOCK + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + offset = x * stride_x + y * stride_y + z * stride_z + val = tl.load(ptr + offset, mask=mask) + tl.store(out + offset, val, mask=mask) + + +@triton.jit +def addptr_implicit_perm_load_store_3d( + ptr, + out, + znumel, + ynumel, + xnumel, + stride_z, + stride_y, + stride_x, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x = pid_x * XBLOCK + tl.arange(0, XBLOCK)[:, None, None] + y = pid_y * YBLOCK + tl.arange(0, YBLOCK)[None, :, None] + z = pid_z * ZBLOCK + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + offset = x * stride_x + y * stride_y + z * stride_z + val = tl.load(ptr + offset, mask=mask) + tl.store(out + offset, val, mask=mask) + + +@triton.jit +def addptr_implicit_perm_load_store_4d_static_stride( + ptr, + out, + wnumel, + znumel, + ynumel, + xnumel, + stride_w: tl.constexpr, + stride_z: tl.constexpr, + stride_y: tl.constexpr, + stride_x: tl.constexpr, + WBLOCK: tl.constexpr, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid0 = tl.program_id(0) # covers (w, x) + pid1 = tl.program_id(1) # y + pid2 = tl.program_id(2) # z + + xblocks_per_w = (xnumel + XBLOCK - 1) // XBLOCK + + w_pid = pid0 // xblocks_per_w + x_pid = pid0 - w_pid * xblocks_per_w + + x0 = x_pid * XBLOCK + y0 = pid1 * YBLOCK + z0 = pid2 * ZBLOCK + w0 = w_pid * WBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :, None] + w = w0 + tl.arange(0, WBLOCK)[None, None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) & (w < wnumel) + + offset = x * stride_x + y * stride_y + z * stride_z + w * stride_w + val = tl.load(ptr + offset, mask=mask, other=0.0) + tl.store(out + offset, val, mask=mask) + + +@triton.jit +def addptr_implicit_perm_load_store_4d( + ptr, + out, + wnumel, + znumel, + ynumel, + xnumel, + stride_w, + stride_z, + stride_y, + stride_x, + WBLOCK: tl.constexpr, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid0 = tl.program_id(0) # covers (w, x) + pid1 = tl.program_id(1) # y + pid2 = tl.program_id(2) # z + + xblocks_per_w = (xnumel + XBLOCK - 1) // XBLOCK + + w_pid = pid0 // xblocks_per_w + x_pid = pid0 - w_pid * xblocks_per_w + + x0 = x_pid * XBLOCK + y0 = pid1 * YBLOCK + z0 = pid2 * ZBLOCK + w0 = w_pid * WBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :, None] + w = w0 + tl.arange(0, WBLOCK)[None, None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) & (w < wnumel) + + offset = x * stride_x + y * stride_y + z * stride_z + w * stride_w + val = tl.load(ptr + offset, mask=mask, other=0.0) + tl.store(out + offset, val, mask=mask) + + +@triton.jit +def make_tensor_ptr_implicit_perm_load_store_2d_static_stride(ptr, out, ynumel, xnumel, STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr): + y0 = tl.program_id(1) * YBLOCK + x0 = tl.program_id(0) * XBLOCK + y = y0 + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] + x = x0 + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] + xmask = x < xnumel + ymask = y < ynumel + mask = xmask & ymask + + tptr = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel), + strides=(STRIDE_X, STRIDE_Y), + offsets=(x0, y0), + block_shape=(XBLOCK, YBLOCK), + order=(0, 1), + ) + + val = tl.load(tptr) + tl.store(out + (x * STRIDE_X + STRIDE_Y * y), val, mask=mask) + + +@triton.jit +def make_tensor_ptr_implicit_perm_load_store_3d_static_stride( + ptr, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z: tl.constexpr, + STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x0 = pid_x * XBLOCK + y0 = pid_y * YBLOCK + z0 = pid_z * ZBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + tptr = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel, znumel), + strides=(STRIDE_X, STRIDE_Y, STRIDE_Z), + offsets=(x0, y0, z0), + block_shape=(XBLOCK, YBLOCK, ZBLOCK), + order=(0, 1, 2), + ) + + val = tl.load(tptr) + tl.store(out + (x * STRIDE_X + y * STRIDE_Y + z * STRIDE_Z), val, mask=mask) + + +@triton.jit +def make_tensor_ptr_implicit_perm_load_store_3d( + ptr, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z, + STRIDE_Y, + STRIDE_X, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x0 = pid_x * XBLOCK + y0 = pid_y * YBLOCK + z0 = pid_z * ZBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + tptr = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel, znumel), + strides=(STRIDE_X, STRIDE_Y, STRIDE_Z), + offsets=(x0, y0, z0), + block_shape=(XBLOCK, YBLOCK, ZBLOCK), + order=(0, 1, 2), + ) + + val = tl.load(tptr) + tl.store(out + (x * STRIDE_X + y * STRIDE_Y + z * STRIDE_Z), val, mask=mask) + + +@triton.jit +def make_tensor_ptr_implicit_perm_load_3d_static_stride( + ptr, + out, + znumel, # logical z (== X) + ynumel, # logical y (== Y) + xnumel, # logical x (== Z) + STRIDE_Z: tl.constexpr, + STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, + # out is row-major with shape (xnumel, ynumel, znumel) + OUT_STRIDE0: tl.constexpr, # = ynumel*znumel + OUT_STRIDE1: tl.constexpr, # = znumel + OUT_STRIDE2: tl.constexpr, # = 1 + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x0 = pid_x * XBLOCK + y0 = pid_y * YBLOCK + z0 = pid_z * ZBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + # load: implicit permute view + tptr = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel, znumel), + strides=(STRIDE_X, STRIDE_Y, STRIDE_Z), + offsets=(x0, y0, z0), + block_shape=(XBLOCK, YBLOCK, ZBLOCK), + order=(0, 1, 2), + ) + val = tl.load(tptr) + + # store: row-major output (no implicit permute) + out_offset = x * OUT_STRIDE0 + y * OUT_STRIDE1 + z * OUT_STRIDE2 + tl.store(out + out_offset, val, mask=mask) + + +@triton.jit +def advance_implicit_perm_load_store_2d_static_stride(ptr, out, ynumel, xnumel, STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr): + y0 = tl.program_id(1) * YBLOCK + x0 = tl.program_id(0) * XBLOCK + y = y0 + tl.arange(0, YBLOCK)[None, :] + x = x0 + tl.arange(0, XBLOCK)[:, None] + mask = (x < xnumel) & (y < ynumel) + + tptr = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel), + strides=(STRIDE_X, STRIDE_Y), + offsets=(0, 0), + block_shape=(XBLOCK, YBLOCK), + order=(0, 1), + ) + tptr2 = tl.advance(tptr, (x0, y0)) + val = tl.load(tptr2) + tl.store(out + (x * STRIDE_X + y * STRIDE_Y), val, mask=mask) + + +@triton.jit +def advance_implicit_perm_load_store_3d_static_stride( + ptr, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z: tl.constexpr, + STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x0 = pid_x * XBLOCK + y0 = pid_y * YBLOCK + z0 = pid_z * ZBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + tptr0 = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel, znumel), + strides=(STRIDE_X, STRIDE_Y, STRIDE_Z), + offsets=(0, 0, 0), + block_shape=(XBLOCK, YBLOCK, ZBLOCK), + order=(0, 1, 2), + ) + tptr = tl.advance(tptr0, (x0, y0, z0)) + val = tl.load(tptr) + tl.store(out + (x * STRIDE_X + y * STRIDE_Y + z * STRIDE_Z), val, mask=mask) + + +# ---------------------------------------------------------- +# pytest case +# ---------------------------------------------------------- +def ceil_div(a, b): + return (a + b - 1) // b + + +def _assert_row_major_2d(A, X, Y): + assert tuple(A.shape) == (X, Y) + assert A.is_contiguous() + assert A.stride() == (Y, 1) + + +def _assert_row_major_3d(A, X, Y, Z): + # [X, Y, Z] contiguous -> stride = (Y*Z, Z, 1) + assert tuple(A.shape) == (X, Y, Z) + assert A.is_contiguous() + assert A.stride() == (Y * Z, Z, 1) + + +def _assert_row_major_4d(A, X, Y, Z, W): + # [X, Y, Z, W] contiguous -> stride = (Y*Z*W, Z*W, W, 1) + assert tuple(A.shape) == (X, Y, Z, W) + assert A.is_contiguous() + assert A.stride() == (Y * Z * W, Z * W, W, 1) + + +# ---------------------------------------------------------- +# pytest case: addptr kernels +# ---------------------------------------------------------- +@pytest.mark.parametrize("X, Y, XBLOCK, YBLOCK", case_2d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_2d_static_stride( + X, + Y, + XBLOCK, + YBLOCK, + dtype, + sigtype, +): + """ + Test goal: + - Real memory layout: A[X, Y], row-major (stride = (Y, 1)) + - Kernel view: A^T[Y, X], stride = (1, Y) + - Kernel does load+store with identical offsets + - Result must satisfy: out == in + """ + A = test_common.generate_tensor( + shape=(X, Y), + dtype=sigtype, + ).npu() + + _assert_row_major_2d(A, X, Y) + + out = torch.zeros_like(A) + + # A^T logical shape + xnumel = Y # cols of A + ynumel = X # rows of A + + # A^T logical stride + stride_x = 1 + stride_y = Y + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + 1, + ) + + addptr_implicit_perm_load_store_2d_static_stride[grid]( + A, + out, + ynumel, + xnumel, + stride_y, + stride_x, + XBLOCK=XBLOCK, + YBLOCK=YBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, XBLOCK, YBLOCK", case_2d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_2d( + X, + Y, + XBLOCK, + YBLOCK, + dtype, + sigtype, +): + """ + Same as static-stride version, but stride passed as runtime values. + """ + A = test_common.generate_tensor( + shape=(X, Y), + dtype=sigtype, + ).npu() + + _assert_row_major_2d(A, X, Y) + + out = torch.zeros_like(A) + + # A^T logical shape + xnumel = Y # cols of A + ynumel = X # rows of A + + # A^T logical stride + stride_x = 1 + stride_y = Y + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + 1, + ) + + addptr_implicit_perm_load_store_2d[grid]( + A, + out, + ynumel, + xnumel, + stride_y, + stride_x, + XBLOCK=XBLOCK, + YBLOCK=YBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_3d_static_stride(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): + """ + Test goal: + - Real memory layout: A[X, Y, Z], row-major (stride = (Y*Z, Z, 1)) + - Kernel view: treat as permuted logical coords via stride: + offset = x*1 + y*Z + z*(Y*Z) + i.e. (x,y,z) mapped to base index (z, y, x) + - Kernel does load+store with identical offsets + - Result must satisfy: out == in + """ + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + out = torch.zeros_like(A) + + # Logical shape for "A^(perm)" (x fastest) + xnumel = Z + ynumel = Y + znumel = X + + # Logical strides (in elements): (1, Z, Y*Z) + stride_x = 1 + stride_y = Z + stride_z = Y * Z + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + addptr_implicit_perm_load_store_3d_static_stride[grid]( + A, + out, + znumel, + ynumel, + xnumel, + stride_z, + stride_y, + stride_x, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_3d(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): + """ + Same as static-stride version, but stride passed as runtime values. + """ + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + out = torch.zeros_like(A) + + xnumel = Z + ynumel = Y + znumel = X + + stride_x = 1 + stride_y = Z + stride_z = Y * Z + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + addptr_implicit_perm_load_store_3d[grid]( + A, + out, + znumel, + ynumel, + xnumel, + stride_z, + stride_y, + stride_x, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK", case_4d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_4d_static_stride(X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK, dtype, sigtype): + """ + Test goal: + - Real memory layout: A[X, Y, Z, W], row-major (stride = (Y*Z*W, Z*W, W, 1)) + - Kernel view: treat as permuted logical coords via stride: + offset = x*1 + y*W + z*(Z*W) + w*(Y*Z*W) + i.e. (x,y,z,w) mapped to base index (w, z, y, x) + - Kernel does load+store with identical offsets + - Result must satisfy: out == in + """ + A = test_common.generate_tensor(shape=(X, Y, Z, W), dtype=sigtype).npu() + _assert_row_major_4d(A, X, Y, Z, W) + out = torch.zeros_like(A) + + # Logical shape (x fastest) + xnumel = W + ynumel = Z + znumel = Y + wnumel = X + + # Logical strides (in elements): (1, W, Z*W, Y*Z*W) + stride_x = 1 + stride_y = W + stride_z = Z * W + stride_w = Y * Z * W + + # Kernel maps pid0 over (w, x). It uses xblocks_per_w computed from xnumel. + xblocks_per_w = ceil_div(xnumel, XBLOCK) + grid0 = wnumel * xblocks_per_w + grid = ( + grid0, + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + addptr_implicit_perm_load_store_4d_static_stride[grid]( + A, + out, + wnumel, + znumel, + ynumel, + xnumel, + stride_w, + stride_z, + stride_y, + stride_x, + WBLOCK=WBLOCK, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK", case_4d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_4d(X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK, dtype, sigtype): + """ + Same as static-stride version, but stride passed as runtime values. + """ + A = test_common.generate_tensor(shape=(X, Y, Z, W), dtype=sigtype).npu() + _assert_row_major_4d(A, X, Y, Z, W) + out = torch.zeros_like(A) + + xnumel = W + ynumel = Z + znumel = Y + wnumel = X + + stride_x = 1 + stride_y = W + stride_z = Z * W + stride_w = Y * Z * W + + xblocks_per_w = ceil_div(xnumel, XBLOCK) + grid0 = wnumel * xblocks_per_w + grid = ( + grid0, + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + addptr_implicit_perm_load_store_4d[grid]( + A, + out, + wnumel, + znumel, + ynumel, + xnumel, + stride_w, + stride_z, + stride_y, + stride_x, + WBLOCK=WBLOCK, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +# ---------------------------------------------------------- +# pytest case: make_tensor_ptr kernels +# ---------------------------------------------------------- +@pytest.mark.parametrize("X, Y, XBLOCK, YBLOCK", case_2d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_make_tensor_ptr_implicit_perm_load_store_2d_static_stride(X, Y, XBLOCK, YBLOCK, dtype, sigtype): + """ + Test goal matches addptr_2d_static_stride, but uses tl.make_block_ptr + tl.load(tptr). + Real layout: A[X,Y] row-major stride=(Y,1) + Kernel view: A^T[Y,X] stride=(1,Y) + Store is by explicit linear offset with same logical stride. + """ + A = test_common.generate_tensor(shape=(X, Y), dtype=sigtype).npu() + _assert_row_major_2d(A, X, Y) + out = torch.zeros_like(A) + + xnumel = Y + ynumel = X + STRIDE_X = 1 + STRIDE_Y = Y + + grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) + make_tensor_ptr_implicit_perm_load_store_2d_static_stride[grid]( + A, + out, + ynumel, + xnumel, + STRIDE_Y=STRIDE_Y, + STRIDE_X=STRIDE_X, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_make_tensor_ptr_implicit_perm_load_store_3d_static_stride(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): + """ + Real layout: A[X,Y,Z] row-major stride=(Y*Z, Z, 1) + Kernel view (logical): shape=(Z,Y,X), strides=(1, Z, Y*Z) + """ + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + out = torch.zeros_like(A) + + xnumel = Z + ynumel = Y + znumel = X + STRIDE_X = 1 + STRIDE_Y = Z + STRIDE_Z = Y * Z + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + make_tensor_ptr_implicit_perm_load_store_3d_static_stride[grid]( + A, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z=STRIDE_Z, + STRIDE_Y=STRIDE_Y, + STRIDE_X=STRIDE_X, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_make_tensor_ptr_implicit_perm_load_store_3d(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): + """ + Same as static stride but STRIDE_* passed at runtime. + """ + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + out = torch.zeros_like(A) + + xnumel = Z + ynumel = Y + znumel = X + STRIDE_X = 1 + STRIDE_Y = Z + STRIDE_Z = Y * Z + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + make_tensor_ptr_implicit_perm_load_store_3d[grid]( + A, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z, + STRIDE_Y, + STRIDE_X, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_make_tensor_ptr_implicit_perm_load_3d_static_stride(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + + # logical/permuted shape + xnumel = Z + ynumel = Y + znumel = X + + # implicit-permute strides (elements) + STRIDE_X = 1 + STRIDE_Y = Z + STRIDE_Z = Y * Z + + # output is row-major of shape (Z, Y, X) + out = torch.empty((xnumel, ynumel, znumel), device="npu", dtype=A.dtype) + assert out.is_contiguous() + OUT_STRIDE0 = ynumel * znumel # Y*X + OUT_STRIDE1 = znumel # X + OUT_STRIDE2 = 1 + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + make_tensor_ptr_implicit_perm_load_3d_static_stride[grid]( + A, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z=STRIDE_Z, + STRIDE_Y=STRIDE_Y, + STRIDE_X=STRIDE_X, + OUT_STRIDE0=OUT_STRIDE0, + OUT_STRIDE1=OUT_STRIDE1, + OUT_STRIDE2=OUT_STRIDE2, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + # expected: out[x,y,z] == A[z,y,x] => out == A.permute(2,1,0) + ref = A.permute(2, 1, 0).contiguous() + torch.testing.assert_close(out, ref) + + +# ---------------------------------------------------------- +# pytest case: advance kernels +# ---------------------------------------------------------- +@pytest.mark.parametrize("X, Y, XBLOCK, YBLOCK", case_2d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_advance_implicit_perm_load_store_2d_static_stride(X, Y, XBLOCK, YBLOCK, dtype, sigtype): + """ + Same goal as addptr_2d_static_stride, but uses tl.make_block_ptr + tl.advance. + """ + A = test_common.generate_tensor(shape=(X, Y), dtype=sigtype).npu() + _assert_row_major_2d(A, X, Y) + out = torch.zeros_like(A) + + xnumel = Y + ynumel = X + STRIDE_X = 1 + STRIDE_Y = Y + + grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) + advance_implicit_perm_load_store_2d_static_stride[grid]( + A, + out, + ynumel, + xnumel, + STRIDE_Y=STRIDE_Y, + STRIDE_X=STRIDE_X, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_advance_implicit_perm_load_store_3d_static_stride(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): + """ + Real layout: A[X,Y,Z] row-major stride=(Y*Z, Z, 1) + Kernel view (logical): shape=(Z,Y,X), strides=(1, Z, Y*Z) + """ + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + out = torch.zeros_like(A) + + xnumel = Z + ynumel = Y + znumel = X + STRIDE_X = 1 + STRIDE_Y = Z + STRIDE_Z = Y * Z + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + advance_implicit_perm_load_store_3d_static_stride[grid]( + A, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z=STRIDE_Z, + STRIDE_Y=STRIDE_Y, + STRIDE_X=STRIDE_X, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +if __name__ == "__main__": + test_addptr_implicit_perm_load_store_2d_static_stride(*case_2d[0], *types_all[0]) + test_addptr_implicit_perm_load_store_2d(*case_2d[0], *types_all[0]) + test_addptr_implicit_perm_load_store_3d_static_stride(*case_3d[0], *types_all[0]) + test_addptr_implicit_perm_load_store_3d(*case_3d[0], *types_all[0]) + test_addptr_implicit_perm_load_store_4d_static_stride(*case_4d[0], *types_all[0]) + test_addptr_implicit_perm_load_store_4d(*case_4d[0], *types_all[0]) + test_make_tensor_ptr_implicit_perm_load_store_2d_static_stride(*case_2d[0], *types_all[0]) + test_make_tensor_ptr_implicit_perm_load_store_3d_static_stride(*case_3d[0], *types_all[0]) + test_make_tensor_ptr_implicit_perm_load_store_3d(*case_3d[0], *types_all[0]) + test_make_tensor_ptr_implicit_perm_load_3d_static_stride(*case_3d[0], *types_all[0]) + test_advance_implicit_perm_load_store_2d_static_stride(*case_2d[0], *types_all[0]) + test_advance_implicit_perm_load_store_3d_static_stride(*case_3d[0], *types_all[0]) diff --git a/third_party/ascend/unittest/pytest_ut/test_indirect_scalar_load_offset.py b/third_party/ascend/unittest/pytest_ut/test_indirect_scalar_load_offset.py new file mode 100644 index 0000000000..b9a34efccd --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_indirect_scalar_load_offset.py @@ -0,0 +1,98 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + + +@triton.jit +def gather_after_reduce_kernel( + logits_ptr, + topk_ids_ptr, + output_ptr, + logits_stride, + vocab_size, + BLOCK: tl.constexpr, +): + req_idx = tl.program_id(0) + + max_val = -float('inf') + for start in range(0, vocab_size, BLOCK): + offsets = start + tl.arange(0, BLOCK) + mask = offsets < vocab_size + vals = tl.load( + logits_ptr + req_idx * logits_stride + offsets, + mask=mask, + other=-float('inf'), + ) + block_max = tl.max(vals) + max_val = tl.maximum(max_val, block_max) + + topk_id = tl.load(topk_ids_ptr + req_idx + tl.arange(0, 1)) + val = tl.load(logits_ptr + req_idx * logits_stride + topk_id) + tl.store(output_ptr + req_idx + tl.arange(0, 1), val - max_val) + + +def torch_reference(logits, topk_ids): + num_rows = logits.shape[0] + output = torch.empty(num_rows, dtype=logits.dtype) + for i in range(num_rows): + max_val = logits[i].max() + output[i] = logits[i, topk_ids[i]] - max_val + return output + + +shapes = [ + (4, 128), + (8, 256), + (16, 1024), +] + + +@pytest.mark.parametrize('num_rows,vocab_size', shapes) +def test_gather_after_reduce(num_rows, vocab_size): + BLOCK = 128 + + logits_ref = test_common.generate_tensor(shape=(num_rows, vocab_size), dtype='float32') + logits = logits_ref.npu() + logits_flat = logits.reshape(-1) + + topk_ids_ref = torch.randint(0, vocab_size, (num_rows, ), dtype=torch.int64) + topk_ids = topk_ids_ref.npu() + + output = torch.empty(num_rows, dtype=torch.float32).npu() + + gather_after_reduce_kernel[(num_rows, )]( + logits_flat, + topk_ids, + output, + vocab_size, + vocab_size, + BLOCK=BLOCK, + ) + + output_ref = torch_reference(logits_ref, topk_ids_ref) + test_common.validate_cmp('float32', output, output_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_interleave_optimizaiton.py b/third_party/ascend/unittest/pytest_ut/test_interleave_optimizaiton.py new file mode 100644 index 0000000000..7f491e0494 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_interleave_optimizaiton.py @@ -0,0 +1,125 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common + + +def torch_interleave_load(q, k, head_dim_half, bias): + d_indices = torch.arange(0, head_dim_half) + k[d_indices * 2 + bias] = q[d_indices * 2 + bias] + k[d_indices * 2 + 1 + bias] = -q[d_indices * 2 + 1 + bias] + return k + + +def torch_interleave_load_with_mask(q, k, head_dim_half, bias, numel): + d_indices = torch.arange(0, min(head_dim_half, numel)) + k[d_indices * 2 + bias] = q[d_indices * 2 + bias] + k[d_indices * 2 + 1 + bias] = -q[d_indices * 2 + 1 + bias] + return k + + +def torch_interleave_loadstore_with_mask(q, head_dim_half, bias, numel): + d_indices = torch.arange(0, min(head_dim_half, numel)) + # it's unneccessary since we store it back without edit: q[d_indices * 2 + bias] = q[d_indices * 2 + bias] + q[d_indices * 2 + 1 + bias] = -q[d_indices * 2 + 1 + bias] + return q + + +@triton.jit +def triton_interleave_load(q_ptr, k_ptr, head_dim_half: tl.constexpr, bias: tl.constexpr): + d_indices = tl.program_id(0) + tl.arange(0, head_dim_half) + q_real = tl.load(q_ptr + d_indices * 2 + bias) + q_imag = tl.load(q_ptr + d_indices * 2 + 1 + bias) + new_q_real = q_real + new_q_imag = -q_imag + tl.store(k_ptr + d_indices * 2 + bias, new_q_real) + tl.store(k_ptr + d_indices * 2 + 1 + bias, new_q_imag) + + +@triton.jit +def triton_interleave_load_with_mask(q_ptr, k_ptr, head_dim_half: tl.constexpr, bias: tl.constexpr, + numel: tl.constexpr): + d_indices = tl.program_id(0) + tl.arange(0, head_dim_half) + mask = d_indices < numel + q_real = tl.load(q_ptr + d_indices * 2 + bias, mask) + q_imag = tl.load(q_ptr + d_indices * 2 + 1 + bias, mask) + new_q_real = q_real + new_q_imag = -q_imag + tl.store(k_ptr + d_indices * 2 + bias, new_q_real, mask) + tl.store(k_ptr + d_indices * 2 + 1 + bias, new_q_imag, mask) + + +# when load and store are on the same pointer, sometimes we can only optimize the store with mask +@triton.jit +def triton_interleave_loadstore_with_mask(q_ptr, head_dim_half: tl.constexpr, bias: tl.constexpr, numel: tl.constexpr): + d_indices = tl.arange(0, head_dim_half) + mask = d_indices < numel + q_real = tl.load(q_ptr + d_indices * 2 + bias, mask) + q_imag = tl.load(q_ptr + d_indices * 2 + 1 + bias, mask) + new_q_real = q_real + new_q_imag = -q_imag + tl.store(q_ptr + d_indices * 2 + bias, new_q_real, mask) + tl.store(q_ptr + d_indices * 2 + 1 + bias, new_q_imag, mask) + + +@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias', [ + ['float32', torch.float32, 16, 4], +]) +def test_interleave(para_type, data_type, head_dim_half, bias): + length = bias + head_dim_half * 2 + q = torch.randn((length, ), dtype=data_type).npu() + k = torch.zeros_like(q, dtype=data_type).npu() + k_ref = torch.zeros_like(q, dtype=data_type).npu() + + triton_interleave_load[(1, )](q, k, head_dim_half, bias) + k_ref = torch_interleave_load(q, k_ref, head_dim_half, bias) + assert torch.allclose(k, k_ref) + + +@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias,numel', [ + ['float32', torch.float32, 16, 0, 8], +]) +def test_interleave_with_mask(para_type, data_type, head_dim_half, bias, numel): + length = bias + head_dim_half * 2 + q = torch.randn((length, ), dtype=data_type).npu() + k = torch.zeros_like(q, dtype=data_type).npu() + k_ref = torch.zeros_like(q, dtype=data_type).npu() + + triton_interleave_load_with_mask[(1, )](q, k, head_dim_half, bias, numel) + k_ref = torch_interleave_load_with_mask(q, k_ref, head_dim_half, bias, numel) + assert torch.allclose(k, k_ref) + + +@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias,numel', [ + ['float32', torch.float32, 16, 0, 8], +]) +def test_interleave_loadstore_with_mask(para_type, data_type, head_dim_half, bias, numel): + length = bias + head_dim_half * 2 + q = torch.randn((length, ), dtype=data_type).npu() + q_ref = q.clone() + + triton_interleave_loadstore_with_mask[(1, )](q, head_dim_half, bias, numel) + q_ref = torch_interleave_loadstore_with_mask(q_ref, head_dim_half, bias, numel) + assert torch.allclose(q, q_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_lgamma.py b/third_party/ascend/unittest/pytest_ut/test_lgamma.py index bc5db118ac..fc56fffa24 100644 --- a/third_party/ascend/unittest/pytest_ut/test_lgamma.py +++ b/third_party/ascend/unittest/pytest_ut/test_lgamma.py @@ -86,3 +86,29 @@ def test_all_blocks_parallel(param_list, monkeypatch): triton_lgamma[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub) test_common.validate_cmp(dtype, y_cal, y_ref) monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") + + +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 2048, 8), 2, 32768, 512], +]) +def test_auto_blockify(param_list, monkeypatch): + monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") + dtype, shape, ncore, xblock, xblock_sub = param_list + x = test_common.generate_tensor(shape, dtype).npu() + + # Avoid numerical instability near negative integer + nearest_int = torch.round(x) + neg_mask = nearest_int <= -1 + threshold = torch.zeros_like(x) + if neg_mask.any(): + neg_ints = nearest_int[neg_mask] + threshold[neg_mask] = 5.75e-5 * (2.42**(-1 - neg_ints)) + mask = (torch.abs(x - nearest_int) < threshold) & (nearest_int <= -1) + if mask.any(): + x = torch.where(mask, nearest_int, x) + + y_ref = torch.lgamma(x).npu() + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_lgamma[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub, auto_blockify_size=ncore) + test_common.validate_cmp(dtype, y_cal, y_ref) + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") diff --git a/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py b/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py index 64790e1935..b3993de102 100644 --- a/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py +++ b/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py @@ -103,6 +103,46 @@ def triton_linearize_mask_broadcast(in_tensor, BLOCK_SIZE): optimize_dynamic_offset=True) +@triton.jit +def rem_kernel(in_ptr0, in_ptr1, out_ptr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + x = tl.arange(0, BLOCK_SIZE) + + base_offset = pid * BLOCK_SIZE + x + + rem_result = base_offset % 128 + mask = rem_result < 64 + + tmp0 = tl.load(in_ptr0 + base_offset, mask=mask, other=0.0) + tmp1 = tl.load(in_ptr1 + base_offset, mask=mask, other=0.0) + tmp2 = tmp0 + tmp1 + + tl.store(out_ptr + base_offset, tmp2, mask=mask) + + +def test_linearize_mask_rem(): + N = 1024 + BLOCK_SIZE = 256 + dtype = 'float32' + shape = (N, ) + + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.zeros(shape).npu() + + grid = (ceil_div(N, BLOCK_SIZE), ) + rem_kernel[grid](x0, x1, triton_res, N, BLOCK_SIZE=BLOCK_SIZE) + + base_offsets = torch.arange(N).npu() + rem_results = base_offsets % 128 + mask_bool = rem_results < 64 + + torch_res = torch.zeros((N, )).npu() + torch_res[mask_bool] = x0[mask_bool] + x1[mask_bool] + + test_common.validate_cmp(dtype, triton_res, torch_res) + + def profile_performance_test(M, N, dtype, BLOCK_SIZE): print(f"\nDetailed performance analysis: M={M}, N={N}, dtype={dtype}, block_size={BLOCK_SIZE}") diff --git a/third_party/ascend/unittest/pytest_ut/test_makeblockptr_negative_padding.py b/third_party/ascend/unittest/pytest_ut/test_makeblockptr_negative_padding.py new file mode 100644 index 0000000000..4d339997ab --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_makeblockptr_negative_padding.py @@ -0,0 +1,136 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is herey_size granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import test_common + + +@triton.jit +def negative_padding_with_load_kernel( + input_ptr, + output_ptr, + x_offset: tl.constexpr, + y_offset: tl.constexpr, + x_size: tl.constexpr, + y_size: tl.constexpr, +): + in_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(x_size, y_size), + strides=(y_size, 1), + offsets=(x_offset, y_offset), + block_shape=(x_size, y_size), + order=(1, 0), + ) + out_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(x_size, y_size), + strides=(y_size, 1), + offsets=(0, 0), + block_shape=(x_size, y_size), + order=(1, 0), + ) + in_val = tl.load(in_ptr, boundary_check=(0, 1), padding_option="zero") + tl.store(out_ptr, in_val) + + +@triton.jit +def negative_padding_with_store_kernel( + input_ptr, + output_ptr, + x_offset: tl.constexpr, + y_offset: tl.constexpr, + x_size: tl.constexpr, + y_size: tl.constexpr, +): + in_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(x_size, y_size), + strides=(y_size, 1), + offsets=(0, 0), + block_shape=(x_size, y_size), + order=(1, 0), + ) + out_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(x_size, y_size), + strides=(y_size, 1), + offsets=(x_offset, y_offset), + block_shape=(x_size, y_size), + order=(1, 0), + ) + in_val = tl.load(in_ptr) + tl.store(out_ptr, in_val, boundary_check=(0, 1)) + + +@pytest.mark.parametrize('param_list', [(8, 8), (16, 16), (32, 32), (64, 64)]) +def test_makeblockptr_load_with_negative_padding(param_list): + shape = param_list + torch.manual_seed(1) + x_offset = torch.randint(shape[0], size=()).item() + # y_offset = torch.randint(shape[1], size=()).item() + y_offset = 0 + input_tensor = torch.arange(start=1, end=shape[0] * shape[1] + 1, dtype=torch.int32).view(shape).npu() + output = torch.zeros(shape, dtype=torch.int32).npu() + negative_padding_with_load_kernel[(1, )]( + input_tensor, + output, + -x_offset, + -y_offset, + shape[0], + shape[1], + ) + output_ref = torch.zeros((shape[0] + x_offset, shape[1] + y_offset), dtype=torch.int32).cpu() + output_subview = torch.narrow(output_ref, 0, x_offset, shape[0]) + output_subview = torch.narrow(output_subview, 1, y_offset, shape[1]) + output_subview.copy_(input_tensor) + output_ref = torch.narrow(output_ref, 0, 0, shape[0]) + output_ref = torch.narrow(output_ref, 1, 0, shape[1]) + test_common.validate_cmp("int32", output, output_ref) + + +@pytest.mark.parametrize('param_list', [(8, 8), (16, 16), (32, 32), (64, 64)]) +def test_makeblockptr_store_with_negative_padding(param_list): + shape = param_list + torch.manual_seed(1) + x_offset = torch.randint(shape[0], size=()).item() + # y_offset = torch.randint(shape[1], size=()).item() + y_offset = 0 + input_tensor = torch.arange(start=1, end=shape[0] * shape[1] + 1, dtype=torch.int32).view(shape).npu() + output = torch.zeros(shape, dtype=torch.int32).npu() + negative_padding_with_store_kernel[(1, )]( + input_tensor, + output, + -x_offset, + -y_offset, + shape[0], + shape[1], + ) + output_ref = torch.zeros(shape, dtype=torch.int32).cpu() + input_subview = torch.narrow(input_tensor, 0, x_offset, shape[0] - x_offset) + input_subview = torch.narrow(input_subview, 1, y_offset, shape[1] - y_offset) + output_subview = torch.narrow(output_ref, 0, 0, shape[0] - x_offset) + output_subview = torch.narrow(output_subview, 1, 0, shape[1] - y_offset) + output_subview.copy_(input_subview) + test_common.validate_cmp("int32", output, output_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_mod.py b/third_party/ascend/unittest/pytest_ut/test_mod.py index 5886403225..fc91c60afb 100644 --- a/third_party/ascend/unittest/pytest_ut/test_mod.py +++ b/third_party/ascend/unittest/pytest_ut/test_mod.py @@ -26,8 +26,15 @@ import test_common -def torch_pointwise(x0, x1): - res = x0 % x1 +def torch_pointwise(x0, x1, dtype): + if dtype == 'float16': + x0 = x0.to(torch.float32) + x1 = x1.to(torch.float32) + elif dtype == 'float32': + x0 = x0.to(torch.float64) + x1 = x1.to(torch.float64) + res = torch.div(x0, x1, rounding_mode="trunc") + res = x0 - x1 * res return res @@ -58,8 +65,9 @@ def test_case(param_list): else: x0 = test_common.generate_tensor(shape, dtype).npu() x1 = test_common.generate_tensor(shape, dtype).npu() - y_ref = torch_pointwise(x0.cpu(), x1.cpu()) - y_ref = y_ref.npu() + y_ref = torch_pointwise(x0, x1, dtype) + if dtype == "float16": + y_ref = y_ref.to(torch.float16) y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() triton_mod[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) #test_common.validate_cmp(dtype, y_cal, y_ref.npu()) diff --git a/third_party/ascend/unittest/pytest_ut/test_mul_reduce.py b/third_party/ascend/unittest/pytest_ut/test_mul_reduce.py new file mode 100644 index 0000000000..5ca4246331 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_mul_reduce.py @@ -0,0 +1,55 @@ +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import numpy as np + + +@triton.jit +def minimum(a, b): + ret = tl.minimum(a, b, tl.PropagateNan.ALL) + if a.dtype == tl.bfloat16: + ret = ret.to(tl.bfloat16) + return ret + + +@triton.jit +def triton_pw_rdc5d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret0 = x0 * x1 + ret = tl.reduce(ret0, 4, minimum, keep_dims=True) + zblk_idx = tl.arange(0, 1) + odx = (lblk_idx[:, None, None, None, None] * K * N * M + mblk_idx[None, :, None, None, None] * K * N + + nblk_idx[None, None, :, None, None] * K + kblk_idx[None, None, None, :, None] + + zblk_idx[None, None, None, None, :]) + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("shape", [(16, 1, 1, 1, 1)]) # L=16, others=1 +def test_pw_rdc5d(dtype, shape): + L, M, N, K, Z = shape + a = torch.randn(*shape, dtype=dtype, device='npu') + b = torch.randn(*shape, dtype=dtype, device='npu') + out = torch.empty(*shape, dtype=dtype, device='npu') + + expected = (a * b).to(dtype) + + triton_pw_rdc5d[(1, )](a, b, out, L=L, M=M, N=N, K=K, Z=Z) + + torch.testing.assert_close(out.cpu(), expected.cpu(), rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/third_party/ascend/unittest/pytest_ut/test_multibuffer.py b/third_party/ascend/unittest/pytest_ut/test_multibuffer.py new file mode 100644 index 0000000000..04d2471e22 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_multibuffer.py @@ -0,0 +1,73 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os + +import pytest +import triton +import triton.language as tl +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +import triton.extension.buffer.language as bl +import triton.language.extra.cann.extension as al +from triton._C.libtriton import ir, buffer_ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + buffer_ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + + +@triton.jit +def multibuffer(XBLOCK: tl.constexpr): + buf = bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.UB) + al.multibuffer(buf, 2) + + +def test_multibuffer(): + print("=" * 60) + print("Test 1: test_alloc_ub_multibuffer") + print("=" * 60) + mlir = compile_kernel(multibuffer, {}, {"XBLOCK": 256}) + print(f"Generated MLIR ({len(mlir)} chars):\n") + print(mlir) + + +# ============== Main for manual testing ============== +if __name__ == "__main__": + test_multibuffer() diff --git a/third_party/ascend/unittest/generalization_cases/test_general_arange.py b/third_party/ascend/unittest/pytest_ut/test_negative_mask_dim.py similarity index 52% rename from third_party/ascend/unittest/generalization_cases/test_general_arange.py rename to third_party/ascend/unittest/pytest_ut/test_negative_mask_dim.py index 4e7f2df67b..55ea0cb752 100644 --- a/third_party/ascend/unittest/generalization_cases/test_general_arange.py +++ b/third_party/ascend/unittest/pytest_ut/test_negative_mask_dim.py @@ -18,49 +18,31 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -import math import pytest -import torch -import torch_npu + import triton import triton.language as tl -import test_common -from test_common import TestUtils - -def torch_pointwise(length): - res = (torch.arange(0, length) / 2.7) * torch.arange(0, length) - return res - - -def torch_arange(start, end): - TRITON_MAX_TENSOR_NUMEL = 1048576 - if end < start: - raise ValueError("arange's end argument must be greater than the start argument") - if end - start > TRITON_MAX_TENSOR_NUMEL: - raise ValueError( - f"end - start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}") - return torch.arange(start, end) +import torch +import torch_npu @triton.jit -def triton_arange(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): - off = tl.arange(0, BLOCK) - val = tl.arange(START, END) - tl.store(z + off, val) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -def test_case(shape): - start = 0 - end = shape[0] - shape = [end - start] - block = end - start - dtype = 'int32' - - y_ref = torch_arange(start, end) - y_cal = torch.zeros(shape, dtype=torch.int32).npu() - - triton_arange[(1, )](y_cal, START=start, END=end, BLOCK=block) - - assert torch.equal(y_cal.cpu(), y_ref.cpu()) +def triton_negative_mask_dim(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): + index = tl.arange(0, XBLOCK) + mask = (index < 1) & (index + 1 >= XBLOCK) + tmp0 = tl.load(in_ptr0 + index, mask, other=0.0) + tl.store(out_ptr0 + index, tmp0, None) + + +@pytest.mark.parametrize('param_list', [ + ['float32', (32, ), 32], +]) +def test_negative_mask_dim(param_list): + dtype, shape, xblock = param_list + x0 = torch.ones(shape, dtype=eval('torch.' + dtype)).npu() + y_ref = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + + y_cal = torch.ones(shape, dtype=eval('torch.' + dtype)).npu() + triton_negative_mask_dim[(1, )](x0, y_cal, xblock) + assert torch.allclose(y_cal, y_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_nextafter.py b/third_party/ascend/unittest/pytest_ut/test_nextafter.py index 4df371dd11..007f74d9dc 100644 --- a/third_party/ascend/unittest/pytest_ut/test_nextafter.py +++ b/third_party/ascend/unittest/pytest_ut/test_nextafter.py @@ -50,7 +50,6 @@ def triton_nextafter(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SU @pytest.mark.parametrize('param_list', [ ['float32', (2, 4096, 8), 2, 32768, 1024], ['float16', (2, 4096, 8), 2, 32768, 1024], - ['bfloat16', (2, 4096, 8), 2, 32768, 1024], ]) def test_nextafter(param_list): dtype, shape, ncore, xblock, xblock_sub = param_list diff --git a/third_party/ascend/unittest/pytest_ut/test_paged_kvcache_krope.py b/third_party/ascend/unittest/pytest_ut/test_paged_kvcache_krope.py new file mode 100644 index 0000000000..819f131f12 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_paged_kvcache_krope.py @@ -0,0 +1,72 @@ +import torch +import triton +import triton.language as tl +import pytest + + +@triton.jit +def rope_like_load_kernel( + Kv_cache, + Req_to_tokens, + output_ptr, + stride_kv: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + ROPE_DIM: tl.constexpr, +): + + offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM) + offs_n = tl.arange(0, BLOCK_N) + + kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + + offs_k_pe = kv_loc[None, :] * stride_kv + offs_d_kpe[:, None] + + k_pe = tl.load(Kv_cache + offs_k_pe) + + offs_out = offs_n[:, None] * ROPE_DIM + tl.arange(0, ROPE_DIM)[None, :] + tl.store(output_ptr + offs_out, tl.trans(k_pe)) + + +def test_bubbleup_extract_nonzero_offset(): + device = "npu" + + PAGE_SIZE = 2 + BLOCK_N = 4 + head_dim = 32 + head_dim_v = 24 + rope_dim = head_dim - head_dim_v + num_pages = BLOCK_N // PAGE_SIZE + + req_to_tokens = torch.arange(num_pages, dtype=torch.int32, device=device) + total_tokens = num_pages * PAGE_SIZE + kv_cache = torch.zeros(total_tokens, head_dim, dtype=torch.float32, device=device) + for token_id in range(total_tokens): + kv_cache[token_id, :head_dim_v] = (torch.arange(head_dim_v, dtype=torch.float32) + token_id * 100) + kv_cache[token_id, head_dim_v:] = (torch.arange(head_dim_v, head_dim, dtype=torch.float32) + token_id * 1000) + output = torch.zeros(BLOCK_N, rope_dim, dtype=torch.float32, device=device) + + rope_like_load_kernel[(1, )]( + kv_cache.flatten(), + req_to_tokens, + output.flatten(), + stride_kv=head_dim, + HEAD_DIM_V=head_dim_v, + HEAD_DIM=head_dim, + PAGE_SIZE=PAGE_SIZE, + BLOCK_N=BLOCK_N, + ROPE_DIM=rope_dim, + ) + + expected = torch.zeros(BLOCK_N, rope_dim, dtype=torch.float32, device=device) + for token_id in range(BLOCK_N): + expected[token_id] = (torch.arange(head_dim_v, head_dim, dtype=torch.float32) + token_id * 1000) + + buggy = torch.zeros(BLOCK_N, rope_dim, dtype=torch.float32, device=device) + for token_id in range(BLOCK_N): + buggy[token_id] = (torch.arange(rope_dim, dtype=torch.float32) + token_id * 100) + + assert torch.allclose(output, expected, atol=1e-5) diff --git a/third_party/ascend/unittest/pytest_ut/test_parallel.py b/third_party/ascend/unittest/pytest_ut/test_parallel.py index ff41149de0..f81c060811 100644 --- a/third_party/ascend/unittest/pytest_ut/test_parallel.py +++ b/third_party/ascend/unittest/pytest_ut/test_parallel.py @@ -81,6 +81,7 @@ def get_torch_typename(dtype): typelist = ['int8', 'int16', 'int32', 'int64'] +@pytest.mark.skip(reason="not supported after the NPUIR is updated in April, and will be fixed later") @pytest.mark.parametrize('L, M, N', testlist) @pytest.mark.parametrize('sigtype', typelist) def test_add_bind_false(sigtype, L, M, N): diff --git a/third_party/ascend/unittest/pytest_ut/test_permuted_boundary_check.py b/third_party/ascend/unittest/pytest_ut/test_permuted_boundary_check.py new file mode 100644 index 0000000000..9ff97ac885 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_permuted_boundary_check.py @@ -0,0 +1,38 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest + + +@triton.jit +def zj_fa_fwd_pattern(in_ptr0, in_ptr1, out_ptr, M, K, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr, + KBLOCK: tl.constexpr): + a_ptr = tl.make_block_ptr(base=in_ptr0, shape=(M, K), # 8, 3 + strides=(K, 1), offsets=(0, 0), block_shape=(MBLOCK, KBLOCK), order=(1, 0)) + + b_ptr = tl.make_block_ptr(base=in_ptr1, shape=(K, N), # 3, 8 + strides=(1, K), offsets=(0, 0), block_shape=(KBLOCK, NBLOCK), order=(0, 1)) + + c_ptr = tl.make_block_ptr(base=out_ptr, shape=(M, N), strides=(1, M), offsets=(0, 0), block_shape=(MBLOCK, NBLOCK), + order=(0, 1)) + + a = tl.load(a_ptr, boundary_check=(0, ), padding_option="zero") + b = tl.load(b_ptr, boundary_check=(0, ), padding_option="zero") + c = tl.dot(a, b) + tl.store(c_ptr, c, boundary_check=(0, 1)) + + +def test_permute_boundary_check(): + M = 8 + K = 3 + N = 8 + MBLOCK = 8 + NBLOCK = 8 + KBLOCK = 4 + a = torch.randn((M, K), device="npu") # 8, 3 + b = torch.randn((N, K), device="npu") # 8, 3 + c = torch.empty((N, M), device="npu") + zj_fa_fwd_pattern[(1, 1, 1)](a, b, c, M, K, N, MBLOCK, NBLOCK, KBLOCK) + std = a @ b.T + torch.testing.assert_close(std, c.T, atol=1e-2, rtol=1e-2) diff --git a/third_party/ascend/unittest/pytest_ut/test_reduce_maximum.py b/third_party/ascend/unittest/pytest_ut/test_reduce_maximum.py new file mode 100644 index 0000000000..ec29ccbbfe --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_reduce_maximum.py @@ -0,0 +1,308 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +@triton.jit +def maximum(a, b): + ret = tl.maximum(a, b, tl.PropagateNan.ALL) + # 经过测试发现,tl.maximum仅在输入类型为bfloat16时,输出的结果会转变为float32,从而导致编译报错。在GPU上测试发现,和NPU上错误的现象一致。 + # 因此此处针对输入类型为bfloat16的情况,对输出进行了类型转换来规避该错误引起的编译报错。 + if a.dtype == tl.bfloat16: + ret = ret.to(tl.bfloat16) + return ret + + +@triton.jit +def triton_max_5d_dim024(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + idx = lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + \ + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[ + None, None, None, None, :] + odx = mblk_idx[:, None] * K + kblk_idx[None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 4, maximum) + ret1 = tl.reduce(ret, 2, maximum) + ret2 = tl.reduce(ret1, 0, maximum) + tl.store(out_ptr0 + odx, ret2) + + +@triton.jit +def triton_max_5d_dim13(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None] * N * Z + nblk_idx[None, :, None] * Z + zblk_idx[None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret_k = tl.reduce(x, 3, maximum) # [L, M, N, Z] + ret_m = tl.reduce(ret_k, 1, maximum) # [L, N, Z] + tl.store(out_ptr0 + odx, ret_m) + + +@triton.jit +def triton_max_5d_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (mblk_idx[:, None, None, None] * N * K * Z + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 0, maximum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_max_5d_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * N * K * Z + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 1, maximum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_max_5d_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * K * Z + mblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 2, maximum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_max_5d_dim3(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * N * Z + mblk_idx[None, :, None, None] * N * Z + + nblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 3, maximum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_max_5d_dim4(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * N * K + mblk_idx[None, :, None, None] * N * K + + nblk_idx[None, None, :, None] * K + kblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 4, maximum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_max_5d_all(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret1 = tl.reduce(x, 4, maximum) + ret2 = tl.reduce(ret1, 3, maximum) + ret3 = tl.reduce(ret2, 2, maximum) + ret4 = tl.reduce(ret3, 1, maximum) + ret5 = tl.reduce(ret4, 0, maximum) + tl.store(out_ptr0, ret5) + + +testlist = [ + (triton_max_5d_dim024, (1, 1, 1, 1, 1), "dim024"), + (triton_max_5d_dim024, (2, 2, 2, 2, 2), "dim024"), + (triton_max_5d_dim024, (3, 11, 1, 3, 42), "dim024"), + (triton_max_5d_dim13, (1, 1, 1, 1, 1024), "dim13"), + (triton_max_5d_dim0, (2, 2, 2, 2, 2), "dim0"), + (triton_max_5d_dim1, (2, 2, 2, 2, 2), "dim1"), + (triton_max_5d_dim2, (2, 2, 2, 2, 2), "dim2"), + (triton_max_5d_dim3, (2, 2, 2, 2, 2), "dim3"), + (triton_max_5d_dim4, (2, 2, 2, 2, 2), "dim4"), + (triton_max_5d_all, (3, 11, 1, 3, 42), "all"), +] + +typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] + +ids = [ + "{}-{}-{}".format(testfunc.__name__, "-".join(map(str, shape)), dim_name) for testfunc, shape, dim_name in testlist +] + + +@pytest.mark.parametrize('testfunc, shape, dim_name', testlist, ids=ids) +@pytest.mark.parametrize('dtype', typelist) +def test_max(testfunc, dtype, shape, dim_name): + x0 = test_common.generate_tensor(shape=shape, dtype=dtype).npu() + + if dim_name == "dim024": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 4) + ans, _ = torch.max(ans.to(torch.int64), 2) + ans, _ = torch.max(ans.to(torch.int64), 0) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 4) + ans, _ = torch.max(ans, 2) + ans, _ = torch.max(ans, 0) + output = torch.zeros((shape[1], ) + (shape[3], ), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim13": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 3) + ans, _ = torch.max(ans.to(torch.int64), 1) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 3) + ans, _ = torch.max(ans, 1) + output = torch.zeros((shape[0], ) + (shape[2], ) + (shape[4], ), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim0": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 0) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 0) + output = torch.zeros((shape[1], ) + (shape[2], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim1": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 1) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 1) + output = torch.zeros((shape[0], ) + (shape[2], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim2": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 2) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 2) + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim3": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 3) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 3) + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[2], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim4": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 4) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 4) + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[2], ) + (shape[3], ), + dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "all": + if 'int' in dtype: + ans = torch.max(x0.to(torch.int64)) + ans = torch.tensor([ans], dtype=eval('torch.' + dtype)) + else: + ans = torch.tensor([torch.max(x0)], dtype=eval('torch.' + dtype)) + output = torch.zeros((1, ), dtype=eval('torch.' + dtype)).npu() + + testfunc[(1, )](x0, output, *shape) + + test_common.validate_cmp(dtype, output, ans) diff --git a/third_party/ascend/unittest/pytest_ut/test_reduce_min_4_keepdim_True_with_index_op.py b/third_party/ascend/unittest/pytest_ut/test_reduce_min_4_keepdim_True_with_index_op.py new file mode 100644 index 0000000000..3428508f69 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_reduce_min_4_keepdim_True_with_index_op.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import time + +import pytest +import torch +import torch_npu + +import triton +import triton.language as tl +import test_common + + +@triton.jit +def promote_to_tensor(x): + # Addition promotes to tensor for us + return x + tl.zeros((1, ), tl.int1) + + +@triton.jit +def minimum_with_index(a_value, a_index, b_value, b_index): + mask = a_value < b_value + equal = a_value == b_value + if promote_to_tensor(a_value).dtype.is_floating(): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def triton_min_5d_dim4_keepdim(in_ptr0, in_ptr1, out_ptr0, out_ptr1, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, + K: tl.constexpr, Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + idx = lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + \ + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :] + x = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret, ret1 = tl.reduce((x, x1), 4, minimum_with_index, keep_dims=True) + zblk_idx = tl.arange(0, 1) + odx = lblk_idx[:, None, None, None, None] * K * N * M + mblk_idx[None, :, None, None, None] * K * N + \ + nblk_idx[None, None, :, None, None] * K + kblk_idx[None, None, None, :, None] \ + + zblk_idx[None, None, None, None, :] + tl.store(out_ptr0 + odx, ret) + tl.store(out_ptr1 + odx, ret1) + + +testlist = [ + # 5D + (triton_min_5d_dim4_keepdim, (1, 1, 1, 1, 1)), + (triton_min_5d_dim4_keepdim, (2, 2, 2, 2, 2)), + (triton_min_5d_dim4_keepdim, (9, 3, 2, 4, 17)), + (triton_min_5d_dim4_keepdim, (3, 11, 1, 3, 42)), + (triton_min_5d_dim4_keepdim, (2, 51, 3, 13, 1)), + (triton_min_5d_dim4_keepdim, (129, 1, 5, 1, 4)), + (triton_min_5d_dim4_keepdim, (203, 1, 2, 2, 3)), + (triton_min_5d_dim4_keepdim, (512, 1, 1, 1, 1)), + (triton_min_5d_dim4_keepdim, (3, 1, 1, 2, 600)), + (triton_min_5d_dim4_keepdim, (1, 1, 1, 1, 1024)), + (triton_min_5d_dim4_keepdim, (15, 2, 2, 2, 54)), + (triton_min_5d_dim4_keepdim, (2, 91, 4, 2, 4)), + (triton_min_5d_dim4_keepdim, (1, 1, 3, 2, 600)), + (triton_min_5d_dim4_keepdim, (5, 2, 4, 1, 26)), + (triton_min_5d_dim4_keepdim, (2, 2, 2, 4, 8)), +] + +typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] + +ids = ["{}-{}".format(testfunc.__name__, "-".join(map(str, shape))) for testfunc, shape in testlist] + + +@pytest.mark.parametrize('testfunc, shape', testlist, ids=ids) +@pytest.mark.parametrize('sigtype', typelist) +def test_min_dim4_keepdim(testfunc, sigtype, shape): + dtype = eval('torch.' + sigtype) + x0 = torch.randn(shape).to(dtype).npu() + + x1 = torch.arange(x0.numel()).view(x0.shape).npu().to(torch.int32) + if 'int' in sigtype: + ans, ans1 = torch.min(x0.to(torch.int64), 4) + ans = ans.to(dtype) + else: + ans, ans1 = torch.min(x0, 4) + output = torch.zeros(shape[0:4], dtype=dtype).npu() + output1 = torch.zeros(shape[0:4], dtype=torch.int32).npu() + testfunc[(1, )](x0, x1, output, output1, *shape, debug=True) + test_common.validate_cmp(sigtype, output, ans) + test_common.validate_cmp('int32', output1, ans1) diff --git a/third_party/ascend/unittest/pytest_ut/test_reduce_minimum.py b/third_party/ascend/unittest/pytest_ut/test_reduce_minimum.py new file mode 100644 index 0000000000..f472eef67e --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_reduce_minimum.py @@ -0,0 +1,308 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +@triton.jit +def minimum(a, b): + ret = tl.minimum(a, b, tl.PropagateNan.ALL) + # 经过测试发现,tl.minimum仅在输入类型为bfloat16时,输出的结果会转变为float32,从而导致编译报错。在GPU上测试发现,和NPU上错误的现象一致。 + # 因此此处针对输入类型为bfloat16的情况,对输出进行了类型转换来规避该错误引起的编译报错。 + if a.dtype == tl.bfloat16: + ret = ret.to(tl.bfloat16) + return ret + + +@triton.jit +def triton_min_5d_dim024(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + idx = lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + \ + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[ + None, None, None, None, :] + odx = mblk_idx[:, None] * K + kblk_idx[None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 4, minimum) + ret1 = tl.reduce(ret, 2, minimum) + ret2 = tl.reduce(ret1, 0, minimum) + tl.store(out_ptr0 + odx, ret2) + + +@triton.jit +def triton_min_5d_dim13(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None] * N * Z + nblk_idx[None, :, None] * Z + zblk_idx[None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret_k = tl.reduce(x, 3, minimum) # [L, M, N, Z] + ret_m = tl.reduce(ret_k, 1, minimum) # [L, N, Z] + tl.store(out_ptr0 + odx, ret_m) + + +@triton.jit +def triton_min_5d_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (mblk_idx[:, None, None, None] * N * K * Z + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 0, minimum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_min_5d_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * N * K * Z + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 1, minimum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_min_5d_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * K * Z + mblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 2, minimum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_min_5d_dim3(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * N * Z + mblk_idx[None, :, None, None] * N * Z + + nblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 3, minimum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_min_5d_dim4(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * N * K + mblk_idx[None, :, None, None] * N * K + + nblk_idx[None, None, :, None] * K + kblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 4, minimum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_min_5d_all(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret1 = tl.reduce(x, 4, minimum) + ret2 = tl.reduce(ret1, 3, minimum) + ret3 = tl.reduce(ret2, 2, minimum) + ret4 = tl.reduce(ret3, 1, minimum) + ret5 = tl.reduce(ret4, 0, minimum) + tl.store(out_ptr0, ret5) + + +testlist = [ + (triton_min_5d_dim024, (1, 1, 1, 1, 1), "dim024"), + (triton_min_5d_dim024, (2, 2, 2, 2, 2), "dim024"), + (triton_min_5d_dim024, (3, 11, 1, 3, 42), "dim024"), + (triton_min_5d_dim13, (1, 1, 1, 1, 1024), "dim13"), + (triton_min_5d_dim0, (2, 2, 2, 2, 2), "dim0"), + (triton_min_5d_dim1, (2, 2, 2, 2, 2), "dim1"), + (triton_min_5d_dim2, (2, 2, 2, 2, 2), "dim2"), + (triton_min_5d_dim3, (2, 2, 2, 2, 2), "dim3"), + (triton_min_5d_dim4, (2, 2, 2, 2, 2), "dim4"), + (triton_min_5d_all, (3, 11, 1, 3, 42), "all"), +] + +typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] + +ids = [ + "{}-{}-{}".format(testfunc.__name__, "-".join(map(str, shape)), dim_name) for testfunc, shape, dim_name in testlist +] + + +@pytest.mark.parametrize('testfunc, shape, dim_name', testlist, ids=ids) +@pytest.mark.parametrize('dtype', typelist) +def test_min(testfunc, dtype, shape, dim_name): + x0 = test_common.generate_tensor(shape=shape, dtype=dtype).npu() + + if dim_name == "dim024": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 4) + ans, _ = torch.min(ans.to(torch.int64), 2) + ans, _ = torch.min(ans.to(torch.int64), 0) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 4) + ans, _ = torch.min(ans, 2) + ans, _ = torch.min(ans, 0) + output = torch.zeros((shape[1], ) + (shape[3], ), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim13": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 3) + ans, _ = torch.min(ans.to(torch.int64), 1) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 3) + ans, _ = torch.min(ans, 1) + output = torch.zeros((shape[0], ) + (shape[2], ) + (shape[4], ), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim0": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 0) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 0) + output = torch.zeros((shape[1], ) + (shape[2], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim1": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 1) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 1) + output = torch.zeros((shape[0], ) + (shape[2], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim2": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 2) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 2) + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim3": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 3) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 3) + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[2], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim4": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 4) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 4) + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[2], ) + (shape[3], ), + dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "all": + if 'int' in dtype: + ans = torch.min(x0.to(torch.int64)) + ans = torch.tensor([ans], dtype=eval('torch.' + dtype)) + else: + ans = torch.tensor([torch.min(x0)], dtype=eval('torch.' + dtype)) + output = torch.zeros((1, ), dtype=eval('torch.' + dtype)).npu() + + testfunc[(1, )](x0, output, *shape) + + test_common.validate_cmp(dtype, output, ans) diff --git a/third_party/ascend/unittest/pytest_ut/test_runtime_utils.py b/third_party/ascend/unittest/pytest_ut/test_runtime_utils.py new file mode 100644 index 0000000000..6a70866cfe --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_runtime_utils.py @@ -0,0 +1,33 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import logging +import os +from triton.backends.ascend import utils + + +def test_get_logger(): + logger = utils.get_logger("test_utils", "INFO") + assert logger.level == logging.INFO + + +def test_get_ascend_arch_from_env(): + os.environ["TRITON_ASCEND_ARCH"] = "Ascend910_9599" + result = utils.get_ascend_arch_from_env() + assert result == "Ascend910_9599" diff --git a/third_party/ascend/unittest/pytest_ut/test_scalar_calc.py b/third_party/ascend/unittest/pytest_ut/test_scalar_calc.py index 84ecdbeb19..53517eb4eb 100644 --- a/third_party/ascend/unittest/pytest_ut/test_scalar_calc.py +++ b/third_party/ascend/unittest/pytest_ut/test_scalar_calc.py @@ -137,7 +137,7 @@ def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): def torch_func(x0): y = x0[0] - y = y % 2.0 + y = y - 2.0 * torch.div(y, 2.0, rounding_mode="trunc") return torch.tensor(y) dtype, N = param_list diff --git a/third_party/ascend/unittest/pytest_ut/test_select_analysis.py b/third_party/ascend/unittest/pytest_ut/test_select_analysis.py new file mode 100644 index 0000000000..6f7f99b110 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_select_analysis.py @@ -0,0 +1,121 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +@triton.jit +def kernel_cal_select_mask_bool( + Output_ptr, + Indices_ptr, + numel: tl.constexpr, + BLOCK: tl.constexpr, +): + offs = tl.arange(0, BLOCK) + indice = tl.load(Indices_ptr) + + true_tensor = tl.arange(0, BLOCK) < numel + false_tensor = tl.arange(0, BLOCK) >= numel + mask = offs < indice + res = tl.where(mask, true_tensor, false_tensor) + tl.store(Output_ptr + offs, res) + + +@triton.jit +def kernel_cal_select_mask( + QK_ptr, + Other_ptr, + Output_ptr, + Indices_ptr, + stride_qk: tl.constexpr, + numel: tl.constexpr, + BLOCK: tl.constexpr, +): + rows = tl.arange(0, BLOCK) * stride_qk + cols = tl.arange(0, BLOCK) + offs = rows[:, None] + cols[None, :] + row_indices = tl.load(Indices_ptr) + col_indices = tl.load(Indices_ptr + 1) + + qk_ub = tl.load(QK_ptr + offs) + other = tl.load(Other_ptr + offs) + mask_rows = rows < row_indices * stride_qk + mask_cols = cols < col_indices + + res = tl.where(mask_rows[:, None] & mask_cols[None, :], qk_ub, other) + tl.store(Output_ptr + offs, res) + + +def torch_cal_select_mask_bool( + Indice: torch.Tensor, + numel, + BLOCK, +): + offs = torch.arange(0, BLOCK) + true_tensor = torch.arange(0, BLOCK) < numel + false_tensor = torch.arange(0, BLOCK) >= numel + mask = offs < Indice + + res = torch.where(mask, true_tensor, false_tensor) + return res + + +def torch_cal_select_mask( + QK: torch.Tensor, + Other: torch.Tensor, + Indices: torch.Tensor, +): + row_limit_idx = Indices[0].item() + col_limit_idx = Indices[1].item() + Output = Other.clone() + Output[:row_limit_idx, :col_limit_idx] = QK[:row_limit_idx, :col_limit_idx] + return Output + + +@pytest.mark.parametrize('param_list', [['bool', 64, 63]]) +def test_select_analysis_bool(param_list): + dtype, SEQ_LEN, indice = param_list + assert dtype == 'bool' + qk_cal = torch.empty(SEQ_LEN).npu() + indices = torch.tensor([indice]).npu() + qk_ref = torch_cal_select_mask_bool(indice, SEQ_LEN, SEQ_LEN) + kernel_cal_select_mask_bool[(1, )](qk_cal, indices, SEQ_LEN, SEQ_LEN) + test_common.validate_cmp(dtype, qk_cal, qk_ref) + + +@pytest.mark.parametrize('param_list', [ + ['float16', 64, 63, 62], + ['float32', 64, 63, 62], +]) +def test_select_analysis(param_list): + dtype, SEQ_LEN, indice_x, indice_y = param_list + assert dtype != 'bool' + qk = torch.rand([SEQ_LEN, SEQ_LEN], dtype=eval('torch.' + dtype), device='npu') + qk_cal = torch.empty_like(qk).npu() + other = torch.zeros_like(qk).npu() + indices_tensor = torch.tensor([indice_x, indice_y]).npu() + qk_ref = torch_cal_select_mask(qk, other, indices_tensor) + kernel_cal_select_mask[(1, )](qk, other, qk_cal, indices_tensor, qk.stride(0), SEQ_LEN * SEQ_LEN, SEQ_LEN) + test_common.validate_cmp(dtype, qk_cal, qk_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_invert.py b/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_invert.py new file mode 100644 index 0000000000..64abd9b3cb --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_invert.py @@ -0,0 +1,153 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +@triton.jit +def cal_atten_mask_kernel( + QK_ptr, + Indices_ptr, + stride_qk_m, + stride_qk_n, + stride_ik, + SEQ_LEN: tl.constexpr, + sparse_block_size: tl.constexpr, + BLOCK_SBS: tl.constexpr, + TOPK_BASE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + idx_sub_sbs = pid_n + cur_s1 = pid_m * BLOCK_SBS + cur_s2 = cur_s1 + BLOCK_SBS + + if cur_s1 >= SEQ_LEN: + return + + beg_sbs = idx_sub_sbs * BLOCK_SBS // sparse_block_size + end_sbs = ((idx_sub_sbs + 1) * BLOCK_SBS) // sparse_block_size + + valid_col_end = cur_s1 + (cur_s2 - cur_s1) + + offs_m = cur_s1 + tl.arange(0, BLOCK_SBS) + offs_n_base = idx_sub_sbs * BLOCK_SBS + offs_n = offs_n_base + tl.arange(0, BLOCK_SBS) + + mask_m = offs_m < SEQ_LEN + mask_n = offs_n < SEQ_LEN + mask_load = mask_m[:, None] & mask_n[None, :] + + qk_ub = tl.load(QK_ptr + offs_m[:, None] * stride_qk_m + offs_n[None, :] * stride_qk_n, mask=mask_load, other=0.0) + + for idx_k in range(beg_sbs, end_sbs): + idx_s2 = tl.load(Indices_ptr + TOPK_BASE + idx_k * stride_ik) + if idx_s2 != -1 and idx_s2 * sparse_block_size > valid_col_end: + idx_lower_sbs = idx_k * sparse_block_size - \ + idx_sub_sbs * BLOCK_SBS + idx_higher_sbs = (idx_k + 1) * sparse_block_size - \ + idx_sub_sbs * BLOCK_SBS + mask_lower_sbs = tl.arange(0, BLOCK_SBS) >= idx_lower_sbs + mask_higher_sbs = tl.arange(0, BLOCK_SBS) < idx_higher_sbs + qk_ub = tl.where((mask_lower_sbs & mask_higher_sbs)[None, :], float("-inf"), qk_ub) + + tl.store(QK_ptr + offs_m[:, None] * stride_qk_m + offs_n[None, :] * stride_qk_n, qk_ub, mask=mask_load) + + +def launch_cal_atten_mask(qk_tensor, indices_tensor, sparse_block_size=64, block_sbs=128): + """ + qk_tensor: (SEQ_LEN, SEQ_LEN) + indices_tensor: (K,) / (BATCH, K, ...) + """ + assert qk_tensor.is_contiguous() + M, N = qk_tensor.shape + + stride_qk_m = qk_tensor.stride(0) + stride_qk_n = qk_tensor.stride(1) + + stride_ik = 1 + topk_base = 0 + + grid = (triton.cdiv(M, block_sbs), triton.cdiv(N, block_sbs)) + cal_atten_mask_kernel[grid]( + qk_tensor, + indices_tensor, + stride_qk_m, + stride_qk_n, + stride_ik, + SEQ_LEN=M, + sparse_block_size=sparse_block_size, + BLOCK_SBS=block_sbs, + TOPK_BASE=topk_base, + ) + return qk_tensor + + +def torch_cal_atten_mask( + qk, + indices, + sparse_block_size, + block_sbs, + topk_base=0, +): + device = qk.device + dtype = qk.dtype + M, N = qk.shape + + row_ids = torch.arange(M, device=device).unsqueeze(1) + col_ids = torch.arange(N, device=device).unsqueeze(0) + + k_idx_global = col_ids // sparse_block_size + lookup_idx = k_idx_global + topk_base + max_valid_idx = indices.numel() - 1 + + valid_lookup = (lookup_idx >= 0) & (lookup_idx <= max_valid_idx) + safe_lookup_idx = lookup_idx.clamp(0, max_valid_idx) + idx_s2_map = indices.gather(0, safe_lookup_idx.squeeze(0)).unsqueeze(0) + idx_s2_map = torch.where(valid_lookup, idx_s2_map, torch.tensor(-1, device=device)) + + row_block_ends = ((row_ids // block_sbs) + 1) * block_sbs + row_block_ends = torch.min(row_block_ends, torch.tensor(N, device=device)) + + start_pos_k_map = idx_s2_map * sparse_block_size + cond_valid = (idx_s2_map != -1) + cond_exceed = (start_pos_k_map > row_block_ends) + final_mask = cond_valid & cond_exceed + + qk_out = torch.where(final_mask, torch.tensor(float("-inf"), dtype=dtype, device=device), qk) + return qk_out + + +@pytest.mark.parametrize('param_list', [['float32', 1024, 128, 64]]) +def test_divsiop_select_analysis1(param_list): + dtype, SEQ_LEN, BLOCK_SBS, SPARSE_BLOCK = param_list + qk = torch.zeros((SEQ_LEN, SEQ_LEN), dtype=eval('torch.' + dtype), device='npu') + K_SIZE = 20 + indices = torch.full((K_SIZE, ), -1, dtype=torch.int32, device='npu') + indices[10] = 20 + qk_ref = torch_cal_atten_mask(qk.clone(), indices, sparse_block_size=SPARSE_BLOCK, block_sbs=BLOCK_SBS) + qk_cal = launch_cal_atten_mask(qk, indices, sparse_block_size=SPARSE_BLOCK, block_sbs=BLOCK_SBS) + test_common.validate_cmp(dtype, qk_cal, qk_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_signbit.py b/third_party/ascend/unittest/pytest_ut/test_signbit.py index 693b601c23..8350094340 100644 --- a/third_party/ascend/unittest/pytest_ut/test_signbit.py +++ b/third_party/ascend/unittest/pytest_ut/test_signbit.py @@ -64,3 +64,18 @@ def test_all_blocks_parallel(param_list, monkeypatch): triton_signbit[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub) test_common.validate_cmp('bool', y_cal, y_ref) monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") + + +@pytest.mark.parametrize('param_list', [ + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['float32', (2, 4096, 8), 2, 32768, 1024], +]) +def test_auto_blockify(param_list, monkeypatch): + monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") + dtype, shape, ncore, xblock, xblock_sub = param_list + x = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch.signbit(x).npu() + y_cal = torch.zeros(shape).bool().npu() + triton_signbit[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub, auto_blockify_size=ncore) + test_common.validate_cmp('bool', y_cal, y_ref) + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") diff --git a/third_party/ascend/unittest/pytest_ut/test_simplify_iterargs.py b/third_party/ascend/unittest/pytest_ut/test_simplify_iterargs.py new file mode 100644 index 0000000000..d1dabaf669 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_simplify_iterargs.py @@ -0,0 +1,78 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + +import torch +import torch_npu + + +@triton.jit +def triton_fn_expanddims(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr, YBLOCK_SUB: tl.constexpr): + base1 = tl.arange(0, XBLOCK)[:, None] + base2 = tl.arange(0, YBLOCK_SUB)[None, :] + loops1: tl.constexpr = YBLOCK // YBLOCK_SUB # assume it's divisible + for _ in range(loops1): + x0 = base1 * YBLOCK + base2 + base2 = base2 + YBLOCK_SUB + tmp0 = tl.load(in_ptr0 + (x0), None) + tl.store(out_ptr0 + (x0), tmp0, None) + + +@triton.jit +def triton_fn_broadcast(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr, YBLOCK_SUB: tl.constexpr): + base1 = tl.arange(0, XBLOCK)[:, None] + base2 = tl.arange(0, YBLOCK_SUB)[None, :] + base2 = base2.broadcast_to((XBLOCK, YBLOCK_SUB)) + loops1: tl.constexpr = YBLOCK // YBLOCK_SUB # assume it's divisible + for _ in range(loops1): + x0 = base1 * YBLOCK + base2 + base2 = base2 + YBLOCK_SUB + tmp0 = tl.load(in_ptr0 + (x0), None) + tl.store(out_ptr0 + (x0), tmp0, None) + + +@pytest.mark.parametrize('param_list', [ + ['float32', (128, 128), 128, 128, 32], +]) +def test_expanddims(param_list): + dtype, shape, xblock, yblock, yblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_fn_expanddims[(1, )](x0, y_cal, xblock, yblock, yblock_sub) + test_common.validate_cmp(dtype, y_cal, x0) + + +@pytest.mark.parametrize('param_list', [ + ['float32', (128, 128), 128, 128, 32], +]) +def test_broadcast(param_list): + dtype, shape, xblock, yblock, yblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_fn_broadcast[(1, )](x0, y_cal, xblock, yblock, yblock_sub) + test_common.validate_cmp(dtype, y_cal, x0) diff --git a/third_party/ascend/unittest/pytest_ut/test_sink_broadcast.py b/third_party/ascend/unittest/pytest_ut/test_sink_broadcast.py new file mode 100644 index 0000000000..99b347246d --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_sink_broadcast.py @@ -0,0 +1,89 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + +import torch +import torch_npu + + +@triton.jit +def triton_sink_broadcast1(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr): + base1 = tl.arange(0, XBLOCK)[:, None] * YBLOCK + base2 = tl.arange(0, YBLOCK)[None, :] + base1 = base1.broadcast_to((XBLOCK, YBLOCK)) + tmp0 = tl.load(in_ptr0 + base1, None) + index = base1 + base2 + tl.store(out_ptr0 + index, tmp0, None) + + +@triton.jit +def triton_sink_broadcast2(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr): + base1 = tl.arange(0, XBLOCK)[:, None] * YBLOCK + base2 = tl.arange(0, YBLOCK)[None, :] + base1 = base1.broadcast_to((XBLOCK, YBLOCK)) + tmp0 = tl.load(in_ptr0 + base1, base1 < XBLOCK * YBLOCK, other=0.0) + index = base1 + base2 + tl.store(out_ptr0 + index, tmp0, index < XBLOCK * YBLOCK) + + +@triton.jit +def triton_sink_broadcast3(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr): + base1 = (tl.arange(0, XBLOCK) * YBLOCK)[:, None] + base2 = tl.arange(0, YBLOCK)[None, :] + base1 = base1.broadcast_to((XBLOCK, YBLOCK)) + tmp0 = tl.load(in_ptr0 + base1 + base2, (base1 + base2) < XBLOCK * YBLOCK, other=0.0) + index = base1 + base2 + tl.store(out_ptr0 + index, tmp0, index < XBLOCK * YBLOCK) + + +@pytest.mark.parametrize('param_list', [ + ['float32', (32, 32), 32, 32], +]) +def test_sink_broadcast(param_list): + dtype, shape, xblock, yblock = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = x0.clone() + y_ref = y_ref[:, 0].unsqueeze(1).expand(-1, x0.size(1)) + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + y_cal2 = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_sink_broadcast1[(1, )](x0, y_cal, xblock, yblock) + triton_sink_broadcast2[(1, )](x0, y_cal2, xblock, yblock) + test_common.validate_cmp(dtype, y_cal, y_ref) + test_common.validate_cmp(dtype, y_cal2, y_ref) + + +@pytest.mark.parametrize('param_list', [ + ['float32', (32, 32), 32, 32], +]) +def test_sink_broadcast3(param_list): + dtype, shape, xblock, yblock = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = x0.clone() + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_sink_broadcast3[(1, )](x0, y_cal, xblock, yblock) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_sync_block.py b/third_party/ascend/unittest/pytest_ut/test_sync_block.py index 0c84b3166a..8d1df6d243 100644 --- a/third_party/ascend/unittest/pytest_ut/test_sync_block.py +++ b/third_party/ascend/unittest/pytest_ut/test_sync_block.py @@ -17,6 +17,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import torch import triton import triton.language as tl @@ -91,7 +92,7 @@ def test_matmul_exp(dtype, ashape, bshape): C_ref = (A @ B).exp() # compare - test_common.validate_cmp(dtype, C, C_ref) + torch.testing.assert_close(C_ref, C, rtol=3e-2, atol=3e-2, equal_nan=True) if __name__ == "__main__": diff --git a/third_party/ascend/unittest/pytest_ut/test_use_analysis.py b/third_party/ascend/unittest/pytest_ut/test_use_analysis.py new file mode 100644 index 0000000000..06ce0a2470 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_use_analysis.py @@ -0,0 +1,74 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + +import torch +import torch_npu + + +@triton.jit +def triton_reduce_deadcode(v_ptr, in_ptr0, in_ptr1, out_ptr0, VBLOCK: tl.constexpr, XBLOCK: tl.constexpr, + YBLOCK: tl.constexpr): + v_idx = tl.arange(0, VBLOCK) + v = tl.load(v_ptr + v_idx) + v_ret = tl.argmax(v, 0) + if v_ret < v_ret + 1: + for _ in range(v_ret, v_ret + 1): + cube_idx = tl.arange(0, XBLOCK)[:, None] * YBLOCK + tl.arange(0, YBLOCK)[None, :] + c0 = tl.load(in_ptr0 + cube_idx) + c1 = tl.load(in_ptr1 + cube_idx) + ret = tl.dot(c0, c1) + 1 + tl.store(out_ptr0 + cube_idx, ret) + else: + for _ in range(v_ret - 1, v_ret): + cube_idx = tl.arange(0, XBLOCK)[:, None] * YBLOCK + tl.arange(0, YBLOCK)[None, :] + c0 = tl.load(in_ptr0 + cube_idx) + c1 = tl.load(in_ptr1 + cube_idx) + ret = tl.dot(c0, c1) + 1 + tl.store(out_ptr0 + cube_idx, ret) + + +def torch_reduce_deadcode(in0, in1, v): + v_ret = torch.argmax(v) + if v_ret < v_ret + 1: + ret = torch.matmul(in0, in1) + 1 + else: + ret = torch.matmul(in0, in1) + 1 + return ret + + +def test_reduce_deadcode(): + VBLOCK, XBLOCK, YBLOCK = 16, 16, 16 + sigtype = 'float32' + dtype = torch.float32 + in0 = torch.randn((XBLOCK, YBLOCK), dtype=dtype, device='npu') + in1 = torch.randn((XBLOCK, YBLOCK), dtype=dtype, device='npu') + v = torch.randn((VBLOCK, ), dtype=dtype, device='npu') + out = torch.zeros((XBLOCK, YBLOCK), dtype=dtype, device='npu') + + triton_reduce_deadcode[(1, )](v, in0, in1, out, VBLOCK=VBLOCK, XBLOCK=XBLOCK, YBLOCK=YBLOCK) + expected = torch_reduce_deadcode(in0, in1, v) + test_common.validate_cmp(sigtype, out, expected) diff --git a/third_party/ascend/unittest/pytest_ut/test_zeros.py b/third_party/ascend/unittest/pytest_ut/test_zeros.py index cc7fd067f1..3dab18af39 100644 --- a/third_party/ascend/unittest/pytest_ut/test_zeros.py +++ b/third_party/ascend/unittest/pytest_ut/test_zeros.py @@ -27,15 +27,11 @@ @triton.jit -def fn_npu_f32(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): +def fn_npu_f32(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): xidx = tl.arange(0, XB) yidx = tl.arange(0, YB) zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - ret = tl.zeros((XB, YB, ZB), dtype=tl.float32) oidx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] @@ -44,15 +40,11 @@ def fn_npu_f32(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.con @triton.jit -def fn_npu_f16(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): +def fn_npu_f16(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): xidx = tl.arange(0, XB) yidx = tl.arange(0, YB) zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - ret = tl.zeros((XB, YB, ZB), dtype=tl.float16) oidx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] @@ -61,15 +53,11 @@ def fn_npu_f16(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.con @triton.jit -def fn_npu_i8(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): +def fn_npu_i8(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): xidx = tl.arange(0, XB) yidx = tl.arange(0, YB) zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - ret = tl.zeros((XB, YB, ZB), dtype=tl.int8) oidx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] @@ -87,17 +75,14 @@ def fn_npu_i8(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.cons ]) def test_case(param_list): dtype, shape, ncore, XB, YB, ZB = param_list - x0 = test_common.generate_tensor(shape, dtype) y_ref = torch.full((XB, YB, ZB), 0, dtype=eval('torch.' + dtype)).npu() - print(f"y_ref = {y_ref[0, 0, 0:4]}") y_cal = torch.randint(1, (XB, YB, ZB), dtype=eval('torch.' + dtype)).npu() if dtype == "float32": - fn_npu_f32[ncore, 1, 1](y_cal, x0, XB, YB, ZB) + fn_npu_f32[ncore, 1, 1](y_cal, XB, YB, ZB) elif dtype == "float16": - fn_npu_f16[ncore, 1, 1](y_cal, x0, XB, YB, ZB) + fn_npu_f16[ncore, 1, 1](y_cal, XB, YB, ZB) else: - fn_npu_i8[ncore, 1, 1](y_cal, x0, XB, YB, ZB) - print(f"y_cal = {y_cal[0, 0, 0:4]}") + fn_npu_i8[ncore, 1, 1](y_cal, XB, YB, ZB) test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_zeroslike.py b/third_party/ascend/unittest/pytest_ut/test_zeroslike.py index 76ddf08f7e..6d6b4a822f 100644 --- a/third_party/ascend/unittest/pytest_ut/test_zeroslike.py +++ b/third_party/ascend/unittest/pytest_ut/test_zeroslike.py @@ -53,7 +53,7 @@ def fn_npu_(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.conste ]) def test_case(param_list): dtype, shape, ncore, XB, YB, ZB = param_list - x0 = test_common.generate_tensor(shape, dtype) + x0 = test_common.generate_tensor(shape, dtype).npu() y_ref = torch.zeros_like(x0, dtype=eval('torch.' + dtype)).npu() print(f"y_ref = {y_ref[0, 0, 0:4]}") y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu()