From b7ceafcc530e0cb687269d897983196fe6239dac Mon Sep 17 00:00:00 2001 From: Stardep <1486216685@qq.com> Date: Mon, 11 May 2026 02:08:39 +0800 Subject: [PATCH 1/5] sync: import triton-ascend release/3.2.1 updates --- bin/RegisterTritonDialects.h | 37 +- include/triton/Tools/Sys/GetEnv.hpp | 1 + python/src/llvm.cc | 6 +- python/triton/__init__.py | 2 + python/triton/backends/__init__.py | 5 + python/triton/extension/__init__.py | 20 + python/triton/extension/buffer/__init__.py | 20 + .../extension/buffer/language/__init__.py | 44 + .../extension/buffer/language/builder.py | 75 + .../triton/extension/buffer/language/core.py | 363 ++ .../extension/buffer/language/semantic.py | 158 + .../triton/extension/buffer/src/buffer_ir.cc | 169 + python/triton/tools/get_ascend_devices.py | 55 + third_party/ascend/CMakeLists.txt | 47 +- third_party/ascend/ascend_ir.cc | 375 +- .../ascend/backend/backend_register.py | 52 +- third_party/ascend/backend/compiler.py | 525 +- third_party/ascend/backend/driver.py | 120 +- .../ascend/backend/lib/libdevice.10.bc | Bin 0 -> 84908 bytes third_party/ascend/backend/npu_utils.cpp | 289 +- .../ascend/backend/runtime/autoparser.py | 344 +- .../ascend/backend/runtime/autotuner.py | 295 +- .../ascend/backend/runtime/tile_generator.py | 115 +- third_party/ascend/backend/runtime/utils.py | 4 +- .../spec/include/runtime/libentry/libentry.h | 1 - .../triton/Dialect/Triton/IR/OpInterfaces.h | 1 - .../spec/lib/runtime/libentry/libentry.cpp | 23 + .../backend/spec/triton/compiler/compiler.py | 17 +- .../backend/spec/triton/language/__init__.py | 134 +- .../backend/spec/triton/language/core.py | 17 +- .../backend/spec/triton/language/semantic.py | 55 +- .../spec/triton/runtime/_async_compile.py | 55 + .../spec/triton/runtime/ascend_interpreter.py | 735 +++ .../backend/spec/triton/runtime/autotuner.py | 253 +- .../backend/spec/triton/runtime/code_cache.py | 2 +- .../spec/triton/runtime/interpreter.py | 356 +- .../ascend/backend/spec/triton/runtime/jit.py | 67 +- .../backend/spec/triton/runtime/libentry.py | 4 +- third_party/ascend/backend/testing.py | 8 +- third_party/ascend/backend/utils.py | 45 +- .../include/AutoBlockify/AutoBlockify.h | 118 + .../include/AutoBlockify/CMakeLists.txt | 3 + .../ascend/include/AutoBlockify/Passes.h | 37 + .../ascend/include/AutoBlockify/Passes.td | 21 + .../ascend/include/AutoBlockify/Utils.h | 66 + third_party/ascend/include/CMakeLists.txt | 15 +- .../include/TritonAffinityOpt/CMakeLists.txt | 3 + .../ascend/include/TritonAffinityOpt/DAG.h | 330 + .../ascend/include/TritonAffinityOpt/Passes.h | 47 + .../include/TritonAffinityOpt/Passes.td | 29 + .../include/TritonAffinityOpt/Utils.hpp | 26 + third_party/ascend/language/cann/__init__.py | 8 +- .../language/cann/extension/__init__.py | 43 +- .../ascend/language/cann/extension/aux_ops.py | 65 +- .../ascend/language/cann/extension/builder.py | 1 + .../ascend/language/cann/extension/core.py | 680 +- .../language/cann/extension/custom_op.py | 104 +- .../ascend/language/cann/extension/mem_ops.py | 1187 ++-- .../language/cann/extension/semantic.py | 283 +- .../ascend/language/cann/extension/vec_ops.py | 1097 ++-- third_party/ascend/language/cann/libdevice.py | 1138 ++-- .../ascend/lib/AutoBlockify/AutoBlockify.cpp | 363 ++ .../ascend/lib/AutoBlockify/CMakeLists.txt | 22 + .../lib/AutoBlockify/RewriteOperation.cpp | 509 ++ third_party/ascend/lib/AutoBlockify/Utils.cpp | 211 + third_party/ascend/lib/CMakeLists.txt | 43 +- .../Conversion/TritonToLLVM/TritonToLLVM.cpp | 141 +- .../lib/TritonAffinityOpt/CMakeLists.txt | 19 + .../ascend/lib/TritonAffinityOpt/DAG.cpp | 534 ++ .../lib/TritonAffinityOpt/DAGSSBuffer.cpp | 5534 +++++++++++++++++ .../ascend/lib/TritonAffinityOpt/DAGScope.cpp | 1139 ++++ .../ascend/lib/TritonAffinityOpt/DAGSync.cpp | 1333 ++++ third_party/ascend/python/src/ir.cc | 25 +- third_party/ascend/triton_ascend.cc | 649 +- .../tutorials/03-matrix-multiplication.py | 217 + .../ascend/tutorials/04-low-memory-dropout.py | 139 + third_party/ascend/tutorials/05-layer-norm.py | 127 + .../ascend/tutorials/06-demo-autotune.py | 80 - .../ascend/tutorials/06-fused-attention.py | 365 ++ .../ascend/tutorials/07-extern-functions.py | 89 + third_party/ascend/tutorials/07-profiler.py | 184 - .../ascend/tutorials/08-grouped-gemm.py | 282 + .../ascend/tutorials/09-persistent-matmul.py | 337 + .../tutorials/15-embedding_gather_demo.py | 118 - .../TritonToLinalg/copy_use_analysis.mlir | 263 + .../TritonToLinalg/fixpipe_use_analysis.mlir | 421 ++ .../950PR/TritonToLinalg/if_use_analysis.mlir | 479 ++ .../General/AutoBlockify/auto_blockify.mlir | 134 + .../General/DiscreteMaskAccess/atomic.mlir | 201 + .../General/DiscreteMaskAccess/loadstore.mlir | 67 + .../simplify_for_loop.mlir | 121 + .../General/TritonToHFusion/fp_to_fp_rtz.mlir | 18 + .../General/TritonToHFusion/mod.mlir | 11 + .../sync_block_op_conversion.mlir | 20 + .../General/TritonToLinalg/atomic_rmw.mlir | 109 + .../TritonToLinalg/atomic_rmw_block.mlir | 46 + .../General/TritonToLinalg/legal_stride.mlir | 30 + .../General/TritonToLinalg/parse_select.mlir | 37 + .../TritonToStructured/CmpConverter.mlir | 24 + .../PromotePointerIterArgsPattern.mlir | 73 + .../TritonToStructured/SplatCmpConverter.mlir | 16 + .../General/TritonToStructured/parseCmp.mlir | 117 + .../TritonToStructured/parseConstant.mlir | 28 + .../TritonToStructured/parseMakeRange.mlir | 24 + .../General/TritonToStructured/parseRem.mlir | 81 + .../bubbleupoperation.mlir | 127 + .../TritonToUnstructure/if_simplifier.mlir | 45 + .../TritonToUnstructure/nested_loop.mlir | 207 + .../General/TritonToUnstructure/splat.mlir | 14 + .../TritonToUnstructure/unstructure_mix.mlir | 82 + .../ascend/unittest/affine_map/affine_map.py | 40 + .../affine_map/affine_map_buffer_type_demo.py | 27 + .../affine_map_complex_expr_demo.py | 36 + .../affine_map_indexing_map_demo.py | 33 + .../affine_map/affine_map_parse_demo.py | 29 + .../unittest/autotune_ut/01-vector-add.py | 91 + .../unittest/autotune_ut/02-fused-softmax.py | 107 + .../unittest/autotune_ut/03-layer-norm.py | 140 + .../unittest/autotune_ut/04-libentry.py | 101 + .../autotune_ut/test_autotune_param_valid.py | 175 + .../unittest/autotune_ut/test_common.py | 21 +- .../autotune_ut/test_customized_config.py | 99 + .../autotune_ut/test_low_dim_axes_parse.py | 59 + .../unittest/autotune_ut/test_mask_parse.py | 164 + .../autotune_ut/test_no_tiling_axis_parse.py | 99 + .../autotune_ut/test_reduction_axes_parse.py | 118 +- .../autotune_ut/test_split_axis_parse.py | 167 + .../autotune_ut/test_tiling_axis_parse.py | 135 + .../unittest/custom_op/builtin_ops_demo.py | 71 + .../unittest/custom_op/custom_op_demo.py | 119 + .../custom_op/custom_op_extra_buffer_demo.py | 117 + .../custom_op_indexing_map_complex_demo.py | 71 + .../custom_op_indexing_map_compose_demo.py | 73 + .../unittest/custom_op/test_gather_load.py | 38 + .../unittest/custom_op/test_index_select.py | 49 + .../unittest/generalization_cases/acc_util.py | 144 - .../unittest/generalization_cases/test_abs.py | 157 - .../generalization_cases/test_advance.py | 221 - .../unittest/generalization_cases/test_and.py | 161 - .../generalization_cases/test_argmax.py | 362 -- .../generalization_cases/test_argmin.py | 360 -- .../test_associative_scan.py | 523 -- .../generalization_cases/test_atomic_add.py | 576 -- .../generalization_cases/test_atomic_and.py | 562 -- .../generalization_cases/test_atomic_cas.py | 484 -- .../generalization_cases/test_atomic_max.py | 258 - .../generalization_cases/test_atomic_min.py | 264 - .../generalization_cases/test_atomic_or.py | 438 -- .../generalization_cases/test_atomic_xchg.py | 434 -- .../generalization_cases/test_atomic_xor.py | 441 -- .../generalization_cases/test_broadcast.py | 324 - .../generalization_cases/test_broadcast_to.py | 327 - .../generalization_cases/test_cast.py | 391 -- .../generalization_cases/test_cdiv.py | 162 - .../generalization_cases/test_ceil.py | 157 - .../generalization_cases/test_common.py | 343 - .../unittest/generalization_cases/test_cos.py | 157 - .../generalization_cases/test_count_dim0.py | 152 - .../generalization_cases/test_count_dim1.py | 151 - .../generalization_cases/test_cumprod.py | 254 - .../generalization_cases/test_cumsum.py | 253 - .../test_debug_barrier.py | 168 - .../generalization_cases/test_device_print.py | 106 - .../generalization_cases/test_div_rn.py | 154 - .../generalization_cases/test_dot_scaled.py | 281 - .../unittest/generalization_cases/test_eq.py | 128 - .../unittest/generalization_cases/test_erf.py | 172 - .../unittest/generalization_cases/test_exp.py | 155 - .../generalization_cases/test_exp2.py | 153 - .../generalization_cases/test_expand_dims.py | 223 - .../generalization_cases/test_fdiv.py | 155 - .../generalization_cases/test_full_op.py | 1096 ---- .../generalization_cases/test_ge_op.py | 158 - .../generalization_cases/test_general_add.py | 438 -- .../test_general_clamp.py | 179 - .../generalization_cases/test_general_div.py | 139 - .../test_general_floor.py | 148 - .../test_general_floordiv.py | 163 - .../generalization_cases/test_general_fma.py | 162 - .../test_general_gather.py | 150 - .../test_general_interleave.py | 216 - .../generalization_cases/test_general_join.py | 229 - .../generalization_cases/test_general_log.py | 148 - .../generalization_cases/test_general_log2.py | 153 - .../test_general_maximum.py | 153 - .../test_general_minimum.py | 153 - .../generalization_cases/test_general_mul.py | 155 - .../test_general_ravel.py | 189 - .../test_general_reshape.py | 160 - .../test_general_rsqrt.py | 151 - .../test_general_sigmoid.py | 175 - .../generalization_cases/test_general_sin.py | 173 - .../test_general_softmax.py | 179 - .../test_general_split.py | 224 - .../generalization_cases/test_general_sub.py | 155 - .../test_general_tensor_descriptor.py | 268 - .../generalization_cases/test_general_view.py | 159 - .../generalization_cases/test_gt_op.py | 158 - .../generalization_cases/test_invert.py | 166 - .../generalization_cases/test_le_op.py | 158 - .../generalization_cases/test_load_store.py | 168 - .../test_logical_and_op.py | 157 - .../test_logical_or_op.py | 156 - .../generalization_cases/test_lshift_op.py | 183 - .../generalization_cases/test_lt_op.py | 158 - .../test_make_blkptr_matmul.py | 74 - .../test_make_block_ptr.py | 212 - .../generalization_cases/test_matmul.py | 186 - .../unittest/generalization_cases/test_max.py | 313 - .../unittest/generalization_cases/test_min.py | 313 - .../unittest/generalization_cases/test_mod.py | 227 - .../unittest/generalization_cases/test_ne.py | 128 - .../unittest/generalization_cases/test_neg.py | 158 - .../unittest/generalization_cases/test_not.py | 173 - .../unittest/generalization_cases/test_or.py | 159 - .../test_permute_1d_2d.py | 95 - .../generalization_cases/test_permute_3d.py | 101 - .../test_permute_4d_5d.py | 223 - .../generalization_cases/test_rand.py | 304 - .../generalization_cases/test_range.py | 187 - .../generalization_cases/test_reduce.py | 338 - .../generalization_cases/test_relu.py | 66 - .../generalization_cases/test_rshift_op.py | 184 - .../test_scalar_tensor.py | 82 - .../generalization_cases/test_sort.py | 229 - .../generalization_cases/test_sqrt.py | 174 - .../generalization_cases/test_sqrt_rn.py | 174 - .../test_static_print_and_assert_op.py | 73 - .../unittest/generalization_cases/test_sum.py | 332 - .../generalization_cases/test_sum_dim0.py | 74 - .../generalization_cases/test_sum_dim1.py | 66 - .../generalization_cases/test_swizzle2d.py | 72 - .../generalization_cases/test_trans_1d_2d.py | 114 - .../generalization_cases/test_trans_3d.py | 144 - .../generalization_cases/test_trans_4d_5d.py | 223 - .../generalization_cases/test_umulhi.py | 147 - .../generalization_cases/test_where.py | 135 - .../unittest/generalization_cases/test_xor.py | 174 - .../generalization_cases/test_xorsum.py | 281 - .../generalization_cases/test_zeros_op.py | 534 -- .../generalization_cases/test_zeroslike.py | 170 - third_party/ascend/unittest/kernels/README.md | 62 - .../ascend/unittest/kernels/common_kernel.py | 7 - .../ascend/unittest/kernels/test_common.py | 111 - .../unittest/kernels/test_triton_kernel.py | 73 - .../unittest/kernels/vllm/expand_kernel.py | 33 - .../vllm/rejection_random_sample_kernel.py | 55 - .../vllm/sample_recovered_tokens_kernel.py | 77 - .../unittest/pytest_ut/test_01_vector_add.py | 87 + .../pytest_ut/test_02_fused_softmax.py | 129 + .../test_03_matrix_multiplication.py | 228 + .../pytest_ut/test_04_low_memory_dropout.py | 134 + .../unittest/pytest_ut/test_05_layer_norm.py | 120 + .../pytest_ut/test_06_fused_attention.py | 352 ++ .../pytest_ut/test_07_extern_functions.py | 84 + .../pytest_ut/test_08_grouped_gemm.py | 288 + .../pytest_ut/test_09_persistent_matmul.py | 326 + .../pytest_ut/test_10_gather_sorted.py | 195 + .../unittest/pytest_ut/test_11_rab_time.py | 407 ++ .../pytest_ut/test_12_hstu_attention.py | 824 +++ ...test_13_matrix_multiplication_optimized.py | 239 + .../pytest_ut/test_14_accuracy_comparison.py | 149 + .../pytest_ut/test_15_demo_autotune.py | 90 + .../unittest/pytest_ut/test_16_profiler.py | 125 + .../pytest_ut/test_17_demo_libentry.py | 131 + .../unittest/pytest_ut/test_18_gather.py | 134 + .../ascend/unittest/pytest_ut/test_add.py | 17 + .../unittest/pytest_ut/test_address_check.py | 69 + .../unittest/pytest_ut/test_advance_ptr.py | 65 + .../pytest_ut/test_affine_map_binding.py | 114 + .../ascend/unittest/pytest_ut/test_alloc.py | 11 +- .../ascend/unittest/pytest_ut/test_arch.py | 93 + .../ascend/unittest/pytest_ut/test_argmax.py | 65 + .../ascend/unittest/pytest_ut/test_argmin.py | 65 + .../ascend/unittest/pytest_ut/test_asm.py | 153 +- .../unittest/pytest_ut/test_asm_scalar.py | 38 + .../ascend/unittest/pytest_ut/test_assume1.py | 34 + .../unittest/pytest_ut/test_atomic_add.py | 73 +- .../unittest/pytest_ut/test_atomic_and.py | 15 +- .../unittest/pytest_ut/test_atomic_cas.py | 109 +- .../unittest/pytest_ut/test_atomic_max.py | 25 +- .../unittest/pytest_ut/test_atomic_min.py | 20 +- .../pytest_ut/test_atomic_rmw_useanalysis.py | 70 + .../unittest/pytest_ut/test_block_ptr.py | 107 +- .../unittest/pytest_ut/test_boundary_check.py | 317 + .../unittest/pytest_ut/test_cat_help_func.py | 720 +++ .../unittest/pytest_ut/test_celoss_indices.py | 108 + .../unittest/pytest_ut/test_compile_hint.py | 9 +- .../unittest/pytest_ut/test_complex_mask.py | 69 + .../ascend/unittest/pytest_ut/test_copy.py | 11 +- .../ascend/unittest/pytest_ut/test_cumprod.py | 2 +- .../ascend/unittest/pytest_ut/test_cumsum.py | 2 +- .../ascend/unittest/pytest_ut/test_custom.py | 218 +- .../pytest_ut/test_discrete_mask_atomic.py | 48 + .../pytest_ut/test_discrete_mask_loadstore.py | 379 +- .../test_discrete_mask_tail_block_mte_oob.py | 249 + .../pytest_ut/test_discrete_overlap_mask.py | 232 + .../ascend/unittest/pytest_ut/test_dot.py | 112 +- .../ascend/unittest/pytest_ut/test_erfinv.py | 26 + .../ascend/unittest/pytest_ut/test_expm1.py | 8 +- .../test_fast_dividef.py} | 124 +- .../test_fast_expf.py} | 124 +- .../ascend/unittest/pytest_ut/test_gamma.py | 18 + .../unittest/pytest_ut/test_if_advance.py | 55 + .../ascend/unittest/pytest_ut/test_if_load.py | 84 + .../pytest_ut/test_implicit_atomic.py | 140 + .../pytest_ut/test_implicit_permute.py | 1081 ++++ .../test_indirect_scalar_load_offset.py | 92 + .../pytest_ut/test_interleave_optimizaiton.py | 130 + .../ascend/unittest/pytest_ut/test_lgamma.py | 28 + .../unittest/pytest_ut/test_linearize_mask.py | 52 +- .../test_makeblockptr_negative_padding.py | 140 + .../ascend/unittest/pytest_ut/test_mod.py | 18 +- .../unittest/pytest_ut/test_mul_reduce.py | 62 + .../unittest/pytest_ut/test_multibuffer.py | 75 + .../test_negative_mask_dim.py} | 60 +- .../unittest/pytest_ut/test_nextafter.py | 10 +- .../pytest_ut/test_paged_kvcache_krope.py | 80 + .../unittest/pytest_ut/test_parallel.py | 1 + .../pytest_ut/test_permuted_boundary_check.py | 60 + .../unittest/pytest_ut/test_reduce_maximum.py | 332 + ...reduce_min_4_keepdim_True_with_index_op.py | 98 + .../unittest/pytest_ut/test_reduce_minimum.py | 332 + .../unittest/pytest_ut/test_runtime_utils.py | 33 + .../unittest/pytest_ut/test_scalar_calc.py | 2 +- .../pytest_ut/test_select_analysis.py | 142 + .../test_select_analysis_for_invert.py | 170 + .../ascend/unittest/pytest_ut/test_signbit.py | 17 + .../pytest_ut/test_simplify_iterargs.py | 80 + .../unittest/pytest_ut/test_sink_broadcast.py | 91 + .../unittest/pytest_ut/test_sync_block.py | 3 +- .../unittest/pytest_ut/test_use_analysis.py | 73 + .../ascend/unittest/pytest_ut/test_zeros.py | 27 +- .../unittest/pytest_ut/test_zeroslike.py | 2 +- 334 files changed, 35135 insertions(+), 27810 deletions(-) create mode 100644 python/triton/extension/__init__.py create mode 100644 python/triton/extension/buffer/__init__.py create mode 100644 python/triton/extension/buffer/language/__init__.py create mode 100644 python/triton/extension/buffer/language/builder.py create mode 100644 python/triton/extension/buffer/language/core.py create mode 100644 python/triton/extension/buffer/language/semantic.py create mode 100644 python/triton/extension/buffer/src/buffer_ir.cc create mode 100644 python/triton/tools/get_ascend_devices.py create mode 100644 third_party/ascend/backend/lib/libdevice.10.bc create mode 100644 third_party/ascend/backend/spec/triton/runtime/_async_compile.py create mode 100644 third_party/ascend/backend/spec/triton/runtime/ascend_interpreter.py create mode 100644 third_party/ascend/include/AutoBlockify/AutoBlockify.h create mode 100644 third_party/ascend/include/AutoBlockify/CMakeLists.txt create mode 100644 third_party/ascend/include/AutoBlockify/Passes.h create mode 100644 third_party/ascend/include/AutoBlockify/Passes.td create mode 100644 third_party/ascend/include/AutoBlockify/Utils.h create mode 100644 third_party/ascend/include/TritonAffinityOpt/CMakeLists.txt create mode 100644 third_party/ascend/include/TritonAffinityOpt/DAG.h create mode 100644 third_party/ascend/include/TritonAffinityOpt/Passes.h create mode 100644 third_party/ascend/include/TritonAffinityOpt/Passes.td create mode 100644 third_party/ascend/include/TritonAffinityOpt/Utils.hpp create mode 100644 third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp create mode 100644 third_party/ascend/lib/AutoBlockify/CMakeLists.txt create mode 100644 third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp create mode 100644 third_party/ascend/lib/AutoBlockify/Utils.cpp create mode 100644 third_party/ascend/lib/TritonAffinityOpt/CMakeLists.txt create mode 100644 third_party/ascend/lib/TritonAffinityOpt/DAG.cpp create mode 100644 third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp create mode 100644 third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp create mode 100644 third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp create mode 100644 third_party/ascend/tutorials/03-matrix-multiplication.py create mode 100644 third_party/ascend/tutorials/04-low-memory-dropout.py create mode 100644 third_party/ascend/tutorials/05-layer-norm.py delete mode 100644 third_party/ascend/tutorials/06-demo-autotune.py create mode 100644 third_party/ascend/tutorials/06-fused-attention.py create mode 100644 third_party/ascend/tutorials/07-extern-functions.py delete mode 100644 third_party/ascend/tutorials/07-profiler.py create mode 100644 third_party/ascend/tutorials/08-grouped-gemm.py create mode 100644 third_party/ascend/tutorials/09-persistent-matmul.py delete mode 100644 third_party/ascend/tutorials/15-embedding_gather_demo.py create mode 100644 third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/copy_use_analysis.mlir create mode 100644 third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/fixpipe_use_analysis.mlir create mode 100644 third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/if_use_analysis.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/AutoBlockify/auto_blockify.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/atomic.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/loadstore.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonAscendAllPass/simplify_for_loop.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToHFusion/fp_to_fp_rtz.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToHFusion/mod.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToHIVM/sync_block_op_conversion.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw_block.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToLinalg/legal_stride.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToLinalg/parse_select.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToStructured/CmpConverter.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToStructured/PromotePointerIterArgsPattern.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToStructured/SplatCmpConverter.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToStructured/parseCmp.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToStructured/parseConstant.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToStructured/parseMakeRange.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToStructured/parseRem.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToUnstructure/bubbleupoperation.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToUnstructure/if_simplifier.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToUnstructure/nested_loop.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToUnstructure/splat.mlir create mode 100644 third_party/ascend/unittest/Conversion/General/TritonToUnstructure/unstructure_mix.mlir create mode 100644 third_party/ascend/unittest/affine_map/affine_map.py create mode 100644 third_party/ascend/unittest/affine_map/affine_map_buffer_type_demo.py create mode 100644 third_party/ascend/unittest/affine_map/affine_map_complex_expr_demo.py create mode 100644 third_party/ascend/unittest/affine_map/affine_map_indexing_map_demo.py create mode 100644 third_party/ascend/unittest/affine_map/affine_map_parse_demo.py create mode 100644 third_party/ascend/unittest/autotune_ut/01-vector-add.py create mode 100644 third_party/ascend/unittest/autotune_ut/02-fused-softmax.py create mode 100644 third_party/ascend/unittest/autotune_ut/03-layer-norm.py create mode 100644 third_party/ascend/unittest/autotune_ut/04-libentry.py create mode 100644 third_party/ascend/unittest/autotune_ut/test_autotune_param_valid.py create mode 100644 third_party/ascend/unittest/autotune_ut/test_customized_config.py create mode 100644 third_party/ascend/unittest/autotune_ut/test_low_dim_axes_parse.py create mode 100644 third_party/ascend/unittest/autotune_ut/test_mask_parse.py create mode 100644 third_party/ascend/unittest/autotune_ut/test_no_tiling_axis_parse.py create mode 100644 third_party/ascend/unittest/autotune_ut/test_split_axis_parse.py create mode 100644 third_party/ascend/unittest/autotune_ut/test_tiling_axis_parse.py create mode 100644 third_party/ascend/unittest/custom_op/builtin_ops_demo.py create mode 100644 third_party/ascend/unittest/custom_op/custom_op_demo.py create mode 100644 third_party/ascend/unittest/custom_op/custom_op_extra_buffer_demo.py create mode 100644 third_party/ascend/unittest/custom_op/custom_op_indexing_map_complex_demo.py create mode 100644 third_party/ascend/unittest/custom_op/custom_op_indexing_map_compose_demo.py create mode 100644 third_party/ascend/unittest/custom_op/test_gather_load.py create mode 100644 third_party/ascend/unittest/custom_op/test_index_select.py delete mode 100644 third_party/ascend/unittest/generalization_cases/acc_util.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_abs.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_advance.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_and.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_argmax.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_argmin.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_associative_scan.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_atomic_add.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_atomic_and.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_atomic_cas.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_atomic_max.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_atomic_min.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_atomic_or.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_atomic_xchg.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_atomic_xor.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_broadcast.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_broadcast_to.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_cast.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_cdiv.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_ceil.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_common.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_cos.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_count_dim0.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_count_dim1.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_cumprod.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_cumsum.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_debug_barrier.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_device_print.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_div_rn.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_dot_scaled.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_eq.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_erf.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_exp.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_exp2.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_expand_dims.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_fdiv.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_full_op.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_ge_op.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_add.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_clamp.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_div.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_floor.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_floordiv.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_fma.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_gather.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_interleave.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_join.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_log.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_log2.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_maximum.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_minimum.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_mul.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_ravel.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_reshape.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_rsqrt.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_sigmoid.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_sin.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_softmax.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_split.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_sub.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_tensor_descriptor.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_general_view.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_gt_op.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_invert.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_le_op.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_load_store.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_logical_and_op.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_logical_or_op.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_lshift_op.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_lt_op.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_make_blkptr_matmul.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_make_block_ptr.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_matmul.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_max.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_min.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_mod.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_ne.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_neg.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_not.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_or.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_permute_1d_2d.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_permute_3d.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_permute_4d_5d.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_rand.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_range.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_reduce.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_relu.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_rshift_op.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_scalar_tensor.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_sort.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_sqrt.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_sqrt_rn.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_static_print_and_assert_op.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_sum.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_sum_dim0.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_sum_dim1.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_swizzle2d.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_trans_1d_2d.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_trans_3d.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_trans_4d_5d.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_umulhi.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_where.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_xor.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_xorsum.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_zeros_op.py delete mode 100644 third_party/ascend/unittest/generalization_cases/test_zeroslike.py delete mode 100644 third_party/ascend/unittest/kernels/README.md delete mode 100644 third_party/ascend/unittest/kernels/common_kernel.py delete mode 100644 third_party/ascend/unittest/kernels/test_common.py delete mode 100644 third_party/ascend/unittest/kernels/test_triton_kernel.py delete mode 100644 third_party/ascend/unittest/kernels/vllm/expand_kernel.py delete mode 100644 third_party/ascend/unittest/kernels/vllm/rejection_random_sample_kernel.py delete mode 100644 third_party/ascend/unittest/kernels/vllm/sample_recovered_tokens_kernel.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_01_vector_add.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_02_fused_softmax.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_03_matrix_multiplication.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_04_low_memory_dropout.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_05_layer_norm.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_06_fused_attention.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_07_extern_functions.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_08_grouped_gemm.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_09_persistent_matmul.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_10_gather_sorted.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_11_rab_time.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_12_hstu_attention.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_13_matrix_multiplication_optimized.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_14_accuracy_comparison.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_15_demo_autotune.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_16_profiler.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_17_demo_libentry.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_18_gather.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_address_check.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_advance_ptr.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_affine_map_binding.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_arch.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_argmax.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_argmin.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_asm_scalar.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_assume1.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_atomic_rmw_useanalysis.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_boundary_check.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_cat_help_func.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_celoss_indices.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_complex_mask.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_discrete_mask_atomic.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_discrete_mask_tail_block_mte_oob.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_discrete_overlap_mask.py rename third_party/ascend/unittest/{generalization_cases/test_tan.py => pytest_ut/test_fast_dividef.py} (66%) rename third_party/ascend/unittest/{generalization_cases/test_log1p.py => pytest_ut/test_fast_expf.py} (66%) create mode 100644 third_party/ascend/unittest/pytest_ut/test_if_advance.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_if_load.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_implicit_atomic.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_implicit_permute.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_indirect_scalar_load_offset.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_interleave_optimizaiton.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_makeblockptr_negative_padding.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_mul_reduce.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_multibuffer.py rename third_party/ascend/unittest/{generalization_cases/test_general_arange.py => pytest_ut/test_negative_mask_dim.py} (52%) create mode 100644 third_party/ascend/unittest/pytest_ut/test_paged_kvcache_krope.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_permuted_boundary_check.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_reduce_maximum.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_reduce_min_4_keepdim_True_with_index_op.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_reduce_minimum.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_runtime_utils.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_select_analysis.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_select_analysis_for_invert.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_simplify_iterargs.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_sink_broadcast.py create mode 100644 third_party/ascend/unittest/pytest_ut/test_use_analysis.py diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index ff1f96c1d1..d4c6294b06 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -1,12 +1,16 @@ #pragma once - -#ifdef __AMD__ -#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" -#include "amd/include/TritonAMDGPUTransforms/Passes.h" -#endif -#ifdef __NVIDIA__ -#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" -#endif +#include "ascend/include/TritonToLinalg/Passes.h" +#include "ascend/include/DiscreteMaskAccessConversion/Passes.h" +#include "ascend/include/TritonToStructured/Passes.h" +#include "ascend/include/TritonToAnnotation/Passes.h" +#include "ascend/include/TritonToUnstructure/Passes.h" +#include "ascend/include/TritonToHIVM/Passes.h" +#include "ascend/include/TritonToHFusion/Passes.h" +#include "ascend/include/TritonToLLVM/Passes.h" +#include "ascend/include/AutoBlockify/Passes.h" +// #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +// #include "amd/include/TritonAMDGPUTransforms/Passes.h" +// #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #ifdef __NVIDIA__ @@ -66,11 +70,18 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { #endif mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::registerAllocateSharedMemoryPass(); -#ifdef __NVIDIA__ - mlir::triton::registerConvertTritonGPUToLLVMPass(); - mlir::triton::registerConvertNVGPUToLLVMPass(); - mlir::triton::registerDecomposeUnsupportedNVIDIAConversions(); -#endif + // mlir::triton::registerConvertTritonGPUToLLVMPass(); + // mlir::triton::registerConvertNVGPUToLLVMPass(); + // mlir::triton::registerDecomposeUnsupportedNVIDIAConversions(); + mlir::triton::registerTritonToLinalgPasses(); + mlir::triton::registerDiscreteMaskAccessConversion(); + mlir::triton::registerTritonToStructuredPasses(); + mlir::triton::registerTritonToAnnotationPasses(); + mlir::triton::registerTritonToUnstructurePasses(); + mlir::triton::registerTritonToHIVMPasses(); + mlir::triton::registerTritonToHFusionPasses(); + mlir::triton::registerTritonToLLVMPasses(); + mlir::triton::registerAutoBlockifyPasses(); mlir::registerLLVMDIScope(); #ifdef __AMD__ diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index e5132b6d36..e13de65d94 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -20,6 +20,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "DISABLE_PTXAS_OPT", "LLVM_IR_ENABLE_DUMP", "LLVM_ENABLE_TIMING", + "MLIR_DISABLE_MULTITHREADING", "LLVM_PASS_PLUGIN_PATH", "MLIR_ENABLE_DIAGNOSTICS", "MLIR_ENABLE_DUMP", diff --git a/python/src/llvm.cc b/python/src/llvm.cc index f9b98a2540..ee4b222eaa 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -25,6 +25,7 @@ #include "llvm/Transforms/InstCombine/InstCombine.h" #include #include +#include #include #include #include @@ -172,7 +173,7 @@ void init_triton_llvm(py::module &&m) { [](llvm::Module::FunctionListType &s) { return py::make_iterator(s.begin(), s.end()); }, - py::keep_alive<0, 1>()); + py::keep_alive<0, 1>(), py::call_guard()); // Module Flag behavior. See // https://llvm.org/doxygen/classllvm_1_1Module.html#a0a5c55e12c97b80021330fe82b642293 @@ -388,7 +389,8 @@ void init_triton_llvm(py::module &&m) { // (optional) parameters py::arg("arch") = "", py::arg("features") = "", py::arg("flags") = std::vector{}, - py::arg("enable_fp_fusion") = false); + py::arg("enable_fp_fusion") = false, + py::call_guard()); m.def( "translate_to_asm", diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 6b6228f0b4..83d599e0cd 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -23,6 +23,7 @@ MockTensor, ) from .runtime.jit import jit +from .runtime._async_compile import AsyncCompileMode from .compiler import compile, CompilationError from .errors import TritonError @@ -31,6 +32,7 @@ from . import tools __all__ = [ + "AsyncCompileMode", "autotune", "cdiv", "CompilationError", diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index 92ba144ba9..9dd06e623f 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -35,7 +35,12 @@ class Backend: def _discover_backends(): backends = dict() root = os.path.dirname(__file__) + # The package does not ship the files required to load the + # upstream nvidia and amd backends, so skip discovering them here. + ignored_dirs = {"nvidia", "amd"} for name in os.listdir(root): + if name in ignored_dirs: + continue if not os.path.isdir(os.path.join(root, name)): continue if name.startswith('__'): diff --git a/python/triton/extension/__init__.py b/python/triton/extension/__init__.py new file mode 100644 index 0000000000..006c0ba6ab --- /dev/null +++ b/python/triton/extension/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. \ No newline at end of file diff --git a/python/triton/extension/buffer/__init__.py b/python/triton/extension/buffer/__init__.py new file mode 100644 index 0000000000..006c0ba6ab --- /dev/null +++ b/python/triton/extension/buffer/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. \ No newline at end of file diff --git a/python/triton/extension/buffer/language/__init__.py b/python/triton/extension/buffer/language/__init__.py new file mode 100644 index 0000000000..9d02b82c93 --- /dev/null +++ b/python/triton/extension/buffer/language/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = [ + # core + "builtin", + "is_builtin", + + # buffer + "buffer", + + # base address space + "address_space", + + # alloc + "alloc", + + # to_buffer + "to_buffer", + + # to_tensor + "to_tensor", + "subview", +] + +from .core import builtin, is_builtin, address_space, buffer, alloc, to_buffer, to_tensor, subview diff --git a/python/triton/extension/buffer/language/builder.py b/python/triton/extension/buffer/language/builder.py new file mode 100644 index 0000000000..f94ed54f15 --- /dev/null +++ b/python/triton/extension/buffer/language/builder.py @@ -0,0 +1,75 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +buffer-specific builder utilities for code generation. +""" + +__all__ = [ + "create_builder_method_wrapper_with_buffer_builder", + "attach_builder_methods_with_buffer_builder", + "setup_unified_builder_with_buffer_builder", +] + + +def create_builder_method_wrapper_with_buffer_builder(main_builder, delegate_builder, method_name): + """ + Create a wrapper that delegates a method call to another builder while + synchronizing insertion points and locations. + """ + delegate_method = getattr(delegate_builder, method_name) + + def wrapper(*args, **kwargs): + saved_ip = main_builder.get_insertion_point() + saved_loc = main_builder.get_loc() + delegate_builder.restore_insertion_point(saved_ip) + if saved_loc: + delegate_builder.set_loc(saved_loc) + result = delegate_method(*args, **kwargs) + main_builder.restore_insertion_point(saved_ip) + if saved_loc: + main_builder.set_loc(saved_loc) + return result + + wrapper.__name__ = method_name + wrapper.__doc__ = getattr(delegate_method, '__doc__', None) + return wrapper + + +def attach_builder_methods_with_buffer_builder(main_builder, delegate_builder, method_names): + """Attach multiple methods from a delegate builder to the main builder.""" + for method_name in method_names: + wrapper = create_builder_method_wrapper_with_buffer_builder(main_builder, delegate_builder, method_name) + setattr(main_builder, method_name, wrapper) + + +def setup_unified_builder_with_buffer_builder(main_builder, buffer_builder): + """Set up a unified builder interface by attaching methods from specialized builders.""" + main_builder._buffer_builder = buffer_builder + buffer_methods = [ + 'get_null_attr', + 'get_str_array_attr', + 'alloc', + 'to_buffer', + 'to_tensor', + 'subview', + ] + attach_builder_methods_with_buffer_builder(main_builder, buffer_builder, buffer_methods) + diff --git a/python/triton/extension/buffer/language/core.py b/python/triton/extension/buffer/language/core.py new file mode 100644 index 0000000000..eb9d4397c4 --- /dev/null +++ b/python/triton/extension/buffer/language/core.py @@ -0,0 +1,363 @@ +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = [ + "address_space", + "buffer_type", + "subview", + "alloc", + "buffer", + "to_buffer", + "to_tensor", +] + +import importlib +from typing import TypeVar, List +from functools import wraps + +from triton._C.libtriton import ir +import triton.language.core as tl +from triton.language import semantic as real_semantic + + +T = TypeVar("T") + +TRITON_BUILTIN = "__triton_builtin__" +BUFFER_BUILTIN = "__buffer_builtin__" + + +def builtin(fn: T) -> T: + """Mark a function as a buffer language builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + # also set triton_builtin to true so that CodeGenerator will recognize this function + setattr(wrapper, TRITON_BUILTIN, True) + setattr(wrapper, BUFFER_BUILTIN, True) + + return wrapper + + +def is_builtin(fn) -> bool: + """Is this a registered buffer language builtin function?""" + return getattr(fn, BUFFER_BUILTIN, False) + + +class address_space: + """Represents a buffer's address space. + + The :code:`address_space` of a buffer is a target-specific attribute. + """ + + def to_ir(self, builder: ir.builder) -> ir.type: + raise NotImplementedError( + "Abstract address_space cannot be converted to ir" + ) + + +class buffer_type(tl.dtype): + + def __init__(self, element_ty: tl.dtype, shape: List, space: address_space = None, strides: List = None): + self.element_ty = element_ty + self.shape = shape if isinstance(shape, list) else list(shape) + self.space = space + self.strides = strides if strides is not None else [] + self.name = self._make_name() + + def _make_name(self): + res = '' + + def to_ir(self, builder: ir.builder) -> ir.type: + element_ty_ir = self.element_ty.to_ir(builder) + addr_space_attr = self.space.to_ir(builder) if self.space else builder.get_null_attr() + + # use the method with strides if strides is not empty + if self.strides: + return builder.get_buffer_ty_with_strides(self.shape, element_ty_ir, self.strides, addr_space_attr) + else: + return builder.get_buffer_ty(self.shape, element_ty_ir, addr_space_attr) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def __eq__(self, other) -> bool: + if not isinstance(other, buffer_type): + return False + return (self.element_ty == other.element_ty and + self.shape == other.shape and + self.space == other.space and + self.strides == other.strides) + + def __ne__(self, other) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +# ----------------------- +# buffer +# ----------------------- + + +class buffer(tl._value): + """Represents a region of memory. + + :code:`buffer` is the fundamental data structure for Triton programs using + the buffer language extension. Most functions in + :py:mod:`triton.extension.buffer.language` operate on and return buffers. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, buffer_ty: buffer_type): + """Not called by user code.""" + super().__init__(handle) + self.type = buffer_ty + self.dtype = buffer_ty.element_ty.scalar + self.shape = buffer_ty.shape + self.space = buffer_ty.space + self.strides = buffer_ty.strides + + def __str__(self) -> str: + # ex. "<16x32xfloat32, address_space>" + res = '<' + 'x'.join(str(s) + for s in self.shape) + 'x' + str(self.dtype) + if self.space: + res += ', ' + str(self.space) + return res + '>' + + @builtin + def subview( + self, + offsets: List[tl.constexpr], + sizes: List[tl.constexpr], + strides: List[tl.constexpr], + _builder=None + ) -> 'buffer': + return subview(self, offsets, sizes, strides, _builder=_builder) + + @builtin + def to_tensor(self, writable=True, target_shape=None, _builder=None): + """Convert this buffer to a tl.tensor""" + return to_tensor(self, writable=writable, target_shape=target_shape, _builder=_builder) + + +semantic = importlib.import_module(".semantic", package=__package__) + + +@builtin +def alloc( + etype: tl.dtype, + shape: List[tl.constexpr], + _address_space: address_space = None, + is_mem_unique: bool = False, + _builder=None +) -> buffer: + """ + Allocates a region of local memory with the specified shape and type. + + :param etype: the element type of the buffer. + :type etype: tl.dtype + :param shape: A list of non-negative integers representing the shape of the buffer. + :type shape: List[tl.constexpr] + :param _address_space: (Optional) backend-specific local memory address space + :type _address_space: bl.address_space + """ + return semantic.alloc(etype, shape, _address_space, is_mem_unique, _builder) + + +@builtin +def to_buffer( + tensor: tl.tensor, + space: address_space = None, + bind_buffer: buffer = None, + _builder=None +) -> buffer: + """ + Convert a tensor to a buffer. + + :param tensor: the tensor to convert. + :type tensor: tl.tensor + :param space: the address space for the buffer (optional). + :type space: address_space + """ + return semantic.to_buffer( + tensor, space, bind_buffer, _builder + ) + + +@builtin +def to_tensor( + memref: buffer, + writable: bool = True, + target_shape=None, + _builder=None +) -> tl.tensor: + """ + Create a tl.tensor from a bl.buffer. + + :param memref: the input bl.buffer object. + :memref type: bl.buffer + :param writable: If set true, the resultant tensor is considered "writable" during bufferization. + :type writable: bool + """ + return semantic.to_tensor(memref, writable, _builder, target_shape=target_shape) + + +def check_subview(src, offsets, sizes, strides): + """ + Check data of subview methods which the data length and the offset value must be 32-byte aligned. + + The conditions for checking data are as follows: + 1. offset value must be 32-bytes aligned. + 2. all strides must be 1. + 3. the first point's offset in the second row of the last dimension must be 32-bytes aligned. + + For instance, the following example fails to satisfy the specified criteria. + %subview = memref.subview %arg0[1, 1][4, 4][2, 2] + : memref<8x8xf32, strided<[8, 1], offset: 0>> to + memref<4x4xf32, strided<[16, 2], offset: 9>> + offsets = [8, 8] | sizes = [4, 4] | strides = [2, 2] + result_offset = 9 + second_row_start_offset = 25 + The scene will be go wrong because the follow conditions are not meet. + 1) result_offset is not 32-bytes aligned. + 2) strides = [2, 2], not all strides are equal to 1. + 3) second_row_start_offset are not 32-bytes aligned. + """ + bytes_per_block = 32 + bits_per_byte = 8 + base_byte = bytes_per_block // (src.dtype.primitive_bitwidth // bits_per_byte) + result_strides = [] + result_offset = 0 + second_row_start_offset = 0 + length = len(strides) + src_strides = [1] * length + if length == 1: + if offset[0] % base_byte != 0: + raise TypeError(f"all strides should be 1 and the offset value should be 32-bytes aligned.") + return + for i in range(length - 2, -1, -1): + src_strides[i] = src_strides[i + 1] * src.shape[i + 1] + for i in range(0, length): + if isinstance(offsets[i], tl.tensor): + return + result_strides.append(src_strides[i] * strides[i]) + result_offset = result_offset + offsets[i] * src_strides[i] + second_row_start_offset = result_offset + src_strides[-2] * strides[-2] + is_unaligned = False + if sizes[1] > 1: + is_unaligned = second_row_start_offset % base_byte != 0 + stride_1 = all(s == 1 for s in strides) + is_unaligned = result_offset % base_byte != 0 or is_unaligned or not stride_1 + if is_unaligned: + raise TypeError(f"all strides should be 1 and the offset value should be 32-bytes aligned.") + + +@builtin +def subview( + src: buffer, + offsets: List[tl.constexpr], + sizes: List[tl.constexpr], + strides: List[tl.constexpr], + _builder=None +) -> buffer: + ''' + Creates a subview of the source buffer with the specified offsets, sizes, and strides. + + :param src: The source buffer to create a subview from. + :type src: buffer + :param offsets: A list of non-negative integers representing the offsets in each dimension. + :type offsets: List[tl.constexpr] + :param sizes: A list of non-negative integers representing the sizes in each dimension. + :type sizes: List[tl.constexpr] + :param strides: A list of non-negative integers representing the strides in each dimension. + :type strides: List[tl.constexpr] + :return: A new buffer representing the subview of the source buffer. + :rtype: buffer + ''' + # Validate that sizes and strides contain only constexpr values + new_sizes = [] + for i, size in enumerate(sizes): + if isinstance(size, int): + # Convert regular integers to constexpr + new_sizes.append(tl.constexpr(size)) + elif isinstance(size, tl.constexpr): + new_sizes.append(size) + else: + raise TypeError(f"sizes[{i}] must be constexpr, got {type(size).__name__}") + + new_strides = [] + for i, stride in enumerate(strides): + if isinstance(stride, int): + # Convert regular integers to constexpr + new_strides.append(tl.constexpr(stride)) + elif isinstance(stride, tl.constexpr): + new_strides.append(stride) + else: + raise TypeError(f"strides[{i}] must be constexpr, got {type(stride).__name__}") + + check_offsets = [] + new_offsets = [] + for offset in offsets: + if isinstance(offset, tl.constexpr): + # Check that constexpr offset values cannot be negative + if offset < 0: + raise ValueError(f"Offset value must be non-negative, got {offset}") + new_offsets.append(real_semantic.to_tensor(offset, _builder)) + check_offsets.append(offset) + elif isinstance(offset, int): + # Convert regular integers to constexpr and then to tensor + if offset < 0: + raise ValueError(f"Offset value must be non-negative, got {offset}") + new_offsets.append(real_semantic.to_tensor(tl.constexpr(offset), _builder)) + check_offsets.append(tl.constexpr(offset)) + else: + # Assume it's already a tensor + new_offsets.append(offset) + check_offsets.append(offset) + + check_subview(src, check_offsets, new_sizes, new_strides) + return semantic.subview(src, new_offsets, new_sizes, new_strides, _builder) diff --git a/python/triton/extension/buffer/language/semantic.py b/python/triton/extension/buffer/language/semantic.py new file mode 100644 index 0000000000..5dd526366a --- /dev/null +++ b/python/triton/extension/buffer/language/semantic.py @@ -0,0 +1,158 @@ +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from typing import ( + TypeVar, List +) + +from triton._C.libtriton import ir +import triton.language.core as tl + +from . import core as bl + + +T = TypeVar('T') + + +def alloc( + etype: tl.dtype, + shape: List[tl.constexpr], + address_space: bl.address_space, + is_mem_unique, + builder: ir.builder +) -> bl.buffer: + shape = tl._unwrap_shape(shape) + if etype == tl.int1: + raise TypeError("Unsupported alloc int1 type") + if not isinstance(shape, (tuple, list)): + raise TypeError("shape must be list/tuple") + etype = tl._constexpr_to_value(etype) + address_space = tl._constexpr_to_value(address_space) + element_ty_ir = etype.to_ir(builder) + addr_space_attr = ( + address_space.to_ir(builder) if address_space else builder.get_null_attr() + ) + memref_ty = builder.get_buffer_ty(shape, element_ty_ir, addr_space_attr) + handle = builder.alloc(memref_ty) + if is_mem_unique: + builder.create_annotation_mark(handle, "mem_unique", builder.get_unit_attr()) + builder.create_annotation_mark( + handle, "effects", builder.get_str_array_attr(["write", "read"]) + ) + + buffer_ty = bl.buffer_type(element_ty=etype, shape=shape, space=address_space) + return bl.buffer(handle, buffer_ty) + + +def to_buffer( + tensor: tl.tensor, + address_space: bl.address_space, + bind_buffer: bl.buffer, + builder: ir.builder, +) -> bl.buffer: + if not isinstance(tensor.shape, (tuple, list)) or not tensor.shape: + raise TypeError("scalar type cannot be converted to buffer") + if isinstance(bind_buffer, bl.buffer): + builder.create_bind_buffer(tensor.handle, bind_buffer.handle) + return bind_buffer + if not (bind_buffer is None): + raise ValueError("bind_buffer must be a buffer or None") + address_space = tl._constexpr_to_value(address_space) + addr_space_attr = ( + address_space.to_ir(builder) if address_space else builder.get_null_attr() + ) + handle = builder.to_buffer(tensor.handle, addr_space_attr) + buffer_ty = bl.buffer_type(element_ty=tensor.dtype, shape=tensor.shape, space=address_space) + return bl.buffer(handle, buffer_ty) + + +def to_tensor( + memref: bl.buffer, + writable: bool, + builder: ir.builder, + target_shape=None +) -> tl.tensor: + if not isinstance(memref, bl.buffer): + raise TypeError("memref must be bl.buffer") + + need_convert_layout = False + shape = memref.shape + if target_shape: + need_convert_layout = True + shape = tl._unwrap_shape(target_shape) + assert shape != memref.shape, "target shape is the same as source shape" + if not isinstance(shape, (tuple, list)): + raise TypeError("shape must be list/tuple") + tensor_type = tl.block_type(memref.dtype, shape) + + memref_value = memref.handle + if need_convert_layout: + buffer_ty = bl.buffer_type( + element_ty=memref.dtype, + shape=shape, + space=memref.space, + ) + memref_value = builder.create_convert_layout( + memref_value, buffer_ty.to_ir(builder)) + + return tl.tensor(builder.to_tensor(memref_value, writable), tensor_type) + + +def subview( + src: bl.buffer, + offsets: List[tl.tensor], + sizes: List[tl.constexpr], + strides: List[tl.constexpr], + builder: ir.builder +) -> bl.buffer: + + new_offsets = [offset.handle for offset in offsets] + sizes_int = tl._unwrap_shape(sizes) + strides_int = tl._unwrap_shape(strides) + + result_handle = builder.subview(src.handle, new_offsets, sizes_int, strides_int) + + # calculate the memory layout strides of the source buffer + if src.strides: + # use the strides of the source buffer + src_memory_strides = src.strides + else: + # calculate the default row-major strides + src_memory_strides = [] + stride = 1 + for dim_size in reversed(src.shape): + if dim_size < 0: + raise ValueError("Cannot compute strides for buffer with dynamic dimensions") + src_memory_strides.insert(0, stride) + stride *= dim_size + + result_memory_strides = [] + for src_stride, subview_stride in zip(src_memory_strides, strides_int): + result_memory_strides.append(src_stride * subview_stride) + + # create buffer_type with strides + buffer_ty = bl.buffer_type( + element_ty=src.dtype, + shape=sizes_int, + space=src.space, + strides=result_memory_strides + ) + return bl.buffer(result_handle, buffer_ty) diff --git a/python/triton/extension/buffer/src/buffer_ir.cc b/python/triton/extension/buffer/src/buffer_ir.cc new file mode 100644 index 0000000000..f1f07dda52 --- /dev/null +++ b/python/triton/extension/buffer/src/buffer_ir.cc @@ -0,0 +1,169 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * Copyright 2018-2020 Philippe Tillet + * Copyright 2020-2022 OpenAI + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +#include "ir.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +using namespace mlir; +namespace py = pybind11; + +constexpr unsigned kIntegerAttrBitWidth = 64; + +struct BufferOpBuilder : public TritonOpBuilder {}; + +void init_buffer_ir(py::module &&m) +{ + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_( + m, "buffer_builder", py::module_local(), py::dynamic_attr()) + .def(py::init()) + .def("get_null_attr", [](BufferOpBuilder &self) { return Attribute(); }) + .def("get_str_array_attr", + [](BufferOpBuilder &self, const std::vector &array) -> ArrayAttr { + auto strRefVec = to_vector(llvm::map_range(array, [](const auto &s) { return llvm::StringRef(s); })); + return self.getBuilder().getStrArrayAttr(llvm::ArrayRef {strRefVec}); + }) + .def("alloc", + [](BufferOpBuilder &self, Type memrefType) -> Value { + return self.create( + mlir::cast(memrefType)); + }) + .def("to_buffer", + [](BufferOpBuilder &self, Value &src, const Attribute &addressSpace) -> Value { + auto tensorType = dyn_cast(src.getType()); + if (!tensorType) { + llvm::report_fatal_error("to_buffer: src must be tensor type"); + } + auto memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType(), MemRefLayoutAttrInterface {}); + // TODO: We need to add a pass before OneShotBufferize to generate MemorySpaceCastOp + Operation *memref = self.create(memrefType, src); + if (addressSpace) { + memref = self.create( + MemRefType::get(memrefType.getShape(), memrefType.getElementType(), memrefType.getLayout(), + addressSpace), + memref->getResult(0)); + } + return memref->getResult(0); + }) + .def("to_tensor", + [](BufferOpBuilder &self, Value &src, bool writable) -> Value { + const auto &memrefType = mlir::cast(src.getType()); + auto hasAddressSpace = memrefType.getMemorySpace(); + if (hasAddressSpace) { + return self.create( + self.create( + MemRefType::get(memrefType.getShape(), + memrefType.getElementType(), + memrefType.getLayout()), + src), + true, writable); + } + return self.create(src, true, writable); + }) + .def("subview", + [](BufferOpBuilder &self, Value source, std::vector &offsets, + const std::vector &sizes, + const std::vector &strides) -> Value { + SmallVector mixedOffsets; + auto *context = self.getBuilder().getContext(); + auto &builder = self.getBuilder(); + + // Get source memref type for validation + auto sourceType = mlir::cast(source.getType()); + int64_t rank = sourceType.getRank(); + // Verify the number of parameters + if (offsets.size() != rank || sizes.size() != rank || + strides.size() != rank) { + throw std::runtime_error("Number of offsets, sizes, and strides " + "must match memref rank"); + } + + for (const auto &offset : offsets) { + auto indexType = builder.getIndexType(); + if (offset.getType() != indexType) { + Value offset_val = + self.create(indexType, offset); + mixedOffsets.push_back(offset_val); + } else { + mixedOffsets.push_back(offset); + } + } + + SmallVector mixedSizes; + SmallVector mixedStrides; + for (int64_t i = 0; i < rank; ++i) { + int64_t size = sizes[i]; + int64_t stride = strides[i]; + int64_t srcDim = sourceType.getDimSize(i); + + // verify sizes cannot be negative or zero + if (size <= 0) { + throw std::runtime_error("Expected sizes to be positive"); + } + + // verify strides cannot be negative or zero + if (stride <= 0) { + throw std::runtime_error("Expected strides to be positive"); + } + + // getDimSize() returns -1 (ShapedType::kDynamic) for dynamic dimensions + if (!ShapedType::isDynamic(srcDim)) { + // verify the subview size does not exceed the source dimension + if (size > srcDim) { + throw std::runtime_error( + "Subview size cannot exceed source dimension size"); + } + + // verify strides cannot exceed the source dimension size + if (stride > srcDim) { + throw std::runtime_error( + "Stride cannot exceed source dimension size"); + } + } + + mixedSizes.push_back( + IntegerAttr::get(IntegerType::get(context, kIntegerAttrBitWidth), size)); + mixedStrides.push_back( + IntegerAttr::get(IntegerType::get(context, kIntegerAttrBitWidth), stride)); + } + + return self.create(source, mixedOffsets, + mixedSizes, mixedStrides); + }); +} \ No newline at end of file diff --git a/python/triton/tools/get_ascend_devices.py b/python/triton/tools/get_ascend_devices.py new file mode 100644 index 0000000000..13c28cda81 --- /dev/null +++ b/python/triton/tools/get_ascend_devices.py @@ -0,0 +1,55 @@ +import os +import glob +import logging +import subprocess + +logger = logging.getLogger(__name__) + + +def get_ascend_devices(): + devices = [] + pci_path = '/sys/bus/pci/devices/*' + + for dev in glob.glob(pci_path): + try: + vendor_path = os.path.join(dev, 'vendor') + device_path = os.path.join(dev, 'device') + + if os.path.exists(vendor_path): + with open(vendor_path, 'r') as f: + vendor = f.read().strip() + + if vendor == "0x19e5" and os.path.exists(device_path): + with open(device_path, 'r') as f: + device = f.read().strip() + devices.append(device) + except (IOError, OSError) as e: + logger.warning(f"can not fetch device {dev}: {e}") + continue + + return devices + + +def check_npu_smi_device(): + try: + result = subprocess.run( + ["npu-smi", "info"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + shell=False, + timeout=100 + ) + if result.returncode == 0: + output = result.stdout.lower() + return "ascend910_95" in output or "ascend950" in output or "910_958b" in output + return False + except Exception as e: + logger.warning(f"can not use command: npu-smi info") + return False + + +ascend_devices = get_ascend_devices() +pci_condition = any("0xd806" in dev for dev in ascend_devices) +npu_smi_condition = check_npu_smi_device() +is_compile_on_910_95 = pci_condition or npu_smi_condition \ No newline at end of file diff --git a/third_party/ascend/CMakeLists.txt b/third_party/ascend/CMakeLists.txt index 1bf6cb570a..5672ac1e5c 100644 --- a/third_party/ascend/CMakeLists.txt +++ b/third_party/ascend/CMakeLists.txt @@ -12,9 +12,20 @@ include_directories(${CMAKE_BINARY_DIR}/third_party/flir/include) # set(BISHENGIR_ENABLE_A5_UNPUBLISHED_FEATURES ON) # set(BISHENGIR_BUILD_STANDALONE_IR_ONLY ON) -# add_subdirectory(${ASCENDNPU_IR_SRC_DIR} ${ASCENDNPU_IR_BINARY_DIR}) -# include_directories(${ASCENDNPU_IR_SRC_DIR}/bishengir/include) -# include_directories(${ASCENDNPU_IR_BINARY_DIR}/bishengir/include) # Tablegen'd files +# Temporarily save and clear the RULE_LAUNCH_* attributes in the current directory to ensure AscendNPU-IR is not affected. +get_property(_saved_launch_compile DIRECTORY PROPERTY RULE_LAUNCH_COMPILE) +get_property(_saved_launch_link DIRECTORY PROPERTY RULE_LAUNCH_LINK) +set_property(DIRECTORY PROPERTY RULE_LAUNCH_COMPILE "") +set_property(DIRECTORY PROPERTY RULE_LAUNCH_LINK "") + +add_subdirectory(${ASCENDNPU_IR_SRC_DIR} ${ASCENDNPU_IR_BINARY_DIR}) + +# restore properties +set_property(DIRECTORY PROPERTY RULE_LAUNCH_COMPILE ${_saved_launch_compile}) +set_property(DIRECTORY PROPERTY RULE_LAUNCH_LINK ${_saved_launch_link}) + +include_directories(${ASCENDNPU_IR_SRC_DIR}/bishengir/include) +include_directories(${ASCENDNPU_IR_BINARY_DIR}/bishengir/include) # Tablegen'd files add_subdirectory(backend/spec/lib) @@ -36,7 +47,6 @@ endif() add_triton_plugin(TritonAscend ${CMAKE_CURRENT_SOURCE_DIR}/triton_ascend.cc ${CMAKE_CURRENT_SOURCE_DIR}/ascend_ir.cc - LINK_LIBS TritonToLinalgIncubated BiShengIRScopeDialect @@ -44,7 +54,34 @@ add_triton_plugin(TritonAscend ${_MLIRMeshDialect_LIB} ) -# target_link_libraries(TritonAscend PRIVATE Python3::Module pybind11::headers) +option(TRITON_ENABLE_COVERAGE_LLVM_COV "Enable code llvm-cov coverage tool for Ascend plugin " OFF) +if(TRITON_ENABLE_COVERAGE_LLVM_COV) + message(STATUS "Enabling llvm-cov coverage tool flags for TritonAscend") + target_compile_options(TritonAscend PRIVATE + -fprofile-arcs + -ftest-coverage + -O0 + -fprofile-update=atomic + --coverage + ) + target_link_options(TritonAscend PRIVATE + --coverage + -lgcov + ) + # branch coverage + target_compile_definitions(TritonAscend PRIVATE + COVERAGE_ENABLED=1 + ) +endif() + + +# To enable hitest coverage tool +if(TRITON_ENABLE_COVERAGE_HITEST) + set_target_properties(TritonAscend PROPERTIES + RULE_LAUNCH_COMPILE "hitestwrapper" + RULE_LAUNCH_LINK "hitestwrapper" + ) +endif() if(TRITON_BUILD_UT) add_subdirectory(unittest) diff --git a/third_party/ascend/ascend_ir.cc b/third_party/ascend/ascend_ir.cc index 7cab69c452..25f5597a2d 100644 --- a/third_party/ascend/ascend_ir.cc +++ b/third_party/ascend/ascend_ir.cc @@ -24,6 +24,7 @@ #include "ir.h" #include "pybind11/pybind11.h" +#include #include #include "bishengir/Dialect/Annotation/IR/Annotation.h" @@ -33,6 +34,8 @@ #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" @@ -44,20 +47,43 @@ namespace py = pybind11; struct AscendNPUIROpBuilder : public TritonOpBuilder { std::string target; static constexpr char kTarget910_95[] = "Ascend910_95"; + static constexpr char kTarget950[] = "Ascend950"; explicit AscendNPUIROpBuilder(MLIRContext *context, std::string target = "") : TritonOpBuilder(context), target(target) {} - bool is_910_95() { + bool is_910_95() const + { // TODO: Use enum instead of strings after enabling HACC in satandalone // build - constexpr size_t kTargetLen = sizeof(kTarget910_95) - 1; - return target.size() >= kTargetLen && - target.compare(0, kTargetLen, kTarget910_95) == 0; + constexpr size_t kLen910 = sizeof(kTarget910_95) - 1; + bool match_910 = target.size() >= kLen910 && + target.compare(0, kLen910, kTarget910_95) == 0; + + constexpr size_t kLen950 = sizeof(kTarget950) - 1; + bool match_950 = target.size() >= kLen950 && + target.compare(0, kLen950, kTarget950) == 0; + + return match_910 || match_950; } }; namespace { +MLIRContext *gDefaultAscendContext = nullptr; + +MLIRContext *resolveContext(const py::object &contextObj) +{ + if (!contextObj.is_none()) { + return &py::cast(contextObj); + } + if (gDefaultAscendContext) { + return gDefaultAscendContext; + } + throw std::invalid_argument( + "No default MLIR context. Pass context explicitly or call " + "ascend_ir.load_dialects(context) first."); +} + struct ModeAndPipes { hivm::SyncBlockModeAttr modeAttr = {}; hivm::PipeAttr cubePipe = {}; @@ -143,6 +169,247 @@ ModeAndPipes GetSyncBlockModeAndPipes(MLIRContext *ctx, } // namespace void init_ascend_ir(py::module &&m) { + auto affineExprClass = + py::class_(m, "affine_expr", py::module_local()); + affineExprClass + .def("__str__", [](AffineExpr self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }) + .def("__repr__", [](AffineExpr self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return ""; + }) + .def("is_symbolic_or_constant", &AffineExpr::isSymbolicOrConstant) + .def("is_pure_affine", &AffineExpr::isPureAffine) + .def("is_function_of_dim", &AffineExpr::isFunctionOfDim) + .def("compose", + [](AffineExpr self, AffineMap map) { return self.compose(map); }) + .def("get_largest_known_divisor", &AffineExpr::getLargestKnownDivisor) + .def("floordiv", + [](AffineExpr self, AffineExpr other) { return self.floorDiv(other); }) + .def("ceildiv", + [](AffineExpr self, AffineExpr other) { return self.ceilDiv(other); }) + .def("mod", + [](AffineExpr self, AffineExpr other) { return self % other; }) + .def("__hash__", [](AffineExpr self) { + return py::int_(static_cast(mlir::hash_value(self))); + }) + .def("__eq__", + [](AffineExpr lhs, AffineExpr rhs) { return lhs == rhs; }) + .def(py::self + py::self) + .def(py::self - py::self) + .def(py::self * py::self) + .def(py::self % py::self); + affineExprClass + .def_static("get_constant", + [](int64_t val, py::object contextObj) { + auto *context = resolveContext(contextObj); + return getAffineConstantExpr(val, context); + }, + py::arg("value"), py::arg("context") = py::none()) + .def_static("get_dim", + [](uint32_t pos, py::object contextObj) { + auto *context = resolveContext(contextObj); + return getAffineDimExpr(pos, context); + }, + py::arg("pos"), py::arg("context") = py::none()) + .def_static("get_symbol", + [](uint32_t pos, py::object contextObj) { + auto *context = resolveContext(contextObj); + return getAffineSymbolExpr(pos, context); + }, + py::arg("pos"), py::arg("context") = py::none()); + + py::class_(m, "affine_constant_expr", + py::module_local()) + .def("get_value", &AffineConstantExpr::getValue); + py::class_(m, "affine_dim_expr", + py::module_local()) + .def("get_position", &AffineDimExpr::getPosition); + py::class_(m, "affine_symbol_expr", + py::module_local()) + .def("get_position", &AffineSymbolExpr::getPosition); + py::class_(m, "affine_binary_op_expr", + py::module_local()) + .def("get_lhs", &AffineBinaryOpExpr::getLHS) + .def("get_rhs", &AffineBinaryOpExpr::getRHS); + + auto affineMapClass = py::class_(m, "affine_map", py::module_local()); + affineMapClass + .def("__str__", [](AffineMap &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }) + .def("__repr__", [](AffineMap &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return ""; + }) + .def("is_identity", &AffineMap::isIdentity) + .def("is_permutation", &AffineMap::isPermutation) + .def("get_num_dims", &AffineMap::getNumDims) + .def("get_num_symbols", &AffineMap::getNumSymbols) + .def("get_num_results", &AffineMap::getNumResults) + .def("is_empty", &AffineMap::isEmpty) + .def("is_single_constant", &AffineMap::isSingleConstant) + .def("is_constant", &AffineMap::isConstant) + .def("get_constant_result", [](AffineMap &self) -> int64_t { + if (!self.isSingleConstant()) { + throw std::runtime_error("affine map is not a single constant map"); + } + return self.getSingleConstantResult(); + }) + .def("get_result", + [](AffineMap &self, uint32_t pos) { + if (pos >= self.getNumResults()) { + throw py::index_error("result index out of range"); + } + return self.getResult(pos); + }) + .def("get_sub_map", + [](AffineMap &self, const std::vector &resultPos) { + return self.getSubMap(resultPos); + }) + .def("replace", + [](AffineMap &self, AffineExpr expr, AffineExpr replacement, + uint32_t numResultDims, uint32_t numResultSymbols) { + return self.replace(expr, replacement, numResultDims, + numResultSymbols); + }) + .def("compose", [](AffineMap &self, AffineMap map) { + return self.compose(map); + }) + .def("get_results", [](AffineMap &self) -> std::vector { + auto results = self.getResults(); + return std::vector(results.begin(), results.end()); + }) + .def("__hash__", [](AffineMap &self) { + return py::int_(static_cast(mlir::hash_value(self))); + }) + .def("__eq__", [](AffineMap &lhs, AffineMap &rhs) { return lhs == rhs; }) + .def("inverse_permutation", [](AffineMap &self) -> py::object { + // Validate it's a permutation first + if (!self.isPermutation()) { + throw py::value_error("AffineMap must be a valid permutation to compute inverse"); + } + + // Returns AffineMap directly, not a pointer + AffineMap inverse = mlir::inversePermutation(self); + + // Check if result is valid (null AffineMap) + if (!inverse) { + throw py::value_error("Failed to compute inverse permutation"); + } + + return py::cast(inverse); + }) + .def("to_dict", [](AffineMap &self) -> py::dict { + py::list results; + for (AffineExpr result : self.getResults()) { + if (auto dimExpr = dyn_cast(result)) { + results.append(dimExpr.getPosition()); + } else { + std::string exprStr; + llvm::raw_string_ostream os(exprStr); + result.print(os); + results.append(py::str(exprStr)); + } + } + + py::dict ret; + ret["num_dims"] = self.getNumDims(); + ret["num_symbols"] = self.getNumSymbols(); + ret["results"] = std::move(results); + return ret; + }); + affineMapClass + .def_static("get", + [](int64_t numDims, int64_t numSymbols, const py::iterable &resultsIn, + py::object contextObj) -> AffineMap { + MLIRContext *context = nullptr; + if (numDims < 0 || numSymbols < 0) { + throw std::invalid_argument( + "num_dims and num_symbols must be non-negative"); + } + llvm::SmallVector results; + for (const auto &item : resultsIn) { + if (py::isinstance(item)) { + auto expr = py::cast(item); + if (!context) { + context = expr.getContext(); + } + results.push_back(expr); + continue; + } + if (py::isinstance(item)) { + if (!context) { + context = resolveContext(contextObj); + } + int64_t pos = py::cast(item); + if (pos < 0 || pos >= numDims) { + throw std::invalid_argument( + "result dim index is out of range for num_dims"); + } + results.push_back(getAffineDimExpr(pos, context)); + continue; + } + throw std::invalid_argument( + "results must contain affine_expr or int dim indices"); + } + if (!context) { + context = resolveContext(contextObj); + } + return AffineMap::get(numDims, numSymbols, results, context); + }, + py::arg("num_dims"), py::arg("num_symbols"), py::arg("result_dims"), + py::arg("context") = py::none()) + .def_static("get_identity", + [](int64_t numDims, py::object contextObj) -> AffineMap { + auto *context = resolveContext(contextObj); + if (numDims < 0) { + throw std::invalid_argument( + "num_dims must be non-negative"); + } + return AffineMap::getMultiDimIdentityMap(numDims, context); + }, + py::arg("num_dims"), py::arg("context") = py::none()) + .def_static("get_minor_identity", + [](int64_t dims, int64_t results, py::object contextObj) { + auto *context = resolveContext(contextObj); + if (dims < 0 || results < 0) { + throw std::invalid_argument( + "dims/results must be non-negative"); + } + return AffineMap::getMinorIdentityMap(dims, results, context); + }, + py::arg("dims"), py::arg("results"), + py::arg("context") = py::none()) + .def_static("get_empty", [](py::object contextObj) { + auto *context = resolveContext(contextObj); + return AffineMap::get(0, 0, {}, context); + }, py::arg("context") = py::none()) + .def_static("get_permutation", + [](const std::vector &permutation, + py::object contextObj) { + auto *context = resolveContext(contextObj); + return AffineMap::getPermutationMap(permutation, context); + }, + py::arg("permutation"), py::arg("context") = py::none()) + .def_static("get_constant", + [](int64_t value, py::object contextObj) { + auto *context = resolveContext(contextObj); + return AffineMap::getConstantMap(value, context); + }, + py::arg("value"), py::arg("context") = py::none()); + py::enum_(m, "AddressSpace", py::module_local()) .value("L1", hivm::AddressSpace::L1) .value("UB", hivm::AddressSpace::UB) @@ -175,6 +442,21 @@ void init_ascend_ir(py::module &&m) { .value("MIX", hivm::VFMode::MIX) .export_values(); + py::enum_(m, "IteratorType", py::module_local()) + .value("Parallel", hivm::IteratorType::kParallel) + .value("Broadcast", hivm::IteratorType::kBroadcast) + .value("Transpose", hivm::IteratorType::kTranspose) + .value("Reduction", hivm::IteratorType::kReduction) + .value("Interleave", hivm::IteratorType::kInterleave) + .value("Deinterleave", hivm::IteratorType::kDeinterleave) + .value("Inverse", hivm::IteratorType::kInverse) + .value("Pad", hivm::IteratorType::kPad) + .value("Concat", hivm::IteratorType::kConcat) + .value("Gather", hivm::IteratorType::kGather) + .value("Cumulative", hivm::IteratorType::kCumulative) + .value("Opaque", hivm::IteratorType::kOpaque) + .export_values(); + py::enum_(m, "FixpipeDMAMode", py::module_local()) .value("NZ2DN", hivm::FixpipeDMAMode::NZ2DN) .value("NZ2ND", hivm::FixpipeDMAMode::NZ2ND) @@ -209,9 +491,7 @@ void init_ascend_ir(py::module &&m) { .export_values(); m.def("load_dialects", [](MLIRContext &context) { - // Allow unregistered dialects so we can parse HACC attributes without - // registering the dialect - context.allowUnregisteredDialects(); + gDefaultAscendContext = &context; DialectRegistry registry; registry.insert(); @@ -223,6 +503,10 @@ void init_ascend_ir(py::module &&m) { m, "ascendnpu_ir_builder", py::module_local(), py::dynamic_attr()) .def(py::init(), py::arg("context"), py::arg("target") = "") + .def("get_int_attr", + [](AscendNPUIROpBuilder &self, int64_t value) -> Attribute { + return IntegerAttr::get(self.getBuilder().getI64Type(), value); + }) .def("get_core_type_attr", [](AscendNPUIROpBuilder &self, hivm::TCoreType core_type) -> Attribute { @@ -236,6 +520,13 @@ void init_ascend_ir(py::module &&m) { [](AscendNPUIROpBuilder &self, hivm::VFMode mode) -> Attribute { return self.getBuilder().getAttr(mode); }) + .def("get_iterator_types_attr", + [](AscendNPUIROpBuilder &self, const std::vector& array) { + auto attrs = llvm::to_vector(llvm::map_range(array, [&self](hivm::IteratorType type) { + return cast(self.getBuilder().getAttr(type)); + })); + return self.getBuilder().getArrayAttr(attrs); + }) .def("get_t_core_type_attr_name", [](AscendNPUIROpBuilder &self) -> std::string { return hivm::TCoreTypeAttr::name.str(); @@ -253,8 +544,32 @@ void init_ascend_ir(py::module &&m) { .def("parse_attr", [](TritonOpBuilder &self, std::string value) -> Attribute { auto *ctx = self.getBuilder().getContext(); + // Enable parsing of HACC attributes by allowing unregistered dialects. + ctx->allowUnregisteredDialects(); return mlir::parseAttribute(value, ctx); }) + .def("get_affine_map_attr", + [](AscendNPUIROpBuilder &self, AffineMap affineMap) -> Attribute { + return AffineMapAttr::get(affineMap); + }) + .def("get_affine_map_array_attr", + [](AscendNPUIROpBuilder &self, + const std::vector &affineMaps) -> Attribute { + auto *ctx = self.getBuilder().getContext(); + llvm::SmallVector attrs; + attrs.reserve(affineMaps.size()); + for (const auto &map : affineMaps) { + attrs.push_back(AffineMapAttr::get(map)); + } + return ArrayAttr::get(ctx, attrs); + }) + .def("get_buffer_ty_with_affine_map", + [](AscendNPUIROpBuilder &self, std::vector &shape, + Type &elementType, AffineMap affineMap, + const Attribute &memorySpace) -> Type { + auto layout = AffineMapAttr::get(affineMap); + return MemRefType::get(shape, elementType, layout, memorySpace); + }) .def("create_fixpipe", [](AscendNPUIROpBuilder &self, Value src, Value dst, hivm::FixpipeDMAMode dma_mode, @@ -281,6 +596,13 @@ void init_ascend_ir(py::module &&m) { mlir::TypeRange{}, src, dst, dma_mode_attr, dual_dst_mode_attr, pre_quant_mode_attr, pre_relu_mode_attr, channel_split); }) + .def("create_annotation_mark", + [](TritonOpBuilder &self, Value &ptr, const std::string &attrKey, + Attribute &attrVal) { + auto annotationOp = self.create(ptr); + annotationOp->setAttr(self.getBuilder().getStringAttr(attrKey), + attrVal); + }) .def("create_bind_buffer", [](TritonOpBuilder &self, Value &src, Value &alloc) -> void { auto ctx = self.getBuilder().getContext(); @@ -296,19 +618,48 @@ void init_ascend_ir(py::module &&m) { attrVal); }) .def("create_custom_op", - [](AscendNPUIROpBuilder &self, const std::string &name, - const py::dict &attrs, const std::vector &ins, - const std::vector &outs) -> std::vector { + [](AscendNPUIROpBuilder &self, + const std::string &name, + const py::dict &attrs, + const std::vector &ins, + const std::vector &outs, + const std::vector &arg_attrs) -> std::vector { ValueRange inputs{ins}; ValueRange outputs{outs}; + ValueRange temp_buffers{}; TypeRange res_types{outputs}; - auto op = - self.create(res_types, name, inputs, outputs); + auto op = self.create(res_types, name, inputs, outputs, temp_buffers); for (auto &attr : attrs) { std::string attr_name = py::cast(attr.first); Attribute attr_value = py::cast(attr.second); op->setAttr(attr_name, attr_value); } + + SmallVector dictAttrs(arg_attrs.size()); + Attribute emptyDict = self.getBuilder().getDictionaryAttr({}); + for (const auto &[idx, attrs] : llvm::enumerate(arg_attrs)) { + if (idx >= op.getNumOperands()) + continue; + + if (attrs.is_none()) { + dictAttrs[idx] = emptyDict; + continue; + } + + llvm::SmallVector namedAttrs; + for (const auto &attr : attrs) { + std::string attr_name = py::cast(attr.first); + Attribute attr_value = py::cast(attr.second); + namedAttrs.push_back( + NamedAttribute(self.getBuilder().getStringAttr(attr_name), attr_value)); + } + + dictAttrs[idx] = self.getBuilder().getDictionaryAttr(namedAttrs); + } + + ArrayAttr arg_attrs_array = self.getBuilder().getArrayAttr(dictAttrs); + op->setAttr("arg_attrs", arg_attrs_array); + auto results = op->getResults(); return std::vector(results.begin(), results.end()); }) diff --git a/third_party/ascend/backend/backend_register.py b/third_party/ascend/backend/backend_register.py index 480f2a7fed..ca2d39c401 100644 --- a/third_party/ascend/backend/backend_register.py +++ b/third_party/ascend/backend/backend_register.py @@ -33,7 +33,7 @@ def decorator(func: Callable): if category not in self.strategies: self.strategies[category] = {} if method in self.strategies[category]: - raise ValueError(f"Strategy {name} already registered") + raise ValueError(f"Strategy {method} already registered") self.strategies[category][method] = func return func @@ -164,7 +164,7 @@ def get_empty_tensor(size): @backend_strategy_registry.register("mindspore", "get_tensor_params_shape") -def get_tensor_params_shape(args): +def get_tensor_params_shape(*args): import mindspore tensor_params = [arg for arg in args if isinstance(arg, mindspore.Tensor)] tensor_params_shape = [] @@ -174,7 +174,7 @@ def get_tensor_params_shape(args): @backend_strategy_registry.register("torch_npu", "get_tensor_params_shape") -def get_tensor_params_shape(args): +def get_tensor_params_shape(*args): import torch tensor_params = [arg for arg in args if isinstance(arg, torch.Tensor)] tensor_params_shape = [] @@ -188,10 +188,13 @@ def get_cc_cmd(build_pch): import mindspore mindspore_path = os.path.dirname(os.path.realpath(mindspore.__file__)) cc_cmd = [ + f"-I{mindspore_path}", + f"-I{os.path.join(mindspore_path, 'include/')}", f"-I{os.path.join(mindspore_path, 'include/third_party')}", f"-I{os.path.join(mindspore_path, 'include/third_party/robin_hood_hashing')}", f"-I{os.path.join(mindspore_path, 'include/mindspore/core')}", f"-I{os.path.join(mindspore_path, 'include/mindspore/core/include')}", + f"-I{os.path.join(mindspore_path, 'include/mindspore/core/mindrt/include')}", f"-I{os.path.join(mindspore_path, 'include/mindspore/ccsrc')}", f"-I{os.path.join(mindspore_path, 'include/mindspore/ccsrc/include')}", f"-I{os.path.join(mindspore_path, 'include/mindspore/ops')}", @@ -255,23 +258,31 @@ def set_current_device(device_id): @backend_strategy_registry.register("mindspore", "get_current_stream") def get_current_stream(device): import mindspore - return mindspore.current_stream().id + try: + return mindspore.current_stream().stream_ptr() + except Exception: + return mindspore.current_stream().id @backend_strategy_registry.register("torch_npu", "get_current_stream") def get_current_stream(device): import torch import torch_npu - from torch_npu._C import _npu_getCurrentRawStream if device is None: device = torch.npu.current_device() - return _npu_getCurrentRawStream(device) + if hasattr(torch_npu._C, "_npu_getCurrentRawStreamNoWait"): + from torch_npu._C import _npu_getCurrentRawStreamNoWait + return _npu_getCurrentRawStreamNoWait(device) + else: + from torch_npu._C import _npu_getCurrentRawStream + return _npu_getCurrentRawStream(device) @backend_strategy_registry.register("mindspore", "header_file") def header_file(enable_taskqueue): return f'''#include "include/utils/device_manager_conf.h" #include "include/runtime/hardware_abstract/device_context/device_context_manager.h" +#include "include/mindspore/ops/kernel/ascend/aclnn/pyboost_impl/aclnn_utils.h" {'#include "include/pynative/utils/runtime/op_executor.h"' if {enable_taskqueue} else ''} {'#include "include/runtime/pipeline/pipeline.h"' if {enable_taskqueue} else ''}''' @@ -285,34 +296,43 @@ def header_file(enable_taskqueue): @backend_strategy_registry.register("mindspore", "allocate_memory") def allocate_memory(size, stream): - return f"device_context->device_res_manager_->AllocateMemory({size}, reinterpret_cast({stream}));" + return f'''auto work_ptr = std::make_shared(device_context, {size}, reinterpret_cast({stream})); + workspace_addr_ptr = work_ptr->ptr_;''' @backend_strategy_registry.register("torch_npu", "allocate_memory") -def allocate_memory(size, option): - return f"const_cast(at::empty({size}, {option}).storage().data());" +def allocate_memory(size, stream): + return f"workspace_addr_ptr = const_cast(at::empty({size}, at::TensorOptions().device(at::kPrivateUse1).dtype(at::kByte)).storage().data());" + + +@backend_strategy_registry.register("mindspore", "allocate_sync_block_lock") +def allocate_sync_block_lock(size, stream): + return f'''auto sync_ptr = std::make_shared(device_context, {size}, reinterpret_cast({stream})); + syncBlockLock_ptr = work_ptr->ptr_;''' @backend_strategy_registry.register("torch_npu", "allocate_sync_block_lock") def allocate_sync_block_lock(size, stream): - return f"const_cast(at_npu::native::allocate_workspace({size}, {stream}).storage().data());" + return f"syncBlockLock_ptr = const_cast(at_npu::native::allocate_workspace({size}, {stream}).storage().data());" @backend_strategy_registry.register("mindspore", "pre_launch") -def pre_launch(): - return '''static auto device_context = mindspore::device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({mindspore::device::DeviceType::kAscend, mindspore::DeviceManagerConf::GetInstance()->device_id()}); - device_context->device_res_manager_->BindDeviceToCurrentThread(false);''' +def pre_launch(first_call): + if first_call: + return '''static auto device_context = mindspore::device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({mindspore::device::DeviceType::kAscend, mindspore::DeviceManagerConf::GetInstance()->device_id()}); + device_context->device_res_manager_->BindDeviceToCurrentThread(false);''' + else: + return '''device_context->device_res_manager_->BindDeviceToCurrentThread(false);''' @backend_strategy_registry.register("torch_npu", "pre_launch") -def pre_launch(): +def pre_launch(first_call): return "" @backend_strategy_registry.register("mindspore", "async_launch") def async_launch(func): - return f'''mindspore::runtime::OpExecutor::DispatchLaunchTask({func}); - mindspore::runtime::Pipeline::Get().launch_stage()->Wait();''' + return f'''mindspore::runtime::OpExecutor::DispatchLaunchTask({func});''' @backend_strategy_registry.register("torch_npu", "async_launch") diff --git a/third_party/ascend/backend/compiler.py b/third_party/ascend/backend/compiler.py index 2bb1c550af..1f04b4ec9f 100644 --- a/third_party/ascend/backend/compiler.py +++ b/third_party/ascend/backend/compiler.py @@ -21,6 +21,7 @@ import ctypes import functools import hashlib +import glob import os import re import subprocess @@ -37,6 +38,7 @@ _check_bishengir_is_regbased, _enable_unpublished_feature, _enable_print_ub_bits, + _enable_dump_memory_info, _get_kernel_target, _get_llvm_path, _get_mlir_path, @@ -47,6 +49,7 @@ _is_auto_map_parallel_blocks_enabled, downgrade_llir, force_disable_ffts, + triton_enable_libdevice_simt, ) from triton.backends.ascend.driver import (NPUUtils) from triton.backends.compiler import ( @@ -57,12 +60,8 @@ ) from triton.runtime import driver from triton.runtime.cache import get_dump_manager +from triton.tools.get_ascend_devices import is_compile_on_910_95 -try: - import acl - is_compile_on_910_95 = acl.get_soc_name().startswith("Ascend910_95") -except Exception as e: - is_compile_on_910_95 = False # TODO: materialize the concrete min shape @@ -96,6 +95,73 @@ def make_ttir(mod, metadata, opt): def ttir_to_linalg(mod, metadata, opt, *, named_ops=False): # use triton_adapter to lower Triton-MLIR to linalg # Get Triton-MLIR as string + ttir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "kernel.ttir.mlir") + dst_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") + Path(src_path).write_text(ttir_code) + triton_adapter_opt_path = _get_triton_adapter_opt_path() + + enable_nd2nz_on_vector = metadata["enable_nd2nz_on_vector"] + enable_select_analysis = metadata["enable_select_analysis"] + compile_on_910_95 = metadata["compile_on_910_95"] + force_simt_template = metadata["force_simt_template"] + enable_sync_block_lock = metadata["enable_sync_block_lock"] + enable_mask_fallback_conversion = metadata["enable_mask_fallback_conversion"] + optimize_dynamic_offset = metadata["optimize_dynamic_offset"] + auto_blockify_size = metadata["auto_blockify_size"] + if not _is_auto_map_parallel_blocks_enabled(): + auto_blockify_size = 1 + pm = ir.pass_manager(mod.context) + pm.enable_debug() + ascend.passes.ttir.add_auto_blockify( + pm, + auto_blockify_size + ) + if (metadata["add_auto_scheduling"]): + ascend.passes.ttir.add_dag_sync(pm) + ascend.passes.ttir.add_dag_scope(pm) + passes.common.add_cse(pm) + passes.common.add_canonicalizer(pm) + ascend.passes.ttir.add_dag_ssbuffer(pm) + passes.common.add_cse(pm) + passes.common.add_canonicalizer(pm) + + ascend.passes.ttir.add_triton_to_structure( + pm, + enable_mask_fallback_conversion, + optimize_dynamic_offset + ) + ascend.passes.ttir.add_discrete_mask_access_conversion( + pm, + compile_on_910_95, + force_simt_template, + enable_sync_block_lock + ) + ascend.passes.ttir.add_triton_to_annotation(pm) + ascend.passes.ttir.add_triton_to_unstructure( + pm, + compile_on_910_95, + force_simt_template + ) + ascend.passes.ttir.add_triton_to_hivm(pm) + ascend.passes.ttir.add_triton_to_hfusion(pm) + ascend.passes.ttir.add_triton_to_llvm(pm) + ascend.passes.ttir.add_bubble_up_operation(pm) + ascend.passes.ttir.add_triton_to_structure( + pm, + enable_mask_fallback_conversion, + optimize_dynamic_offset + ) + ascend.passes.ttir.add_triton_to_linalg( + pm, + False, + named_ops, + enable_nd2nz_on_vector, + enable_select_analysis, + compile_on_910_95 + ) + pm.run(mod) enable_nd2nz_on_vector = metadata["enable_nd2nz_on_vector"] enable_select_analysis = metadata["enable_select_analysis"] @@ -129,104 +195,6 @@ def ttir_to_linalg(mod, metadata, opt, *, named_ops=False): return str(mod) -def linalg_to_llir(linalg: str, metadata, opt): - with tempfile.TemporaryDirectory() as tmpdir: - ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") - llmlir_path = os.path.join(tmpdir, "kernel.llir.mlir") - llir_path = os.path.join(tmpdir, "kernel.ll") - Path(ttadapter_path).write_text(linalg) - mlir_opt_path = _get_mlir_path("bin", "mlir-opt") - # TritonAdapter-MLIR to LLVM-MLIR - subprocess.check_call([ - mlir_opt_path, - ttadapter_path, - "--convert-linalg-to-affine-loops", - "--eliminate-empty-tensors", - "--empty-tensor-to-alloc-tensor", - "--one-shot-bufferize=allow-return-allocs-from-loops=true", - "--lower-affine", - "--convert-linalg-to-loops", - "--convert-scf-to-cf", - "--convert-cf-to-llvm", - "--convert-arith-to-llvm", - "--convert-math-to-llvm", - "--convert-complex-to-llvm", - "--convert-vector-to-llvm", - "--convert-index-to-llvm", - "--memref-expand", - "--expand-strided-metadata", - "--finalize-memref-to-llvm", - "--convert-func-to-llvm", - # Lowering memrefs creates more affine.apply ops. - # Lowering these affine ops again creates further arith ops, - # so we have to run these two passes again here. - "--lower-affine", - "--convert-arith-to-llvm", - # Remove all unrealized casts created - "--reconcile-unrealized-casts", - "-o", - llmlir_path, - ]) - if opt.debug: - dump_manager = get_dump_manager(metadata["hash"]) - dump_manager.put(Path(llmlir_path).read_text(), "kernel.llir.mlir", binary=False) - - # LLVM-MLIR to LLVM-IR - mlir_translate_path = _get_mlir_path("bin", "mlir-translate") - subprocess.check_call([mlir_translate_path, llmlir_path, "--mlir-to-llvmir", "-o", llir_path]) - if opt.debug: - dump_manager = get_dump_manager(metadata["hash"]) - dump_manager.put(Path(llir_path).read_text(), "kernel.ll", binary=False) - - return Path(llir_path).read_text() - - -def llir_to_cpuasm(llir: str, metadata, opt): - # add metadata at final stage - # Note: Compiled Kernel requires to estimate size of shared memory to occupy - # Currently, CPU backend requires no limit on shared memory size - metadata["shared"] = 1 - # We can get a function name (C naming) from - # LLVM-IR by getting the first "define void @". - fn_name = llir.split("define void @")[1].split("(")[0].strip() - metadata["name"] = fn_name + " cpu" - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "kernel.ll") - linked_path = os.path.join(tmpdir, "kernel_linked.ll") - dst_path = os.path.join(tmpdir, "kernel.s") - - llir = downgrade_llir(llir) - if opt.debug: - dump_manager = get_dump_manager(metadata["hash"]) - dump_manager.put(llir, "kernel_downgrade.ll", binary=False) - - Path(src_path).write_text(llir) - - linker_path = _get_llvm_path("bin", "llvm-link") - libclc_path = _get_llvm_path("lib", "clc", "libspirv-aarch64--.bc") - subprocess.check_call([ - linker_path, - src_path, - libclc_path, - "--only-needed", - "-S", - "-o", - linked_path, - ]) - if opt.debug: - dump_manager = get_dump_manager(metadata["hash"]) - dump_manager.put(Path(linked_path).read_text(), "kernel_linked.ll", binary=False) - - llc_path = _get_llvm_path("bin", "llc") - subprocess.check_call([llc_path, linked_path, "-o", dst_path]) - if opt.debug: - dump_manager = get_dump_manager(metadata["hash"]) - dump_manager.put(Path(dst_path).read_text(), "kernel.s", binary=False) - - # Actually it's text-format assembly. Use read_text(). - return Path(dst_path).read_text() - - def __get_metadata_attr_by_callback(lib, postfix: str, metadata, meta_key: str): func_symbol = metadata["kernel_name"] + postfix if hasattr(lib, func_symbol): @@ -264,8 +232,8 @@ def _parse_linalg_metadata(linalg: str, metadata: dict): # Example: %arg1: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32} -> ('1', '0') TENSOR_KIND_REGEX = r'%arg(\d+):[^,)]*?\{[^}]*?tt\.tensor_kind\s*=\s*([^:\s}]+)\s*:[^}]*?\}' - # Example removal: ', mix_mode = "aiv"' → '' - REMOVE_MIX_MODE_REGEX = r', mix_mode\s*=\s*"[^"]*"' + # Example: bitcode = "a.bc" + BITCODES_REGEX = r'bitcode\s*=\s*(?:"([^"]+)"|\'([^\']+)\'|(\w+))' # Note: Compiled Kernel requires to estimate size of shared memory to occupy # Currently, NPU backend does not limit on shared memory @@ -276,15 +244,17 @@ def _parse_linalg_metadata(linalg: str, metadata: dict): metadata["mix_mode"] = re.search(MIX_MODE_REGEX, linalg).group(1) metadata["parallel_mode"] = re.search(PARALLEL_MODE_REGEX, linalg).group(1) metadata["kernel_name"] = re.search(KERNEL_NAME_REGEX, linalg).group(1) - # Use while space to split kernel_name and mix_mode. + # Use while "_" to split kernel_name and mix_mode. # Check the function load_binary in npu_driver.py. - metadata["name"] = metadata["kernel_name"] + " " + metadata["mix_mode"] + metadata["name"] = metadata["kernel_name"] + "_" + metadata["mix_mode"] # Parse all tensor kinds from arguments metadata["tensor_kinds"] = [int(kind) for _, kind in re.findall(TENSOR_KIND_REGEX, linalg)] # init the ub bits of triton kernel for inductor autotune using metadata["required_ub_bits"] = 0 - # remove the mix_mode attribute - linalg = re.sub(REMOVE_MIX_MODE_REGEX, "", linalg) + + # Parse all bitcode paths + bitcodes = re.findall(BITCODES_REGEX, linalg) + metadata["bitcodes"] = [val for group in bitcodes for val in group if val] return linalg, metadata @@ -308,7 +278,7 @@ def _parse_ttir_metadata(ttir: str, metadata: dict): # Note: Currently, for TTIR inputs, we only support vector kernels. metadata["mix_mode"] = "aiv" metadata["kernel_name"] = re.search(KERNEL_NAME_REGEX, ttir).group(1) - metadata["name"] = metadata["kernel_name"] + " " + metadata["mix_mode"] + metadata["name"] = metadata["kernel_name"] + "_" + metadata["mix_mode"] # Parse all tensor kinds from arguments metadata["tensor_kinds"] = [int(kind) for _, kind in re.findall(TENSOR_KIND_REGEX, ttir)] return metadata @@ -320,6 +290,35 @@ def get_common_bishengir_compile_options(metadata): return [bishengir_target_opt] +def get_auto_bind_sub_block_option(metadata): + # auto_tile_and_bind_subblock is read from the module. + # enable_auto_bind_sub_block is set by the user and has a higher priority. + enable_auto_bind_sub_block = metadata["enable_auto_bind_sub_block"] + return ( + metadata["auto_tile_and_bind_subblock"] + if enable_auto_bind_sub_block is None + else enable_auto_bind_sub_block + ) + + +def _save_npuir_debug_output(stdout_bytes: bytes, stderr_bytes: bytes, tmpdir: str, metadata_hash: str): + stdout = stdout_bytes.decode('utf-8') if stdout_bytes else '' + stderr = stderr_bytes.decode('utf-8') if stderr_bytes else '' + combined = stdout + stderr + if not combined.strip(): + combined = "No output captured." + output_path = os.path.join(tmpdir, "kernel.npuir.mlir") + with open(output_path, 'w', encoding='utf-8') as f: + f.write(combined) + + dump_manager = get_dump_manager(metadata_hash) + dump_manager.put( + Path(output_path).read_text(encoding='utf-8'), + "kernel.npuir.mlir", + binary=False + ) + + def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): linalg, metadata = _parse_linalg_metadata(linalg, metadata) with tempfile.TemporaryDirectory() as tmpdir: @@ -340,17 +339,14 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): f"--enable-auto-multi-buffer={multibuffer}", ] - enable_ubuf_saving = metadata["enable_ubuf_saving"] - if enable_ubuf_saving is not None: - _compile_option_list += [ - f"--enable-ubuf-saving={enable_ubuf_saving}", - ] + disable_tightly_coupled_buffer_reuse = metadata["disable_tightly_coupled_buffer_reuse"] + if disable_tightly_coupled_buffer_reuse: + _compile_option_list += ["--disable-tightly-coupled-buffer-reuse"] + + _compile_option_list += [ + f"--enable-auto-bind-sub-block={get_auto_bind_sub_block_option(metadata)}", + ] - enable_auto_bind_sub_block = metadata["enable_auto_bind_sub_block"] - if enable_auto_bind_sub_block is not None: - _compile_option_list += [ - f"--enable-auto-bind-sub-block={enable_auto_bind_sub_block}", - ] if force_disable_ffts(): _compile_option_list += ["--disable-ffts"] if _is_ascend_sanitizer_enabled(): @@ -366,6 +362,11 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): _compile_option_list += \ [f"--enable-hivm-auto-cv-balance={enable_hivm_auto_cv_balance}"] + sync_solver = metadata["sync_solver"] + if sync_solver is not None: + _compile_option_list += \ + [f"--enable-hivm-graph-sync-solver={sync_solver}"] + unit_flag = metadata["unit_flag"] if unit_flag is not None: _compile_option_list += \ @@ -376,6 +377,11 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): _compile_option_list += \ [f"--enable-hivm-inject-barrier-all-sync={inject_barrier_all}"] + inject_block_all = metadata["inject_block_all"] + if inject_block_all is not None: + _compile_option_list += \ + [f"--enable-hivm-inject-block-all-sync={inject_block_all}"] + limit_auto_multi_buffer_only_for_local_buffer = metadata["limit_auto_multi_buffer_only_for_local_buffer"] if limit_auto_multi_buffer_only_for_local_buffer is not None: _compile_option_list += \ @@ -386,16 +392,6 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): _compile_option_list += \ [f"--set-workspace-multibuffer={set_workspace_multibuffer}"] - tile_mix_vector_loop = metadata["tile_mix_vector_loop"] - if tile_mix_vector_loop is not None: - _compile_option_list += \ - [f"--tile-mix-vector-loop={tile_mix_vector_loop}"] - - tile_mix_cube_loop = metadata["tile_mix_cube_loop"] - if tile_mix_cube_loop is not None: - _compile_option_list += \ - [f"--tile-mix-cube-loop={tile_mix_cube_loop}"] - auto_multi_buffer = metadata["limit_auto_multi_buffer_of_local_buffer"] if auto_multi_buffer is not None: _compile_option_list += \ @@ -404,33 +400,55 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): enable_mixed_cv = metadata["enable_mixed_cv"] if enable_mixed_cv is not None: _compile_option_list += \ - [f"--enable_mixed_cv={enable_mixed_cv}"] + [f"--enable-mixed-cv={enable_mixed_cv}"] enable_cce_vf_auto_sync = metadata["enable_cce_vf_auto_sync"] if enable_cce_vf_auto_sync is not None: _compile_option_list += \ - [f"--apend-bisheng-options=-mllvm --cce-vf-auto-sync={enable_cce_vf_auto_sync}"] + [f"--append-bisheng-options=-mllvm --cce-vf-auto-sync={enable_cce_vf_auto_sync}"] enable_cce_vf_remove_membar = metadata["enable_cce_vf_remove_membar"] if enable_cce_vf_remove_membar is not None: _compile_option_list += \ - [f"--apend-bisheng-options=-mllvm --cce-vf-remove-membar={enable_cce_vf_remove_membar}"] + [f"--append-bisheng-options=-mllvm --cce-vf-remove-membar={enable_cce_vf_remove_membar}"] + + if metadata["enable_vf_fusion"]: + _compile_option_list += ["--enable-vf-fusion"] enable_drop_unit_dims = metadata["enable_drop_unit_dims"] if enable_drop_unit_dims is not None: _compile_option_list += \ [f"--enable-drop-unit-dims={enable_drop_unit_dims}"] + enable_flatten = metadata["enable_flatten"] + if enable_flatten is not None: + _compile_option_list += \ + [f"--enable-flatten={enable_flatten}"] + enable_auto_vectorize_v2 = metadata["enable_auto_vectorize_v2"] if enable_auto_vectorize_v2 is not None: _compile_option_list += \ [f"--enable-auto-vectorize-v2={enable_auto_vectorize_v2}"] + auto_vectorize_v2_max_fused_ops_num = metadata["auto_vectorize_v2_max_fused_ops_num"] + if auto_vectorize_v2_max_fused_ops_num is not None: + _compile_option_list += \ + [f"--hfusion-max-fused-ops-in-auto-vectorize-v2={auto_vectorize_v2_max_fused_ops_num}"] + prevec_max_fused_ops_num = metadata["prevec_max_fused_ops_num"] + if prevec_max_fused_ops_num is not None: + _compile_option_list += \ + [f"--hfusion-max-fused-elementwise-ops={prevec_max_fused_ops_num}"] disable_auto_inject_block_sync = metadata["disable_auto_inject_block_sync"] if disable_auto_inject_block_sync is not None: _compile_option_list += \ [f"--disable-auto-inject-block-sync={disable_auto_inject_block_sync}"] + bitcodes = metadata["bitcodes"] + if bitcodes is not None: + for bitcode in bitcodes: + _compile_option_list += \ + [f"--link-aicore-bitcode={bitcode}"] + if _is_auto_map_parallel_blocks_enabled(): _compile_option_list += ["--enable-auto-blockify-loop"] npu_compiler_path, env = _get_npucompiler_path() @@ -445,20 +463,54 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): mix_mode = opt.mix_mode if mix_mode in ["aic"]: _compile_option_list += ["--disable-hfusion-vectorize=true"] - cmd_list = ([npu_compiler_path, ttadapter_path] + _compile_option_list + ["-o", bin_file]) - # TODO both bishengir-compile and triton-compile use passing attr by module - auto_tile_and_bind_subblock = metadata["auto_tile_and_bind_subblock"] - if auto_tile_and_bind_subblock is False: - cmd_list += ["--enable-auto-bind-sub-block=false"] + + if opt.debug: + _compile_option_list += ["--bishengir-print-ir-after=hivm-graph-sync-solver"] + + cmd_list = ( + [npu_compiler_path, ttadapter_path] + + _compile_option_list + + ["-o", bin_file] + ) vf_merge_level = metadata["vf_merge_level"] - if vf_merge_level: + if vf_merge_level is not None: cmd_list += [f"--enable-vf-merge-level={vf_merge_level}"] - ret = subprocess.run(cmd_list, env=env, capture_output=True, check=True) - match = re.search(r'UB\s+size\s*=\s*(\d+)\s*bits', ret.stdout.decode('utf-8')) + hfusion_enable_multiple_consumer_fusion = metadata["hfusion_enable_multiple_consumer_fusion"] + if hfusion_enable_multiple_consumer_fusion: + cmd_list += [f"--hfusion-enable-multiple-consumer-fusion={hfusion_enable_multiple_consumer_fusion}"] + + if opt.debug: + print(f"[DEBUG] cmd_list: {' '.join(cmd_list)}") + + try: + ret = subprocess.run( + cmd_list, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True + ) + except subprocess.CalledProcessError as e: + if opt.debug: + _save_npuir_debug_output(e.stdout, e.stderr, tmpdir, metadata["hash"]) + raise + + if opt.debug: + _save_npuir_debug_output(ret.stdout, ret.stderr, tmpdir, metadata["hash"]) + + stdout_str = ret.stdout.decode('utf-8') if ret.stdout else '' + match = re.search(r'UB\s+size\s*=\s*(\d+)\s*bits', stdout_str) if match: # get the ub bits of triton kernel from bisheng for inductor autotune using metadata["required_ub_bits"] = int(match.group(1)) + + if not Path(bin_path).exists(): + error_msg = ret.stderr.decode('utf-8') if ret.stderr else '' + print(f"[DEBUG] {bin_path} is not found") + print(f"[DEBUG] Stderr:\n{error_msg}") + raise subprocess.CalledProcessError(ret.returncode, cmd_list, ret.stdout, ret.stderr) + if Path(callback_path).is_file(): lib = ctypes.CDLL(callback_path) __get_metadata_attr_by_callback(lib, "_infer_task_type_function", metadata, "bs_task_type") @@ -501,11 +553,16 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): f"--enable-ubuf-saving={enable_ubuf_saving}", ] - enable_auto_bind_sub_block = metadata["enable_auto_bind_sub_block"] - if enable_auto_bind_sub_block is not None: + enable_preload = metadata["enable_preload"] + if enable_preload is not None: _compile_option_list += [ - f"--enable-auto-bind-sub-block={enable_auto_bind_sub_block}", + f"--enable-preload={enable_preload}", ] + + _compile_option_list += [ + f"--enable-auto-bind-sub-block={get_auto_bind_sub_block_option(metadata)}", + ] + if _is_ascend_sanitizer_enabled(): _compile_option_list += ["--enable-sanitizer=true"] if not _is_debug_line_info_disabled(): @@ -514,6 +571,9 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): if _enable_print_ub_bits(): _compile_option_list += ["--enable-print-memory-allocated-size"] + if _enable_dump_memory_info(): + _compile_option_list += ["--enable-memory-display=true"] + enable_hivm_auto_cv_balance = metadata["enable_hivm_auto_cv_balance"] if enable_hivm_auto_cv_balance is not None: _compile_option_list += \ @@ -521,8 +581,10 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): sync_solver = metadata["sync_solver"] if sync_solver is not None: - _compile_option_list += \ - [f"--enable-hivm-graph-sync-solver={sync_solver}"] + _compile_option_list += [ + f"--enable-hivm-graph-sync-solver={sync_solver}", + f"--enable-hivm-cross-core-gss={sync_solver}", + ] unit_flag = metadata["unit_flag"] if unit_flag is not None: @@ -534,6 +596,11 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): _compile_option_list += \ [f"--enable-drop-unit-dims={enable_drop_unit_dims}"] + enable_flatten = metadata["enable_flatten"] + if enable_flatten is not None: + _compile_option_list += \ + [f"--enable-flatten={enable_flatten}"] + enable_auto_vectorize_v2 = metadata["enable_auto_vectorize_v2"] if enable_auto_vectorize_v2 is not None: _compile_option_list += \ @@ -579,6 +646,21 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): _compile_option_list += \ [f"--disable-auto-inject-block-sync={disable_auto_inject_block_sync}"] + bitcodes = metadata["bitcodes"] + if bitcodes is not None: + for bitcode in bitcodes: + _compile_option_list += \ + [f"--link-aicore-bitcode={bitcode}"] + + enable_libdevice = os.getenv("TRITON_ENABLE_LIBDEVICE", False) + if enable_libdevice: + _compile_option_list += [f"--link-aicore-bitcode={get_libdevice()}"] + + disable_size_align_for_cast = metadata["disable_size_align_for_cast"] + if disable_size_align_for_cast is not None: + _compile_option_list += \ + [f"--disable-size-align-for-cast={disable_size_align_for_cast}"] + if _is_auto_map_parallel_blocks_enabled(): _compile_option_list += ["--enable-auto-blockify-loop"] npu_compiler_path, env = _get_npucompiler_path() @@ -588,15 +670,44 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): bishengir_hivm_opt, "--enable-triton-kernel-compile=true", ] - cmd_list = ([npu_compiler_path, ttadapter_path] + _compile_option_list + ["-o", bin_file]) - auto_tile_and_bind_subblock = metadata["auto_tile_and_bind_subblock"] - if auto_tile_and_bind_subblock is False: - cmd_list += ["--enable-auto-bind-sub-block=false"] - ret = subprocess.run(cmd_list, env=env, capture_output=True, check=True) - match = re.search(r'UB\s+size\s*=\s*(\d+)\s*bits', ret.stdout.decode('utf-8')) + + if opt.debug: + _compile_option_list += ["--bishengir-print-ir-after=hivm-graph-sync-solver"] + cmd_list = ( + [npu_compiler_path, ttadapter_path] + + _compile_option_list + + ["-o", bin_file] + ) + if opt.debug: + print(f"[DEBUG] cmd_list: {' '.join(cmd_list)}") + + try: + ret = subprocess.run( + cmd_list, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True + ) + except subprocess.CalledProcessError as e: + if opt.debug: + _save_npuir_debug_output(e.stdout, e.stderr, tmpdir, metadata["hash"]) + raise + + if opt.debug: + _save_npuir_debug_output(ret.stdout, ret.stderr, tmpdir, metadata["hash"]) + + stdout_str = ret.stdout.decode('utf-8') if ret.stdout else '' + match = re.search(r'UB\s+size\s*=\s*(\d+)\s*bits', stdout_str) if match: - # get the ub bits of triton kernel from bisheng for inductor autotune using metadata["required_ub_bits"] = int(match.group(1)) + + if not Path(bin_path).exists(): + error_msg = ret.stderr.decode('utf-8') if ret.stderr else '' + print(f"[DEBUG] {bin_path} is not found") + print(f"[DEBUG] Stderr:\n{error_msg}") + raise subprocess.CalledProcessError(ret.returncode, cmd_list, ret.stdout, ret.stderr) + if Path(callback_path).is_file(): lib = ctypes.CDLL(callback_path) __get_metadata_attr_by_callback(lib, "_infer_task_type_function", metadata, "bs_task_type") @@ -607,6 +718,11 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): return Path(bin_path).read_bytes() +def get_libdevice(): + current = os.path.dirname(__file__) + return os.path.join(current, "lib/libdevice.10.bc") + + @dataclass(frozen=True) class NPUOptions: debug: bool = False @@ -616,7 +732,7 @@ class NPUOptions: arch: str = "" cluster_dims: tuple = (1, 1, 1) - num_warps: int = 4 + num_warps: int = 32 num_ctas: int = 1 num_stages: int = 1 warp_size: int = 32 @@ -625,6 +741,7 @@ class NPUOptions: reg_dec_producer: int = 0 reg_inc_consumer: int = 0 + auto_blockify_size: int = 1 compile_on_910_95: bool = is_compile_on_910_95 optimize_dynamic_offset: bool = False enable_mask_fallback_conversion: bool = False @@ -639,14 +756,17 @@ class NPUOptions: supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15", "fp8e4nv", "fp8e4b8", "fp8e5b16") deprecated_fp8_dtypes: Tuple[str] = () vf_merge_level: int = 1 + default_dot_input_precision: str = "ieee" allowed_dot_input_precisions: Tuple[str] = ("ieee", "hf32") max_num_imprecise_acc_default: int = 0 extern_libs: dict = None - bisheng_options: str = None + bisheng_options: str = "-cce-link-aicore-ll-module " + get_libdevice() multibuffer: bool = not is_compile_on_910_95 enable_ubuf_saving: bool = None - enable_auto_bind_sub_block: bool = not is_compile_on_910_95 + enable_preload: bool = None + enable_auto_bind_sub_block: bool = None + disable_tightly_coupled_buffer_reuse: bool = False enable_select_analysis: bool = True enable_hivm_auto_cv_balance: bool = None sync_solver: bool = None @@ -654,9 +774,13 @@ class NPUOptions: enable_cce_vf_auto_sync: bool = None enable_cce_vf_remove_membar: bool = None enable_drop_unit_dims: bool = None + enable_flatten: bool = None enable_auto_vectorize_v2: bool = None + auto_vectorize_v2_max_fused_ops_num: int = None + prevec_max_fused_ops_num: int = None inject_barrier_all: bool = None inject_block_all: bool = None + disable_size_align_for_cast: bool = None limit_auto_multi_buffer_only_for_local_buffer: bool = None limit_auto_multi_buffer_of_local_buffer: str = None set_workspace_multibuffer: int = None @@ -664,13 +788,17 @@ class NPUOptions: tile_mix_cube_loop: int = None disable_auto_inject_block_sync: bool = None enable_mixed_cv: bool = None + enable_vf_fusion: bool = False + add_auto_scheduling: bool = False + hfusion_enable_multiple_consumer_fusion: bool = False stream: int = None parallel_mode: str = "simd" force_simt_only: bool = False force_simt_template: bool = False + enable_sync_block_lock: bool = False # only take effect on the simt-only & simd-simt-mix scenarios - shared_mem_dynamic_size: int = 221184 + shared_mem_dynamic_size: int = None # enable_bishengir_simt_optimization is passed as # -enable-bishengir-simt-optimization flag to bishengir-compile. enable_bishengir_simt_optimization: int = 000 @@ -679,6 +807,10 @@ class NPUOptions: compile_mode: str = "simd" mix_mode: str = "" simt_stack_limit: int = None + # take effect on the reorder instruction pattern for SIMT. The pattern is disabled by default. + enable_simt_reorder_instruction: bool = False + # disable simt fma optimization to get high precision + disable_fma: bool = False def __post_init__(self): # Parse compile_mode and set related fields @@ -690,31 +822,12 @@ def __post_init__(self): elif self.compile_mode == "simt_only": object.__setattr__(self, "force_simt_only", True) object.__setattr__(self, "parallel_mode", "simt") - object.__setattr__(self, "shared_mem_dynamic_size", 122880) - - def hash(self): - key = "_".join([f"{name}-{val}" for name, val in self.__dict__.items()]) - return hashlib.sha256(key.encode("utf-8")).hexdigest() - - -@dataclass(frozen=True) -class CPUOptions: - debug: bool = False - llvm_version: int = 15 - kernel_name: str = "triton_" - cluster_dims: tuple = (1, 1, 1) - num_warps: int = -1 - num_ctas: int = -1 - num_stages: int = -1 - - enable_warp_specialization: bool = False - enable_persistent: bool = False - optimize_epilogue: bool = False - enable_fp_fusion: bool = True - allow_fp8e4nv: bool = False - max_num_imprecise_acc_default: int = 0 - extern_libs: dict = None + if self.force_simt_only: + if self.shared_mem_dynamic_size is None: + object.__setattr__(self, "shared_mem_dynamic_size", 122880) + else: + object.__setattr__(self, "shared_mem_dynamic_size", 221184) def hash(self): key = "_".join([f"{name}-{val}" for name, val in self.__dict__.items()]) @@ -744,6 +857,7 @@ def ttir_to_npubin(mod, metadata, opt): # build compile options _compile_option_list = get_common_bishengir_compile_options(metadata) if opt.force_simt_only: + _compile_option_list += ["--enable-hivm-compile=false"] _compile_option_list += ["--enable-triton-ir-compile"] _compile_option_list += ["--pure-simt"] _compile_option_list += [f"--num-warps={opt.num_warps}"] @@ -754,12 +868,33 @@ def ttir_to_npubin(mod, metadata, opt): ] if opt.simt_stack_limit: _compile_option_list += [f"--simt-stack-limit={opt.simt_stack_limit}"] - if opt.shared_mem_dynamic_size: + if opt.shared_mem_dynamic_size is not None: _compile_option_list += [f"--shared-mem-dynamic-size={opt.shared_mem_dynamic_size}"] + if opt.enable_simt_reorder_instruction: + _compile_option_list += ["--enable-simt-reorder-instruction=true"] + if opt.disable_fma: + _compile_option_list += [f"--disable-fma"] + + enable_libdevice_simt = triton_enable_libdevice_simt() + if (enable_libdevice_simt): + bisheng_options = metadata["bisheng_options"] + if bisheng_options is not None: + _compile_option_list += [ + f"--append-bisheng-options={bisheng_options}" + ] npu_compiler_path, env = _get_npucompiler_path() - cmd_list = ([npu_compiler_path, src_path] + _compile_option_list + ["-o", bin_file]) - ret = subprocess.run(cmd_list, env=env, capture_output=True, check=True) + cmd_list = ( + [npu_compiler_path, src_path] + + _compile_option_list + + ["-o", bin_file] + ) + ret = subprocess.run(cmd_list, env = env, capture_output = True, check = True) + if not Path(bin_path).exists(): + error_msg = ret.stderr.decode('utf-8') + print(f"[DEBUG] {bin_path} is not found") + print(f"[DEBUG] Stderr:\n{error_msg}") + raise subprocess.CalledProcessError(ret.returncode, cmd_list, ret.stdout, ret.stderr) return Path(bin_path).read_bytes() @@ -767,13 +902,11 @@ class AscendBackend(BaseBackend): @staticmethod def supports_target(target: GPUTarget): - return target.backend == "cpu" or target.backend == "npu" + return target.backend == "npu" def __init__(self, target: GPUTarget) -> None: super().__init__(target) - if target.backend == "cpu": - self.binary_ext = "cpuasm" - elif target.backend == "npu": + if target.backend == "npu": self.binary_ext = "npubin" def parse_options(self, opts) -> Any: @@ -783,8 +916,10 @@ def parse_options(self, opts) -> Any: args.setdefault("arch", self.target.arch) options = NPUOptions(**args) else: - args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} - options = CPUOptions(**args) + raise NotImplementedError( + f"Backend '{self.target.backend}' is not supported. " + "Please ensure the target backend is set to 'npu'." + ) return options def pack_metadata(self, metadata): @@ -795,7 +930,7 @@ def pack_metadata(self, metadata): # CANN runtime limits the length of kernel name <= 50. # Considering '\n' is appended, thus the real kernel name <= 49. KERNEL_NAME_MAX_LEN = 49 - kernel_name_orig, mix_mode = metadata.name.split() + kernel_name_orig, _ = metadata.name.rsplit("_", 1) if len(kernel_name_orig) > KERNEL_NAME_MAX_LEN: kernel_name = kernel_name_orig[-KERNEL_NAME_MAX_LEN:] else: @@ -835,10 +970,10 @@ def add_stages(self, stages, options): stages["npubin"] = ( lambda src, metadata: linalg_to_bin_enable_npu_compile_A2_A3(src, metadata, options)) else: - stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options) - stages["ttadapter"] = lambda src, metadata: ttir_to_linalg(src, metadata, options) - stages["llir"] = lambda src, metadata: linalg_to_llir(src, metadata, options) - stages["cpuasm"] = lambda src, metadata: llir_to_cpuasm(src, metadata, options) + raise NotImplementedError( + f"Backend '{self.target.backend}' is not supported. " + "Please ensure the target backend is set to 'npu'." + ) @functools.lru_cache() def hash(self): diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py index e1fe48bb41..803a66595a 100644 --- a/third_party/ascend/backend/driver.py +++ b/third_party/ascend/backend/driver.py @@ -28,7 +28,7 @@ from typing import Optional import functools import hashlib -from triton.runtime.cache import get_cache_manager, get_dump_manager +from triton.runtime.cache import get_cache_manager, get_dump_manager, default_cache_dir from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget from triton.backends.ascend.utils import (_precompile_npu_hash, _precompile_npu_ext, _build_npu_ext, _check_cxx11_abi, @@ -69,7 +69,7 @@ def __init__(self): env_arch = get_ascend_arch_from_env() def load_binary(self, name, kernel, shared, device): - fnname, mix_mode = name.split() + fnname, mix_mode = name.rsplit("_", 1) return self.npu_utils_mod.load_kernel_binary(fnname, kernel, shared, device, mix_mode) @functools.lru_cache() @@ -94,19 +94,6 @@ def get_aicore_num(self): def get_aivector_core_num(self): return self.get_device_properties("npu")["num_vectorcore"] - @functools.lru_cache() - def set_device_limit(self, device, ty, val): - """ - Set npu device limit - - Args: - device: Device id - ty: The type of the limit, valid types include: - "LOW_POWER_TIMEOUT", "WARP_STACK_SIZE", "DVG_WARP_STACK_SIZE", "STACK_SIZE" - val: The specific meaning of the value depends on the type of limit. - """ - self.npu_utils_mod.set_device_limit(device, ty, val) - class NPULauncher(object): @@ -225,32 +212,51 @@ def get_empty_cache_for_benchmark(self): return get_backend_func("get_empty_tensor", cache_size // 4) -# fixed the issue of corrupted gch header files in multi-threaded scenarios. -def _precompile_npu_ext_with_lock(header_path): +def _precompile_npu_ext_with_lock(header_src, enable_precompile): import fcntl - src_path = os.path.dirname(header_path) - lock_path = os.path.join(src_path, "precompiled.lock") + precompile_hash = _precompile_npu_hash(header_src) + cache = get_cache_manager(precompile_hash) + gch_path = cache.get_file("precompiled.h.gch") + header_path = cache.get_file("precompiled.h") + if enable_precompile: + if header_path is not None and gch_path is not None: + return header_path + else: + if header_path is not None: + return header_path + cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() + lock_path = os.path.join(cache_dir, f"{precompile_hash}.lock") with open(lock_path, "a+") as f: try: fcntl.flock(f, fcntl.LOCK_EX) - _precompile_npu_ext(header_path) + header_path = cache.get_file("precompiled.h") + if enable_precompile: + gch_path = cache.get_file("precompiled.h.gch") + if header_path is not None and gch_path is not None: + return header_path + else: + if header_path is not None: + return header_path + header_path = cache.put(header_src, "precompiled.h", binary=False) + if not enable_precompile: + return header_path + src_dir = os.path.dirname(header_path) + gch_path = os.path.join(src_dir, "precompiled.h.gch") + _precompile_npu_ext(header_path, gch_path) + return header_path finally: fcntl.flock(f, fcntl.LOCK_UN) - + def make_npu_launcher_stub(header_src, wrapper_src, debug=False): """ Generate the launcher stub to launch the kernel """ - precompile_hash = _precompile_npu_hash(header_src) - cache = get_cache_manager(precompile_hash) - header_path = cache.get_file("precompiled.h") - gch_path = cache.get_file("precompiled.h.gch") + enable_precompile = not os.getenv("TRITON_DISABLE_PRECOMPILE", 'false').lower() in ('true', '1') # if precompile header file and its gch file not exist, do precompile - if header_path is None and gch_path is None: - header_path = cache.put(header_src, "precompiled.h", binary=False) - _precompile_npu_ext_with_lock(header_path) - + header_path = _precompile_npu_ext_with_lock(header_src, enable_precompile) + assert header_path is not None, "the precompiled.h path is empty." + # try to get cached file so_cache_key = hashlib.sha256(wrapper_src.encode("utf-8")).hexdigest() so_cache_manager = get_cache_manager(so_cache_key) @@ -274,15 +280,13 @@ def make_npu_launcher_stub(header_src, wrapper_src, debug=False): return cache_path kernel_launcher_type = "torch" - enable_taskqueue = os.getenv("TRITON_ENABLE_TASKQUEUE", 'true').lower() in ('true', '1') - if not enable_taskqueue: - kernel_launcher_type = None + with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, f"{name}.cxx") with open(src_path, "w") as f: f.write(wrapper_src) - so_path = _build_npu_ext(name, header_path, src_path, kernel_launcher=kernel_launcher_type, precompile=True) + so_path = _build_npu_ext(name, header_path, src_path, kernel_launcher=kernel_launcher_type, precompile=enable_precompile) if debug: with open(so_path, "rb") as f: dump_manager.put(f.read(), so_name, binary=True) @@ -555,10 +559,32 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); if(!ptr_info.dev_ptr) return ptr_info; - Py_DECREF(ret); // Thanks ChatGPT! + aclrtPtrAttributes attributes; + aclError status = aclrtPointerGetAttributes(ptr_info.dev_ptr, &attributes); + + if (status == ACL_SUCCESS) { + if (attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE && attributes.location.type != 4) { + Py_DECREF(ret); + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + return ptr_info; + } + } else { + Py_DECREF(ret); + PyErr_Format(PyExc_RuntimeError, + "Failed to query pointer attributes at argument %d. " + "Error code: %d. This may indicate invalid memory address " + "or NPU device error.", + idx, status); + ptr_info.valid = false; + return ptr_info; + } + Py_DECREF(ret); return ptr_info; } PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; return ptr_info; } """ @@ -746,12 +772,13 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): name.append(kernelName); void *workspace_addr_ptr = NULL; uint32_t blockNum4Workspace = gridX * gridY * gridZ; + {get_backend_func("pre_launch", True)} {f''' uint64_t totalWorkSpaceSize = {workspace_size} * blockNum4Workspace; - auto optionsWorkspace = at::TensorOptions().device(at::kPrivateUse1).dtype(at::kByte); - workspace_addr_ptr = {get_backend_func("allocate_memory", "totalWorkSpaceSize", "optionsWorkspace")} + {get_backend_func("allocate_memory", "totalWorkSpaceSize", "stream")} ''' if workspace_size > 0 else ''} {'auto launch_call = [=]() -> rtError_t' if enable_taskqueue else ''} {{ + {get_backend_func("pre_launch", False)} uint32_t blockNum = gridX * gridY * gridZ; #ifdef ENABLE_GRID_WARN_PRINT @@ -761,14 +788,13 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): warned = true; }} #endif - {get_backend_func("pre_launch")} {'blockNum = std::min(blockNum, (uint32_t)' + str(num_physical_blocks) + ');' if enable_auto_map_parallel_blocks else ''} // set mixBlockNumRation for nodeBasicBlockDim for msprof report uint32_t mixBlockNumRation = {mix_block_dim_ratio}; uint32_t nodeBasicBlockDim = (mixBlockNumRation << 16) + blockNum; {'cce::internal::DebugTunnelData *DTData = cce::internal::DebugTunnel::Open(blockNum);' if enable_device_print else ''} - rtError_t ret; + rtError_t ret = RT_ERROR_NONE; {'void *ffts_addr = NULL; uint32_t ffts_len; ret = rtGetC2cCtrlAddr((uint64_t*)&ffts_addr, &ffts_len);' if target_support_ffts else ''} {'if (ret != RT_ERROR_NONE) return ret;' if (target_support_ffts and enable_taskqueue) else 'if (ret != RT_ERROR_NONE) return;' if (target_support_ffts and (not enable_taskqueue)) else ''} // stub argument for workspace @@ -776,7 +802,7 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): uint16_t ModuleId = 0; {f''' uint64_t syncBlockLockSize = {lock_num} * sizeof(int64_t); - syncBlockLock_ptr = {get_backend_func("allocate_sync_block_lock", "syncBlockLockSize", "stream")} + {get_backend_func("allocate_sync_block_lock", "syncBlockLockSize", "stream")} if (!syncBlockLock_ptr) {{ {alloc_success_code if enable_taskqueue else sync_lock_fail_code} }} @@ -880,8 +906,12 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): } }} - if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ - return NULL; + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; }} // get kernel_name @@ -904,8 +934,12 @@ def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): if (PyErr_Occurred()) {{ return NULL; }} - if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ - return NULL; + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; }} Py_RETURN_NONE; }} diff --git a/third_party/ascend/backend/lib/libdevice.10.bc b/third_party/ascend/backend/lib/libdevice.10.bc new file mode 100644 index 0000000000000000000000000000000000000000..76bcbc63a5f07f3696fb15ce79dc32e2b323a1c3 GIT binary patch literal 84908 zcmdSB4OoT#7S>o$T6$-N8*O-p`Fy|EfST^!&vP8l|M(xr z!*1}IFF&s9yv~>3`TL#cd6{l+`f!3#5R`%-EY?~^JhCc9yH|vv(sTxyO zNYdinGo8(G_hdv=?9tnv4O-fwdCe%=r|%Yz$&wbAo5sYCsZ88nm%mRu(t_(gDGI{9 z_?a|(cd)p4x)68Yh-y>ZNPLSPrQLa;vS9Jjd~w9`QX%Xym33;+NzJF4CaWw?qa6OJ zMpRDTW7XdyY*5q|sFun0SQB=Mno>)Uq1q#A#wVOuT)r$ZU2lpLwFZT)0KeN~HBV2~ zjW-`H+$9QYtgQyca#3?id*V=-B3ssK5TDYYGPs*9rF!i>8o#B~s+9>jCPC1kQilF$ zc<&8-$@IcJp-*tjJHv(GTK5B*LsqfX7JF!a%<=tUon0~Scw;W@H(!uBD>geTiiO>q zW~w)@ly8c!*q~8wS}ETaUr|(BvAMEhVnc;XRByGa1)Xpx=EMDAmv9kTY-e1Izsvl7 zZOkE0Y@4iBz2zV;vZc1dX{y+`QeIM7vAMRQ*reWCqTaGnUUX2s!BSC-t5sLH+F;o& zh?n-q`uCgPFEDrRk9BW|S8tA2Z<$$9Buvz&3zdgdD*{ikQr>A4rU~B*mF@>Rta@3K zdg$(S;f1i#g?)TVLE|1RJR$VSM47QoZ8K@gC$-p9>+O^3WoeD}$qn+URb^?ym)*KA z9XjE(?sBJYZIkYzQ`gm`TU!u*(W2|=*8QVR^=+l@f;Ie-&?K`q_@>m$rs5(^^0WhG zQ)=u(>+Pv+^2tJv82$||k{I5j*L_l_TPuYBqg&VG(0yVJ|4IyBE9$-}(0%IQ)xNIO z{WmOQ1;y0%vME*ep?lB$WZi1F!twvi)9rdT?|YPehoEQ*cTW>WxI_M*>3gp#zV5H~wL*1J zr@CpF5YeU}@p zZ!)@Wi*HJsebOHJ)W)*3dQWP5*`#VaI`+_3`{YZoED`iQj&OA8wH95sBfM!+6}nkj z+5umxFtl>IFxYJrCU;mehwIJ9-JYXTgt^3Ev7?>ahiJ-N6M`)>JtX3rmwtL^W=D230$DCL?uL(Cd%*cOt&m| zhd!dYH}_Jd?y?a6O@Z3grUIk)U#Zl6f+5$Y`oyBUSZ9lb$9X4C*CT}Y6zDG0hP@*T zzr}I7E7tIjn{@rZ^6D;G!>_n??>Ct#iu~=TR(y=@|5zJ#rI3SF5@Wh={%zo|3I^kE zt77%TN0e)$M<)fXEVoSh&rknx_34)eKX>AvJq>*ueV_k6(fqe1N&Cj@RgcdM5jb(U zYiA1uN89gRVbOk9iTOr>=d}G?XnH2NY3Jh3ir)x=Rd5TM_Z7mX>;`w}k+$%8G6+JD zzCKvEQr&rIaHV(HE4YwMP>CxC38sfmKlq#==y4eh{<8}8!IlFdQBTW)toRuCS~;j- z=EiYGxx(Eq2twPvO_h%=`X& zlzHX{Psw_99K*5r?4iDTCf81ZN%{|t@wQm(H`?4o+ z?OO#q`z&hA!9T)(ff^e!YDSPyi;t|xK*1vC?v)Gf9~A8PFNdfFL4JF|&OSKdt{SU( z?5-MHc6Njypt<=}{WW&sxBCPk9v^QN?E3TLWp^~?V`EnaNmX;JU}qnBTOn9~P_U=o zstp$0xV9D7>u<^(lLkrQjGI!R#v=cU`>_Q07X%#u~%jaF_l*7IOVSsNlwhtdyia_Wbt?5Y_## zU?wdq78X|g*I%(Fm0l%SaSxH2)MG<*Wi3IH-L0EOKIL5^4c*(iS?S>dJi-qfeeV)0 z;!50BYMH=j`ZBRP;m!&f6K%b#L0{7T;?8!T|6^4+hV>7NH?C~OosG5eXWlZI)PT3D z?vE#hKL#F)#R-h!7gq?&1OcVFt$0_)mmofcLlaQV{T*rKkTke^Xl_XT%{nCPuHv1o zzPor2&XM2Mmqz|pRxT5oaH(6xd-kkwSM5EyS|}6*3+(X310!pK`>x^*A8Ng;crS+C zHEN#kVB>O#i@03BJ6b0{erMw@`-|*(nNSInTg7|l{N1(poAr0sUQbHXUByc>7od1> zsczmjFeoPIUcR$&&&H}x+&T7-SCuWhvvDI&coOd%uz$Jee(|opx1n;SASC0Ks6Ffb zf%|UK9Sq}B^?yn8!c6t)Tlc+?F3mDeT^8@^El+OLFSv8qemZ&gU7h(rOvasK?YHmi zhKB|U!aE0slHlD@mU z_Rc@9pMU4X>X^2Bj3A&--d4On56g&@I;oqt4Yb-f;V*|s)0LH)RC{N?SIMW$lh5f132h;pd`C<}F-x|MjetE6O*Ie(>hg>o*^!&b*o499&wVs2 z_^r36Rc$pSZ(Pv(^`eW*_Fvy}6=PwTK{`-ICJZAp7;XkeC z4qOhY@b>*d`PTP^9}I4r_p0}igEw+t9lPNdkLvH+kpE`l&(>s3{LZc)H{sF9S;e26 z9+v&*;+n7j^1GrZ{$tvhBSAKqVAD5`*DGuexlKP}0w5QVOZXAMMaTo}AuIsY5Pl4J zmGB6lk&q8K0`L+}QCSFhmkQxSDvJQ0QXzavT` z5#TD}QGjg|U@0JnunbU0SPm#6tN?fkYXGMSYXMz^b%3jc^#I#ufD4dA*a#>jYyuP! zHUqqb?SRvS#{pf0CjeIoI{>y4xy?}ym<4Dfah6OMr7!2whZu0k}+saFxn00oSMyY+Iom0L%ik5pt;f8nA#0 zp^(ZefK^loMO6L+u!Rc2OXXF-F)D=9RQ?lijtZfR%I^V}sSvJG`2*k@6@qOWlzPA{ zKpP>4${T(^TF9oTEbMqVhMuWh#WLRQ?XQ zMulM8F1J~l0hs_NVHTi{FdNWK$f3<)z}|P%}Kyn+7QmshH#!Xgf7~g0$ijG;WBLq zU(kkdl{Rk!zM&1_8f^&IX+y9*4x4sBCcsIU1*jv;26Pj0Xmb`YpEiUAv>_~{4WW=W z?*f+5hOmk@gtfFG6w&5Az((2-w$O&~IBf`C+V}y7X+tz=qBXQ z<|1G|Z3qi!Ls&=~LLqHF1uUZtVHIr%YiUC$qRnT3jkF0M62eaE>;F^Ryv!(dJ9QMcNQ9(}wT`Z3tIs^A+G5+7PbMhH#xW1ltbS z^Z+sePQole9bq=0n~+1BZvgXYLs&o?!a~{*3Tg9Cz%tqpR?&vAmNtYU+I$PxNE^Zy z+7KS64Z%yBKEPqx5RTD?aFRBJ)3mt`I7=JCIoc4;(}vJRo9_V^X+yY78^RZ~AzY=+ z4Zt_FAzY&k;W}*ywo19pq?n{v*a4}8Oh6uC2B3s63s6nS0yGe213CzE0lkDAfJq6M z2S_E%2jmeR29yvM0ICU(02&Aj0Ud-#0lkDmfJp^d0!Sq+1LP4_07?j}0M&%mfCj=^ zKnGzxpqEetFsT6>0I7tHfIPxxKnYKKnLLlpqF6VhwB&tc0ej26Oc!k0VpBN0#p;S01brMfDXc3KrbN&V2T9H1Edn> z1M&zD14;-B0M&#?01bqNfDXc=fL=l&z%&@J1dvKt2FN3<0F)3`0jdeB0S$z;fDXcX zKrf*PV2T240HhK&0`dr(0VRYjfNH`vKm*}%KnGz5pqJnUn9KklAeC?!kViNQC?OmJ zR1;1B8VDx=9fY?5y@b;MQw-n?AeC?ykVkkAP(nBds3v>>Xds*ibPzrQ^b)!NrdYry zfKUKndY8pqlVGpn>oOpo8!gpqFqJU>XMa8jwo(29QVi7EnUC2B;={2WTK% z2Xqi_0D1|wXK|eo06QR+kO{~m%m9=SW&x@RS%3z@Y(NKLE})l?12ByQ%mbtn<^%Ew z4+BaF3jo!GM*t0kg@6viqkvvQA;5GmU3Fsia4d^AD2AIYI&Hz#gX90PH_W&h?bAW2X2Y?2`c|Zr@BS0^q3t+k*@ChK5 za1oG4_zX}&xD2Q!d=6+Jd;#bndtWSN;nJ1BfJMFA)EtL6FvYm5Y7WS2p<7@30(lw1AtEesf3Gw zJi=#y62fIbHQ{qW1K|rm2jMF~FX1Y{G!5`IAeHb9Adm1ZpoDM@P)+y_&_K8j=pftx z^b#2V7!{NBxBeXOs%K*7l2>Dc209;fE+o`Mq zR8b++P+1MAr$T6?vKG)zh2W>M9?(OD&_|^RU{pgPm;q9NmrR9`UK*8+fLtnsd@7p( zE-HlWRJH)Bs1RzXYy;F&Av98X9MDdM;HRh$k;-{MI~9VT%13}6Dug~NT>zsF z3c(DJ0@q|Jj9}BKTm>o}4B(DdL1-PgXwo|zVsG>rsq4FJ|o(iFn%5^|H z6@s724L}bSLLU{z&qgB@f*BwM(#cd9QKwO1M4d~8kWYo7wTlX2I~B&*Ra6KyR2Xd6 zQz0}`VbtADh2W>c@Vkc!p^pmVaAPDCf*BwMcFusQhp^r)-z&IER!3>ZB`(!GN@YAR)1LRU6NaZA;oeIHE7I~u${^|Kou224V4c7^;8IrRL%q1sSx~BJ_7Vm zA@ot{0vKbV5X=B6@kyq_3@DAtML;eULOzwx04^$o?NlxUs;CfZsC*8nr$T6?@&%xs z3c*k1D?kqwLLZf@0OK$y1T#QNppvODlS-rV4Iq~aA)m^(02dX)b}H8ZRa6KyRK5e$ zQz0}`xejQjLhw_$0qCJZ=%d2?%QylG!3>ZRv1BUD$kM1VBg>^i$fv^8%teK;oeFcb zDk_8;DoobusSp~eFl%e4Lhw^z`qo2*&_{)NoN*)+f*Bwsbjehh*`-loW|vEakWYmv zo{I`$I~C@7Ra6KyRG9SDQz0}`VfNQfh2W>cG_Z#Xp^r)-z<4hdf*BwshRIZzA*NAT z2FRsC$fvRb;G#mNaZ}B zoeIHEGk}WZuSKseA!wr$X>k`3lfOh0sUkD!@1n3c(DJ67Xay%*4~Ed;`d(Ldd7`Ex<*Eu${^^ zKou224VCWz^;8IrRIUTssSx~BZUB0y5c;Sve>WyWA(#PDBA-l!8GRZRX7ssK2>Dc) z+PkO_wo_q_UqyvbLxstHJrzPD6=wbIR0w`5O#geR5c;Ta9l$sN3c(DJ76g*1aAhEk z3ReblsSxt1aEZW0g|M9p*9xkr5NfD!(V(6Rp^*w#589~^{8YG%&_jjLN2L&8oCt+r z21pAF$yB()kVa)0AeRaupUMh=iwa>ol~sT$Dufy;s{!>?2#r+M0@|q%{8ZKhdZ-Zk zs1yN=DNqPzfV4o7Oob~IX;d}>a;XsVscZ(gs1UYO*#fAdLa3p#4Ny;o&`9NRKsyzJ zpUMtE4;4Zm6)(V;28Cb-NDCdwRJh`iM&&RdmkJ@D%29xe3Sm2yV}L3ugc>R*0QFP| zjZ{to+NluyRNe;kP$BeDISnvQg+ed`qy>{?DqJ~9qjDCIONEe6Rz0P3j_8mXKIv{NDYseAv`G*bBj&`yQmr}7n`hYF#O%2j}I8We&VAT8h|Q{hTZ8kKJV zxl{=GRK5kcs1UYOxdy1BLa3qg9iW~Hp^?gUKsyzJpUMqD4;4b+v`I-vf|P>7rkC0D zoOH(r|_d&nDeYQK58MJbt;4Q-Xo~m#5)<5xjuX z+HKrapx(ApowBl`7~kC$&dLgBfqGNCdQ-0wm)-2HcvMqSWU45xuGmuTN#F}oow{$D zbXP3$#VvR(Y`?ix`tC0@ACEIPdt#0ks@|>SJ1BuSQ6@F`QmgE#yX8}=eB*G_zS1e} z@+oyb?a@^g&b*3ZXPIlIx}>1OWvX!1Rz#ng7yc#QmvZPX;LQ+gco$ylz*C6vF@A6C zAzSP_ab~|K*57Hq&=vDvve@OOoYXq|4^mLExkO!3os->D8O|DreJu8P zq4`o=%!l5XCuFfKa9m7xp(=N0S!%0qN}YWwDx~jF+6DWR-T3I+b2tP~TCo}B(Wtk^ ztBb2W`Jd=@SMWBNL-)1i9`8H0*b7}z9bGXWq5w0ar__~AtuGr|?n~WcpVaD0Yw%9O zOF8ZGwAPvRMR-H5KwX3i%&Un0;E3)@LHHM)s%{bQlZE32uYWpp>+p&gUVX8If8D9V zt3nqXx?amYA73dnAM(U>bVXf>i}^Gzra3MKEqEzz*mCiuDUJ524f075qr9V>{_p1)e zJwXpXYkspkZE(e;!-BAR$;!gwM>nk6y38@lE@m&=v}}!IR%QS;Zgv$eUb8I8k~DFw zC2?%x_*~1R$;lHZkDqv-C3UjJA}(1|xNe2GX<6}xRqNM@i4*R(+&@8#m(;NnmP|}7 zOkA=!F=>3FxSWPRkVdaxyKa?py=&Pzm&hW_Tea4?W|?Te?>;d%-DX>}X4BgH*Q{Uf zynpS+4K8Q#`W3~?Hf*3EuGm;uyi{CXynd~ijn5CQcNMyD`&)mpab5BHHEZr)x@rSU zzqq(?-IA4$ZeF%(#Y)$P4J+3dyH>4p393v~7QQT{);^ymNEUIACS&FvQ9nGfv~Ewc zCdd{sPtzWq9hbYpva{eNukp}E)jUlHgryK_DQtq^hv0(HNnx{Uymn4^tHqk0>=VLH zYvmak;}wUs-b7zPl_%NmZq<(%uQ;yteVjdByC-_)co4}c`pkc}9^Nzd%9?e?kL6_+eOYct$CdZ< z-)TASzcfAg2%d8v{2xx{3!x!}LeP#pi$b_x_zkp?l2#ngPPR4AS&jb9t~sgASNf^> zNZ>V6FNZ;33=dBJq(JwX1Fj()FL&YlXAWIYo!g_{5-)kKk_wDEj@vwEjV~3i+8rpH zTICB`4JT;Ei>Q9le4&uvFL`1Q6`I@jn_B{dtej#ZJq4%Y2`+ z%(BlLd?wa*-eq@a&xF}++7!*;=o*V^nHl6ra-l=`3KrBqL*3~x?$P^v8gG+ zo9J)aFN@BvP~o9V4!O#P)Om5=gRh5nM1w2Qf5e1SS(piZM*)8&}$b>>l}5yPF*t^zfcUkZ!kz;Z zHD-9wa2JOOe~QN)-cuBLUmlfEV*3^EWh>*oJii*ai*zq5zenqxo&!Urdp>&GJ=tID zjS7R(mKChe5$laIgVLTAZA>3+v`^CWEpb7(Y6-ksF;q3)B1gv=Z&+;Yv&Zi-;dOvg z+1C-XD;*l2s$8#V$_!hsPmfZ+T)Zo6w^if*mAO?czh9^Rr*^2`qZc+$t+r_vnybQw zOcaCSgrv~@Ps`6_`cxUBqG_#(e_5f{zEUV3t`F*L)Q;Eef4VGLb+%4s(8!#HwS}rU zSJA=N&cGk9)4pT6k`sjK5Ip9f-g?t3k!{aV6W8Lq7|Fw zz+e32p4KFNCS8eLl~GXOmIjYyR&zpG5v-5QTWZy)LNu)rX8DA2lTe^9DNWbxG8qGX zgnfLVU(#n~RD}zp_1ipdbl5^M&>!%7RWlxnpV^OYw})>LCM7)U8*b7Jcgj4`dN`qf zZ@14EVbVn0Zb$z#a6P`)lRU@Z<~u->_B`v0@4se3|26-;?LMEW|GNLmR(dX@LP$|m zcq(Nxv=y>_GLv>V-Q`_nWf?N6&;oJ+_FKzpqcndUuw%IfDZIOFvU1ThV&2Wa|A9X9 zp_}^c<>fV+_X@&HZJcFsfe@!rEsWHd^dgN5#lYNyYdd}7Zm+>^&S`d)>9tYDo7akw z?roR+65q_h+|w?3kozR(VWZbj!(&tplllzjopgP{%jHeVrx(>WDF-d8yA)>2Z@Lt0 zT+|t_l56uT<5WS~eEnsmVKMxzGD7VB#X>PKt_S)WnpemO46>pLU|(37 zpt?t;(w7#r<&;~_TNQ^?afyNcAG%98IPf+UuS5Sma2@uS`ds18Yjn=;zXpCS-TU@& zbF+`3lX0xoMl1e0(8uup7_R-`&AA)?LZ9QoOs&88(D(9-!z@8oujDbl56!6ht@X8# zc9#8VxP%e5O6_Y=30n`9MkP%4>n)hJD>YtcSyLtah83PbDOa5dix|`zF#>!K8fKP3 zT+%Egd4u!oE;)U}J>CTQVns^YcF(VTVKyZUS`14KSCoVC7h4olvMROhiYJ6#n>*|Y ztG>d$MtZI_QhDb{<>e&uZ~WU0^%4dCZIj*xbeIU#DjPJ3qXnXoEj zYecdy?5D=oq-2CC?uE?>_IY^p=)n-C!-wXyCipxGd88-C6q@9)9rU~=`oe;yqxK%O zIBe$KLrwBXue;Y%B)d1YMy5sdVjh3ODSEBSq7kPO9R}@_uKey`U5p05S!@mbnPFwc zJ`9-m3rCD{%lri@XSy{?k(5;j?NV*IH_2fzJ@}e7SF!t#2{Sa$W)$3`QlpnSPG_{K z);aCghZU7UuSL~)mh0CyAFwxI+@V1DCmR$RyK!{c)cm z=bi8gmLNwr8azBQM{jHm#2*}&_m!v2EQ<^>AHbmajbWSTm1xBjzkW?~f;lC@rfpV)XRpa@)s)&YBRub&*SAH=4p^@!mt$o8T=81e0aWf4pP+{!NVPE6o>l#kSNI z<;P;wp0a$auK39W^u{JP+{6X6gKSLvneh%=)6S%+qk2p8yF-g*Fzyngc9-Q_bj7ks zEy|1HPkVzk9 z-X~HC>(wLPMG0H%rBi)jkG1LG*Q~R>)=Wosd9uT1d~W^BQSD_lO%bT2Mj3A8m}hA- zR%$v2DWA}+Npv>{sVXCF&nB+b>)E&HyB|?y|MpGoesRGgsvn7E<2#b-gQ9D^FUZUX zR-b6oMZCGJL46G2U0mE-^vZ$p@+W}04he!Al#kN82K{XKu&fD2=!_q`#9{|ZhxE)W<4j7WM zr6HbS1$8RXSvm||uUpNo7w0vsseN5>(rN!pg$5Nzd$%}+Z}+nA9$?=+6_Iah@|=pW zpr4r&2BDwT`jQeywRrL^>S9?;i*lnF)#1r^;zMkUa*O!W3o`UGEt)?y8~v;{C}{$^ z^GtExhSD)D%HN4mb*1@jA;q5iT9khfqnb+d^`XVHMx#5oan@^Rcg~8BMK{}B7Nlx* z+TROdsiv}N){e|IdQJ2N zxAs)Q2}j5~Z!D`+27i#$U3vl8#)irdVTi9n@m6f^?T^i;Lm#BblKLp#*#CEP@%K`$ z@$uiKj;k5yqlnlr4|_+y|H>D84NYJD{TCk%IO_YcT(Emb|0JzJOo)`DOq*hzBXzY0 zv0{L_I8lpO!#6_ty7o8(hXaBhld=-RgN6f{4wWZnaI@3CEBJNo5p+&XI}&k|NwgrU zNO74HPq;@QR4Iy#)|YyMf?ih~v#R%-f`+SSs66LQio}3}oueVVP7gF&wi#8#W6`f*70gMTvM?)s4%_^3A= z(g3|jxcCrPNFdF0+>l+wzg)#76a8VT-RekK_SeY`1c!)AGFp?;FdlD*HBG@vG?8pb zc1n}ikYCW`Dl_Y`wsVL>w=-E5o31er3MrW>KNFVa7HuJ;oI&u|>2B31ljXcaamZyK zCi;|N*(rUV95JH2sqnY%*A*C4N@M_ks67*QK(8@5M5ia=!d_XtUs)?exl3aWR+aZu zM4Ei6Y?e@9sN5-rD<+#u?SqdEA(4KI_~I`*kD2(3H}6 zK*)psKSOF?j;C9(s5Rr-FFrpt4)F@ZPsSkd#5bgI@Q)p6y9!Thk_m0+He--l6oEXT zV8M)}tO9LHd~o840(55Ivx?WX(+wl+SCsQUjwnZYW1Cx)rTw$~gm7HoYe z+#_ISj~V#+wItZ1NoF3e6!u8 zT(-=O@gYw-Ksegwo{*-BHA#j3q?T)wUmjx`l0)>LGPj&zM&y;62B zQy%Hw-{R*W#Ukn6ao{^*=s;v}d$RK}d@2`v3^`0|xC#U(-1di2)pIh8(M3`o#MrpM zU-kzUzW(6)W@vrowIN^iUkEv1kKK!?VOnl~u<);91;o&{Kn%SS`B0SN96tYt6jPVr z^C&$2;!eZs+7ocxd%RHyERi}{Rl~-gaF{Rw)`m1MhHt2~CLD1hM{W)zP|byvLClAC zCFrAedph+Ds>K>v;BS;g<~jUjMQ(nrs?-k%*X8TrSR13^@F`ijo3#io%u2l-WBBVaUC(7 zwm9t$jpolmNZ1sQ3rKSx*K0BHM2r~Rp?pH@Pq_ZPvn<4$;7w{XXr7v;>9gc_h9HdT zlO>*LOT^^U7SfVne-D)*kH@#>cm*@d&NTeS5PHZZAKj8}$!d#9=ijs)LeB{Owcf9{ zsuaugdCGo&e)dsCU3Y5)V%hS?(uO4?Xg&NoKFAiPsjKV-4=Y@DF_7(5l?Adrw~pCf zIc8I2d*M@=?bU`N+q3A9?dg%=mhIFnN4z`B`q%oc&jqwQZ zP5%-Yf!~Kp!!psd$AEvB^rL{NQNiK#&ow(T+O3(!5aoEJ#t19E9jDLX^0E}uEAbP# zZYp!#8gEFP?wx|2qIw0Y{wHOvNmlqLVJiF+V|&AhD#e+wd17|T%+nZX8Hc!}Gb$q~ zONcQXL&$)zYPtb6sMv=EjPhvxlZjci7_=20RkJnIVhSm;jZCeS)-ztRoG5@NE!&4E zqpVD^xxl$Q8ezFpg}J?Hgm?7MI0+*nFn)lAi`w*@u^5nJjAx&<7pRr43~`3SU14TF z71;vL1Y3ybrO38uPmMvCxRJY2$E}&2Av0re4NrNicf8;7XGY7Zv)T4xUX&?Qu4@f0 zt?5>yyAeVT_?1YOIIGXuJtB4k}yH zR0Q|aMx5DDDn==aZ9n`L9FNk~v2;yoUh6FUvzK!fJ(iWH0)2vW#3Q@1ip$2zr>j=f zEo*7%)V`@*CBEf-AufDGYDZF2sqyx;KCWpE%uUifY|Oki4Qu4H^hH?hPB^_!Hu70_ zseP%nIeIsiP_3%0p2hAbN7cGb=S_=;KNv3_QOfhXJ8xT$8J<@*wk)FF`+_I{%b8_HeY$EUapblorw8>?0L(Q?hdzAf_yx zkJ)j7H`sG3$uh%&@Z2{#cy3C;0{pTyDL%BQ_0XBH{3u*Ytr~U`X2vdIqZ$Q5htg;h7^LjC)Ys;tXu{>Fw~-QmU^!q^GIIYW`Hg6$rc&@9$= z>ZA2V(@pVMN?%%+F>Y7Awagu~*tNtFI$@^N^himZ_DWStf)7Un!_tu@huMcn&*!MO zIXy_#%k&p)$c|1%Xyq{;ai-h26bn~QHP;9fAuV6~^S2okkG{bo!@VTyqpT$|uxv`zf4Yj6LnZ|6qVtA!?lw#ZSF`dfS zYB!(vARXDE$2M|NR^7vjr)COF9h_xH)Qge4C3x3vceki*TC=I!v#3e^^mc1=Q?LvR z0o%$wn6X}YE#uxSq!Xi@-uP~1IWF}{|D}GacuFRRBRF9V^@b?Mm>zR$o*FkI+FU2) zxq&swe?11cJUK`)>$WlQFZas&*OJ9}nD>tXv40H6ivwdo>m3_%GD7r)Wk+B?dlnK^ zUzl!?HYF`v%2_jP=>3lIelhbMJlGar)*@z!LJv8hy7*K3(~` zqGjz6TuVOIW724CrbAt-rvAP?Q7Ug>pB*VY`t}dTtlWT{@x~1#iMEgal<@n#LLkMK zQV+?+_`;4M6^@@Fm5Dp=2@mps0G1eaDXHPb>>U;Qr5RV0&mdZu)4WK(N`>F(BX}U- z5J#rj?o`{bsJRnOkC;tdi6o_#gW2y`DvcgfW~EvhJp~v&>2YBZIl6)LB*G)kN9_$Z z)p}h>dAL3~OAPVE3#DG{pm`=8%Q?-6HssOUii_coJ$REe`(eI_gj2P`zI zeKK2Axp$#nRcWj1j}_l2k5d&6Wa(LEo2kFrzKNKr$2O6)fbY$SDEIcSqb2TF+%(Tl zxM@CNPc)-l_8OY8W!xW!^snil?pm>5%v`oD0lxHg$=4u^pf5G0S4OZm_^q>>adN}sb=mdU zG%*;=gELgki!4(2QW%l#(hXY?Q)Hm7OI|`!pW{^qw?>$Lqj$-MU@x!RgJ~mr4>H=i zRbH3er105f__chjEJ3K#Tgx50J^g-~<6>a`NGve4 zOMl-jFx1P2_V@3|!Rbu{{rkzo`U~Dj&(esF;PnGJ)Qda2RajC^atx|dxI;?EX+@hR zc$_#?(Tuua1!2FbnDxT1LQxRbLW4QyXTWdUsLGutx)5VnXci8f=6d>l#LA3qT zWvvkjzObFYhIg6n+!eQ3QJgNOD2g+X#0~7gJ!mwK?_ZpJ(3n5qFmB=5TNs%IgRI(` z=!_z-3DGTDCPDtJ0@>Z`OK%LUjQjv*JZ3X3(_ zX%xTJP%Lq{yN5+;E;&pWZ_I=G2i3)g{#(YCTt}a*9{b}@sxtE3s%mdUOmxS~VL>6f zeKNy}I?q73Ui7f*&S1Di?f`sR2*{J-Rk}x6cmdELaw4{w_@mHmG^ul4m(SoNEsWC$=py=hy}qK zeM^|fY-owRxfD2f=P=8_#FN=|v+>sD^bR!M*++~|3^blHYUeP_*T?s#8>6a9kwE`O zWT?ko|N% zQmRmI)jYLY@ukR+_OI@!E2SUSg48q-;b;`wuCk^^8ME zo9wUbrCI7+XK1SQxvr0&tvSxQlRA~1wI;V*3TbaJq&=J3sT{L^^S&~~p>c*EhO}p# zWiRLu(u%XaYcIR2bqr}=Eniftd-}g_4{0xzOChb8hLF~l76@t6`Y$yQ(w<>Rs~wa2 zm=OHzxQv02wtt`fU&pkS(%5-O8aq7udUyVFb5fUuV0Un@p;1~#p!W+_NwX!hkALSa z|3C4SzOW8NP#K6BDi>nn-z24j`e0je^Pv;&9`1apb~ zWw_cxL3ISVz2D*XB;k~UH+q-f&=TzN>#?4mZy1=UQgfk{Pm!n(F0xc zi;P0mK-WymskT|Ny8S~liflvSaN96W?Vr%$dWJM9|Kp~^wZ-&1+!8q4Q&@jWHf{LV z4maPp!#Ut^qYe)nxarBvt2f82E2nQ@%+6nE+%ZtbWl^4CHHQ+s*kSDO4o@9$xU%H7 zh~0Ku=#$;u<#AC*UrX1cLakZCU|F8<535F>r(r-;Cwb>ID}HDKL`J!>9F48%jd zdkvRfo%Gp?;ES3SW))BWkSa1Oq@E#Rv^PS^GO>gJoi zx+OGC^3|%h+HU%4jQK6@K>rmlpc7jz*rY^ZpT1_l8tbd*d9jF$PBT#223HBk?5`{@ zMNAg59j7-kjv~SJG1czDslXKN<9gL7#pLtCHTKeIoEpBMUhasm1PP-xGcmnd@6{g}Dce~mzGK)f5`(5q!oSFewJO5GR ze+=}Mi&2M14|J{TFvq;Sp-z8T`FrHEo^g1Ry@PnT<#$5VK{*bJM9UV6P0BaLpDJAG zLez!Q)`(O%`T-Lyp_tKOM|p zr}(>DhaE|4W`X+Y&RmS^^llX%*_tKllw(I6`ef*?ndbL&Hqaq`;uw<>TDT5!MIS-ED0^GkxS0bByNqGtGzc1kaCtpdhP4rhDq?u#BGjxjd zx$aed&U@H7H?327rPfpzaQ~M5bpKO2l~eX__R{^|cZd5w=PdhCINU#uLahD9ofl5` z-(J3`ApGfn-tPWi-5GHI)9C)E1>FDIJKX;{y8nnVj>igutH@~gZ+p$lt&S-Da^yfc;gLcS{a9iYrgW%r_Ff?#*#%Z4*WytOYz~z#w7fnb9 zHwF8`hTe;1xEAa;7_soW3C_L47GTL_ z;`xI4aNRE~wyt?r_F7cFG~Q&>1M|D1V7cWyR_R@3(c2F>ef#%&T|N<^OULM6qT@0N z!baxT7(~t{74A3$=drTfFPWQV%g|$Zcncexl^#sp+zD|bUbefG6PmckpI4z$(r8`4 zs1Zh%`L;+W&Z8;Xt;(*5Ga1-@XzOucC``)^-B~uqFQ_&OLD=5Qv_7nG9*e7P3es{I zJrTG6WI&!4r%l0oCHg4S z*0TGXT-ej}gWlFVES1ST^{{x$)a& z2Nm*&XPi9N#C+?KT@P2jIo6&!v;Uv`FBAnqi`E~cDAb3dlX7pof6m|@`0g76=aqP7 z57R?W%5>}Z#Y1tFPwJA$yDp6TQ(Wqu$(S^e1XAnTKz}X0W0?zkn*(P%`?otW$J|_n zy>6F+ng8>tfs+PvyY#E-wukVbzrp-a8<-zH3>ugp8g9-H4WUy4^Fw3P&H3S#SbdQ# zwm_BTTH5{eG);VJhAN}xwE8K&ERk^fX>am_anf@AoKK7qA-jvb0{fky#LClPbj^LE z@$K2sfi>F$c!Q_lK)N)&lwuK4{H=&HSHHugGB9j>V{l3@%RL*1^~Pv!j+{OBAzb=x zC0w~|BF8XcJj1HI=!P~wTEWJFS?1EtZ~DGllOAv*e#5Q@?B zM#xIbc?zZrcLWj8!bU(+|#x>SxEUJvf4k_H>8z|dM2YoI2 zrGB9;%$b29^6tfS8;!;Vj5=YIJ>Q|MF|;JeCOB+uN0W1~ZJ9h3dsjm2ne?u<@D#Ua zlH01{nXpJ%ezE4>dt`EM#7+rZp&M6l%*P&2XT=3aML4E_e5^q1&3JBt!*I#glZUJA z&UkvKJqyF|pSOl#fm3tu_`rpx1ui7U(aBrd3?-4zSf82TuMHiGdaf(Y(u`7U9c^pI zJ46^Rb>g7rXwQdr5tj3MJ?bX4qugoRKcRmxd9zX5$=ufWX6q$FwOHp_2bbp=A1Lb1 zczMuMKM}@w-R1Y>qY`AI?l|BeFOxqUk)XAQx_VUG!L88-9H=}*$FZ5UD zr+T8wcYXN$Ak2IRGfmhi8tYwTh#xM)8yWc1vQPP^-5#6s=X1sX&d|%%st5K*^)|me zSD)xicSr5^#uj3^e*db3ea9kZ4#rt+$KdU<{dlxB*Uh2M9`ctn=T9gT#E4w{6DL7$KuExe!>BNqtH8#Aq@~-XuH?`NrB6x_~*WLnq zQGv6LcdSbW-dXA2`|Ll{jkt03f4=*1%Jj--7i+gJQN&FyoscrU?PnYQu*+xihJF3< z5Y_ERBRS0l&N{sM#*a_L8hE<>@HEa<&*AVk_q43DS`*5e(1_3Dz_&IV18;uDLbP`Q z&KG5d;na2}Lpi4pyMO(wG0#ua3?4WXHcj&*XY)XQyEDTX{8L5oEcEftOtv3c89x@PaUCl^T;{}^VZK@L6(0G*| z>wSuK1*%bpy$cJl6fp;pr8SUI$BZZ~jo3NM^cQD-yu!6kA5vVVABCglP1<@zlB_AC zGa@KKf9U4>DyPw&182M~;D3`dNHIlkAmAT3p|aP2W_kJ}lkLH$5-;sFlwZ=k>%Vyp zC8hQ0)!6X9{T#{?MQZEQUk8@ywqf%Q++(a$DUOltDg^_Q2=+bwNi}}Zk1{rGS5+@ zXp3zFedqa|b?V0l7WNisw&~4@1NrDmL$}8{3xPLVASK}mYim*hPKG!2B9@Rk)^j*g zB6X}eV{o^*V{Ub<_7L>q`x-D+)cG@bq@)Q)N}59MK2q}hBE@fNu~r-GsS#VkO7n|C z`^yr!^48)^0qOz8tMPlevGDQ&FMnbiB)|<{K6ZSgJgngz* z!g<1)RiJ*dwq?3(=&t+CD|E7#hE|vP#pYndSr#nmuNVGe=-efQme(9ixAenK zuH#4xG|6*NwY@>6ny%0i57u%1hRgm<3go?5$B|9&S4!(PX%1%=)^YH;&doqR<00qU zSjXwAHI)Yf`Fj}1M?$)P|K{fz$bWIi)?~Z0>=mx#SaQ5;zjrqT)^W-f)#;x8{PsY; zH?WRVz;zr?|2j^=9f5p11NpEqX^-iHpBp#x|7aa&l{A(D>)6rYyVh}Te~%BNh}#Ir zl*|9x^854urvnG;+oMMxshyGW=c*gltKsm455X5Epzm5TvELFY1?zPhOd9hI8@(>O z8zH+5UQnCiH{0bh30$(k1BvP0LaW1yTgcPSTyS1q5X&HKpupT0q&!bfauqv{l5S zDVr4wgiTQplP*xSVnC`@iv|QKSgRH>3)Fg|z=efr#NRms z&YVVdZNOswi|-TUuG~oMIA~Jm%mT*vXYsnQ#4kUXCf-P}Lp0y4(y9?>TbQ8IyW&?E z3clE-770b@EJx{=P0|f1GW9rnPS#o($!y9zNS#eF9YL8B93Hu5XHC)bwpq^n;end|k)!#qV@mrwqYtCZ6v^030?K9 z*aS?LuCYqE*IANpmDuFDT&`oH9d&%P|*&Ng%h6YoS@PKPfmEWK5fF|K3}$v2+uF9@EClY!c#%j zCK2Hw+U-APshgqQ4tXxO^=cBGfOb2P^mL%Uv!@?F&_Q%B*S?wq>a& zI&Oi>L`?N}`{e+Ft@R8!;I>M-7W3~iroUT?s z(gCrJKzXPWoHO?)@h6YLCMVpJIN@yq#DW@O@P?Zh{n>>RtnYXZDBp@9@1BZH((mmP z1pT;w;PtC)ZA~E#uwrkE4j>%v=`wM`7T^!`48&w~CIp<0WA);AAt%Rs6iiSt zpV;aVQ>-&4!y`P;W#f7>>N(In$U%H5Ah~WklGAPp^gtB75Xc*~rhsWx{OQkSsMw6P z_9{iXC5$kNLNLfm2BEkmDm6`JkEWJx2yqy5uMldTSwm_FyIlbeBg@U{qAaYutntto zb3<{K0T>e)H`Grod%q6s{VtV+y|1f`?$!YIzCp5*>DcUD;biX&r{g%y!_z&bB-zH#fQ+dw*Lc%ieeEPT2e7Ju!`z0W5oemnYBOKkxZ3_I_2glf8d2 z$I0IR#l#ls5cYl(Fw{?-?0wsyt6P|!jiQeP$_AdjuSMm>Jn0inZ9L;s*wR{+RStq| zb-ETQzTQpVR}kVZV^Cd2$fdU4n#e~vXm0@yf!W8QF%+4s-j@Z+Jhy2B=y7h`8@m?5 zsC_t{^teD)SjD|7M-2U#Qw9^4%WMPqqVZsW{|4?BH|6GB%}5#YK5OMty$FGa^rnTSFbN2cZK^mKwFqvVQHwx% z?_ThwL-IO-O9dHf^a)|}F3w#2P<-}`-z{|tE&?TzhJoFINLM%yNG0|Es{Ta(gyWHO zJO^h*14?RLln5?H>j!t;ayRvNB<-<=e+Rwc&ho!-0OA8|7>EyVIB6u{I>7J$5C_g0 zP2QkA;H=nxw_%*=5=@pPr0FFr=K4pRFPF!^>hjg4-g;-A=DaMQ$6TQm%e2Kvf!W9c z-t`|Aa7`c+Rb0uvpoKBTanCH^#f-X~EZ{hP0h1KHa{+fJ3z(F2wJqRf`~q$katpWz zS-@ma`vu&CEMSke1w5UU`jQ1)RCQVJg9RK-7BD#5>;ldq1-Bye3uT4X+{=sTLl5Opi26dAErnZ=i4lnoKi&r^ zlo#>CyJpj!`QaPX)Cg{Pe#D3t{_1-Q6@@kJ7jU$~ocGTKOoAZ`oo-dEben?rW>+Ad z04{9IwFDZB1tUrDp*&taCiZBeHdN6p(i}QMe{YNzk`xV+!#(@3$5je)j2!RKSVmc! zH(l+__LvGtwx=JXtSIqW_}h zZjtnIM%64Ar&~g{pH{pb&CHOOT17K$N*OM7H;GFeDY}Wh(AQvwTTi-2O1bs}ri zEBVr*p?+XZ{#P*hQ9>B>ZA_)6+eMgsJ$hKcG!!Q>wkC@=CRKf!G%GgZ1Tric#r6%XnxOiB1+)4$JQfhj44SUr2m!`W2*5)I9I3UYt``WsH~k;$^7xSezR6UlVjbs6!?IVx1d{ z1oF|qB2_`mNEy^Mx84M#!dxtR^|{rL$Y`VgC#Hw)MGV+UfF>)vPX$A4T5G1wZm;#4nlF%YcYTp-nb=0X&R|re zH7%{0L<`0K(lm{7MP`fNE4Hy#(eScl8p)j_Af)R3TEDpJTRdb^Q<1yD;oDFahqUD? z-BgDtqb#-}f0pc`C^KByM2c``*h;m5eJ*KGy;+r3Tw~G+g|UxBX~w{%v*iii+@mge zn-)&t-rKUSai%~{RhONW*=sn>J~xh@8^fg^w)^X*al?;wO)~St51o~l^TQYB(DV4= z?{ud4;rF_x+Jmptuj>umEzXW`1Xe~fD$tnLhP!9ziKCx0s7|pNNSnHlIV+gqw}4dm z6Ta!xf_V65RT1#b%H*gM)r-e}G_%2N#E*T7bBn}N3zo3OqrUjC+V0YwyeLh67iQ?y<4bJ;+Fu zGqOfat!OWlGZO{Wie+Q$#kXF2Wu#6eJ^{KoAISo*aq`_L?EAosf!#^Xd={T4KcDz| zXM}4;w=z$3pEGr{1yo=V)q|9Y71Q7KWtaK+E(c6i-Wz+6rZ*Q*)?a<7Fdwm$m*-wI zrC9Ga&}#KQOfKCMXH`|w%RQ5T!;li_-BS2%8Dftrv%wRs3VJq)KDERi&D|+$w2}@P zy4plP$=6zuuJ1ykcz2#JCEB5qg94S>$zlWb&=WogoSVH!g*4Y#}W za%s4+ZkQt+CDGAPEW@D{&`l2o)}^Sej+Z0-{HMuF*%v%(-y|kQCS1RjKBpLf77w&I za!=lv*rs^3LCaW0(OYB1_l?R1w={$A$^?Q_B0#xIw}e}2wTiqUs-_~TY=j!mBvsZ0 z7uFEc)i6Bl3;swRiGmABSR|!(lVO- zOD{qwY#Rj?sr4gvaT;$IpAb)42kMM$1_PD5Eif&ExIYb~<}XhdSh}-XM12LSe3t=C zmAA(j@@Bq8RbTa??>F?vn+e%F%^tWucPie@5a3$WB=TlXi&TN+%`7Ev#=Vienc%82 zL7>_woYW+maVWMi-}z>UYED&3e#Vh&$!dEu0;TiKWb5!| zJm!!$qgu_H#^-~1XPOEAmopvz5?9Nf?{uaQz?uHZX&R5fncgjT*0L|Wx`mp9T6Ril zD&U;yJezj3p#K}g^?w-3;=7ud;fD1dNSii`4lv5Xf}gCx+jV+nQ9+s|1W`+|cuCyc z+ND)82*c5gU^lV5v{BH}NX@gCCCU2nhVui3^aK9A?UU-}@yfM(ygr0el!rXz<)=Aa zynj6X5kLGxm5m>MuY0m3m>Uj4O9emt(ChRuet1PUW>u+QN*v!;XJQZAeB=P91LAGX z@0_P|2E>|`GB)|P&>=cn7F)L-X7?C@P%{xt7;_C_aH{duM`g-}(_>;C$*fa-98UB+ zFt8j2kdI;}i2k_g^f^U2aGBRb;8z>Pb7YNwhAd|mbrXL|2_>Rcc^jMDZRzzP%XM3t zrkx_t=LQb?3}7M@D9vVq(e30Ha^JaTv29^=PD)3@@+mJ&cSLOZRoAX=sd$5v<+F;g zd{#NrXR{Sfz9>)foTstJo8yAYZ9&qx4$nD)RI0~j=AnQ?31qMfQ~%*(8;Fl>CqA|@ zE_gvullF^`y!NT*P8F{thQ{DaK1_?lhcm(njKKcMPzwYEarL8>0wD4Wqmj*4B{2#trHOsGR`50`DpvopM6 z=p4S()`20qMVI9sWG!J7k#sR4u+3DuM!d30HAk$QriD`;!aC)3?laZs7>cCvHeHaH zs`GI zPM_dZQ+sS;wb?08>K1~DzMH8MB1P))H8PubZknq=Wo~$$Fp7I$ zdjjLSAzT2p@(5kirCjiqM{21I5nZ~`W4SG=$@_hhWTepZ=McnK*Ig62n@=)&eBewn z9v3JpgqOOv%TWX~Ru4`c<`>8;e1d6b41I!+G3Lm0J-8Vp1}ED2;Uxv~7Jm5NEcyUH ze6HHon;V`VJfg^xyZW92MNw7z97Tn^F2^{>;hD{lpQ9aN4irqG_iCi1982%HT-US7 z`xv|xT~DM9*lwP2@LPh8nGk>B`@xs=o>W$fMcGh!joOgyd_MxC{`+a6Lap9fo41()2kBAg)}oiPuR2o-Y*P)V z`hsF3eFE8oN@!9R9mu4cUw!g93p@7BKS8X;P5@$iE%s}t@ArwesJodKP> z=q89~0hn$eo~)9UP0qSLvP}Dy`{!e;i2vskyX&7%9ineV@1c95MEsUltdlbz68O)i z5do=mV-P1SB?_Sua_`<@no-Qj(!`N68~^n+MWT_{ZJ*`IOL=9b$)L{|-8zZUDgp^k zfWOK@3YEtwkIWR*a1G*i$(lhVjzQ1EoUNf?MfO%zW$8j|N%;z;Fq1}J422@Hd8Bbh zEomh%p)6EcL2(lk1qxAtNiWnQ>{@8SxW%Mo^sIJ-Rw3x_C7iCXAp^zMe-*?Q4Z?FI ze=^j(UUXthNKx^6k?Weagbj0=U8Er=2FulYA>L|3XWN>pwK0b*ga3NU7Wg%F6DPp% zu^`kFn-twRIbHh16IFlzE+T4&%l`L_Hc`!_@e_~A3Iu1tL$Nw}D2@nT_gP)K&$`LU zK`b~~+g|rAM-)WlN_~|Y=Z43a2d}Rg%8qT)X#ubbOUmhq4K;@(D|uZqNzHWRRZrju zGjFU)@)b=s(I!aOnPvt>&bwhaA5mfd>OOcakm@#yyY6TC0uLKO={=U(y7NBe5*_^^ zcdvkH{PnR6)m4I;-`TATHD6G!krwu5-OQeFGtIGfbcv&n5(V`IU}ou2fq1rEn6C9)N%d@A z>NbF7>KQyqXcXJ|BTkRyIFF-0CdcRe8oxo&QLk>|;{xJfRu3Z%<}vVBMl?-334LoD z{H%kCq@OvTuvSog^y5(V#&|;0yJA)rR_z?no~GB|Nz*Iz0Zp&Sho$LJW4KFkG(BBk zdz>XF$76Hw0O3{7v0@3EwGeQa(zCf+hh#ikCr;K8;W6-iVM)IvRxi?^*`c3MH@33l zsjm|}ayE;F&_2}<L}DzaQ9>5m+_iSQhN9O;D9ax%C-h>L zKdf`{hs%1ms#*Ts z7Yx0fvDU0?wSgEPt$ST6Kf?&yp7u}C&ZjLwi%*DWl*uM~?N?dz(%c@j^MN0y&0@;= zzzGfC41@2nGM->lEglG}(9Qu(KRpe`)7>4S56#Mo$}zDOacnqNflT3fV|*nU2Ig_k z1MG3t1#*n;M?;UAl*eUtDl$U6z@!%%tp z44MKz-qj9)j75u}i`tk838cd(8}5-(4Zi3{2n`Br%wTEdTb>EU_g*_8MnOwY$OTlPR)77E5L9|g*D=TG!y!xmX z8eqSZ)tkfVMOeFLou)w30Kc&>vPJR1p!Fok)3%EHV26ZQ|GCKFq3k?4xE8L1-F7yP z@)SqQ+&YBz^gnE{3uElMT6|Q{10T&W1Qgn2nNnLAZoAS*dLU<0AyyGr%ACwrCsiXF zB)l!HBs*0bNgHh)YISH~qb=+QR(!zt|h{CLW}UggdqQOWsmP*15rO z1rpVd{AlK50Zgst!?5tb*CPJ#)qii-vciR8YW=}u{F zTtA>k;)GX;o|nNgC(;ta33C*Uzae^jRkR3fz;Rg^iPh!*jO5)Fm_?K9_*L z-2L=k&Th{p-);obG=UYY65B^mu5<-fsmGyvd!UMw+IOV8sF}TD*tix zi+53%2Olc2H`1}E<&Ec2l5O-_YFs(akY6f{XRT~n=! zGx%mog+oeZb`gV?%tjqqILv@rIlHp1sb67NmN~bSt@XC}Wb4&OykyhALQC^}|G{Ye z*-Yv!R*|v~*|%r`@s0Q7meH}+sv)HMz~qaV!J?&n?g)g*)qu^UeRWs|K7OU)zJn0H zO`wCHBb{{DX`WwF2&Qv(s`_I(dzAV+I(v-z0iEqW9Jw3P5<2%AB-@#_(Mec@c3Tdj z(etVJ7|Cz+Z@VKqhyHCt23=|FQtLhsL0=quArB z9CECmY#Z8(QSQ#0UZh~HS0h`G5zQz!7yMMOfc5%kJ%*vmdpld&kuUmqzv#*+V!bqe zD{Z5qzYw*vUK$ReA4CxP!A^w!;y)021F>ES<(dz4UD8e&oCv)St0#z$SLIh>d(-Ce zgW*e~XrX1MmDFO)(4iIsjo;DrSLQRrh-SYli%#~;9iFlXr?G!B)B?fFtR9)|!l+#F zG2c&%zi@JG~loVxjOIU&Qw z23=_*IJZSZozcz&Bz#5TGO!E7MJ0 zs+BBtasO{6=p{?}*l4VjYRfUkYxy|rAq{!w(g4H9(XQc>R8)sL#Wvd=CXSEl5^*&! z61l^zIINWO?PF&KVivQdd|#0%^egJn2)@u@p4`D0H}WSQLQ|*882-d-flNo~b=Qe$ z^eV8Y=B$F@o`l4$#U?~kt04SXi6-a>S(B9FGBut+Pw;bh z%KpvIeVnLA?HOqyA=}5SMsH4YVD0Nr4GWmi5zll9w|k&0Z3H3XqL?dP=83unbP2yp zk_)zIbP2DDCSAfwbGmO)h7cWpgpT#(yM+5Ibi)h;dW!=JpKB&uCvA*P{;`>LPv0U8 zE5ktBiW7nd4~d$3s)Yu*;FqRq&LOlVbO|5ghj#F-HP7iH@8+Z-rDg-u(4j`FvNW|% zwhgW6ud3fFBG#OA2}emB=@R~2BJib!n&+1aht`Vjp{00kn$b6#8dfWMs8%-RrdfQm z-G@NKA*anXqhCItBU%s^=DLYd>AjwC1q|()(fh4QiFD+a$^zk7v0cI^J&*g;o|%sO zFhimzA&wNQDa`x@(pEDu5#^}~Ch^CvM+L>M7fy>{x)YSC%*NY z9B#uLRnfX2;B?CvOrcci#OZ3uxMH0$ba+cVOJNC0yP_D-Xg*HEwOXAa!&^Udz3Wd} zQOxRXuP`r#`mCeeVrxmY+z|3h!c=sFAA#!fz^MpdAp}A%o8^cxm~uv&tnU5Glugw? znSwXQ8Br1s^hj4bLRD=egA!Cn*}o4c#IrSG*qDku^v$At4z$DUtbtCBHObkO8J-US)#)Fc>(Xo6kU33u_ZjM5ja%Zh@B5(|W*hoR8hSII2|ns0o#`Zl z%As?niYJo3?LCV~-}VKp1~2kK-*$@xecMfWq;Gpk(NEYVN2ZXdFiHx{Xp(2t zkA3nA$UYb3Plj3`z}8}jS+ZFDo`lbqqG5tSo;dD_FEqU!6$7+0uBde41EfqG-k(8b zA1}lgR~M%5JER?X$DedU;%XS!xW=jF0R{ZW>?rX%K|X+nW^2jWhQ!5p-H=90@J5rw zNBAN>k;9wdBZ_o3FKqV1QdQZghNQ=6WL3HT64G(e_)@w~k5q)zIo&^;=1H21v(Z1D zCUN=yWiy{WB$z@QAN-0oEgC0I5E4%AtObiRmg4Uo^nd}m z0db+bZ5}$(Q&l6GZOK^f_FPeauccavLhwMc){M0lM*NddT2~;HD^0>#mW1Sj3z~ec zndfdZHdrgNDLzw8L3>t@!t+~f{nEnAOe&=pGGB-yddQTX6;>FJXf)!^Vxrmnz;W+1 zEbk(gKt@5t0PI1zvt;tc^lA#=S%lo2r zt68Z_a>4c=oDQd_sea1s;Ng*z6)z}+HiJ6*g#<1_xxbLJD>Vi27}cEFQOb^LHM@Z@ zh~?uy2nOyglx$4+!&KKNjcg;b2kowmd*a!0W>T^-<35szN2`V+wfo(|W|#!sFqni^ zyU)K$Hnt|q&{P3MIS961UyYux+Ky*;V%EQI4NZNOCSg8Nq;Caw^wweAOalTbC2H3{nI=W7cng^y}fplD*jGTq!- zCU;;pTF`6le`AfxOk|~EN1$3G_O=zHkDquFdIXQI9Y6RuhyR{t^0*m1qvN>Dq>2yQ=*tT+53)@UdP&nx?xcB8o56rw(I~n`DT~*0G!Gbsv5rAX_l#4opRFqJc~>t!bzw(0 zx^JUQ*)$Gdc8-j2K5jC~^7TJyB4C&+)wk)h4YS| zIj1;|n_uaOD5gj}Pu3_58$8m1*YV{lAbeH97M*eI8i%N*Ol6g@4@OLF8}q_{w`?*KRWl zmUN7_ucuy4+hbO%Vrf(12Ucx3=b)4{9z7$wn_=lNC8(S9v^Qmbqq{~g54mzE?a7jgt`gZkC zTjX_*f8Pn{RQ=n|sU(am{BCgnBw6MCP>>P_c7OjUbW_MrT@m1!|C)B@ zBpi&&=8#u1={du0?utz^9=x{b2_nZEKks==MT#?JoN+3lgWv_Cs|x|X5SHcY3FVp~ z5pF>mcHbiGlV*{@EAKA#M4wwT#vwHhr6N0^hJJLT+?8g=p=Gxluz*AW-~wOzpMy^I z%I}-jBF>2qF@Zs=2A~G~0y?(8szxg+^G?bVhFZN}7DPI?8z4ZHB1uJ%&G*!@5PPD*>I3GQ2w&k!Xc`j;vW!$znp5BnJWXAER+qFcatJr0wT)l_)ba*J9aY5B6wLAgySXb(oJ0HaSLzZ>CGc0;k|f*8 z58o|MaSfkK`|-nP22+Xr@Eu9|t(DOf5(at?~Nj@JgV$0=oul}gA{kGol6R>;czom2^)< z|5Pwg)M8kxLJBq%m6S8pVmy_$D)Uq+iY$`7xMxtGD#Ic7bPnaoRI&d2I~s|-p+V2T zCPvOdJR4|ObUbNQX=dIviI=z4Sy*Jy21TT90_&ZWVxVbmi+C>56VXVU1{1rNO>Yw? zc^S|@-DukhW5@=}IdrUrddUIjM9r2_cGd-=QHunP6|TxjYjnL|Ib#;Py-?2RI_JRq zoOCYQEmQQO#}c@=4Ekt#O=P>NtuseHqS9MP8#ohzI46*{1}1_o3;{5$$7e9*LiBp0 z>y7ag0mt--R8EfEP3pnzpE;eBj$x+`?_E~q09QK)uIp7%H26LwDxGGf1_2hPYsLqBSQSUOm`CPzA1X&<1%_Qe2=z- zhP6i0^6DmZNI~2nLQ4jHHl?HpC5n?~&ZtVrYdKqe-?ucyIP^S|aNX~xvCE;{%o_@p z*{~HK(?+K{5Wd@OP+J!Ziw=jJrU(d$Gc@Y?w9^r}Ki0J>v{!VNZpbs2I3C`d(~CQ& zBqVoW6R$2ljcJNAiuh%CCrJxP*oJ44!QI4!`eTJznqT(JqRjiub5|lNJ%`WiilX_5ibTN&Z77% z4pg8+c%K_W!h5rbI13-C+sF5Zv9|o^7g)|`6wwa^?I0QvI6Zqc61~q2xeoe`KGw|}4o!rzZWe&_UqFVO`zFyX;672rzB!YJ3O_yv3%f78$fy8zqz~MU19M{;=TPWM%<8d|F2#NQ$c!VnG!G zvm}6G4qiyTJhSM<&CGO-7#o>~G>NLx_2N>U<;8Mpdag+zM&0~Yv1x3u$ts$QXk8af zmctb|3V9u{XwM>ARV^oxZ;9fnyg{^p7p%v5R(KS;Wg{YZNz1<=USN#FYX6E2_`jjQ zlg5nHU-O&f+O=mgy-c4bOXQDBrud_0jABA^J(U)qPiJb%KcLJJA{k8fbT=MYRuaD!l*So!k+ul2V zr4!mIY|OZkok=fkJYa}TGW>927>4zmcX}TV&h6a?Rq)&G4+r7kwa3I&ux+^B`{yn4 z(wlXa0Dv4H-u@dSX5V~o5awUH>*wv4RLs){S1w@^Yku1PL2Q!s{qhia|F*|BKR*Y) zdFKyvKhq=?ebXfvcX{&g%ubwk(X!9<3lk_&G-I1R{B>kr@j3aIKq zpEXcp*Kb|GW~$ov)KNHC_v#;?;N0d#)4si=vQ7PS9@y-Up1gYjttsC5R3pXHU;Ek3 za5TDw7gR(tkEMTqj^#&P`sm7!XtEhO>SNq;#ZU85i`w){Qw`YXdx`wmk`{D9KrX`lzlXIdeFMXuc6e@R1k=l&SD>;P zNy$_*QfOs#DjX8fLNxmFT0Cias%zi z^uuY>w19${9)X5>P1oUi-t&PPKF)@ zLOJR=ftfNs8@geGA#i3^JZ~Phx@X`pS^bLtDzax~>uw4uT12x`7`RLT4Nr4tg3DKe!+H$lJ|ntJ#9O0?AP-zbnAqQnx3 z9xYej)`#+5QY536V7ALv^3Kt%v&ppbV%zx-N2tt6cqZOUMSYoV)k2@wdM4~sGrGjd zUR`HAnr_+hJXNQt#U6aUu8Ltrjnj;msbG-3oJFd!p^XcWxU}7=>5sQqdUFAO+;Wmo zeSq3q5V6I{ElKINkd4%0t5tLrn{BLO78{X~jpP|!yW9Lb`_^r>uwMijWt33g7RuckBDuIn(H-m6ELkpV6m(zB?1C2hfGqfqS9K)xX~MQK&AN{0 zEzyi}II}q0EM8)&XM!83t&$vElz#KMFRc>PfPQG&fTQ;@zv8$sGCyj2ll$HrPH(o{^<*{cjaqm@A_xJg8 ze)qzI-~{0DIb-*>2q<6}6%HY;-2a8YvUMZkKoRe>&KxE@{^$!=pRfF3G!D6})QwK` zY43badmS65V=uAnh)`+X)cz8p;gtuP_;c<&SGef_cS@Jxxd+;EBms}vvXLdrpyR8Ax(@(Xe^bFrMz(l%5+vAMFA7HPH{fTh+9D+2- z=Hx=M=gCWmesOmaz?YTH>+i}KJBz@ zKr;1ndo7!ooc}Yl(N&}W&uv=vkvGw@Ip^B6?CQ>@b3eCf*|iDe(4}SVgP~;|o|mCz zD_-Tatiz4fvMrvDHZ5B=p3|~LRilq};$(CG_#cWK3v)F- zMv13my=Rs8aoKR>h6iqo*DTOp^?mFGCud;s6RhQdc&!Bk3rj=rciuq9dm&y&WDjzD z%<_;jdB{$ANMCtKmG^z&ASj?BpKfrh8ak zAzLlJ%|y1I64|=t@Uexbq}hk}9ogdhe(`HWwrqnTTaT{${Pn&MWb3j|@gUs1aIoRgx0bV|N@YTaU;Y4Zws+ z*xcgVPPX#L+q?qWIb_GiCwhyH&`*sQy=*QE-XJ5g^^?6#ww~+KSmceJuep2oOC*Ob z+0rIMwu*coLADCFan|=RoldZ=SIN5qsHF`m3PPR5o{HJV{OwKU+*C`$m z*(y{xWy|jSq)oOym`G&HHazX)Bu=&hramiM^;7e89m($^vXwA+8$HD2HnbsUXnEyO z6EjTzxAktW!}pY`B6pMqw@^AOX8okK9tk;qb^)o~qx}rIb26vs}ru zdlQZE{rhq~HePaPytk-AwbC5jy%*yR@!X?e0_6>M?h%(wq+P@31ofZ5kJ<9A{31VQ zcRt;ZyX3^VU3B~MS%t=Fh30FB)3S-_iCa`q*%@I@FII!vBZ`a>|$c5&fo$d~fk%j?Sb9^Rrsm!REt_ znK?DZ!e8>z3y0$eabp4&MKS!ClBpxj{Foboeog!(Pl!Z`1Nk4ZSS0jy^)P+>=0#fB zDoYFCe$-K|qPq}3idrM?tn*QGKtsR^_Y69Td&(u>$xM9zz+V$-12>QQ6#EZ*-uL7+58V&7iHtm7EK>%Am`IYj+Bfh17Tbiupp48Tcs_{0BcX&-w zN;#cgQ9aEFx^k|Zi6vIa2lcL#0iptcqJ!9EVdg7CE zd?FkfDh4YIV|w8eD?W)E+4%DF_#{Y%JY_`ynlqiO9z{o1mk<+6n8uluAJm}yMUukw z>w{gmszr+vkI#RrWVI zEqp`7YGH@p@NicgRWpUt!kcSGpX$QB>hn|nsf7a$WLybAubvucVY||)g=K-$!(DOI zohhsqo|N|8RDPixd{zrjJf3e`Nq$!h(ZZv)(W6anV;XYClvj>1F=O?^ZF%|Xz?~r{ z^B##~vp$zOmq^PSBev>xN{QRIt_*8sa&rynD}nRC@|_NZX0wVs)EdI2q&MJuw<)_5 zjXf_!WHE*3qnY%^P33Q~Ns=ir(OpA#i`*qMQ9mo2t1GDiV7xm&sH6sMLZ{z=NXJfK~1z<`qg|_^5ovjZ1QAfR)hZjJ<> zO}C(~Pm5?ybLE|!gV-9)J|AyK5V8gyV{6Vq2wLt-j}nd7vwLMx2hT@S3(`I<+Rnxq zb0BCdJ3JDlOI8s+Me)oHSHw{ zo#R9iWi8?rQQjxiuwTVO<+gXAh<;6iBHFbz7>X!-F<#!>ZRk{QXbiy$uqMqVib(e< z|7<^x+&z`t!KH{Mwnp82a}leETFDuu;(4qh5`?r+0WZ0~;^VsDYNw?mto;g5G*P zCCm@v?VH65)zt;PB{{t%iydDwJ)A|unZR477O3Z-4=o`iQKZ+Ig#eER0^ao1f!h>8 z5XcNQzSa4V5A6ieDHu)=9it~fbe6M_&meM^RYC@<$)&=(lEHO(1paxCZDOM&gYU-y z{)wMQ23x$z;NxVlE|$fL3dxmMvi;WjkiiR*Fxc%J8LSE+gI9GTlUzy$AAgk$zC?&w znqg${7i>4(zu0Mnk--X)2);p~j~{))k6>cw@fmns?JBhY zS3ncl^mrYLN}Z$+va9w!VNG{E*E^FmxB$aT%ApY+&ix`668xyLD+f)zMagICJ!$|H z5x@dF8~=;`@}D6owjA5pBjzaCfii---gdV(Xj<>Wq`^A@INd}VC5VB$RIhz{`bQO2 z`^OdBkaQFVa+W_V(8quG&fmyynYH@#_1Gj$)e=%`bo;~CZlKl(8~g&|AsBWqCmrXh z^Tz=T2T=HQyOxm71HdaUB)sxM>;X&udG75a_W_0N8@f*aYTb=r$wq{h<6BPPexFSL z=s{R*NkI#W1w`F(JIK6&flSJ<3k=9c#XT@BB@3UDh=J^UB66?t#h zYCtp>f7$pN_B)9x*z+X-p{x@*7&iC!zZ&sPdC-SIhX%g0kMv4DaqwoNiWz$Ucfur! zeE0L!V7~2p&hzg|6?5>@-_fk3^n-Ob!SFG@x^_Nx6?pRVrdKeDyDR4J1?A?wor_v= z+Jg_D0%_B~Uf~0zO*iJPIk-#f^{`C<#Pn7DHv+Ody!`O6{z>*vmpmdD?tSXz8(ZX# z`D5P!e%1KpU2Gp$`tiu$fheuX*^S5jXy6yuajxaWvdMtk7Jv2OXTbjcUX_Z6yYObE z8g*Z-_x7zfu-#na$Xhkue(FK6z#lg*2GaJf<`OWtUsoKwjCb zJp1HFzyj4L4o<`KI2n2+gFKJx(l=2hdyMb9ur6Lc5!um6cZqIN;Z0pU6s4Q2K^{k{!5yb^b z^{d-w7VBA>Nv6IZK%xHQ{NO7-Pjb+gV zN(6kZax#GkI^$n8_-v`$w9-TkHm{5IpB;`nedq6(tco$sCZ&p-d}5;)lG_IcG@1jQ zx%xC*CU>G}GQA3%gJ@exGRptIBZ=@=gNIDoEmn+k|8}x2O`y!C(o)3qR|;d5_iC9E zc?G&MfE`^Nr7EPP0&!Uz; zLl7#DC`?}7MiAN;heJLm-n~6RXukqnDMnTLU9gc^`OGB_M_d#y{pZ?|`} zk3Ff4EA@SGb? z9H=U{U6*u>UW&_-N-O$QjFfP1H>bd*8azYRxS}`(Zl|aVxOB`45bVfCFGf222z@uE zQFE6dclI-+MKZmT;Ldl#5I&p4ht+j`cH6?}e027QIRC~#p5{5kHuqk*47GuXQ&V>8~xQQtY{$7jN zC}X3ku+Z>MD-zdq6W^oUl+Ei&H1)tnqN!Kad^9yc*xDjNG!=$us)f+~t8E=H<%I1> zxn;9}*5yv{z8#V!G7?Qa<d`yG?JNZ>uM)KDv1ZfQb$Hz}vsH<|_n%ERPgjX+hUpDyHzN9Ig={n40Bv}HkOfjwu@!@ z4WstAhfQV<%ioVQ)gj}>lWnlcRSy!&-29>~mbu-jDfK`bY|;`>4qdT~p%2LvcpN}# z`|2hxmWgu0CRUIAZK>@WW4TzSx^na_9@D)y_Mb_WiPDTact_QhB&p)%OsceaSlUwC zC3+T|j!wJYi%Y81yFZ&$De=f>PLbcmg45khxls*5>|fd4MD@_`x1P)qcpiSi1KNXH zQ`e_1F4a_HjGtV)iRc9hX_zpptUtkuJwKAAO)1~f9u@2oFZ{9MD;Kbs7*FGnPxyCk z2P`g&C%|HD{8{v%uB*9(Mf-ChbRkpGSf&*U2g|oGPXd2++5A+g}dkD zlQZjkv=1Il#+48NGgSGA?NCZE}&ET*|ubQmCK z-;ov{<=|iBfIF#2bHq$p%dl8bpb4h*hEEQ`l*SKI-PFMRwy6D@TYPV+6}Rkuw*`vZ z4!_@9{C+pu{(8n}9qD}3_wP2lcPs&ioW_|svc26q?hm`C8fqY+?@36gEwnK7D;Mq)I@7gMC zEF5x7T}}GhCqdhy}-lmvs;O0YL`!ca(4 zwJ*sV4Q-4!7RZ%HvP(0>j}NhSiSc+A?%Exu|tJ4vOsv!@f& zMAIj}jKdwWn1nMR_iJnm>iJrzS=0t*xA-9-HzndwMYc@i=VhtAIUKk8#@5+!3XD zE_>QWpxPSGe*bmJ!?;o2Z+nr-$m@{eS_kmz~`~S^QqYb%>X*-w$ho3$`Pg+ zTHo72EkG)JIs9TS6^&5`6n}<1b}C*fomwN6|T`wi3c*&Rc!9QscR`cXxU4 zpW(rO#svR)C-BdhMpN%H0{r$U3xs=qDo=qoXsMorU913&v?}0{9Zdi0Mu2J3(+6MKIpkF0H~kC>YlST5dDgNt1HIiEJZ+wU6!AFp47Pn&H_MP zN5r@0R^>c|y1+x|*5JJSFU6TiPp4$)KDJGI2B3aC+4OuYmq!mMcejn^^ku-jc;oY& z2Hp@XZ@kD^gk3MtuHhd9#m?l%l>8uX;K%GOp!;)|tpA`3!(TFgUaH}F7eHOI&jnD| z?;ChLfB@9C)tK%aec&Q>pj!ZSTRbyEGQB-OJ+VSGte(H)+3m8s{2h<Q6!;krB7CnTY)&33?KPG>Eij}|QwJDX)59Y>{9Eh@w;X#G@BQ5-xJHdX< z{3SmVw*l1k2SssXT>y3dI{>I3`4fPe{ZU5^f!!nUqZoI|ES=AEF19WC#C>|e3*1xQ z@uMuSo_oqoNpuD`kBOfK=q7RV*s~}}_JRwb4mc6scJ{HU3fI}U_qxumKhe{5wrKGv z*-I{fy5!`*rvU_@#{6O}`pE#OtEOyA`X@jga7r|189&8q@5oB|DgM=;p3OhbQ*zs6 zZi+RF6OH`vd+*5Y^)7(=-r|%gW&GLbPc?`CfwfW86z=RU*&|K-*;_*W%>3}q-_c!O zb^+9HH}0BF0P3%jQ!F)OkLHg(E=oC>IM!ODn~$}%5Q1K7v;+#LmR7&Og3pWai36X+ zO-);t$$qjJpS0qWoT;IKFS4H)@kvpZC9ra8$x1k2-ilGMg zgah|;E1Lx^tWhpy0qW&RN9a9SDTS;SzF4Cb*-lf_%CiJYXknewKaFVNX$B*-@ZD|D z!f0c{YGHkFa^vLmvdoRltI)y*f(ucv5iNXqHm`*f-lK+P02j(5x6dHB(DKXV_Ls4? zL|6!XJt_>nu_}uqxTKIiDp3*ig{6R7~bZOz1`6N%E+zBmQ zyPea*SH!Fq76g~IY2mNZIW1g#dh}i%N&P$6(8zBS)gL4Fsjh10g>O*7@DHwa%Uk9)~P&C!?HHWz4v`*)}o zZIG;B0cx^DTEE%A0o1nR@b9YT-aQT9D!4*kYRD!D4Fa~cB+Y{^PWYVofAcDmSb*C0 z#rbGzapNYt6E~a>7vXS+3h@HT`vf=qP`BeW(rpz5K|4Sf7c9o3vIiev0cvV-Xj%4( z#1m2CTL7r*-XSjF?<_!F=W+oxl2&_&G)=o0Gq8t1woWE3Br|A$k+6h^zy-`$0*B5T za)>N}S(C35${84`N@<&wLblGfuLrD=|pkh|j@@$QPx ztR$hw6Ug1KOh?EawGMKp4!h3E-60puo2f<^HBues(o6LN#2cG;{LIST10r`f$~v-g zSDV~U?)GHLn$p|LU9-ynB;HKznQ;%>5{winQGI|zqiTVeU;4dU{#yk zSx`4|#H;yEN81M5)3aN2C_S9nl0!KYdaI7FcLgA#*L9qBsy#FMPyXj@nei{wdfDsL z%j=M>gvPKkK4LDXx4L@0D+W1+OM%D8tWHZ_$4l@?@w0V+qBnFZlL17-+A~D-92tsV zD4**Yo;S)oZc^e~x=$o-VR5Ui>l`7cx@|in%2Zj;j|R)?Pbjp)@ly)@rdTgdZm~~k zO`TgLS}grT_i9{f(XJxJ)o#1#fgVv=ulQ9XB-{NxvPE8ZWv$(csbLdTTC(-xD;>A}pZ2~4tf{MOHzDC9AZSF8gfS`#Dk77BV@U|}AX2r~ zDiDS!6flHAoM;k20-~iYT53^Ki$hxnTIZok0Fk+9tF@F`%4hv4^`~g7^|O9@*WNih z1pD>-hx^~>x%YW4PwmQi&))m&v(`F$uRX2B?4=5oABx;Hn>@$#JGHR9oL^l7zROBg zGTy!Kq$sEZ^iOm7TIiLw+zFMNIHTj}EmCyob9$09bXZ=cC=g#Sr;xpdpDl8ud?PC- zmmW?;-E3r~GmMIz%^&MZO}3x`RAKCI&>0IKtFB#@7bR{*=THZ(7UC|ma&o9L367#l zwfqVxgJoMQ)k3c+5>V(gYE6=ooO7y9Vmv-9Wk9e4DI7>u*`c%Q_0E%F9f)^^4>Lpm zxDS;o(v>rz(XQ3Jf%;C7Jw>FEK#A?q$<*?Z5%L=Xw=8}|lKyb2I~;g#XH1@Cg(jUv z<;H{LG}zonT`^r^IBN8PEVfSyf~;*>a*97PA$bof2dAJm!(knAO4amzReh<>i85@* z+?heKovBUyk!eZg4#bQKO&PS47x*oDeXZM}*H>0qmByg2SKcT3TH6wRoKL1$FAM&3 z7Haz%WoHtuAT*V?k{SBrP&M3g?iJGN{+j@wTdi%$t231JlVJ2l3!6=pw0;`|pRb(E zTFs2^Ee#*w1UxRCrPRzQ=JG4K%O-b}aoyCr#9MA1?sm4lCARSI6Ngc2V@;y`*$}-~ zJHOG=OHiKzP6r0M%GuMAbJ3<0B)kO_jBM&B&$3z-P0bO&NxLC9X*XSfRxgImf_KEg z?&Lu5x6>R!p@H$_4U=J;jOIP|&Ua7%Cpb3nCepm7u!>U=8YEaF}RK{a%3aF+Q*Cqm~rgF+6 zRGT~jsOH?yrhsaGunVbXRoM=tnz{(0nm2Y(4G~B+-y2cQg?4W4g5p%Hn$fG!%hgX= zs8`MRj8yYRHN}h`rJC=h#ohvbTUE6#$q|gG<{hXUsOCKI-bYmP$Ac>EJEE8o)wHRh ziesp`f)UkB+>My=R8tMAxj{gyX60^LHLHuu461qH5K+yl*62(DQOzxOe^Skobmf&) zTtQIHZW*nb6=v^^sOAAXteRD=d0B>%4%v-X&7B`pQti8-nz7T6YIbPCZ*c<#G_!^^ zGvK6{(#&RBM>RJ^d`w)?a;$O+TG#^(8!$;nl95kaYqe?XjQ4N`(5X`+J z3h|Sj+gZWMHF|}311ZGtG>AY6Z54B`u|iy+m<6L0qV@Gy4S4;q>T6x{+(;p+)?kIW zYorhdj;j*)iZ&Q2#AT{@1NEbkLI^)Z%xDSU0}3(Kj#h}+5A_PsSkpbC5UHOLh3IRI z{<}ep)9wGH5T`Pf5A_v9W?dSs5Vur6P>4=Snz$bnLg~ti zR)yD^1)MCKb+K$ly&1p5EZDVT4|o0kE|d4sdf9oOcCG6#Cg`l+5BpRgfq|~QVCcxC zhnVno!{PhbhYqRsjB>ZK^9#ykFSXuoWJ*5Ko>ySko93xE2bbNVjHjZL?a?~0a^{?m zQF~>&*>1pWvt1ve@}TZRxD_bh2ago*Ly9r%ss+Wc-eQLqEIy{TfP%0OfXjm4(uC00 z9Ez?$z)4}{8mw0{Q7Yl=Q$f;*Pt*8@KReh<%Rv@|`6+VkNVpI8Zx zf{PgVAAR_5KcwvNK5PC7+rKwm@0aDbbYwvsyFNn%j$-;gf!Xi&(#6k09oZTsMLZIh|ce$;ZyfIatTyTh*&eRx6@eNGLMat>V5 zX^)9U7F|AF$AcXYx}!B0sFzn&em1I3HkVa(pQ+kL>U85N)zn7Ku%|~f!IzEewDTWy zo$?;p)tmZ^uG1g(+-G*bkviEtr-It3e~>!exUB0wZd9j#K%LM@KKab`)N?3n;pnoa zI&fdIf1EMmGiHxG@M%Wh31UyUIT-Kf`AE)#bt&CZ&_4Ci6{q%7C;K@kOvZc`2jcES z?kep!PwoX>tUnFWp|$hRS=IENRw21&*~4MMz28#$^%%wp)Xxu;ABV}BOlU|&JxI;_ zJT9PyG~FKhdDijYBG8gdUnX?&LyE8#3J;Ag!NK~vbKNu3J??;Vh<-EQ3Ch8|K6Aw} z*gv$;9v1u?k1HdwACw_Q(|6Q{k}j%T)dj9ln71f&$bLAZ9!5tc-&2NJoA1FJ$}LJL z7u`a3aB(L39DY!EpoyFF@Df6gJ5VY9huA67g^I|^&_vtHOtR(RTDGt<6I$W3?NiCl zfz+$2xYMfJlz#dq3w;znCVAjt(%DF(Vbc9<-pNd3|0p+{Iaj`g?m>0I;mY|D7b&V< zsGM3Da7Je|iQ~Y{v2V!)-%RvK^?)L8cC8Dd)-l?bLSBpJsDBIMsu=f@qiTA7s3A;p8^Mc{dnoZs`ZAF zOuJqfs&CKXQ77!7F&y%SnRWK@>Zaq$#+aKrR z0MniqU|O#DLifc=R}l9P9_|5mT0O@snLTV}cW#Q^IUl=ozfL^o(;GPZBk*MD9(1uj z3VciHFN!YKM&VP0s=p(f({uLkRyCCks1Z_n15fitR_j`6b` z?$`|5yHUgTu;=0t_P-;%%tW3Yd>stCYkV9Si+%@xdZxeIL(d_k>X#zV29fdDe2$~_R0j}+YfcrPBVQJ{!8Bhjq{#7<&#-t!`T^ZAkLoX5bW$PXE2)(CFF&uC?kjLXTeINo4 zT&_72#&O8p-)o#dO%BN&6BBU-eZT?4hYm{p*7B*z-iEE^Hmf;!q0lvFwV|Y)t2v_w z={vs`@oTf;ORecZgY<_m1!6)5sc6 z2B(v_=9u%Xiq$2O$HJSO1OJ0sUE5h^8z}w2+#?l^+72@{Oc>CsALks|GKVVSHECxS zM;zNss(iuGLL?wnzMmh+P9@d4@vUe{Gt+SI@A--B6msw3w>p^Jr+HvgzzGX96!?Rq z3nU}=b{23VJPr5$<_Np-$iv=RU(Ze^1unbIw6^|{c)IvEN48kNF-(g76h&`X+O+GFY7n2&SI{%gMQB;M{c@9EqU~fO`M=g!=r~Cc_KUV=qu}E*=gj_Q@>Y@ zug#W+l6LyqdgbBlk&F)=T}Co~vtB!rG4FewmEmbNg`B9$FLVMXSz7vDm2+JrkHIzW zRf;pgkFR@BM_u%v7KAh$rUPdksf|s)ig%H`!cGlnm*EwrIPrA$c%kcm(N&3bbrFdD?+~omD+~+KqBf)JP_GCtpz{G#vD?$DP|a zZ$@#@D$$QIVJ#htJ1Q1;TZDZTwfI(tCaiV!JrPVZ_WxQ4qhnoLs3(H;`xX>du2D7tAx;=L&G272nI%D z=e;O-Lp*X!*p!Fb+1HDFpwFN5HL^_ZRCJ-aFXmbfYyk2rtLo~-i=eVSMKiY8CX*W|%?Qk#dylgw8+Xzr^z z-!Pu6%4a;T&4ck|Z5}d%R&rq2(C3o^Z3^!+^2hel1V0#0e%gat^dpwH(0G#hnxVPe z$1!;gx(i*;$au2Q$t}}SlaI%f+cS{oMxkHw1T>!1^%#vOZ9d|NR&$K!zV>X4?ahPs z@kr;Ala3=33?mW!y0No+c-ZInG0=$L>5>;M`aCy~Dw~YA<|6e8`nIAAMCGTLP}A3stO*w$epa_1%{Gng8&B&k_*npU^`B#5kndp6DQ2e$*Z%}mm7mgSd=e7`o zVt>xd)#M9V*869JBF{N76;}`$6o)}-_^g_C!aRdPvFx_qpxB?Y$57IZ+b$(kn~sC&@ow{hXE&XSFlGe_dzloj$c z>v^|H6Z)YuS5!cn(6=f%+ExR1l6TLhDbTI$;M`zc43X*ZZ5C^b4~u`DZ5Ve)}2K=Cl(!R5I8EwF#`)40FhJ zJAMX5r7oe?N;&`N1hk?2oD*pZesboXL-IPCWAHtB^bPI6#^E6yUWj^*Uas-y`WX3k zFelHxvpCVvx8HYW_d6NyHx7O-5)FZ-E9ZxnY-SkNK;=VBv`oH#lP5*orrSK{n4IVQ zo70G)^URqy;8e2-`f}RInMo&NhH(C=2{rth%|iW3sIHTf2wwAHNj7M{m^h@gF*he$ z;^zDR`zB%x9M%45Me3p24tWA z^v1XVw_;`+eE7^EI34LLoj2{Rd41jUP7lxfm;w7>=Ue}~=jkcEV>RPy|2D+>c~#vb zro(q^Z_pdoYL(mP6#?fgqE1E{Aum6Qr-!3Yz|J9QkUT|3B&T&sVEqd?lTBWOasQ%aP$-@f<%odD0i``c< zet;_A0h$@9^9$4Y!SG7KG#WX;A$7+q1~U+_uFT9mk|3YA=i5-VRu zC0~z##EmOa;<}NOl+x2EF+NGTbv`=p*k^P&@-i+NAC7cKha*+@P~uvAIPwBYEXRi< zW6|MA)(}b@28Sb0Tl|a?>z<*+XA=<{r4$R{Qa)yHV<1XDgTl=5!i?I6y^rQ#Po$5N zaTXVx#Y=#)tY}AB8sl-66+T!HyYcbXUYwdEWCp_ zi?=1J*jH&d%Tk=>FwRna13%?F&eAsrKjk>C*e;xYw0>-Nso?zJap* z@DH5ji8IPlDZ(CKEwT64wpf%Uwp5L@XudUiO7(u6Wg`AM&bgs1!cdfj^%Bk^jzd|7 z-^N)^;4HTe;4Hp4%SY2u7Hu@nvI%GDe;H@l2U*a;Klqah8H^P`C$L|}ev@tBHyt4k z{LDI6Fi@c2=lYnz-E;AY*aiUph}64CABNx0yd}uBs+f-U-tt_39QGL}nGJV;f2m?T zoYnn1Zxixgr(;q63&7l7Ae9GcPgun9T!QbJq7Ab){N7kN&L3L>ra+f>m{?wM0G1Ld zsx)z4d#nX`_o-KbW&SDbo?vA|9M!kste}MhSR`Q~@`@%~`9={xgMx|Xutv);zEbB1 zp}U0Pa;s)r4O;PbS5>9UlRnm!q4_voHMsW_P6v}Zk!o@;I?0~PepgophX~;Cm1i89 zCUTJU)fC^c>U38@ppjZ+#V&|ds+|G@=_jyZ2Ex6_+H}Q998VRJ#o8W*2=Cn4?2C+H zv2IMUhs#-QpE$^?Ws2l9V@W$O8}a#7BR)_3AK|kO82k_5Gdju%d@cs6{eObj5JJzu z>$AnVqjN| zNVk=s+l{9Ux(yaq2Fzh49)bJau@HpS3j$VkO-vEp^W=;Voc-yZCLcIr>d($Vw)Qn^ zm`KlMz_0hM`}!y$1m!xxwx9R`igh|4C0no!7UVNUeobR@c)C2RZzETnV1FxzxQ$>7 zDN~u|2N_H{C2un>W@4g>qc7&$z}I$LJ@MKfgq+!^e{46RsYUR+l#7Mr)><{R-P4)Z1EKh$%*hk=&o%ssfZ5JEJo%a=Q_ z-vTb^^DnbS$o=+Y5)U4=CXnQN2rb7$P7qM2kiTfdqyM4}_cdz5xqF@T3^^Wl`76o=#`MyWF%IsJ z{Ctako`$6yQ#{N1v@*8zQ?A7SkG2se#ZT)zR(i&c0#r>Bb$sE8H5>0`2C;n z3zz%_A0@_oY!9USLKJ`p)|hRnGm-_pybKv;Zm>1XC9DmUH=6E=_oUuOpY65qTksRK zx(cQhn^0elBnJ zf~wC@NEgg+Pt5Od%x`n}{ZDL<7wifyO2&8(4&%W$D>9+{UEu9p^xte4uQuo&b_Fnw z-#ZMmAj8Gjo)AKt3*RggzS%HHy+W^xMdKYN4WJ0pSHpmZD`XJb>{S>pd4JcJ*8-$h zn}G(S4mx>IgG!ehL{U1^B$#e(35-^9m`;oc9j`W#dzHfPu@S$=Yyavy{15cT#{~PS zr+tASE9$@2->m$L{$}l{{s!w0^f!9_X`uBdSFb-M6N&!pgMblOe|qGsmtkGDx2_r{ z?m%lA-(tU#xe14s(E|QmAGI#A!&|spo?tg%x*OYCyK>}7U9?^ySL#~SAN3(y3-w3- zy@3}2tdA0zyN>neLqpfSRs@dPp(kNY=m`%1Ki<$)qMn2i4*9fbLO>hL9J@i^`)H-v z=K7wbM_%*}N8$~I1gfwxSvB!a5bDlWxPG9M%mHWL{_OpzJF}gr9X99TP7K~gghI(^ zQ<)PF(gj*pQOudv1{=s#KJ_r&H3>f~_tF*^9>_5*=9<~@4U;S0^;1#-)0a_@L%06Q>`d@zr~Fpqd6JhHTg=JQ!2INqV7R*;x#(Aj{86U}qh_#ZbN zn}3po_mS&wpj+a#A2Dyl`_KDpp}>)Tng|3tkZ38M_Pv93a>uMq01 ze4)NAmTjoMOhsD|C=G1lv<^YBr`+*}220QoS_$e}3tnU#km1gRXixM}J z;HI9R(}7N4>-mY|mHi8ThDyK>)CRECK5%t_IywAP9=7UvSQccv0C?y=!b4AJQ4)oq z*~4h?rhqWE`evlf1|670=N#R_m8BeP;%Y6dP4_~u07hbQWuiur*sSNF5Z5&n%O3Yt z*btC_1E>C%&@frY4TRqh@bBj=+N~7Fh>xh7lF+>teBVO0PJnr+%5L2cqt4ixrb;Wx zG>|O+kOa5B&2w!!t@x~|%n3r4h}G zQ)X&-O>iKhMT`I|6jj1PH^-vW%B|6|3--7yh?I*V0&gWOK1pKW=Lrjm$ihE14)#?< zH<&RO`^UZvErRJ4h;GrI=+G<=%AeC{DVSEn8_q3~!3mwVJO>m9mokY zoegb=gcz$&!Hdxg_JBLG2Ywno=O97YdU)k}6dX~kg70Yw{I5u)(oOG<$0IbAPEfKJ zl$A9FjH|cN=`q%Lhm`xK?_*8(f$(W<(Sn4&pG=eLu<;HS6*u>lt1!zDpQ!IIJKTfz zN1Z8QCL~gD_WT69;}p}r+cZ-M)+y|ELD7Q!Ez*T9qIGk3(L)o_m_avLFk{eH+N#7vx?KjR?3s zpyj@BY$r0CW0A0oKS-MfuZ#rh7Y$% zUw;Tu3ia2w8ULmFhx{L^KYX)>`nwk`AyGTF7el;X2;&YN9_;>9!I~B{J1S>}SZZGE z13kJ_NNgVcyP)s#U+JQo`Ya!kfJ|CICJq#q33?~@FrVH!7c7$i80rqYVeSX+Fvuiq zF$7Uwh|v}A7-X`d%BNSp1j{641uYY8D#)Y~WCACmu7XUQZIOiv@8nO_I0)F;Vuo$C zs26^$;ZX{R1!l%vrKPf8=QY7M8Ajre*hINV=*6{VQEuet&M3DK8od$a9wHBMs5cGm zbzo-~@eZ-GgBZ3B@=zEGGzW%j$0}o6SSk(#>yio?486_sf3dt9Et7_US4j-m#;e<* zcW$0)l^RGTdwv&4sUS49MHPELb6WIE8+7W*LlWzk> zKpO*bv@X2V<{ki%ww}5D3gNE+Gz0L?>@(BBeIS7E^y$rD@cM_Qb!YSD+66|M@#gNr z-v3|#(PvL3n+dr?=lDwD%ws9Enf#o?xDSJA)Zou7_75nsF>$DYQFf!{_+owAFBHMF zAKMhQ{oP#X`)s^$Yqyl0A#`Yx2PMMTpkzGkv0*Ve4W`f!uqWyx00(KUnQ+0S1nAmrCvDNEUp~+mGJ-9CBHwa%eDcH->^ZLME*TaR@1hh6vK)uM#40e5DtM48&ne#PK(C)l4i7 zq=)}Z78gJ+#NHP|ntvh-G)5-2E1&0OLMu4UZ0GTF0MKpTI{0k-E3~iH{>%82;euFD z960bE3Zw1)0Cw-ri2+ZS`slyot3Ha_st*o_e1isbZ2R@Ez8RYmy`7x61iqPIEEBMk z}$ajhG?d%3V*F2G70yCJ%S_-p*`Iy}z zA(SQ)!jQ_7>M)ZDT`QSu0@-#SZ2jt4h*pIOVlvA0HXhescTI z7spvN1X?sJE!yfWx)_$%11)bVEqm%MZ!^aB2adg~96MM)c8D?VLEyMY%5hKX$30_M zi$&HUD(eWHbqv!cPGqxOWwTOe6VIHOB$_BwO-$EK1UK&rk!`-pcB9U=km=AMa%fgL zwCNnWn2y&)jyF|~JvztROs9U46P$Y<)Hw|?ogavtAE}(5=$xN1UBqIS5VcE0gG-Fb z)Hw0fYsaOiL0^lc}erH%!YiaaD+2^VO~!8(a%b+#AI1&1(0y2KO$LS=Yt0 zZmMVXG|akf;?XbmxU2RUZ15N|nf*XK`;mI~lZM&POnk*bz9B`v5skhvtT}N(bCws) zS=l%zp5>PmfgOKK-Fi=Q+te#TxZ4qh5kvNYn#(iqcaaly-$mn>U(Wm&xG@}%J9vXbTL zSC(g)zM=?zCBNjAjaObNG+osYysEimRoj(SU8bwA2d}f5IA{lW2fOX3Hw z#1EM!JP1yBRFd%IO2RYK6mdvO$kvpI=9CyJH7+D|`PS5x&8hK}EGa}L+bTL z^^IqoVsTh;h^9EAr8tIL5*JpoTvM{Lr6iuaH7RVXOtUq;Wos6O2$RCIB-Uk}@UQ?tFNW&3UJj{dM6cQrc(TXqa_-+U1E<|ECUPg>r5#;p>E zSA~>SMYLAMnC*-U-?_YO=gQWd@n*Y{!gtBacBQxO$})Rb5&mv|*}EHC-z_w&X$Y@r zE~{y4t?4q`b3J^|&9Xf`t$S{p?d=cWd$(-wVC&u?vwaW3_dP1x_oQ{-GqYN8L~TfU zZA4pbjQN4Mhy%;Z53FoE5O01kDdM24{9t<9!7TGbiikt`<%c%59V#?G(hzZ^x%^05 z+mSBwqt_#j-Yh@b({}W>`LX_pV|UAs4YnN{GC%$x;`pQT<4@X-KQlimjyxH%?PNsz z$(S*x;v!Ei-*#$c`>FUbr;{R2%eI|PZ$F(i=8PipO#Zer8{5wmj=9(nd9iuh#kTf~ zU1L7K9{KsrZJ+nFe|~$+rT)lEceh;{Y`-)#=JJEc%a68Qe$syV*_bAAR8vSrQ$$Bo zjK!6>s4L4WuB`015^vF*6xA%NXio2F&a$|wh`O3zadl(I)k2GohNzC_ijKC9jxLMN z>rtIID>{2RI&WKa^+$Ext>_x;=o+&4;z86Gk1D=+((%PJi?77dUxjS{Dx&kN7|R=R z(KnWFzp=9OM!e$t zNh7jX3ocGN{Hdebe*5H>`_t8qMbv}t7m-gcJ}QZYWH8=ZcVOMdwekc@&0pa&)x3)pmw$QpnCx8mfH8>tQ8Xm z)qdZ26ioO-9WXe%y6Cxj-k|5%6OKg-KK8lQ7+JLFBfrP*Y%U7?dM>}}NRhZ^K|t=U zBFUGFR^*x&hg=gEEPb^&?9;Gn$?oEaM-gXhFBL~!jK0-$yEvw7@#7}*lGsae{M6S< z;y!&f;Nh{7rRA@!NIhS&{MhP(sY4~N9?Cd;WglPp+Rbuii_HbTa`k}<=D?{gzB1wR zb|&}qPkg1ca|hFH*HgYS`Rz((SiXfonfkDjx%RX`pj^}Z7E`m;NuW&AzRf)B=OIwO z{^2{!7Qcl8W#-l@=D?#UfpT5(PA2#3RRU%9$z4ph^&138Hs~)EcifIiAXdV6$SsG`@;KU5WX4yZGeAYp^0Dk+X?B=ewmMHItwL+KXM=aM(N2P+Gm6Q z5GB9s;P-L_!rxNJkLGv?MRPa|n7EY!j`oWXig2zwgRvfPYrqL5_%^y6LJ^M2sRkU} z=Nl-&57IcH1V0HlTH7F$;Egm+D8a7-ZUyNHCHNh{(H;y!3H}FN4xt1$2YI095=w9Z zjT1_6HyS6D;QlmDD8WN$oKS)<103;(P=Y7ZIH3g31{}$YP=XiIIH3gJM&pDMyqd-d zCHMgvCzRly(Kw+5zYI94KcNKg1RSlu5K8dx07rZ#l;HR1^n?=pDP2CH1UH9|1Id?A zf=>b*l}{+aJpo7cAr#?Ap9AUigp%}&>GXsWJb_M6D8aL6oKS*qp>aY9UI94b8=(Z> zP2+?T{4n6?yCam~7wPf|C3rhs4xt49hE7i?!3O|G-w~k%e*!pa2M8rNkFDn?p#&Gw zIH3gh1bjS{Lny&TfFpSkO7K{~(ViPZ37!l%;sc=sUr(nel;GunBe@ew@ZEr;=MqZr zk7%4wf}f>vLJ5AA#t9|(O~BD}2_^Vl8Yh(CkLYp;CAg`nzMg~<+?vJ-CHNE?CzRmc zfTQ{lN^mia6H4$n8Yh(CNiCzRk^SinH_CzRjO7K~LqjrN(f(O#&5K8bR zG)^eNlW3e!g6GgUp#(3caY6~clg0@p_(2*cl;CG*oKS+d&^Vz4|Axj1CHNqX6H4$u z07v}}p#(PrdlKP<65N)?2_^Vc8Yh(CK7b?s6N+%u&xX)Ap#*;waHQvi5a|N7Ggrg4{+06(il;8jv>||mu^FTMzZ{vTEoq_N=pvaac z6wDQj17WO&7xMuC?EE*T!*#;Sl`&qPYo!^o1X;#9X|7*NYOXXb!)p$LyfKuJmMPB& zPM5u$>*b%8m7BUoAhPlL6h^Gb zShvR8=o0kF3rUr{B3v6?A?5&&^voKSNM8mPoDBFNO-b>Pkuq|!6~?Sqyl^c`Ny{{P zhGH#b9(hUqrGWvXZ{jP%y&P$7oiCn8Ag}#b?c3;;1gYIbUC6GUVeJD zkx6rM6H?Oh(^686(&Z}hvXcMFHOU1T5)>Pa638;wr5InLQb#j3D^;pU+Av!8lGn*M zpeh;V1YwPmjx;CD_%?k>>LE>D2eLA-g1$7|g||1Byp+Bs_o3*~%)yuBE);Eim%b!- zQALbzlWW8M7qX+Oq4b9PFGvQ}&Gyw!{3(lZp^DT@X5W{0JdLhEmiP)cfzY==^A4g+^&p*`G34iq zQ$Z_jxQ&0)cuuCYz;J`K0;oCln&aW=F=uUBnmo_Xch=gpoV1KIFg0c+C#U8pl4pZT VrKIMkC8s82$@4stA&M{h`(MxfC8q!Y literal 0 HcmV?d00001 diff --git a/third_party/ascend/backend/npu_utils.cpp b/third_party/ascend/backend/npu_utils.cpp index 28c4ec1354..ec41487b84 100644 --- a/third_party/ascend/backend/npu_utils.cpp +++ b/third_party/ascend/backend/npu_utils.cpp @@ -38,7 +38,7 @@ static std::unordered_map registered_names; static std::unordered_map> func_stubs; static std::tuple -registerKernel(const char *name, const void *data, size_t data_size, int shared, +registerKernel(const char *name, const void *data, size_t data_size, int device, const char *kernel_mode_str) { rtError_t rtRet; @@ -55,14 +55,14 @@ registerKernel(const char *name, const void *data, size_t data_size, int shared, rtRet = rtSetDevice(device); if (rtRet != RT_ERROR_NONE) { printf("rtSetDevice failed, 0x%x\n", rtRet); - return {NULL, NULL}; + return {nullptr, nullptr}; } - void *devbinHandle = NULL; + void *devbinHandle = nullptr; rtRet = rtDevBinaryRegister(&devbin, &devbinHandle); if (rtRet != RT_ERROR_NONE) { printf("rtDevBinaryRegister failed, 0x%x\n", rtRet); - return {NULL, NULL}; + return {nullptr, nullptr}; } std::string stubName = name; @@ -75,7 +75,7 @@ registerKernel(const char *name, const void *data, size_t data_size, int shared, if (rtRet != RT_ERROR_NONE) { printf("rtFunctionRegister failed(stubName = %s), 0x%x\n", stubName.c_str(), rtRet); - return {NULL, NULL}; + return {nullptr, nullptr}; } return std::make_tuple(devbinHandle, func_stub_handle); @@ -91,16 +91,16 @@ static PyObject *loadKernelBinary(PyObject *self, PyObject *args) { if (!PyArg_ParseTuple(args, "ss#iis", &name, &data, &data_size, &shared, &device, &kernel_mode)) { - return NULL; + return nullptr; } auto [module_handle, func_handle] = - registerKernel(name, data, data_size, shared, device, kernel_mode); + registerKernel(name, data, data_size, device, kernel_mode); uint64_t mod = reinterpret_cast(module_handle); uint64_t func = reinterpret_cast(func_handle); if (PyErr_Occurred()) { - return NULL; + return nullptr; } return Py_BuildValue("(KKii)", mod, func, 0, 0); @@ -113,10 +113,10 @@ static PyObject *getArch(PyObject *self, PyObject *args) { if (rtRet != RT_ERROR_NONE) { printf("rtGetSocVersion failed, 0x%x", rtRet); - return NULL; + return nullptr; } if (PyErr_Occurred()) { - return NULL; + return nullptr; } return Py_BuildValue("s", name); } @@ -128,10 +128,10 @@ static PyObject *getAiCoreNum(PyObject *self, PyObject *args) { if (rtRet != RT_ERROR_NONE) { printf("rtGetAiCoreCount failed, 0x%x", rtRet); - return NULL; + return nullptr; } if (PyErr_Occurred()) { - return NULL; + return nullptr; } return Py_BuildValue("I", aiCoreCnt); } @@ -141,19 +141,19 @@ static PyObject *createStream(PyObject *self, PyObject *args) { rtError_t rtRet = rtStreamCreate(&stream, 0); - if (rtRet != RT_ERROR_NONE) { - printf("rtStreamCreate failed, 0x%x", rtRet); - return NULL; - } - if (PyErr_Occurred()) { - return NULL; - } - uint64_t stream_uint64 = reinterpret_cast(stream); - PyObject *result = Py_BuildValue("K", stream_uint64); + if (rtRet != RT_ERROR_NONE) { + printf("rtStreamCreate failed, 0x%x", rtRet); + return nullptr; + } + if (PyErr_Occurred()) { + return nullptr; + } + uint64_t stream_uint64 = reinterpret_cast(stream); + PyObject* result = Py_BuildValue("K", stream_uint64); - if (result == NULL) { - rtStreamDestroy(stream); - } + if (result == nullptr) { + rtStreamDestroy(stream); + } return result; } @@ -193,21 +193,21 @@ std::vector readDataFromBinaryFile(const std::string &filename) { } static PyObject *readDataFromBinaryFileWrapper(PyObject *self, PyObject *args) { - const char *filename; - uint64_t arr_ptr; - if (!PyArg_ParseTuple(args, "sK", &filename, &arr_ptr)) { - return NULL; - } - - try { - std::vector data = readDataFromBinaryFile(filename); - char *arr = reinterpret_cast(arr_ptr); - std::copy(data.begin(), data.end(), arr); - return Py_None; - } catch (const std::exception &e) { - PyErr_SetString(PyExc_RuntimeError, e.what()); - return NULL; - } + const char *filename; + uint64_t arr_ptr; + if (!PyArg_ParseTuple(args, "sK", &filename, &arr_ptr)) { + return nullptr; + } + + try { + std::vector data = readDataFromBinaryFile(filename); + char *arr = reinterpret_cast(arr_ptr); + std::copy(data.begin(), data.end(), arr); + return Py_None; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + } } void writeDataToBinaryFile(const std::string &filename, const char *data, @@ -225,136 +225,99 @@ void writeDataToBinaryFile(const std::string &filename, const char *data, } static PyObject *writeDataToBinaryFileWrapper(PyObject *self, PyObject *args) { - const char *filename; - uint64_t arr_ptr; - size_t num_bytes; - - if (!PyArg_ParseTuple(args, "sKn", &filename, &arr_ptr, &num_bytes)) { - return NULL; - } - - try { - const char *data = reinterpret_cast(arr_ptr); - writeDataToBinaryFile(filename, data, num_bytes); - return Py_None; - } catch (const std::exception &e) { - PyErr_SetString(PyExc_RuntimeError, e.what()); - return NULL; - } + const char *filename; + uint64_t arr_ptr; + size_t num_bytes; + + if (!PyArg_ParseTuple(args, "sKn", &filename, &arr_ptr, &num_bytes)) { + return nullptr; + } + + try { + const char* data = reinterpret_cast(arr_ptr); + writeDataToBinaryFile(filename, data, num_bytes); + return Py_None; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + } } -static PyObject *allocateHostMemory(PyObject *self, PyObject *args) { - uint64_t num_bytes; - if (!PyArg_ParseTuple(args, "K", &num_bytes)) { - return NULL; - } +static PyObject* allocateHostMemory(PyObject* self, PyObject* args) { + uint64_t num_bytes; + if (!PyArg_ParseTuple(args, "K", &num_bytes)) { + return nullptr; + } - void *host_ptr = NULL; - rtError_t error = rtMallocHost(&host_ptr, num_bytes, RT_MEMORY_HOST); - if (error != RT_ERROR_NONE) { - PyErr_Format(PyExc_RuntimeError, - "rtMallocHost failed with error code: 0x%x", error); - return NULL; - } + void* host_ptr = nullptr; + rtError_t error = rtMallocHost(&host_ptr, num_bytes, RT_MEMORY_HOST); + if (error != RT_ERROR_NONE) { + PyErr_Format(PyExc_RuntimeError, "rtMallocHost failed with error code: 0x%x", error); + return nullptr; + } PyObject *result = Py_BuildValue("K", (uint64_t)host_ptr); - if (result == NULL) { - rtFreeHost(host_ptr); - } + if (result == nullptr) { + rtFreeHost(host_ptr); + } return result; } -static PyObject *allocateDeviceMemory(PyObject *self, PyObject *args) { - uint64_t num_bytes; - if (!PyArg_ParseTuple(args, "K", &num_bytes)) { - return NULL; - } +static PyObject* allocateDeviceMemory(PyObject* self, PyObject* args) { + uint64_t num_bytes; + if (!PyArg_ParseTuple(args, "K", &num_bytes)) { + return nullptr; + } - void *device_ptr = NULL; - rtError_t error = rtMalloc(&device_ptr, num_bytes, RT_MEMORY_HBM, 0); - if (error != RT_ERROR_NONE) { - PyErr_Format(PyExc_RuntimeError, "rtMalloc failed with error code: 0x%x", - error); - return NULL; - } + void* device_ptr = nullptr; + rtError_t error = rtMalloc(&device_ptr, num_bytes, RT_MEMORY_HBM, 0); + if (error != RT_ERROR_NONE) { + PyErr_Format(PyExc_RuntimeError, "rtMalloc failed with error code: 0x%x", error); + return nullptr; + } PyObject *result = Py_BuildValue("K", (uint64_t)device_ptr); - if (result == NULL) { - rtFree(device_ptr); - } + if (result == nullptr) { + rtFree(device_ptr); + } return result; } -static PyObject *copyMemory(PyObject *self, PyObject *args) { - uint64_t dst_ptr; - uint64_t src_ptr; - size_t count; - const char *direction_str; - rtMemcpyKind_t copy_direction; - - if (!PyArg_ParseTuple(args, "KKns", &dst_ptr, &src_ptr, &count, - &direction_str)) { - return NULL; - } - - if (strcmp(direction_str, "H2D") == 0) { - copy_direction = RT_MEMCPY_HOST_TO_DEVICE; - } else if (strcmp(direction_str, "D2H") == 0) { - copy_direction = RT_MEMCPY_DEVICE_TO_HOST; - } else { - PyErr_SetString(PyExc_ValueError, - "Invalid copy direction. Must be 'H2D' or 'D2H'."); - return NULL; - } - - void *dst = (void *)dst_ptr; - void *src = (void *)src_ptr; - - rtError_t error = rtMemcpy(dst, count, src, count, copy_direction); - if (error != RT_ERROR_NONE) { - PyErr_Format(PyExc_RuntimeError, "rtMemcpy failed with error code: 0x%x", - error); - return NULL; - } - - Py_INCREF(Py_None); - return Py_None; -} - -static const std::unordered_map LimitTypeMap = { - {"LOW_POWER_TIMEOUT", rtLimitType_t::RT_LIMIT_TYPE_LOW_POWER_TIMEOUT}, - {"WARP_STACK_SIZE", rtLimitType_t::RT_LIMIT_TYPE_SIMT_WARP_STACK_SIZE}, - {"DVG_WARP_STACK_SIZE", - rtLimitType_t::RT_LIMIT_TYPE_SIMT_DVG_WARP_STACK_SIZE}, - {"STACK_SIZE", rtLimitType_t::RT_LIMIT_TYPE_STACK_SIZE}}; - -static PyObject *setDeviceLimit(PyObject *self, PyObject *args) { - int device; // device ID - const char *type_str; - uint32_t val; - if (!PyArg_ParseTuple(args, "isI", &device, &type_str, &val)) { - return NULL; - } - - auto it = LimitTypeMap.find(type_str); - if (it == LimitTypeMap.end()) { - printf("Invalid limit type: %s.\n", type_str); - return NULL; - } - - rtError_t rtRet = rtDeviceSetLimit(device, it->second, val); - if (rtRet != RT_ERROR_NONE) { - printf("rtDeviceSetLimit failed, 0x%x\n", rtRet); - return NULL; - } - if (PyErr_Occurred()) { - return NULL; - } - return Py_None; +static PyObject* copyMemory(PyObject* self, PyObject* args) { + uint64_t dst_ptr; + uint64_t src_ptr; + size_t count; + const char* direction_str; + rtMemcpyKind_t copy_direction; + + if (!PyArg_ParseTuple(args, "KKns", &dst_ptr, &src_ptr, &count, &direction_str)) { + return nullptr; + } + + if (strcmp(direction_str, "H2D") == 0) { + copy_direction = RT_MEMCPY_HOST_TO_DEVICE; + } else if (strcmp(direction_str, "D2H") == 0) { + copy_direction = RT_MEMCPY_DEVICE_TO_HOST; + } else { + PyErr_SetString(PyExc_ValueError, "Invalid copy direction. Must be 'H2D' or 'D2H'."); + return nullptr; + } + + void *dst = (void*)dst_ptr; + void *src = (void*)src_ptr; + + rtError_t error = rtMemcpy(dst, count, src, count, copy_direction); + if (error != RT_ERROR_NONE) { + PyErr_Format(PyExc_RuntimeError, "rtMemcpy failed with error code: 0x%x", error); + return nullptr; + } + + Py_INCREF(Py_None); + return Py_None; } static PyMethodDef NpuUtilsMethods[] = { @@ -363,19 +326,13 @@ static PyMethodDef NpuUtilsMethods[] = { {"get_arch", getArch, METH_VARARGS, "Get soc version of NPU"}, // sentinel {"get_aicore_num", getAiCoreNum, METH_VARARGS, "Get the number of AI core"}, - {"create_stream", createStream, METH_VARARGS, "Create a stream"}, - {"read_data_from_file", readDataFromBinaryFileWrapper, METH_VARARGS, - "Read binary file into the array already allocated"}, - {"write_data_to_file", writeDataToBinaryFileWrapper, METH_VARARGS, - "Write an array to a binary file"}, - {"allocate_device_memory", allocateDeviceMemory, METH_VARARGS, - "Allocate device memory"}, - {"allocate_host_memory", allocateHostMemory, METH_VARARGS, - "Allocate host memory"}, - {"copy_memory", copyMemory, METH_VARARGS, - "Copy data between host and device"}, - {"set_device_limit", setDeviceLimit, METH_VARARGS, "Set the limit of NPU"}, - {NULL, NULL, 0, NULL}}; + {"create_stream", createStream, METH_VARARGS, "Create a stream"}, + {"read_data_from_file", readDataFromBinaryFileWrapper, METH_VARARGS, "Read binary file into the array already allocated"}, + {"write_data_to_file", writeDataToBinaryFileWrapper, METH_VARARGS, "Write an array to a binary file"}, + {"allocate_device_memory", allocateDeviceMemory, METH_VARARGS, "Allocate device memory"}, + {"allocate_host_memory", allocateHostMemory, METH_VARARGS, "Allocate host memory"}, + {"copy_memory", copyMemory, METH_VARARGS, "Copy data between host and device"}, + {nullptr, nullptr, 0, nullptr}}; static PyModuleDef ModuleDef = { PyModuleDef_HEAD_INIT, "npu_utils", @@ -384,8 +341,8 @@ static PyModuleDef ModuleDef = { PyMODINIT_FUNC PyInit_npu_utils(void) { PyObject *m = PyModule_Create(&ModuleDef); - if (m == NULL) { - return NULL; + if (m == nullptr) { + return nullptr; } PyModule_AddFunctions(m, NpuUtilsMethods); diff --git a/third_party/ascend/backend/runtime/autoparser.py b/third_party/ascend/backend/runtime/autoparser.py index 2176d71e59..642ffc780e 100644 --- a/third_party/ascend/backend/runtime/autoparser.py +++ b/third_party/ascend/backend/runtime/autoparser.py @@ -97,6 +97,14 @@ def get_axis(self, var: str, node=None): axis = self.handle_lt_node(var, child_node) elif isinstance(child_node, ast.Assign): axis = self.handle_assign_node(var, child_node) + + elif isinstance(child_node, ast.BinOp) and \ + isinstance(child_node.op, ast.BitAnd): + + axis = self.handle_lt_node(var, child_node.left) + if axis is None: + axis = self.handle_lt_node(var, child_node.right) + if axis is not None: return axis self.checked_vars.append(var) @@ -178,6 +186,13 @@ def __init__(self, func_ast: ast.AST, keys: Dict[str, str], candidates_params: L super().__init__(func_ast, keys) self.split_axes = dict() self.program_id_vars = list() + self.program_id_var_dims = dict() + self.num_programs_var_dims = dict() + self.grid_stride_tiling_only = dict() + # axis_name -> program_id axis dim + self.split_axis_pid_dims = dict() + # axis_name -> program_id axis dim (includes axes inferred without split params) + self.axis_pid_dims = dict() self.candidates_params = candidates_params def parse(self) -> Dict[str, str]: @@ -185,41 +200,247 @@ def parse(self) -> Dict[str, str]: return self.split_axes def visit_Assign(self, node): - if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Attribute): - if isinstance(node.value.func.value, ast.Name): - if node.value.func.value.id == "tl" and node.value.func.attr == "program_id": - if isinstance(node.targets[0], ast.Name) and \ - node.targets[0].id not in self.program_id_vars: - self.program_id_vars.append(node.targets[0].id) + pid_dim = self._get_program_id_dim(node.value) + if pid_dim is not None: + if ( + len(node.targets) == 1 + and isinstance(node.targets[0], ast.Name) + and node.targets[0].id not in self.program_id_vars + ): + self.program_id_vars.append(node.targets[0].id) + self.program_id_var_dims[node.targets[0].id] = pid_dim + num_programs_dim = self._get_num_programs_dim(node.value) + if num_programs_dim is not None: + if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + self.num_programs_var_dims[node.targets[0].id] = num_programs_dim self.generic_visit(node) def visit_BinOp(self, node): if isinstance(node.op, ast.Mult): split_axes_val = None + split_axis_pid_dim = None if isinstance(node.left, ast.Name) and node.left.id in self.program_id_vars: if isinstance(node.right, ast.Name): split_axes_val = node.right.id + split_axis_pid_dim = self.program_id_var_dims.get(node.left.id) elif isinstance(node.left, ast.Call) and isinstance(node.left.func, ast.Attribute): if node.left.func.value.id == "tl" and \ node.left.func.attr == "program_id": if isinstance(node.right, ast.Name): split_axes_val = node.right.id + split_axis_pid_dim = self._get_program_id_dim(node.left) if isinstance(node.right, ast.Name) and node.right.id in self.program_id_vars: if isinstance(node.left, ast.Name): split_axes_val = node.left.id + split_axis_pid_dim = self.program_id_var_dims.get(node.right.id) elif isinstance(node.right, ast.Call) and isinstance(node.right.func, ast.Attribute): if node.right.func.value.id == "tl" and node.right.func.attr == "program_id": if isinstance(node.left, ast.Name): split_axes_val = node.left.id - + split_axis_pid_dim = self._get_program_id_dim(node.right) + if split_axes_val in self.candidates_params and \ split_axes_val not in self.split_axes.values(): split_axes_key = self.get_axis(split_axes_val) - if split_axes_key: + if split_axes_key and not self._is_tiling_only_split(split_axes_key, split_axes_val): self.split_axes[split_axes_key] = split_axes_val + if split_axis_pid_dim is not None: + self._record_axis_pid_dim(split_axes_key, split_axis_pid_dim) self.generic_visit(node) + def visit_For(self, node): + if not isinstance(node.iter, ast.Call): + self.generic_visit(node) + return + + iter_fn = node.iter.func + is_range = isinstance(iter_fn, ast.Name) and iter_fn.id == "range" + is_tl_range = ( + isinstance(iter_fn, ast.Attribute) + and isinstance(iter_fn.value, ast.Name) + and iter_fn.value.id == "tl" + and iter_fn.attr == "range" + ) + if not (is_range or is_tl_range): + self.generic_visit(node) + return + + if len(node.iter.args) == 0: + self.generic_visit(node) + return + + start = node.iter.args[0] if len(node.iter.args) >= 2 else None + stop = node.iter.args[1] if len(node.iter.args) >= 2 else node.iter.args[0] + pid_dim = self._extract_pid_dim_from_expr(start) + axis = self._axis_from_expr(stop) + if axis is not None and pid_dim is not None: + self._record_axis_pid_dim(axis, pid_dim) + if len(node.iter.args) >= 3: + step = node.iter.args[2] + loop_tiling_only_param = self._extract_grid_stride_split_param(start, step, pid_dim) + if loop_tiling_only_param is not None: + self._mark_tiling_only_param(axis, loop_tiling_only_param) + + self.generic_visit(node) + + def _get_program_id_dim(self, node): + if not ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "tl" + and node.func.attr == "program_id" + ): + return None + + axis_dim = 0 + if len(node.args) > 0: + if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, int): + axis_dim = node.args[0].value + else: + return None + + for kw in node.keywords: + if kw.arg == "axis": + if isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, int): + axis_dim = kw.value.value + else: + return None + break + return axis_dim + + def _get_num_programs_dim(self, node): + if not ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "tl" + and node.func.attr == "num_programs" + ): + return None + + axis_dim = 0 + if len(node.args) > 0: + if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, int): + axis_dim = node.args[0].value + else: + return None + + for kw in node.keywords: + if kw.arg == "axis": + if isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, int): + axis_dim = kw.value.value + else: + return None + break + return axis_dim + + def _extract_pid_dim_from_expr(self, node): + if node is None: + return None + for child in ast.walk(node): + if isinstance(child, ast.Name) and child.id in self.program_id_var_dims: + return self.program_id_var_dims[child.id] + pid_dim = self._get_program_id_dim(child) + if pid_dim is not None: + return pid_dim + return None + + def _contains_pid_dim(self, node, pid_dim): + if node is None: + return False + for child in ast.walk(node): + if isinstance(child, ast.Name): + if self.program_id_var_dims.get(child.id, None) == pid_dim: + return True + if self._get_program_id_dim(child) == pid_dim: + return True + return False + + def _contains_num_programs_dim(self, node, pid_dim): + if node is None: + return False + for child in ast.walk(node): + if isinstance(child, ast.Name): + if self.num_programs_var_dims.get(child.id, None) == pid_dim: + return True + if self._get_num_programs_dim(child) == pid_dim: + return True + return False + + def _is_candidate_name(self, node, candidate_name): + return ( + isinstance(node, ast.Name) + and node.id == candidate_name + and candidate_name in self.candidates_params + ) + + def _extract_pid_multiplied_candidate(self, node, pid_dim): + if node is None: + return None + candidates = set() + for child in ast.walk(node): + if not isinstance(child, ast.BinOp) or not isinstance(child.op, ast.Mult): + continue + left = child.left + right = child.right + if isinstance(left, ast.Name) and left.id in self.candidates_params and \ + self._contains_pid_dim(right, pid_dim): + candidates.add(left.id) + if isinstance(right, ast.Name) and right.id in self.candidates_params and \ + self._contains_pid_dim(left, pid_dim): + candidates.add(right.id) + if len(candidates) == 1: + return next(iter(candidates)) + return None + + def _contains_num_programs_multiplied_candidate(self, node, candidate_name, pid_dim): + if node is None: + return False + for child in ast.walk(node): + if not isinstance(child, ast.BinOp) or not isinstance(child.op, ast.Mult): + continue + if self._is_candidate_name(child.left, candidate_name): + if self._contains_num_programs_dim(child.right, pid_dim): + return True + if self._is_candidate_name(child.right, candidate_name): + if self._contains_num_programs_dim(child.left, pid_dim): + return True + return False + + def _extract_grid_stride_split_param(self, start, step, pid_dim): + if start is None or step is None: + return None + candidate_name = self._extract_pid_multiplied_candidate(start, pid_dim) + if candidate_name is None: + return None + if self._contains_num_programs_multiplied_candidate(step, candidate_name, pid_dim): + return candidate_name + return None + + def _mark_tiling_only_param(self, axis, candidate_name): + self.grid_stride_tiling_only.setdefault(axis, set()).add(candidate_name) + if self.split_axes.get(axis, None) == candidate_name: + del self.split_axes[axis] + self.split_axis_pid_dims.pop(axis, None) + + def _is_tiling_only_split(self, axis, candidate_name): + return candidate_name in self.grid_stride_tiling_only.get(axis, set()) + + def _axis_from_expr(self, node): + if node is None: + return None + for k, v in self.keys.items(): + if self.contains_target_var(node, v): + return k + return None + + def _record_axis_pid_dim(self, axis, pid_dim): + self.axis_pid_dims[axis] = pid_dim + if axis in self.split_axes: + self.split_axis_pid_dims[axis] = pid_dim + class TilingAxesParser(AxesKeyParser): """ @@ -262,12 +483,13 @@ def parse(self) -> Dict[str, str]: return self.tiling_axes def visit_For(self, node): - if isinstance(node.iter, ast.Call) and \ - len(node.iter.args) == 3 and \ - isinstance(node.iter.args[2], ast.Name): - for_loop_param = node.iter.args[2].id - if for_loop_param in self.candidates_params and \ - for_loop_param not in self.candidates_params_for_loop: + if isinstance(node.iter, ast.Call) and len(node.iter.args) == 3: + step_expr = node.iter.args[2] + for_loop_param = self._extract_unique_candidate(step_expr) + if ( + for_loop_param is not None + and for_loop_param not in self.candidates_params_for_loop + ): self.candidates_params_for_loop.append(for_loop_param) self.generic_visit(node) @@ -276,10 +498,10 @@ def visit_Assign(self, node): # handle FloorDiv if isinstance(node.value, ast.BinOp) and isinstance(node.value.op, ast.FloorDiv): denominator = node.value.right - if isinstance(denominator, ast.Name) and \ - denominator.id in self.candidates_params and \ - denominator.id not in self.candidates_params_for_loop: - self.candidates_params_for_loop.append(denominator.id) + denominator_param = self._extract_unique_candidate(denominator) + if denominator_param is not None and \ + denominator_param not in self.candidates_params_for_loop: + self.candidates_params_for_loop.append(denominator_param) self.visit(self.func_ast) tiling_axes_val = self.get_tiling_axes_val(node.value) @@ -312,6 +534,21 @@ def get_tiling_axes_val(self, node): return val return None + def _extract_unique_candidate(self, expr): + """ + Extract a unique tiling candidate from an expression. + Return None when no candidate or ambiguous (more than one candidate) appears. + """ + if expr is None: + return None + candidates = [ + param for param in self.candidates_params + if self.contains_target_var(expr, param) + ] + if len(candidates) == 1: + return candidates[0] + return None + class ReductionAxesParser(AxesKeyParser): """ @@ -343,12 +580,39 @@ def __init__(self, func_ast: ast.AST, keys: Dict[str, str]): """ super().__init__(func_ast, keys) self.reduction_axes = list() - self.reduction_func = ('sum', 'xor_sum', 'max', 'min', 'argmax', 'argmin') # tl.xxx + self.reduction_func = ('sum', 'xor_sum', 'max', 'min', 'argmax', 'argmin') # tl.xxx + self.ndim = 1 def parse(self) -> List[str]: super().parse() return self.reduction_axes + def visit_Assign(self, node): + self._scan_subscripts(node.value) + self.generic_visit(node) + + def _scan_subscripts(self, node): + if isinstance(node, ast.Subscript): + ndim = self._get_subscripts_ndim(node) + if ndim > self.ndim: + self.ndim = ndim + + for child in ast.iter_child_nodes(node): + self._scan_subscripts(child) + + def _get_subscripts_ndim(self, subscript_node): + slice_node = subscript_node.slice + + if isinstance(slice_node, ast.Tuple): + # e.g. [:, None] -> Tuple(elts=[Slice(), Constant(None)]) + return len(slice_node.elts) + elif isinstance(slice_node, (ast.Slice, ast.Constant, ast.Name, ast.UnaryOp, ast.BinOp)): + # e.g. [0], [:], [i], [-1], [i+1] + return 1 + else: + # Fallback: treat as 1D + return 1 + def visit_Call(self, node): if not isinstance(node.func, ast.Attribute): return @@ -358,23 +622,43 @@ def visit_Call(self, node): return if func.attr not in self.reduction_func: return - + + axis_dim = None args = node.args if len(args) == 1: - keywords = node.keywords - for keyword in keywords: + # Axis passed as keyword argument + for keyword in node.keywords: if keyword.arg == 'axis': - if isinstance(keyword.value, ast.Constant): - axis_dim = keyword.value.value + axis_dim = self.get_axis_dim(keyword.value) + break + elif len(args) == 2: - if isinstance(args[1], ast.Constant): # check the second param - axis_dim = args[1].value + # Axis passed as positional argument. Check the second param + axis_dim = self.get_axis_dim(args[1]) + else: - return + raise ValueError("Reduction funtions args error") + + if axis_dim is not None: + reduction_axis = self.get_axis(axis_dim) + if reduction_axis and reduction_axis not in self.reduction_axes: + self.reduction_axes.append(reduction_axis) + + def get_axis_dim(self, node): + if isinstance(node, ast.Constant): + axis_dim = node.value + elif isinstance(node, ast.UnaryOp) and \ + isinstance(node.op, ast.USub): + operand = node.operand + if isinstance(operand, ast.Constant): + axis_dim = self.ndim - operand.value + else: + raise ValueError(f"Reduction function axis error, got: {ast.dump(node)}") - reduction_axis = self.get_axis(axis_dim) - if reduction_axis and reduction_axis not in self.reduction_axes: - self.reduction_axes.append(reduction_axis) + if not isinstance(axis_dim, int): + raise ValueError("Reduction function axis must be an integer, " + f"got {type(node.value).__name__}: {node.value}") + return axis_dim def get_axis(self, axis_dim: int): """ diff --git a/third_party/ascend/backend/runtime/autotuner.py b/third_party/ascend/backend/runtime/autotuner.py index 3ea0381054..5af6fe2aea 100644 --- a/third_party/ascend/backend/runtime/autotuner.py +++ b/third_party/ascend/backend/runtime/autotuner.py @@ -23,16 +23,22 @@ from __future__ import annotations import builtins +import copy +import functools +import ast import os import time -import copy +from concurrent.futures import ThreadPoolExecutor from typing import Dict, List + from torch import Tensor +import triton from triton.runtime.autotuner import Autotuner, Config +from .autoparser import (LowDimsAxesParser, PtrNumsParser, ReductionAxesParser, + SplitAxesParser, TilingAxesParser) from .utils import get_byte_per_numel, is_valid_axis_name, valid_axis_names -from .autoparser import SplitAxesParser, TilingAxesParser, ReductionAxesParser, LowDimsAxesParser, PtrNumsParser class AutoTilingTuner(Autotuner): @@ -88,7 +94,13 @@ def __init__( tiling_params = self.hints.get("tiling_params", None) low_dim_axes = self.hints.get("low_dim_axes", None) reduction_axes = self.hints.get("reduction_axes", None) - self._init_axis_params(key, split_params, tiling_params, low_dim_axes, reduction_axes) + self._init_axis_params( + key, + split_params, + tiling_params, + low_dim_axes, + reduction_axes, + ) self.auto_gen_config = not configs or self.hints.get("auto_gen_config", False) self.gen_configs = [] # generated configs from TileGenerator @@ -98,8 +110,15 @@ def __init__( else: self.user_configs = configs self.is_simt_mode = False + self.simt_stack_limit = 8192 self.user_specified_warps = None self.print_autotuning = os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" + # Compile kernels in parallel by default for triton.runtime.JITFunction, + # but not for others, e.g., LibEntry, since it's not compatible with AsyncCompileMode + self.compile_parallel = ( + isinstance(self.fn, triton.runtime.JITFunction) + and os.getenv("TRITON_AUTOTUNE_PARALLEL_COMPILE", "1") == "1" + ) def _init_axis_params(self, key, split_params, tiling_params, low_dim_axes, reduction_axes): if isinstance(key, list): @@ -138,9 +157,15 @@ def _init_axis_params(self, key, split_params, tiling_params, low_dim_axes, redu set(self.keys.keys()))) self.split_params = split_params + self.all_split_params = {} + self.fixed_split_params = {} self.tiling_params = tiling_params self.low_dim_axes = low_dim_axes self.reduction_axes = reduction_axes + self.fixed_grid_dims = set() + self.fixed_grid_dim_values = {} + self.split_axis_pid_dims = {} + self.axis_pid_dims = {} self.dual_reduction = False self.persistent_reduction = False self.num_buffers = -1 @@ -167,8 +192,51 @@ def _autoparse_axis_params(self, all_args): self.persistent_reduction = True if not self.split_params: - self.split_params = self._autoparse_split_params(miss_params) - miss_params = [arg for arg in miss_params if arg not in self.split_params.values()] + all_split_params = self._autoparse_split_params( + self._get_constexpr_candidates() + ) + self.all_split_params = dict(all_split_params) + self.fixed_split_params = {} + self.fixed_grid_dim_values = self._get_fixed_grid_dim_values( + all_args.get("grid", None), + all_args, + ) + self.fixed_grid_dims = set(self.fixed_grid_dim_values.keys()) + + fixed_grid_axes = { + axis for axis, pid_dim in self.axis_pid_dims.items() + if pid_dim in self.fixed_grid_dims + } + + # Only missing constexpr params are tunable, and fixed-grid axes + # should not be tuned on split. + self.split_params = { + axis: param + for axis, param in all_split_params.items() + if param in miss_params and axis not in fixed_grid_axes + } + + # Fixed split is inferred only from fixed grid dims. + for axis, pid_dim in self.axis_pid_dims.items(): + if pid_dim not in self.fixed_grid_dims: + continue + core_num = self.fixed_grid_dim_values.get(pid_dim, 0) + axis_len_name = self.keys.get(axis, None) + axis_len = all_args.get(axis_len_name, None) + if not isinstance(core_num, int) or core_num <= 0: + continue + if not isinstance(axis_len, int) or axis_len <= 0: + continue + + self.fixed_split_params[axis] = (axis_len + core_num - 1) // core_num + elif not self.axis_pid_dims: + # When split axes are provided by hints, parse axis->program_id mapping + # independently for fixed-grid semantics and diagnostics. + self._autoparse_axis_pid_dims() + miss_params = [ + arg for arg in miss_params + if arg not in self.split_params.values() + ] if not self.tiling_params: self.tiling_params = self._autoparse_tiling_params(miss_params) miss_params = [arg for arg in miss_params if arg not in self.tiling_params.values()] @@ -191,6 +259,7 @@ def _gen_tile_configs(self, kv_dict: Dict[str, int], dtype: torch.dtype) -> List kernel_meta = KernelMeta( axis_sizes, self.split_params, + self.fixed_split_params, self.tiling_params, self.low_dim_axes, dtype, @@ -272,6 +341,8 @@ def generate_key_and_configs(self, *args, **kwargs): def run(self, *args, **kwargs): key = self.generate_key_and_configs(*args, **kwargs) + if self.is_simt_mode and kwargs.get('simt_stack_limit', None) is None: + kwargs['simt_stack_limit'] = self.simt_stack_limit used_cached_result = True if key not in self.cache: # prune configs @@ -319,14 +390,40 @@ def _batch_bench(self, *args, configs, **kwargs): exc = None exc_stack = "" - for config, fn in kernels_call.items(): + if self.compile_parallel: + import psutil + + max_workers = min(psutil.cpu_count(logical=False) // 2, len(kernels_call)) + future_kernels = [] try: - fn() - run_fns[config] = fn - except (CompileTimeAssertionFailure, MLIRCompilationError, OutOfResources) as e: - import traceback - exc_stack = traceback.format_exc() - exc = e + with ( + ThreadPoolExecutor(max_workers=max_workers) as executor, + triton.AsyncCompileMode(executor), + ): + for config, fn in kernels_call.items(): + future_kernels.append((config, fn(warmup=True))) + + for config, fut in future_kernels: + try: + if hasattr(fut, "result"): + fut = fut.result() + run_fns[config] = functools.partial(kernels_call[config], warmup=False) + except (CompileTimeAssertionFailure, MLIRCompilationError) as e: + import traceback + exc_stack = traceback.format_exc() + exc = e + except Exception as e: + # ignore exception from __exit__() of AsyncCompileMode + triton.runtime._async_compile.active_mode.set(None) + else: + for config, fn in kernels_call.items(): + try: + fn(warmup=False) + run_fns[config] = functools.partial(fn, warmup=False) + except (CompileTimeAssertionFailure, MLIRCompilationError, OutOfResources) as e: + import traceback + exc_stack = traceback.format_exc() + exc = e if len(run_fns) == 0: raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc} \nStack trace: {exc_stack}") @@ -356,15 +453,18 @@ def _make_kernel_call(self, *args, config, **meta): current = dict(meta, **config.all_kwargs()) full_nargs = {**self.nargs, **current} - def kernel_call(): + def kernel_call(warmup): if config.pre_hook: config.pre_hook(full_nargs) self.pre_hook(full_nargs) try: - self.fn.run( + current.update({"warmup": warmup}) + res = self.fn.run( *args, **current, ) + if warmup: + return res except Exception as e: try: self.post_hook(full_nargs, exception=e) @@ -380,8 +480,27 @@ def warmup(self, *args, **kwargs): _ = self.generate_key_and_configs(*args, **kwargs) pruned_configs = self.prune_configs(kwargs) ret = [] - for config in pruned_configs: - ret.append(self.fn.warmup(*args, **kwargs, **config.all_kwargs())) + if self.compile_parallel: + import psutil + + max_workers = min(psutil.cpu_count(logical=False) // 2, len(pruned_configs)) + with ( + ThreadPoolExecutor(max_workers=max_workers) as executor, + triton.AsyncCompileMode(executor), + ): + for config in pruned_configs: + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs() + )) + else: + for config in pruned_configs: + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs() + )) self.nargs = None return ret @@ -389,7 +508,10 @@ def _profile(self, *args, config, **meta): from ..testing import do_bench_npu kernel_call = self._make_kernel_call(*args, config=config, **meta) - do_bench_npu(kernel_call, prof_dir=self.auto_profile_dir, keep_res=True) + fn = functools.partial(kernel_call, warmup=False) + do_bench_npu( + fn, prof_dir=self.auto_profile_dir, keep_res=True + ) def _autoparse_split_params(self, candidates_params: List[str]) -> Dict[str, str]: """ @@ -398,10 +520,147 @@ def _autoparse_split_params(self, candidates_params: List[str]) -> Dict[str, str func_ast = self.fn.parse() parser = SplitAxesParser(func_ast, self.keys, candidates_params) split_axes = parser.parse() + self.split_axis_pid_dims = dict(getattr(parser, "split_axis_pid_dims", {})) + self.axis_pid_dims = dict(getattr(parser, "axis_pid_dims", {})) if self.print_autotuning: - print(f"Ascend autotuning parse split axes: {split_axes}") + print( + f"Ascend autotuning parse split axes: {split_axes}, " + f"split axis pid dims: {self.split_axis_pid_dims}, " + f"axis pid dims: {self.axis_pid_dims}" + ) return split_axes + def _autoparse_axis_pid_dims(self) -> Dict[str, int]: + """ + Extract axis -> program_id dim mapping without relying on split-parameter + classification, so fixed-grid semantics can always consume it. + """ + func_ast = self.fn.parse() + parser = SplitAxesParser( + func_ast, + self.keys, + self._get_constexpr_candidates(), + ) + _ = parser.parse() + self.axis_pid_dims = dict(getattr(parser, "axis_pid_dims", {})) + self.split_axis_pid_dims = dict(getattr(parser, "split_axis_pid_dims", {})) + if self.print_autotuning: + print( + "Ascend autotuning parse axis pid dims (independent): " + f"{self.axis_pid_dims}" + ) + return self.axis_pid_dims + + def _get_constexpr_candidates(self) -> List[str]: + """ + Returns all constexpr parameter names from the kernel function definition. + """ + func_ast = self.fn.parse() + constexpr_names = [] + for node in ast.walk(func_ast): + if not isinstance(node, ast.FunctionDef): + continue + if not isinstance(node.args, ast.arguments): + continue + for arg in node.args.args: + if not isinstance(arg, ast.arg): + continue + ann = arg.annotation + if ( + isinstance(ann, ast.Attribute) + and isinstance(ann.value, ast.Name) + and ann.value.id == "tl" + and ann.attr == "constexpr" + ): + constexpr_names.append(arg.arg) + break + return constexpr_names + + def _get_fixed_grid_dim_values(self, grid, all_args: Dict[str, object] = None) -> Dict[int, int]: + """ + Returns fixed grid dim -> value. + - Static tuple/list grid: direct extraction + - Callable grid: infer fixed dims by perturbing missing constexpr params + """ + if grid is None: + return {} + if callable(grid): + return self._infer_fixed_dims_from_callable_grid(grid, all_args or {}) + return self._extract_fixed_grid_dims(grid) + + def _extract_fixed_grid_dims(self, grid) -> Dict[int, int]: + if isinstance(grid, int): + grid = (grid,) + if not isinstance(grid, (tuple, list)): + return {} + fixed_dims = {} + for idx, dim in enumerate(grid): + if isinstance(dim, int) and dim > 0: + fixed_dims[idx] = dim + return fixed_dims + + def _normalize_grid_tuple(self, grid_out): + if isinstance(grid_out, int): + return (grid_out,) + if isinstance(grid_out, (tuple, list)): + return tuple(grid_out) + return None + + def _infer_fixed_dims_from_callable_grid(self, grid_fn, all_args: Dict[str, object]) -> Dict[int, int]: + constexpr_candidates = self._get_constexpr_candidates() + base_meta = dict(all_args or {}) + + # Fill missing constexpr with stable probe defaults so grid(meta) can execute. + for name in constexpr_candidates: + if name not in base_meta: + base_meta[name] = 128 + + try: + base_grid_raw = grid_fn(dict(base_meta)) + except Exception: + return {} + + base_grid = self._normalize_grid_tuple(base_grid_raw) + if base_grid is None: + return {} + + dynamic_dims = set() + # Missing constexpr are tunable candidates. + tunable_probe_names = [name for name in constexpr_candidates if name not in (all_args or {})] + probe_values = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + + for name in tunable_probe_names: + baseline = base_meta.get(name, 128) + for probe in probe_values: + if probe == baseline: + continue + probe_meta = dict(base_meta) + probe_meta[name] = probe + try: + probe_grid_raw = grid_fn(probe_meta) + except Exception: + continue + probe_grid = self._normalize_grid_tuple(probe_grid_raw) + if probe_grid is None: + continue + if len(probe_grid) != len(base_grid): + dynamic_dims.update(range(min(len(probe_grid), len(base_grid)))) + continue + for idx, (base_dim, probe_dim) in enumerate(zip(base_grid, probe_grid)): + if not (isinstance(base_dim, int) and isinstance(probe_dim, int)): + dynamic_dims.add(idx) + continue + if base_dim != probe_dim: + dynamic_dims.add(idx) + + fixed_dims = {} + for idx, dim in enumerate(base_grid): + if idx in dynamic_dims: + continue + if isinstance(dim, int) and dim > 0: + fixed_dims[idx] = dim + return fixed_dims + def _autoparse_tiling_params(self, candidates_params: List[str]) -> Dict[str, str]: """ Extracts the tiling axis parameters from triton kernel code. diff --git a/third_party/ascend/backend/runtime/tile_generator.py b/third_party/ascend/backend/runtime/tile_generator.py index 9adf013953..4e30eb6df3 100644 --- a/third_party/ascend/backend/runtime/tile_generator.py +++ b/third_party/ascend/backend/runtime/tile_generator.py @@ -52,7 +52,9 @@ class AxisInfo: split_name: str = "" tiling_name: str = "" is_split_axis: bool = False + is_tunable_split_axis: bool = False is_tiling_axis: bool = False + fixed_split_size: int = 0 @property def is_reduction(self): @@ -65,6 +67,7 @@ def __init__( self, axis_sizes: Dict[str, int], split_params: Dict[str, str], + fixed_split_params: Dict[str, int], tiling_params: Dict[str, str], low_dims: List[str], dtype: torch.dtype, @@ -90,7 +93,9 @@ def __init__( :param dual_reduction: performing reduction on more than one axis. :param persistent_reduction: there is no splitting in reduction axis. """ - self._validate_axis(axis_sizes, split_params, tiling_params, low_dims) + self._validate_axis( + axis_sizes, split_params, fixed_split_params, tiling_params, low_dims + ) axis_dict = {} idx = 0 @@ -99,9 +104,11 @@ def __init__( if name.startswith("r"): prefix = "r" - is_split_axis = name in split_params + is_tunable_split_axis = name in split_params + fixed_split_size = fixed_split_params.get(name, 0) + is_split_axis = is_tunable_split_axis or fixed_split_size > 0 is_tiling_axis = name in tiling_params - split_name = "" if name not in split_params else split_params[name] + split_name = "" if not is_tunable_split_axis else split_params[name] tiling_name = "" if name not in tiling_params else tiling_params[name] axis_dict[name] = AxisInfo( @@ -112,12 +119,17 @@ def __init__( split_name=split_name, tiling_name=tiling_name, is_split_axis=is_split_axis, + is_tunable_split_axis=is_tunable_split_axis, is_tiling_axis=is_tiling_axis, + fixed_split_size=fixed_split_size, ) idx += 1 self.axis_info = list(axis_dict.values()) self.split_axis = [x for x in axis_dict.values() if x.is_split_axis] + self.tunable_split_axis = [ + x for x in axis_dict.values() if x.is_tunable_split_axis + ] self.tiling_axis = [x for x in axis_dict.values() if x.is_tiling_axis] self.low_dims_axis = [x for x in axis_dict.values() if x.name in low_dims] self.dtype = dtype @@ -131,6 +143,7 @@ def _validate_axis( cls, axis_sizes: Dict[str, int], split_params: Dict[str, str], + fixed_split_params: Dict[str, int], tiling_params: Dict[str, str], low_dims: List[str], ) -> None: @@ -144,6 +157,7 @@ def check_keys(params: List[str], context="parameter"): raise KeyError(f"{context} '{k}' not found in known axes: {axis_sizes.keys()}") check_keys(split_params.keys(), "split axis") + check_keys(fixed_split_params.keys(), "fixed split axis") check_keys(tiling_params.keys(), "tiling axis") check_keys(low_dims, "low dim axis") @@ -180,12 +194,18 @@ def __init__(self, kernel_meta: KernelMeta): self.is_simt_mode = kernel_meta.is_simt_mode local_mem_size = (rf_size_in_kbytes if self.is_simt_mode else ub_size_in_kbytes) self.max_numel_threshold = local_mem_size * 1024 // self.dtype_bytes // self.num_buffers - self.max_total_numel = functools.reduce(lambda x, y: x * y, [x.block_size - for x in self.blocks]) if self.blocks else 1 - self.tiny_kernel = self.max_total_numel < 128 * 1024 - self.stop_numel = min(1024 // self.dtype_bytes, self.max_total_numel // - (num_vector_core * 2)) if self.tiny_kernel else 1024 // self.dtype_bytes + self.max_total_numel = functools.reduce(lambda x, y: x * y, [x.block_size for x in self.blocks]) if self.blocks else 1 + self.small_kernel = self.max_total_numel < 128 * 1024 + self.tiny_kernel = self.max_total_numel <= 32 * 1024 + self.stop_numel = min(1024 // self.dtype_bytes, self.max_total_numel // (num_vector_core * 2)) if self.small_kernel else 1024 // self.dtype_bytes self.max_programs_num = 65535 + self.tiny_program_threshold = num_vector_core // 8 + self.tiny_per_program_cap = 1 + self.tiny_low_program_hist = { + p: 0 for p in range(1, self.tiny_program_threshold + 1) + } + self.tiny_low_program_active = False + self.tiny_low_program_tile_floor = 0 @classmethod def init_blocks_info(cls, kernel_meta: KernelMeta) -> List[BlockInfo]: @@ -193,7 +213,7 @@ def init_blocks_info(cls, kernel_meta: KernelMeta) -> List[BlockInfo]: for axis in kernel_meta.axis_info: block_name = axis.split_name sub_block_name = axis.tiling_name - block_size = axis.length + block_size = axis.fixed_split_size if axis.fixed_split_size > 0 else axis.length sub_block_size = block_size blocks.append(BlockInfo(block_name, sub_block_name, block_size, sub_block_size)) @@ -213,6 +233,7 @@ def calcu_last_split_blocks(self, axis_idx): break last_splits = num_vector_core // splits + last_splits = max(1, last_splits) last_blocks = (self.numels[axis_idx] + last_splits - 1) // last_splits return last_blocks @@ -244,10 +265,15 @@ def fill_config(self, cfg, candi_block): curr_numel = candi_block[axis.index] if not axis.is_tiling_axis: curr_numel = self.aligned_numel(curr_numel) - cfg[block_info.block_name] = curr_numel + if block_info.block_name: + cfg[block_info.block_name] = curr_numel if axis.is_tiling_axis: tiling_numel = self.aligned_numel(block_info.sub_block_size) - cfg[block_info.sub_block_name] = min(tiling_numel, candi_block[axis.index]) + cfg[block_info.sub_block_name] = ( + tiling_numel + if self.is_simt_mode + else min(tiling_numel, candi_block[axis.index]) + ) def find_config(self, cfg): for config_var in self.configs: @@ -255,13 +281,62 @@ def find_config(self, cfg): return True return False + def _try_add_tiny_low_program_config(self, total_programs): + if ( + not self.tiny_kernel + or total_programs < 1 + or total_programs > self.tiny_program_threshold + ): + return + + if self.tiny_low_program_hist.get(total_programs, 0) >= self.tiny_per_program_cap: + return + + + candi_block = tuple([x.block_size for x in self.blocks]) + if self.add_to_configs(list(candi_block)): + if candi_block not in self.candidate_blocks: + self.candidate_blocks.append(candi_block) + if not self.tiny_low_program_active: + self.tiny_low_program_active = True + self.tiny_low_program_tile_floor = self.calculate_tile_numel() + self.tiny_low_program_hist[total_programs] = ( + self.tiny_low_program_hist.get(total_programs, 0) + 1 + ) + + def _calc_total_programs(self, candi_block=None): + grids = [] + for axis in self.kernel_meta.split_axis: + numel = self.numels[axis.index] + block_size = ( + self.blocks[axis.index].block_size + if candi_block is None + else candi_block[axis.index] + ) + programs = (numel + block_size - 1) // block_size + grids.append(programs) + + total_programs = functools.reduce(lambda x, y: x * y, grids) if grids else 1 + return total_programs + def add_to_configs(self, candi_block): newcfg = {} self.fill_config(newcfg, candi_block) tile_numel = self.calculate_tile_numel() - stop_numel_threshold = 0 if len(self.configs) < 10 or self.tiny_kernel else self.stop_numel + 100 - if (tile_numel <= self.max_numel_threshold and tile_numel >= stop_numel_threshold - and not self.find_config(newcfg)): + stop_numel_threshold = 0 if len(self.configs) < 10 or self.small_kernel else self.stop_numel + 100 + if self.tiny_low_program_active and self.tiny_low_program_tile_floor > 0: + total_programs = self._calc_total_programs(candi_block) + program_threshold = self.tiny_program_threshold if self.small_kernel else num_vector_core // 2 + if total_programs <= program_threshold: + tiny_low_program_threshold = max( + self.stop_numel, self.tiny_low_program_tile_floor // 2 + ) + stop_numel_threshold = max(stop_numel_threshold, tiny_low_program_threshold) + if ( + tile_numel <= self.max_numel_threshold + and tile_numel >= stop_numel_threshold + and not self.find_config(newcfg) + ): self.configs.append(Config(newcfg, num_warps=1, num_stages=1)) return True return False @@ -330,19 +405,21 @@ def calc_total_programs(): self.candidate_blocks.append(tuple([x.block_size for x in self.blocks])) break - program_threshold = num_vector_core // 8 if self.tiny_kernel else num_vector_core // 2 + program_threshold = self.tiny_program_threshold if self.small_kernel else num_vector_core // 2 + if self.tiny_kernel and total_programs <= program_threshold: + self._try_add_tiny_low_program_config(total_programs) if total_programs > program_threshold or self.dual_reduction: if len(self.candidate_blocks) > 2: self.candidate_blocks.pop(0) self.candidate_blocks.append(tuple([x.block_size for x in self.blocks])) - if self.tiny_kernel: + if self.small_kernel: self.add_to_configs(list(tuple([x.block_size for x in self.blocks]))) slow_decend_split = (total_programs > num_vector_core_tile // 2) if not slow_decend_split: - self.blocks[axis_idx].block_size = numel // 2 + self.blocks[axis_idx].block_size = (numel + 1) // 2 else: - step = numel // 4 if numel // 4 > 1 else 1 + step = (numel + 3) // 4 if (numel + 3) // 4 > 1 else 1 self.blocks[axis_idx].block_size = numel - step self.blocks[axis_idx].sub_block_size = self.blocks[axis_idx].block_size total_programs = calc_total_programs() @@ -404,7 +481,7 @@ def descend_split_tiling(self): tiling_not_low_dims = [x for x in self.kernel_meta.tiling_axis if x not in self.kernel_meta.low_dims_axis] def descend_split_axis(): - for axis in self.kernel_meta.split_axis: + for axis in self.kernel_meta.tunable_split_axis: if self.descend_one_axis(axis.index, is_split=True): return True diff --git a/third_party/ascend/backend/runtime/utils.py b/third_party/ascend/backend/runtime/utils.py index f470fc61da..b5e3eae719 100644 --- a/third_party/ascend/backend/runtime/utils.py +++ b/third_party/ascend/backend/runtime/utils.py @@ -41,11 +41,11 @@ def _init_npu_params(): ub_size_in_kbytes = 192 rf_size_in_kbytes = None - ASCEND_VARIANTS = ["Ascend910B", "Ascend910_93", "Ascend910_95"] + ASCEND_VARIANTS = ["Ascend910B", "Ascend910_93", "Ascend910_95", "Ascend950"] if any(variant in target.arch for variant in ASCEND_VARIANTS): num_vector_core = num_cube_core * 2 - if '910_95' in target.arch: + if target.arch.startswith("Ascend910_95") or target.arch.startswith("Ascend950"): ub_size_in_kbytes = 256 rf_size_in_kbytes = 128 diff --git a/third_party/ascend/backend/spec/include/runtime/libentry/libentry.h b/third_party/ascend/backend/spec/include/runtime/libentry/libentry.h index 2b7690ac01..730b3e6072 100644 --- a/third_party/ascend/backend/spec/include/runtime/libentry/libentry.h +++ b/third_party/ascend/backend/spec/include/runtime/libentry/libentry.h @@ -1,5 +1,4 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * Copyright 2018-2020 Philippe Tillet * Copyright 2020-2022 OpenAI * diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h index 3dc7a3b844..457a751fbd 100644 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h +++ b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h @@ -1,5 +1,4 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * Copyright 2018-2020 Philippe Tillet * Copyright 2020-2022 OpenAI * diff --git a/third_party/ascend/backend/spec/lib/runtime/libentry/libentry.cpp b/third_party/ascend/backend/spec/lib/runtime/libentry/libentry.cpp index ab9c81660e..1e2a9071e9 100644 --- a/third_party/ascend/backend/spec/lib/runtime/libentry/libentry.cpp +++ b/third_party/ascend/backend/spec/lib/runtime/libentry/libentry.cpp @@ -1,3 +1,26 @@ +/* + * Copyright 2018-2020 Philippe Tillet + * Copyright 2020-2022 OpenAI + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + #include "runtime/libentry/libentry.h" using namespace libentry; diff --git a/third_party/ascend/backend/spec/triton/compiler/compiler.py b/third_party/ascend/backend/spec/triton/compiler/compiler.py index cc7ba30e7c..0630699daa 100644 --- a/third_party/ascend/backend/spec/triton/compiler/compiler.py +++ b/third_party/ascend/backend/spec/triton/compiler/compiler.py @@ -169,6 +169,11 @@ def triton_key(): return f'{__version__}' + '-'.join(contents) +def get_cache_key(src, backend, backend_options, env_vars): + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}" + return key + + def parse(full_name, ext, context): if ext == "ttir" or ext == "ttgir": module = ir.parse_mlir_module(full_name, context) @@ -217,7 +222,7 @@ def filter_traceback(e: BaseException): e.__traceback__ = frames[0] -def compile(src, target=None, options=None): +def compile(src, target=None, options=None, _env_vars=None): if target is None: target = driver.active.get_current_target() assert isinstance(target, GPUTarget), "target must be of GPUTarget type" @@ -230,8 +235,8 @@ def compile(src, target=None, options=None): extra_options = src.parse_options() options = backend.parse_options(dict(options or dict(), **extra_options)) # create cache manager - env_vars = get_cache_invalidating_env_vars() - key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars + key = get_cache_key(src, backend, options, env_vars=env_vars) hash = hashlib.sha256(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) # For dumping/overriding only hash the source as we want it to be independent of triton @@ -291,7 +296,11 @@ def compile(src, target=None, options=None): else: stage_name = "MLIRCompile" error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e) - error_detail += f"\n\n[INFO]: The compiled kernel cache is in {fn_cache_manager.cache_dir}\n\n" + from ..runtime.cache import FileCacheManager + if isinstance(fn_cache_manager, FileCacheManager): + error_detail += f"\n\n[INFO]: The compiled kernel cache is in {fn_cache_manager.cache_dir}\n\n" + else: + error_detail += f"\n\n[INFO]: The compiled kernel cache is {file_name}.{ext}\n\n" raise MLIRCompilationError(stage_name, error_detail) from e ir_filename = f"{file_name}.{ext}" if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): diff --git a/third_party/ascend/backend/spec/triton/language/__init__.py b/third_party/ascend/backend/spec/triton/language/__init__.py index ea7f8e7d66..d89c541442 100644 --- a/third_party/ascend/backend/spec/triton/language/__init__.py +++ b/third_party/ascend/backend/spec/triton/language/__init__.py @@ -1,10 +1,130 @@ -def language_extend_globals(globals_dict): - try: - import acl - is_compile_on_910_95 = acl.get_soc_name().startswith("Ascend910_95") - except Exception as e: - is_compile_on_910_95 = False - globals_dict["is_compile_on_910_95"] = is_compile_on_910_95 +"""isort:skip_file""" +# Import order is significant here. +from triton.tools.get_ascend_devices import is_compile_on_910_95 +from . import math +from . import extra +from .standard import ( + argmax, + argmin, + # cdiv, + cumprod, + cumsum, + flip, + interleave, + max, + min, + ravel, + sigmoid, + softmax, + sort, + sum, + swizzle2d, + topk, + xor_sum, + zeros, + zeros_like, +) +from .core import ( + PropagateNan, + TRITON_MAX_TENSOR_NUMEL, + _experimental_descriptor_load, + _experimental_descriptor_store, + make_tensor_descriptor, + load_tensor_descriptor, + store_tensor_descriptor, + add, + advance, + arange, + associative_scan, + assume, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + broadcast, + broadcast_to, + cat, + cast, + clamp, + const, + constexpr, + debug_barrier, + device_assert, + device_print, + dot, + dot_scaled, + dtype, + expand_dims, + float16, + float32, + float64, + float8e4b15, + float8e4nv, + float8e4b8, + float8e5, + float8e5b16, + full, + function_type, + gather, + histogram, + inline_asm_elementwise, + int1, + int16, + int32, + int64, + int8, + join, + load, + make_block_ptr, + max_constancy, + max_contiguous, + maximum, + minimum, + multiple_of, + num_programs, + permute, + pi32_t, + pointer_type, + nv_tma_desc_type, + program_id, + range, + reduce, + reshape, + split, + static_assert, + static_print, + static_range, + store, + tensor, + trans, + uint16, + uint32, + uint64, + uint8, + view, + void, + where, +) +from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, + ceil, cdiv) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint_to_uniform_float, +) def language_extend_exports(globals_dict, all_list): diff --git a/third_party/ascend/backend/spec/triton/language/core.py b/third_party/ascend/backend/spec/triton/language/core.py index 9ed34a899d..8c6b1119bf 100644 --- a/third_party/ascend/backend/spec/triton/language/core.py +++ b/third_party/ascend/backend/spec/triton/language/core.py @@ -1542,15 +1542,18 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i specified (i.e. at least one must be :code:`None`). """ assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" - assert not allow_tf32, "allow_tf32 is deprecated, please use input_precision='hf32' on Ascend instead." + assert not allow_tf32, "allow_tf32 is not supported as 'True', please use fp32 on Ascend instead." if input_precision is None: supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" + # when setting allow_tf32, use input_precision='hf32' on Ascend instead. + if allow_tf32: + default_precision = "hf32" input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) else: - assert input_precision not in [ - "tf32", "tf32x3" - ], "input_precision == tf32 or tf32x3 is invalid, please use input_precision='hf32' on Ascend instead." + if input_precision == "tf32": + input_precision = "hf32" + input_precision = _constexpr_to_value(input_precision) out_dtype = _constexpr_to_value(out_dtype) max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) @@ -2091,9 +2094,9 @@ def expand_ndims(t, ndims): ret = semantic.reduction(input, axis, make_combine_region, _builder) if keep_dims: if axis is not None: - ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret) + ret = builtins.tuple(expand_dims(t, axis, _builder=_builder) for t in ret) else: - ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret) + ret = builtins.tuple(expand_ndims(t, len(input[0].shape)) for t in ret) return ret @@ -2523,7 +2526,7 @@ def kernel(A, B, C, D, BLOCK: tl.constexpr): if not has_multiple_outputs: return tensor(call.get_result(0), res_tys[0]) - return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) + return builtins.tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) # ----------------------- diff --git a/third_party/ascend/backend/spec/triton/language/semantic.py b/third_party/ascend/backend/spec/triton/language/semantic.py index 7be3b70db7..50cd6a8f80 100644 --- a/third_party/ascend/backend/spec/triton/language/semantic.py +++ b/third_party/ascend/backend/spec/triton/language/semantic.py @@ -336,15 +336,16 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu # float % float if scalar_ty.is_floating(): # input - input.div(other, rounding_mode="floor") * other - floor = math.floor(fdiv(input, other, False, builder), _builder=builder) - ret = sub(input, mul(floor, other, True, builder), True, builder) - return ret + return tl.tensor(builder.create_frem(input.handle, other.handle), input.type) # % int elif scalar_ty.is_int(): if scalar_ty.int_signedness != other_scalar_ty.int_signedness: raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " "because they have different signedness;" "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if hasattr(input, 'was_bool_to_int8'): + false_val = builder.get_int1(False) + return tl.tensor(false_val, tl.int1) if scalar_ty.is_int_signed(): return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) else: @@ -1114,6 +1115,9 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti # Check `boundary_check` argument boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + if boundary_check and padding is None: + padding = ir.PADDING_OPTION.PAD_ZERO + # Build IR return tl.tensor( builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) @@ -1177,11 +1181,16 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ # Build IR if mask is None: - ret = tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + load_handle = builder.create_load(ptr.handle, cache, eviction, is_volatile) else: - ret = tl.tensor( - builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, - is_volatile), dst_ty) + load_handle = builder.create_masked_load( + ptr.handle, mask.handle, other.handle if other else None, cache, eviction, is_volatile + ) + + if is_bool: + load_handle.set_attr("was_bool_to_int8", builder.get_bool_attr(True)) + + ret = tl.tensor(load_handle, dst_ty) # Do not cast back to int1 when is_bool=true. We directly use the int8 tensor given by tl.load if is_bool: ret.was_bool_to_int8 = True @@ -1608,7 +1617,8 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona if (input_precision == getattr(ir.INPUT_PRECISION, "HF32")): if (not lhs.dtype.is_fp32() or not rhs.dtype.is_fp32() or not ret_scalar_ty.is_fp32()): - raise ValueError("input_precision = 'hf32' must be used with f32 * f32 = f32 on Ascend") + # when input and result is not fp32, ignore input_precision (default is ieee) + input_precision = _str_to_dot_input_precision(builder.options.default_dot_input_precision, builder) if max_num_imprecise_acc is not None: print("max_num_imprecise_acc in tl.dot is not supported on Ascend yet. Thus it is ignored.") @@ -1653,8 +1663,12 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te rhs_format: str, acc: Union[tl.tensor, None], out_dtype: tl.dtype, lhs_k_pack, rhs_k_pack, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() - assert lhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"lhs matrix dtype must be bf16 or fp16" - assert rhs.dtype == tl.bfloat16 or rhs.dtype == tl.float16, f"rhs matrix dtype must be bf16 or fp16" + if is_compile_on_910_95: + assert lhs.dtype in [tl.float16, tl.bfloat16, tl.uint8, tl.float8e5, tl.float8e4nv], f"lhs matrix dtype must be in [bf16, fp16, uint8, e5m2, e4m3]" + assert rhs.dtype in [tl.float16, tl.bfloat16, tl.uint8, tl.float8e5, tl.float8e4nv], f"rhs matrix dtype must be in [bf16, fp16, uint8, e5m2, e4m3]" + else: + assert lhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"lhs matrix dtype must be bf16 or fp16" + assert rhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"rhs matrix dtype must be bf16 or fp16" assert lhs.dtype == rhs.dtype, f"lhs rhs matrix must get same dtype" lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) @@ -1663,26 +1677,33 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te rhs_format: str = rhs_format.value lhs_format_enum = _str_to_fp_type(lhs_format) rhs_format_enum = _str_to_fp_type(rhs_format) - allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" + if is_compile_on_910_95: + allowed_formats = {"bf16", "fp16", "e4m3", "e5m2"} + else: + allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) - assert isinstance(lhs_scale, tl.tensor) and lhs_scale.dtype == tl.int8, f"lhs_scale must be int8 tensor" + assert isinstance(lhs_scale, tl.tensor) and (lhs_scale.dtype == tl.int8 or lhs_scale.dtype == tl.uint8), f"lhs_scale must be int8 or uint8 tensor" if not rhs_scale_is_none: - assert isinstance(rhs_scale, tl.tensor) and rhs_scale.dtype == tl.int8, f"rhs_scale must be int8 tensor" + assert isinstance(rhs_scale, tl.tensor) and (rhs_scale.dtype == tl.int8 or rhs_scale.dtype == tl.uint8), f"rhs_scale must be int8 or uint8 tensor" lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) - if lhs_k_pack == False: + assert lhs_k_pack or lhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K" + assert rhs_k_pack or rhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K" + + lhs_k_pack_v = lhs_k_pack.value if isinstance(lhs_k_pack, tl.constexpr) else lhs_k_pack + rhs_k_pack_v = rhs_k_pack.value if isinstance(rhs_k_pack, tl.constexpr) else rhs_k_pack + + if lhs_k_pack_v is False: dims = (1, 0) - dims = core._unwrap_iterable(dims) tmp_lhs = permute(lhs, dims, builder) lhs = reshape(tmp_lhs, (lhs.shape[0], lhs.shape[1]), True, builder) - if rhs_k_pack == False: + if rhs_k_pack_v is False: dims = (1, 0) - dims = core._unwrap_iterable(dims) tmp_rhs = permute(rhs, dims, builder) rhs = reshape(tmp_rhs, (rhs.shape[0], rhs.shape[1]), True, builder) diff --git a/third_party/ascend/backend/spec/triton/runtime/_async_compile.py b/third_party/ascend/backend/spec/triton/runtime/_async_compile.py new file mode 100644 index 0000000000..a6c773123e --- /dev/null +++ b/third_party/ascend/backend/spec/triton/runtime/_async_compile.py @@ -0,0 +1,55 @@ +from __future__ import annotations +from typing import Callable, Optional +from concurrent.futures import Executor, as_completed, Future +from contextvars import ContextVar + +active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None) + + +class FutureKernel: + + def __init__(self, finalize_compile: Callable, future: Future): + self.finalize_compile = finalize_compile + self.kernel = None + self.future = future + + def result(self): + if self.kernel is not None: + return self.kernel + + kernel = self.future.result() + self.finalize_compile(kernel) + self.kernel = kernel + return kernel + + +class AsyncCompileMode: + + def __init__(self, executor: Executor): + self.executor = executor + self.raw_futures = [] + self.future_kernels = {} + + def submit(self, key, compile_fn, finalize_fn): + future = self.future_kernels.get(key) + if future is not None: + return future + + future = self.executor.submit(compile_fn) + future._key = key + self.raw_futures.append(future) + future_kernel = FutureKernel(finalize_fn, future) + self.future_kernels[key] = future_kernel + return future_kernel + + def __enter__(self): + if active_mode.get() is not None: + raise RuntimeError("Another AsyncCompileMode is already active") + active_mode.set(self) + return self + + def __exit__(self, exc_type, exc_value, traceback): + # Finalize any outstanding compiles + for future in as_completed(self.raw_futures): + self.future_kernels[future._key].result() + active_mode.set(None) diff --git a/third_party/ascend/backend/spec/triton/runtime/ascend_interpreter.py b/third_party/ascend/backend/spec/triton/runtime/ascend_interpreter.py new file mode 100644 index 0000000000..6cfb26e682 --- /dev/null +++ b/third_party/ascend/backend/spec/triton/runtime/ascend_interpreter.py @@ -0,0 +1,735 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Ascend-specific interpreter builder extensions. + +This module extends the base InterpreterBuilder with Ascend-specific operations +(extension ops) without modifying the public base class. All Ascend-related +features are isolated here and can be extended independently. + +Author: Triton-Ascend Contributors +""" + +import warnings +import contextlib +import numpy as np +import triton.language as tl +from .interpreter import InterpreterBuilder, TensorHandle, ReduceOps, _get_np_dtype +from .._C.libtriton import interpreter as _interpreter + + +class AscendReduceOps(ReduceOps): + """ + Ascend reduce operations that override only the apply_impl logic. + All other methods (sum, min_max, generic_reduce, etc.) are inherited from ReduceOps. + """ + def apply_impl(self, input_param): + if self.combine_fn == tl.standard._argmin_combine_tie_break_left: + return self.min_max(input_param[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) + elif self.combine_fn == tl.standard._argmax_combine_tie_break_left: + return self.min_max(input_param[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) + # Ta has modified the implemention of tl.max + elif self.combine_fn == tl.standard._elementwise_max_default: + return self.min_max(input_param[0], val_reduce_op=np.nanmax, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_max_propagate_nan: + return self.min_max(input_param[0], val_reduce_op=np.max, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_min: + return self.min_max(input_param[0], val_reduce_op=np.nanmin, idx_reduce_op=None) + elif self.combine_fn == tl.standard._sum_combine: + return self.sum(input_param[0]) + else: + # Fall back to the slow mode + return self.generic_reduce(input_param) + + +def _compute_strides(shape): + strides = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + strides[i] = strides[i + 1] * shape[i + 1] + return strides + + +class AscendInterpreterBuilder(InterpreterBuilder): + """ + Extended InterpreterBuilder with Ascend-specific extension operations. + + This class inherits from InterpreterBuilder and adds support for: + - get_element (extract_scalar): Extract scalar from tensor using indices + - insert_slice: Insert sub-tensor into full tensor + - extract_slice: Extract slice from tensor + - index_select_simd: SIMD gather operation + - get_sub_vec_id: Get vector core ID for 1:2 ratio emulation + - Synchronization operations: sync_block_set/wait/all + + All extension operations handle both TensorHandle and Python int types + for interpreter mode compatibility. + """ + + def __init__(self) -> None: + super().__init__() + # Sub-vector core ID for simulating 1:2 hardware ratio + self.sub_vec_id = 0 + # Flag to track if sub_vec_id simulation is needed + self._sub_vec_simulation_enabled = False + + def to_int_val(self, val): + """ + Convert a value (int or TensorHandle) to Python int. + + :param val: Value to convert (int, TensorHandle, or other) + :return: Python integer + """ + if isinstance(val, TensorHandle): + return int(val.data.item()) + return int(val) + + def _patch_lang_ascend(self, fn): + + def _new_range(arg1, arg2=None, step=None, **kwargs): + if step is None: + step = 1 + if arg2 is None: + start, end = 0, arg1 + else: + start, end = arg1, arg2 + return range(start, end, step) + + def _new_reduce(input_param, axis, combine_fn, keep_dims=False, **kwargs): + return AscendReduceOps(axis, combine_fn, keep_dims).apply(input_param) + + @contextlib.contextmanager + def _dummpy_scope(*args, **kwargs): + yield + + tl.extra.cann.extension.scope = _dummpy_scope + tl.extra.cann.extension.parallel = _new_range + tl.reduce = _new_reduce + tl.core.reduce = _new_reduce + + def get_additional_reserved_keywords(self): + """ + Return additional reserved keywords specific to Ascend backend. + + These keywords will be filtered out from kernel call arguments + and are not supported by the interpreter. + + :return: List of additional reserved keyword strings + """ + return [ + "multibuffer", # Ascend-specific memory buffering + "debug", + "optimize_dynamic_offset", + "enable_mixed_cv", + "enable_auto_bind_sub_block", + "sync_solver", + # Add more Ascend-specific keywords here as needed + # "ascend_option1", + # "ascend_option2", + ] + + def patch_extensions(self, fn): + """ + Patch Ascend extension modules for the given function. + + This method handles all Ascend-specific extension module patching, + including CANN extensions and any other extension modules found in + the function's global namespace. + + :param fn: The kernel function to patch extensions for + """ + # Import _patch_builtin from parent module + from .interpreter import _patch_builtin + self._patch_lang_ascend(fn) + + # Patch all modules in fn's globals that might be extension modules + for name, value in list(fn.__globals__.items()): + if value is None: + continue + try: + # Check if it looks like an extension module (has builtin functions) + if hasattr(value, '__name__') and 'extension' in str(value.__name__): + _patch_builtin(value, self) + # Also try patching any module-like object that might have builtin functions + elif hasattr(value, '__dict__') and not isinstance(value, type): + # Try to patch it and ignore if it fails + try: + _patch_builtin(value, self) + except Exception: + pass + except Exception: + pass + + # Also try importing extension directly as fallback + try: + import triton.language.extra.cann.extension as extension + _patch_builtin(extension, self) + except (ImportError, AttributeError): + # Extension module not available (e.g., non-Ascend backend) + pass + + def execute_with_sub_vec_simulation(self, fn, args, grid): + """ + Execute function with optional 1:2 sub-vector core simulation. + + Sub-vector simulation is only activated when create_get_sub_vec_id() is + actually called during execution. This avoids unnecessary double execution + for code that doesn't use sub_vec_id functionality. + + :param fn: The kernel function to execute + :param args: Function arguments + :param grid: Grid dimensions (nx, ny, nz) + """ + # Reset simulation flag at the beginning of each execution + self._sub_vec_simulation_enabled = False + self.sub_vec_id = 0 + + # First, try a single execution to see if sub_vec_id is used + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + self.set_grid_idx(x, y, z) + fn(**args) + + # If sub_vec_id was accessed during execution, run again with sub_vec_id=1 + if self._sub_vec_simulation_enabled: + self.sub_vec_id = 1 + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + self.set_grid_idx(x, y, z) + fn(**args) + + # ======================================================================== + # Extension ops for Ascend + # ======================================================================== + + def create_extract_scalar(self, tensor_handle, indices): + """ + Extract a scalar from a tensor using indices (equivalent to get_element). + + Handles mixed types: Python int (from loops) and TensorHandle (from other ops). + + :param tensor_handle: The tensor to extract from (TensorHandle) + :param indices: List of scalar indices (can be TensorHandle or Python int) + :return: Scalar value as TensorHandle + """ + # Convert indices from TensorHandle or Python int to integers + index_values = [] + for idx in indices: + if isinstance(idx, int): + # Python int passed directly (e.g., from loop counter) + index_values.append(idx) + elif isinstance(idx, TensorHandle): + # Interpreter TensorHandle + index_values.append(int(idx.data.item()) if hasattr(idx.data, 'item') else int(idx.data)) + else: + # Fallback: try to extract data + index_values.append(int(idx.data.item()) if hasattr(idx, 'data') and hasattr(idx.data, 'item') + else int(idx.data) if hasattr(idx, 'data') else int(idx)) + + # Extract the scalar value + scalar_data = tensor_handle.data[tuple(index_values)] + return TensorHandle(np.array([scalar_data]), tensor_handle.dtype.scalar) + + def create_insert_slice(self, full_tensor, sub_tensor, offsets, sizes, strides): + """ + Insert a sub-tensor into a full tensor at specified offsets. + + Handles mixed types: Python int and TensorHandle for offsets. + + :param full_tensor: The full tensor (destination, TensorHandle) + :param sub_tensor: The sub-tensor to insert (TensorHandle) + :param offsets: List of offset TensorHandle objects or Python ints + :param sizes: List of size integers + :param strides: List of stride integers + :return: Modified tensor with sub_tensor inserted (TensorHandle) + """ + result = full_tensor.data.copy() + + # Convert offsets from TensorHandle or Python int to integers + offset_values = [] + for off in offsets: + if isinstance(off, int): + # Python int passed directly + offset_values.append(off) + elif isinstance(off, TensorHandle): + # Interpreter TensorHandle + offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) + else: + # Fallback + offset_values.append(int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') + else int(off.data) if hasattr(off, 'data') else int(off)) + + # Build slices for insertion + slices = [] + for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): + end = offset + size * stride + if stride == 1: + slices.append(slice(offset, end)) + else: + slices.append(slice(offset, end, stride)) + + # Insert the sub-tensor + result[tuple(slices)] = sub_tensor.data + + return TensorHandle(result, full_tensor.dtype.scalar) + + def create_extract_slice(self, full_tensor, offsets, sizes, strides): + """ + Extract a slice from a full tensor. + + Handles mixed types: Python int and TensorHandle for offsets. + + :param full_tensor: The full tensor (TensorHandle) + :param offsets: List of offset TensorHandle objects or Python ints + :param sizes: List of size integers + :param strides: List of stride integers + :return: Extracted sub-tensor (TensorHandle) + """ + # Convert offsets from TensorHandle or Python int to integers + offset_values = [] + for off in offsets: + if isinstance(off, int): + # Python int passed directly + offset_values.append(off) + elif isinstance(off, TensorHandle): + # Interpreter TensorHandle + offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) + else: + # Fallback + offset_values.append(int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') + else int(off.data) if hasattr(off, 'data') else int(off)) + + # Build slices for extraction + slices = [] + for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): + end = offset + size * stride + if stride == 1: + slices.append(slice(offset, end)) + else: + slices.append(slice(offset, end, stride)) + + # Extract the slice + extracted = full_tensor.data[tuple(slices)] + + return TensorHandle(extracted, full_tensor.dtype.scalar) + + def create_index_select_simd(self, src_ptr, index_tensor, dim, src_shape, src_offset, read_shape, result_shape): + """ + SIMD index_select operation (gather with indices along a dimension). + + This is a hardware-accelerated gather operation that selects elements + from a tensor using a set of indices along a specified dimension. + + :param src_ptr: Source tensor pointer (TensorHandle), just ptr address, not value + :param index_tensor: 1D tensor of indices (TensorHandle or array) + :param dim: Dimension to select from (int) + :param src_shape: List of source shape (int or TensorHandle) + :param src_offset: List of source offset (int or TensorHandle) + :param read_shape: List of read shape (int or TensorHandle) + :param result_shape: List of result shape (int or TensorHandle) + :return: Result tensor with selected indices (TensorHandle) + """ + # Convert src_shape, src_offset, read_shape to integers + src_shape_vals = [self.to_int_val(s) for s in src_shape] + src_offset_vals = [self.to_int_val(s) if s != -1 else -1 for s in src_offset] + read_shape_vals = [self.to_int_val(r) if r != -1 else -1 for r in read_shape] + result_shape_vals = [self.to_int_val(r) for r in result_shape] + + # Get index values - handle both array and TensorHandle + if isinstance(index_tensor, TensorHandle): + indices = index_tensor.data.flatten() + else: + indices = np.asarray(index_tensor).flatten() + + # Ensure indices are integers + if indices.dtype not in [np.int32, np.int64]: + indices = indices.astype(np.int32) + + # Element type + dtype_tt = src_ptr.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + src_strides = _compute_strides(src_shape_vals) + base_addr = int(src_ptr.data.item()) + + # Create result tensor + result = np.empty(result_shape_vals, dtype=dtype_np) + + # Perform index_select: for each index, read the specified data + for out_idx, in_idx in enumerate(indices): + in_idx = int(in_idx) + # Generate all coordinates in the tile + ranges = [] + for d in range(len(src_shape_vals)): + if d == dim: + ranges.append([in_idx]) + else: + offset = src_offset_vals[d] if src_offset_vals[d] != -1 else 0 + read_size = read_shape_vals[d] if read_shape_vals[d] != -1 else src_shape_vals[d] + # Clamp to valid range + offset = max(0, min(offset, src_shape_vals[d] - 1)) + read_size = min(read_size, src_shape_vals[d] - offset) + ranges.append(list(range(offset, offset + read_size))) + from itertools import product + coords = list(product(*ranges)) + + # Compute address for each element in the tile + addresses = [] + for coord in coords: + offset = sum(coord[i] * src_strides[i] for i in range(len(coord))) + addr = base_addr + offset * np.dtype(dtype_np).itemsize + addresses.append(addr) + # load data + addr_array = np.array(addresses, dtype=np.uint64) + mask_array = np.ones_like(addr_array, dtype=bool) + other_array = np.zeros_like(addr_array, dtype=dtype_np) + tile_data = _interpreter.load(addr_array, mask_array, other_array, dtype_np) + # Reshape tile_data to match read_shape with dim=1 at dim + tile_shape = [] + for d in range(len(src_shape_vals)): + if d == dim: + tile_shape.append(1) + else: + offset = src_offset_vals[d] + read_size = read_shape_vals[d] + offset = max(0, min(offset, src_shape_vals[d] - 1)) + read_size = min(read_size, src_shape_vals[d] - offset) + tile_shape.append(read_size) + tile_data = tile_data.reshape(tile_shape) + + # Build result slice + result_slices = [] + for d in range(len(result_shape_vals)): + if d == dim: + result_slices.append(slice(out_idx, out_idx + 1)) + else: + result_slices.append(slice(None)) + result[tuple(result_slices)] = tile_data + + return TensorHandle(result, dtype_tt) + + def create_get_sub_vec_id(self): + """ + Get the Vector Core index on the AI Core. + + In Interpreter mode, simulate multiple vector cores by maintaining + a sub_vec_id counter. This is used for 1:2 hardware ratio emulation + where different vector cores process different partitions of the data. + + The first call to this method enables sub_vec_simulation, causing + the kernel to be executed twice (once for each sub_vec_id value). + + :return: Vector Core ID as TensorHandle (int64, scalar) + """ + # Enable sub_vec_id simulation when this method is called + self._sub_vec_simulation_enabled = True + + # Return the current sub_vec_id + vec_id = np.int64(self.sub_vec_id) + return TensorHandle(np.array([vec_id], dtype=np.int64), tl.int64) + + def sync_block_set(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): + """ + Set synchronization event between compute and vector units. + + In Interpreter mode, this is a no-op since we execute single-threaded. + Synchronization is not needed in CPU emulation. + + :param sender: Source unit ("cube" or "vector") + :param receiver: Destination unit ("cube" or "vector") + :param event_id: Event ID (TensorHandle) + :param sender_pipe_value: Sender pipe value + :param receiver_pipe_value: Receiver pipe value + """ + # No-op in interpreter mode: single-threaded execution doesn't need sync + pass + + def sync_block_wait(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): + """ + Wait for synchronization event between compute and vector units. + + In Interpreter mode, this is a no-op since we execute single-threaded. + Synchronization is not needed in CPU emulation. + + :param sender: Source unit ("cube" or "vector") + :param receiver: Destination unit ("cube" or "vector") + :param event_id: Event ID (TensorHandle) + :param sender_pipe_value: Sender pipe value + :param receiver_pipe_value: Receiver pipe value + """ + # No-op in interpreter mode: single-threaded execution doesn't need sync + pass + + def sync_block_all(self, mode, event_id): + """ + Synchronize all compute or vector units globally. + + In Interpreter mode, this is a no-op since we execute single-threaded. + Synchronization is not needed in CPU emulation. + + :param mode: Sync mode ("all_cube", "all_vector", "all", "all_sub_vector") + :param event_id: Event ID (int, constexpr, or TensorHandle) + """ + # No-op in interpreter mode: single-threaded execution doesn't need sync + pass + + def get_int1_ty(self): + return tl.int1 + + def get_all_ones_value(self, tt_type): + np_type = _get_np_dtype(tt_type) + if "int" in np_type.name: + return TensorHandle(np.full(1, -1, dtype=np_type), tt_type.scalar) + elif np_type == np.bool_: + return TensorHandle(np.full(1, True, dtype=np_type), tt_type.scalar) + else: + raise TypeError(f"unsupported type {tt_type}") + + def is_simt_mode(self): + return False + + def create_sort(self, ptr_data, dim: int, descending: bool): + ndim = ptr_data.data.ndim + norm_dim = dim if dim >= 0 else dim + ndim + if not (0 <= norm_dim < ndim): + raise IndexError( + f"Dimension out of range(expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})") + + if descending: + sorted_asc = np.sort(ptr_data.data, axis=norm_dim) + sorted_desc = np.flip(sorted_asc, axis=norm_dim) + return TensorHandle(sorted_desc, ptr_data.dtype.scalar) + else: + return TensorHandle(np.sort(ptr_data.data, axis=norm_dim), ptr_data.dtype.scalar) + + def create_flip(self, ptr_data, dim): + ndim = ptr_data.data.ndim + norm_dim = dim if dim >= 0 else dim + ndim + if not (0 <= norm_dim < ndim): + raise IndexError( + f"Dimension out of range(expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})") + return TensorHandle(np.flip(ptr_data.data, axis=norm_dim), ptr_data.dtype.scalar) + + def create_gather_out_to_ub(self, src_ptr, index_tensor, index_boundary, dim, src_stride, end_offset, start_offset, other=None): + # Convert src_stride, start_offset, end_offset to integers + src_stride_vals = [self.to_int_val(s) for s in src_stride] + start_offset_vals = [self.to_int_val(s) for s in start_offset] + end_offset_vals = [self.to_int_val(s) for s in end_offset] + + # Element type + dtype_tt = src_ptr.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + element_size = np.dtype(dtype_np).itemsize + base_addr = int(src_ptr.data.item()) + index_shape = index_tensor.data.shape + index_rank = len(index_shape) + total_elements = np.prod(index_shape) + + # Generate coordinates + all_coords = [] + for idx in range(total_elements): + coord = np.unravel_index(idx, index_shape) + all_coords.append(coord) + + # Compute the source tensor coordinates for each position in all_coords + src_coords = [] + for coord in all_coords: + src_coord = [] + for d in range(index_rank): + if d == dim: + index_value = index_tensor.data[coord] + if index_value >= index_boundary: + src_coord.append(-1) + else: + src_coord.append(start_offset_vals[d] + index_value) + else: + src_coord.append(start_offset_vals[d] + coord[d]) + src_coords.append(src_coord) + + # Compute address and mask + addresses = [] + valid_mask = [] + for _, src_coord in enumerate(src_coords): + if -1 in src_coord: + addresses.append(0) + valid_mask.append(False) + else: + offset = 0 + for d in range(index_rank): + offset += src_coord[d] * src_stride_vals[d] + address = base_addr + offset * element_size + addresses.append(address) + valid_mask.append(True) + + addr_array = np.array(addresses, dtype=np.uint64) + mask_array = np.array(valid_mask, dtype=bool) + + # Create other value array + if other is not None: + if isinstance(other, TensorHandle): + other_value = other.data.item() + else: + other_value = other + other_array = np.full(addr_array.shape, other_value, dtype=dtype_np) + else: + other_array = np.zeros(addr_array.shape, dtype=dtype_np) + + # Load data + flat_result = _interpreter.load(addr_array, mask_array, other_array, dtype_np) + result = flat_result.reshape(index_shape) + return TensorHandle(result, dtype_tt) + + def create_scatter_ub_to_out(self, dst_ptr, value_tensor, index_tensor, index_boundary, dim, dst_stride, end_offset, start_offset): + # Convert dst_stride, start_offset, end_offset to integers + dst_stride_vals = [self.to_int_val(s) for s in dst_stride] + start_offset_vals = [self.to_int_val(s) for s in start_offset] + end_offset_vals = [self.to_int_val(s) for s in end_offset] + + # Element type + dtype_tt = dst_ptr.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + element_size = np.dtype(dtype_np).itemsize + base_addr = int(dst_ptr.data.item()) + + index_shape = index_tensor.data.shape + index_rank = len(index_shape) + total_elements = np.prod(index_shape) + flat_values = value_tensor.data.flatten() + flat_indices = index_tensor.data.flatten() + + # Generate coordinates + all_coords = [] + for idx in range(total_elements): + coord = np.unravel_index(idx, index_shape) + all_coords.append(coord) + + # Compute address and mask + addresses = [] + valid_mask = [] + for _, coord in enumerate(all_coords): + index_value = index_tensor.data[coord] + if index_value >= index_boundary: + addresses.append(0) + valid_mask.append(False) + else: + dst_coord = [] + for d in range(index_rank): + if d == dim: + dst_coord.append(start_offset_vals[d] + index_value) + else: + dst_coord.append(start_offset_vals[d] + coord[d]) + offset = 0 + for d in range(index_rank): + offset += dst_coord[d] * dst_stride_vals[d] + address = base_addr + offset * element_size + addresses.append(address) + valid_mask.append(True) + + addr_array = np.array(addresses, dtype=np.uint64) + mask_array = np.array(valid_mask, dtype=bool) + + _interpreter.store(addr_array, flat_values, mask_array) + + def create_index_put(self, dst_ptr, index_tensor, value_tensor, dim, index_boundary, end_offset, start_offset, dst_stride): + # Convert dst_stride, start_offset, end_offset_ to integers + dst_stride_vals = [self.to_int_val(s) for s in dst_stride] + start_offset_vals = [self.to_int_val(s) for s in start_offset] + end_offset_vals = [self.to_int_val(s) for s in end_offset] + + # Element type + dtype_tt = dst_ptr.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + element_size = np.dtype(dtype_np).itemsize + base_addr = int(dst_ptr.data.item()) + + value_shape = value_tensor.data.shape + value_rank = len(value_shape) + + flat_values = value_tensor.data.flatten() + total_elements = flat_values.size + + # Generate coordinates + all_coords = [] + for idx in range(total_elements): + coord = np.unravel_index(idx, value_shape) + all_coords.append(coord) + + read_ranges = [] + for d in range(value_rank): + start = start_offset_vals[d] + end = end_offset_vals[d] + read_ranges.append((start, end)) + + #Compute address + addresses = [] + valid_mask = [] + values_to_store = [] + for i, coord in enumerate(all_coords): + index_pos = coord[dim] + index_value = index_tensor.data[index_pos] + if index_value >= index_boundary: + addresses.append(0) + valid_mask.append(False) + else: + dst_coord = [] + for d in range(value_rank): + if d == dim: + dst_coord.append(index_value) + else: + dst_coord.append(start_offset_vals[d] + coord[d]) + offset = 0 + for d in range(value_rank): + offset += dst_coord[d] * dst_stride_vals[d] + address = base_addr + offset * element_size + addresses.append(address) + values_to_store.append(flat_values[i]) + valid_mask.append(True) + + addr_array = np.array(addresses, dtype=np.uint64) + mask_array = np.array(valid_mask, dtype=bool) + values_array = np.array(values_to_store, dtype=dtype_np) + + _interpreter.store(addr_array, values_array, mask_array) + + def get_bool_attr(self, val): + return bool(val) + + def get_unit_attr(self): + return None # None valule in compile_hint return uint + + def get_int32_attr(self, val): + return int(val) + + def get_str_attr(self, val): + return str(val) + + def get_i64_array_attr(self, val): + return [int(x) for x in val] + + def create_annotation_mark(self, ptr_data, hint_name: str, hint_val): + if hint_name == "overflow_mode": + raise ValueError(f"overflow_mode is not supported in interpreter mode, may have accuracy issues") + else: + warnings.warn( + f"compile_hint '{hint_name}' is not supported in interpreter mode, just pass it", + UserWarning, + stacklevel=2 + ) \ No newline at end of file diff --git a/third_party/ascend/backend/spec/triton/runtime/autotuner.py b/third_party/ascend/backend/spec/triton/runtime/autotuner.py index 5afa5228d5..5f4fd8f6b9 100644 --- a/third_party/ascend/backend/spec/triton/runtime/autotuner.py +++ b/third_party/ascend/backend/spec/triton/runtime/autotuner.py @@ -4,7 +4,9 @@ import os import time import inspect -from typing import Dict +import itertools + +from typing import Any, Dict, List from .jit import KernelInterface from .errors import OutOfResources @@ -116,7 +118,6 @@ def _post_hook(kwargs, exception): quantiles=quantiles, ) return - import triton.testing self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( kernel_call, @@ -406,3 +407,251 @@ def decorator(fn): return Heuristics(fn, fn.arg_names, values) return decorator + + +_ALL_PARAMS = { + "BM_list", "BN_list", + "multibuffer", "unit_flag", + "limit_auto_multi_buffer_only_for_local_buffer", + "limit_auto_multi_buffer_of_local_buffer", + "set_workspace_multibuffer", + "enable_hivm_auto_cv_balance", + "tile_mix_vector_loop", + "tile_mix_cube_loop" +} + +_DEFAULTS = { + "BM_list": [16, 32, 64, 128], + "BN_list": [16, 32, 64, 128], + "multibuffer": [False], + "unit_flag": [False], + "limit_auto_multi_buffer_only_for_local_buffer": [True], + "limit_auto_multi_buffer_of_local_buffer": ["no-l0c"], + "set_workspace_multibuffer": [2, 4], + "enable_hivm_auto_cv_balance": [True], + "tile_mix_vector_loop": [2, 4], + "tile_mix_cube_loop": [2, 4] +} + +_VALID_VALUES = { + "limit_auto_multi_buffer_of_local_buffer": ["no-limit", "no-l0c"], + "set_workspace_multibuffer": [2, 4], + "tile_mix_vector_loop": [2, 4, 8], + "tile_mix_cube_loop": [2, 4, 8] +} + +_CUBE_PARAMS = {"multibuffer", "unit_flag", "limit_auto_multi_buffer_of_local_buffer"} +_MIXCV_PARAMS = { + "multibuffer", "unit_flag", + "limit_auto_multi_buffer_only_for_local_buffer", + "limit_auto_multi_buffer_of_local_buffer", + "set_workspace_multibuffer", + "enable_hivm_auto_cv_balance", + "tile_mix_vector_loop", + "tile_mix_cube_loop" +} +_VECTOR_PARAMS = {"multibuffer"} + + +def _check_boolean_list(val: List[Any], param_name: str) -> bool: + return isinstance(val, (list, tuple)) and len(val) > 0 and all(isinstance(x, bool) for x in val) + + +def _check_string_in_set(val: List[Any], valid_set: set, param_name: str) -> bool: + return isinstance(val, (list, tuple)) and len(val) > 0 and all(v in valid_set for v in val) + + +def _check_int_in_set(val: List[Any], valid_set: set, param_name: str) -> bool: + return isinstance(val, (list, tuple)) and len(val) > 0 and all(isinstance(v, int) and v in valid_set for v in val) + + +_VALIDATION_RULES = { + "multibuffer": { + "desc": "must be non-empty list/tuple of boolean values", + "check": _check_boolean_list + }, + "unit_flag": { + "desc": "must be non-empty list/tuple of boolean values", + "check": _check_boolean_list + }, + "limit_auto_multi_buffer_only_for_local_buffer": { + "desc": "must be non-empty list/tuple of boolean values", + "check": _check_boolean_list + }, + "limit_auto_multi_buffer_of_local_buffer": { + "desc": f"must be one or more of: {_VALID_VALUES['limit_auto_multi_buffer_of_local_buffer']}", + "check": lambda val, param_name: _check_string_in_set(val, _VALID_VALUES['limit_auto_multi_buffer_of_local_buffer'], "limit_auto_multi_buffer_of_local_buffer") + }, + "set_workspace_multibuffer": { + "desc": f"must be one or more of: {_VALID_VALUES['set_workspace_multibuffer']}", + "check": lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['set_workspace_multibuffer'], "set_workspace_multibuffer") + }, + "enable_hivm_auto_cv_balance": { + "desc": "must be non-empty list/tuple of boolean values", + "check": _check_boolean_list + }, + "tile_mix_vector_loop": { + "desc": f"must be one or more of: {_VALID_VALUES['tile_mix_vector_loop']}", + "check": lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['tile_mix_vector_loop'], "tile_mix_vector_loop") + }, + "tile_mix_cube_loop": { + "desc": f"must be one or more of: {_VALID_VALUES['tile_mix_cube_loop']}", + "check": lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['tile_mix_cube_loop'], "tile_mix_cube_loop") + } +} + + +class BaseAutotuner: + """ + Base Class: Used to generate auto-tuning configurations for Triton kernels. + Subclasses must define: + operator_name: The name of the operator. + supported_params: A set of supported parameter names. + default_params: A dictionary of default parameter values. + validation_rules: Validation rules for parameters (described in detail below). + """ + + def __init__( + self, + operator_name: str, + supported_params: set, + default_params: Dict[str, Any], + validation_rules: Dict[str, Dict[str, Any]] + ): + self.operator_name = operator_name + self.supported_params = supported_params + self.default_params = default_params + self.validation_rules = validation_rules + + SPECIAL_PARAMS_NO_WARNING = {"BM_list", "BN_list"} + + def validate_parameters(self, **kwargs: Any) -> bool: + invalid_params = [k for k in kwargs.keys() if k not in _ALL_PARAMS] + if invalid_params: + print(f"[ERROR] Invalid parameters for {self.operator_name}: {invalid_params}") + return False + + for param_name, rule in self.validation_rules.items(): + if param_name not in kwargs: + continue + + value = kwargs[param_name] + if not rule["check"](value, param_name): + print(f"[ERROR] Invalid value for '{param_name}' in {self.operator_name}: {value}") + print(f" Expected: {rule['desc']}") + return False + + return True + + def get_configs(self, **kwargs: Any) -> List[triton.Config]: + import triton + if not self.validate_parameters(**kwargs): + return [] + + params = self.default_params.copy() + bm_list = kwargs.get("BM_list") + bn_list = kwargs.get("BN_list") + + if bm_list is not None: + params["BM_list"] = bm_list + if bn_list is not None: + params["BN_list"] = bn_list + + for k, v in kwargs.items(): + if k in self.supported_params: + params[k] = v + + valid_kwargs = {k: v for k, v in kwargs.items() if k in self.supported_params} + + other_kwargs = {k: v for k, v in kwargs.items() if k not in self.supported_params and k not in self.SPECIAL_PARAMS_NO_WARNING} + if other_kwargs: + print( + f"[WARNING] Parameter(s) {list(other_kwargs.keys())} do not belong to {self.operator_name} and have been ignored.") + + configs = [] + + bm_list = params.get("BM_list", _DEFAULTS["BM_list"]) + bn_list = params.get("BN_list", _DEFAULTS["BN_list"]) + limit_flag = valid_kwargs.get("limit_auto_multi_buffer_only_for_local_buffer", [False])[0] + + dynamic_params = {} + + for param_name in sorted(self.supported_params): + if param_name == "limit_auto_multi_buffer_only_for_local_buffer": + dynamic_params[param_name] = valid_kwargs.get(param_name, _DEFAULTS[param_name]) + elif param_name in ["set_workspace_multibuffer", "enable_hivm_auto_cv_balance", "tile_mix_vector_loop", "tile_mix_cube_loop"]: + if not limit_flag: + dynamic_params[param_name] = valid_kwargs.get(param_name, _DEFAULTS[param_name]) + else: + dynamic_params[param_name] = valid_kwargs.get(param_name, _DEFAULTS[param_name]) + + other_params = {} + for param_name in sorted(dynamic_params.keys()): + if param_name in valid_kwargs: + other_params[param_name] = valid_kwargs[param_name] + else: + other_params[param_name] = _DEFAULTS.get(param_name, [True]) + + bm_bn_combos = list(itertools.product(bm_list, bn_list)) + other_combos = list(itertools.product(*other_params.values())) + all_combos = list(itertools.product(bm_bn_combos, other_combos)) + for (bm, bn), other_values in all_combos: + config_kwargs = { + "BLOCK_M": bm, + "BLOCK_N": bn, + } + for i, param_name in enumerate(sorted(other_params.keys())): + config_kwargs[param_name] = other_values[i] + configs.append(triton.Config(config_kwargs)) + return configs + + +CubeAutotuner = BaseAutotuner( + operator_name="cube", + supported_params=_CUBE_PARAMS, + default_params=_DEFAULTS, + validation_rules=_VALIDATION_RULES +) + +MixcvAutotuner = BaseAutotuner( + operator_name="mixcv", + supported_params=_MIXCV_PARAMS, + default_params=_DEFAULTS, + validation_rules=_VALIDATION_RULES +) + +VectorAutotuner = BaseAutotuner( + operator_name="vector", + supported_params=_VECTOR_PARAMS, + default_params=_DEFAULTS, + validation_rules=_VALIDATION_RULES +) + + +def get_autotune_cube_config(**kwargs: Any) -> List[triton.Config]: + """ + Generate autotune configuration for the cube operator. + Supported parameters: multibuffer, unit_flag, limit_auto_multi_buffer_of_local_buffer. + """ + import triton + return CubeAutotuner.get_configs(**kwargs) + + +def get_autotune_cv_config(**kwargs: Any) -> List[triton.Config]: + """ + Generate autotune configuration for the mixcv operator. + Supported parameters: multibuffer, unit_flag, limit_auto_multi_buffer_only_for_local_buffer, + limit_auto_multi_buffer_of_local_buffer, set_workspace_multibuffer, + enable_hivm_auto_cv_balance, tile_mix_vector_loop, tile_mix_cube_loop + """ + import triton + return MixcvAutotuner.get_configs(**kwargs) + + +def get_autotune_vector_config(**kwargs: Any) -> List[triton.Config]: + """ + Generate autotune configuration for the vector operator. + Supported parameters: multibuffer + """ + import triton + return VectorAutotuner.get_configs(**kwargs) diff --git a/third_party/ascend/backend/spec/triton/runtime/code_cache.py b/third_party/ascend/backend/spec/triton/runtime/code_cache.py index 43d841cd3d..563d46c8af 100644 --- a/third_party/ascend/backend/spec/triton/runtime/code_cache.py +++ b/third_party/ascend/backend/spec/triton/runtime/code_cache.py @@ -1,4 +1,4 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright (c) FlagOpen contributors # Copyright 2018-2020 Philippe Tillet # Copyright 2020-2022 OpenAI # Copyright © 2024 BAAI. All rights reserved. diff --git a/third_party/ascend/backend/spec/triton/runtime/interpreter.py b/third_party/ascend/backend/spec/triton/runtime/interpreter.py index 7ad9b1b9f0..d15232ef90 100644 --- a/third_party/ascend/backend/spec/triton/runtime/interpreter.py +++ b/third_party/ascend/backend/spec/triton/runtime/interpreter.py @@ -14,6 +14,24 @@ from .._C.libtriton import interpreter as _interpreter from .._C.libtriton import ir as _ir +# Import Ascend-specific interpreter builder (with deferred import to avoid circular dependency) +_has_ascend_support = False +AscendInterpreterBuilder = None + +def _try_import_ascend(): + global _has_ascend_support, AscendInterpreterBuilder + try: + from . import ascend_interpreter + AscendInterpreterBuilder = ascend_interpreter.AscendInterpreterBuilder + _has_ascend_support = True + except ImportError as e: + _has_ascend_support = False + AscendInterpreterBuilder = None + except Exception as e: + # Catch other exceptions (like circular import) and log them + _has_ascend_support = False + AscendInterpreterBuilder = None + class TensorHandle: @@ -80,7 +98,7 @@ class InterpreterOptions: supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15") deprecated_fp8_dtypes: Tuple[str] = () default_dot_input_precision: str = "tf32" - allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", "hf32") max_num_imprecise_acc_default: int = 0 backend_name: str = "interpreter" @@ -140,6 +158,8 @@ def _convert_float(input, input_dtype, output_dtype, rounding_mode): bias_input = input_dtype.exponent_bias bias_output = output_dtype.exponent_bias exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + # mark NAN value + input_nan_index = (exponent == (1 << input_exponent_width) - 1) & (significand != 0) subnormal_index = exponent == 0 if np.any(subnormal_index): # Credit to Phil: phil@openai.com @@ -159,8 +179,13 @@ def _convert_float(input, input_dtype, output_dtype, rounding_mode): significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & ( (1 << input_dtype.fp_mantissa_width) - 1) # Prevent overflow and underflow - exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1)) + exponent_unclamped = exponent - bias_input + bias_output + output_max_exponent = (1 << output_exponent_width) - 1 + exponent_output = np.maximum(0, np.minimum(exponent_unclamped, output_max_exponent)) exponent_output = exponent_output.astype(output_unint_dtype) + # mark overflow index + overflow_index = exponent_unclamped > output_max_exponent - 1 + sign_output = sign.astype(output_unint_dtype) if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( @@ -188,6 +213,8 @@ def _convert_float(input, input_dtype, output_dtype, rounding_mode): shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input) significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) + # covert overflow value to inf + significand_output[overflow_index & ~input_nan_index] = 0 output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( exponent_output << output_dtype.fp_mantissa_width) | significand_output return output.reshape(input.shape) @@ -245,8 +272,6 @@ def __init__(self) -> None: # For interpreter mode, don't enforce GPU hardware shape constraints # NumPy matmul works with any size, including small matrices self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1) - # Sub-vector core ID for simulating 1:2 hardware ratio - self.sub_vec_id = 0 def set_grid_idx(self, x, y, z): if not x < self.grid_dim[0]: @@ -612,260 +637,6 @@ def create_splat(self, arg, shape): else: # scalar return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) - # Extension ops for Ascend - def create_extract_scalar(self, tensor_handle, indices): - """ - Extract a scalar from a tensor using indices (equivalent to get_element). - - :param tensor_handle: The tensor to extract from - :param indices: List of scalar indices (can be TensorHandle or Python int) - :return: Scalar value - """ - # Convert indices from TensorHandle or Python int to integers - index_values = [] - for idx in indices: - if isinstance(idx, int): - # Python int passed directly (e.g., from loop counter) - index_values.append(idx) - elif isinstance(idx, TensorHandle): - # Interpreter TensorHandle - index_values.append(int(idx.data.item()) if hasattr(idx.data, 'item') else int(idx.data)) - else: - # Fallback: try to extract data - index_values.append( - int(idx.data.item()) if hasattr(idx, 'data') and hasattr(idx.data, 'item') else - int(idx.data) if hasattr(idx, 'data') else int(idx)) - - # Extract the scalar value - scalar_data = tensor_handle.data[tuple(index_values)] - return TensorHandle(np.array([scalar_data]), tensor_handle.dtype.scalar) - - def create_insert_slice(self, full_tensor, sub_tensor, offsets, sizes, strides): - """ - Insert a sub-tensor into a full tensor at specified offsets. - - :param full_tensor: The full tensor (destination) - :param sub_tensor: The sub-tensor to insert - :param offsets: List of offset TensorHandle objects or Python ints - :param sizes: List of size integers - :param strides: List of stride integers - :return: Modified tensor with sub_tensor inserted - """ - result = full_tensor.data.copy() - - # Convert offsets from TensorHandle or Python int to integers - offset_values = [] - for off in offsets: - if isinstance(off, int): - # Python int passed directly - offset_values.append(off) - elif isinstance(off, TensorHandle): - # Interpreter TensorHandle - offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) - else: - # Fallback - offset_values.append( - int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') else - int(off.data) if hasattr(off, 'data') else int(off)) - - # Build slices for insertion - slices = [] - for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): - end = offset + size * stride - if stride == 1: - slices.append(slice(offset, end)) - else: - slices.append(slice(offset, end, stride)) - - # Insert the sub-tensor - result[tuple(slices)] = sub_tensor.data - - return TensorHandle(result, full_tensor.dtype.scalar) - - def create_extract_slice(self, full_tensor, offsets, sizes, strides): - """ - Extract a slice from a full tensor. - - :param full_tensor: The full tensor - :param offsets: List of offset TensorHandle objects or Python ints - :param sizes: List of size integers - :param strides: List of stride integers - :return: Extracted sub-tensor - """ - # Convert offsets from TensorHandle or Python int to integers - offset_values = [] - for off in offsets: - if isinstance(off, int): - # Python int passed directly - offset_values.append(off) - elif isinstance(off, TensorHandle): - # Interpreter TensorHandle - offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) - else: - # Fallback - offset_values.append( - int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') else - int(off.data) if hasattr(off, 'data') else int(off)) - - # Build slices for extraction - slices = [] - for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): - end = offset + size * stride - if stride == 1: - slices.append(slice(offset, end)) - else: - slices.append(slice(offset, end, stride)) - - # Extract the slice - extracted = full_tensor.data[tuple(slices)] - - return TensorHandle(extracted, full_tensor.dtype.scalar) - - def create_index_select_simd(self, src_ptr, index_tensor, dim, src_shape, src_offset, read_shape, result_shape): - """ - SIMD index_select operation (gather with indices along a dimension). - - :param src_ptr: Source tensor pointer - :param index_tensor: 1D tensor of indices - :param dim: Dimension to select from - :param src_shape: List of source shape (int or TensorHandle) - :param src_offset: List of source offset (int or TensorHandle) - :param read_shape: List of read shape (int or TensorHandle) - :param result_shape: List of result shape (int or TensorHandle) - :return: Result tensor with selected indices - """ - - # Convert src_shape, src_offset, read_shape to integers - def to_int(val): - if isinstance(val, TensorHandle): - return int(val.data.item()) - return int(val) - - src_shape_vals = [to_int(s) for s in src_shape] - src_offset_vals = [to_int(o) if o != -1 else -1 for o in src_offset] - read_shape_vals = [to_int(r) if r != -1 else -1 for r in read_shape] - result_shape_vals = [to_int(r) for r in result_shape] - - # Get index values - handle both array and TensorHandle - if isinstance(index_tensor, TensorHandle): - indices = index_tensor.data.flatten() - else: - indices = np.asarray(index_tensor).flatten() - - # Ensure indices are integers - if indices.dtype not in [np.int32, np.int64]: - indices = indices.astype(np.int32) - - # Create result tensor - result = np.empty(result_shape_vals, dtype=src_ptr.data.dtype) - - # Perform index_select: for each index, read the specified data - for out_idx, in_idx in enumerate(indices): - in_idx = int(in_idx) - - # Validate index bounds - if not (0 <= in_idx < src_shape_vals[dim]): - # Out of bounds - fill with zeros - result_slices = [slice(None)] * len(result_shape_vals) - result_slices[dim] = slice(out_idx, out_idx + 1) - result[tuple(result_slices)] = 0 - continue - - # Build source slice - src_slices = [] - for d in range(len(src_shape_vals)): - if d == dim: - src_slices.append(slice(in_idx, in_idx + 1)) - else: - offset = src_offset_vals[d] if src_offset_vals[d] != -1 else 0 - read_size = read_shape_vals[d] if read_shape_vals[d] != -1 else src_shape_vals[d] - # Clamp to valid range - offset = max(0, min(offset, src_shape_vals[d] - 1)) - read_size = min(read_size, src_shape_vals[d] - offset) - src_slices.append(slice(offset, offset + read_size)) - - # Build result slice - result_slices = [] - for d in range(len(result_shape_vals)): - if d == dim: - result_slices.append(slice(out_idx, out_idx + 1)) - else: - result_slices.append(slice(None)) - - # Copy data with proper shape handling - try: - src_data = src_ptr.data[tuple(src_slices)] - # Handle shape mismatch by resizing - target_shape = [result_shape_vals[d] if d != dim else 1 for d in range(len(result_shape_vals))] - if src_data.shape != tuple(target_shape): - # Pad or trim as needed - pad_width = [(0, target_shape[d] - src_data.shape[d]) for d in range(len(target_shape))] - src_data = np.pad(src_data, pad_width, mode='constant', constant_values=0) - result[tuple(result_slices)] = src_data - except Exception as e: - # On error, fill with zeros - result[tuple(result_slices)] = 0 - - return TensorHandle(result, src_ptr.dtype.scalar) - - def create_get_sub_vec_id(self): - """ - Get the Vector Core index on the AI Core. - - In Interpreter mode, simulate multiple vector cores by maintaining - a sub_vec_id counter. This is used for 1:2 hardware ratio emulation - where different vector cores process different partitions of the data. - - :return: Vector Core ID as TensorHandle (int64, scalar) - """ - # Return the current sub_vec_id (set by GridExecutor) - vec_id = np.int64(self.sub_vec_id) - return TensorHandle(np.array([vec_id], dtype=np.int64), tl.int64) - - def sync_block_set(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): - """ - Set synchronization event between compute and vector units. - - In Interpreter mode, this is a no-op since we execute single-threaded. - Synchronization is not needed in CPU emulation. - - :param sender: Source unit ("cube" or "vector") - :param receiver: Destination unit ("cube" or "vector") - :param event_id: Event ID (TensorHandle) - :param sender_pipe_value: Sender pipe value - :param receiver_pipe_value: Receiver pipe value - """ - # No-op in interpreter mode: single-threaded execution doesn't need sync - pass - - def sync_block_wait(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): - """ - Wait for synchronization event between compute and vector units. - - In Interpreter mode, this is a no-op since we execute single-threaded. - Synchronization is not needed in CPU emulation. - - :param sender: Source unit ("cube" or "vector") - :param receiver: Destination unit ("cube" or "vector") - :param event_id: Event ID (TensorHandle) - :param sender_pipe_value: Sender pipe value - :param receiver_pipe_value: Receiver pipe value - """ - # No-op in interpreter mode: single-threaded execution doesn't need sync - pass - - def sync_block_all(self, mode, event_id): - """ - Synchronize all compute or vector units globally. - - In Interpreter mode, this is a no-op since we execute single-threaded. - Synchronization is not needed in CPU emulation. - - :param mode: Sync mode ("all_cube", "all_vector", "all", "all_sub_vector") - :param event_id: Event ID (int, constexpr, or TensorHandle) - """ - # No-op in interpreter mode: single-threaded execution doesn't need sync - pass def create_atomic_cas(self, ptr, cmp, val, sem, scope): if sem not in self.ir_sem_to_interpreter_sem: @@ -1265,32 +1036,10 @@ def _patch_lang(fn): _patch_builtin(lang.math, interpreter_builder) _patch_lang_tensor(lang.tensor) _patch_lang_core(lang) - - # Patch all modules in fn's globals that might be extension modules - for name, value in list(fn.__globals__.items()): - if value is None: - continue - try: - # Check if it looks like an extension module (has builtin functions) - if hasattr(value, '__name__') and 'extension' in str(value.__name__): - _patch_builtin(value, interpreter_builder) - # Also try patching any module-like object that might have builtin functions - elif hasattr(value, '__dict__') and not isinstance(value, type): - # Try to patch it and ignore if it fails - try: - _patch_builtin(value, interpreter_builder) - except Exception: - pass - except Exception: - pass - - # Also try importing extension directly as fallback - try: - import triton.language.extra.cann.extension as extension - _patch_builtin(extension, interpreter_builder) - except (ImportError, AttributeError): - # Extension module not available (e.g., non-Ascend backend) - pass + + # Patch Ascend extensions if using AscendInterpreterBuilder + if hasattr(interpreter_builder, 'patch_extensions'): + interpreter_builder.patch_extensions(fn) # TODO: wrap everything in triton tensors @@ -1317,10 +1066,19 @@ def _implicit_cvt(arg): return arg -interpreter_builder = InterpreterBuilder() +# Use AscendInterpreterBuilder if available, otherwise fall back to base InterpreterBuilder +_try_import_ascend() +if _has_ascend_support and AscendInterpreterBuilder is not None: + interpreter_builder = AscendInterpreterBuilder() +else: + interpreter_builder = InterpreterBuilder() # These keywords are not supported by the interpreter -RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg", "multibuffer"] +RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"] + +# Allow Ascend interpreter to extend reserved keywords +if hasattr(interpreter_builder, 'get_additional_reserved_keywords'): + RESERVED_KWS.extend(interpreter_builder.get_additional_reserved_keywords()) class GridExecutor: @@ -1379,26 +1137,14 @@ def __call__(self, *args_dev, **kwargs): assert len(grid) <= 3, "grid must have at most 3 dimensions" grid = grid + (1, ) * (3 - len(grid)) interpreter_builder.set_grid_dim(*grid) - - # Infer the number of sub-vector cores from kernel parameters - # Check for M and sub_M parameters (common pattern for 1:2 ratio) - num_sub_vec_ids = 1 - if 'M' in args and 'sub_M' in args: - M = args['M'] - sub_M = args['sub_M'] - # Extract scalar values if they're TensorHandle - if isinstance(M, TensorHandle): - M = int(M.data.item() if hasattr(M.data, 'item') else M.data) - if isinstance(sub_M, TensorHandle): - sub_M = int(sub_M.data.item() if hasattr(sub_M.data, 'item') else sub_M.data) - # Number of vector cores = M / sub_M - if isinstance(M, int) and isinstance(sub_M, int) and sub_M > 0: - num_sub_vec_ids = max(1, M // sub_M) - + try: - # Loop over sub-vector IDs to simulate parallel vector core execution - for sub_vec_id in range(num_sub_vec_ids): - interpreter_builder.sub_vec_id = sub_vec_id + # Execute kernels - sub_vec_id simulation handled by AscendInterpreterBuilder + if hasattr(interpreter_builder, 'execute_with_sub_vec_simulation'): + # Ascend builder with sub-vector simulation + interpreter_builder.execute_with_sub_vec_simulation(self.fn, args, grid) + else: + # Standard execution for base interpreter for x in range(grid[0]): for y in range(grid[1]): for z in range(grid[2]): diff --git a/third_party/ascend/backend/spec/triton/runtime/jit.py b/third_party/ascend/backend/spec/triton/runtime/jit.py index 45178a40bb..17271aac93 100644 --- a/third_party/ascend/backend/spec/triton/runtime/jit.py +++ b/third_party/ascend/backend/spec/triton/runtime/jit.py @@ -9,9 +9,12 @@ from collections import defaultdict from functools import cached_property from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple -from ..runtime.driver import driver from types import ModuleType +from triton._C.libtriton import get_cache_invalidating_env_vars +from .driver import driver +from . import _async_compile + TRITON_MODULE = __name__[:-len(".runtime.jit")] T = TypeVar("T") @@ -616,17 +619,9 @@ def run(self, *args, grid, warmup, **kwargs): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") - if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): + kernel = self._do_compile(key, signature, device, backend, target, constants, options, configs[0], warmup) + if kernel is None: return None - # compile the kernel - src = self.ASTSource(self, signature, constants, configs[0]) - kernel = self.compile( - src, - target=target, - options=options.__dict__, - ) - self.cache[device][key] = kernel - self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) # Check that used global values have not changed. not_present = object() @@ -647,6 +642,8 @@ def run(self, *args, grid, warmup, **kwargs): grid_0 = grid[0] grid_1 = grid[1] if grid_size > 1 else 1 grid_2 = grid[2] if grid_size > 2 else 1 + if hasattr(kernel, "result"): + kernel = kernel.result() # launch kernel launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) @@ -728,7 +725,7 @@ def warmup(self, *args, grid, **kwargs): return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) def preload(self, specialization_data): - from ..compiler import compile, ASTSource + from ..compiler import make_backend from triton.backends.compiler import AttrsDescriptor import json import triton.language as tl @@ -742,14 +739,54 @@ def preload(self, specialization_data): for key, value in deserialized_obj['constants'].items() } signature = dict(deserialized_obj['signature'].items()) - src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) + options = { key: tuple(value) if isinstance(value, list) else value for key, value in deserialized_obj['options'].items() } key = deserialized_obj['key'] - kernel = compile(src, None, options) - self.cache[device][key] = kernel + target = driver.active.get_current_target() + backend = make_backend(target) + options = backend.parse_options(options) + attrs = AttrsDescriptor.from_dict(deserialized_obj['attrs']) + return self._do_compile( + key, + signature, + device, + backend, + target, + constants, + options, + attrs, + warmup=True, + ) + + def _do_compile(self, key, signature, device, backend, target, constants, options, attrs, warmup): + kernel_cache = self.cache[device] + + if self._call_hook(key, signature, device, constants, options, [attrs], warmup, before=True): + return None + src = self.ASTSource(self, signature, constants, attrs) + + async_mode = _async_compile.active_mode.get() + if async_mode is not None: + from triton.compiler.compiler import get_cache_key + + env_vars = get_cache_invalidating_env_vars() + cache_key = get_cache_key(src, backend, options, env_vars) + + def async_compile(): + return self.compile(src, target=target, options=options.__dict__, _env_vars=env_vars) + + def finalize_compile(kernel): + kernel_cache[key] = kernel + self._call_hook(key, signature, device, constants, options, [attrs], warmup, before=False) + + kernel = async_mode.submit(cache_key, async_compile, finalize_compile) + else: + kernel = self.compile(src, target=target, options=options.__dict__) + kernel_cache[key] = kernel + self._call_hook(key, signature, device, constants, options, [attrs], warmup, before=False) return kernel # we do not parse `src` in the constructor because diff --git a/third_party/ascend/backend/spec/triton/runtime/libentry.py b/third_party/ascend/backend/spec/triton/runtime/libentry.py index a358b9ae8c..3a4a0231e6 100644 --- a/third_party/ascend/backend/spec/triton/runtime/libentry.py +++ b/third_party/ascend/backend/spec/triton/runtime/libentry.py @@ -1,4 +1,3 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. # Copyright 2018-2020 Philippe Tillet # Copyright 2020-2022 OpenAI # Copyright © 2024 BAAI. All rights reserved. @@ -285,6 +284,9 @@ def libentry(): """ def decorator(fn): + from triton.runtime.interpreter import InterpretedFunction + if isinstance(fn, InterpretedFunction): + return fn return LibEntry(fn) return decorator diff --git a/third_party/ascend/backend/testing.py b/third_party/ascend/backend/testing.py index 65d5968dd7..5e588c5631 100644 --- a/third_party/ascend/backend/testing.py +++ b/third_party/ascend/backend/testing.py @@ -57,7 +57,8 @@ def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None if clear_l2_cache: buffer = runtime.driver.active.get_empty_cache_for_benchmark() - buffer.zero_() + buffer = buffer.float() # to avoid type cast + buffer.sum() torch.npu.synchronize() # shake out of any npu error total = warmup + active @@ -74,7 +75,8 @@ def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None for fn in funcs: for _ in builtins.range(total): if clear_l2_cache: - buffer.zero_() + buffer.sum() # use buffer read to clear l2 cache + torch.npu.synchronize() fn() torch.npu.synchronize() if clear_l2_cache: @@ -172,7 +174,7 @@ def _collect_prof_result(base_dir: str, funcs, num_warmup: int, num_active: int, df = pd.read_csv(kernel_details_file) # filter out l2 cache clearing operation - filter_cond = ~df["Type"].str.contains(r"^ZerosLike$", case=False, na=False) + filter_cond = ~df["Type"].str.contains(r"^ReduceSum$", case=False, na=False) filter_df = df[filter_cond] if key is not None: key_rows = filter_df[filter_df["Name"].str.contains(key, na=False)] diff --git a/third_party/ascend/backend/utils.py b/third_party/ascend/backend/utils.py index 8f90b5fa4e..efd42780e1 100644 --- a/third_party/ascend/backend/utils.py +++ b/third_party/ascend/backend/utils.py @@ -28,6 +28,7 @@ from pathlib import Path import logging from platform import python_version +from triton.tools.get_ascend_devices import is_compile_on_910_95 from triton.backends.ascend.backend_register import backend_strategy_registry import pybind11 @@ -152,7 +153,10 @@ def _get_llvm_path(path: str, *paths) -> str: def _get_npucompiler_path() -> str: ascend_dir = os.path.dirname(os.path.abspath(__file__)) env = os.environ.copy() - npu_compiler_path = os.path.join(ascend_dir, "bishengir", "bin", "bishengir-compile") + if is_compile_on_910_95: + npu_compiler_path = os.path.join(ascend_dir, "bishengir-a5", "bin", "bishengir-compile") + else: + npu_compiler_path = os.path.join(ascend_dir, "bishengir", "bin", "bishengir-compile") if os.path.exists(npu_compiler_path): npuir_env_path = os.path.dirname(npu_compiler_path) env["PATH"] = npuir_env_path + ":" + env["PATH"] @@ -263,6 +267,10 @@ def _enable_print_ub_bits() -> bool: return os.getenv("ENABLE_PRINT_UB_BITS", "false").lower() in ("true", "1") +def _enable_dump_memory_info() -> bool: + return os.getenv("TRITON_MEMORY_DISPLAY", "false").lower() in ("true", "1") + + def _get_cxx(): cxx = os.environ.get("CC") if cxx is None: @@ -302,11 +310,8 @@ def _precompile_npu_hash(header_src): return hash_txt -def _precompile_npu_ext(header_path): - src_dir = os.path.dirname(header_path) - gch_path = os.path.join(src_dir, "precompiled.h.gch") +def _precompile_npu_ext(header_path, gch_path): cxx = _get_cxx() - cc_cmd = [cxx, "-x", "c++-header", header_path] # disable all warnings cc_cmd += [f"-w"] @@ -344,12 +349,13 @@ def _precompile_npu_ext(header_path): cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-o", gch_path] - ret = subprocess.check_call(cc_cmd) + result = subprocess.run(cc_cmd, capture_output=True, text=True) - if ret != 0: - print(f"Unable to precompile header file, ret is: {ret}") + if result.returncode == 0: + return header_path + else: + raise RuntimeError(f"Failed to compile {gch_path}, error: {result.stderr},cmd={cc_cmd}") - return header_path def _build_npu_ext(obj_name: str, header_path, src_path, *, kernel_launcher="torch", precompile=False) -> str: @@ -399,8 +405,8 @@ def _build_npu_ext(obj_name: str, header_path, src_path, *, kernel_launcher="tor "-lascendcl", ] # FIXME: check why this condition works wrong in parall scene - # if kernel_launcher == "torch": - cc_cmd += get_backend_func("get_cc_cmd", build_pch=False) + if kernel_launcher == "torch": + cc_cmd += get_backend_func("get_cc_cmd", build_pch=False) cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-Winvalid-pch", "-o", so_path] @@ -413,7 +419,7 @@ def _build_npu_ext(obj_name: str, header_path, src_path, *, kernel_launcher="tor # only for clang++, when precompile invalid, fallback to normal compile return _build_npu_ext(obj_name, header_path, src_path, precompile=False) else: - raise RuntimeError(f"Failed to compile {src_path}, error: {result.stderr}") + raise RuntimeError(f"Failed to compile {src_path}, error: {result.stderr},cmd={cc_cmd}") def _get_kernel_target(metadata: dict): @@ -531,8 +537,11 @@ def is_ffts_supported(arch: str): Cases: - empty str: User does not specify arch, thus it runs on 910B/910D both of which support ffts. Return True. - Ascend310B4: 310B4 does not support ffts. Return False. + - Ascend910_95*: 910_95 does not support ffts. Return False. - Other arch: 910B/910D supports ffts. Return True. ''' + if is_compile_on_910_95: + return False if arch in ["Ascend910A", "Ascend310B4"]: return False return True @@ -541,5 +550,17 @@ def is_ffts_supported(arch: str): def force_disable_ffts(): ''' ''' + if is_compile_on_910_95: + return True disable_ffts = os.getenv("TRITON_DISABLE_FFTS", "false").lower() in ("true", "1") return disable_ffts + + +def triton_support_ffts(): + arch = get_ascend_arch_from_env() + return is_ffts_supported(arch) and (not force_disable_ffts()) + + +def triton_enable_libdevice_simt(): + enable_libdevice_simt = os.getenv("TRITON_ENABLE_LIBDEVICE_SIMT", False) + return enable_libdevice_simt \ No newline at end of file diff --git a/third_party/ascend/include/AutoBlockify/AutoBlockify.h b/third_party/ascend/include/AutoBlockify/AutoBlockify.h new file mode 100644 index 0000000000..b3633b76ab --- /dev/null +++ b/third_party/ascend/include/AutoBlockify/AutoBlockify.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#pragma once + +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" + +#define GEN_PASS_DECL_AUTOBLOCKIFY +#include "ascend/include/AutoBlockify/Passes.h.inc" + +#define GEN_PASS_DEF_AUTOBLOCKIFY +#include "ascend/include/AutoBlockify/Passes.h.inc" + +namespace mlir { +namespace triton { + +std::unique_ptr> +createAutoBlockifyPass(const AutoBlockifyOptions &options = {}); + +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace triton; + +class PropagateUnrealizedCastDown + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit PropagateUnrealizedCastDown(MLIRContext *context, + Value logicalBlockId, + Value logicalBlockNum, + int autoBlockifySize); + + LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, + PatternRewriter &rewriter) const override; + +private: + void handleBlockifyLoop(scf::ForOp blockifyLoop, Operation *op, PatternRewriter &rewriter) const; + void rewriteSplat(UnrealizedConversionCastOp op, triton::SplatOp splatOp, + PatternRewriter &rewriter) const; + void rewriteExpandDims(UnrealizedConversionCastOp op, + triton::ExpandDimsOp expandDimsOp, + PatternRewriter &rewriter) const; + void rewriteReduce(UnrealizedConversionCastOp op, triton::ReduceOp reduceOp, + PatternRewriter &rewriter) const; + void rewriteScan(UnrealizedConversionCastOp op, triton::ScanOp scanOp, + PatternRewriter &rewriter) const; + void rewriteLoad(UnrealizedConversionCastOp op, triton::LoadOp loadOp, + PatternRewriter &rewriter) const; + void rewriteStore(UnrealizedConversionCastOp op, triton::StoreOp storeOp, + PatternRewriter &rewriter) const; + void rewriteAtomicRMW(UnrealizedConversionCastOp op, + triton::AtomicRMWOp atomicRMWOp, + PatternRewriter &rewriter) const; + void rewriteAssert(UnrealizedConversionCastOp op, triton::AssertOp assertOp, + PatternRewriter &rewriter) const; + void rewriteExtractSlice(UnrealizedConversionCastOp op, + tensor::ExtractSliceOp extractSliceOp, + PatternRewriter &rewriter) const; + void rewriteInsertSlice(UnrealizedConversionCastOp op, + tensor::InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const; + void rewriteWhile(UnrealizedConversionCastOp op, scf::WhileOp whileOp, + PatternRewriter &rewriter) const; + void rewriteLoop(UnrealizedConversionCastOp op, LoopLikeOpInterface loopOp, + PatternRewriter &rewriter) const; + void rewriteIf(UnrealizedConversionCastOp &op, scf::IfOp ifOp, ArrayRef indices, + PatternRewriter &rewriter) const; + void rewriteYield(UnrealizedConversionCastOp &op, scf::YieldOp yieldOp, + PatternRewriter &rewriter) const; + void rewriteCondition(UnrealizedConversionCastOp op, + scf::ConditionOp conditionOp, + PatternRewriter &rewriter) const; + void rewriteGeneraleOp(UnrealizedConversionCastOp op, Operation *generalOp, + PatternRewriter &rewriter) const; + + Value logicalBlockId; + Value logicalBlockNum; + int autoBlockifySize; +}; + +class AutoBlockifyPass : public ::impl::AutoBlockifyBase { +public: + explicit AutoBlockifyPass(const AutoBlockifyOptions &options); + void runOnOperation() override; + +private: + bool checkBlockifiable(Value v); + void preProcess(triton::FuncOp func); + + DenseSet checkedValues; + Value logicalBlockId; + Value logicalBlockNum; +}; diff --git a/third_party/ascend/include/AutoBlockify/CMakeLists.txt b/third_party/ascend/include/AutoBlockify/CMakeLists.txt new file mode 100644 index 0000000000..abae4f6e9f --- /dev/null +++ b/third_party/ascend/include/AutoBlockify/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name AutoBlockify) +add_public_tablegen_target(AutoBlockifyPassIncGen) \ No newline at end of file diff --git a/third_party/ascend/include/AutoBlockify/Passes.h b/third_party/ascend/include/AutoBlockify/Passes.h new file mode 100644 index 0000000000..7d5147ef92 --- /dev/null +++ b/third_party/ascend/include/AutoBlockify/Passes.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef TRITON_ADAPTER_AUTO_BLOCKIFY_PASSES_H +#define TRITON_ADAPTER_AUTO_BLOCKIFY_PASSES_H + +#include "AutoBlockify.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "ascend/include/AutoBlockify/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_ADAPTER_AUTO_BLOCKIFY_PASSES_H diff --git a/third_party/ascend/include/AutoBlockify/Passes.td b/third_party/ascend/include/AutoBlockify/Passes.td new file mode 100644 index 0000000000..56ab1587be --- /dev/null +++ b/third_party/ascend/include/AutoBlockify/Passes.td @@ -0,0 +1,21 @@ +#ifndef AUTO_BLOCKIFY_PASSES +#define AUTO_BLOCKIFY_PASSES + +include "mlir/Pass/PassBase.td" + +def AutoBlockify : Pass<"auto-blockify", "mlir::ModuleOp"> { + let summary = "Apply auto blockify v2"; + let constructor = "triton::createAutoBlockifyPass()"; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::tensor::TensorDialect", + "mlir::triton::TritonDialect" + ]; + let options = [ + Option<"autoBlockifySize", "auto-blockify-size", "int", "1", + "Apply auto blockify v2 when TRITON_ALL_BLOCKS_PARALLEL is 1." + "Expand highest dimension with blockify size"> + ]; +} + +#endif // AUTO_BLOCKIFY_PASSES diff --git a/third_party/ascend/include/AutoBlockify/Utils.h b/third_party/ascend/include/AutoBlockify/Utils.h new file mode 100644 index 0000000000..639922c8ee --- /dev/null +++ b/third_party/ascend/include/AutoBlockify/Utils.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#pragma once + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +using namespace mlir; +using namespace triton; + +constexpr llvm::StringLiteral autoBlockifySizeAttr = "auto_blockify_size"; +constexpr llvm::StringLiteral logicalBlockIdAttr = "logical_block_id"; +constexpr llvm::StringLiteral autoBlockifyLoopAttr = + "auto_blockify_loop"; +constexpr llvm::StringLiteral autoBlockifyRegionOpAttr = + "auto_blockify_region_op"; + +RankedTensorType getExpandedType(Type type, UnrealizedConversionCastOp op); + +Value rewriteValue(Value value, UnrealizedConversionCastOp op, + OpBuilder &builder); + +void replaceValue(Operation *newOp, Operation *oldOp, Value newMask, + RewriterBase &rewriter, + ArrayRef replaceIndices = {}); + +Value createMask(Value mask, Value uccMask, ArrayRef targetShape, + RewriterBase &rewriter); + +void mapRegionIterArg(IRMapping &mapping, ValueRange oldArgs, + ValueRange newArgs, ArrayRef indices, Value mask, + OpBuilder &builder); + +void mapYieldedValue(IRMapping &mapping, scf::YieldOp yieldOp, + ArrayRef indices, UnrealizedConversionCastOp op, + OpBuilder &builder); + +Operation *createBlockifyLoop(Operation *targetOp, + UnrealizedConversionCastOp op, + Value logicalBlockId, Value logicalBlockNum, + int autoBlockifySize, RewriterBase &rewriter); + +std::optional getBlockifyLoop(Operation *op); \ No newline at end of file diff --git a/third_party/ascend/include/CMakeLists.txt b/third_party/ascend/include/CMakeLists.txt index 9cd93fe4ae..2f77a4a5c4 100644 --- a/third_party/ascend/include/CMakeLists.txt +++ b/third_party/ascend/include/CMakeLists.txt @@ -1,3 +1,12 @@ -add_subdirectory(TritonToLLVM) -add_subdirectory(TritonToHIVM) -add_subdirectory(TritonToHFusion) +add_subdirectory(Dialect) +add_subdirectory(TritonToAnnotation) +add_subdirectory(TritonToHFusion) +add_subdirectory(TritonToHIVM) +add_subdirectory(TritonToLinalg) +add_subdirectory(Utils) +add_subdirectory(DiscreteMaskAccessConversion) +add_subdirectory(TritonToUnstructure) +add_subdirectory(TritonToLLVM) +add_subdirectory(TritonToStructured) +add_subdirectory(AutoBlockify) +add_subdirectory(TritonAffinityOpt) diff --git a/third_party/ascend/include/TritonAffinityOpt/CMakeLists.txt b/third_party/ascend/include/TritonAffinityOpt/CMakeLists.txt new file mode 100644 index 0000000000..4a804f0784 --- /dev/null +++ b/third_party/ascend/include/TritonAffinityOpt/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonAffinityOpt) +add_public_tablegen_target(TritonAffinityOptConversionPassIncGen) \ No newline at end of file diff --git a/third_party/ascend/include/TritonAffinityOpt/DAG.h b/third_party/ascend/include/TritonAffinityOpt/DAG.h new file mode 100644 index 0000000000..ecebcf4e3e --- /dev/null +++ b/third_party/ascend/include/TritonAffinityOpt/DAG.h @@ -0,0 +1,330 @@ +#ifndef AffinityDAGDEF +#define AffinityDAGDEF +#include "Utils.hpp" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TinyPtrVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include +#include +#include +#include +#include +#include + +namespace mlir { namespace AffinityDAG { + +enum class OpAbility { + PREFER_VECTOR = 1 << 0, + CUBE_ONLY = 1 << 1, + CUBE_AND_VECTOR = PREFER_VECTOR | CUBE_ONLY + +}; + +enum CoreType { + UNDETERMINED = 0, + VECTOR_ONLY = 1 << 0, + CUBE_ONLY = 1 << 1, + CUBE_AND_VECTOR = VECTOR_ONLY | CUBE_ONLY +}; + +inline constexpr CoreType toCoreType(OpAbility ct) { + using U = std::underlying_type_t; + return static_cast(static_cast(ct)); +} + +constexpr inline CoreType operator| (CoreType lhs, CoreType rhs) { + return enumOp(std::bit_or<>(), lhs, rhs); +} + +inline CoreType operator& (CoreType lhs, CoreType rhs) { + return enumOp(std::bit_and<>(), lhs, rhs); +} + +inline bool intersects(CoreType lhs, CoreType rhs) { + return (lhs & rhs) != CoreType::UNDETERMINED; +} + +inline CoreType operator& (OpAbility lhs, CoreType rhs) { + return toCoreType(lhs) & rhs; +} + +inline CoreType operator!(CoreType ct) +{ + CoreType newCt = UNDETERMINED; + if ((ct & CoreType::CUBE_ONLY) == UNDETERMINED) { + newCt = newCt | CoreType::CUBE_ONLY; + } + + if ((ct & CoreType::VECTOR_ONLY) == UNDETERMINED) { + newCt = newCt | CoreType::VECTOR_ONLY; + } + + return newCt; +} + +inline hivm::TCoreType toHivm(CoreType ct) +{ + switch (ct) { + case UNDETERMINED: + return hivm::TCoreType::CUBE_OR_VECTOR; + case CUBE_ONLY: + return hivm::TCoreType::CUBE; + case VECTOR_ONLY: + return hivm::TCoreType::VECTOR; + case CUBE_AND_VECTOR: + return hivm::TCoreType::CUBE_AND_VECTOR; + default: + llvm_unreachable("Invalid CoreType that cannot convert to hivm"); + } +} + +inline bool intersects(OpAbility lhs, CoreType rhs) { + return (lhs & rhs) != CoreType::UNDETERMINED; +} + +inline bool exactlyOneType(CoreType ct) { + return (ct == CUBE_ONLY) || (ct == VECTOR_ONLY); +} + +const char* literalCoreType(CoreType ct); + +class MoveOnly { +protected: + MoveOnly() = default; + ~MoveOnly() = default; + + MoveOnly(const MoveOnly &) = delete; + MoveOnly &operator=(const MoveOnly &) = delete; + + MoveOnly(MoveOnly &&) = default; + MoveOnly &operator=(MoveOnly &&) = default; +}; + +class Node; +class OpNode; +class ValueNode; + +ValueNode *getDataSource(OpNode *op); + +class Graph : MoveOnly { +public: + using OpMapRaw = llvm::DenseMap>; + using ValueMapRaw = llvm::DenseMap>; + using OpMap = std::shared_ptr; + using ValueMap = std::shared_ptr; + + Graph( + Block* block, + Graph* parent = nullptr, + OpMap opMap = nullptr, + ValueMap valueMap = nullptr, + bool inheritParent = true + ); + + static std::unique_ptr fromMultiBlockFunc(triton::FuncOp funcOp); + + OpMapRaw& getOpMap() const { + return *opMap; + } + + ValueMapRaw& getValueMap() const { + return *valueMap; + } + + // [DEBUG] start + std::unique_ptr> legacyOpMap = nullptr; + std::unique_ptr> legacyValueTypes = nullptr; + + inline llvm::DenseMap& getOpMapLegacy() { + if (!legacyOpMap) { + legacyOpMap = std::move(std::make_unique>()); + for(auto& [key, val] : *opMap) { + (*legacyOpMap)[key] = val.get(); + } + } + + return *legacyOpMap; + } + + llvm::DenseMap& getValueTypes() ; + + // [DEBUG] end + +private: + friend class Node; + friend class OpNode; + OpMap opMap; + ValueMap valueMap; + Block* block; + Graph* parent; + OpNode* terminator = nullptr; + size_t opCount = 0; + llvm::SmallVector blockArgs; +}; + +class Node : MoveOnly { +protected: + friend class Graph; + friend class ValueNode; + bool isUpstreamOfCubeMem = false; + virtual CoreType absorbImpl() = 0; + llvm::SmallVector outputs; + +public: + CoreType isOnPrivate = UNDETERMINED; + + enum NodeKind { + NK_Op, + NK_Value + }; + + inline CoreType isOn() const { + return isOnPrivate; + } + + bool absorb() { + auto newCoreType = absorbImpl(); + auto changed = newCoreType != isOnPrivate; + isOnPrivate = newCoreType; + + return changed; + }; + + virtual llvm::SmallVector getAffected() const = 0; + virtual OpNode* getSourceOpNode() = 0; + + ArrayRef getOutputs() const { + return outputs; + } + + CoreType absorbCommon(); + +private: + const NodeKind kind; + +public: + NodeKind getKind() const { + return kind; + } + +protected: + Node(NodeKind kind) : kind(kind) {} +}; + +class OpNode : public Node { + friend class Graph; + friend class ValueNode; + llvm::SmallVector inputs; + llvm::SmallVector subgraphs; + virtual CoreType absorbImpl() override; + +public: + Operation* op; + + OpNode(Operation* op, Graph* graph); + OpAbility canRunOn() const; + inline ArrayRef getInputs() const { + return inputs; + } + + static bool classof(const Node* node) { + return node->getKind() == NK_Op; + } + + virtual llvm::SmallVector getAffected() const override { + llvm::SmallVector result(inputs.begin(), inputs.end()); + result.append(outputs.begin(), outputs.end()); + + return result; + } + + virtual OpNode* getSourceOpNode() override { + return this; + } +}; + +class ValueNode : public Node { + friend class Graph; + friend class OpNode; + virtual CoreType absorbImpl() override; +public: + + Node* source = nullptr; + Value value; + // ValueNode(OpResult value); + // ValueNode(BlockArgument value); + + ValueNode(Value value) : Node(NK_Value), value(value) {}; + virtual OpNode* getSourceOpNode() override { + if (!source) { + return nullptr; + } + + return source->getSourceOpNode(); + } + static bool classof(const Node* node) { + return node->getKind() == NK_Value; + } + + virtual llvm::SmallVector getAffected() const override { + llvm::SmallVector result(outputs.begin(), outputs.end()); + if (source) + result.push_back(source); + + return result; + } +}; + +class GraphManager { +private: + llvm::DenseMap> graphs; + +public: + static GraphManager &getInstance() { + static GraphManager instance; + return instance; + } + + void registerGraph(llvm::StringRef funcName, std::shared_ptr graph) { + graphs[funcName] = graph; + } + + AffinityDAG::Graph* getGraph(llvm::StringRef funcName) { + auto it = graphs.find(funcName); + return it != graphs.end() ? it->second.get() : nullptr; + } + + void removeGraph(llvm::StringRef funcName) { + graphs.erase(funcName); + } +}; + + +inline llvm::DenseMap& Graph::getValueTypes() { + static std::mutex mtx; + std::lock_guard lock(mtx); + if (!legacyValueTypes) { + legacyValueTypes = std::move(std::make_unique>()); + for(auto& [key, val] : *valueMap) { + llvm::dbgs() << key << "\n"; + llvm::dbgs().flush(); + (*legacyValueTypes)[key] = val.get()->isOn(); + } + } + + return *legacyValueTypes; +} + +} } +#endif diff --git a/third_party/ascend/include/TritonAffinityOpt/Passes.h b/third_party/ascend/include/TritonAffinityOpt/Passes.h new file mode 100644 index 0000000000..5c9a63225f --- /dev/null +++ b/third_party/ascend/include/TritonAffinityOpt/Passes.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef TRITON_ADAPTER_TRITON_AFFINITY_OPTIMIZATION_PASSES_H +#define TRITON_ADAPTER_TRITON_AFFINITY_OPTIMIZATION_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +// Forward declarations. +class ModuleOp; + +namespace triton { + +/// Creates a pass to convert Triton dialect to Annotation dialect. +std::unique_ptr> createDAGSSBufferPass(); + +std::unique_ptr> createDAGSyncPass(); + +std::unique_ptr> createDAGScopePass(); + +#define GEN_PASS_REGISTRATION +#include "ascend/include/TritonAffinityOpt/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_ADAPTER_TRITON_AFFINITY_OPTIMIZATION_PASSES_H \ No newline at end of file diff --git a/third_party/ascend/include/TritonAffinityOpt/Passes.td b/third_party/ascend/include/TritonAffinityOpt/Passes.td new file mode 100644 index 0000000000..b2a72f58db --- /dev/null +++ b/third_party/ascend/include/TritonAffinityOpt/Passes.td @@ -0,0 +1,29 @@ +/* + * Copyright (c) Huawei Technologies Co. + * Licensed under the MIT license. + */ + +#ifndef TRITON_AFFINITY_OPTIMIZATION_PASSES +#define TRITON_AFFINITY_OPTIMIZATION_PASSES + +include "mlir/Pass/PassBase.td" + +def DAGSSBuffer : Pass<"dag-ssbuf", "mlir::ModuleOp"> { + let summary = "Convert vector operations to shared storage buffer operations"; + let constructor = "triton::createDAGSSBufferPass()"; + let dependentDialects = ["hivm::HIVMDialect", "bufferization::BufferizationDialect", "scope::ScopeDialect", "annotation::AnnotationDialect"]; +} + +def DAGScope : Pass<"dag-scope", "mlir::ModuleOp"> { + let summary = "Convert native triton code to NPU-affine code"; + let constructor = "triton::createDAGScopePass()"; + let dependentDialects = ["hivm::HIVMDialect", "bufferization::BufferizationDialect", "scope::ScopeDialect", "annotation::AnnotationDialect"]; +} + +def DAGSync : Pass<"dag-sync", "mlir::ModuleOp"> { + let summary = "DAG sync"; + let constructor = "triton::createDAGSyncPass()"; + let dependentDialects = ["hivm::HIVMDialect", "bufferization::BufferizationDialect", "annotation::AnnotationDialect"]; +} + +#endif // TRITON_AFFINITY_OPTIMIZATION_PASSES \ No newline at end of file diff --git a/third_party/ascend/include/TritonAffinityOpt/Utils.hpp b/third_party/ascend/include/TritonAffinityOpt/Utils.hpp new file mode 100644 index 0000000000..42b8c64a3f --- /dev/null +++ b/third_party/ascend/include/TritonAffinityOpt/Utils.hpp @@ -0,0 +1,26 @@ +#ifndef TRITON_AFFINITY_UTILS_HPP +#define TRITON_AFFINITY_UTILS_HPP + +#include + +namespace mlir::AffinityDAG { + +template +constexpr inline T enumOp(F&& func, T lhs, T rhs) { + static_assert(std::is_enum_v, "T must be an enum type"); + + using U = std::underlying_type_t; + + return static_cast( + std::invoke( + std::forward(func), + static_cast(lhs), + static_cast(rhs) + ) + ); +} + +} // namespace TritonAffinity::Utils + + +#endif \ No newline at end of file diff --git a/third_party/ascend/language/cann/__init__.py b/third_party/ascend/language/cann/__init__.py index d7feaad57a..d599b4569b 100644 --- a/third_party/ascend/language/cann/__init__.py +++ b/third_party/ascend/language/cann/__init__.py @@ -18,17 +18,19 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +from triton.language import math +from triton.backends.ascend.utils import triton_enable_libdevice_simt + from . import libdevice from . import extension extension.parallel = extension.aux_ops.parallel -libdevice.atan2 = extension.math_ops.atan2 +if not triton_enable_libdevice_simt(): + libdevice.atan2 = extension.math_ops.atan2 libdevice.isfinited = extension.math_ops.isfinited libdevice.finitef = extension.math_ops.finitef libdevice.flip = extension.flip -from triton.language import math - libdevice.umulhi = math.umulhi libdevice.exp = math.exp libdevice.exp2 = math.exp2 diff --git a/third_party/ascend/language/cann/extension/__init__.py b/third_party/ascend/language/cann/extension/__init__.py index 20c339bc8a..1d8c4dc31b 100644 --- a/third_party/ascend/language/cann/extension/__init__.py +++ b/third_party/ascend/language/cann/extension/__init__.py @@ -1,14 +1,27 @@ -try: - import acl - is_compile_on_910_95 = acl.get_soc_name().startswith("Ascend910_95") -except Exception as e: - is_compile_on_910_95 = False +from triton.tools.get_ascend_devices import is_compile_on_910_95 +from triton._C.libtriton.ascend import ir as _ascend_ir + +# MLIR affine bindings (same objects as triton._C.libtriton.ascend.ir). +affine_expr = _ascend_ir.affine_expr +affine_constant_expr = _ascend_ir.affine_constant_expr +affine_dim_expr = _ascend_ir.affine_dim_expr +affine_symbol_expr = _ascend_ir.affine_symbol_expr +affine_binary_op_expr = _ascend_ir.affine_binary_op_expr +affine_map = _ascend_ir.affine_map + +AffineExpr = affine_expr +AffineConstantExpr = affine_constant_expr +AffineDimExpr = affine_dim_expr +AffineSymbolExpr = affine_symbol_expr +AffineBinaryOpExpr = affine_binary_op_expr +AffineMap = affine_map from .core import ( ascend_address_space, builtin, CORE, copy_from_ub_to_l1, + copy, debug_barrier, fixpipe, FixpipeDMAMode, @@ -19,6 +32,7 @@ is_builtin, MODE, PIPE, + IteratorType, sub_vec_id, sub_vec_num, sync_block_all, @@ -55,7 +69,6 @@ ) from .mem_ops import ( - index_select, index_put, gather_out_to_ub, scatter_ub_to_out, @@ -66,6 +79,7 @@ # core "builtin", "copy_from_ub_to_l1", + "copy", "CORE", "debug_barrier", "fixpipe", @@ -77,6 +91,7 @@ "is_builtin", "MODE", "PIPE", + "IteratorType", "sub_vec_id", "sub_vec_num", "sync_block_all", @@ -85,6 +100,20 @@ # address space "ascend_address_space", + # ascend IR affine (MLIR) + "affine_expr", + "affine_constant_expr", + "affine_dim_expr", + "affine_symbol_expr", + "affine_binary_op_expr", + "affine_map", + "AffineExpr", + "AffineConstantExpr", + "AffineDimExpr", + "AffineSymbolExpr", + "AffineBinaryOpExpr", + "AffineMap", + # scope "scope", @@ -114,9 +143,9 @@ "cast", # mem ops - "index_select", "index_put", "gather_out_to_ub", "scatter_ub_to_out", "index_select_simd", ] + diff --git a/third_party/ascend/language/cann/extension/aux_ops.py b/third_party/ascend/language/cann/extension/aux_ops.py index d872e44ac6..f7a8d55e8a 100644 --- a/third_party/ascend/language/cann/extension/aux_ops.py +++ b/third_party/ascend/language/cann/extension/aux_ops.py @@ -1,23 +1,33 @@ import triton.language as tl from triton.language import semantic, core, standard -from triton.language.core import (_constexpr_to_value, _tensor_member_fn, _unwrap_iterable, builtin, constexpr, dtype, - tensor, check_bit_width, _unwrap_if_constexpr, range) +from triton.language.core import ( + _constexpr_to_value, + _tensor_member_fn, + _unwrap_iterable, + builtin, + constexpr, + dtype, + tensor, + check_bit_width, + _unwrap_if_constexpr, + range +) from triton.language.semantic import ( - wrap_tensor, - _str_to_rounding_mode, - not_equal, + wrap_tensor, + _str_to_rounding_mode, + not_equal, _str_to_dot_input_precision, - binary_op_type_checking_impl, - integer_promote_impl, - broadcast_impl_shape, - _str_to_sem, - _str_to_scope, + binary_op_type_checking_impl, + integer_promote_impl, + broadcast_impl_shape, + _str_to_sem, + _str_to_scope, bitcast, bitwise_op_type_checking_impl, - to_tensor, - _str_to_load_cache_modifier, + to_tensor, + _str_to_load_cache_modifier, _str_to_eviction_policy, - _str_to_padding_option, + _str_to_padding_option, _canonicalize_boundary_check, ) @@ -57,10 +67,8 @@ def sync_block_set(sender, receiver, event_id, _builder=None): sender = _constexpr_to_value(sender) receiver = _constexpr_to_value(receiver) event_id = _constexpr_to_value(event_id) - assert isinstance(sender, str) and (sender == "cube" - or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" - assert isinstance(receiver, str) and (receiver == "cube" or receiver - == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" if sender == receiver: raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') @@ -80,10 +88,8 @@ def sync_block_wait(sender, receiver, event_id, _builder=None): sender = _constexpr_to_value(sender) receiver = _constexpr_to_value(receiver) event_id = _constexpr_to_value(event_id) - assert isinstance(sender, str) and (sender == "cube" - or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" - assert isinstance(receiver, str) and (receiver == "cube" or receiver - == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" if sender == receiver: raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') @@ -100,9 +106,7 @@ class parallel(range): This is used in the mixed cube-vector kernel on 910B. The number of vector cores is determined by the number of iteration in this loop. Currently on 910B, max 2 vector cores could be used. """ - - def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, - bind_sub_block: bool = False): + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, bind_sub_block: bool = False): super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) self.bind_sub_block = bind_sub_block @@ -112,10 +116,11 @@ def compile_hint_impl(ptr: tensor, hint_name: str, hint_val, builder: ir.builder # FIXME: is_simt_mode # if builder.is_simt_mode(): # return - if not hint_val: - hint_val = builder.get_unit_attr() - elif isinstance(hint_val, bool): + # Check isinstance(hint_val, bool) first to handle False explicitly + if isinstance(hint_val, bool): hint_val = builder.get_bool_attr(hint_val) + elif not hint_val: + hint_val = builder.get_unit_attr() elif isinstance(hint_val, int): hint_val = builder.get_int32_attr(hint_val) elif isinstance(hint_val, core.constexpr): @@ -125,8 +130,7 @@ def compile_hint_impl(ptr: tensor, hint_name: str, hint_val, builder: ir.builder hint_val = builder.get_i64_array_attr(hint_val) else: raise ValueError(f"Unsupported hint value type: {type(hint_val)}") - builder.create_annotation(ptr.handle, hint_name, hint_val) - + builder.create_annotation_mark(ptr.handle, hint_name, hint_val) @builtin def compile_hint(ptr, hint_name, hint_val=None, _builder=None): @@ -146,7 +150,6 @@ def _unwrap(val): hint_val = _unwrap_if_constexpr(hint_val) if hint_val else hint_val compile_hint_impl(ptr, hint_name, hint_val, _builder) - @builtin def multibuffer(src: tensor, size, _builder=None): """ @@ -156,4 +159,4 @@ def multibuffer(src: tensor, size, _builder=None): """ buffer_size = _constexpr_to_value(size) assert isinstance(buffer_size, int) and buffer_size == 2, f"only support bufferize equals 2" - compile_hint_impl(src, "multi_buffer", buffer_size, _builder) + compile_hint_impl(src, "hivm.multi_buffer", buffer_size, _builder) diff --git a/third_party/ascend/language/cann/extension/builder.py b/third_party/ascend/language/cann/extension/builder.py index cfd4be3b0b..8cf699f63a 100644 --- a/third_party/ascend/language/cann/extension/builder.py +++ b/third_party/ascend/language/cann/extension/builder.py @@ -73,6 +73,7 @@ def setup_unified_builder(main_builder, ascend_builder): 'create_copy_buffer', 'create_copy_tensor', 'create_fixpipe', + 'create_annotation_mark', 'create_bind_buffer', 'create_debug_barrier', 'is_910_95', diff --git a/third_party/ascend/language/cann/extension/core.py b/third_party/ascend/language/cann/extension/core.py index 1710d7b960..333615c870 100644 --- a/third_party/ascend/language/cann/extension/core.py +++ b/third_party/ascend/language/cann/extension/core.py @@ -1,312 +1,368 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# Copyright 2018-2020 Philippe Tillet -# Copyright 2020-2022 OpenAI -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -__all__ = [ - "ascend_address_space", "builtin", "CORE", "copy_from_ub_to_l1", "debug_barrier", "fixpipe", "FixpipeDMAMode", - "FixpipeDualDstMode", "FixpipePreQuantMode", "FixpipePreReluMode", "int64", "is_builtin", "MODE", "PIPE", - "sub_vec_id", "sub_vec_num", "sync_block_all", "sync_block_set", "sync_block_wait", "SYNC_IN_VF" -] - -import enum -from typing import TypeVar, List, Union -from functools import wraps - -from triton._C.libtriton import ir -from triton._C.libtriton.ascend import ir as ascend_ir -import triton.language.core as tl - -import triton.extension.buffer.language as bl -from triton.language.core import _constexpr_to_value -from triton.backends.ascend.driver import NPUUtils - -from . import semantic as semantic - -PIPE = semantic.PIPE - -T = TypeVar("T") - -TRITON_BUILTIN = "__triton_builtin__" -ASCEND_BUILTIN = "__ascend_builtin__" - - -def builtin(fn: T) -> T: - """Mark a function as a buffer language builtin.""" - assert callable(fn) - - @wraps(fn) - def wrapper(*args, **kwargs): - if "_builder" not in kwargs or kwargs["_builder"] is None: - raise ValueError("Did you forget to add @triton.jit ? " - "(`_builder` argument must be provided outside of JIT functions.)") - return fn(*args, **kwargs) - - # also set triton_builtin to true so that CodeGenerator will recognize this function - setattr(wrapper, TRITON_BUILTIN, True) - setattr(wrapper, ASCEND_BUILTIN, True) - - return wrapper - - -def is_builtin(fn) -> bool: - """Is this a registered ascend language builtin function?""" - return getattr(fn, ASCEND_BUILTIN, False) - - -class int64(int): - """ - For custom op, python int argument will be converted to int32 by default, - if a device-side int64 is required, you can pass an al.int64(x) to it. - """ - - def __new__(cls, value): - obj = int.__new__(cls, value) - obj.type = tl.int64 - return obj - - -class CORE(enum.Enum): - VECTOR = ascend_ir.CoreType.VECTOR - CUBE = ascend_ir.CoreType.CUBE - CUBE_OR_VECTOR = ascend_ir.CoreType.CUBE_OR_VECTOR - CUBE_AND_VECTOR = ascend_ir.CoreType.CUBE_AND_VECTOR - - -class PIPE(enum.Enum): - PIPE_S = ascend_ir.PIPE.PIPE_S - PIPE_V = ascend_ir.PIPE.PIPE_V - PIPE_M = ascend_ir.PIPE.PIPE_M - PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1 - PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2 - PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3 - PIPE_ALL = ascend_ir.PIPE.PIPE_ALL - PIPE_FIX = ascend_ir.PIPE.PIPE_FIX - - -class MODE(enum.Enum): - SIMD = ascend_ir.MODE.SIMD - SIMT = ascend_ir.MODE.SIMT - MIX = ascend_ir.MODE.MIX - - -class ascend_address_space_base(bl.address_space): - - def __init__(self, address_space_value: ascend_ir.AddressSpace) -> None: - super().__init__() - self.real_address_space = address_space_value - - def to_ir(self, builder: ir.builder) -> ir.attribute: - return builder.get_target_attribute(self.real_address_space) - - -class ascend_address_space_group: - - def __init__(self): - for k, v in {k: v - for k, v in ascend_ir.AddressSpace.__dict__.items() - if isinstance(v, ascend_ir.AddressSpace)}.items(): - setattr(self, k, ascend_address_space_base(v)) - - -ascend_address_space = ascend_address_space_group() - - -@builtin -def sub_vec_id(_builder=None) -> tl.tensor: - """ - Get the Vector Core index on the AI Core. - """ - return semantic.sub_vec_id(_builder) - - -@builtin -def copy_from_ub_to_l1(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], _builder: None) -> None: - """ - Copies data from the Unified Buffer (UB) to the L1 Buffer. - - :param src: The source data located in the Unified Buffer. - :type src: tl.tensor | bl.buffer - :param dst: The destination buffer located in L1 memory. - :type dst: tl.tensor | bl.buffer - """ - return semantic.copy_from_ub_to_l1(src, dst, _builder) - - -def create_sync_block(sender, receiver, event_id, is_set: bool, sender_pipe=None, receiver_pipe=None, _builder=None): - sender = _constexpr_to_value(sender) - receiver = _constexpr_to_value(receiver) - assert isinstance(sender, str) and (sender == "cube" - or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" - assert isinstance(receiver, str) and (receiver == "cube" or receiver - == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" - if isinstance(event_id, int): - assert (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" - if sender == receiver: - raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') - if sender_pipe is None and receiver_pipe is None: - if sender == "cube": - sender_pipe = PIPE.PIPE_FIX - receiver_pipe = PIPE.PIPE_MTE2 - if sender == "vector": - sender_pipe = PIPE.PIPE_MTE3 - receiver_pipe = PIPE.PIPE_MTE2 - if not isinstance(sender_pipe, PIPE) or not isinstance(receiver_pipe, PIPE): - raise TypeError("sender_pipe and receiver_pipe must be instances of PIPE enum") - if is_set: - return semantic.create_sync_block_set(sender, receiver, event_id, sender_pipe, receiver_pipe, _builder) - return semantic.create_sync_block_wait(sender, receiver, event_id, sender_pipe, receiver_pipe, _builder) - - -@builtin -def sync_block_set(sender, receiver, event_id, sender_pipe=None, receiver_pipe=None, _builder=None): - return create_sync_block(sender, receiver, event_id, True, sender_pipe, receiver_pipe, _builder) - - -@builtin -def sync_block_wait(sender, receiver, event_id, sender_pipe=None, receiver_pipe=None, _builder=None): - return create_sync_block(sender, receiver, event_id, False, sender_pipe, receiver_pipe, _builder) - - -@builtin -def sync_block_all(mode, event_id, _builder=None): - mode = _constexpr_to_value(mode) - event_id = _constexpr_to_value(event_id) - assert isinstance(mode, str), f"mode: {mode} is not string" - assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" - assert mode in ("all_cube", "all_vector", "all", - "all_sub_vector"), f"ERROR: mode = {mode}, only supports all_cube/all_vector/all/all_sub_vector" - _builder.sync_block_all(mode, event_id) - - -class FixpipeDMAMode(enum.Enum): - NZ2DN = ascend_ir.FixpipeDMAMode.NZ2DN - NZ2ND = ascend_ir.FixpipeDMAMode.NZ2ND - NZ2NZ = ascend_ir.FixpipeDMAMode.NZ2NZ - - -class FixpipeDualDstMode(enum.Enum): - NO_DUAL = ascend_ir.FixpipeDualDstMode.NO_DUAL - COLUMN_SPLIT = ascend_ir.FixpipeDualDstMode.COLUMN_SPLIT - ROW_SPLIT = ascend_ir.FixpipeDualDstMode.ROW_SPLIT - - -class FixpipePreQuantMode(enum.Enum): - NO_QUANT = ascend_ir.FixpipePreQuantMode.NO_QUANT - F322BF16 = ascend_ir.FixpipePreQuantMode.F322BF16 - F322F16 = ascend_ir.FixpipePreQuantMode.F322F16 - S322I8 = ascend_ir.FixpipePreQuantMode.S322I8 - - -class FixpipePreReluMode(enum.Enum): - LEAKY_RELU = ascend_ir.FixpipePreReluMode.LEAKY_RELU - NO_RELU = ascend_ir.FixpipePreReluMode.NO_RELU - NORMAL_RELU = ascend_ir.FixpipePreReluMode.NORMAL_RELU - P_RELU = ascend_ir.FixpipePreReluMode.P_RELU - - -@builtin -def fixpipe( - src: tl.tensor, - dst: bl.buffer, - dma_mode: FixpipeDMAMode = FixpipeDMAMode.NZ2ND, - dual_dst_mode: FixpipeDualDstMode = FixpipeDualDstMode.NO_DUAL, - _builder=None, -) -> None: - """ - Directly store a tensor on L0C to a local buffer via fixpipe. - Fixpipe is pipeline that performing data movement from L0C to other memory hierarchies. - Currently support: - - L0C to UB (for Ascend910_95 sereies) - - :param src: the source tensor, Must be located in the l0C memory region. - :type src: tl.tensor - :param dst: The destination buffer, Must be located in the UB memory region. - :type dst: bl.buffer - :param dma_mode: DMA transfer mode, "nz2nd" enables NZ to ND layout transformation - :type dma_mode: str - """ - if not _builder.is_910_95(): - raise RuntimeError("this feature is only supported on Ascend910_95") - if not isinstance(src, tl.tensor): - raise TypeError("src is not of tensor type") - elif not isinstance(dst, bl.buffer): - raise TypeError("dst is not of buffer type") - if dst.space != ascend_address_space.UB: - raise TypeError("dst must be located in the UB memory region") - - if len(dst.shape) == 2 and (dst.type.element_ty == tl.float32 or dst.type.element_ty == tl.int32): - N = dst.shape[1] - if N % 8 != 0: - raise ValueError("32b Fixpipe last dim must be aligned to 8") - if (dma_mode != FixpipeDMAMode.NZ2ND) and (N % 16 != 0): - raise ValueError("32b non-NZ2ND Fixpipe last dim must be aligned to 16") - if (dual_dst_mode == FixpipeDualDstMode.COLUMN_SPLIT) and (N % 32 != 0): - raise ValueError("32b Column split dual Fixpipe last dim must be aligned to 32") - M = dst.shape[0] - if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 8 != 0): - raise ValueError("32b NZ2DN Fixpipe first dim must be aligned to 8") - dst16bits = (dst.type.element_ty == tl.float16 or dst.type.element_ty == tl.int16 - or dst.type.element_ty == tl.bfloat16) - if len(dst.shape) == 2 and dst16bits: - N = dst.shape[1] - if N % 16 != 0: - raise ValueError("16b Fixpipe last dim must be aligned to 16") - M = dst.shape[0] - if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 16 != 0): - raise ValueError("16b NZ2DN Fixpipe first dim must be aligned to 16") - - return semantic.fixpipe(src, dst, dma_mode, dual_dst_mode, FixpipePreQuantMode.NO_QUANT, FixpipePreReluMode.NO_RELU, - _builder) - - -class SYNC_IN_VF(enum.Enum): - VV_ALL = enum.auto() - VST_VLD = enum.auto() - VLD_VST = enum.auto() - VST_VST = enum.auto() - VS_ALL = enum.auto() - VST_LD = enum.auto() - VLD_ST = enum.auto() - VST_ST = enum.auto() - SV_ALL = enum.auto() - ST_VLD = enum.auto() - LD_VST = enum.auto() - ST_VST = enum.auto() - - -@builtin -def debug_barrier( - sync_mode: SYNC_IN_VF, - _builder=None, -) -> None: - return semantic.debug_barrier(sync_mode.name, _builder) - - -@builtin -def sub_vec_num(_builder=None) -> tl.constexpr: - """ - Get the Vector Core Num on one AI Core. - """ - npuUtils = NPUUtils() - cube_num = npuUtils.get_aivector_core_num() - vector_num = npuUtils.get_aicore_num() - const_val = cube_num // vector_num - return tl.constexpr(const_val) +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = [ + "ascend_address_space", + "builtin", + "CORE", + "copy_from_ub_to_l1", + "copy", + "debug_barrier", + "fixpipe", + "FixpipeDMAMode", + "FixpipeDualDstMode", + "FixpipePreQuantMode", + "FixpipePreReluMode", + "int64", + "is_builtin", + "MODE", + "PIPE", + "IteratorType", + "sub_vec_id", + "sub_vec_num", + "sync_block_all", + "sync_block_set", + "sync_block_wait", + "SYNC_IN_VF" +] + +import enum +from typing import TypeVar, List, Union +from functools import wraps + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +import triton.language.core as tl + +import triton.extension.buffer.language as bl +from triton.language.core import _constexpr_to_value +from triton.backends.ascend.driver import NPUUtils + +from . import semantic as semantic +PIPE = semantic.PIPE + + +T = TypeVar("T") + +TRITON_BUILTIN = "__triton_builtin__" +ASCEND_BUILTIN = "__ascend_builtin__" + + +def builtin(fn: T) -> T: + """Mark a function as a buffer language builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + # also set triton_builtin to true so that CodeGenerator will recognize this function + setattr(wrapper, TRITON_BUILTIN, True) + setattr(wrapper, ASCEND_BUILTIN, True) + + return wrapper + + +def is_builtin(fn) -> bool: + """Is this a registered ascend language builtin function?""" + return getattr(fn, ASCEND_BUILTIN, False) + + +class int64(int): + """ + For custom op, python int argument will be converted to int32 by default, + if a device-side int64 is required, you can pass an al.int64(x) to it. + """ + def __new__(cls, value): + obj = int.__new__(cls, value) + obj.type = tl.int64 + return obj + + +class CORE(enum.Enum): + VECTOR = ascend_ir.CoreType.VECTOR + CUBE = ascend_ir.CoreType.CUBE + CUBE_OR_VECTOR = ascend_ir.CoreType.CUBE_OR_VECTOR + CUBE_AND_VECTOR = ascend_ir.CoreType.CUBE_AND_VECTOR + + +class PIPE(enum.Enum): + PIPE_S = ascend_ir.PIPE.PIPE_S + PIPE_V = ascend_ir.PIPE.PIPE_V + PIPE_M = ascend_ir.PIPE.PIPE_M + PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1 + PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2 + PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3 + PIPE_ALL = ascend_ir.PIPE.PIPE_ALL + PIPE_FIX = ascend_ir.PIPE.PIPE_FIX + + +class MODE(enum.Enum): + SIMD = ascend_ir.MODE.SIMD + SIMT = ascend_ir.MODE.SIMT + MIX = ascend_ir.MODE.MIX + + +class IteratorType(enum.Enum): + Parallel = ascend_ir.IteratorType.Parallel + Broadcast = ascend_ir.IteratorType.Broadcast + Transpose = ascend_ir.IteratorType.Transpose + Reduction = ascend_ir.IteratorType.Reduction + Interleave = ascend_ir.IteratorType.Interleave + Deinterleave = ascend_ir.IteratorType.Deinterleave + Inverse = ascend_ir.IteratorType.Inverse + Pad = ascend_ir.IteratorType.Pad + Concat = ascend_ir.IteratorType.Concat + Gather = ascend_ir.IteratorType.Gather + Cumulative = ascend_ir.IteratorType.Cumulative + Opaque = ascend_ir.IteratorType.Opaque + + +class ascend_address_space_base(bl.address_space): + def __init__(self, address_space_value: ascend_ir.AddressSpace) -> None: + super().__init__() + self.real_address_space = address_space_value + + def to_ir(self, builder: ir.builder) -> ir.attribute: + return builder.get_target_attribute(self.real_address_space) + + +class ascend_address_space_group: + + def __init__(self): + for k, v in { + k: v + for k, v in ascend_ir.AddressSpace.__dict__.items() + if isinstance(v, ascend_ir.AddressSpace) + }.items(): + setattr(self, k, ascend_address_space_base(v)) + + +ascend_address_space = ascend_address_space_group() + + +@builtin +def sub_vec_id(_builder=None) -> tl.tensor: + """ + Get the Vector Core index on the AI Core. + """ + return semantic.sub_vec_id(_builder) + + +@builtin +def copy_from_ub_to_l1(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], _builder: None) -> None: + """ + Copies data from the Unified Buffer (UB) to the L1 Buffer. + + :param src: The source data located in the Unified Buffer. + :type src: tl.tensor | bl.buffer + :param dst: The destination buffer located in L1 memory. + :type dst: tl.tensor | bl.buffer + """ + from warnings import warn + warn("copy_from_ub_to_l1 is deprecated, please use copy instead.") + return semantic.copy_from_ub_to_l1(src, dst, _builder) + + +@builtin +def copy(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], _builder: None) -> None: + """ + Copies data from the Unified Buffer (UB) to the Unified Buffer (UB) or L1 Buffer. + + :param src: The source data located in the Unified Buffer. + :type src: tl.tensor | bl.buffer + :param dst: The destination buffer located Unified Buffer (UB) or L1 memory. + :type dst: tl.tensor | bl.buffer + """ + return semantic.copy(src, dst, _builder) + + +def create_sync_block(sender, receiver, event_id, is_set: bool, + sender_pipe=None, receiver_pipe=None, + _builder=None): + sender = _constexpr_to_value(sender) + receiver = _constexpr_to_value(receiver) + assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + if isinstance(event_id, int): + assert (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + if sender == receiver: + raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') + if sender_pipe is None and receiver_pipe is None: + if sender == "cube": + sender_pipe = PIPE.PIPE_FIX + receiver_pipe = PIPE.PIPE_MTE2 + if sender == "vector": + sender_pipe = PIPE.PIPE_MTE3 + receiver_pipe = PIPE.PIPE_MTE2 + if not isinstance(sender_pipe, PIPE) or not isinstance(receiver_pipe, PIPE): + raise TypeError("sender_pipe and receiver_pipe must be instances of PIPE enum") + if is_set: + return semantic.create_sync_block_set(sender, receiver, event_id, sender_pipe, receiver_pipe, _builder) + return semantic.create_sync_block_wait(sender, receiver, event_id, sender_pipe, receiver_pipe, _builder) + + +@builtin +def sync_block_set(sender, receiver, event_id, sender_pipe=None, receiver_pipe=None, _builder=None): + return create_sync_block(sender, receiver, event_id, True, sender_pipe, receiver_pipe, _builder) + + +@builtin +def sync_block_wait(sender, receiver, event_id, sender_pipe=None, receiver_pipe=None, _builder=None): + return create_sync_block(sender, receiver, event_id, False, sender_pipe, receiver_pipe, _builder) + + +@builtin +def sync_block_all(mode, event_id, _builder=None): + mode = _constexpr_to_value(mode) + event_id = _constexpr_to_value(event_id) + assert isinstance(mode, str), f"mode: {mode} is not string" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + assert mode in ("all_cube", "all_vector", "all", "all_sub_vector"), f"ERROR: mode = {mode}, only supports all_cube/all_vector/all/all_sub_vector" + _builder.sync_block_all(mode, event_id) + + +class FixpipeDMAMode(enum.Enum): + NZ2DN = ascend_ir.FixpipeDMAMode.NZ2DN + NZ2ND = ascend_ir.FixpipeDMAMode.NZ2ND + NZ2NZ = ascend_ir.FixpipeDMAMode.NZ2NZ + + +class FixpipeDualDstMode(enum.Enum): + NO_DUAL = ascend_ir.FixpipeDualDstMode.NO_DUAL + COLUMN_SPLIT = ascend_ir.FixpipeDualDstMode.COLUMN_SPLIT + ROW_SPLIT = ascend_ir.FixpipeDualDstMode.ROW_SPLIT + + +class FixpipePreQuantMode(enum.Enum): + NO_QUANT = ascend_ir.FixpipePreQuantMode.NO_QUANT + F322BF16 = ascend_ir.FixpipePreQuantMode.F322BF16 + F322F16 = ascend_ir.FixpipePreQuantMode.F322F16 + S322I8 = ascend_ir.FixpipePreQuantMode.S322I8 + + +class FixpipePreReluMode(enum.Enum): + LEAKY_RELU = ascend_ir.FixpipePreReluMode.LEAKY_RELU + NO_RELU = ascend_ir.FixpipePreReluMode.NO_RELU + NORMAL_RELU = ascend_ir.FixpipePreReluMode.NORMAL_RELU + P_RELU = ascend_ir.FixpipePreReluMode.P_RELU + + +@builtin +def fixpipe( + src: tl.tensor, + dst: bl.buffer, + dma_mode: FixpipeDMAMode = FixpipeDMAMode.NZ2ND, + dual_dst_mode: FixpipeDualDstMode = FixpipeDualDstMode.NO_DUAL, + _builder=None, +) -> None: + """ + Directly store a tensor on L0C to a local buffer via fixpipe. + Fixpipe is pipeline that performing data movement from L0C to other memory hierarchies. + Currently support: + - L0C to UB (for Ascend910_95 sereies) + + :param src: the source tensor, Must be located in the l0C memory region. + :type src: tl.tensor + :param dst: The destination buffer, Must be located in the UB memory region. + :type dst: bl.buffer + :param dma_mode: DMA transfer mode, "nz2nd" enables NZ to ND layout transformation + :type dma_mode: str + """ + if not _builder.is_910_95(): + raise RuntimeError("this feature is only supported on Ascend910_95") + if not isinstance(src, tl.tensor): + raise TypeError("src is not of tensor type") + elif not isinstance(dst, bl.buffer): + raise TypeError("dst is not of buffer type") + if dst.space != ascend_address_space.UB: + raise TypeError("dst must be located in the UB memory region") + + if len(dst.shape) == 2 and ( + dst.type.element_ty == tl.float32 or dst.type.element_ty == tl.int32 + ): + N = dst.shape[1] + if N % 8 != 0: + raise ValueError("32b Fixpipe last dim must be aligned to 8") + if (dma_mode != FixpipeDMAMode.NZ2ND) and (N % 16 != 0): + raise ValueError("32b non-NZ2ND Fixpipe last dim must be aligned to 16") + if (dual_dst_mode == FixpipeDualDstMode.COLUMN_SPLIT) and (N % 32 != 0): + raise ValueError( + "32b Column split dual Fixpipe last dim must be aligned to 32" + ) + M = dst.shape[0] + if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 8 != 0): + raise ValueError("32b NZ2DN Fixpipe first dim must be aligned to 8") + dst16bits = ( + dst.type.element_ty == tl.float16 + or dst.type.element_ty == tl.int16 + or dst.type.element_ty == tl.bfloat16 + ) + if len(dst.shape) == 2 and dst16bits: + N = dst.shape[1] + if N % 16 != 0: + raise ValueError("16b Fixpipe last dim must be aligned to 16") + M = dst.shape[0] + if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 16 != 0): + raise ValueError("16b NZ2DN Fixpipe first dim must be aligned to 16") + + return semantic.fixpipe( + src, dst, dma_mode, dual_dst_mode, FixpipePreQuantMode.NO_QUANT, FixpipePreReluMode.NO_RELU, _builder + ) + + +class SYNC_IN_VF(enum.Enum): + VV_ALL = enum.auto() + VST_VLD = enum.auto() + VLD_VST = enum.auto() + VST_VST = enum.auto() + VS_ALL = enum.auto() + VST_LD = enum.auto() + VLD_ST = enum.auto() + VST_ST = enum.auto() + SV_ALL = enum.auto() + ST_VLD = enum.auto() + LD_VST = enum.auto() + ST_VST = enum.auto() + + +@builtin +def debug_barrier( + sync_mode: SYNC_IN_VF, + _builder=None, +) -> None: + return semantic.debug_barrier(sync_mode.name, _builder) + + +@builtin +def sub_vec_num(_builder=None) -> tl.constexpr: + """ + Get the Vector Core Num on one AI Core. + """ + npuUtils = NPUUtils() + cube_num = npuUtils.get_aivector_core_num() + vector_num = npuUtils.get_aicore_num() + const_val = cube_num // vector_num + return tl.constexpr(const_val) diff --git a/third_party/ascend/language/cann/extension/custom_op.py b/third_party/ascend/language/cann/extension/custom_op.py index b3352c26b2..e1b149e2b9 100644 --- a/third_party/ascend/language/cann/extension/custom_op.py +++ b/third_party/ascend/language/cann/extension/custom_op.py @@ -152,22 +152,122 @@ def _args_to_operands(op, builder, args, kwargs): return operands +def _bind_op_arguments(op, args, kwargs): + if not op.signature.parameters: + return None + return op.signature.bind(*args, **kwargs) + + +def _make_align_dim_attrs(op, builder, arg_attrs): + # Find op argument by name using op.align_dim's key + # We want to return a dict mapping for each align_dim key -> int attribute for the actual bound argument value. + name = 'align_dim' + if not hasattr(op, name): + return + + # To find argument indices matching each align_dim key, check the op.signature parameters + # and map align_dim key (argument name) to its index position. + align_arg_indices = {} + if hasattr(op, "signature"): + param_names = list(op.signature.parameters.keys()) + for arg_name in op.align_dim.keys(): + if arg_name in param_names: + align_arg_indices[arg_name] = param_names.index(arg_name) + + for arg, align_val in op.align_dim.items(): + if isinstance(arg, str) and arg in align_arg_indices: + arg_attrs[align_arg_indices[arg]] = { name : builder.get_int_attr(align_val) } + print(arg_attrs[align_arg_indices[arg]]) + elif isinstance(arg, int): + arg_attrs[arg] = { name : builder.get_int_attr(align_val) } + print(arg_attrs[arg]) + else: + assert False, f"{name}'s keys should be string or int" + + +def _make_arg_attrs(op, builder): + num_args = len(op.signature.parameters) if hasattr(op, "signature") else 0 + arg_attrs = [{} for _ in range(num_args)] + + _make_align_dim_attrs(op, builder, arg_attrs) + return arg_attrs + + def _add_optional_attr(op, name, builder, attrs): if hasattr(op, name): attrs[name] = builder.get_str_attr(getattr(op, name)) +def _add_bitcode_attr(op, builder, attrs): + name = 'bitcode' + if not hasattr(op, name): + return + + from pathlib import Path + bitcode = Path(getattr(op, name)) + assert bitcode.exists(), f"Provided bitcode ({name}) not exist" + attrs[name] = builder.get_str_attr(str(bitcode.absolute())) + + +def _add_optional_extra_buffer_attr(op, builder, attrs): + name = 'extra_buffers' + if not hasattr(op, name): + return + + extra_buffers = getattr(op, name) + if isinstance(extra_buffers, tuple): + extra_buffers = [ extra_buffers ] + + extra_buffer_types, extra_buffer_sizes = zip(*extra_buffers) + attrs[name + "_types"] = builder.get_type_array_attr([ty.to_ir(builder) for ty in extra_buffer_types]) + attrs[name + "_sizes"] = builder.get_i64_array_attr(list(extra_buffer_sizes)) + + +def _add_optional_indexing_map_attr(op, builder, attrs): + # Optional indexing map attribute: + # `indexing_map` should be an iterable of al.affine_map (MLIR AffineMap) objects. + name = 'indexing_map' + if not hasattr(op, name): + return + + indexing_map = getattr(op, name) + attrs[name] = builder.get_affine_map_array_attr(indexing_map) + + +def _add_optional_iterator_types_attr(op, builder, attrs): + name = 'iterator_types' + if not hasattr(op, name): + return + + attrs[name] = builder.get_iterator_types_attr([iterator_type.value for iterator_type in getattr(op, name)]) + + def _make_attrs(op, builder): attrs = { 'hivm.tcore_type': builder.get_core_type_attr(op.core.value), 'hivm.pipe': builder.get_pipe_attr(op.pipe.value), 'hivm.vf_mode': builder.get_vf_mode_attr(op.mode.value), } + + if not op.name.startswith('__builtin_'): + assert hasattr(op, 'symbol'), f"Non builtin custom op, symbol is required." + assert hasattr(op, 'bitcode'), f"Non builtin custom op, bitcode path is required." + + # Add bit code path attribute, formalize to abosulte path. + _add_bitcode_attr(op, builder, attrs) + + + _add_optional_indexing_map_attr(op, builder, attrs) + _add_optional_iterator_types_attr(op, builder, attrs) + + _add_optional_extra_buffer_attr(op, builder, attrs) + _add_optional_attr(op, 'symbol', builder, attrs) _add_optional_attr(op, 'source', builder, attrs) _add_optional_attr(op, 'compile', builder, attrs) # Extra attributes can be added here, such as op.extra_attr="attr_a=xx" _add_optional_attr(op, 'extra_attr', builder, attrs) + return attrs @@ -207,8 +307,9 @@ def custom_semantic(name: str, *args, _builder=None, **kwargs): inputs = _args_to_operands(op, _builder, args, kwargs) # Setup attributes. attrs = _make_attrs(op, _builder) + arg_attrs = _make_arg_attrs(op, _builder) # Build IR for the custom op. - res = _builder.create_custom_op(name, attrs, inputs, outputs) + res = _builder.create_custom_op(name, attrs, inputs, outputs, arg_attrs) # Results with same types as outputs. res_types = [out.type for out in outs] return _to_result(res, res_types) @@ -228,6 +329,7 @@ def register_custom_op(op): setattr(op, 'name', op.__name__) # The op name should not be used. assert op.name not in _custom_op_registry, f"Custom op name '{op.name}' already used." + # Check required core, pipe, mode fields. assert hasattr(op, 'core'), "'core' field is required." assert hasattr(op, 'pipe'), "'pipe' field is required." diff --git a/third_party/ascend/language/cann/extension/mem_ops.py b/third_party/ascend/language/cann/extension/mem_ops.py index 859bbecd67..a59b71add9 100644 --- a/third_party/ascend/language/cann/extension/mem_ops.py +++ b/third_party/ascend/language/cann/extension/mem_ops.py @@ -1,551 +1,636 @@ -import numbers -import triton.language as tl -from triton.language import semantic as real_semantic -from triton.language.core import ( - _constexpr_to_value, - _tensor_member_fn, - _unwrap_iterable, - builtin, - constexpr, - dtype, - tensor, - check_bit_width, - _unwrap_if_constexpr, -) -from triton.language.semantic import ( - wrap_tensor, - _str_to_rounding_mode, - not_equal, - _str_to_dot_input_precision, - binary_op_type_checking_impl, - integer_promote_impl, - broadcast_impl_shape, - _str_to_sem, - _str_to_scope, - bitcast, - bitwise_op_type_checking_impl, - to_tensor, - _str_to_load_cache_modifier, - _str_to_eviction_policy, - _str_to_padding_option, - _canonicalize_boundary_check, -) - -from typing import Optional, Tuple, List, overload, Union -from triton._C.libtriton import ir - -from ._utils import _convert_elem_to_ir_value - - -@_tensor_member_fn -@builtin -def index_select(src: tensor, idx: tensor, bound, lstdim_blksiz, offsets, numels, _builder=None): - """ - Embedding - :src_ptr: - :idx: - """ - - def embedding_gather_impl(src: tl.tensor, idx: tl.tensor, bound: int, blksiz: int, offsets: Tuple, numels: Tuple, - builder: ir.builder) -> tl.tensor: - assert idx.dtype.is_int(), "index must be an integer tensor" - if not src.dtype.element_ty.is_floating(): - raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {src.dtype.element_ty}") - - require_i64 = idx.dtype.is_int64() - # require_i64 = True - offsets = [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in offsets] - numels = [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in numels] - ret = builder.create_embedding_gather(src.handle, idx.handle, bound, blksiz, offsets, numels) - ret_shape = [_unwrap_if_constexpr(s) for s in idx.shape] - ret_shape.append(blksiz) - return wrap_tensor(ret, src.dtype.element_ty, ret_shape) - - bound = _constexpr_to_value(bound) - lstdim_blksiz = _constexpr_to_value(lstdim_blksiz) - - return embedding_gather_impl(src, idx, bound, lstdim_blksiz, offsets, numels, _builder) - - -@_tensor_member_fn -@builtin -def index_put(ptr: tensor, index: tensor, value: tensor, dim: int, index_boundary: int, end_offset: tuple, - start_offset: tuple, dst_stride: tuple, _builder=None): - """ - Index put values from a tensor into a destination tensor. - - Index put operation for different tensor ranks: - 1. 2D index scatter (0 <= dim < 1): - 1.1 dim = 0 - out[index[i]][start_offset[1]:end_offset[1]] = value[i][0:end_offset[1]-start_offset[1]] - 2. 3D index scatter (0 <= dim < 2): - 2.1 dim = 0 - out[index[i]][start_offset[1]:end_offset[1]][start_offset[2]:end_offset[2]] - = value[i][0:end_offset[1]-start_offset[1]][0:end_offset[2]-start_offset[2]] - 2.2 dim = 1 - out[start_offset[0]:end_offset[0]][index[j]][start_offset[2]:end_offset[2]] - = value[0:end_offset[0]-start_offset[0]][j][0:end_offset[2]-start_offset[2]] - - - :param ptr: pointer type, the destination tensor pointer (in GM) - :param index: tensor, a index to scatter (in UB) - :param value: tensor, a value to store (in UB) - :param dim: int32, the dimension to scatter along - :param index_boundary: int64, the upper boundary for index values - :param end_offset: tuple of int, the offsets of each dimension for the end of the scatter region - :param start_offset: tuple of int, the offsets of each dimension for the start of the scatter region - :param dst_stride: tuple of int, the stride of each dimension of destination tensor - - Constraints - *********** - - `ptr` and `value` must have the same rank. - - `ptr.dtype` only supports `float16`, `bfloat16`, `float32` currently. - - `index` must be an integer tensor. If `index.rank` != 1, it will be reshaped to 1D. - - `index.numel` must equal `value.shape[dim]`. - - `value` support 2~5D tensors. - - `dim` must be valid (0 <= dim < rank(value) - 1). - - Example - ******* - .. code-block:: python - - import torch - import triton - import triton.language as tl - from triton.language.extra.cann.extension import index_put - - @triton.jit - def simple_index_put_kernel(value_ptr, index_ptr, dst_ptr): - # index tile shape: [2] - index_local = tl.arange(0, 2) - x1_local = tl.arange(0, 2)[None, :] # shape=(1,2) - - index_tile = tl.load(index_ptr + index_local) - value_tile = tl.load(value_ptr + index_local[:, None]*2 + x1_local) - - index_put( - ptr=dst_ptr, - index=index_tile, - value=value_tile, - dim=0, - index_boundary=4, - end_offset=(2, 2), - start_offset=(0, 0), - dst_stride=(2, 1) - ) - - dst = torch.zeros((4,2), device='npu', dtype=torch.float32) - value = torch.tensor([[1.,2.], [3.,4.]], device='npu') - index = torch.tensor([2, 0], device='npu') - - simple_index_put_kernel[(1,)](value, index, dst) - print("IndexPut result:", dst) # ref:[[3.,4.], [0.,0.], [1.,2.], [0.,0.]] - """ - - def index_put_impl(ptr: tl.tensor, index: tl.tensor, value: tl.tensor, dim: int, index_boundary: int, - end_offset: Tuple, start_offset: Tuple, dst_stride: Tuple, builder: ir.builder): - assert index.dtype.is_int(), "index must be an integer tensor" - if not ptr.dtype.element_ty.is_floating(): - raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {ptr.dtype.element_ty}") - if not isinstance(dim, int): - raise ValueError("dim must be of type tl.constexpr") - - v_rank = len(value.shape) - idx_rank = len(index.shape) - if v_rank < 2 or v_rank > 5: - raise ValueError(f"value rank must be in [2, 5], got value rank={v_rank}") - if dim < 0 or dim >= v_rank - 1: - raise ValueError(f"dim must satisfy 0<=dim 5: - raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") - if dim < 0 or dim >= idx_rank: - raise ValueError(f"dim must satisfy 0<=dim 5: - raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") - if dim < 0 or dim >= idx_rank: - raise ValueError(f"dim must satisfy 0<=dim 0 - - dim = _constexpr_to_value(dim) - index_boundary = _constexpr_to_value(index_boundary) - value = _constexpr_to_value(value) - - if not _is_ranked_tensor(value) or isinstance(value, constexpr): - element_ty = ptr.type.scalar.element_ty - value = real_semantic.full(index.shape, value, element_ty, _builder) - return scatter_ub_to_out_impl(ptr, value, index, index_boundary, dim, dst_stride, end_offset, start_offset, - _builder) - - -@_tensor_member_fn -@builtin -def index_select_simd(src, dim, index, src_shape, src_offset, read_shape, _builder=None) -> tensor: - """ - Parallel index_select operation from Global Memory to Unified Buffer (SIMD version). - - Selects data from multiple indices along a specified dimension and loads - them as tiles from GM directly to UB with zero-copy semantics. - - :param src: Source tensor pointer (in GM) - :type src: tensor (pointer type) - :param dim: The dimension along which to select indices - :type dim: int or constexpr - :param index: 1D tensor of indices to select (in UB) - :type index: tensor - :param src_shape: Complete shape of the source tensor (can be int or tensor) - :type src_shape: List[Union[int, tensor]] - :param src_offset: Starting offset for reading (can be int or tensor) - :type src_offset: List[Union[int, tensor]] - :param read_shape: Size to read (tile shape, can be int or tensor) - :type read_shape: List[Union[int, tensor]] - - **Constraints:** - - - ``read_shape[dim]`` must be ``-1`` - - ``src_offset[dim]`` can be ``-1`` (will be ignored) - - Boundary handling: ``src_offset + read_shape > src_shape`` automatically - truncates to ``src_shape`` boundary - - Does not check if ``index`` contains out-of-bounds values - - **Example:** - - .. code-block:: python - - @triton.jit - def kernel(src_ptr, output_ptr, indices_ptr, M, N, D, ...): - # Load indices (e.g., [5, 10, 15, 20]) - indices = tl.load(indices_ptr + tl.arange(0, 4)) - - # Example 1: Static shapes (constants) - # Index select from dimension 1 - # src: [8, 100, 256], index_select at dim=1 - # Read: [4, ?, 128] starting from [4, ?, 128] - result = extension.index_select_simd( - src_ptr, - dim=1, - index=indices, - src_shape=[8, 100, 256], - src_offset=[4, -1, 128], - read_shape=[4, -1, 128] - ) - # result shape: [4, 4, 128] - - # Example 2: Dynamic shapes (variables) - result2 = extension.index_select_simd( - src_ptr, - dim=1, - index=indices, - src_shape=[M, N, D], - src_offset=[4, -1, 128], - read_shape=[4, -1, 128] - ) - - tl.store(output_ptr + ..., result) - - :return: Result tensor in UB with shape where ``dim`` is replaced - by the length of ``index`` - :rtype: tensor - """ - - def index_select_simd_impl(src: tl.tensor, dim: int, index: tl.tensor, src_shape: List[Union[int, tl.tensor]], - src_offset: List[Union[int, tl.tensor]], read_shape: List[Union[int, tl.tensor]], - builder: ir.builder) -> tl.tensor: - # Validate inputs - ndim = len(src_shape) - assert len(src_offset) == ndim, \ - f"src_offset length {len(src_offset)} must match src_shape length {ndim}" - assert len(read_shape) == ndim, \ - f"read_shape length {len(read_shape)} must match src_shape length {ndim}" - assert 0 <= dim < ndim, \ - f"dim={dim} must be in range [0, {ndim})" - assert len(index.shape) == 1, \ - f"index must be 1D tensor, got {len(index.shape)}D" - assert dim < ndim - 1, \ - f"index_select_simd cannot support trailing dimension as dim={dim}, ndim={ndim}" - - newsrc_shape = [o.handle for o in src_shape] - newsrc_offset = [o.handle for o in src_offset] - # Create output type - return_shape = [index.shape[0] if i == dim else read_shape[i] for i in range(ndim)] - element_ty = src.type.element_ty - output_ty = tl.block_type(element_ty, return_shape) - out = builder.create_index_select_simd(src.handle, index.handle, dim, newsrc_shape, newsrc_offset, read_shape, - return_shape) - return tl.tensor(out, output_ty) - - dim = _constexpr_to_value(dim) - - # Process shape parameters: convert constexpr to values, keep tensors as-is - def process_param(val): - """Convert constexpr to value, keep tensor or int as-is""" - if isinstance(val, tensor): - return val - else: - return _constexpr_to_value(val) - - newsrc_shape = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_shape] - newsrc_offset = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_offset] - assert len(index.shape) == 1, "index must be a 1D tensor" - - return index_select_simd_impl(src, dim, index, newsrc_shape, newsrc_offset, read_shape, _builder) +import numbers +import triton.language as tl +from triton.language import semantic as real_semantic +from triton.language.core import ( + _constexpr_to_value, + _tensor_member_fn, + _unwrap_iterable, + builtin, + constexpr, + dtype, + tensor, + check_bit_width, + _unwrap_if_constexpr, +) +from triton.language.semantic import ( + wrap_tensor, + _str_to_rounding_mode, + not_equal, + _str_to_dot_input_precision, + binary_op_type_checking_impl, + integer_promote_impl, + broadcast_impl_shape, + _str_to_sem, + _str_to_scope, + bitcast, + bitwise_op_type_checking_impl, + to_tensor, + _str_to_load_cache_modifier, + _str_to_eviction_policy, + _str_to_padding_option, + _canonicalize_boundary_check, +) + +from typing import Optional, Tuple, List, overload, Union +from triton._C.libtriton import ir + +from ._utils import _convert_elem_to_ir_value + + +@_tensor_member_fn +@builtin +def index_put( + ptr: tensor, + index: tensor, + value: tensor, + dim: int, + index_boundary: int, + end_offset: tuple, + start_offset: tuple, + dst_stride: tuple, + _builder=None +): + """ + Index put values from a tensor into a destination tensor. + + Index put operation for different tensor ranks: + 1. 2D index scatter (0 <= dim < 1): + 1.1 dim = 0 + out[index[i]][start_offset[1]:end_offset[1]] = value[i][0:end_offset[1]-start_offset[1]] + 2. 3D index scatter (0 <= dim < 2): + 2.1 dim = 0 + out[index[i]][start_offset[1]:end_offset[1]][start_offset[2]:end_offset[2]] + = value[i][0:end_offset[1]-start_offset[1]][0:end_offset[2]-start_offset[2]] + 2.2 dim = 1 + out[start_offset[0]:end_offset[0]][index[j]][start_offset[2]:end_offset[2]] + = value[0:end_offset[0]-start_offset[0]][j][0:end_offset[2]-start_offset[2]] + + + :param ptr: pointer type, the destination tensor pointer (in GM) + :param index: tensor, a index to scatter (in UB) + :param value: tensor, a value to store (in UB) + :param dim: int32, the dimension to scatter along + :param index_boundary: int64, the upper boundary for index values + :param end_offset: tuple of int, the offsets of each dimension for the end of the scatter region + :param start_offset: tuple of int, the offsets of each dimension for the start of the scatter region + :param dst_stride: tuple of int, the stride of each dimension of destination tensor + + Constraints + *********** + - `ptr` and `value` must have the same rank. + - `ptr.dtype` only supports `float16`, `bfloat16`, `float32` currently. + - `index` must be an integer tensor. If `index.rank` != 1, it will be reshaped to 1D. + - `index.numel` must equal `value.shape[dim]`. + - `value` support 2~5D tensors. + - `dim` must be valid (0 <= dim < rank(value) - 1). + + Example + ******* + .. code-block:: python + + import torch + import triton + import triton.language as tl + from triton.language.extra.cann.extension import index_put + + @triton.jit + def simple_index_put_kernel(value_ptr, index_ptr, dst_ptr): + # index tile shape: [2] + index_local = tl.arange(0, 2) + x1_local = tl.arange(0, 2)[None, :] # shape=(1,2) + + index_tile = tl.load(index_ptr + index_local) + value_tile = tl.load(value_ptr + index_local[:, None]*2 + x1_local) + + index_put( + ptr=dst_ptr, + index=index_tile, + value=value_tile, + dim=0, + index_boundary=4, + end_offset=(2, 2), + start_offset=(0, 0), + dst_stride=(2, 1) + ) + + dst = torch.zeros((4,2), device='npu', dtype=torch.float32) + value = torch.tensor([[1.,2.], [3.,4.]], device='npu') + index = torch.tensor([2, 0], device='npu') + + simple_index_put_kernel[(1,)](value, index, dst) + print("IndexPut result:", dst) # ref:[[3.,4.], [0.,0.], [1.,2.], [0.,0.]] + """ + + def index_put_impl( + ptr: tl.tensor, + index: tl.tensor, + value: tl.tensor, + dim: int, + index_boundary: int, + end_offset: Tuple, + start_offset: Tuple, + dst_stride: Tuple, + _builder: ir.builder + ): + assert index.dtype.is_int(), "index must be an integer tensor" + if not ptr.dtype.element_ty.is_floating(): + raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {ptr.dtype.element_ty}") + if not isinstance(dim, int): + raise ValueError("dim must be of type tl.constexpr") + + v_rank = len(value.shape) + idx_rank = len(index.shape) + if v_rank < 2 or v_rank > 5: + raise ValueError(f"value rank must be in [2, 5], got value rank={v_rank}") + if dim < 0 or dim >= v_rank - 1: + raise ValueError(f"dim must satisfy 0<=dim 5: + raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") + if dim < 0 or dim >= idx_rank: + raise ValueError(f"dim must satisfy 0<=dim 5: + raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") + if dim < 0 or dim >= idx_rank: + raise ValueError(f"dim must satisfy 0<=dim 0 + + dim = _constexpr_to_value(dim) + index_boundary = _constexpr_to_value(index_boundary) + value = _constexpr_to_value(value) + + if not _is_ranked_tensor(value) or isinstance(value, constexpr): + element_ty = ptr.type.scalar.element_ty + value = real_semantic.full(index.shape, value, element_ty, _builder) + return scatter_ub_to_out_impl(ptr, value, index, index_boundary, dim, + dst_stride, end_offset, start_offset, _builder) + + +@_tensor_member_fn +@builtin +def index_select_simd( + src, + dim, + index, + src_shape, + src_offset, + read_shape, + _builder=None +) -> tensor: + """ + Parallel index_select operation from Global Memory to Unified Buffer (SIMD version). + + Selects data from multiple indices along a specified dimension and loads + them as tiles from GM directly to UB with zero-copy semantics. + + :param src: Source tensor pointer (in GM) + :type src: tensor (pointer type) + :param dim: The dimension along which to select indices + :type dim: int or constexpr + :param index: 1D tensor of indices to select (in UB) + :type index: tensor + :param src_shape: Complete shape of the source tensor (can be int or tensor) + :type src_shape: List[Union[int, tensor]] + :param src_offset: Starting offset for reading (can be int or tensor) + :type src_offset: List[Union[int, tensor]] + :param read_shape: Size to read (tile shape, can be int or tensor) + :type read_shape: List[Union[int, tensor]] + + **Constraints:** + + - ``read_shape[dim]`` must be ``-1`` + - ``src_offset[dim]`` can be ``-1`` (will be ignored) + - Boundary handling: ``src_offset + read_shape > src_shape`` automatically + truncates to ``src_shape`` boundary + - Does not check if ``index`` contains out-of-bounds values + + **Example:** + + .. code-block:: python + + @triton.jit + def kernel(src_ptr, output_ptr, indices_ptr, M, N, D, ...): + # Load indices (e.g., [5, 10, 15, 20]) + indices = tl.load(indices_ptr + tl.arange(0, 4)) + + # Example 1: Static shapes (constants) + # Index select from dimension 1 + # src: [8, 100, 256], index_select at dim=1 + # Read: [4, ?, 128] starting from [4, ?, 128] + result = extension.index_select_simd( + src_ptr, + dim=1, + index=indices, + src_shape=[8, 100, 256], + src_offset=[4, -1, 128], + read_shape=[4, -1, 128] + ) + # result shape: [4, 4, 128] + + # Example 2: Dynamic shapes (variables) + result2 = extension.index_select_simd( + src_ptr, + dim=1, + index=indices, + src_shape=[M, N, D], + src_offset=[4, -1, 128], + read_shape=[4, -1, 128] + ) + + tl.store(output_ptr + ..., result) + + :return: Result tensor in UB with shape where ``dim`` is replaced + by the length of ``index`` + :rtype: tensor + """ + + def index_select_simd_impl( + src: tl.tensor, + dim: int, + index: tl.tensor, + src_shape: List[Union[int, tl.tensor]], + src_offset: List[Union[int, tl.tensor]], + read_shape: List[Union[int, tl.tensor]], + _builder: ir.builder + ) -> tl.tensor: + # Validate inputs + ndim = len(src_shape) + assert len(src_offset) == ndim, \ + f"src_offset length {len(src_offset)} must match src_shape length {ndim}" + assert len(read_shape) == ndim, \ + f"read_shape length {len(read_shape)} must match src_shape length {ndim}" + assert 0 <= dim < ndim, \ + f"dim={dim} must be in range [0, {ndim})" + assert len(index.shape) == 1, \ + f"index must be 1D tensor, got {len(index.shape)}D" + assert dim < ndim - 1, \ + f"index_select_simd cannot support trailing dimension as dim={dim}, ndim={ndim}" + # Handle both tensor and int offsets (for interpreter mode) + newsrc_shape = [] + for s in src_shape: + if isinstance(s, tensor): + newsrc_shape.append(s.handle) + elif isinstance(s, int): + # For interpreter mode: keep as int + newsrc_shape.append(s) + else: + newsrc_shape.append(s.handle if hasattr(s, 'handle') else s) + newsrc_offset = [] + for s in src_offset: + if isinstance(s, tensor): + newsrc_offset.append(s.handle) + elif isinstance(s, int): + # For interpreter mode: keep as int + newsrc_offset.append(s) + else: + newsrc_offset.append(s.handle if hasattr(s, 'handle') else s) + + # Create output type + return_shape = [ + index.shape[0] if i == dim else read_shape[i] + for i in range(ndim) + ] + element_ty = src.type.element_ty + output_ty = tl.block_type(element_ty, return_shape) + out = _builder.create_index_select_simd(src.handle, index.handle, dim, newsrc_shape, newsrc_offset, read_shape, return_shape) + return tl.tensor(out, output_ty) + + dim = _constexpr_to_value(dim) + + # Process shape parameters: convert constexpr to values, keep tensors as-is + def process_param(val): + """Convert constexpr to value, keep tensor or int as-is""" + if isinstance(val, tensor): + return val + else: + return _constexpr_to_value(val) + + newsrc_shape = [ + real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in src_shape + ] + newsrc_offset = [ + real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in src_offset + ] + assert len(index.shape) == 1, "index must be a 1D tensor" + + return index_select_simd_impl( + src, dim, index, newsrc_shape, newsrc_offset, read_shape, _builder + ) diff --git a/third_party/ascend/language/cann/extension/semantic.py b/third_party/ascend/language/cann/extension/semantic.py index 29df62e651..2f43733dca 100644 --- a/third_party/ascend/language/cann/extension/semantic.py +++ b/third_party/ascend/language/cann/extension/semantic.py @@ -1,129 +1,154 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# Copyright 2018-2020 Philippe Tillet -# Copyright 2020-2022 OpenAI -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -__all__ = [ - "fixpipe", - "create_address_space", -] - -import enum -from typing import (TypeVar, List, Union) - -from triton._C.libtriton import ir -from triton._C.libtriton.ascend import ir as ascend_ir -import triton.language.core as tl -import triton.language.extra.cann.extension as al -import triton.extension.buffer.language as bl - -from triton.language import semantic as real_semantic - -T = TypeVar('T') - - -def create_address_space(address_space: ascend_ir.AddressSpace, - builder: ascend_ir.ascendnpu_ir_builder) -> ir.attribute: - return builder.get_target_attribute(address_space) - - -class PIPE(enum.Enum): - PIPE_S = ascend_ir.PIPE.PIPE_S - PIPE_V = ascend_ir.PIPE.PIPE_V - PIPE_M = ascend_ir.PIPE.PIPE_M - PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1 - PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2 - PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3 - PIPE_ALL = ascend_ir.PIPE.PIPE_ALL - PIPE_FIX = ascend_ir.PIPE.PIPE_FIX - - -def create_sync_block_set(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): - if isinstance(event_id, int): - _builder.sync_block_set(sender, receiver, - real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, sender_pipe.value, - receiver_pipe.value) - elif isinstance(event_id, tl.constexpr): - _builder.sync_block_set(sender, receiver, - real_semantic.to_tensor(event_id, _builder).handle, sender_pipe.value, - receiver_pipe.value) - else: - _builder.sync_block_set(sender, receiver, event_id.handle, sender_pipe.value, receiver_pipe.value) - - -def create_sync_block_wait(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): - if isinstance(event_id, int): - _builder.sync_block_wait(sender, receiver, - real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, sender_pipe.value, - receiver_pipe.value) - elif isinstance(event_id, tl.constexpr): - _builder.sync_block_wait(sender, receiver, - real_semantic.to_tensor(event_id, _builder).handle, sender_pipe.value, - receiver_pipe.value) - else: - _builder.sync_block_wait(sender, receiver, event_id.handle, sender_pipe.value, receiver_pipe.value) - - -def sub_vec_id(builder: ascend_ir.ascendnpu_ir_builder) -> tl.tensor: - return tl.tensor(builder.create_get_sub_vec_id(), tl.int64) - - -def copy_from_ub_to_l1(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], builder): - if not builder.is_910_95(): - raise RuntimeError("this feature is only supported on Ascend910_95") - if isinstance(src, tl.tensor) or isinstance(dst, tl.tensor): - raise TypeError("tensor not support yet") - if src.shape != dst.shape: - raise TypeError("src and dst must have same shape") - if src.dtype != dst.dtype: - raise TypeError("src and dst need to have the same type") - if isinstance(src, bl.buffer) and isinstance(dst, bl.buffer): - if src.space != al.ascend_address_space.UB: - raise TypeError("src's AddressSpace must be UB") - if dst.space != al.ascend_address_space.L1: - raise TypeError("dst's AddressSpace must be L1") - builder.create_copy_buffer(src.handle, dst.handle) - else: - raise TypeError("src and dst must be tl.tensor or bl.buffer") - - -def fixpipe( - src: tl.tensor, - dst, - dma_mode, - dual_dst_mode, - pre_quant_mode, - pre_relu_mode, - builder: ascend_ir.ascendnpu_ir_builder, -) -> None: - builder.create_fixpipe( - src.handle, - dst.handle, - dma_mode.value, - dual_dst_mode.value, - pre_quant_mode.value, - pre_relu_mode.value, - ) - - -def debug_barrier(sync_mode: str, builder) -> None: - target = tl.tensor(builder.get_int64(0), tl.int64) - attr = builder.get_str_attr(sync_mode) - builder.create_debug_barrier(target.handle, "SYNC_IN_VF", attr) +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = [ + "fixpipe", + "create_address_space", +] + +import enum +from typing import ( + TypeVar, List, Union +) + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +import triton.language.core as tl +import triton.language.extra.cann.extension as al +import triton.extension.buffer.language as bl + +from triton.language import semantic as real_semantic + +T = TypeVar('T') + + +def create_address_space( + address_space: ascend_ir.AddressSpace, + builder: ascend_ir.ascendnpu_ir_builder +) -> ir.attribute: + return builder.get_target_attribute(address_space) + + +class PIPE(enum.Enum): + PIPE_S = ascend_ir.PIPE.PIPE_S + PIPE_V = ascend_ir.PIPE.PIPE_V + PIPE_M = ascend_ir.PIPE.PIPE_M + PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1 + PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2 + PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3 + PIPE_ALL = ascend_ir.PIPE.PIPE_ALL + PIPE_FIX = ascend_ir.PIPE.PIPE_FIX + + +def create_sync_block_set(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): + if isinstance(event_id, int): + _builder.sync_block_set(sender, receiver, + real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, + sender_pipe.value, receiver_pipe.value) + elif isinstance(event_id, tl.constexpr): + _builder.sync_block_set(sender, receiver, + real_semantic.to_tensor(event_id, _builder).handle, + sender_pipe.value, receiver_pipe.value) + else: + _builder.sync_block_set(sender, receiver, + event_id.handle, sender_pipe.value, receiver_pipe.value) + + +def create_sync_block_wait(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): + if isinstance(event_id, int): + _builder.sync_block_wait(sender, receiver, + real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, + sender_pipe.value, receiver_pipe.value) + elif isinstance(event_id, tl.constexpr): + _builder.sync_block_wait(sender, receiver, + real_semantic.to_tensor(event_id, _builder).handle, + sender_pipe.value, receiver_pipe.value) + else: + _builder.sync_block_wait(sender, receiver, + event_id.handle, sender_pipe.value, receiver_pipe.value) + + +def sub_vec_id(builder: ascend_ir.ascendnpu_ir_builder) -> tl.tensor: + return tl.tensor(builder.create_get_sub_vec_id(), tl.int64) + + +def copy_from_ub_to_l1(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], builder): + if not builder.is_910_95(): + raise RuntimeError("this feature is only supported on Ascend910_95") + if isinstance(src, tl.tensor) or isinstance(dst, tl.tensor): + raise TypeError("tensor not support yet") + if src.shape != dst.shape: + raise TypeError("src and dst must have same shape") + if src.dtype != dst.dtype: + raise TypeError("src and dst need to have the same type") + if isinstance(src, bl.buffer) and isinstance(dst, bl.buffer): + if src.space != al.ascend_address_space.UB: + raise TypeError("src's AddressSpace must be UB") + if dst.space != al.ascend_address_space.L1: + raise TypeError("dst's AddressSpace must be L1") + builder.create_copy_buffer(src.handle, dst.handle) + else: + raise TypeError("src and dst must be tl.tensor or bl.buffer") + + +def copy(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], builder): + if not builder.is_910_95(): + raise RuntimeError("this feature is only supported on Ascend910_95") + if isinstance(src, tl.tensor) or isinstance(dst, tl.tensor): + raise TypeError("tensor not support yet") + if src.shape != dst.shape: + raise TypeError("src and dst must have same shape") + if src.dtype != dst.dtype: + raise TypeError("src and dst need to have the same type") + if isinstance(src, bl.buffer) and isinstance(dst, bl.buffer): + if src.space != al.ascend_address_space.UB: + raise TypeError("src's AddressSpace must be UB") + if dst.space not in (al.ascend_address_space.L1, al.ascend_address_space.UB): + raise TypeError("dst's AddressSpace must be UB or L1") + builder.create_copy_buffer(src.handle, dst.handle) + else: + raise TypeError("src and dst must be tl.tensor or bl.buffer") + + +def fixpipe( + src: tl.tensor, + dst, + dma_mode, + dual_dst_mode, + pre_quant_mode, + pre_relu_mode, + builder: ascend_ir.ascendnpu_ir_builder, +) -> None: + builder.create_fixpipe( + src.handle, + dst.handle, + dma_mode.value, + dual_dst_mode.value, + pre_quant_mode.value, + pre_relu_mode.value, + ) + + +def debug_barrier(sync_mode: str, builder) -> None: + target = tl.tensor(builder.get_int64(0), tl.int64) + attr = builder.get_str_attr(sync_mode) + builder.create_debug_barrier(target.handle, "SYNC_IN_VF", attr) diff --git a/third_party/ascend/language/cann/extension/vec_ops.py b/third_party/ascend/language/cann/extension/vec_ops.py index effbbc0fa7..ea2a5f7c41 100644 --- a/third_party/ascend/language/cann/extension/vec_ops.py +++ b/third_party/ascend/language/cann/extension/vec_ops.py @@ -1,535 +1,562 @@ -# insert_slice -# extract_slice -# get_element -# sort -# flip -# gather - -import triton.language as tl -from triton.language import semantic, core, standard -from triton.language.core import (_constexpr_to_value, _tensor_member_fn, _unwrap_iterable, builtin, constexpr, dtype, - tensor, check_bit_width, _unwrap_if_constexpr, range) -from triton.language.semantic import ( - wrap_tensor, - _str_to_rounding_mode, - not_equal, - _str_to_dot_input_precision, - binary_op_type_checking_impl, - integer_promote_impl, - broadcast_impl_shape, - _str_to_sem, - _str_to_scope, - bitcast, - bitwise_op_type_checking_impl, - to_tensor, - _str_to_load_cache_modifier, - _str_to_eviction_policy, - _str_to_padding_option, - _canonicalize_boundary_check, -) - -from . import is_compile_on_910_95 -from .aux_ops import compile_hint_impl - -from typing import Optional, Tuple, List, overload -from triton._C.libtriton import ir - - -@_tensor_member_fn -@builtin -def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: - """ - Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. - - :param ful: The tensor to receive tensor. - :type ful: Tensor - :param sub: The tensor to be inserted. - :type sub: Tensor - :param offsets: - :type offsets: tuple of ints - :param sizes: - :type sizes: tuple of ints - :param strides: - :type strides: tuple of ints - """ - - def insert_slice_impl(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], - builder: ir.builder) -> tensor: - assert (len(ful.shape) == len(offsets)) - assert (len(ful.shape) == len(sizes)) - assert (len(ful.shape) == len(strides)) - assert (all([s >= 1 for s in sizes])) - assert (all([s >= 0 for s in strides])) - # Handle both tensor and int offsets (for interpreter mode) - new_offsets = [] - for o in offsets: - if isinstance(o, tensor): - new_offsets.append(o.handle) - elif isinstance(o, int): - # For interpreter mode: keep as int - new_offsets.append(o) - else: - new_offsets.append(o.handle if hasattr(o, 'handle') else o) - ret_type = tl.block_type(ful.type.scalar, ful.shape) - out = builder.create_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) - return tensor(out, ret_type) - - assert len(ful.shape) > 0 - assert len(ful.shape) == len(sub.shape) - new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] - out = insert_slice_impl(ful, sub, new_offsets, sizes, strides, _builder) - return out - - -@_tensor_member_fn -@builtin -def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: - """ - Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. - - :param ful: The tensor to split. - :type ful: Tensor - :param offsets: - :type offsets: tuple of ints - :param sizes: - :type sizes: tuple of ints - :param strides: - :type strides: tuple of ints - """ - - def extract_slice_impl(ful: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], - builder: ir.builder) -> tensor: - assert (len(ful.shape) == len(offsets)) - assert (len(ful.shape) == len(sizes)) - assert (len(ful.shape) == len(strides)) - assert (all([s >= 1 for s in sizes])) - assert (all([s >= 0 for s in strides])) - # Handle both tensor and int offsets (for interpreter mode) - new_offsets = [] - for o in offsets: - if isinstance(o, tensor): - new_offsets.append(o.handle) - elif isinstance(o, int): - # For interpreter mode: keep as int - new_offsets.append(o) - else: - new_offsets.append(o.handle if hasattr(o, 'handle') else o) - ret_type = tl.block_type(ful.type.scalar, sizes) - out = builder.create_extract_slice(ful.handle, new_offsets, sizes, strides) - return tensor(out, ret_type) - - assert len(ful.shape) > 0 - new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] - sub = extract_slice_impl(ful, new_offsets, sizes, strides, _builder) - return sub - - -@_tensor_member_fn -@builtin -def get_element(src, indice, _builder=None, _generator=None): - """ - get_element op reads a ranked tensor and returns one element as specified by the given indices. - The result of the op is a value with the same type as the elements of the tensor. - The arity of indices must match the rank of the accessed value. - - :param src: The tensor to be accessed. - :type src: Tensor - :param indice: - :type indice: tuple of ints - """ - - def get_element_impl(src: tensor, indice: List[tensor], builder: ir.builder): - if len(src.shape) != len(indice): - raise ValueError("Indice's rank must be equal to src tensor's rank") - - # Handle both tensor and int indices (for interpreter mode) - new_indice = [] - for i in indice: - if isinstance(i, tensor): - new_indice.append(i.handle) - elif isinstance(i, int): - # For interpreter mode: convert int to TensorHandle - new_indice.append(i) - else: - # Try to use .handle attribute if available - new_indice.append(i.handle if hasattr(i, 'handle') else i) - - result = builder.create_extract_scalar(src.handle, new_indice) - return wrap_tensor(result, src.type.scalar, None) - - assert len(src.shape) > 0 - new_indice = [semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i for i in indice] - return get_element_impl(src, new_indice, _builder) - - -@builtin -def flip(ptr, dim=-1, _builder=None, _generator=None): - - def flip_impl(ptr: tensor, dim: int, builder: ir.builder, generator=None): - """ - Flips a tensor `ptr` along the dimension `dim`. - - :param ptr: the first input tensor - :type ptr: tensor - :param dim: the dimension to flip along - :type dim: int - :param generator: the code generator (required for reduce operations) - :type generator: generator object - """ - - def _get_flip_dim(dim, shape): - dim = _unwrap_if_constexpr(dim) - shape = _unwrap_if_constexpr(shape) - if dim is None: - dim = len(shape) - 1 - if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index - dim += len(shape) - return constexpr(dim) - - def _log2(i: core.constexpr): - log2 = 0 - n = core.constexpr(i).value - while n > 1: - n >>= 1 - log2 += 1 - return core.constexpr(log2) - - def flip_simd(ptr: tensor, dim: int, builder: ir.builder): - """ - Triton flip operation for simd - - Args: - ptr: tensor, input tensor - dim: int, dimension to flip (can be negative, normalized here) - builder: ir.builder, underlying IR builder - Returns: - flipped: tensor, same type and shape as input - """ - - shape = getattr(ptr, "shape", None) - if shape is None or shape == (): - shape = getattr(getattr(ptr, "type", None), "shape", None) - - rank = None - if shape is not None: - try: - rank = len(shape) - except Exception: - rank = len(list(shape)) - - if rank is not None: - if rank < 1: - raise ValueError("ascend.flip requires tensor rank >= 1") - norm_dim = dim if dim >= 0 else dim + rank - if not (0 <= norm_dim < rank): - raise ValueError(f"ascend.flip got invalid dim={dim} for shape {tuple(shape)}") - dim = norm_dim - else: - if dim < 0: - raise ValueError("ascend.flip with unknown rank requires non-negative dim") - - flipped_vals = builder.create_flip(ptr.handle, dim) - flipped = tensor(flipped_vals, type=ptr.type) - return flipped - - # If compile_mode is not simt, use the simd implementation - if not builder.is_simt_mode(): - return flip_simd(ptr, dim, builder) - core.static_assert(-len(ptr.shape) <= dim and dim < len(ptr.shape), _builder=builder) - _dim: core.constexpr = _get_flip_dim(dim, ptr.shape) - core.static_assert(standard._is_power_of_two(ptr.shape[_dim]), _builder=builder) - steps: core.constexpr = _log2(ptr.shape[_dim]) - # If steps is 0, return the original tensor - if steps == 0: - return ptr - # reshape the swap dimension to (2, 2, ..., 2) - idtype = core.get_int_dtype(bitwidth=ptr.dtype.primitive_bitwidth, signed=True) - y = core.reshape( - ptr.to(idtype, bitcast=True, _builder=builder), - ptr.shape.__getitem__(slice(None, _dim)) + [2] * steps + ptr.shape.__getitem__(slice(_dim + 1, None)), - _builder=builder) - for i in static_range(steps): - y = y.__xor__(standard.xor_sum(y, _dim + i, True, _builder=builder, _generator=generator), _builder=builder) - ptr = core.reshape(y, ptr.shape, _builder=builder).to(ptr.dtype, bitcast=True, _builder=builder) - return ptr - - try: - dim = int(dim.value) if hasattr(dim, "value") else int(dim) - except Exception as e: - raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}") from e - - dim = len(ptr.shape) - 1 if dim == -1 else dim - return flip_impl(ptr, dim, _builder, _generator) - - -class static_range: - """ - Iterator for non-JIT Python functions that need to iterate over constexpr values. - This is used in functions like flip that are called during compilation. - """ - - def __init__(self, arg1, arg2=None, step=None): - if step is None: - self.step = core.constexpr(1) - else: - self.step = step - if arg2 is None: - self.start = core.constexpr(0) - self.end = arg1 - else: - self.start = arg1 - self.end = arg2 - - def __iter__(self): - # Extract actual values from constexpr objects for iteration - start_val = core._constexpr_to_value(self.start) - end_val = core._constexpr_to_value(self.end) - step_val = core._constexpr_to_value(self.step) - # Store as regular Python integers for iteration - self._current = start_val - self._end = end_val - self._step = step_val - return self - - def __next__(self): - if self._current >= self._end: - raise StopIteration - value = self._current - self._current += self._step - return value - - -@builtin -def sort(ptr, dim=-1, descending=False, _builder=None): - """ - sort the input tensor along 'dim' - - param: - ptr: tensor, input tensor - dim: int or tl.constexpr[int], dimension to sort - descending: bool or tl.constexpr[bool], the result is descending or not - _builder: ir.builder - return: - values: tensor, the sorted tensor - """ - - def sort_impl(ptr: tensor, dim: int, descending, builder: ir.builder): - allowed_types = { - tl.int8, tl.int16, tl.bfloat16, tl.float16, tl.float32, tl.int32, tl.int64, tl.float8e4nv, tl.float8e5 - } - base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type - if base_ty not in allowed_types: - raise TypeError( - f"ascend.sort only supports int8, int16, bfloat16, float16, float32, int32, int64, float8e4nv, float8e5" - f"but got {ptr.type}") - - shape = getattr(ptr, "shape", None) - if shape is None or shape == (): - shape = getattr(getattr(ptr, "type", None), "shape", None) - - rank = None - if shape is not None: - try: - rank = len(shape) - except Exception: - rank = len(list(shape)) - - if rank is not None: - if rank < 1: - raise ValueError("ascend.sort requires tensor rank >= 1") - last_dim = rank - 1 - norm_dim = dim if dim >= 0 else dim + rank - if norm_dim != last_dim: - raise ValueError(f"ascend.sort only supports sorting along the last dimension " - f"(dim={last_dim} or -1) for shape {tuple(shape)}, but got dim={dim}") - dim = last_dim - else: - if dim != -1: - raise ValueError("ascend.sort only supports the last dimension; when rank is unknown " - "you must pass dim=-1") - - if hasattr(descending, "value"): - descending = bool(descending.value) - else: - descending = bool(descending) - - sorted_vals = builder.create_sort(ptr.handle, dim, descending) - - values = tensor(sorted_vals, type=ptr.type) - - return values - - try: - dim = int(dim.value) if hasattr(dim, "value") else int(dim) - except Exception as e: - raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}. Error: {str(e)}") from e - - if hasattr(descending, "value"): - descending = bool(descending.value) - else: - descending = bool(descending) - - ret = sort_impl(ptr, dim, descending, _builder) - base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type - if base_ty.is_int8() or base_ty.is_int16(): - compile_hint_impl(ret, "overflow_mode", constexpr("saturate"), _builder) - return ret - - -def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, fp_downcast_rounding: Optional[str] = None, - overflow_mode: Optional[str] = None) -> tensor: - src_ty = input.type - if isinstance(dst_ty, tl.constexpr): - dst_ty = dst_ty.value - if isinstance(fp_downcast_rounding, tl.constexpr): - fp_downcast_rounding = fp_downcast_rounding.value - if src_ty.is_block(): - dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) - if src_ty == dst_ty: - return input - - src_sca_ty = src_ty.scalar - dst_sca_ty = dst_ty.scalar - if src_sca_ty == dst_sca_ty: - return input - - # For fp downcasting default rounding mode should be RTNE, for all other conversions it should - # not be set - fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) - use_custom_rounding = False - if dst_sca_ty.is_floating() and src_sca_ty.is_floating( - ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: - if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE - elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True - else: - if fp_downcast_rounding is not None: - raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " - "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) - if not is_compile_on_910_95: - if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): - raise ValueError("[fp8, fp64] is unsupported on Ascend for now." - "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) - if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): - assert builder.codegen_fns.get( - "convert_custom_types") is not None, "target doesn't provide conversion for this type." - return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) - # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 - # and non-default rounding modes for downcasting - if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ - (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ - use_custom_rounding: - return tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) - - # bf16 <=> (not fp32) - if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ - (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): - return ascend_cast_impl(ascend_cast_impl(input, tl.float32, builder), dst_sca_ty, builder) - - # Standard floating types' casting: truncation - # fp64 => fp32, fp16, bf16 - # fp32 => fp16, bf16 - truncate_fp = src_sca_ty.is_floating() and \ - dst_sca_ty.is_floating() and \ - src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth - if truncate_fp: - return tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Standard floating types' casting: extension - # fp32 => fp64 - # fp16 => fp32, fp64 - # bf16 => fp32, fp64 - ext_fp = src_sca_ty.is_floating() and \ - dst_sca_ty.is_floating() and \ - src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth - if ext_fp: - return tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting between integer types - if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ - (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): - sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() - if dst_sca_ty.is_bool(): - ty = input.dtype.to_ir(builder) - _0 = tensor(builder.get_null_value(ty), input.dtype) - return not_equal(input, _0, builder) - elif overflow_mode == "saturate" and \ - (src_sca_ty.is_int_unsigned() or dst_sca_ty.is_int_unsigned()) and \ - src_sca_ty.int_bitwidth >= dst_sca_ty.int_bitwidth: - return ascend_cast_impl(ascend_cast_impl(input, tl.float32, builder), dst_sca_ty, builder) - return tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) - - # Casting standard floating types to integer types - if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): - if dst_sca_ty.is_bool(): - ty = input.dtype.to_ir(builder) - _0 = tensor(builder.get_null_value(ty), input.dtype) - return not_equal(input, _0, builder) - elif dst_sca_ty.is_int_signed(): - return tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) - else: - return tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting integer types to standard floating types - if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): - if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): - return tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) - else: - return tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting pointer types to integer types - if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): - bitwidth = dst_sca_ty.int_bitwidth - if bitwidth == 64: - return tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) - if bitwidth == 1: - return not_equal(ascend_cast_impl(input, tl.int64, builder), tensor(builder.get_int64(0), tl.int64), - builder) - - # Casting integer types to pointer types - if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): - return tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) - - # Casting pointer types to pointer types - if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): - return tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) - - assert False, f'cannot cast {input} to {dst_ty}' - - -@_tensor_member_fn -@builtin -def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, - overflow_mode: Optional[str] = None, _builder=None): - """ - Casts a tensor to the given :code:`dtype`. - - :param dtype: The target data type. - :type dtype: dtype - :param fp_downcast_rounding: The rounding mode for downcasting - floating-point values. This parameter is only used when self is a - floating-point tensor and dtype is a floating-point type with a - smaller bitwidth. Supported values are :code:`"rtne"` (round to - nearest, ties to even) and :code:`"rtz"` (round towards zero). - :type fp_downcast_rounding: str, optional - :param bitcast: If true, the tensor is bitcasted to the given - :code:`dtype`, instead of being numerically casted. - :type bitcast: bool, optional - :param overflow_mode: When overflow_mode is not set or is "trunc", - truncation (cut-off) will be used to handle overflow. When - overflow_mode is "sautrate", the maximum value of the data type - will be used to handle overflow. - :type overflow_mode: string, optional - """ - overflow_modes = ["trunc", "saturate"] - input = semantic.to_tensor(input, _builder) - if isinstance(bitcast, constexpr): - bitcast = bitcast.value - if bitcast: - return semantic.bitcast(input, dtype, _builder) - ret = ascend_cast_impl(input, dtype, _builder, fp_downcast_rounding, overflow_mode) - if overflow_mode is not None: - if overflow_mode in overflow_modes: - compile_hint_impl(ret, "overflow_mode", overflow_mode, _builder) - else: - raise ValueError(f"Unknown overflow_mode:{overflow_mode} is found.") - return ret +# insert_slice +# extract_slice +# get_element +# sort +# flip +# gather + +import triton.language as tl +from triton.language import semantic, core, standard +from triton.language.core import ( + _constexpr_to_value, + _tensor_member_fn, + _unwrap_iterable, + builtin, + constexpr, + dtype, + tensor, + check_bit_width, + _unwrap_if_constexpr, + range +) +from triton.language.semantic import ( + wrap_tensor, + _str_to_rounding_mode, + not_equal, + _str_to_dot_input_precision, + binary_op_type_checking_impl, + integer_promote_impl, + broadcast_impl_shape, + _str_to_sem, + _str_to_scope, + bitcast, + bitwise_op_type_checking_impl, + to_tensor, + _str_to_load_cache_modifier, + _str_to_eviction_policy, + _str_to_padding_option, + _canonicalize_boundary_check, +) + +from . import is_compile_on_910_95 +from .aux_ops import compile_hint_impl + +from typing import Optional, Tuple, List, overload +from triton._C.libtriton import ir + +@_tensor_member_fn +@builtin +def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to receive tensor. + :type ful: Tensor + :param sub: The tensor to be inserted. + :type sub: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + + def insert_slice_impl(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tensor: + assert(len(ful.shape) == len(offsets)) + assert(len(ful.shape) == len(sizes)) + assert(len(ful.shape) == len(strides)) + assert(all([s>=1 for s in sizes])) + assert(all([s>=0 for s in strides])) + # Handle both tensor and int offsets (for interpreter mode) + new_offsets = [] + for o in offsets: + if isinstance(o, tensor): + new_offsets.append(o.handle) + elif isinstance(o, int): + # For interpreter mode: keep as int + new_offsets.append(o) + else: + new_offsets.append(o.handle if hasattr(o, 'handle') else o) + ret_type = tl.block_type(ful.type.scalar, ful.shape) + out = builder.create_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) + return tensor(out, ret_type) + + assert len(ful.shape) > 0 + assert len(ful.shape) == len(sub.shape) + new_offsets = [ + semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in offsets + ] + out = insert_slice_impl(ful, sub, new_offsets, sizes, strides, _builder) + return out + + +@_tensor_member_fn +@builtin +def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to split. + :type ful: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + + def extract_slice_impl(ful: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tensor: + assert(len(ful.shape) == len(offsets)) + assert(len(ful.shape) == len(sizes)) + assert(len(ful.shape) == len(strides)) + assert(all([s>=1 for s in sizes])) + assert(all([s>=0 for s in strides])) + # Handle both tensor and int offsets (for interpreter mode) + new_offsets = [] + for o in offsets: + if isinstance(o, tensor): + new_offsets.append(o.handle) + elif isinstance(o, int): + # For interpreter mode: keep as int + new_offsets.append(o) + else: + new_offsets.append(o.handle if hasattr(o, 'handle') else o) + ret_type = tl.block_type(ful.type.scalar, sizes) + out = builder.create_extract_slice(ful.handle, new_offsets, sizes, strides) + return tensor(out, ret_type) + + assert len(ful.shape) > 0 + new_offsets = [ + semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in offsets + ] + sub = extract_slice_impl(ful, new_offsets, sizes, strides, _builder) + return sub + +@_tensor_member_fn +@builtin +def get_element(src, indice, _builder=None, _generator=None): + """ + get_element op reads a ranked tensor and returns one element as specified by the given indices. + The result of the op is a value with the same type as the elements of the tensor. + The arity of indices must match the rank of the accessed value. + + :param src: The tensor to be accessed. + :type src: Tensor + :param indice: + :type indice: tuple of ints + """ + + def get_element_impl(src: tensor, indice: List[tensor], builder: ir.builder): + if len(src.shape) != len(indice): + raise ValueError("Indice's rank must be equal to src tensor's rank") + + # Handle both tensor and int indices (for interpreter mode) + new_indice = [] + for i in indice: + if isinstance(i, tensor): + new_indice.append(i.handle) + elif isinstance(i, int): + # For interpreter mode: convert int to TensorHandle + new_indice.append(i) + else: + # Try to use .handle attribute if available + new_indice.append(i.handle if hasattr(i, 'handle') else i) + + result = builder.create_extract_scalar(src.handle, new_indice) + return wrap_tensor(result, src.type.scalar, None) + + assert len(src.shape) > 0 + new_indice = [ + semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i + for i in indice + ] + return get_element_impl(src, new_indice, _builder) + +@builtin +def flip(ptr, dim=-1, _builder=None, _generator=None): + + def flip_impl(ptr: tensor, dim: int, builder: ir.builder, generator=None): + """ + Flips a tensor `ptr` along the dimension `dim`. + + :param ptr: the first input tensor + :type ptr: tensor + :param dim: the dimension to flip along + :type dim: int + :param generator: the code generator (required for reduce operations) + :type generator: generator object + """ + + def _get_flip_dim(dim, shape): + dim = _unwrap_if_constexpr(dim) + shape = _unwrap_if_constexpr(shape) + if dim is None: + dim = len(shape) - 1 + if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index + dim += len(shape) + return constexpr(dim) + + def _log2(i: core.constexpr): + log2 = 0 + n = core.constexpr(i).value + while n > 1: + n >>= 1 + log2 += 1 + return core.constexpr(log2) + + def flip_simd(ptr: tensor, dim: int, builder: ir.builder): + """ + Triton flip operation for simd + + Args: + ptr: tensor, input tensor + dim: int, dimension to flip (can be negative, normalized here) + builder: ir.builder, underlying IR builder + Returns: + flipped: tensor, same type and shape as input + """ + + shape = getattr(ptr, "shape", None) + if shape is None or shape == (): + shape = getattr(getattr(ptr, "type", None), "shape", None) + + rank = None + if shape is not None: + try: + rank = len(shape) + except Exception: + rank = len(list(shape)) + + if rank is not None: + if rank < 1: + raise ValueError("ascend.flip requires tensor rank >= 1") + norm_dim = dim if dim >= 0 else dim + rank + if not (0 <= norm_dim < rank): + raise ValueError( + f"ascend.flip got invalid dim={dim} for shape {tuple(shape)}" + ) + dim = norm_dim + else: + if dim < 0: + raise ValueError( + "ascend.flip with unknown rank requires non-negative dim" + ) + + flipped_vals = builder.create_flip(ptr.handle, dim) + flipped = tensor(flipped_vals, type=ptr.type) + return flipped + + # If compile_mode is not simt, use the simd implementation + if not builder.is_simt_mode(): + return flip_simd(ptr, dim, builder) + core.static_assert(-len(ptr.shape) <= dim and dim < len(ptr.shape), _builder=builder) + _dim: core.constexpr = _get_flip_dim(dim, ptr.shape) + core.static_assert(standard._is_power_of_two(ptr.shape[_dim]), _builder=builder) + steps: core.constexpr = _log2(ptr.shape[_dim]) + # If steps is 0, return the original tensor + if steps == 0: + return ptr + # reshape the swap dimension to (2, 2, ..., 2) + idtype = core.get_int_dtype(bitwidth=ptr.dtype.primitive_bitwidth, signed=True) + y = core.reshape(ptr.to(idtype, bitcast=True, _builder=builder), ptr.shape.__getitem__(slice(None, _dim)) + [2] * steps + ptr.shape.__getitem__(slice(_dim + 1, None)), _builder=builder) + for i in static_range(steps): + y = y.__xor__(standard.xor_sum(y, _dim + i, True, _builder=builder, _generator=generator), _builder=builder) + ptr = core.reshape(y, ptr.shape, _builder=builder).to(ptr.dtype, bitcast=True, _builder=builder) + return ptr + + try: + dim = int(dim.value) if hasattr(dim, "value") else int(dim) + except Exception as e: + raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}") from e + + dim = len(ptr.shape) - 1 if dim == -1 else dim + return flip_impl(ptr, dim, _builder, _generator) + + +class static_range: + """ + Iterator for non-JIT Python functions that need to iterate over constexpr values. + This is used in functions like flip that are called during compilation. + """ + def __init__(self, arg1, arg2=None, step=None): + if step is None: + self.step = core.constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = core.constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + + def __iter__(self): + # Extract actual values from constexpr objects for iteration + start_val = core._constexpr_to_value(self.start) + end_val = core._constexpr_to_value(self.end) + step_val = core._constexpr_to_value(self.step) + # Store as regular Python integers for iteration + self._current = start_val + self._end = end_val + self._step = step_val + return self + + def __next__(self): + if self._current >= self._end: + raise StopIteration + value = self._current + self._current += self._step + return value + + +@builtin +def sort(ptr, dim=-1, descending=False, _builder=None): + """ + sort the input tensor along 'dim' + + param: + ptr: tensor, input tensor + dim: int or tl.constexpr[int], dimension to sort + descending: bool or tl.constexpr[bool], the result is descending or not + _builder: ir.builder + return: + values: tensor, the sorted tensor + """ + + def sort_impl(ptr: tensor, dim: int, descending, builder: ir.builder): + allowed_types = {tl.int8, tl.int16, tl.bfloat16, tl.float16, tl.float32, tl.int32, tl.int64, tl.float8e4nv, tl.float8e5} + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty not in allowed_types: + raise TypeError( + f"ascend.sort only supports int8, int16, bfloat16, float16, float32, int32, int64, float8e4nv, float8e5" + f"but got {ptr.type}" + ) + + shape = getattr(ptr, "shape", None) + if shape is None or shape == (): + shape = getattr(getattr(ptr, "type", None), "shape", None) + + rank = None + if shape is not None: + try: + rank = len(shape) + except Exception: + rank = len(list(shape)) + + if rank is not None: + if rank < 1: + raise ValueError("ascend.sort requires tensor rank >= 1") + last_dim = rank - 1 + norm_dim = dim if dim >= 0 else dim + rank + if norm_dim != last_dim: + raise ValueError( + f"ascend.sort only supports sorting along the last dimension " + f"(dim={last_dim} or -1) for shape {tuple(shape)}, but got dim={dim}" + ) + dim = last_dim + else: + if dim != -1: + raise ValueError( + "ascend.sort only supports the last dimension; when rank is unknown " + "you must pass dim=-1" + ) + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + sorted_vals = builder.create_sort(ptr.handle, dim, descending) + + values = tensor(sorted_vals, type=ptr.type) + + return values + + try: + dim = int(dim.value) if hasattr(dim, "value") else int(dim) + except Exception as e: + raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}. Error: {str(e)}") from e + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + ret = sort_impl(ptr, dim, descending, _builder) + # interpreter mode not support compile_hint overflow_mode, direct return + from triton.runtime.interpreter import InterpreterBuilder + if isinstance(_builder, InterpreterBuilder): + return ret + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty.is_int8() or base_ty.is_int16(): + compile_hint_impl(ret, "overflow_mode", constexpr("saturate"), _builder) + return ret + + +def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, + fp_downcast_rounding: Optional[str] = None, overflow_mode: Optional[str] = None) -> tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if isinstance(fp_downcast_rounding, tl.constexpr): + fp_downcast_rounding = fp_downcast_rounding.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty == dst_sca_ty: + return input + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + if not is_compile_on_910_95: + if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): + raise ValueError("[fp8, fp64] is unsupported on Ascend for now." + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return ascend_cast_impl(ascend_cast_impl(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif overflow_mode == "saturate" and \ + (src_sca_ty.is_int_unsigned() or dst_sca_ty.is_int_unsigned()) and \ + src_sca_ty.int_bitwidth >= dst_sca_ty.int_bitwidth: + if is_compile_on_910_95: + result = tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + compile_hint_impl(result, "saturate_src_unsigned", src_sca_ty.is_int_unsigned(), builder) + compile_hint_impl(result, "saturate_dst_unsigned", dst_sca_ty.is_int_unsigned(), builder) + return result + else: + return ascend_cast_impl(ascend_cast_impl(input, tl.float32, builder), dst_sca_ty, builder) + return tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(ascend_cast_impl(input, tl.int64, builder), tensor(builder.get_int64(0), tl.int64), builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, overflow_mode: Optional[str] = None, _builder=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :type dtype: dtype + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :type fp_downcast_rounding: str, optional + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + :type bitcast: bool, optional + :param overflow_mode: When overflow_mode is not set or is "trunc", + truncation (cut-off) will be used to handle overflow. When + overflow_mode is "sautrate", the maximum value of the data type + will be used to handle overflow. + :type overflow_mode: string, optional + """ + overflow_modes = ["trunc", "saturate"] + input = semantic.to_tensor(input, _builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(input, dtype, _builder) + ret = ascend_cast_impl(input, dtype, _builder, fp_downcast_rounding, overflow_mode) + if overflow_mode is not None: + if overflow_mode in overflow_modes: + from triton.runtime.interpreter import InterpreterBuilder + if isinstance(_builder, InterpreterBuilder): + overflow_mode = constexpr(overflow_mode) + compile_hint_impl(ret, "overflow_mode", overflow_mode, _builder) + else: + raise ValueError(f"Unknown overflow_mode:{overflow_mode} is found.") + return ret diff --git a/third_party/ascend/language/cann/libdevice.py b/third_party/ascend/language/cann/libdevice.py index eaba0a831e..d27bdefb4f 100644 --- a/third_party/ascend/language/cann/libdevice.py +++ b/third_party/ascend/language/cann/libdevice.py @@ -22,7 +22,8 @@ from triton.language import core, math, semantic from triton._C.libtriton import ir from triton.runtime.jit import jit -from triton.backends.ascend.utils import get_ascend_arch_from_env +from triton.backends.ascend.utils import get_ascend_arch_from_env, triton_enable_libdevice_simt +from triton.tools.get_ascend_devices import is_compile_on_910_95 @core.extern @@ -82,11 +83,17 @@ def atan(arg0, _builder=None): @core.extern def tanh(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_tanhf", core.dtype("fp32")), - (core.dtype("fp16"), ): ("__hmf_tanhDh", core.dtype("fp16")), - }, is_pure=True, _builder=_builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_tanh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_tanhf", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_tanhDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) @core.extern @@ -109,16 +116,18 @@ def ldexp(arg0, arg1, _builder=None): @core.extern def pow(arg0, arg1, _builder=None): - return core.extern_elementwise( - "", "", [arg0, arg1], { - (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_powf", core.dtype("fp32")), - (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_powf", core.dtype("fp16")), - (core.dtype("bf16"), core.dtype("bf16")): ("__hmf_powf", core.dtype("bf16")), - (core.dtype("int64"), core.dtype("int64")): ("__hmf_powi", core.dtype("int64")), - (core.dtype("int32"), core.dtype("int32")): ("__hmf_powi", core.dtype("int32")), - (core.dtype("int16"), core.dtype("int16")): ("__hmf_powi", core.dtype("int16")), - (core.dtype("int8"), core.dtype("int8")): ("__hmf_powi", core.dtype("int8")), - }, is_pure=True, _builder=_builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_pow_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_powf", core.dtype("fp32")), + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_powDh", core.dtype("fp16")), + (core.dtype("bf16"), core.dtype("bf16")): ("__hmf_powDb", core.dtype("bf16")), + }, is_pure=True, _builder=_builder) @core.extern @@ -133,196 +142,351 @@ def isnan(arg0, _builder=None): @core.extern def div_rz(arg0, arg1, _builder=None): - core.static_print("tl.div_rz is unsupported for now. Use libdevice.div_rz instead.") - core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_div_rz_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + +@core.builtin +def fast_dividef(arg0, arg1, _builder=None): + arg0 = semantic.to_tensor(arg0, _builder) + arg1 = semantic.to_tensor(arg1, _builder) + ret = semantic.fdiv(arg0, arg1, False, _builder) + return ret + +@core.builtin +def fast_expf(arg0, _builder=None): + arg0 = semantic.to_tensor(arg0, _builder) + ret = core.tensor(_builder.create_exp(arg0.handle), arg0.type) + return ret @core.extern def fmod(arg0, arg1, _builder=None): - core.static_print("tl.fmod is unsupported for now. Use libdevice.fmod instead.") - core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_fmod_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"),): ("__hmf_float_as_int_fp32", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + @core.extern +def atan2(arg0, arg1, _builder): + if arg0.dtype == core.dtype("bf16") or arg1.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.atan2 for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_atan2_fp16", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_atan2_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.builtin +@math._check_dtype(dtypes=["fp32"]) +@math._add_math_1arg_docstr("trunc") def trunc(arg0, _builder=None): - core.static_print("tl.trunc is unsupported for now. Use libdevice.trunc instead.") - core.static_assert(False) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"),): ("__hmf_trunc_fp16", core.dtype("fp16")), + (core.dtype("fp32"),): ("__hmf_trunc_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + + + zero = semantic.full(arg0.shape, 0.0, arg0.type.scalar, _builder) + condition = semantic.greater_equal(arg0, zero, _builder) + + + floor_result = core.tensor(_builder.create_floor(arg0.handle), arg0.type) + ceil_result = core.tensor(_builder.create_ceil(arg0.handle), arg0.type) + + + return semantic.where(condition, floor_result, ceil_result, _builder) @core.extern def round(arg0, _builder=None): - return core.extern_elementwise("", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_roundf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) - + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_roundf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("acos") def acos(arg0: core.tensor, _builder: ir.builder): - pi = 3.1415926536 - pi_half = 1.5707963268 - sqrt2 = 1.4142135624 - eps = 1e-8 + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.acos for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"),): ("__hmf_acos_fp16", core.dtype("fp16")), + (core.dtype("fp32"),): ("__hmf_acos_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + pi = 3.1415926536 + pi_half = 1.5707963268 + sqrt2 = 1.4142135624 + eps = 1e-8 - # |x| < 0.5, acos(x) = pi/2 - [x + x*x²*(0.1666667 + x²*(0.075 + x²*(0.0446429 + 0.0303810*x²))] - arg0 = semantic.to_tensor(arg0, _builder) - abs_x = math.abs(arg0, _builder=_builder) - dtype = arg0.dtype - arg0_2 = semantic.mul(arg0, arg0, True, _builder) - arg0_4 = semantic.mul(arg0_2, arg0_2, True, _builder) - arg0_6 = semantic.mul(arg0_4, arg0_2, True, _builder) - arg0_8 = semantic.mul(arg0_6, arg0_2, True, _builder) - arg0_10 = semantic.mul(arg0_8, arg0_2, True, _builder) - poly = semantic.add(1.0, semantic.mul(0.166667, arg0_2, True, _builder), True, _builder) - poly = semantic.add(poly, semantic.mul(0.075, arg0_4, True, _builder), True, _builder) - poly = semantic.add(poly, semantic.mul(0.044643, arg0_6, True, _builder), True, _builder) - poly = semantic.add(poly, semantic.mul(0.030380, arg0_8, True, _builder), True, _builder) - poly = semantic.add(poly, semantic.mul(0.022372, arg0_10, True, _builder), True, _builder) - acos_center = semantic.sub(pi_half, semantic.mul(arg0, poly, True, _builder), True, _builder) - - # 0.5<|x|<0.9, acos(x) = 2*arctan(t), t=sqrt((1-abs_x)/(1+abs_x)) - numerator_mid = semantic.sub(1.0, abs_x, True, _builder) - denom_mid = semantic.add(1.0, abs_x, True, _builder) - div_mid = semantic.truediv(numerator_mid, denom_mid, _builder) - t_mid = math.sqrt(div_mid, _builder=_builder) - t2_mid = semantic.mul(t_mid, t_mid, True, _builder) - t4_mid = semantic.mul(t2_mid, t2_mid, True, _builder) - t6_mid = semantic.mul(t4_mid, t2_mid, True, _builder) - - # 1 + t2*(-0.3333310 + t2*(0.1999341 + t2*(-0.1420890 + t2*0.1065976))) - poly_mid1 = semantic.mul(0.1065976, t2_mid, True, _builder) - poly_mid2 = semantic.add(-0.1420890, poly_mid1, True, _builder) - poly_mid3 = semantic.mul(poly_mid2, t2_mid, True, _builder) - poly_mid4 = semantic.add(0.1999341, poly_mid3, True, _builder) - poly_mid5 = semantic.mul(poly_mid4, t2_mid, True, _builder) - poly_mid6 = semantic.add(-0.3333310, poly_mid5, True, _builder) - poly_mid = semantic.add(1.0, semantic.mul(poly_mid6, t2_mid, True, _builder), True, _builder) - arctan_t = semantic.mul(t_mid, poly_mid, True, _builder) - acos_mid = semantic.mul(2.0, arctan_t, True, _builder) - is_neg_mid = semantic.less_than(arg0, 0.0, _builder) - acos_mid_signed = semantic.where(is_neg_mid, semantic.sub(pi, acos_mid, True, _builder), acos_mid, _builder) - - is_center = semantic.less_than(abs_x, 0.5, _builder) - res_mid_boundary = semantic.where(is_center, acos_center, acos_mid_signed, _builder) - return res_mid_boundary + # |x| < 0.5, acos(x) = pi/2 - [x + x*x²*(0.1666667 + x²*(0.075 + x²*(0.0446429 + 0.0303810*x²))] + arg0 = semantic.to_tensor(arg0, _builder) + abs_x = math.abs(arg0, _builder=_builder) + dtype = arg0.dtype + arg0_2 = semantic.mul(arg0, arg0, True, _builder) + arg0_4 = semantic.mul(arg0_2, arg0_2, True, _builder) + arg0_6 = semantic.mul(arg0_4, arg0_2, True, _builder) + arg0_8 = semantic.mul(arg0_6, arg0_2, True, _builder) + arg0_10 = semantic.mul(arg0_8, arg0_2, True, _builder) + poly = semantic.add(1.0, semantic.mul(0.166667, arg0_2, True, _builder), True, _builder) + poly = semantic.add(poly, semantic.mul(0.075, arg0_4, True, _builder), True, _builder) + poly = semantic.add(poly, semantic.mul(0.044643, arg0_6, True, _builder), True, _builder) + poly = semantic.add(poly, semantic.mul(0.030380, arg0_8, True, _builder), True, _builder) + poly = semantic.add(poly, semantic.mul(0.022372, arg0_10, True, _builder), True, _builder) + acos_center = semantic.sub(pi_half, semantic.mul(arg0, poly, True, _builder), True, _builder) + + # 0.5<|x|<0.9, acos(x) = 2*arctan(t), t=sqrt((1-abs_x)/(1+abs_x)) + numerator_mid = semantic.sub(1.0, abs_x, True, _builder) + denom_mid = semantic.add(1.0, abs_x, True, _builder) + div_mid = semantic.truediv(numerator_mid, denom_mid, _builder) + t_mid = math.sqrt(div_mid, _builder=_builder) + t2_mid = semantic.mul(t_mid, t_mid, True, _builder) + t4_mid = semantic.mul(t2_mid, t2_mid, True, _builder) + t6_mid = semantic.mul(t4_mid, t2_mid, True, _builder) + + poly_mid1 = semantic.mul(0.1065976, t2_mid, True, _builder) + poly_mid2 = semantic.add(-0.1420890, poly_mid1, True, _builder) + poly_mid3 = semantic.mul(poly_mid2, t2_mid, True, _builder) + poly_mid4 = semantic.add(0.1999341, poly_mid3, True, _builder) + poly_mid5 = semantic.mul(poly_mid4, t2_mid, True, _builder) + poly_mid6 = semantic.add(-0.3333310, poly_mid5, True, _builder) + poly_mid = semantic.add(1.0, semantic.mul(poly_mid6, t2_mid, True, _builder), True, _builder) + arctan_t = semantic.mul(t_mid, poly_mid, True, _builder) + acos_mid = semantic.mul(2.0, arctan_t, True, _builder) + is_neg_mid = semantic.less_than(arg0, 0.0, _builder) + acos_mid_signed = semantic.where(is_neg_mid, semantic.sub(pi, acos_mid, True, _builder), acos_mid, _builder) + + is_center = semantic.less_than(abs_x, 0.6, _builder) + res_mid_boundary = semantic.where(is_center, acos_center, acos_mid_signed, _builder) + return res_mid_boundary @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("sinh") def sinh(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - exp0 = core.tensor(_builder.create_exp(arg0.handle), arg0.type) - exp1 = semantic.truediv(1.0, exp0, _builder) - tmp = semantic.sub(exp0, exp1, True, _builder) - ret = semantic.truediv(tmp, 2.0, _builder) - return ret + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.sinh for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"),): ("__hmf_sinh_fp16", core.dtype("fp16")), + (core.dtype("fp32"),): ("__hmf_sinh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + exp0 = core.tensor(_builder.create_exp(arg0.handle), arg0.type) + exp1 = semantic.truediv(1.0, exp0, _builder) + tmp = semantic.sub(exp0, exp1, True, _builder) + ret = semantic.truediv(tmp, 2.0, _builder) + return ret @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("cosh") def cosh(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - exp0 = core.tensor(_builder.create_exp(arg0.handle), arg0.type) - exp1 = semantic.truediv(1.0, exp0, _builder) - tmp = semantic.add(exp0, exp1, True, _builder) - ret = semantic.truediv(tmp, 2.0, _builder) - return ret + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.cosh for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"),): ("__hmf_cosh_fp16", core.dtype("fp16")), + (core.dtype("fp32"),): ("__hmf_cosh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + exp0 = core.tensor(_builder.create_exp(arg0.handle), arg0.type) + exp1 = semantic.truediv(1.0, exp0, _builder) + tmp = semantic.add(exp0, exp1, True, _builder) + ret = semantic.truediv(tmp, 2.0, _builder) + return ret @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("acosh") def acosh(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - tmp = semantic.sub(semantic.mul(arg0, arg0, True, _builder), 1.0, True, _builder) - sqrt_res = core.tensor(_builder.create_sqrt(tmp.handle), tmp.type) - sum_res = semantic.add(arg0, sqrt_res, True, _builder) - return core.tensor(_builder.create_log(sum_res.handle), sum_res.type) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.acosh for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_acosh_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_acosh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + tmp = semantic.sub(semantic.mul(arg0, arg0, True, _builder), 1.0, True, _builder) + sqrt_res = core.tensor(_builder.create_sqrt(tmp.handle), tmp.type) + sum_res = semantic.add(arg0, sqrt_res, True, _builder) + return core.tensor(_builder.create_log(sum_res.handle), sum_res.type) @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("asinh") def asinh(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - tmp = semantic.add(semantic.mul(arg0, arg0, True, _builder), 1.0, True, _builder) - sqrt_res = core.tensor(_builder.create_sqrt(tmp.handle), tmp.type) - sum_res = semantic.add(arg0, sqrt_res, True, _builder) - return core.tensor(_builder.create_log(sum_res.handle), sum_res.type) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.asinh for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_asinh_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_asinh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + tmp = semantic.add(semantic.mul(arg0, arg0, True, _builder), 1.0, True, _builder) + sqrt_res = core.tensor(_builder.create_sqrt(tmp.handle), tmp.type) + sum_res = semantic.add(arg0, sqrt_res, True, _builder) + return core.tensor(_builder.create_log(sum_res.handle), sum_res.type) @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("atanh") def atanh(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - a = semantic.add(1.0, arg0, True, _builder) - b = semantic.sub(1.0, arg0, True, _builder) - lna = core.tensor(_builder.create_log(a.handle), a.type) - lnb = core.tensor(_builder.create_log(b.handle), b.type) - tmp = semantic.sub(lna, lnb, True, _builder) - return semantic.mul(tmp, 0.5, True, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.atanh for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("__hmf_atanh_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_atanh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + a = semantic.add(1.0, arg0, True, _builder) + b = semantic.sub(1.0, arg0, True, _builder) + lna = core.tensor(_builder.create_log(a.handle), a.type) + lnb = core.tensor(_builder.create_log(b.handle), b.type) + tmp = semantic.sub(lna, lnb, True, _builder) + return semantic.mul(tmp, 0.5, True, _builder) @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("expm1") def expm1(arg0: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - tmp = core.tensor(_builder.create_exp(arg0.handle), arg0.type) - return semantic.sub(tmp, 1, True, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.expm1 for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"),): ("__hmf_expm1_fp16", core.dtype("fp16")), + (core.dtype("fp32"),): ("__hmf_expm1_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + tmp = core.tensor(_builder.create_exp(arg0.handle), arg0.type) + return semantic.sub(tmp, 1, True, _builder) @core.builtin -@math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@math._check_dtype(dtypes=["fp16", "fp32"]) @math._add_math_2arg_docstr("nextafter") def nextafter(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): - x = semantic.to_tensor(arg0, _builder) - y = semantic.to_tensor(arg1, _builder) - dtype_map = {"bf16": core.int16, "fp16": core.int16, "fp32": core.int32} - min_pos_bit = {"bf16": 0x0001, "fp16": 0x0001, "fp32": 0x00000001} - max_neg_bit = {"bf16": 0x8001, "fp16": 0x8001, "fp32": 0x80000001} - int_type = dtype_map[x.type.scalar.name] - x_eq_y = semantic.equal(x, y, _builder) - x_gt_0 = semantic.greater_than(x, 0, _builder) - y_gt_x = semantic.greater_than(y, x, _builder) - next_neg = semantic.xor_(x_gt_0, y_gt_x, _builder) - next_pos = semantic.not_(next_neg, _builder) - - p1 = semantic.full(x.shape, 1, int_type, _builder) - n1 = semantic.full(x.shape, -1, int_type, _builder) - dir_xy = semantic.where(next_pos, p1, n1, _builder) - x_abs = math.abs(x, _builder=_builder) - x_is_0 = semantic.equal(x_abs, 0, _builder) - - min_pos = semantic.full(x.shape, min_pos_bit[x.type.scalar.name], int_type, _builder) - max_neg = semantic.full(x.shape, max_neg_bit[x.type.scalar.name], int_type, _builder) - min_pos = semantic.bitcast(min_pos, x.dtype, _builder) - max_neg = semantic.bitcast(max_neg, x.dtype, _builder) - bits_x = semantic.bitcast(x, int_type, _builder) - bits_next = semantic.add(bits_x, dir_xy, True, _builder) - next_val = semantic.bitcast(bits_next, x.dtype, _builder) - - need_min_pos = semantic.logical_and(x_is_0, next_pos, _builder) - need_max_neg = semantic.logical_and(x_is_0, next_neg, _builder) - next_val = semantic.where(need_min_pos, min_pos, next_val, _builder) - next_val = semantic.where(need_max_neg, max_neg, next_val, _builder) - return semantic.where(x_eq_y, x, next_val, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_nextafter_fp16", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_nextafter_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + x = semantic.to_tensor(arg0, _builder) + y = semantic.to_tensor(arg1, _builder) + dtype_map = { + "bf16": core.int16, + "fp16": core.int16, + "fp32": core.int32 + } + min_pos_bit = { + "bf16": 0x0001, + "fp16": 0x0001, + "fp32": 0x00000001 + } + max_neg_bit = { + "bf16": 0x8001, + "fp16": 0x8001, + "fp32": 0x80000001 + } + int_type = dtype_map[x.type.scalar.name] + x_eq_y = semantic.equal(x, y, _builder) + x_gt_0 = semantic.greater_than(x, 0, _builder) + y_gt_x = semantic.greater_than(y, x, _builder) + next_neg = semantic.xor_(x_gt_0, y_gt_x, _builder) + next_pos = semantic.not_(next_neg, _builder) + + p1 = semantic.full(x.shape, 1, int_type, _builder) + n1 = semantic.full(x.shape, -1, int_type, _builder) + dir_xy = semantic.where(next_pos, p1, n1, _builder) + x_abs = math.abs(x, _builder=_builder) + x_is_0 = semantic.equal(x_abs, 0, _builder) + + min_pos = semantic.full(x.shape, min_pos_bit[x.type.scalar.name], int_type, _builder) + max_neg = semantic.full(x.shape, max_neg_bit[x.type.scalar.name], int_type, _builder) + min_pos = semantic.bitcast(min_pos, x.dtype, _builder) + max_neg = semantic.bitcast(max_neg, x.dtype, _builder) + bits_x = semantic.bitcast(x, int_type, _builder) + bits_next = semantic.add(bits_x, dir_xy, True, _builder) + next_val = semantic.bitcast(bits_next, x.dtype, _builder) + + need_min_pos = semantic.logical_and(x_is_0, next_pos, _builder) + need_max_neg = semantic.logical_and(x_is_0, next_neg, _builder) + next_val = semantic.where(need_min_pos, min_pos, next_val, _builder) + next_val = semantic.where(need_max_neg, max_neg, next_val, _builder) + return semantic.where(x_eq_y, x, next_val, _builder) @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_2arg_docstr("hypot(Euclidean Distance)") def hypot(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): - arg0 = semantic.to_tensor(arg0, _builder) - arg1 = semantic.to_tensor(arg1, _builder) - x2 = semantic.mul(arg0, arg0, True, _builder) - y2 = semantic.mul(arg1, arg1, True, _builder) - sum_res = semantic.add(x2, y2, True, _builder) - return core.tensor(_builder.create_sqrt(sum_res.handle), sum_res.type) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("bf16"): + core.static_print("extern livdevice.hypot for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_hypot_fp16", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_hypot_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0 = semantic.to_tensor(arg0, _builder) + arg1 = semantic.to_tensor(arg1, _builder) + x2 = semantic.mul(arg0, arg0, True, _builder) + y2 = semantic.mul(arg1, arg1, True, _builder) + sum_res = semantic.add(x2, y2, True, _builder) + return core.tensor(_builder.create_sqrt(sum_res.handle), sum_res.type) # This function is derived from the Cephes Math Library release 2.8: June, 2000 @@ -333,117 +497,134 @@ def hypot(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): @math._check_dtype(dtypes=["fp16", "fp32"]) @math._add_math_2arg_docstr("besseli0 (Modified Bessel function of the first kind, order 0).") def cyl_bessel_i0(arg0: core.tensor, _builder: ir.builder): - param1 = [ - -4.41534164647933937950e-18, - +3.33079451882223809783e-17, - -2.43127984654795469359e-16, - +1.71539128555513303061e-15, - -1.16853328779934516808e-14, - +7.67618549860493561688e-14, - -4.85644678311192946090e-13, - +2.95505266312963983461e-12, - -1.72682629144155570723e-11, - +9.67580903537323691224e-11, - -5.18979560163526290666e-10, - +2.65982372468238665035e-09, - -1.30002500998624804212e-08, - +6.04699502254191894932e-08, - -2.67079385394061173391e-07, - +1.11738753912010371815e-06, - -4.41673835845875056359e-06, - +1.64484480707288970893e-05, - -5.75419501008210370398e-05, - +1.88502885095841655729e-04, - -5.76375574538582365885e-04, - +1.63947561694133579842e-03, - -4.32430999505057594430e-03, - +1.05464603945949983183e-02, - -2.37374148058994688156e-02, - +4.93052842396707084878e-02, - -9.49010970480476444210e-02, - +1.71620901522208775349e-01, - -3.04682672343198398683e-01, - +6.76795274409476084995e-01, - ] - param2 = [ - -7.23318048787475395456e-18, - -4.83050448594418207126e-18, - +4.46562142029675999901e-17, - +3.46122286769746109310e-17, - -2.82762398051658348494e-16, - -3.42548561967721913462e-16, - +1.77256013305652638360e-15, - +3.81168066935262242075e-15, - -9.55484669882830764870e-15, - -4.15056934728722208663e-14, - +1.54008621752140982691e-14, - +3.85277838274214270114e-13, - +7.18012445138366623367e-13, - -1.79417853150680611778e-12, - -1.32158118404477131188e-11, - -3.14991652796324136454e-11, - +1.18891471078464383424e-11, - +4.94060238822496958910e-10, - +3.39623202570838634515e-09, - +2.26666899049817806459e-08, - +2.04891858946906374183e-07, - +2.89137052083475648297e-06, - +6.88975834691682398426e-05, - +3.36911647825569408990e-03, - +8.04490411014108831608e-01, - ] - arg0 = semantic.to_tensor(arg0, _builder) - abs_x = core.tensor(_builder.create_fabs(arg0.handle), arg0.type) - x_a = semantic.sub(semantic.mul(abs_x, 0.5, True, _builder), 2.0, True, _builder) - a_n_2 = 0 - a_n_1 = 0 - a_n = param1[0] - for i in range(1, 30): - a_n_2 = a_n_1 - a_n_1 = a_n - a_n = semantic.sub(semantic.mul(x_a, a_n_1, True, _builder), a_n_2, True, _builder) - a_n = semantic.add(a_n, param1[i], True, _builder) - - f_32 = semantic.full(abs_x.shape, 32.0, abs_x.type.scalar, _builder) - x_b = semantic.sub(semantic.fdiv(f_32, abs_x, True, _builder), 2.0, True, _builder) - b_n_2 = 0 - b_n_1 = 0 - b_n = param2[0] - for i in range(1, 25): - b_n_2 = b_n_1 - b_n_1 = b_n - b_n = semantic.sub(semantic.mul(x_b, b_n_1, True, _builder), b_n_2, True, _builder) - b_n = semantic.add(b_n, param2[i], True, _builder) - - half_exp = semantic.mul(core.tensor(_builder.create_exp(abs_x.handle), abs_x.type), 0.5, True, _builder) - res_a = semantic.mul(half_exp, semantic.sub(a_n, a_n_2, True, _builder), True, _builder) - res_b = semantic.fdiv(semantic.mul(half_exp, semantic.sub(b_n, b_n_2, True, _builder), True, _builder), \ - core.tensor(_builder.create_sqrt(abs_x.handle), abs_x.type), True, _builder) - cond = semantic.less_equal(abs_x, 8.0, _builder) - res = semantic.where(cond, res_a, res_b, _builder) - return res + if triton_enable_libdevice_simt() and is_compile_on_910_95: + if arg0.dtype == core.dtype("fp16"): + core.static_print("extern livdevice.cyl_bessel_i0 for dtype bf16 is unspported for now.") + core.static_assert(False) + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_cyl_bessel_i0_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + param1 = [ + -4.41534164647933937950e-18, + +3.33079451882223809783e-17, + -2.43127984654795469359e-16, + +1.71539128555513303061e-15, + -1.16853328779934516808e-14, + +7.67618549860493561688e-14, + -4.85644678311192946090e-13, + +2.95505266312963983461e-12, + -1.72682629144155570723e-11, + +9.67580903537323691224e-11, + -5.18979560163526290666e-10, + +2.65982372468238665035e-09, + -1.30002500998624804212e-08, + +6.04699502254191894932e-08, + -2.67079385394061173391e-07, + +1.11738753912010371815e-06, + -4.41673835845875056359e-06, + +1.64484480707288970893e-05, + -5.75419501008210370398e-05, + +1.88502885095841655729e-04, + -5.76375574538582365885e-04, + +1.63947561694133579842e-03, + -4.32430999505057594430e-03, + +1.05464603945949983183e-02, + -2.37374148058994688156e-02, + +4.93052842396707084878e-02, + -9.49010970480476444210e-02, + +1.71620901522208775349e-01, + -3.04682672343198398683e-01, + +6.76795274409476084995e-01, + ] + param2 = [ + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + +4.46562142029675999901e-17, + +3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + +1.77256013305652638360e-15, + +3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + +1.54008621752140982691e-14, + +3.85277838274214270114e-13, + +7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + +1.18891471078464383424e-11, + +4.94060238822496958910e-10, + +3.39623202570838634515e-09, + +2.26666899049817806459e-08, + +2.04891858946906374183e-07, + +2.89137052083475648297e-06, + +6.88975834691682398426e-05, + +3.36911647825569408990e-03, + +8.04490411014108831608e-01, + ] + arg0 = semantic.to_tensor(arg0, _builder) + abs_x = core.tensor(_builder.create_fabs(arg0.handle), arg0.type) + x_a = semantic.sub(semantic.mul(abs_x, 0.5, True, _builder), 2.0, True, _builder) + a_n_2 = 0 + a_n_1 = 0 + a_n = param1[0] + for i in range(1, 30): + a_n_2 = a_n_1 + a_n_1 = a_n + a_n = semantic.sub(semantic.mul(x_a, a_n_1, True, _builder), a_n_2, True, _builder) + a_n = semantic.add(a_n, param1[i], True, _builder) + + f_32 = semantic.full(abs_x.shape, 32.0, abs_x.type.scalar, _builder) + x_b = semantic.sub(semantic.fdiv(f_32, abs_x, True, _builder), 2.0, True, _builder) + b_n_2 = 0 + b_n_1 = 0 + b_n = param2[0] + for i in range(1, 25): + b_n_2 = b_n_1 + b_n_1 = b_n + b_n = semantic.sub(semantic.mul(x_b, b_n_1, True, _builder), b_n_2, True, _builder) + b_n = semantic.add(b_n, param2[i], True, _builder) + + half_exp = semantic.mul(core.tensor(_builder.create_exp(abs_x.handle), abs_x.type), 0.5, True, _builder) + res_a = semantic.mul(half_exp, semantic.sub(a_n, a_n_2, True, _builder), True, _builder) + res_b = semantic.fdiv(semantic.mul(half_exp, semantic.sub(b_n, b_n_2, True, _builder), True, _builder), \ + core.tensor(_builder.create_sqrt(abs_x.handle), abs_x.type), True, _builder) + cond = semantic.less_equal(abs_x, 8.0, _builder) + res = semantic.where(cond, res_a, res_b, _builder) + return res @core.extern @math._check_dtype(dtypes=["fp16", "fp32"]) def signbit(arg0, _builder=None): - arg0_scalar_ty = arg0.type.scalar - if arg0_scalar_ty == core.float32: - int_ty = core.int32 - else: # arg0 type: float16 / bfloat16 - int_ty = core.int16 + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"),): ("__hmf_signbit_fp16", core.dtype("int32")), + (core.dtype("fp32"),): ("__hmf_signbit_fp32", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + else: + arg0_scalar_ty = arg0.type.scalar + if arg0_scalar_ty == core.float32: + int_ty = core.int32 + else: # arg0 type: float16 / bfloat16 + int_ty = core.int16 - arg0 = semantic.to_tensor(arg0, _builder) - int_tensor = semantic.bitcast(arg0, int_ty, _builder) - if int_ty == core.int32: - shift = 31 - elif int_ty == core.int16: - shift = 15 + arg0 = semantic.to_tensor(arg0, _builder) + int_tensor = semantic.bitcast(arg0, int_ty, _builder) + if int_ty == core.int32: + shift = 31 + elif int_ty == core.int16: + shift = 15 - shift = semantic.full(arg0.shape, shift, int_ty, _builder) - sign_bit_tensor = semantic.lshr(int_tensor, shift, _builder) - sign_bit_tensor = semantic.and_(sign_bit_tensor, semantic.full(arg0.shape, 1, int_ty, _builder), _builder) - return semantic.equal(sign_bit_tensor, 1, _builder) + shift = semantic.full(arg0.shape, shift, int_ty, _builder) + sign_bit_tensor = semantic.lshr(int_tensor, shift, _builder) + sign_bit_tensor = semantic.and_( + sign_bit_tensor, semantic.full(arg0.shape, 1, int_ty, _builder), _builder) + return semantic.equal(sign_bit_tensor, 1, _builder) # Note: @@ -455,100 +636,135 @@ def signbit(arg0, _builder=None): @core.extern @math._check_dtype(dtypes=["fp32"]) def erfinv(arg0, _builder=None): - arg0_scalar_ty = arg0.type.scalar - arg0 = semantic.to_tensor(arg0, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_erfinv_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0_scalar_ty = arg0.type.scalar + arg0 = semantic.to_tensor(arg0, _builder) - inv_sqrt_pi_times_2 = semantic.full(arg0.shape, 1.128379167, arg0_scalar_ty, _builder).handle # 2 / sqrt(pi) - coeff_low_numerator = [-0.140543331, 0.914624893, -1.645349621, 0.886226899] - coeff_low_denominator = [0.012229801, -0.329097515, 1.442710462, -2.118377725, 1.0] - coeff_high_numerator = [1.641345311, 3.429567803, -1.624906493, -1.970840454] - coeff_high_denominator = [1.6370678, 3.5438892, 1.0] - - # low cal - arg0_squared = _builder.create_fmul(arg0.handle, arg0.handle) - numerator_low_range = semantic.full(arg0.shape, coeff_low_numerator[0], arg0_scalar_ty, _builder).handle - for i in range(1, len(coeff_low_numerator)): - numerator_low_range = _builder.create_fma( - numerator_low_range, arg0_squared, - semantic.full(arg0.shape, coeff_low_numerator[i], arg0_scalar_ty, _builder).handle) - - denominator_low_range = semantic.full(arg0.shape, coeff_low_denominator[0], arg0_scalar_ty, _builder).handle - for i in range(1, len(coeff_low_denominator)): - denominator_low_range = _builder.create_fma( - denominator_low_range, arg0_squared, - semantic.full(arg0.shape, coeff_low_denominator[i], arg0_scalar_ty, _builder).handle) - - low_res = _builder.create_fmul(arg0.handle, _builder.create_fdiv(numerator_low_range, denominator_low_range)) - - # high cal - arg0_erf_trans = _builder.create_sqrt( # (log2-log(1-|arg0|))^1/2 - _builder.create_fmul( - semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, - _builder.create_log( - _builder.create_fdiv( + inv_sqrt_pi_times_2 = semantic.full( + arg0.shape, 1.128379167, arg0_scalar_ty, _builder).handle # 2 / sqrt(pi) + coeff_low_numerator = [-0.140543331, 0.914624893, -1.645349621, 0.886226899] + coeff_low_denominator = [0.012229801, -0.329097515, 1.442710462, -2.118377725, 1.0] + coeff_high_numerator = [1.641345311, 3.429567803, -1.624906493, -1.970840454] + coeff_high_denominator = [1.6370678, 3.5438892, 1.0] + + # low cal + arg0_squared = _builder.create_fmul(arg0.handle, arg0.handle) + numerator_low_range = semantic.full( + arg0.shape, coeff_low_numerator[0], arg0_scalar_ty, _builder).handle + for i in range(1, len(coeff_low_numerator)): + numerator_low_range = _builder.create_fma(numerator_low_range, arg0_squared, + semantic.full(arg0.shape, coeff_low_numerator[i], arg0_scalar_ty, _builder).handle) + + denominator_low_range = semantic.full( + arg0.shape, coeff_low_denominator[0], arg0_scalar_ty, _builder).handle + for i in range(1, len(coeff_low_denominator)): + denominator_low_range = _builder.create_fma( + denominator_low_range, arg0_squared, semantic.full( + arg0.shape, coeff_low_denominator[i], arg0_scalar_ty, _builder).handle) + + low_res = _builder.create_fmul(arg0.handle, _builder.create_fdiv(numerator_low_range, denominator_low_range)) + + # high cal + arg0_erf_trans = _builder.create_sqrt( # (log2-log(1-|arg0|))^1/2 + _builder.create_fmul( + semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, + _builder.create_log( + _builder.create_fdiv( + _builder.create_fsub( + semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder).handle, + _builder.create_fabs(arg0.handle) + ), + semantic.full(arg0.shape, 2, arg0_scalar_ty, _builder).handle + ) + ) + ) + ) + numerator_high_range = semantic.full(arg0.shape, coeff_high_numerator[0], arg0_scalar_ty, _builder).handle + for i in range(1, len(coeff_high_numerator)): + numerator_high_range = _builder.create_fma( + numerator_high_range, arg0_erf_trans, semantic.full( + arg0.shape, coeff_high_numerator[i], arg0_scalar_ty, _builder).handle) + + denominator_high_range = semantic.full(arg0.shape, coeff_high_denominator[0], arg0_scalar_ty, _builder).handle + for i in range(1, len(coeff_high_denominator)): + denominator_high_range = _builder.create_fma( + denominator_high_range, arg0_erf_trans, semantic.full( + arg0.shape, coeff_high_denominator[i], arg0_scalar_ty, _builder).handle) + + high_res = _builder.create_fdiv(numerator_high_range, denominator_high_range) + high_res = semantic.mul( + semantic.where( + signbit(arg0, _builder=_builder), + semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder), + semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), + _builder), + core.tensor(high_res, arg0.type), True, _builder + ).handle + + for _ in range(2): + low_res = _builder.create_fsub( + low_res, _builder.create_fdiv( _builder.create_fsub( - semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder).handle, - _builder.create_fabs(arg0.handle)), - semantic.full(arg0.shape, 2, arg0_scalar_ty, _builder).handle)))) - numerator_high_range = semantic.full(arg0.shape, coeff_high_numerator[0], arg0_scalar_ty, _builder).handle - for i in range(1, len(coeff_high_numerator)): - numerator_high_range = _builder.create_fma( - numerator_high_range, arg0_erf_trans, - semantic.full(arg0.shape, coeff_high_numerator[i], arg0_scalar_ty, _builder).handle) - - denominator_high_range = semantic.full(arg0.shape, coeff_high_denominator[0], arg0_scalar_ty, _builder).handle - for i in range(1, len(coeff_high_denominator)): - denominator_high_range = _builder.create_fma( - denominator_high_range, arg0_erf_trans, - semantic.full(arg0.shape, coeff_high_denominator[i], arg0_scalar_ty, _builder).handle) - - high_res = _builder.create_fdiv(numerator_high_range, denominator_high_range) - high_res = semantic.mul( - semantic.where(signbit(arg0, _builder=_builder), semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder), - semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder), - core.tensor(high_res, arg0.type), True, _builder).handle - - for i in range(2): - low_res = _builder.create_fsub( - low_res, - _builder.create_fdiv( - _builder.create_fsub(_builder.create_erf(low_res), arg0.handle), - _builder.create_fmul( - inv_sqrt_pi_times_2, - _builder.create_exp( - _builder.create_fmul( - semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, - _builder.create_fmul(low_res, low_res)))))) - - high_res = _builder.create_fsub( - high_res, - _builder.create_fdiv( - _builder.create_fsub(_builder.create_erf(high_res), arg0.handle), - _builder.create_fmul( - inv_sqrt_pi_times_2, - _builder.create_exp( - _builder.create_fmul( - semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, - _builder.create_fmul(high_res, high_res)))))) - - arg0_abs = core.tensor(_builder.create_fabs(arg0.handle), arg0.type) - # Check if |arg0| > 1 - arg0_over = semantic.greater_than(arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder) - nan_tensor = semantic.full(arg0.shape, float("nan"), arg0_scalar_ty, _builder) - # Check if |arg0| = 1 - arg0_equal1 = semantic.equal(arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder) - pos_inf_tensor = semantic.full(arg0.shape, float("inf"), arg0_scalar_ty, _builder) - neg_inf_tensor = semantic.full(arg0.shape, float("-inf"), arg0_scalar_ty, _builder) - inf_res = semantic.where(signbit(arg0, _builder=_builder), neg_inf_tensor, pos_inf_tensor, _builder) - # Check if |arg0| >= 0.7 - arg0_high = semantic.greater_equal(arg0_abs, semantic.full(arg0.shape, 0.7, arg0_scalar_ty, _builder), _builder) - - return semantic.where( - arg0_equal1, inf_res, - semantic.where( - arg0_over, nan_tensor, - semantic.where(arg0_high, core.tensor(high_res, arg0.type), core.tensor(low_res, arg0.type), _builder), - _builder), _builder) + _builder.create_erf(low_res), arg0.handle + ), + _builder.create_fmul( + inv_sqrt_pi_times_2, _builder.create_exp( + _builder.create_fmul( + semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, + _builder.create_fmul(low_res, low_res) + ) + ) + ) + ) + ) + + high_res = _builder.create_fsub( + high_res, _builder.create_fdiv( + _builder.create_fsub( + _builder.create_erf(high_res), arg0.handle + ), + _builder.create_fmul( + inv_sqrt_pi_times_2, _builder.create_exp( + _builder.create_fmul( + semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, + _builder.create_fmul(high_res, high_res) + ) + ) + ) + ) + ) + + arg0_abs = core.tensor(_builder.create_fabs(arg0.handle), arg0.type) + # Check if |arg0| > 1 + arg0_over = semantic.greater_than( + arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder) + nan_tensor = semantic.full(arg0.shape, float("nan"), arg0_scalar_ty, _builder) + # Check if |arg0| = 1 + arg0_equal1 = semantic.equal( + arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder + ) + pos_inf_tensor = semantic.full(arg0.shape, float("inf"), arg0_scalar_ty, _builder) + neg_inf_tensor = semantic.full(arg0.shape, float("-inf"), arg0_scalar_ty, _builder) + inf_res = semantic.where( + signbit(arg0, _builder=_builder), neg_inf_tensor, pos_inf_tensor, _builder + ) + # Check if |arg0| >= 0.7 + arg0_high = semantic.greater_equal( + arg0_abs, semantic.full(arg0.shape, 0.7, arg0_scalar_ty, _builder), _builder + ) + + return semantic.where( + arg0_equal1, inf_res, semantic.where( + arg0_over, nan_tensor, semantic.where( + arg0_high, core.tensor(high_res, arg0.type), core.tensor(low_res, arg0.type), _builder + ), _builder + ), _builder + ) # Note: @@ -570,7 +786,9 @@ def gamma(arg0, _builder=None): -0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7 ] condition = semantic.less_than(arg0, 0.5, _builder) # 1 - x = x -> x = 0.5 - reflect_arg0 = semantic.where(condition, semantic.sub(1, arg0, True, _builder), arg0, _builder) + reflect_arg0 = semantic.where( + condition, semantic.sub(1, arg0, True, _builder), arg0, _builder + ) x = semantic.full(arg0.shape, 0.99999999999980993, arg0_scalar_ty, _builder) for i in range(0, len(lanczos_coeff)): @@ -584,26 +802,39 @@ def gamma(arg0, _builder=None): _builder.create_fmul(sqrt_2pi_tensor, pow(t, semantic.sub(reflect_arg0, 0.5, True, _builder), _builder=_builder).handle), _builder.create_fmul( - x.handle, - _builder.create_exp( - _builder.create_fmul(t.handle, - semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle)))) - - gamma_res_reflect = _builder.create_fdiv(_builder.create_fdiv(pi_tensor, gamma_res), - _builder.create_sin(_builder.create_fmul(pi_tensor, arg0.handle))) - - is_neg_int = semantic.logical_and(semantic.equal(math.floor(arg0, _builder=_builder), arg0, _builder), - semantic.less_than(arg0, 0, _builder), _builder) + sqrt_2pi_tensor, pow( + t, semantic.sub(reflect_arg0, 0.5, True, _builder), _builder=_builder + ).handle + ), + _builder.create_fmul( + x.handle, _builder.create_exp( + _builder.create_fmul( + t.handle, semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle + ) + ) + ) + ) + + gamma_res_reflect = _builder.create_fdiv( + _builder.create_fdiv(pi_tensor, gamma_res), + _builder.create_sin(_builder.create_fmul(pi_tensor, arg0.handle)) + ) + + is_neg_int = semantic.logical_and( + semantic.equal(math.floor(arg0, _builder=_builder), arg0, _builder), + semantic.less_than(arg0, 0, _builder), _builder + ) pos_inf_tensor = semantic.full(arg0.shape, float('inf'), arg0_scalar_ty, _builder) neg_inf_tensor = semantic.full(arg0.shape, float('-inf'), arg0_scalar_ty, _builder) - gamma_res_reflect = semantic.where(is_neg_int, pos_inf_tensor, core.tensor(gamma_res_reflect, arg0.type), _builder) + gamma_res_reflect = semantic.where( + is_neg_int, pos_inf_tensor, core.tensor(gamma_res_reflect, arg0.type), _builder) res = semantic.where(condition, gamma_res_reflect, core.tensor(gamma_res, arg0.type), _builder) is_pos_inf_input = semantic.equal(arg0, pos_inf_tensor, _builder) is_neg_inf_input = semantic.equal(arg0, neg_inf_tensor, _builder) - return semantic.where(is_pos_inf_input, pos_inf_tensor, - semantic.where(is_neg_inf_input, neg_inf_tensor, res, _builder), _builder) + return semantic.where(is_pos_inf_input, pos_inf_tensor, semantic.where( + is_neg_inf_input, neg_inf_tensor, res, _builder), _builder) # Note: @@ -617,43 +848,23 @@ def gamma(arg0, _builder=None): @core.extern @math._check_dtype(dtypes=["fp32"]) def lgamma(arg0, _builder=None): - arg0_scalar_ty = arg0.type.scalar - arg0 = semantic.to_tensor(arg0, _builder) - - inf_tensor = semantic.full(arg0.shape, float('inf'), arg0_scalar_ty, _builder) - is_inf = semantic.equal(core.tensor(_builder.create_fabs(arg0.handle), arg0.type), inf_tensor, _builder) - gamma_res = _builder.create_fabs(gamma(arg0, _builder=_builder).handle) - lgamma_res = _builder.create_log(gamma_res) - - return semantic.where(is_inf, inf_tensor, core.tensor(lgamma_res, arg0.type), _builder) - - -@core.builtin -@math._check_dtype(dtypes=[ - "fp32", -]) -@math._add_math_1arg_docstr("trunc") -def trunc(arg0: core.tensor, _builder: ir.builder): - """ - Truncate the input to the nearest integer toward zero. - - For positive numbers, this is equivalent to floor(x). - For negative numbers, this is equivalent to ceil(x). - - Special cases: - - trunc(±0) returns ±0. - - trunc(±inf) returns ±inf. - - trunc(NaN) returns NaN. - """ - arg0 = semantic.to_tensor(arg0, _builder) - - zero = semantic.full(arg0.shape, 0.0, arg0.type.scalar, _builder) - condition = semantic.greater_equal(arg0, zero, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_lgamma_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + arg0_scalar_ty = arg0.type.scalar + arg0 = semantic.to_tensor(arg0, _builder) - floor_result = core.tensor(_builder.create_floor(arg0.handle), arg0.type) - ceil_result = core.tensor(_builder.create_ceil(arg0.handle), arg0.type) + inf_tensor = semantic.full(arg0.shape, float('inf'), arg0_scalar_ty, _builder) + is_inf = semantic.equal( + core.tensor(_builder.create_fabs(arg0.handle), arg0.type), inf_tensor, _builder + ) + gamma_res = _builder.create_fabs(gamma(arg0, _builder=_builder).handle) + lgamma_res = _builder.create_log(gamma_res) - return semantic.where(condition, floor_result, ceil_result, _builder) + return semantic.where(is_inf, inf_tensor, core.tensor(lgamma_res, arg0.type), _builder) @core.builtin @@ -662,47 +873,56 @@ def trunc(arg0: core.tensor, _builder: ir.builder): ]) @math._add_math_1arg_docstr("nearbyint") def nearbyint(arg0: core.tensor, _builder: ir.builder): - """ - Round argument x to an integer value in floating-point format. + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"),): ("__hmf_nearbyint_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + """ + Round argument x to an integer value in floating-point format. - Uses the current rounding mode (round-to-nearest-even, aka banker's rounding). - """ - arg0 = semantic.to_tensor(arg0, _builder) + Uses the current rounding mode (round-to-nearest-even, aka banker's rounding). + """ + arg0 = semantic.to_tensor(arg0, _builder) - half = semantic.full(arg0.shape, 0.5, arg0.type.scalar, _builder) + half = semantic.full(arg0.shape, 0.5, arg0.type.scalar, _builder) - positive_adjust = semantic.add(arg0, half, True, _builder) - negative_adjust = semantic.sub(arg0, half, True, _builder) + positive_adjust = semantic.add(arg0, half, True, _builder) + negative_adjust = semantic.sub(arg0, half, True, _builder) - positive_result = core.tensor(_builder.create_floor(positive_adjust.handle), arg0.type) - negative_result = core.tensor(_builder.create_ceil(negative_adjust.handle), arg0.type) + positive_result = core.tensor(_builder.create_floor(positive_adjust.handle), arg0.type) + negative_result = core.tensor(_builder.create_ceil(negative_adjust.handle), arg0.type) - zero = semantic.full(arg0.shape, 0.0, arg0.type.scalar, _builder) - is_positive = semantic.greater_equal(arg0, zero, _builder) - basic_round = semantic.where(is_positive, positive_result, negative_result, _builder) + zero = semantic.full(arg0.shape, 0.0, arg0.type.scalar, _builder) + is_positive = semantic.greater_equal(arg0, zero, _builder) + basic_round = semantic.where(is_positive, positive_result, negative_result, _builder) - # Banker's rounding special treatment: For values exactly in the middle, round to the nearest even number. - fractional = semantic.sub(arg0, basic_round, True, _builder) - abs_fractional = core.tensor(_builder.create_fabs(fractional.handle), fractional.type) + # Banker's rounding special treatment: For values exactly in the middle, round to the nearest even number. + fractional = semantic.sub(arg0, basic_round, True, _builder) + abs_fractional = core.tensor(_builder.create_fabs(fractional.handle), fractional.type) - is_half = semantic.equal(abs_fractional, half, _builder) + is_half = semantic.equal(abs_fractional, half, _builder) - two = semantic.full(arg0.shape, 2.0, arg0.type.scalar, _builder) + two = semantic.full(arg0.shape, 2.0, arg0.type.scalar, _builder) - half_value = math.fdiv(basic_round, two, _builder=_builder) - half_floor = core.tensor(_builder.create_floor(half_value.handle), half_value.type) - double_half = semantic.mul(half_floor, two, True, _builder) + half_value = math.fdiv(basic_round, two, _builder=_builder) + half_floor = core.tensor(_builder.create_floor(half_value.handle), half_value.type) + double_half = semantic.mul(half_floor, two, True, _builder) - is_even = semantic.equal(basic_round, double_half, _builder) + is_even = semantic.equal(basic_round, double_half, _builder) - adjustment = semantic.where(is_positive, semantic.full(arg0.shape, -1.0, arg0.type.scalar, _builder), - semantic.full(arg0.shape, 1.0, arg0.type.scalar, _builder), _builder) + adjustment = semantic.where(is_positive, + semantic.full(arg0.shape, -1.0, arg0.type.scalar, _builder), + semantic.full(arg0.shape, 1.0, arg0.type.scalar, _builder), + _builder) - banker_result = semantic.where(is_even, basic_round, semantic.add(basic_round, adjustment, True, _builder), - _builder) + banker_result = semantic.where(is_even, basic_round, + semantic.add(basic_round, adjustment, True, _builder), + _builder) - # Final result: Use banker's rounding for cases exactly at 0.5, otherwise use basic rounding. - return semantic.where(is_half, banker_result, basic_round, _builder) + # Final result: Use banker's rounding for cases exactly at 0.5, otherwise use basic rounding. + return semantic.where(is_half, banker_result, basic_round, _builder) @core.builtin @@ -711,18 +931,25 @@ def nearbyint(arg0: core.tensor, _builder: ir.builder): ]) @math._add_math_1arg_docstr("arcsine") def asin(arg0: core.tensor, _builder: ir.builder): - """ - Calculate the principal value of the arc sine of the input argument x. + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"),): ("__hmf_asin_fp16", core.dtype("fp16")), + (core.dtype("fp32"),): ("__hmf_asin_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + """ + Calculate the principal value of the arc sine of the input argument x. - Returns result in radians, in the interval [-π/2, +π/2] for x inside [-1, +1]. - Returns NaN for x outside [-1, +1]. - """ - arg0 = semantic.to_tensor(arg0, _builder) + Returns result in radians, in the interval [-π/2, +π/2] for x inside [-1, +1]. + Returns NaN for x outside [-1, +1]. + """ + arg0 = semantic.to_tensor(arg0, _builder) - # asin(x) = π/2 - acos(x) - half_pi = semantic.full(arg0.shape, 1.5707963267948966, arg0.type.scalar, _builder) # π/2 - acos_val = acos(arg0, _builder=_builder) - return semantic.sub(half_pi, acos_val, True, _builder) + # asin(x) = π/2 - acos(x) + half_pi = semantic.full(arg0.shape, 1.5707963267948966, arg0.type.scalar, _builder) # π/2 + acos_val = acos(arg0, _builder=_builder) + return semantic.sub(half_pi, acos_val, True, _builder) @core.builtin @@ -731,18 +958,24 @@ def asin(arg0: core.tensor, _builder: ir.builder): ]) @math._add_math_1arg_docstr("base-10 logarithm") def log10(arg0: core.tensor, _builder: ir.builder): - """ - Calculate the base 10 logarithm of the input argument x. + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"),): ("__hmf_log10_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + """ + Calculate the base 10 logarithm of the input argument x. - Returns NaN for x < 0, -inf for x = 0, and +0 for x = 1. - log10(x) = log(x) / log(10) - """ - arg0 = semantic.to_tensor(arg0, _builder) + Returns NaN for x < 0, -inf for x = 0, and +0 for x = 1. + log10(x) = log(x) / log(10) + """ + arg0 = semantic.to_tensor(arg0, _builder) - log_val = math.log(arg0, _builder=_builder) - log10_const = semantic.full(arg0.shape, 2.302585092994046, arg0.type.scalar, _builder) + log_val = math.log(arg0, _builder=_builder) + log10_const = semantic.full(arg0.shape, 2.302585092994046, arg0.type.scalar, _builder) - return math.fdiv(log_val, log10_const, _builder=_builder) + return math.fdiv(log_val, log10_const, _builder=_builder) @core.builtin @@ -751,29 +984,34 @@ def log10(arg0: core.tensor, _builder: ir.builder): ]) @math._add_math_2arg_docstr("copysign") def copysign(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): - """ - Create a floating-point value with the magnitude of x and the sign of y. - """ - x = semantic.to_tensor(arg0, _builder) - y = semantic.to_tensor(arg1, _builder) + if triton_enable_libdevice_simt() and is_compile_on_910_95: + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_copysign_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + else: + """ + Create a floating-point value with the magnitude of x and the sign of y. + """ + x = semantic.to_tensor(arg0, _builder) + y = semantic.to_tensor(arg1, _builder) - magnitude = core.tensor(_builder.create_fabs(x.handle), x.type) + magnitude = core.tensor(_builder.create_fabs(x.handle), x.type) - zero = semantic.full(y.shape, 0.0, y.type.scalar, _builder) - one = semantic.full(y.shape, 1.0, y.type.scalar, _builder) + zero = semantic.full(y.shape, 0.0, y.type.scalar, _builder) + one = semantic.full(y.shape, 1.0, y.type.scalar, _builder) - is_zero = semantic.equal(y, zero, _builder) - reciprocal = math.fdiv(one, y, _builder=_builder) - is_negative_reciprocal = semantic.less_than(reciprocal, zero, _builder) - is_negative_zero = semantic.and_(is_zero, is_negative_reciprocal, _builder) + is_zero = semantic.equal(y, zero, _builder) + y_reciprocal = math.fdiv(one, y, _builder=_builder) + is_negative_reciprocal = semantic.less_than(y_reciprocal, zero, _builder) + is_negative_zero = semantic.and_(is_zero, is_negative_reciprocal, _builder) - is_negative_nonzero = semantic.less_than(y, zero, _builder) - is_negative = semantic.or_(is_negative_zero, is_negative_nonzero, _builder) + is_negative_nonzero = semantic.less_than(y, zero, _builder) + is_negative = semantic.or_(is_negative_zero, is_negative_nonzero, _builder) - neg_magnitude = semantic.mul(magnitude, semantic.full(magnitude.shape, -1.0, magnitude.type.scalar, _builder), True, - _builder) + neg_magnitude = semantic.mul(magnitude, semantic.full(magnitude.shape, -1.0, magnitude.type.scalar, _builder), True, _builder) - return semantic.where(is_negative, neg_magnitude, magnitude, _builder) + return semantic.where(is_negative, neg_magnitude, magnitude, _builder) if get_ascend_arch_from_env() == "Ascend910_9589": diff --git a/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp b/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp new file mode 100644 index 0000000000..ec621adf79 --- /dev/null +++ b/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp @@ -0,0 +1,363 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "AutoBlockify/AutoBlockify.h" +#include "AutoBlockify/Utils.h" +#include "Dialect/TritonAscend/IR/TritonAscendDialect.h" +#include "Utils/Utils.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "auto-blockify" + +using namespace mlir; +using namespace triton; + +PropagateUnrealizedCastDown::PropagateUnrealizedCastDown(MLIRContext *context, + Value logicalBlockId, + Value logicalBlockNum, + int autoBlockifySize) + : OpRewritePattern(context), + logicalBlockId(logicalBlockId), logicalBlockNum(logicalBlockNum), + autoBlockifySize(autoBlockifySize) {} + +LogicalResult +PropagateUnrealizedCastDown::matchAndRewrite(UnrealizedConversionCastOp op, + PatternRewriter &rewriter) const { + if (op.getInputs().size() != 2) + return failure(); + auto funcOp = op->getParentOfType(); + auto input = op.getInputs()[0]; + auto res = op->getResult(0); + SmallPtrSet users(op->user_begin(), op->user_end()); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Handling UnrealizedConversionCastOp:\n" << op << "\n"; + os << "Users:\n"; + for (auto *user : users) + os << *user << "\n"; + }); + for (auto *user : users) { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(user); + if (auto uccOp = dyn_cast(user)) { + if (uccOp->getResultTypes()[0] != input.getType()) { + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << *user << "\n"; + }); + return op.emitError("UnrealizedConversionCastOp cannot be resolved\n"); + } + rewriter.replaceOp(user, input); + } else if (auto blockifyLoop = getBlockifyLoop(user)) { + handleBlockifyLoop(blockifyLoop.value(), user, rewriter); + } else if (auto splatOp = dyn_cast(user)) { + rewriteSplat(op, splatOp, rewriter); + } else if (auto expandDimsOp = dyn_cast(user)) { + rewriteExpandDims(op, expandDimsOp, rewriter); + } else if (auto reduceOp = dyn_cast(user)) { + rewriteReduce(op, reduceOp, rewriter); + } else if (auto scanOp = dyn_cast(user)) { + rewriteScan(op, scanOp, rewriter); + } else if (auto loadOp = dyn_cast(user)) { + rewriteLoad(op, loadOp, rewriter); + } else if (auto storeOp = dyn_cast(user)) { + rewriteStore(op, storeOp, rewriter); + } else if (auto atomicRMWOp = dyn_cast(user)) { + rewriteAtomicRMW(op, atomicRMWOp, rewriter); + } else if (auto assertOp = dyn_cast(user)) { + rewriteAssert(op, assertOp, rewriter); + } else if (auto extractSliceOp = dyn_cast(user)) { + rewriteExtractSlice(op, extractSliceOp, rewriter); + } else if (auto insertSliceOp = dyn_cast(user)) { + rewriteInsertSlice(op, insertSliceOp, rewriter); + } else if (auto whileOp = dyn_cast(user)) { + rewriteWhile(op, whileOp, rewriter); + } else if (auto loopOp = dyn_cast(user)) { + rewriteLoop(op, loopOp, rewriter); + } else if (auto yieldOp = dyn_cast(user)) { + rewriteYield(op, yieldOp, rewriter); + } else if (auto conditionOp = dyn_cast(user)) { + rewriteCondition(op, conditionOp, rewriter); + } else if (user->hasTrait() || + isa(user)) { + rewriteGeneraleOp(op, user, rewriter); + } else if (isa(user)) { + auto *newOp = + createBlockifyLoop(user, op, logicalBlockId, logicalBlockNum, + autoBlockifySize, rewriter); + rewriter.setInsertionPoint(newOp); + handleBlockifyLoop(*getBlockifyLoop(newOp), newOp, rewriter); + } else { + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Unhandled Op\n" << *user << "\n"; + }); + llvm_unreachable("Unhandled operation"); + } + } + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "After successful conversion\n"; + os << funcOp << "\n"; + }); + rewriter.eraseOp(op); + return success(); +} + +AutoBlockifyPass::AutoBlockifyPass(const AutoBlockifyOptions &options) + : AutoBlockifyBase(options) {} + +bool AutoBlockifyPass::checkBlockifiable(Value v) { + if (!checkedValues.insert(v).second) + return true; + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Checking blockifiable:\n" << v << "\n"; + }); + for (auto &use : v.getUses()) { + auto *user = use.getOwner(); + auto opNum = use.getOperandNumber(); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "User:\n" << *user << "\n"; + }); + if (isa(user) || + llvm::any_of(user->getOperandTypes(), isTensorPtrType)) + return false; + if (auto ifOp = dyn_cast(user)) { + user->setAttr(autoBlockifyRegionOpAttr, UnitAttr::get(v.getContext())); + return true; + } else if (auto whileOp = dyn_cast(user)) { + if (!checkBlockifiable(whileOp.getBeforeArguments()[opNum])) + return false; + } else if (auto loopOp = dyn_cast(user)) { + auto regionIterArg = loopOp.getTiedLoopRegionIterArg(&use); + auto loopResult = loopOp.getTiedLoopResult(&use); + if (!regionIterArg || !loopResult) { + user->setAttr(autoBlockifyRegionOpAttr, UnitAttr::get(v.getContext())); + return true; + } + if (!checkBlockifiable(regionIterArg) || !checkBlockifiable(loopResult)) + return false; + } else if (auto conditionOp = dyn_cast(user)) { + auto whileOp = cast(user->getParentOp()); + if (opNum == 0) { + whileOp->setAttr(autoBlockifyRegionOpAttr, + UnitAttr::get(v.getContext())); + return true; + } + if (!checkBlockifiable(whileOp.getAfterArguments()[opNum - 1]) || + !checkBlockifiable(whileOp->getResult(opNum - 1))) + return false; + } else if (auto conditionOp = dyn_cast(user)) { + if (auto loopOp = dyn_cast(user->getParentOp()); + loopOp && !checkBlockifiable(loopOp.getInits()[opNum])) + return false; + } else { + for (auto res : user->getResults()) { + if (!checkBlockifiable(res)) + return false; + } + } + } + return true; +} + +void AutoBlockifyPass::preProcess(triton::FuncOp func) { + IRRewriter rewriter(func.getContext()); + rewriter.setInsertionPointToStart(&func.getBody().front()); + auto loc = rewriter.getUnknownLoc(); + // Get logical block num + auto xNum = + rewriter.create(loc, triton::ProgramIDDim::X); + auto yNum = + rewriter.create(loc, triton::ProgramIDDim::Y); + auto zNum = + rewriter.create(loc, triton::ProgramIDDim::Z); + auto yzNum = rewriter.create(loc, yNum, zNum); + logicalBlockNum = rewriter.create(loc, yzNum, xNum); + + // Get logical block id + auto xDim = + rewriter.create(loc, triton::ProgramIDDim::X); + auto yDim = + rewriter.create(loc, triton::ProgramIDDim::Y); + auto zDim = + rewriter.create(loc, triton::ProgramIDDim::Z); + xDim->setAttr(logicalBlockIdAttr, rewriter.getUnitAttr()); + yDim->setAttr(logicalBlockIdAttr, rewriter.getUnitAttr()); + zDim->setAttr(logicalBlockIdAttr, rewriter.getUnitAttr()); + auto xFlatten = rewriter.create(loc, xDim, yzNum); + auto yFlatten = rewriter.create(loc, yDim, zNum); + logicalBlockId = rewriter.create(loc, xFlatten, yFlatten); + logicalBlockId = rewriter.create(loc, logicalBlockId, zDim); + + // get blockified block id + auto blockifyTensorType = + RankedTensorType::get({autoBlockifySize}, rewriter.getI32Type()); + auto blockfyRange = rewriter.create( + loc, blockifyTensorType, 0, autoBlockifySize); + auto splatedLogicalBlockId = rewriter.create( + loc, blockfyRange.getType(), logicalBlockId); + Value blockifiedId = + rewriter.create(loc, splatedLogicalBlockId, blockfyRange); + + // get mask + auto splatedBlockNum = rewriter.create( + loc, blockfyRange.getType(), logicalBlockNum); + auto upperboundMask = rewriter.create( + loc, arith::CmpIPredicate::slt, blockifiedId, splatedBlockNum); + auto splatedZero = rewriter.create( + loc, DenseElementsAttr::get(blockifyTensorType, + rewriter.getI32IntegerAttr(0))); + auto lowerboundMask = rewriter.create( + loc, arith::CmpIPredicate::sge, blockifiedId, splatedZero); + Value blockifiedIdMask = + rewriter.create(loc, upperboundMask, lowerboundMask); + + blockifiedId = rewriter + .create( + loc, logicalBlockId.getType(), + ValueRange({blockifiedId, blockifiedIdMask})) + ->getResult(0); + + // replace program id to be computed from blockified id + SmallVector toReplace; + func.walk([&](triton::GetProgramIdOp id) { + if (id->hasAttr(logicalBlockIdAttr)) + return; + toReplace.push_back(id); + }); + for (auto id : toReplace) { + rewriter.setInsertionPoint(id); + Value newId; + if (id.getAxis() == triton::ProgramIDDim::X) { + newId = rewriter.create(id.getLoc(), blockifiedId, yzNum); + newId = rewriter.create(id.getLoc(), newId, xNum); + } else if (id.getAxis() == triton::ProgramIDDim::Y) { + newId = rewriter.create(id.getLoc(), blockifiedId, zNum); + newId = rewriter.create(id.getLoc(), newId, yNum); + } else { + newId = rewriter.create(id.getLoc(), blockifiedId, zNum); + } + rewriter.replaceOp(id, newId); + } + + // Create for loop for region ops + func.walk([&](Operation *op) { + if (op->hasAttr(autoBlockifyRegionOpAttr)) { + auto *newOp = createBlockifyLoop( + op, blockifiedId.getDefiningOp(), + logicalBlockId, logicalBlockNum, autoBlockifySize, rewriter); + newOp->removeAttr(autoBlockifyRegionOpAttr); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); +} + +void AutoBlockifyPass::runOnOperation() { + if (autoBlockifySize == 1) + return; + ModuleOp moduleOp = getOperation(); + if (autoBlockifySize <= 0) { + moduleOp->emitWarning("[AutoBlockify V2] AutoBlockifySize cannot be " + "negative integer, skipping."); + return signalPassFailure(); + } + + MLIRContext *ctx = &getContext(); + + moduleOp.walk([&](triton::FuncOp func) { + LogicalResult result = success(); + func.walk([&](triton::GetProgramIdOp id) { + if (!checkBlockifiable(id.getResult())) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (failed(result)) { + func->emitWarning("Cannot apply auto blockify"); + return WalkResult::skip(); + } + preProcess(func); + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "After preprocess:\n" << func << "\n"; + }); + + RewritePatternSet patterns(ctx); + patterns.add( + ctx, logicalBlockId, logicalBlockNum, autoBlockifySize); + + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + moduleOp->emitError("failed to apply Patterns"); + signalPassFailure(); + return WalkResult::interrupt(); + } + + IRRewriter rewriter(ctx); + func->walk([&](UnrealizedConversionCastOp op) { + rewriter.setInsertionPoint(op); + auto input = op.getInputs()[0]; + auto resType = cast(op->getResultTypes()[0]); + if (auto constantOp = input.getDefiningOp()) { + Attribute val = constantOp.getValue(); + if (auto denseAttr = dyn_cast(val)) + val = denseAttr.getSplatValue(); + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(resType, val)); + } else if (auto tensorType = + dyn_cast(input.getType())) { + input = rewriter.create(input.getLoc(), input, 0); + rewriter.replaceOpWithNewOp(op, resType, input); + } else { + rewriter.replaceOpWithNewOp(op, resType, input); + } + }); + func->setAttr(autoBlockifySizeAttr, + rewriter.getI32IntegerAttr(autoBlockifySize)); + return WalkResult::skip(); + }); + + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, moduleOp))) { + signalPassFailure(); + } +} + +std::unique_ptr> +triton::createAutoBlockifyPass(const AutoBlockifyOptions &options) { + return std::make_unique(options); +} \ No newline at end of file diff --git a/third_party/ascend/lib/AutoBlockify/CMakeLists.txt b/third_party/ascend/lib/AutoBlockify/CMakeLists.txt new file mode 100644 index 0000000000..a0ccd59b2e --- /dev/null +++ b/third_party/ascend/lib/AutoBlockify/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(AutoBlockify + AutoBlockify.cpp + RewriteOperation.cpp + Utils.cpp + + DEPENDS + AutoBlockifyPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + TritonIR + TritonTransforms + TritonAnalysis + MLIRTransforms + MLIRSupport + MLIRSCFTransforms +) \ No newline at end of file diff --git a/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp b/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp new file mode 100644 index 0000000000..d51c44997a --- /dev/null +++ b/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp @@ -0,0 +1,509 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "AutoBlockify/AutoBlockify.h" +#include "AutoBlockify/Utils.h" +#include "Utils/Utils.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "auto-blockify-rewrite-operation" + +using namespace mlir; +using namespace triton; + +void PropagateUnrealizedCastDown::handleBlockifyLoop( + scf::ForOp blockifyLoop, Operation *op, PatternRewriter &rewriter) const { + SmallVector newOperands; + for (auto opr : op->getOperands()) { + auto uccOp = opr.getDefiningOp(); + if (!uccOp) { + newOperands.push_back(opr); + continue; + } + auto input = uccOp.getInputs()[0]; + auto tensorType = cast(input.getType()); + Value newOperand; + if (tensorType.getRank() > 1) { + SmallVector offsets(tensorType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector sizes(1, rewriter.getIndexAttr(1)); + SmallVector strides(tensorType.getRank(), + rewriter.getIndexAttr(1)); + offsets[0] = blockifyLoop.getInductionVar(); + for (auto dim : llvm::drop_begin(tensorType.getShape())) + sizes.push_back(rewriter.getIndexAttr(dim)); + newOperand = rewriter.create( + input.getLoc(), cast(opr.getType()), input, offsets, + sizes, strides); + } else { + newOperand = rewriter.create( + input.getLoc(), input, ValueRange{blockifyLoop.getInductionVar()}); + if (isa(opr.getType())) { + newOperand = rewriter.create( + input.getLoc(), rewriter.getIndexType(), newOperand); + } + } + newOperands.push_back(newOperand); + } + rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); }); +} + +void PropagateUnrealizedCastDown::rewriteGeneraleOp( + UnrealizedConversionCastOp op, Operation *generalOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto res = op->getResult(0); + auto inputType = cast(input.getType()); + SmallVector newOperands; + SmallVector newResults; + SmallVector newResultTypes; + + for (auto operand : generalOp->getOperands()) + newOperands.push_back(rewriteValue(operand, op, rewriter)); + for (auto resType : generalOp->getResultTypes()) { + newResultTypes.push_back(getExpandedType(resType, op)); + } + auto *newOp = + rewriter.create(generalOp->getLoc(), generalOp->getName().getIdentifier(), + newOperands, newResultTypes, generalOp->getAttrs()); + replaceValue(newOp, generalOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteSplat( + UnrealizedConversionCastOp op, triton::SplatOp splatOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto resType = cast(splatOp.getResult().getType()); + auto curShape = + llvm::to_vector(cast(input.getType()).getShape()); + auto splatedShape = resType.getShape(); + for (auto dim : splatedShape) { + input = rewriter.create(input.getLoc(), input, + curShape.size()); + curShape.push_back(dim); + input = rewriter.create( + input.getLoc(), + RankedTensorType::get(curShape, getElementTypeOrSelf(input)), input); + } + replaceValue(input.getDefiningOp(), splatOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteExpandDims( + UnrealizedConversionCastOp op, triton::ExpandDimsOp expandDimsOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto newOp = rewriter.create( + expandDimsOp.getLoc(), input, expandDimsOp.getAxis() + 1); + for (auto attr : expandDimsOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, expandDimsOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteReduce( + UnrealizedConversionCastOp op, triton::ReduceOp reduceOp, + PatternRewriter &rewriter) const { + auto mask = op.getInputs()[1]; + auto srcs = llvm::map_to_vector(reduceOp.getSrcs(), [&](Value src) { + return rewriteValue(src, op, rewriter); + }); + auto newOp = rewriter.create(reduceOp.getLoc(), srcs, + reduceOp.getAxis() + 1); + auto &newCombineOp = newOp.getCombineOp(); + rewriter.cloneRegionBefore(reduceOp.getCombineOp(), newCombineOp, + newCombineOp.end()); + for (auto attr : reduceOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, reduceOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteScan(UnrealizedConversionCastOp op, + triton::ScanOp scanOp, + PatternRewriter &rewriter) const { + auto mask = op.getInputs()[1]; + auto srcs = llvm::map_to_vector(scanOp.getSrcs(), [&](Value src) { + return rewriteValue(src, op, rewriter); + }); + auto newOp = rewriter.create( + scanOp.getLoc(), srcs, scanOp.getAxis() + 1, scanOp.getReverse()); + auto &newCombineOp = newOp.getCombineOp(); + rewriter.cloneRegionBefore(scanOp.getCombineOp(), newCombineOp, + newCombineOp.end()); + for (auto attr : scanOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, scanOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteLoad(UnrealizedConversionCastOp op, + triton::LoadOp loadOp, + PatternRewriter &rewriter) const { + auto uccMask = op.getInputs()[1]; + auto ptr = rewriteValue(loadOp.getPtr(), op, rewriter); + auto other = rewriteValue(loadOp.getOther(), op, rewriter); + auto mask = rewriteValue(loadOp.getMask(), op, rewriter); + auto res = loadOp.getResult(); + auto resType = getExpandedType(res.getType(), op); + if (!other) { + other = rewriter.create( + rewriter.getUnknownLoc(), + DenseElementsAttr::get( + resType, rewriter.getZeroAttr(getElementTypeOrSelf(res)))); + } + mask = createMask(mask, uccMask, resType.getShape(), rewriter); + auto boundaryCheck = llvm::map_to_vector(loadOp.getBoundaryCheck(), + [](int32_t idx) { return idx + 1; }); + auto newOp = rewriter.create( + loadOp.getLoc(), ptr, mask, other, boundaryCheck, loadOp.getPadding(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + for (auto attr : loadOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, loadOp, uccMask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteStore( + UnrealizedConversionCastOp op, triton::StoreOp storeOp, + PatternRewriter &rewriter) const { + auto uccMask = op.getInputs()[1]; + auto ptr = rewriteValue(storeOp.getPtr(), op, rewriter); + auto value = rewriteValue(storeOp.getValue(), op, rewriter); + auto mask = rewriteValue(storeOp.getMask(), op, rewriter); + auto ptrShape = cast(ptr.getType()).getShape(); + mask = createMask(mask, uccMask, ptrShape, rewriter); + auto boundaryCheck = llvm::map_to_vector(storeOp.getBoundaryCheck(), + [](int32_t idx) { return idx + 1; }); + auto newOp = rewriter.create( + storeOp.getLoc(), ptr, value, mask, boundaryCheck, storeOp.getCache(), + storeOp.getEvict()); + for (auto attr : storeOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(storeOp, newOp); +} + +void PropagateUnrealizedCastDown::rewriteAtomicRMW( + UnrealizedConversionCastOp op, triton::AtomicRMWOp atomicRMWOp, + PatternRewriter &rewriter) const { + auto uccMask = op.getInputs()[1]; + auto ptr = rewriteValue(atomicRMWOp.getPtr(), op, rewriter); + auto val = rewriteValue(atomicRMWOp.getVal(), op, rewriter); + auto mask = rewriteValue(atomicRMWOp.getMask(), op, rewriter); + auto resType = getExpandedType(atomicRMWOp.getResult().getType(), op); + mask = createMask(mask, uccMask, resType.getShape(), rewriter); + auto newOp = rewriter.create( + atomicRMWOp.getLoc(), resType, atomicRMWOp.getAtomicRmwOp(), ptr, val, + mask, atomicRMWOp.getSem(), atomicRMWOp.getScope()); + for (auto attr : atomicRMWOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, atomicRMWOp, uccMask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteAssert( + UnrealizedConversionCastOp op, triton::AssertOp assertOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto inputShape = cast(input.getType()).getShape(); + auto conditionType = cast(mask.getType()); + auto oneAttr = rewriter.getIntegerAttr(getElementTypeOrSelf(mask), 1); + auto one = rewriter.create( + mask.getLoc(), DenseElementsAttr::get(conditionType, oneAttr)); + Value condition = rewriter.create(input.getLoc(), mask, one); + condition = createMask(nullptr, condition, inputShape, rewriter); + condition = + rewriter.create(condition.getLoc(), condition, input); + auto newOp = rewriter.create(assertOp.getLoc(), condition, + assertOp.getMessage()); + for (auto attr : assertOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(assertOp, newOp); +} + +void PropagateUnrealizedCastDown::rewriteExtractSlice( + UnrealizedConversionCastOp op, tensor::ExtractSliceOp extractSliceOp, + PatternRewriter &rewriter) const { + auto mask = op.getInputs()[1]; + auto src = rewriteValue(extractSliceOp.getSource(), op, rewriter); + auto offsets = llvm::to_vector(extractSliceOp.getMixedOffsets()); + auto sizes = llvm::to_vector(extractSliceOp.getMixedSizes()); + auto strides = llvm::to_vector(extractSliceOp.getMixedStrides()); + auto srcType = cast(src.getType()); + offsets.insert(offsets.begin(), rewriter.getIndexAttr(0)); + sizes.insert(sizes.begin(), rewriter.getIndexAttr(srcType.getShape()[0])); + strides.insert(strides.begin(), rewriter.getIndexAttr(1)); + auto newOp = rewriter.create( + extractSliceOp.getLoc(), src, offsets, sizes, strides); + auto newMask = rewriter.create( + mask.getLoc(), mask, offsets, sizes, strides); + for (auto attr : extractSliceOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, extractSliceOp, newMask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteInsertSlice( + UnrealizedConversionCastOp op, tensor::InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const { + auto mask = op.getInputs()[1]; + auto src = rewriteValue(insertSliceOp.getSource(), op, rewriter); + auto dst = rewriteValue(insertSliceOp.getDest(), op, rewriter); + auto offsets = llvm::to_vector(insertSliceOp.getMixedOffsets()); + auto sizes = llvm::to_vector(insertSliceOp.getMixedSizes()); + auto strides = llvm::to_vector(insertSliceOp.getMixedStrides()); + auto srcType = cast(src.getType()); + offsets.insert(offsets.begin(), rewriter.getIndexAttr(0)); + sizes.insert(sizes.begin(), rewriter.getIndexAttr(srcType.getShape()[0])); + strides.insert(strides.begin(), rewriter.getIndexAttr(1)); + auto newOp = rewriter.create( + insertSliceOp.getLoc(), src, dst, offsets, sizes, strides); + for (auto attr : insertSliceOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + replaceValue(newOp, insertSliceOp, mask, rewriter); +} + +void PropagateUnrealizedCastDown::rewriteWhile( + UnrealizedConversionCastOp op, scf::WhileOp whileOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto res = op->getResult(0); + SmallVector indices; + SmallVector newInits; + IRMapping mapping; + for (auto [idx, init] : llvm::enumerate(whileOp.getInits())) { + if (init == res) { + indices.push_back(idx); + newInits.push_back(input); + } else { + newInits.push_back(init); + } + } + auto newOp = rewriter.create( + whileOp.getLoc(), whileOp->getResultTypes(), newInits, + [&](OpBuilder &b, Location loc, ValueRange args) { + mapRegionIterArg(mapping, whileOp.getBeforeArguments(), args, indices, + mask, b); + for (auto &bodyOp : *whileOp.getBeforeBody()) + b.clone(bodyOp, mapping); + }, + [&](OpBuilder &b, Location loc, ValueRange args) { + mapRegionIterArg(mapping, whileOp.getAfterArguments(), args, {}, mask, + b); + for (auto &bodyOp : whileOp.getAfterBody()->without_terminator()) + b.clone(bodyOp, mapping); + auto yieldOp = + cast(whileOp.getAfterBody()->getTerminator()); + mapYieldedValue(mapping, yieldOp, indices, op, b); + }); + for (auto attr : whileOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(whileOp, newOp); +} + +void PropagateUnrealizedCastDown::rewriteLoop(UnrealizedConversionCastOp op, + LoopLikeOpInterface loopOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto res = op->getResult(0); + SmallVector indices; + SmallVector newInits; + IRMapping mapping; + for (auto [idx, init] : llvm::enumerate(loopOp.getInits())) { + if (init == res) { + indices.push_back(idx); + newInits.push_back(input); + } else { + newInits.push_back(init); + } + } + LoopLikeOpInterface newOp; + if (auto forOp = dyn_cast(loopOp.getOperation())) { + newOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInits, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + mapping.map(forOp.getInductionVar(), iv); + mapRegionIterArg(mapping, forOp.getRegionIterArgs(), args, indices, + mask, b); + for (auto &bodyOp : forOp.getBody()->without_terminator()) + b.clone(bodyOp, mapping); + auto yieldOp = cast(forOp.getBody()->getTerminator()); + mapYieldedValue(mapping, yieldOp, indices, op, b); + }); + for (auto attr : forOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + } else { + llvm_unreachable("Unhandled loopOp"); + } + replaceValue(newOp, loopOp, mask, rewriter, indices); +} + +void PropagateUnrealizedCastDown::rewriteIf(UnrealizedConversionCastOp &op, + scf::IfOp ifOp, + ArrayRef indices, + PatternRewriter &rewriter) const { + IRMapping mapping; + auto mask = op.getInputs()[1]; + auto thenBlockBuilder = [&](OpBuilder &b, Location loc) { + for (auto &bodyOp : *ifOp.thenBlock()) + b.clone(bodyOp, mapping); + }; + function_ref elseBlockBuilder = + [&](OpBuilder &b, Location loc) { + for (auto &bodyOp : *ifOp.elseBlock()) + b.clone(bodyOp, mapping); + }; + if (!ifOp.elseBlock()) + elseBlockBuilder = nullptr; + auto newOp = rewriter.create(ifOp.getLoc(), ifOp.getCondition(), + thenBlockBuilder, elseBlockBuilder); + for (auto attr : ifOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + if (mapping.contains(op)) + op = cast(mapping.lookup(op)); + replaceValue(newOp, ifOp, mask, rewriter, indices); +} + +void PropagateUnrealizedCastDown::rewriteYield( + UnrealizedConversionCastOp &op, scf::YieldOp yieldOp, + PatternRewriter &rewriter) const { + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto res = op->getResult(0); + SmallVector indices; + auto newOperands = llvm::to_vector(yieldOp.getOperands()); + for (auto [idx, opr] : llvm::enumerate(newOperands)) { + if (opr == res) + indices.push_back(idx); + } + if (auto loopOp = dyn_cast(yieldOp->getParentOp())) { + auto uccOp = rewriter.create( + op.getLoc(), res.getType(), ValueRange({input})); + for (auto curIdx : indices) + newOperands[curIdx] = uccOp->getResult(0); + auto newOp = rewriter.create(yieldOp.getLoc(), newOperands); + for (auto attr : yieldOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(yieldOp, newOp); + rewriter.setInsertionPoint(loopOp); + for (auto curIdx : indices) { + auto &initArg = loopOp.getInitsMutable()[curIdx]; + auto initVal = initArg.get(); + uccOp = rewriter.create( + initVal.getLoc(), input.getType(), ValueRange({initVal})); + uccOp = rewriter.create( + initVal.getLoc(), initVal.getType(), + ValueRange({uccOp->getResult(0), mask})); + rewriter.modifyOpInPlace(loopOp, + [&]() { initArg.set(uccOp->getResult(0)); }); + } + } else if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + for (auto curIdx : indices) + newOperands[curIdx] = input; + auto newOp = rewriter.create(yieldOp.getLoc(), newOperands); + for (auto attr : yieldOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(yieldOp, newOp); + yieldOp = ifOp.thenYield() == yieldOp ? ifOp.elseYield() : ifOp.thenYield(); + if (yieldOp) { + rewriter.setInsertionPoint(yieldOp); + newOperands = llvm::to_vector(yieldOp.getOperands()); + for (auto curIdx : indices) { + auto uccOp = rewriter.create( + op.getLoc(), input.getType(), ValueRange({newOperands[curIdx]})); + newOperands[curIdx] = uccOp->getResult(0); + } + rewriter.replaceOpWithNewOp(yieldOp, newOperands); + } + rewriter.setInsertionPoint(ifOp); + rewriteIf(op, ifOp, indices, rewriter); + } +} + +void PropagateUnrealizedCastDown::rewriteCondition( + UnrealizedConversionCastOp op, scf::ConditionOp conditionOp, + PatternRewriter &rewriter) const { + auto whileOp = cast(conditionOp->getParentOp()); + auto input = op.getInputs()[0]; + auto mask = op.getInputs()[1]; + auto res = op->getResult(0); + int64_t curIdx = -1; + auto args = llvm::to_vector(conditionOp.getArgs()); + for (auto [idx, opr] : llvm::enumerate(args)) { + if (opr == res) + curIdx = idx; + } + args[curIdx] = input; + auto newOp = rewriter.create( + conditionOp.getLoc(), conditionOp.getCondition(), args); + for (auto attr : conditionOp->getAttrs()) { + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + } + rewriter.replaceOp(conditionOp, newOp); + + res = whileOp->getResult(curIdx); + auto oldResType = res.getType(); + auto newResType = getExpandedType(oldResType, op); + rewriter.modifyOpInPlace(whileOp, [&]() { res.setType(newResType); }); + rewriter.setInsertionPointAfter(whileOp); + auto newUccOp = rewriter.create( + res.getLoc(), oldResType, ValueRange({res, mask})); + rewriter.replaceAllUsesExcept(res, newUccOp->getResult(0), newUccOp); + auto arg = whileOp.getAfterArguments()[curIdx]; + auto oldArgType = arg.getType(); + auto newArgType = getExpandedType(oldArgType, op); + rewriter.modifyOpInPlace(whileOp, [&]() { arg.setType(newArgType); }); + rewriter.setInsertionPointToStart(whileOp.getAfterBody()); + newUccOp = rewriter.create( + arg.getLoc(), oldArgType, ValueRange({arg, mask})); + rewriter.replaceAllUsesExcept(arg, newUccOp->getResult(0), newUccOp); +} \ No newline at end of file diff --git a/third_party/ascend/lib/AutoBlockify/Utils.cpp b/third_party/ascend/lib/AutoBlockify/Utils.cpp new file mode 100644 index 0000000000..bdb09e792e --- /dev/null +++ b/third_party/ascend/lib/AutoBlockify/Utils.cpp @@ -0,0 +1,211 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "AutoBlockify/Utils.h" +#include "Utils/Utils.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "auto-blockify-utils" + +using namespace mlir; +using namespace triton; + +RankedTensorType getExpandedType(Type type, UnrealizedConversionCastOp op) { + auto target = op.getInputs()[0]; + auto targetType = cast(target.getType()); + SmallVector targetShape{targetType.getShape()[0]}; + if (auto valueType = dyn_cast(type)) { + targetShape.append(valueType.getShape().begin(), + valueType.getShape().end()); + } + return RankedTensorType::get(targetShape, getElementTypeOrSelf(type)); +} + +Value rewriteValue(Value value, UnrealizedConversionCastOp op, + OpBuilder &builder) { + if (value == nullptr) + return nullptr; + if (value == op->getResult(0)) + return op.getInputs()[0]; + return builder + .create( + value.getLoc(), getExpandedType(value.getType(), op), value) + ->getResult(0); +} + +void replaceValue(Operation *newOp, Operation *oldOp, Value newMask, + RewriterBase &rewriter, ArrayRef replaceIndices) { + int64_t idx = 0; + for (auto [res, oldRes] : + llvm::zip_equal(newOp->getResults(), oldOp->getResults())) { + if (replaceIndices.empty() || + llvm::find(replaceIndices, idx) != replaceIndices.end()) { + auto resType = res.getType(); + auto newUccOp = rewriter.create( + newOp->getLoc(), oldRes.getType(), ValueRange({res, newMask})); + rewriter.replaceAllUsesExcept(oldRes, newUccOp->getResult(0), newUccOp); + } else { + rewriter.replaceAllUsesWith(oldRes, res); + } + idx++; + } + rewriter.eraseOp(oldOp); +} + +Value createMask(Value mask, Value uccMask, ArrayRef targetShape, + RewriterBase &rewriter) { + SmallVector curShape{targetShape[0]}; + for (auto [idx, dim] : llvm::drop_begin(llvm::enumerate(targetShape))) { + curShape.push_back(dim); + uccMask = + rewriter.create(uccMask.getLoc(), uccMask, idx); + uccMask = rewriter.create( + uccMask.getLoc(), + RankedTensorType::get(curShape, getElementTypeOrSelf(uccMask)), + uccMask); + } + if (mask) { + mask = rewriter.create(mask.getLoc(), mask, uccMask); + } else { + mask = uccMask; + } + return mask; +} + +void mapRegionIterArg(IRMapping &mapping, ValueRange oldArgs, + ValueRange newArgs, ArrayRef indices, Value mask, + OpBuilder &builder) { + auto newArgIter = newArgs.begin(); + for (auto [idx, oldArg] : llvm::enumerate(oldArgs)) { + if (llvm::find(indices, idx) != indices.end()) { + auto newUccOp = builder.create( + oldArg.getLoc(), oldArg.getType(), ValueRange({*newArgIter, mask})); + mapping.map(oldArg, newUccOp->getResult(0)); + } else { + mapping.map(oldArg, *newArgIter); + } + ++newArgIter; + } +} + +void mapYieldedValue(IRMapping &mapping, scf::YieldOp yieldOp, + ArrayRef indices, UnrealizedConversionCastOp op, + OpBuilder &builder) { + SmallVector newOperands; + for (auto [idx, operand] : llvm::enumerate(yieldOp.getOperands())) { + operand = mapping.lookup(operand); + if (llvm::find(indices, idx) != indices.end()) + newOperands.push_back(rewriteValue(operand, op, builder)); + else + newOperands.push_back(operand); + } + builder.create(yieldOp.getLoc(), newOperands); +} + +Operation *createBlockifyLoop(Operation *targetOp, + UnrealizedConversionCastOp op, + Value logicalBlockId, Value logicalBlockNum, + int autoBlockifySize, RewriterBase &rewriter) { + auto loc = targetOp->getLoc(); + rewriter.setInsertionPoint(targetOp); + auto initVal = + rewriter.create(loc, rewriter.getIndexAttr(0)); + auto stepVal = + rewriter.create(loc, rewriter.getIndexAttr(1)); + auto blockifySizeVal = rewriter.create( + loc, rewriter.getIndexAttr(autoBlockifySize)); + Value upperBound = + rewriter.create(loc, logicalBlockNum, logicalBlockId); + auto i32Zero = + rewriter.create(loc, rewriter.getI32IntegerAttr(0)); + upperBound = rewriter.create(loc, upperBound, i32Zero); + upperBound = rewriter.create(loc, rewriter.getIndexType(), + upperBound); + upperBound = + rewriter.create(loc, upperBound, blockifySizeVal); + SmallVector inits; + if (auto loopOp = dyn_cast(targetOp)) { + inits = llvm::map_to_vector(loopOp.getInits(), + [&rewriter, &op](Value v) -> Value { + return rewriteValue(v, op, rewriter); + }); + } else { + auto resultTypes = + llvm::map_to_vector(targetOp->getResultTypes(), [&op](Type type) { + return getExpandedType(type, op); + }); + inits = + llvm::map_to_vector(resultTypes, [&rewriter, &loc](Type type) -> Value { + auto tensorType = cast(type); + return rewriter.create(loc, tensorType.getShape(), + tensorType.getElementType()); + }); + } + auto mask = op.getInputs()[1]; + Operation *newOp; + auto blockifyLoop = rewriter.create( + loc, initVal, upperBound, stepVal, inits, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + newOp = b.clone(*targetOp); + + SmallVector newResults; + for (auto [arg, res] : llvm::zip_equal(args, newOp->getResults())) { + auto tensorType = cast(arg.getType()); + auto rank = tensorType.getRank(); + Value newRes; + if (rank > 1) { + SmallVector offsets(tensorType.getRank(), + b.getIndexAttr(0)); + SmallVector sizes(1, b.getIndexAttr(1)); + SmallVector strides(tensorType.getRank(), + b.getIndexAttr(1)); + offsets[0] = iv; + for (auto dim : llvm::drop_begin(tensorType.getShape())) + sizes.push_back(b.getIndexAttr(dim)); + newRes = b.create(loc, res, arg, offsets, + sizes, strides); + } else { + newRes = b.create(loc, res, arg, ValueRange{iv}); + } + newResults.push_back(newRes); + } + b.create(loc, newResults); + }); + + replaceValue(blockifyLoop, targetOp, mask, rewriter); + blockifyLoop->setAttr(autoBlockifyLoopAttr, rewriter.getUnitAttr()); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "After creating blockify loop:\n" << blockifyLoop << "\n"; + }); + return newOp; +} + +std::optional getBlockifyLoop(Operation *op) { + while (auto forOp = op->getParentOfType()) { + if (forOp->hasAttr(autoBlockifyLoopAttr)) + return forOp; + op = forOp; + } + return std::nullopt; +} \ No newline at end of file diff --git a/third_party/ascend/lib/CMakeLists.txt b/third_party/ascend/lib/CMakeLists.txt index bd3c0c6c01..c5fa61143a 100644 --- a/third_party/ascend/lib/CMakeLists.txt +++ b/third_party/ascend/lib/CMakeLists.txt @@ -1 +1,42 @@ -add_subdirectory(Conversion) +add_subdirectory(AutoBlockify) +add_subdirectory(Dialect) +add_subdirectory(TritonToAnnotation) +add_subdirectory(TritonToHFusion) +add_subdirectory(TritonToHIVM) +add_subdirectory(TritonToLinalg) +add_subdirectory(Utils) +add_subdirectory(DiscreteMaskAccessConversion) +add_subdirectory(TritonToUnstructure) +add_subdirectory(TritonToLLVM) +add_subdirectory(TritonToStructured) +add_subdirectory(TritonAffinityOpt) + +if(TRITON_ENABLE_COVERAGE_HITEST) + set(_instrument_targets + DiscreteMaskAccessConversion + TritonToAnnotation + TritonToHFusion + TritonToHIVM + TritonToLinalg + TritonToLLVM + TritonToStructured + TritonToUnstructure + MLIRTritonNPUUtils # from Utils + TritonAscendIR # from Dialect/TritonAscend/IR + TritonStructuredIR # from Dialect/TritonStructured/IR + AutoBlockify + TritonAffinityOpt + ) + + foreach(_target ${_instrument_targets}) + if(TARGET ${_target}) + set_target_properties(${_target} PROPERTIES + RULE_LAUNCH_COMPILE "hitestwrapper" + RULE_LAUNCH_LINK "hitestwrapper" + ) + message(STATUS "Enabled hitestwrapper for target: ${_target}") + else() + message(WARNING "Target ${_target} not found, please check the actual target name") + endif() + endforeach() +endif() diff --git a/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp b/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp index 9c4d41f9ff..75ea897ef4 100644 --- a/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp +++ b/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp @@ -36,14 +36,10 @@ static Type getElementType(Value value) { return type; } -static int64_t get1DTensorLength(Value tensor) { - auto type = mlir::cast(tensor.getType()); - auto shape = type.getShape(); - - assert(shape.size() == 1 && - "ElementwiseInlineAsm now can operate only with 1D tensors"); - - return shape[0]; +static int64_t getTensorNumElements(Value tensor) +{ + auto type = mlir::cast(tensor.getType()); + return type.getNumElements(); } static Value getInt32Value(RewriterBase &rewriter, Location loc, int val) { @@ -81,21 +77,25 @@ SmallVector packOperands(mlir::triton::ElementwiseInlineAsmOp op, return packedOperands; } -static SmallVector unpackElements(Location loc, Value packedValues, - RewriterBase &rewriter) { - auto type = mlir::cast(packedValues.getType()); - auto elementType = type.getElementType(); +static SmallVector unpackElements(Location loc, Value packedValues, RewriterBase &rewriter) +{ + auto type = mlir::cast(packedValues.getType()); + auto elementType = type.getElementType(); + auto shape = type.getShape(); - int64_t length = get1DTensorLength(packedValues); + int64_t numElements = type.getNumElements(); - SmallVector result; - for (int64_t idx = 0; idx < length; idx++) { - SmallVector indexes{ - rewriter.create(loc, idx)}; - Value extracted = rewriter.create(loc, elementType, - packedValues, indexes); - result.push_back(extracted); - } + SmallVector result; + for (int64_t linearIdx = 0; linearIdx < numElements; linearIdx++) { + SmallVector indexes(shape.size()); + int64_t remaining = linearIdx; + for (int64_t dim = shape.size() - 1; dim >= 0; dim--) { + indexes[dim] = rewriter.create(loc, remaining % shape[dim]); + remaining /= shape[dim]; + } + Value extracted = rewriter.create(loc, elementType, packedValues, indexes); + result.push_back(extracted); + } return result; } @@ -175,46 +175,79 @@ createDestOps(triton::ElementwiseInlineAsmOp op, RewriterBase &rewriter, return ret; } -} // namespace +static LogicalResult processScalarInlineAsm(triton::ElementwiseInlineAsmOp op, + PatternRewriter &rewriter) +{ + Location loc = op.getLoc(); -struct ElementwiseInlineAsmOpConversion - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + auto outsWrapped = createDestOps(op, rewriter, {}, loc); - LogicalResult matchAndRewrite(triton::ElementwiseInlineAsmOp op, - PatternRewriter &rewriter) const final { - Location loc = op.getLoc(); + SmallVector outs; + for (const auto &resWrapped : outsWrapped) { + outs.push_back(resWrapped[0]); + } + rewriter.replaceOp(op, outs); - SmallVector> unpackedOperands; - for (auto operand : op.getOperands()) { - auto unpackedOperand = unpackElements(loc, operand, rewriter); - unpackedOperands.push_back(unpackedOperand); - } + return success(); +} - int64_t resultLength = get1DTensorLength(op->getResult(0)); - if (resultLength % op.getPackedElement()) { - op.emitError("Result tensor should be diveded to pack"); - return failure(); - } +static LogicalResult processVectorInlineAsm(triton::ElementwiseInlineAsmOp op, + PatternRewriter &rewriter) +{ + Location loc = op.getLoc(); - SmallVector> unpackedResults(op->getNumResults()); - for (int64_t i = 0; i < resultLength; i += op.getPackedElement()) { - // Block of elements to process with one call to the inline asm. This is - // ordered opposite `unpackedResults`: The outer dim is - // op.getPackedElement(), and the inner dim is the operand. - SmallVector> block(op.getPackedElement()); - for (auto &os : unpackedOperands) { - for (int j = 0; j < op.getPackedElement(); j++) { - block[j].push_back(os[i + j]); - } - } - auto cur = createDestOps(op, rewriter, block, loc); - assert(cur.size() == unpackedResults.size()); - for (unsigned j = 0; j < cur.size(); j++) { - unpackedResults[j].insert(unpackedResults[j].end(), cur[j].begin(), - cur[j].end()); + SmallVector> unpackedOperands; + for (auto operand : op.getOperands()) { + auto unpackedOperand = unpackElements(loc, operand, rewriter); + unpackedOperands.push_back(unpackedOperand); + } + + int64_t resultLength = getTensorNumElements(op->getResult(0)); + if (resultLength % op.getPackedElement()) { + op.emitError("Result tensor should be diveded to pack"); + return failure(); + } + + SmallVector> unpackedResults(op->getNumResults()); + for (int64_t i = 0; i < resultLength; i += op.getPackedElement()) { + // Block of elements to process with one call to the inline asm. This is + // ordered opposite `unpackedResults`: The outer dim is + // op.getPackedElement(), and the inner dim is the operand. + SmallVector> block(op.getPackedElement()); + for (auto &os : unpackedOperands) { + for (int j = 0; j < op.getPackedElement(); j++) { + block[j].push_back(os[i + j]); } } + auto cur = createDestOps(op, rewriter, block, loc); + assert(cur.size() == unpackedResults.size()); + for (unsigned j = 0; j < cur.size(); j++) { + unpackedResults[j].insert(unpackedResults[j].end(), cur[j].begin(), + cur[j].end()); + } + } + // Reorder and pack the results. + SmallVector outs; + for (int i = 0; i < unpackedResults.size(); i++) { + outs.push_back(rewriter.create( + loc, op->getResult(i).getType(), unpackedResults[i])); + } + rewriter.replaceOp(op, outs); + + return success(); +} + +} // namespace + +struct ElementwiseInlineAsmOpConversion : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::ElementwiseInlineAsmOp op, + PatternRewriter &rewriter) const final + { + return op.getOperands().empty() ? processScalarInlineAsm(op, rewriter) + : processVectorInlineAsm(op, rewriter); + } // Reorder and pack the results. SmallVector outs; for (int i = 0; i < unpackedResults.size(); i++) { diff --git a/third_party/ascend/lib/TritonAffinityOpt/CMakeLists.txt b/third_party/ascend/lib/TritonAffinityOpt/CMakeLists.txt new file mode 100644 index 0000000000..2f49f3f0b4 --- /dev/null +++ b/third_party/ascend/lib/TritonAffinityOpt/CMakeLists.txt @@ -0,0 +1,19 @@ +add_triton_library(TritonAffinityOpt + DAGSSBuffer.cpp + DAG.cpp + DAGSync.cpp + DAGScope.cpp + + DEPENDS + TritonAffinityOptConversionPassIncGen + + LINK_LIBS + BiShengIRHIVMDialect + BiShengIRScopeDialect + MLIRIR + MLIRPass + MLIRTransforms + MLIRSupport + TritonIR + MLIRSCFDialect +) \ No newline at end of file diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAG.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAG.cpp new file mode 100644 index 0000000000..b796f8eb96 --- /dev/null +++ b/third_party/ascend/lib/TritonAffinityOpt/DAG.cpp @@ -0,0 +1,534 @@ +#include "TritonAffinityOpt/DAG.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include +#include "bishengir/Dialect/Annotation/IR/Annotation.h" + +namespace mlir { namespace AffinityDAG { + +const auto printFlags = OpPrintingFlags() + .enableDebugInfo(true, true) + .skipRegions(); + +const char* literalCoreType(CoreType ct) { + switch (ct) { + case VECTOR_ONLY: + return "VECTOR_ONLY"; + case CUBE_ONLY: + return "CUBE_ONLY"; + case CUBE_AND_VECTOR: + return "CUBE_AND_VECTOR"; + case UNDETERMINED: + return "UNDETERMINED"; + } + return "Unknown"; +} + + +bool opIsScf(Operation* op) { + if (!llvm::isa(op->getDialect())) + return false; + return true; +} + +Graph::Graph(Block* block, Graph* parent, OpMap opMap, ValueMap valueMap, bool inheritParent) : + block(block), + parent(parent), + opMap(opMap), + valueMap(valueMap) +{ + + if (parent && inheritParent) { + if (!this->opMap) { + this->opMap = parent->opMap; + } + + if (!this->valueMap) { + this->valueMap = parent->valueMap; + } + } + + if (!this->opMap) { + this->opMap = std::make_shared(); + } + + if (!this->valueMap) { + this->valueMap = std::make_shared(); + } + + for(auto blockArg : block->getArguments()) { + (*this->valueMap)[blockArg] = std::make_unique(blockArg); + blockArgs.push_back((*this->valueMap)[blockArg].get()); + } + + for(auto& opRef : block->getOperations()) { + opCount += 1; + auto op = &opRef; + auto opNodeUnique = std::make_unique(op, this); + auto opNode = opNodeUnique.get(); + (*this->opMap)[op] = std::move(opNodeUnique); + + if (block->mightHaveTerminator() && op == block->getTerminator()) { + terminator = opNode; + } + + for (auto& subgraph : opNode->subgraphs) { + opCount += subgraph.opCount; + } + } +}; + +bool valueIsScalar(Value value) { + auto type = value.getType(); + + if (type.isIntOrIndexOrFloat()) { + return true; + } + + if (auto tensorType = llvm::dyn_cast(type)) { + return tensorType.getRank() == 0; + } + + if (auto _ = llvm::dyn_cast(type)) { + return true; + } + + return false; +} + +bool valueIsTensorOfPtr(Value value) { + auto type = value.getType(); + if (auto tensorType = llvm::dyn_cast(type)) { + auto elementType = tensorType.getElementType(); + if (llvm::isa(elementType)) { + return true; + } + } + + return false; +} + +OpAbility OpNode::canRunOn() const { + if (opIsScf(op)) { + return OpAbility::CUBE_AND_VECTOR; + } + return llvm::TypeSwitch(op) + .Case([](auto) { + return OpAbility::CUBE_ONLY; + }) + .Case([](auto) { + return OpAbility::CUBE_AND_VECTOR; + }) + .Case([](arith::SelectOp op) { + // when cond is vector, selectOp should be vector, otherwise scalar + return ( + valueIsScalar(op.getCondition()) ? OpAbility::CUBE_AND_VECTOR : OpAbility::PREFER_VECTOR + ); + }) + .Default([](Operation* op) { + auto isVector = false; + for(auto operand : op->getOperands()) { + if (!valueIsScalar(operand)) { + // if (valueIsTensorOfPtr(operand)) { + // return SCALAR; + // } + isVector = true; + } + } + + for(auto result : op->getResults()) { + if (!valueIsScalar(result)) { + // if (valueIsTensorOfPtr(result)) { + // return SCALAR; + // } + isVector = true; + } + } + + if (isVector) { + return OpAbility::PREFER_VECTOR; + } + + return OpAbility::CUBE_AND_VECTOR; + }); +} + +OpNode::OpNode(Operation* op, Graph* graph) : Node(Node::NK_Op), op(op) { + if (op == nullptr) { + return; + } + + llvm::outs() << op << "\n"; + + auto& valueMap = *graph->valueMap.get(); + auto& opMap = *graph->opMap.get(); + for(const auto operand : op->getOperands()) { + auto valueNode = valueMap.at(operand).get(); + valueNode->outputs.push_back(this); + inputs.push_back(valueNode); + } + + for(const auto& result : op->getResults()) { + auto valueNodeUnique = std::make_unique(result); + auto valueNode = valueNodeUnique.get(); + valueMap[result] = std::move(valueNodeUnique); + valueNode->source = this; + outputs.push_back(valueNode); + } + + // if (!op->hasTrait()) { + // llvm::dbgs() << "Not building subgraph because op is not SingleBlock: " << op << '\n'; + // return; + // } + + if (auto branchOp = llvm::dyn_cast(op)) { + + OpNode* terminator = nullptr; + llvm::SmallVector, 2> validRegions; + + for(auto& region : branchOp->getRegions()) { + if (region.getBlocks().empty()) + continue; + subgraphs.emplace_back(®ion.getBlocks().front(), graph); + validRegions.emplace_back(region, subgraphs.back()); + } + + for(auto [region, subgraph] : validRegions) { + SmallVector succRegions; + + branchOp.getSuccessorRegions(region, succRegions); + if (auto currTerminator = dyn_cast(subgraph.terminator->op)) { + for(auto& succ : succRegions) { + auto forwardedVal = currTerminator.getSuccessorOperands(succ); + if (succ.isParent()) { + // Step1: first yield to parent -> results: double direction + if (!terminator && subgraph.terminator) { + terminator = subgraph.terminator; + for(auto [forwardedVal, resultNode] : llvm::zip_equal( + forwardedVal, + outputs + )) { + auto resultValueNode = llvm::dyn_cast(resultNode); + assert(resultValueNode && "Output of a OpNode should be ValueNode!"); + auto forwardedNode = valueMap[forwardedVal].get(); + resultValueNode->source = forwardedNode; + forwardedNode->outputs.push_back(resultNode); + } + } + + } else { + // Step2: Region terminator -> Succ Operands + auto succRegion = succ.getSuccessor(); + + for(auto [operand, succInput] : llvm::zip_equal( + forwardedVal, + succ.getSuccessorInputs() + )) { + auto forwardedNode = valueMap[operand].get(); + auto succNode = valueMap[succInput].get(); + forwardedNode->outputs.push_back(succNode); + succNode->source = forwardedNode; + } + } + } + } + } + + if (auto loopOp = llvm::dyn_cast(op)) { + // Step3: inits->iter_args (single directional) (should be handled in step 2: ) last terminator -> iter_args (bidirectional) + for(auto [init, iterArgVal] : llvm::zip_equal(loopOp.getInits(), loopOp.getRegionIterArgs())) { + auto& initNode = valueMap[init]; + auto& iterArgNode = valueMap[iterArgVal]; + initNode->outputs.push_back(iterArgNode.get()); + } + // for(auto [init, iterArgVal, yieldNode] : llvm::zip_equal(loopOp.getInits(), loopOp.getRegionIterArgs(), terminator->outputs)) { + // auto& initNode = valueMap[init]; + // auto& iterArgNode = valueMap[iterArgVal]; + // initNode->outputs.push_back(iterArgNode.get()); + // yieldNode->outputs.push_back(iterArgNode.get()); + // iterArgNode->source = yieldNode; + // } + } + } +} + +// llvm::SmallVector getWriteOperandPriority(OpNode* op) { + +// llvm::SmallVector result(op->getInputs()); + +// auto getPriority = [](ValueNode* node) { +// auto typ = getElementTypeOrSelf(node->value); +// if (typ.isInteger(1)) { +// return 2; +// } +// if (llvm::isa(typ)) { +// return 1; +// } +// return 0; +// }; + +// std::stable_sort(result.begin(), result.end(), [&](ValueNode* a, ValueNode* b) { +// return getPriority(a) < getPriority(b); +// }); + +// return result; +// } + +ValueNode* getWriteDataSource(OpNode* op) { + auto inputRange = op->getInputs(); + for(auto node : inputRange.drop_front()) { + auto typ = getElementTypeOrSelf(node->value); + if (!typ.isInteger(1)) { + return node; + } + }; + + return nullptr; +} + +enum class MemPolicy { + NONE, + READ, + WRITE +}; + +CoreType Node::absorbCommon() { + + auto sourceNode = getSourceOpNode(); + auto op = sourceNode ? sourceNode->op : nullptr; + + if (!sourceNode || !op) { + CoreType newCoreType = isOnPrivate; + for(auto output : outputs) { + newCoreType = newCoreType | output->isOn(); + isUpstreamOfCubeMem = isUpstreamOfCubeMem || output->isUpstreamOfCubeMem; + } + return newCoreType; + } + + CoreType newCoreType = sourceNode->isOn(); + + OpAbility ability = sourceNode->canRunOn(); + + if (ability == OpAbility::CUBE_ONLY) { + return CUBE_ONLY; + } + + auto memIface = llvm::dyn_cast(op); + auto memPolicy = MemPolicy::NONE; + + if (memIface) { + // Possible improvements: Determine the policy to use based on shapes, inputs and outputs, etc + if (memIface.hasEffect()) { + memPolicy = MemPolicy::WRITE; + } else if (memIface.hasEffect()) { + memPolicy = MemPolicy::READ; + } + } + + if (memPolicy == MemPolicy::WRITE) { + if (auto data = getWriteDataSource(sourceNode)) { + auto currCt = data->isOn(); + if (exactlyOneType(currCt)) { + if (currCt == CUBE_ONLY) { + isUpstreamOfCubeMem = true; + } + return currCt; + } + } + + // data is not cube_only + return VECTOR_ONLY; + } + + for(auto output : outputs) { + switch (output->isOn()) { + case CUBE_AND_VECTOR: + newCoreType = newCoreType | VECTOR_ONLY; + // not breaking the switch because we need to handle cube + case CUBE_ONLY: + if ( + ability != OpAbility::PREFER_VECTOR || + output->isUpstreamOfCubeMem || + memPolicy == MemPolicy::READ + ) { + isUpstreamOfCubeMem = ( + isUpstreamOfCubeMem || + output->isUpstreamOfCubeMem || + memPolicy == MemPolicy::READ + ); + newCoreType = newCoreType | CUBE_ONLY; + } + break; + case VECTOR_ONLY: + newCoreType = newCoreType | VECTOR_ONLY; + default: // UNDETERMINED, skip + break; + }; + } + + return newCoreType; +} + +CoreType OpNode::absorbImpl() { + if (opIsScf(op)) { + return CUBE_AND_VECTOR; + } + + auto newCoreType = absorbCommon(); + + // if (canRunOn() == OpAbility::CUBE_AND_VECTOR) { + // for (auto input : inputs) { + // newCoreType = newCoreType | input->isOn(); + // } + // } + + return newCoreType; +} + +CoreType ValueNode::absorbImpl() { + return absorbCommon(); +} + +std::unique_ptr Graph::fromMultiBlockFunc(triton::FuncOp funcOp) { + + auto dummyBlock = new Block(); + auto dummyGraph = std::make_unique(dummyBlock); + auto dummyNode = std::make_unique(nullptr, dummyGraph.get()); + size_t opCount = 0; + + for (auto& block : funcOp.getBody()) { + auto& subgraph = dummyNode->subgraphs.emplace_back( + &block, + dummyGraph.get() + ); + opCount += subgraph.opCount; + } + + auto& opMap = *dummyGraph->opMap.get(); + auto& valueMap = *dummyGraph->valueMap.get(); + + llvm::SmallVector nodes; + nodes.reserve(opMap.size() + valueMap.size()); + + for(auto& [_, node] : opMap) { + if (node.get()) + nodes.push_back(node.get()); + } + + for(auto& [_, node] : valueMap) { + if (node.get()) + nodes.push_back(node.get()); + } + + auto diffuse = [&]() { + // Not sure if determinism is required + llvm::SmallSetVector worklist(nodes.begin(), nodes.end()); + + size_t threshold = worklist.size() * 5; + + for(size_t i = 0; i< threshold; i++) { + if (worklist.empty()) { + break; + } + + auto node = worklist.pop_back_val(); + + if (node->absorb()) { + auto affected = node->getAffected(); + worklist.insert(affected.begin(), affected.end()); + } + } + }; + + diffuse(); + + for(auto node : nodes) { + if (node->isOn() == UNDETERMINED) { + node->isOnPrivate = VECTOR_ONLY; + } + } + + diffuse(); + + OpPrintingFlags flags; + flags.skipRegions(); + + for(auto [idx, node] : llvm::enumerate(nodes)) { + llvm::TypeSwitch(node) + .Case([&, idx=idx](OpNode* node) { + if (node->op) { + llvm::dbgs() << llvm::formatv("\n\n====== OpNode on: {1} @ {0} ======\n", + node->op, + literalCoreType(node->isOn()) + ); + node->op->print(llvm::dbgs(), flags); + llvm::dbgs() << "\nAbility: " << literalCoreType(toCoreType(node->canRunOn())); + llvm::dbgs() << llvm::formatv("\n====== {0} ======\n", node->op); + } + }) + .Case([&, idx=idx](ValueNode* node) { + if (node->value) { + llvm::dbgs() << llvm::formatv("\n\n====== ValueNode on {1} @ {0} ======\n", + node->value, + literalCoreType(node->isOn()) + ); + node->value.print(llvm::dbgs(), flags); + llvm::dbgs() << llvm::formatv("\n====== {0} ======\n", node->value); + } + }); + // if (auto opNode = llvm::dyn_cast(node)) { + // if (auto forOp = llvm::dyn_cast_if_present(opNode->op)) { + // llvm::dbgs() << "\n==== ForOp ====\n"; + // llvm::dbgs() << forOp << "\n"; + // llvm::dbgs() << "\n---- IterArgs ----\n"; + // for(auto iterArg : forOp.getRegionIterArgs()) { + // auto& valueNode = valueMap[iterArg]; + // llvm::dbgs() << llvm::formatv( + // "{0}: {1} upstream: {2} definingOp: {3} \n", + // iterArg.getArgNumber(), + // literalCoreType(valueNode->isOn()), + // literalCoreType(valueNode->source->isOn()), + // valueNode->getSourceOp()->op + // ); + // } + // llvm::dbgs() << "\n---- Results ----\n"; + // for(auto result : forOp.getResults()) { + // llvm::dbgs() << result.getResultNumber() << ' ' << literalCoreType(valueMap[result]->isOn()) << '\n'; + // } + // } + // } + } + + return dummyGraph; +}; + +} } diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp new file mode 100644 index 0000000000..211adc311d --- /dev/null +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp @@ -0,0 +1,5534 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "TritonAffinityOpt/Passes.h" + +#include "bishengir/Dialect/Scope/IR/Scope.h" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "bishengir/Dialect/HIVM/IR/HIVMImpl.h" +#include "bishengir/Dialect/HIVM/Transforms/Passes.h" +#include "bishengir/Dialect/HIVM/IR/HIVMInterfaces.h" +#include "bishengir/Dialect/HIVM/Utils/Utils.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Block.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" + +#include "Utils/Utils.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include + +// #include "mlir/Pass/Pass.h" +// #include "mlir/Pass/PassManager.h" + +// #include "mlir/Transforms/Canonicalizer.h" +// #include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DAGSSBUFFER +#include "ascend/include/TritonAffinityOpt/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace hivm; + +namespace { +struct DAGSSBufferPass + : public mlir::triton::impl::DAGSSBufferBase< + DAGSSBufferPass> { + void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } +}; +} // namespace + +void ControlSsbufV2(ModuleOp module) { + mlir::OpBuilder builder(module.getContext()); + // 用于记录已经处理过的scope.scope操作 + llvm::DenseSet processedScopes; + + auto aiCAttr = hivm::TCoreTypeAttr::get( + builder.getContext(), + hivm::TCoreType::CUBE); + int cubeControlIndex = 15; + int vectorControlIndex = 14; + + llvm::DenseSet processedScopes2; + module->walk([&](SyncBlockWaitOp op) { + auto pipeS = hivm::PipeAttr::get(op->getContext(), hivm::PIPE::PIPE_S); + if (op.getTpipe() == pipeS || op.getPipe() == pipeS) { + return; + } + + // 向上查找父scope.scope操作 + mlir::Operation* parentOp = op->getParentOp(); + mlir::Operation* scopeOp = nullptr; + mlir::Operation* forOp = nullptr; + + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + parentOp = op->getParentOp(); + while (parentOp) { + if (dyn_cast(parentOp)) { + forOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + // 如果没有找到scope.scope操作,则跳过 + if (!scopeOp) { + return; + } + if (!forOp) { + return; + } + + // 如果该scope已经处理过,则跳过 + if (processedScopes2.count(forOp) > 0) return; + + // 标记该scope为已处理 + processedScopes2.insert(forOp); + + }); + bool firstSet = true; + bool firstWait = true; + for (auto forOp : processedScopes2) { + mlir::Operation* parentOp = forOp->getParentOp(); + mlir::Operation* scopeOp = nullptr; + + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + bool isAIC = false; + // 1. 先检查操作是否有这个属性 + + if (scopeOp->hasAttr("hivm.tcore_type")) { + auto attr = scopeOp->getAttr("hivm.tcore_type"); + if (attr == aiCAttr) { + isAIC = true; + } + } + + if (isAIC) { + // 在for循环的开头插入代码 + builder.setInsertionPoint(scopeOp); + // %ssb_ready_addr = llvm.mlir.constant(0 : i64) : i64 + auto i64Type = builder.getIntegerType(64); + auto i32Type = builder.getIntegerType(32); + + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + // %ssb_ready_addr = llvm.mlir.constant(0 : i64) : i64 + // add sync_block_wait + auto coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); + + // 在循环末尾(yield之前)插入代码 + auto &loopBody = forOp->getRegion(0).front(); + // 找到循环体的terminator(应该是yield操作) + auto *terminator = loopBody.getTerminator(); + builder.setInsertionPoint(terminator); + + // add sync_block_set + coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::CUBE); + setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + flagId = builder.getIntegerAttr(builder.getI64Type(), cubeControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); + + if (firstWait) { + auto &scopeBlock = scopeOp->getRegion(0).front(); + auto *scope_terminator = scopeBlock.getTerminator(); + builder.setInsertionPoint(scope_terminator); + // add sync_block_wait + coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::CUBE); + setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + flagId = builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); + firstWait = false; + } + } + else { + // 1. 在scopeop的开头插入代码 + // 假设scopeOp是一个具有区域的操作,我们获取其第一个块 + if (firstSet) { + auto &scopeBlock = scopeOp->getRegion(0).front(); + builder.setInsertionPointToStart(&scopeBlock); + + // add sync_block_wait + auto coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); + firstSet = false; + } + + auto i64Type = builder.getIntegerType(64); + auto i32Type = builder.getIntegerType(32); + + // 创建需要的常量 + auto c32ConstAttr = mlir::IntegerAttr::get(i64Type, 32); + auto c32ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c32ConstAttr); + + auto c0i64ConstAttr = mlir::IntegerAttr::get(i64Type, 0); + auto c0i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c0i64ConstAttr); + + auto c0i32ConstAttr = mlir::IntegerAttr::get(i32Type, 0); + auto c0i32ConstOp = builder.create( + scopeOp->getLoc(), i32Type, c0i32ConstAttr); + + auto c1i32ConstAttr = mlir::IntegerAttr::get(i32Type, 1); + auto c1i32ConstOp = builder.create( + scopeOp->getLoc(), i32Type, c1i32ConstAttr); + + // %sub_id = hivm.hir.get_sub_block_idx -> i64 + // 这里假设有一个getSubBlockIdxOp操作 + auto subIdOp = builder.create( + scopeOp->getLoc(), i64Type); + + // %ssb_addr_offset = arith.muli %sub_id, %c32_i64 : i64 + auto ssbAddrOffsetOp = builder.create( + scopeOp->getLoc(), + subIdOp.getResult(), + c32ConstOp.getResult()); + + // %ssb_addr = arith.addi %ssb_addr_offset, %c32_i64 : i64 + auto ssbAddrOp = builder.create( + scopeOp->getLoc(), + ssbAddrOffsetOp.getResult(), + c32ConstOp.getResult()); + + // %vec_id = arith.cmpi eq, %sub_id, %c0_i64 : i64 + auto vecIdOp = builder.create( + scopeOp->getLoc(), + mlir::arith::CmpIPredicate::eq, + subIdOp.getResult(), + c0i64ConstOp.getResult()); + + // 2. 在parentop的开头插入代码 + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + + // add sync_block_wait + auto coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), cubeControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); + + // 在循环末尾(yield之前)插入代码 + auto &loopBody = forOp->getRegion(0).front(); + // 找到循环体的terminator(应该是yield操作) + auto *terminator = loopBody.getTerminator(); + builder.setInsertionPoint(terminator); + + // add sync_block_wait + coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::VECTOR); + setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + flagId = builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); + } + } + + auto i64Type = builder.getIntegerType(64); + auto i32Type = builder.getIntegerType(32); + auto initPtrType = mlir::LLVM::LLVMPointerType::get(builder.getContext(), 11); + SmallVector scopeOps; + module->walk([&](mlir::Operation* op) { + // 检查是否为目标操作 + if (auto scopeOp = dyn_cast(op)) { + scopeOps.push_back(scopeOp); + } + }); + if (!scopeOps.empty()) { + auto scopeOp = scopeOps[0]; + builder.setInsertionPoint(scopeOp); + auto c0i64ConstAttr = mlir::IntegerAttr::get(i64Type, 0); + auto c0i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c0i64ConstAttr); + auto c32i64ConstAttr = mlir::IntegerAttr::get(i64Type, 32); + auto c32i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c32i64ConstAttr); + auto c64i64ConstAttr = mlir::IntegerAttr::get(i64Type, 64); + auto c64i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c64i64ConstAttr); + auto c96i64ConstAttr = mlir::IntegerAttr::get(i64Type, 96); + auto c96i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c96i64ConstAttr); + auto c0i32ConstAttr = mlir::IntegerAttr::get(i32Type, 0); + auto c0i32ConstOp = builder.create( + scopeOp->getLoc(), i32Type, c0i32ConstAttr); + + auto c0initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c0i64ConstOp.getResult()); + auto c32initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c32i64ConstOp.getResult()); + auto c64initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c64i64ConstOp.getResult()); + auto c96initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c96i64ConstOp.getResult()); + + builder.create( + scopeOp->getLoc(), + c0i32ConstOp, + c0initInttoptrOp + ); + builder.create( + scopeOp->getLoc(), + c0i32ConstOp, + c32initInttoptrOp + ); + builder.create( + scopeOp->getLoc(), + c0i32ConstOp, + c64initInttoptrOp + ); + builder.create( + scopeOp->getLoc(), + c0i32ConstOp, + c96initInttoptrOp + ); + } +} + +scf::ForOp transformLoop(scf::ForOp forOp, OpBuilder &builder) { + + // 1. 获取原始循环的信息 + Value originalLowerBound = forOp.getLowerBound(); + Value originalUpperBound = forOp.getUpperBound(); + Value originalStep = forOp.getStep(); + SmallVector iterArgs; + for (auto arg : forOp.getInitArgs()) { + iterArgs.push_back(arg); + } + auto yields = forOp.getBody()->getTerminator(); + + // 2. 检查循环体中是否有特定操作 + int hasTargetOps = 0; + forOp.walk([&](Operation* op) { + if (auto ifOp = dyn_cast(op)) { + if (ifOp->hasAttr("ssbuffer")) { + hasTargetOps++; + } + } + }); + // 3. 如果存在目标操作,在迭代参数中添加计数器 + Value counterInit = nullptr; + mlir::Operation* parentOp = forOp->getParentOp(); + mlir::Operation* scopeOp = nullptr; + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + + builder.setInsertionPoint(scopeOp); + for (int i = 0; i < hasTargetOps; i++) { + Location loc = forOp.getLoc(); + auto argType = originalLowerBound.getType(); + + // 添加到迭代参数列表 + iterArgs.push_back(originalLowerBound); + } + // 2. 创建新的上界:originalUpperBound * 2 + Location loc = forOp.getLoc(); + Type ubType = originalStep.getType(); + builder.setInsertionPoint(forOp); + + int count = 0; + for (auto &op : forOp.getBody()->getOperations()) { + if (auto ifOp = dyn_cast(op)) { + auto parentOp = ifOp->getParentOp(); + if (parentOp == forOp && ifOp->hasAttr("ssbuffer")) { + count++; + } + } + } + + Value two; + if (ubType.isIndex()) { + two = builder.create(loc, count - 1); + } else if (auto intType = dyn_cast(ubType)) { + // 对于整数类型,创建相应类型的常数2 + two = builder.create(loc, count - 1, intType); + } else { + // 其他类型可能需要特殊处理 + llvm::errs() << "Warning: Unexpected type for upper bound: " << ubType << "\n"; + // 尝试创建索引类型的2然后转换 + auto indexTwo = builder.create(loc, count - 1); + two = builder.create(loc, ubType, indexTwo); + } + + auto steps = builder.create( + forOp.getLoc(), + originalStep, + two + ); + + auto nowUpperBound = builder.create( + forOp.getLoc(), + originalUpperBound, + steps + ); + + // 3. Create a new for loop + auto newForOp = builder.create( + forOp.getLoc(), + originalLowerBound, + nowUpperBound, + originalStep, + iterArgs); + + // 4. 设置IR映射表,将旧循环的变量映射到新循环 + IRMapping mapper; + + // 映射迭代变量 + mapper.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + // 映射迭代参数 + for (auto [oldArg, newArg] : + llvm::zip(forOp.getRegionIterArgs(), + newForOp.getRegionIterArgs())) { + mapper.map(oldArg, newArg); + } + + SmallVector newCounterArgs; + for (int i = forOp.getRegionIterArgs().size(); i < newForOp.getRegionIterArgs().size(); i++) { + newCounterArgs.push_back(newForOp.getRegionIterArgs()[i]); + } + // 5. 克隆循环体内容到新循环 + auto &newLoopBody = *newForOp.getBody(); + builder.setInsertionPointToStart(&newLoopBody); + + for (auto &op : forOp.getBody()->without_terminator()) { + builder.clone(op, mapper); + } + + // 6. 克隆yield操作 + if (auto yieldOp = dyn_cast(yields)) { + SmallVector newYieldOperands; + for (auto operand : yieldOp.getOperands()) { + newYieldOperands.push_back(mapper.lookupOrDefault(operand)); + } + if (hasTargetOps != 0) { + for (auto currentCounter : newCounterArgs) { + // 将更新后的计数器添加到yield操作数中 + newYieldOperands.push_back(currentCounter); + } + } + builder.create(yieldOp.getLoc(), newYieldOperands); + } + + // 7. 替换原循环的结果 + if (hasTargetOps != 0) { + // 新循环有额外的计数器结果,但原循环没有对应结果 + // 我们可以选择只替换原循环对应的结果,或者忽略计数器结果 + unsigned numOriginalResults = forOp.getNumResults(); + SmallVector originalResults; + for (unsigned i = 0; i < numOriginalResults; i++) { + originalResults.push_back(newForOp.getResult(i)); + } + forOp.replaceAllUsesWith(originalResults); + } else { + forOp.replaceAllUsesWith(newForOp.getResults()); + } + + // 8. 删除原循环 + forOp.erase(); + return newForOp; + +} + +// Find the first occurrence of convert_layout or fixpipe operation after the specified operation +Value findFirstTargetOpAfterWait(SyncBlockWaitOp waitOp, SmallVector& excludedValues) +{ + bool startSearching = false; + + for (Operation &op : waitOp->getBlock()->getOperations()) { + Value res = nullptr; + if (&op == waitOp) { + startSearching = true; + continue; + } + + if (startSearching) { + if (isa(op)) { + res = op.getOperands()[0]; + } + if (isa(op)) { + res = op.getOperands()[1]; + } + if (isa(op)) { + res = op.getOperands()[1]; + } + if (isa(op)) { + res = op.getOperands()[0]; + } + } + if (res) { + if (llvm::is_contained(excludedValues, res)) { + continue; + } + excludedValues.push_back(res); + return res; + } + } + + return nullptr; +} + +void getWaitType(std::string CoreType, scf::ForOp forOp, SmallVector& waitTypes, SmallVector& allocTypes) +{ + auto scalarWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_S); + auto cubeWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_FIX); + auto vectorWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_MTE3); + SmallVector excludedValues; + forOp.walk([&](Operation* op) { + if (auto waitOp = dyn_cast(op)) { + auto parentOp = op->getParentOp(); + if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { + auto ifOp = dyn_cast(parentOp); + if (forOp == ifOp->getParentOp()) { + auto waitPipe = waitOp.getPipe(); + if ((waitPipe == cubeWaitPipe && CoreType == "cube") || (waitPipe == vectorWaitPipe && CoreType == "vector")) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + waitTypes.push_back(0); + allocTypes.push_back(allocOp); + } + else if (waitPipe != scalarWaitPipe) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + waitTypes.push_back(1); + allocTypes.push_back(allocOp); + } + } + } + } + }); +} + +DenseMap getCounterOffset(scf::ForOp forOp) { + int i = 0; + DenseMap bufferMap; + auto scalarWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_S); + forOp.walk([&](Operation* op) { + bufferMap[i] = 0; + auto ifOp = dyn_cast(op); + if (ifOp && ifOp->hasAttr("ssbuffer") && ifOp->getParentOp() == forOp) { + ifOp.walk([&](Operation* op) { + if (auto waitOp = dyn_cast(op)) { + if (auto waitIfOp = dyn_cast(op->getParentOp())) { + if (waitIfOp == ifOp) { + auto waitPipe = waitOp.getPipe(); + if ((waitPipe != scalarWaitPipe)) { + bufferMap[i]++; + } + } + } + } + }); + i ++; + } + }); + return bufferMap; +} + +SmallVector addBufValLoop(scf::ForOp forOp, DenseMap VecBitMap, DenseMapCubeBitMap, OpBuilder &builder) +{ + auto aiCAttr = hivm::TCoreTypeAttr::get( + builder.getContext(), + hivm::TCoreType::CUBE); + bool isAIC = false; + // 向上查找父scope.scope操作 + mlir::Operation* parentOp = forOp->getParentOp(); + mlir::Operation* scopeOp = nullptr; + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + if (scopeOp->hasAttr("hivm.tcore_type")) { + auto attr = scopeOp->getAttr("hivm.tcore_type"); + if (attr == aiCAttr) { + isAIC = true; + } + } + auto bufferMap = getCounterOffset(forOp); + SmallVector buf_vals; + SmallVector if_conditions; + builder.setInsertionPointToStart(&scopeOp->getRegion(0).front()); + + // 1. 提取并处理end值 + Value startValue = forOp.getLowerBound(); + Value endValue = forOp.getUpperBound(); + // 2. 提取并处理step值 + Value stepValue = forOp.getStep(); + builder.setInsertionPoint(forOp); + Location loc = forOp.getLoc(); + int count = 0; + for (auto &op : forOp.getBody()->getOperations()) { + if (auto ifOp = dyn_cast(op)) { + auto parentOp = ifOp->getParentOp(); + if (parentOp == forOp && ifOp->hasAttr("ssbuffer")) { + count++; + } + } + } + + Value two; + Type ubType = stepValue.getType(); + if (ubType.isIndex()) { + two = builder.create(loc, count - 1); + } else if (auto intType = dyn_cast(ubType)) { + // 对于整数类型,创建相应类型的常数2 + two = builder.create(loc, count - 1, intType); + } else { + // 其他类型可能需要特殊处理 + llvm::errs() << "Warning: Unexpected type for upper bound: " << ubType << "\n"; + // 尝试创建索引类型的2然后转换 + auto indexTwo = builder.create(loc, count - 1); + two = builder.create(loc, ubType, indexTwo); + } + + auto steps = builder.create( + forOp.getLoc(), + endValue.getType(), + stepValue, + two + ); + + auto subLoopValue = builder.create( + forOp.getLoc(), + endValue.getType(), + endValue, + steps + ); + + SmallVector WaitType; + SmallVector AllocType; + SmallVector bufferPtrs; + if (isAIC) { + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + // 创建常量32和64 + Value c0 = builder.create( + forOp.getLoc(), 0, 32 // 值32,64位 + ); + Value c32 = builder.create( + forOp.getLoc(), 32, 64 // 值32,64位 + ); + Value c64 = builder.create( + forOp.getLoc(), 64, 64 // 值64,64位 + ); + // 创建inttoptr操作 + Value ssb_vec0_ptr = builder.create( + forOp.getLoc(), + LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 + c32 + ); + Value ssb_vec1_ptr = builder.create( + forOp.getLoc(), + LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 + c64 + ); + bufferPtrs.push_back(ssb_vec0_ptr); + bufferPtrs.push_back(ssb_vec1_ptr); + // 创建load操作 + Value status_vec0 = builder.create( + forOp.getLoc(), builder.getI32Type(), ssb_vec0_ptr + ); + + Value status_vec1 = builder.create( + forOp.getLoc(), builder.getI32Type(), ssb_vec1_ptr + ); + + getWaitType("cube", forOp, WaitType, AllocType); + + for (auto i = 0; i < WaitType.size(); i++) { + auto correnspondAlloc = CubeBitMap[AllocType[i]]; + auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + Value bufi_vec0_val = builder.create( + forOp.getLoc(), status_vec0, buf_constant_set + ); + Value bufi_vec1_val = builder.create( + forOp.getLoc(), status_vec1, buf_constant_set + ); + Value flag_bufi_vec0; + Value flag_bufi_vec1; + // 创建比较操作 + if (WaitType[i] == 0) { + flag_bufi_vec0 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec0_val, c0 + ); + flag_bufi_vec1 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec1_val, c0 + ); + } + else { + flag_bufi_vec0 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec0_val, buf_constant_set + ); + flag_bufi_vec1 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec1_val, buf_constant_set + ); + } + // 创建最终的and操作 + Value bufi_val = builder.create( + forOp.getLoc(), flag_bufi_vec0, flag_bufi_vec1 + ); + buf_vals.push_back(bufi_val); + } + + } else { + builder.setInsertionPointToStart(&scopeOp->getRegion(0).front()); + Value c0 = builder.create( + forOp.getLoc(), 0, 32 // 值32,64位 + ); + auto i64Type = builder.getIntegerType(64); + // %sub_id = hivm.hir.get_sub_block_idx -> i64 + // 这里假设有一个getSubBlockIdxOp操作 + auto subIdOp = builder.create( + scopeOp->getLoc(), i64Type); + auto i64ConstAttr = mlir::IntegerAttr::get(i64Type, 32); + auto cst_offset = builder.create( + scopeOp->getLoc(), i64Type, i64ConstAttr); + auto ssb_addr_offset = builder.create( + scopeOp->getLoc(), subIdOp, cst_offset); + auto ssb_addr = builder.create( + scopeOp->getLoc(), ssb_addr_offset, cst_offset); + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + // 创建inttoptr操作 + Value ssb_cube_ptr = builder.create( + forOp.getLoc(), + LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 + ssb_addr + ); + bufferPtrs.push_back(ssb_cube_ptr); + // 创建load操作 + Value status_cube = builder.create( + forOp.getLoc(), builder.getI32Type(), ssb_cube_ptr + ); + + getWaitType("vector", forOp, WaitType, AllocType); + for (auto i = 0; i < WaitType.size(); i++) { + auto correnspondAlloc = VecBitMap[AllocType[i]]; + auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + Value bufi_cube_val = builder.create( + forOp.getLoc(), status_cube, buf_constant_set + ); + + Value flag_bufi_cube; + // 创建比较操作 + if (WaitType[i] == 0) { + flag_bufi_cube = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_cube_val, c0 + ); + } + else { + flag_bufi_cube = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_cube_val, buf_constant_set + ); + } + buf_vals.push_back(flag_bufi_cube); + } + } + int bufIdx = 0; + int groupIdx = 0; + + for (const auto &pair : bufferMap) { + if (bufferMap[groupIdx] == 0) { + continue; + } + + // 获取对应的region迭代参数 + Value cnti = builder.create( + forOp.getLoc(), arith::CmpIPredicate::slt, + forOp.getRegionIterArgs()[forOp.getRegionIterArgs().size() - (bufferMap.size() - 1 - groupIdx)], + subLoopValue + ); + + // 计算该组中所有buffer值的AND + Value finalBufVal = buf_vals[bufIdx]; + for (int count = 1; count < bufferMap[groupIdx]; count++) { + finalBufVal = builder.create( + forOp.getLoc(), finalBufVal, buf_vals[bufIdx + count] + ); + } + + auto cond = builder.create( + forOp.getLoc(), finalBufVal, cnti + ); + if_conditions.push_back(cond); + + // 更新索引 + bufIdx += bufferMap[groupIdx]; + groupIdx++; + } + int ifIndex = 0; + int acc = 0; + int bufferBit = 0; + for (int i = 0; i < CubeBitMap.size(); i++) { + bufferBit += (1 << i); + } + forOp.getBody()->walk([&](Operation* op) { + auto ifOp = dyn_cast(op); + if (ifOp && ifOp->hasAttr("ssbuffer")) { + // 获取then区域 + Block* thenBlock = &ifOp.getThenRegion().front(); + + // 找到then区域中的yield操作 + Operation* yieldOp = nullptr; + for (auto& op : *thenBlock) { + if (isa(op)) { + yieldOp = &op; + break; + } + } + if (yieldOp) { + builder.setInsertionPoint(yieldOp); + + if (isAIC) { + // 创建插入的语句 + // %status_v2 = llvm.load %ssb_ptr : !llvm.ptr<11> -> i32 + Value status_v2_0 = builder.create( + yieldOp->getLoc(), + builder.getIntegerType(32), // i32类型 + bufferPtrs[0] // 假设ssb_ptr已在作用域中定义 + ); + Value status_v2_1 = builder.create( + yieldOp->getLoc(), + builder.getIntegerType(32), // i32类型 + bufferPtrs[1] // 假设ssb_ptr已在作用域中定义 + ); + Value buf_val_new_0 = status_v2_0; + Value buf_val_new_1 = status_v2_1; + auto bufferNum = bufferMap[ifIndex]; + for (int i = 0; i < bufferNum; i++) { + if (WaitType[acc + i] == 0) { + auto correnspondAlloc = CubeBitMap[AllocType[acc + i]]; + auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new_0 = builder.create( + yieldOp->getLoc(), + buf_val_new_0, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + buf_val_new_1 = builder.create( + yieldOp->getLoc(), + buf_val_new_1, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + } + else { + auto correnspondAlloc = CubeBitMap[AllocType[acc + i]]; + int bitPos = correnspondAlloc; + int basePattern = bufferBit; + int finalValue = basePattern ^ (1 << bitPos); + auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), finalValue); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new_0 = builder.create( + yieldOp->getLoc(), + buf_val_new_0, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + buf_val_new_1 = builder.create( + yieldOp->getLoc(), + buf_val_new_1, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + } + } + acc += bufferNum; + builder.create( + yieldOp->getLoc(), + buf_val_new_0, + bufferPtrs[0] + ); + builder.create( + yieldOp->getLoc(), + buf_val_new_1, + bufferPtrs[1] + ); + + } + else { + // 创建插入的语句 + // %status_v2 = llvm.load %ssb_ptr : !llvm.ptr<11> -> i32 + Value status_v2 = builder.create( + yieldOp->getLoc(), + builder.getIntegerType(32), // i32类型 + bufferPtrs[0] // 假设ssb_ptr已在作用域中定义 + ); + Value buf_val_new = status_v2; + auto bufferNum = bufferMap[ifIndex]; + for (int i = 0; i < bufferNum; i++) { + if (WaitType[acc + i] == 0) { + auto correnspondAlloc = VecBitMap[AllocType[acc + i]]; + auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new = builder.create( + yieldOp->getLoc(), + buf_val_new, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + } + else { + auto correnspondAlloc = VecBitMap[AllocType[acc + i]]; + int bitPos = correnspondAlloc; + int basePattern = bufferBit; + int finalValue = basePattern ^ (1 << bitPos); + auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), finalValue); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new = builder.create( + yieldOp->getLoc(), + buf_val_new, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + } + } + acc += bufferNum; + builder.create( + yieldOp->getLoc(), + buf_val_new, + bufferPtrs[0] + ); + } + ifIndex ++; + } + } + }); + + return if_conditions; +} + +void ReplaceIf(scf::ForOp forOp, SmallVector conditions, SmallVector& opsToErase, DenseMap& ifArgMap, OpBuilder &builder, ModuleOp moduleOp) +{ + SmallVector ifToProcess; + llvm::outs()<<"enter replaceif\n"; + Value step = forOp.getStep(); + auto aiCAttr = hivm::TCoreTypeAttr::get( + builder.getContext(), + hivm::TCoreType::CUBE); + forOp.getBody()->walk([&](Operation* op) { + auto ifOp = dyn_cast(op); + if (ifOp && ifOp->hasAttr("ssbuffer") && forOp == ifOp->getParentOp()) { + ifToProcess.push_back(ifOp); + } + }); + + IRMapping IRMap; + for (int i = 0; i < ifToProcess.size(); i++) { + auto ifOp = ifToProcess[i]; + auto parentOp = ifOp->getParentOp(); + auto loc = ifOp.getLoc(); + // 获取for循环的iterargs(迭代参数) + auto iterArgs = forOp.getRegionIterArgs(); + if (iterArgs.size() < conditions.size()) { + return; + } + auto thenYieldOp = dyn_cast(ifOp.getThenRegion().front().getTerminator()); + SmallVector thenResults; + if (thenYieldOp) { + // 如果已有返回值,保留它们 + for (auto result : thenYieldOp.getResults()) { + thenResults.push_back(result); + } + } + // 创建新的else区域,返回两个迭代参数 + SmallVector elseResults; + scf::YieldOp elseYieldOp = nullptr; + bool hasElse = false; + if (!ifOp.getElseRegion().empty()) { + elseYieldOp = dyn_cast(ifOp.getElseRegion().front().getTerminator()); + hasElse = true; + } + if (elseYieldOp) { + for (auto result : elseYieldOp.getResults()) { + elseResults.push_back(result); + } + } + // 获取最后两个迭代参数 + Value iterArgMinus = iterArgs[iterArgs.size() - (conditions.size() - i)]; + // 创建新的then区域,返回两个迭代参数 + thenResults.push_back(iterArgMinus); + elseResults.push_back(iterArgMinus); + + // 保存原有的操作,以便后续克隆 + SmallVector thenOps; + for (auto &op : ifOp.getThenRegion().front()) { + thenOps.push_back(&op); + } + + SmallVector elseOps; + if (!ifOp.getElseRegion().empty()) { + for (auto &op : ifOp.getElseRegion().front()) { + elseOps.push_back(&op); + } + } + SmallVector resultTypes; + for (auto val : thenResults) { + resultTypes.push_back(val.getType()); + } + // 创建新的scf.if操作 + builder.setInsertionPoint(ifOp); + auto newIfOp = builder.create( + loc, + resultTypes, + conditions[i], + /*withElseRegion=*/true); + newIfOp->setAttr("ssbuffer", builder.getUnitAttr()); + // 处理then区域 + auto &newThenBlock = newIfOp.getThenRegion().front(); + builder.setInsertionPointToStart(&newThenBlock); + + // 克隆then区域的操作 + for (auto op : thenOps) { + if (auto yieldOp = dyn_cast(op)) { + // 处理yield的操作数映射 + SmallVector mappedOperands; + for (auto operand : yieldOp->getOperands()) { + mappedOperands.push_back(IRMap.lookupOrDefault(operand)); + } + // 获取最后两个迭代参数 + Value iterArgMinus = iterArgs[iterArgs.size() - (conditions.size() - i)]; + + // %ssb_addr = arith.addi %ssb_addr_offset, %c32_i64 : i64 + auto AddIOp = builder.create( + forOp->getLoc(), + iterArgMinus, + step); + // 这里加个add1 + mappedOperands.push_back(AddIOp); + builder.create(loc, mappedOperands); + } else { + auto newOp = builder.clone(*op, IRMap); + IRMap.map(op->getResults(), newOp->getResults()); + } + } + + // 处理else区域 + auto &newElseBlock = newIfOp.getElseRegion().front(); + builder.setInsertionPointToStart(&newElseBlock); + // 克隆else区域的操作 + if (hasElse) { + for (auto op : elseOps) { + if (auto yieldOp = dyn_cast(op)) { + // 处理yield的操作数映射 + SmallVector mappedOperands; + for (auto operand : yieldOp->getOperands()) { + mappedOperands.push_back(IRMap.lookupOrDefault(operand)); + } + Value iterArgMinus = iterArgs[iterArgs.size() - (conditions.size() - i)]; + mappedOperands.push_back(iterArgMinus); + builder.create(loc, mappedOperands); + } else { + auto newOp = builder.clone(*op, IRMap); + IRMap.map(op->getResults(), newOp->getResults()); + } + } + } else { + SmallVector cntOperands; + cntOperands.push_back(iterArgMinus); + builder.create(loc, cntOperands); + } + + // 替换原有if操作的使用 + // 首先,将原if操作的结果替换为新if操作的对应结果 + for (unsigned j = 0; j < ifOp.getNumResults(); ++j) { + ifOp.getResult(j).replaceAllUsesWith(newIfOp.getResult(j)); + } + // 获取新if操作所在的块 + Block* newIfBlock = ifOp->getBlock(); + // 在for循环体内替换迭代参数的使用 + forOp.getBody()->walk([&](Operation* op) { + // 检查操作是否与新ifOp在同一个块中 + Block* opBlock = op->getBlock(); + if (opBlock != newIfBlock) { + // 不在同一个块中,跳过 + return; + } + if (op->isBeforeInBlock(newIfOp)) { + return; // 只处理if操作之后的use + } + for (unsigned j = 0; j < op->getNumOperands(); ++j) { + for (auto argIndex = 0; argIndex < conditions.size(); argIndex ++) { + // 获取最后两个迭代参数 + Value iterArgMinus = iterArgs[iterArgs.size() - (conditions.size() - i)]; + if (op->getOperand(j) == iterArgMinus) { + op->setOperand(j, newIfOp.getResults()[newIfOp.getNumResults() - 1]); + } + } + } + }); + + // // 删除原有的if操作 + opsToErase.push_back(ifOp); + if (ifArgMap.find(newIfOp) == ifArgMap.end()) { + ifArgMap[newIfOp] = iterArgMinus; + } + } +} + +int getNestingDepth(scf::ForOp forOp) { + int depth = 0; + Operation* op = forOp.getOperation(); + while (op) { + if (op->getDialect() && op->getDialect()->getNamespace() == "scf") { + ++depth; + } + op = op->getParentOp(); + } + return depth; +} + +void printDenseMap(const mlir::DenseMap& Map) +{ + for (const auto& pair : Map) { + mlir::Value val = pair.first; + int bitValue = pair.second; + llvm::outs()<& VecBitMap, DenseMap& CubeBitMap, OpBuilder builder) +{ + auto aiCAttr = hivm::TCoreTypeAttr::get( + builder.getContext(), + hivm::TCoreType::CUBE); + auto scalarWaitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto cubeWaitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_FIX); + auto vectorWaitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_MTE3); + + int cubeAcc = 0; + int vecAcc = 0; + SmallVector scopeOpToEdit; + module.walk([&](scope::ScopeOp scopeOp) { + scopeOpToEdit.push_back(scopeOp); + }); + for (auto scopeOp : scopeOpToEdit) { + SmallVector excludedValues; + if (scopeOp->hasAttr("hivm.tcore_type")) { + auto attr = scopeOp->getAttr("hivm.tcore_type"); + if (attr == aiCAttr) { + scopeOp.walk([&](SyncBlockWaitOp waitOp) { + auto parentOp = waitOp->getParentOp(); + if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { + auto waitPipe = waitOp.getPipe(); + if (waitPipe != scalarWaitPipe) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + if (VecBitMap.find(allocOp) != VecBitMap.end()) { + CubeBitMap[allocOp] = VecBitMap[allocOp]; + } else { + CubeBitMap[allocOp] = cubeAcc; + cubeAcc++; + } + } + } + }); + } else { + scopeOp.walk([&](SyncBlockWaitOp waitOp) { + auto parentOp = waitOp->getParentOp(); + if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { + auto waitPipe = waitOp.getPipe(); + if (waitPipe != scalarWaitPipe) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + if (VecBitMap.find(allocOp) == VecBitMap.end()) { + VecBitMap[allocOp] = vecAcc; + vecAcc++; + } + } + } + }); + } + } + } +} + +void modifyForIterargDeps(scf::ForOp forOp, DenseMap ifCounters) +{ + Value iterArg = forOp.getInductionVar(); + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (auto ifOp = dyn_cast(op)) { + if (ifCounters.find(ifOp) != ifCounters.end()) { + Value counter = ifCounters[ifOp]; + + ifOp.walk([&](Operation* opInIf) { + for (auto [i, operand] : llvm::enumerate(opInIf->getOperands())) { + if (operand == iterArg) { + opInIf->setOperand(i, counter); + } + } + }); + } + } + } +} + +void FlowSssbuf(ModuleOp module) { + mlir::OpBuilder builder(module.getContext()); + // 收集所有需要转换的循环 + SmallVector targetLoops; + llvm::outs()<<"enter flowsssbuf\n\n"; + module.walk([&](Operation* op) { + if (auto forOp = dyn_cast(op)) { + // 检查循环是否包含特定的 sync_block_set 操作 + bool hasSyncBlockSet = false; + forOp.walk([&](Operation *op) { + if (isa(op)) { + if (auto ifOp = dyn_cast(op->getParentOp())) { + if (forOp == ifOp->getParentOp() && ifOp->hasAttr("ssbuffer")) { + hasSyncBlockSet = true; + } + } + } + }); + + if (hasSyncBlockSet) { + if (llvm::find(targetLoops, forOp) == targetLoops.end()) { + targetLoops.push_back(forOp); + } + } + } + }); + llvm::outs()<<"enter flowsssbuf\n\n"; + + SmallVector transformLoops; + // 转换每个目标循环 + for (scf::ForOp forOp : targetLoops) { + auto newforOp = transformLoop(forOp, builder); + } + + module.walk([&](Operation* op) { + if (auto forOp = dyn_cast(op)) { + // 检查循环是否包含特定的 sync_block_set 操作 + bool hasSyncBlockSet = false; + forOp.walk([&](Operation *op) { + if (isa(op)) { + if (auto ifOp = dyn_cast(op->getParentOp())) { + if (forOp == ifOp->getParentOp() && ifOp->hasAttr("ssbuffer")) { + hasSyncBlockSet = true; + } + } + } + }); + + if (hasSyncBlockSet) { + if (llvm::find(transformLoops, forOp) == transformLoops.end()) { + transformLoops.push_back(forOp); + } + } + } + + }); + + llvm::sort(transformLoops, [](scf::ForOp a, scf::ForOp b) { + return getNestingDepth(a) > getNestingDepth(b); + }); + DenseMap VecBitMap; + DenseMap CubeBitMap; + getAllocBit(module, VecBitMap, CubeBitMap, builder); + printDenseMap(CubeBitMap); + printDenseMap(VecBitMap); + SmallVector opsToErase; + for (scf::ForOp forOp : transformLoops) { + DenseMap ifArgMap; + llvm::outs()<<"before replaceif\n"; + auto bufvals = addBufValLoop(forOp, VecBitMap, CubeBitMap, builder); + ReplaceIf(forOp, bufvals, opsToErase, ifArgMap, builder, module); + llvm::outs()<<"after replaceif\n"; + for (const auto& pair : ifArgMap) { + auto val = pair.first; + auto bitValue = pair.second; + llvm::outs()<erase(); + } + + +} + +bool isTransOp(mlir::Operation *op) { + auto fixpipeOp = dyn_cast(op); + if (fixpipeOp) + return true; + + auto copyOp = dyn_cast(op); + if (!copyOp) + return false; + else { + + Value copySrc = copyOp.getODSOperands(0).front(); + MemRefType copySrcTy = dyn_cast(copySrc.getType()); + auto SrcAddrSpace = dyn_cast_or_null(copySrcTy.getMemorySpace()); + bool isSrcUbSpace = SrcAddrSpace.getAddressSpace() == hivm::AddressSpace::UB; + + Value copyDst = copyOp.getODSOperands(1).front(); + MemRefType copyDstTy = dyn_cast(copyDst.getType()); + auto DstAddrSpace = dyn_cast_or_null(copyDstTy.getMemorySpace()); + bool isDstCbufSpace = DstAddrSpace.getAddressSpace() == hivm::AddressSpace::L1; + + return isSrcUbSpace && isDstCbufSpace; + } +} + +void FindAndMarkBuffer(ModuleOp module) { + OpBuilder builder(module.getContext()); + unsigned int BufferIdx = 0; + Type idxType = builder.getI32Type(); + StringAttr setFlagAttr = builder.getStringAttr("Set flag"); + StringAttr waitFlagAttr = builder.getStringAttr("Wait flag"); + IntegerAttr idxAttr = builder.getI32IntegerAttr(BufferIdx); + + module.walk([&](mlir::Operation *op) { + + if (isTransOp(op)) { + llvm::outs() << "Buffer idx" << BufferIdx << "\n"; + llvm::outs() << "Trans Op" << *op << "\n"; + Value SharedBuffer; + if (auto fixpipeOp = dyn_cast(op)) { + SharedBuffer = fixpipeOp.getODSOperands(1).front(); + } else { + auto copyOp = dyn_cast(op); + SharedBuffer = copyOp.getODSOperands(1).front(); + } + llvm::outs() << "SharedBuffer" << SharedBuffer << "\n"; + + if (!SharedBuffer) { + op->emitWarning("fixpipe op has empty output operand!"); + return; + } + + // 在Buffer的生产op后set flag标记,在Buffer消费op前增加wait flag标记 + op->setAttr("Buffer idx", builder.getI32IntegerAttr(BufferIdx)); + op->setAttr("Wait Flag", builder.getI32IntegerAttr(0)); + op->setAttr("Set Flag", builder.getI32IntegerAttr(1)); + + for (Operation *consumerOp : SharedBuffer.getUsers()) { + if (consumerOp == op) + continue; + if (!consumerOp) continue; + + llvm::outs() << "consumerOp: " << *consumerOp << "\n"; + + consumerOp->setAttr("Buffer idx", builder.getI32IntegerAttr(BufferIdx)); + consumerOp->setAttr("Wait Flag", builder.getI32IntegerAttr(0)); + } + BufferIdx++; + } + }); +} + +// 结构体存 wait-set 区块信息 +struct WaitSetRegion { + Operation *waitOp; + Operation *lastSetOp; + SmallVector opsToMove; + bool hasCopyOrFixpipe = false; +}; + +struct MergedRegion { + SmallVector regions; + SmallVector opsToMove; + SmallVector yieldValues; + SmallVector resultTypes; +}; + +void MoveIterArgUsersIntoIf( + scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // iter_arg -> mergedRegion index + DenseMap iterArgToRegion; + + for (int r = 0; r < mergedRegions.size(); ++r) { + MergedRegion &mr = mergedRegions[r]; + + for (Operation *op : mr.opsToMove) { + for (Value v : op->getOperands()) { + if (auto barg = mlir::dyn_cast(v)) { + if (barg.getOwner() == &body) { + iterArgToRegion.try_emplace(barg, r); + } + } + } + } + } + + if (iterArgToRegion.empty()) + return; + + // 找最后一个 mergedRegion 的最后一个 op + Operation *lastOp = nullptr; + for (MergedRegion &mr : mergedRegions) + lastOp = mr.opsToMove.back(); + + if (!lastOp) + return; + + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) + opIndex[&op] = idx++; + + int startIdx = opIndex[lastOp] + 1; + + // 扫描 for body 尾部 op + for (Operation &op : body) { + if (opIndex[&op] < startIdx) + continue; + + llvm::SmallDenseSet usedRegions; + for (Value v : op.getOperands()) { + if (auto barg = mlir::dyn_cast(v)) { + auto it = iterArgToRegion.find(barg); + if (it != iterArgToRegion.end()) + usedRegions.insert(it->second); + } + } + + // 必须且只能依赖一个 mergedRegion + if (usedRegions.size() != 1) + continue; + + int target = *usedRegions.begin(); + + mergedRegions[target].opsToMove.push_back(&op); + } +} + +void ComputeYieldForMergedRegion( + MergedRegion &mr, Block &body) { + + mr.yieldValues.clear(); + mr.resultTypes.clear(); + + SmallPtrSet inRegion( + mr.opsToMove.begin(), mr.opsToMove.end()); + + for (Operation *op : mr.opsToMove) { + for (Value res : op->getResults()) { + bool usedOutside = false; + + for (OpOperand &use : res.getUses()) { + Operation *user = use.getOwner(); + + // 不在同一个 for body,交给外层处理(通常不会出现) + if (user->getBlock() != &body) + continue; + + // 只要有一个 use 在 region 外,就必须 yield + if (!inRegion.contains(user)) { + usedOutside = true; + break; + } + } + + if (usedOutside) { + mr.yieldValues.push_back(res); + mr.resultTypes.push_back(res.getType()); + } + } + } +} + +static void ComputeYieldForMergedRegionV2( + MergedRegion &mr, Block &body) { + + mr.yieldValues.clear(); + mr.resultTypes.clear(); + + // 当前 region 内的 ops + SmallPtrSet inRegion( + mr.opsToMove.begin(), mr.opsToMove.end()); + + for (Operation *op : mr.opsToMove) { + for (Value res : op->getResults()) { + + bool usedOutside = false; + + for (OpOperand &use : res.getUses()) { + Operation *user = use.getOwner(); + + // 如果使用在 region 内部 op,跳过 + if (inRegion.contains(user)) + continue; + + // 使用在 region 外部,包括嵌套 region 内部的 block + usedOutside = true; + break; + } + + if (usedOutside) { + mr.yieldValues.push_back(res); + mr.resultTypes.push_back(res.getType()); + } + } + } +} + +static void ComputeYieldForMergedRegionV3(MergedRegion &mr) { + mr.yieldValues.clear(); + mr.resultTypes.clear(); + + // 用 DenseSet 暂存当前 region 的所有 ops + DenseSet regionOps(mr.opsToMove.begin(), mr.opsToMove.end()); + + for (Operation *op : mr.opsToMove) { + for (Value res : op->getResults()) { + + bool needsYield = false; + + for (OpOperand &use : res.getUses()) { + Operation *user = use.getOwner(); + + // 如果 user 不在当前 region,则需要 yield + if (!regionOps.contains(user)) { + needsYield = true; + break; + } + } + + if (needsYield) { + mr.yieldValues.push_back(res); + mr.resultTypes.push_back(res.getType()); + } + } + } +} + +// 递归收集 op 和它所有 region 内的 ops +static void CollectAllNestedOps(Operation *op, DenseSet ®ionOps) { + if (!op) + return; + + if (regionOps.contains(op)) + return; // 已经收集过 + + regionOps.insert(op); + + // 遍历所有 region,递归收集 + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : block) { + CollectAllNestedOps(&nestedOp, regionOps); + } + } + } +} + +static void ComputeYieldForMergedRegionV4(MergedRegion &mr) { + mr.yieldValues.clear(); + mr.resultTypes.clear(); + + // 用 DenseSet 暂存当前 region 的所有 ops + // 初始 DenseSet: 顶层 opsToMove + DenseSet regionOps; + for (Operation *op : mr.opsToMove) { + CollectAllNestedOps(op, regionOps); // 完整展开嵌套 + } + + for (Operation *op : mr.opsToMove) { + for (Value res : op->getResults()) { + + bool needsYield = false; + + for (OpOperand &use : res.getUses()) { + Operation *user = use.getOwner(); + + // 如果 user 不在当前 region,则需要 yield + if (!regionOps.contains(user)) { + needsYield = true; + break; + } + } + + if (needsYield) { + mr.yieldValues.push_back(res); + mr.resultTypes.push_back(res.getType()); + } + } + } +} + +int findTargetRegion( + Operation *startOp, + Block &body, + DenseMap &opToRegion) { + + SmallVector worklist{startOp}; + SmallPtrSet visited; + + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + if (!visited.insert(op).second) + continue; + + auto it = opToRegion.find(op); + if (it != opToRegion.end()) + return it->second; + + for (Value operand : op->getOperands()) { + if (isa(operand)) + continue; + + Operation *defOp = operand.getDefiningOp(); + if (defOp && defOp->getBlock() == &body) + worklist.push_back(defOp); + } + } + + return -1; +} + +void greedyAbsorbToRegion( + Operation *startOp, + int regionIdx, + int lowerBound, + Block &body, + DenseMap &opIndex, + DenseMap &opToRegion, + SmallVector &mergedRegions) { + + auto &mr = mergedRegions[regionIdx]; + + SmallVector worklist; + SmallPtrSet visited( + mr.opsToMove.begin(), mr.opsToMove.end()); + + // 先把 startOp 本身吸收(如果还没被吸收) + if (!opToRegion.count(startOp)) { + mr.opsToMove.push_back(startOp); + opToRegion[startOp] = regionIdx; + visited.insert(startOp); + } + + worklist.push_back(startOp); + + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + + for (Value operand : op->getOperands()) { + if (isa(operand)) + continue; + + Operation *defOp = operand.getDefiningOp(); + if (!defOp || defOp->getBlock() != &body) + continue; + + int defIdx = opIndex[defOp]; + + // 超过前一个 region 的末尾 + if (defIdx < lowerBound) + continue; + + auto it = opToRegion.find(defOp); + + // 不能跨到其他 region + if (it != opToRegion.end() && + it->second != regionIdx) + continue; + + // 去重 + if (!visited.insert(defOp).second) + continue; + + // 吸收 defOp + mr.opsToMove.push_back(defOp); + opToRegion[defOp] = regionIdx; + worklist.push_back(defOp); + } + } +} + +SmallVector getOperationInput(Operation *op, SmallVector dependValues, + DenseMap>> &collectDepValueMap) +{ + // Analyse each Op's input + DenseSet opInput; + if (isa(op) || isa(op)) { + SmallVector regionBlocks; + if (auto ifOp = dyn_cast(op)) { + regionBlocks.push_back(&(ifOp.getThenRegion().front())); + regionBlocks.push_back(&(ifOp.getElseRegion().front())); + } else { + auto forOp = dyn_cast(op); + regionBlocks.push_back(forOp.getBody()); + } + + // recursively walk scf op + for (Block *curBlock: regionBlocks) { + for (auto &curOp : *curBlock) { + for (auto operand : getOperationInput(&curOp, dependValues, collectDepValueMap)) { + Operation *defOp; + if (auto blockArg = dyn_cast(operand)) { + Block* ownerBlock = blockArg.getOwner(); + defOp = ownerBlock->getParentOp(); + } else { + defOp = operand.getDefiningOp(); + } + Block *defBlock = defOp->getBlock(); + + if (!(defOp == op || llvm::is_contained(regionBlocks, defBlock))) { + opInput.insert(operand); + } + } + } + } + SmallVector retVector(opInput.begin(), opInput.end()); + return retVector; + } else { + SmallVector operands = op->getOperands(); + // store ifresult value that will be replaced + for (auto operand : operands) { + if (llvm::is_contained(dependValues, operand)) { + if (collectDepValueMap.find(operand) != collectDepValueMap.end()) { + collectDepValueMap[operand].second.push_back(op); + } else { + SmallVector userOps; + userOps.push_back(op); + collectDepValueMap[operand] = {operand, userOps}; + } + } + } + return operands; + } +} + +SmallVector collectDepValuesCalculation(DenseSet forRegionOps, + DenseSet regionOps, Operation *op, SmallVector dependValues, + DenseMap>> &collectDepValueMap) +{ + DenseSet collectOps; + std::deque opStack; + bool flag = false; + + opStack.push_back(op); + while (opStack.size()) { + Operation *curOp = opStack.front(); + opStack.pop_front(); + + for (auto operand : getOperationInput(curOp, dependValues, collectDepValueMap)) { + if (llvm::is_contained(dependValues, operand)) { + flag = true; + } + + Operation *parentOp = operand.getDefiningOp(); + if (llvm::is_contained(regionOps, parentOp)) { + opStack.push_back(parentOp); + continue; + } else if (llvm::is_contained(forRegionOps, parentOp)) { + opStack.push_back(parentOp); + collectOps.insert(parentOp); + } + } + } + + if (flag) { + SmallVector retVector(collectOps.begin(), collectOps.end()); + return retVector; + } else { + collectDepValueMap.clear(); + SmallVector emptyVector; + emptyVector.clear(); + return emptyVector; + } +} + +void copyOpsToMergedRegion(scf::ForOp forOp, SmallVector collectOps, MergedRegion &mergedRegion, + DenseMap>> &collectDepValueMap) +{ + Block *forBodyBlock = forOp.getBody(); + OpBuilder builder(forOp); + SmallVector clonedOps; + IRMapping mapper; + + // copy calculation of ifreult value related to load/store op + int cnt = 0; + for (Operation &origOp : forBodyBlock->without_terminator()) { + if (cnt >= collectOps.size()) + break; + + if (llvm::is_contained(collectOps, &origOp)) { + builder.setInsertionPointAfter(&origOp); + + Operation *clonedOp = (&origOp)->clone(mapper); + builder.insert(clonedOp); + mapper.map(&origOp, clonedOp); + + clonedOps.push_back(clonedOp); + cnt++; + + // replace the ifresult value by new cloned op's result + SmallVector results = origOp.getResults(); + for (auto [idx, result] : llvm::enumerate(origOp.getResults())) { + if (collectDepValueMap.find(result) != collectDepValueMap.end()) { + collectDepValueMap[result].first = clonedOp->getResult(idx); + } + } + } + } + + DenseSet mergedRegionOps; + for (Operation *op : mergedRegion.opsToMove) { + CollectAllNestedOps(op, mergedRegionOps); + } + + // replace the ifresult value by new cloned op's result + for (Operation *op : mergedRegionOps) { + for (auto [idx, operand] : llvm::enumerate(op->getOperands())) { + if (collectDepValueMap.find(operand) != collectDepValueMap.end()) { + op->setOperand(idx, collectDepValueMap[operand].first); + } + } + } + + // update MergedRegion + clonedOps.append(mergedRegion.opsToMove); + mergedRegion.opsToMove = clonedOps; +} + +void copyLoadCalculation(scf::ForOp forOp, SmallVector dependValues, SmallVector &mergedRegions) +{ + mlir::Operation* parentOp = forOp->getParentOp(); + mlir::Operation* scopeOp = nullptr; + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + auto coreTypeAttr = scopeOp->getAttrOfType( + hivm::TCoreTypeAttr::name); + // only process the vector core + if (coreTypeAttr.getTcoretype() == hivm::TCoreType::CUBE) { + return; + } + + // recursively collect all op in forOp + DenseSet forRegionOps; + for (Operation &op : forOp.getBody()->without_terminator()) { + CollectAllNestedOps(&op, forRegionOps); + } + + for (MergedRegion &mr : mergedRegions) { + DenseSet regionOps; + for (Operation *op : mr.opsToMove) { + CollectAllNestedOps(op, regionOps); + } + + for (Operation *op : regionOps) { + if (isa(op) || isa(op)) { + // recusively check that whether load/store op's operands originated from if results + DenseMap>> collectDepValueMap; + SmallVector collectOps = \ + collectDepValuesCalculation(forRegionOps, regionOps, op, dependValues, collectDepValueMap); + copyOpsToMergedRegion(forOp, collectOps, mr, collectDepValueMap); + } + } + } +} + +// 以 forOp 的 yield value 为中心 +// 决定它应该归属哪个 mergedRegion, 然后再向前吸 operand +void ExpandMergedRegionOpsForAIV( + scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // 记录 block 中 op 顺序 + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) + opIndex[&op] = idx++; + + // 建立 op -> region 映射 + DenseMap opToRegion; + for (int r = 0; r < mergedRegions.size(); ++r) + for (Operation *op : mergedRegions[r].opsToMove) + opToRegion[op] = r; + + // 取 scf.yield + auto yieldOp = + cast(body.getTerminator()); + + // 依次处理每个 yield value(按编号顺序) + for (Value yv : yieldOp.getOperands()) { + + Operation *defOp = yv.getDefiningOp(); + if (!defOp || defOp->getBlock() != &body) + continue; + + int targetRegion = -1; + + // 如果已经在 region 中 + auto it = opToRegion.find(defOp); + if (it != opToRegion.end()) { + targetRegion = it->second; + } else { + // 否则向前搜索确定归属 + targetRegion = + findTargetRegion(defOp, body, opToRegion); + } + + if (targetRegion == -1) + continue; + + // 计算边界 lowerBound + int lowerBound = 0; + + if (targetRegion > 0) { + Operation *prevLast = + mergedRegions[targetRegion - 1] + .opsToMove.back(); + lowerBound = opIndex[prevLast] + 1; + } + + // 真正贪心吸收 + greedyAbsorbToRegion(defOp, + targetRegion, + lowerBound, + body, + opIndex, + opToRegion, + mergedRegions); + } + + // 每个 region 内按 block 顺序排序 + for (auto &mr : mergedRegions) { + llvm::sort(mr.opsToMove, + [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); + } +} + +// 以 mergedRegion 为中心, 向前吸 operand +void ExpandMergedRegionOpsForAIC(scf::ForOp forOp, + SmallVector &mergedRegions) { + Block &body = forOp.getRegion().front(); + + // 记录每个 mergedRegion 的起始 op index + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) { + opIndex[&op] = idx++; + } + + for (int r = 0; r < mergedRegions.size(); ++r) { + MergedRegion &mr = const_cast(mergedRegions[r]); + + // 本 mergedRegion 的最早 op + Operation *firstOp = mr.opsToMove.front(); + int lowerBound = 0; + + // 边界: 前一个 mergedRegion 的最后一个 op + if (r > 0) { + Operation *prevLast = + mergedRegions[r - 1].opsToMove.back(); + lowerBound = opIndex[prevLast] + 1; + } + + SmallVector worklist(mr.opsToMove.begin(), + mr.opsToMove.end()); + SmallPtrSet visited( + mr.opsToMove.begin(), mr.opsToMove.end()); + + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + + // 往前吸收operand + for (Value operand : op->getOperands()) { + // BlockArgument + if (mlir::isa(operand)) + continue; + + Operation *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + + // 不在 for body + if (defOp->getBlock() != &body) + continue; + + int defIdx = opIndex[defOp]; + + // 超出允许向前吸收的边界 + if (defIdx < lowerBound) + continue; + + // 已经在 opsToMove + if (!visited.insert(defOp).second) + continue; + + // 吸收这个 defOp + mr.opsToMove.push_back(defOp); + worklist.push_back(defOp); + } + } + + // 最后按原 block 顺序排序 + llvm::sort(mr.opsToMove, + [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); + } +} + +static void pullInRegionDependencies( + Operation *regionOp, + int regionId, + DenseMap &opToRegion, + Block &body) { + + SmallVector worklist; + + // 先把 region 内的 op 放进去 + for (Region ®ion : regionOp->getRegions()) + for (Block &block : region) + for (Operation &inner : block) + worklist.push_back(&inner); + + SmallPtrSet visited; + + while (!worklist.empty()) { + Operation *innerOp = worklist.pop_back_val(); + + if (!visited.insert(innerOp).second) + continue; + + // operand 的 defining op + for (Value operand : innerOp->getOperands()) { + + Operation *def = operand.getDefiningOp(); + if (!def) + continue; + + if (def->getBlock() != &body) + continue; + + if (!opToRegion.count(def)) { + + opToRegion[def] = regionId; + + // 如果 def 也是 region-op,继续扩展 + if (def->getNumRegions() > 0) + worklist.push_back(def); + } + } + + // 继续遍历 region + for (Region &r : innerOp->getRegions()) + for (Block &b : r) + for (Operation &child : b) + worklist.push_back(&child); + } +} + +// BFS 查找某个 op 最早被哪个 region 使用 +static int findEarliestRegion( + Operation *startOp, + const DenseMap &seedRegionMap, + Block &body) { + + SmallVector worklist{startOp}; + SmallPtrSet visited; + int earliestRegion = -1; + + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + + if (!visited.insert(op).second) + continue; + + for (Value result : op->getResults()) { + for (OpOperand &use : result.getUses()) { + Operation *user = use.getOwner(); + + if (user->getBlock() != &body) + continue; + + auto it = seedRegionMap.find(user); + if (it != seedRegionMap.end()) { + int region = it->second; + if (earliestRegion == -1 || region < earliestRegion) + earliestRegion = region; + } else { + worklist.push_back(user); + } + } + } + } + + return earliestRegion; +} + +void ExpandMergedRegionOpsForAll( + scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // block 内 op 顺序 + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) + opIndex[&op] = idx++; + + // seed region map + DenseMap seedRegionMap; + for (int r = 0; r < mergedRegions.size(); r++) { + for (Operation *op : mergedRegions[r].opsToMove) { + seedRegionMap[op] = r; + } + } + + // 最终 op -> region + DenseMap opToRegion = seedRegionMap; + + // ---------- Step1 顺序扫描 ---------- + for (Operation &op : body) { + + if (isa(&op)) + continue; + + if (opToRegion.count(&op)) + continue; + + int region = findEarliestRegion(&op, seedRegionMap, body); + + if (region != -1) + opToRegion[&op] = region; + } + + // ---------- Step2 region-op 依赖补全 ---------- + for (Operation &op : body) { + + auto it = opToRegion.find(&op); + if (it == opToRegion.end()) + continue; + + if (op.getNumRegions() == 0) + continue; + + pullInRegionDependencies(&op, it->second, opToRegion, body); + } + + // ---------- Step3 append op ---------- + SmallPtrSet seen; + + for (Operation &op : body) { + + auto it = opToRegion.find(&op); + if (it == opToRegion.end()) + continue; + + if (!seen.insert(&op).second) + continue; + + int region = it->second; + mergedRegions[region].opsToMove.push_back(&op); + } + + // ---------- Step4 排序 ---------- + for (auto &mr : mergedRegions) { + + llvm::sort(mr.opsToMove, + [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); + } +} + +void ExpandMergedRegionOpsByInput( + scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // block 内 op 顺序 + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) + opIndex[&op] = idx++; + + // seed region map + DenseMap seedRegionMap; + for (int r = 0; r < mergedRegions.size(); r++) { + for (Operation *op : mergedRegions[r].opsToMove) { + seedRegionMap[op] = r; + } + } + + // 最终 op -> region + DenseMap opToRegion = seedRegionMap; + + // ---------- Step1 顺序扫描 ---------- + for (Operation &op : body) { + + if (isa(&op)) + continue; + + if (opToRegion.count(&op)) + continue; + + int region = findEarliestRegion(&op, seedRegionMap, body); + + if (region != -1) + opToRegion[&op] = region; + } + + // ---------- Step2 region-op 依赖补全 ---------- + for (Operation &op : body) { + + auto it = opToRegion.find(&op); + if (it == opToRegion.end()) + continue; + + if (op.getNumRegions() == 0) + continue; + + pullInRegionDependencies(&op, it->second, opToRegion, body); + } + + // ---------- Step3 append op ---------- + SmallPtrSet seen; + + for (Operation &op : body) { + + auto it = opToRegion.find(&op); + if (it == opToRegion.end()) + continue; + + if (!seen.insert(&op).second) + continue; + + int region = it->second; + mergedRegions[region].opsToMove.push_back(&op); + } + + // ---------- Step4 排序 ---------- + for (auto &mr : mergedRegions) { + + llvm::sort(mr.opsToMove, + [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); + } +} + +static void ExpandMergedRegionOpsByOutput( + scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // block 顺序(保持 IR 顺序) + DenseMap opOrder; + int idx = 0; + for (Operation &op : body) + opOrder[&op] = idx++; + + for (auto &merged : mergedRegions) { + + // 收集 region 当前产生的 value + SmallPtrSet regionValues; + + for (Operation *op : merged.opsToMove) + for (Value res : op->getResults()) + regionValues.insert(res); + + bool changed = true; + + while (changed) { + changed = false; + + for (Operation &op : body) { + + if (isa(op) || isa(op)) + continue; + + if (llvm::is_contained(merged.opsToMove, &op)) + continue; + + bool depends = false; + + for (Value operand : op.getOperands()) { + if (regionValues.contains(operand)) { + depends = true; + break; + } + } + + if (!depends) + continue; + + // 加入 region + merged.opsToMove.push_back(&op); + + // 更新 regionValues + for (Value res : op.getResults()) + regionValues.insert(res); + + changed = true; + } + } + + // 排序保持原 block 顺序 + llvm::sort(merged.opsToMove, + [&](Operation *a, Operation *b) { + return opOrder[a] < opOrder[b]; + }); + } +} + +static void MoveIndependentOpsIntoIf( + scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // 记录哪些 op 已经在 region 里 + SmallPtrSet alreadyAssigned; + + for (auto &mr : mergedRegions) + for (Operation *op : mr.opsToMove) + alreadyAssigned.insert(op); + + // 记录 iter_arg -> region + DenseMap iterArgToRegion; + + for (int r = 0; r < mergedRegions.size(); r++) { + for (Operation *op : mergedRegions[r].opsToMove) { + + for (Value operand : op->getOperands()) { + + if (auto barg = mlir::dyn_cast(operand)) { + + if (barg.getOwner() == &body) + iterArgToRegion[barg] = r; + } + } + } + } + + // block 顺序 + DenseMap opIndex; + int idx = 0; + for (Operation &op : body) + opIndex[&op] = idx++; + + // 扫描所有 op + for (Operation &op : body) { + + if (isa(op) || isa(op)) + continue; + + if (alreadyAssigned.contains(&op)) + continue; + + int targetRegion = -1; + + // 看 operand 是否来自 iter_arg + for (Value operand : op.getOperands()) { + + if (auto barg = mlir::dyn_cast(operand)) { + + if (barg.getOwner() != &body) + continue; + + auto it = iterArgToRegion.find(barg); + if (it != iterArgToRegion.end()) { + + targetRegion = it->second; + break; + } + } + } + + if (targetRegion == -1) + continue; + + mergedRegions[targetRegion].opsToMove.push_back(&op); + alreadyAssigned.insert(&op); + } + + // 排序保持 block 顺序 + for (auto &mr : mergedRegions) { + + llvm::sort(mr.opsToMove, + [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); + } +} + +// 暴力包裹 +static void ExpandMergedRegionOpsGreedyMaximum( + scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // 记录哪些 op 已经属于 region + DenseSet regionOps; + + for (auto ®ion : mergedRegions) + for (Operation *op : region.opsToMove) + regionOps.insert(op); + + // block op 列表 + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + DenseMap opIndex; + for (int i = 0; i < ops.size(); i++) + opIndex[ops[i]] = i; + + for (auto ®ion : mergedRegions) { + + if (region.opsToMove.empty()) + continue; + + // 找到 region 在 block 中的范围 + int start = ops.size(); + int end = -1; + + for (Operation *op : region.opsToMove) { + int idx = opIndex[op]; + start = std::min(start, idx); + end = std::max(end, idx); + } + + SmallVector newOps; + + // ---------- backward 扩展 ---------- + for (int i = start - 1; i >= 0; i--) { + Operation *op = ops[i]; + + if (isa(op)) + break; + + if (regionOps.contains(op)) + break; + + newOps.push_back(op); + } + + // ---------- forward 扩展 ---------- + for (int i = end + 1; i < ops.size(); i++) { + Operation *op = ops[i]; + + if (isa(op)) + break; + + if (regionOps.contains(op)) + break; + + newOps.push_back(op); + } + + // 加入 region + for (Operation *op : newOps) { + region.opsToMove.push_back(op); + regionOps.insert(op); + } + } + + // 最后保持 block 顺序 + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, + [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +static void CollectForYieldRelatedOps( + scf::ForOp forOp, + SmallVector &mergedRegions, + DenseSet &yieldRelatedOps) { + + Block &body = forOp.getRegion().front(); + + // 已经属于 region 的 op + DenseSet regionOps; + for (auto ®ion : mergedRegions) + for (Operation *op : region.opsToMove) + regionOps.insert(op); + + auto yield = cast(body.getTerminator()); + + SmallVector worklist; + DenseSet visited; + + // 初始化 worklist + for (Value v : yield.getOperands()) + worklist.push_back(v); + + while (!worklist.empty()) { + Value v = worklist.pop_back_val(); + + if (!visited.insert(v).second) + continue; + + Operation *def = v.getDefiningOp(); + if (!def) + continue; + + // 只处理 for body 内的 op + if (def->getBlock() != &body) + continue; + + // 已经在 region 内 + if (regionOps.contains(def)) + continue; + + // 记录 + if (yieldRelatedOps.insert(def).second) { + + // 继续向上找依赖 + for (Value operand : def->getOperands()) + worklist.push_back(operand); + } + } +} + +// 贪心吸收region前后的op +static void ExpandMergedRegionOpsGreedy( + scf::ForOp forOp, + SmallVector &mergedRegions, + DenseSet &skipOps) { + + Block &body = forOp.getRegion().front(); + + // 记录哪些 op 已经属于 region + DenseSet regionOps; + for (auto ®ion : mergedRegions) + for (Operation *op : region.opsToMove) + regionOps.insert(op); + + // block op 列表 + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + // op -> index + DenseMap opIndex; + for (int i = 0; i < ops.size(); i++) + opIndex[ops[i]] = i; + + for (auto ®ion : mergedRegions) { + + if (region.opsToMove.empty()) + continue; + + // 找到 region 在 block 中的范围 + int start = ops.size(); + int end = -1; + + for (Operation *op : region.opsToMove) { + int idx = opIndex[op]; + start = std::min(start, idx); + end = std::max(end, idx); + } + + SmallVector newOps; + + // ---------- backward 扩展 ---------- + for (int i = start - 1; i >= 0; i--) { + Operation *op = ops[i]; + + // block terminator + if (isa(op)) + break; + + // 遇到其他 region 的 op + if (regionOps.contains(op)) + break; + + // yield 关联 op,跳过但继续扫描 + if (skipOps.contains(op)) + continue; + + newOps.push_back(op); + } + + // ---------- forward 扩展 ---------- + for (int i = end + 1; i < ops.size(); i++) { + Operation *op = ops[i]; + + // block terminator + if (isa(op)) + break; + + // 遇到其他 region 的 op + if (regionOps.contains(op)) + break; + + // yield 关联 op,跳过 + if (skipOps.contains(op)) + continue; + + newOps.push_back(op); + } + + // 加入 region + for (Operation *op : newOps) { + region.opsToMove.push_back(op); + regionOps.insert(op); + } + } + + // 最后保持 block 顺序 + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, + [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +// 贪心吸收region前面的op +static void ExpandMergedRegionOpsGreedyV2( + scf::ForOp forOp, + SmallVector &mergedRegions, + DenseSet &skipOps) { + + Block &body = forOp.getRegion().front(); + + // block op 列表 + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + // op -> index + DenseMap opIndex; + for (int i = 0; i < ops.size(); i++) + opIndex[ops[i]] = i; + + // 记录哪些 op 已经属于 region + DenseSet regionOps; + for (auto ®ion : mergedRegions) + for (Operation *op : region.opsToMove) + regionOps.insert(op); + + for (int r = 0; r < mergedRegions.size(); r++) { + + auto ®ion = mergedRegions[r]; + if (region.opsToMove.empty()) + continue; + + // ---------- 当前 region block 范围 ---------- + int start = ops.size(); + int end = -1; + + for (Operation *op : region.opsToMove) { + int idx = opIndex[op]; + start = std::min(start, idx); + end = std::max(end, idx); + } + + // ---------- 前一个 region 的末尾 ---------- + int prevEnd = -1; + + if (r > 0 && !mergedRegions[r - 1].opsToMove.empty()) { + for (Operation *op : mergedRegions[r - 1].opsToMove) { + prevEnd = std::max(prevEnd, opIndex[op]); + } + } + + SmallVector newOps; + + // ---------- backward expand ---------- + for (int i = start - 1; i > prevEnd; i--) { + + Operation *op = ops[i]; + + // terminator + if (isa(op)) + break; + + // 已属于 region + if (regionOps.contains(op)) + break; + + // yield chain op + if (skipOps.contains(op)) + continue; + + newOps.push_back(op); + } + + // ---------- 加入 region ---------- + for (Operation *op : newOps) { + region.opsToMove.push_back(op); + regionOps.insert(op); + } + } + + // ---------- 保持 block 顺序 ---------- + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, + [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), + region.opsToMove.end()), + region.opsToMove.end()); + } +} + +// 贪心吸收region前面的op +static void ExpandMergedRegionOpsGreedyV2ForAIC( + scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + + // block op 列表 + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + // op -> index + DenseMap opIndex; + for (int i = 0; i < ops.size(); i++) + opIndex[ops[i]] = i; + + // 记录哪些 op 已经属于 region + DenseSet regionOps; + for (auto ®ion : mergedRegions) + for (Operation *op : region.opsToMove) + regionOps.insert(op); + + for (int r = 0; r < mergedRegions.size(); r++) { + + auto ®ion = mergedRegions[r]; + if (region.opsToMove.empty()) + continue; + + // ---------- 当前 region block 范围 ---------- + int start = ops.size(); + int end = -1; + + for (Operation *op : region.opsToMove) { + int idx = opIndex[op]; + start = std::min(start, idx); + end = std::max(end, idx); + } + + // ---------- 前一个 region 的末尾 ---------- + int prevEnd = -1; + + if (r > 0 && !mergedRegions[r - 1].opsToMove.empty()) { + for (Operation *op : mergedRegions[r - 1].opsToMove) { + prevEnd = std::max(prevEnd, opIndex[op]); + } + } + + SmallVector newOps; + + // ---------- backward expand ---------- + for (int i = start - 1; i > prevEnd; i--) { + + Operation *op = ops[i]; + + // terminator + if (isa(op)) + break; + + // 已属于 region + if (regionOps.contains(op)) + break; + + newOps.push_back(op); + } + + // ---------- 加入 region ---------- + for (Operation *op : newOps) { + region.opsToMove.push_back(op); + regionOps.insert(op); + } + } + + // ---------- 保持 block 顺序 ---------- + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, + [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), + region.opsToMove.end()), + region.opsToMove.end()); + } +} + +static void MoveForYieldOpIntoRegion( + scf::ForOp forOp, + DenseSet &yieldRelatedOps, + SmallVector &mergedRegions) { + + DenseMap opToRegion; + + for (int i = 0; i < mergedRegions.size(); i++) + for (Operation *op : mergedRegions[i].opsToMove) + opToRegion[op] = i; + + auto yield = cast(forOp.getBody()->getTerminator()); + + for (int i = 0; i < yield.getNumOperands(); i++) { + + Value iterArg = forOp.getRegionIterArgs()[i]; + Value yieldVal = yield.getOperand(i); + + Operation *def = yieldVal.getDefiningOp(); + if (!def) + continue; + + if (!yieldRelatedOps.contains(def)) + continue; + + int targetRegion = -1; + + for (Operation *user : iterArg.getUsers()) { + + if (opToRegion.count(user)) { + targetRegion = opToRegion[user]; + break; + } + } + + if (targetRegion == -1) + continue; + + SmallVector stack; + stack.push_back(def); + + while (!stack.empty()) { + Operation *op = stack.pop_back_val(); + + if (!yieldRelatedOps.contains(op)) + continue; + + mergedRegions[targetRegion].opsToMove.push_back(op); + + yieldRelatedOps.erase(op); + + for (Value operand : op->getOperands()) { + if (Operation *dep = operand.getDefiningOp()) + stack.push_back(dep); + } + } + } + + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, + [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +static void MoveRemainingYieldOpsToPrevRegion( + scf::ForOp forOp, + DenseSet &yieldRelatedOps, + SmallVector &mergedRegions) { + + if (yieldRelatedOps.empty()) + return; + + Block &body = forOp.getRegion().front(); + + // op -> region index + DenseMap opToRegion; + for (int i = 0; i < mergedRegions.size(); i++) + for (Operation *op : mergedRegions[i].opsToMove) + opToRegion[op] = i; + + // block 顺序 + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + DenseMap opIndex; + for (int i = 0; i < ops.size(); i++) + opIndex[ops[i]] = i; + + for (Operation *op : yieldRelatedOps) { + + if (op->getBlock() != &body) + continue; + + int idx = opIndex[op]; + + int targetRegion = -1; + + // 向前找最近的 region + for (int i = idx - 1; i >= 0; i--) { + Operation *prev = ops[i]; + + if (opToRegion.count(prev)) { + targetRegion = opToRegion[prev]; + break; + } + } + + if (targetRegion == -1) + continue; + + mergedRegions[targetRegion].opsToMove.push_back(op); + } + + // 排序 + 去重 + for (auto ®ion : mergedRegions) { + + llvm::sort(region.opsToMove, + [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), region.opsToMove.end()), + region.opsToMove.end()); + } +} + +static void MoveIndependentOpsIntoRegionBackwardV2( + scf::ForOp forOp, + SmallVector &mergedRegions) { + + Block &body = forOp.getRegion().front(); + SmallVector ops; + for (Operation &op : body) + ops.push_back(&op); + + DenseMap opToRegion; + for (int i = 0; i < mergedRegions.size(); i++) + for (Operation *op : mergedRegions[i].opsToMove) + opToRegion[op] = i; + + // ----------- 收集移动计划 ----------- + DenseMap movePlan; + + for (int i = 0; i < mergedRegions.size(); i++) { + MergedRegion ®ion = mergedRegions[i]; + if (region.opsToMove.empty()) continue; + + Operation *firstOp = region.opsToMove.front(); + Operation *lastOp = region.opsToMove.back(); + auto itFirst = std::find(ops.begin(), ops.end(), firstOp); + auto itLast = std::find(ops.begin(), ops.end(), lastOp); + if (itFirst == ops.end() || itLast == ops.end()) continue; + + int startIdx = std::distance(ops.begin(), itFirst); + int endIdx = std::distance(ops.begin(), itLast); + + // ----------- 收集 wait-set 区间 ----------- + SmallVector> waitIntervals; + bool inWait = false; + int begin = -1; + for (int j = startIdx; j <= endIdx; j++) { + Operation *op = ops[j]; + if (op->getName().getStringRef().contains("sync_block_wait")) { + inWait = true; begin = j + 1; continue; + } + if (op->getName().getStringRef().contains("sync_block_set") && inWait) { + inWait = false; + waitIntervals.push_back({begin, j - 1}); + } + } + auto isInWaitSet = [&](int idx) { + for (auto &p : waitIntervals) + if (idx >= p.first && idx <= p.second) return true; + return false; + }; + + // ----------- 从后往前扫描 region 内的 op ----------- + for (int j = endIdx; j >= startIdx; j--) { + Operation *op = ops[j]; + if (isa(op) || isInWaitSet(j)) continue; + + // ---------- operand 是否依赖本 region ---------- + bool dependCurrentRegion = false; + for (Value operand : op->getOperands()) { + Operation *def = operand.getDefiningOp(); + if (!def) continue; + if (std::find(region.opsToMove.begin(), + region.opsToMove.end(), + def) != region.opsToMove.end()) { + dependCurrentRegion = true; break; + } + } + if (dependCurrentRegion) continue; + + // ---------- 当前 region 后续是否使用 ---------- + bool usedLaterInSameRegion = false; + for (Value result : op->getResults()) + for (Operation *user : result.getUsers()) + if (std::find(region.opsToMove.begin(), + region.opsToMove.end(), + user) != region.opsToMove.end() && + std::find(region.opsToMove.begin(), + region.opsToMove.end(), op) < + std::find(region.opsToMove.begin(), + region.opsToMove.end(), user)) { + usedLaterInSameRegion = true; break; + } + if (usedLaterInSameRegion) continue; + + // ---------- 找使用该 op 的后续 region ---------- + int targetRegion = -1; + for (int k = i + 1; k < mergedRegions.size(); ++k) { + for (Operation *candidate : mergedRegions[k].opsToMove) + for (Value operand : candidate->getOperands()) + if (operand.getDefiningOp() == op) { + targetRegion = k; break; + } + if (targetRegion != -1) break; + if (targetRegion != -1) break; + } + if (targetRegion == -1) continue; + + movePlan[op] = targetRegion; + // llvm::outs() << "MJ: plan move " << *op + // << " -> region " << targetRegion << "\n"; + } + } + + // ----------- 统一应用移动 ----------- + for (auto &it : movePlan) { + Operation *op = it.first; + int targetRegionIdx = it.second; + MergedRegion &targetRegion = mergedRegions[targetRegionIdx]; + // 更新数据结构 + targetRegion.opsToMove.push_back(op); + + llvm::outs() << "MJ: move " << *op + << " -> region " << targetRegionIdx << "\n"; + } + + // ----------- 更新原 region 的 opsToMove ----------- + for (int i = 0; i < mergedRegions.size(); ++i) { + MergedRegion ®ion = mergedRegions[i]; + SmallVector newOps; + for (Operation *op : region.opsToMove) { + auto it = movePlan.find(op); + if (it == movePlan.end() || it->second == i) { + // 没有移动计划,或者移动的目标就是自己,保留 + newOps.push_back(op); + } + } + region.opsToMove.swap(newOps); + } + + // ----------- 排序 + 去重 ----------- + for (auto ®ion : mergedRegions) { + llvm::sort(region.opsToMove, + [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), + region.opsToMove.end()), + region.opsToMove.end()); + } +} + +// // debug: 如果一个forop的第一个region的最后3条op是%27 = tt.expand_dims %25#1 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> +// %28 = tt.broadcast %27 : tensor<64x1xf32> -> tensor<64x128xf32> +// %29 = arith.mulf %arg10, %28 : tensor<64x128xf32> +// 直接放到第2个region里 +static void TempChange(scf::ForOp forOp, + SmallVector &mergedRegions) { + + if (mergedRegions.size() < 2) + return; + + auto &srcRegion = mergedRegions[0]; + auto &dstRegion = mergedRegions[1]; + + if (srcRegion.opsToMove.size() < 3) + return; + + Operation *op1 = srcRegion.opsToMove[srcRegion.opsToMove.size() - 3]; + Operation *op2 = srcRegion.opsToMove[srcRegion.opsToMove.size() - 2]; + Operation *op3 = srcRegion.opsToMove[srcRegion.opsToMove.size() - 1]; + + // ---------- pattern 匹配 ---------- + if (!op1->getName().getStringRef().contains("tt.expand_dims")) + return; + + if (!op2->getName().getStringRef().contains("tt.broadcast")) + return; + + if (!op3->getName().getStringRef().contains("arith.mulf")) + return; + + llvm::outs() << "TempChange triggered\n"; + + SmallVector opsToMove = {op1, op2, op3}; + + // ---------- 移动到 region2 末尾 ---------- + for (Operation *op : opsToMove) { + dstRegion.opsToMove.push_back(op); + llvm::outs() << "TempChange move: " << *op << "\n"; + } + + // ---------- 从 region1 删除 ---------- + srcRegion.opsToMove.resize(srcRegion.opsToMove.size() - 3); + + // ---------- 排序 ---------- + for (auto ®ion : mergedRegions) { + llvm::sort(region.opsToMove, + [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + region.opsToMove.erase( + std::unique(region.opsToMove.begin(), + region.opsToMove.end()), + region.opsToMove.end()); + } +} + +static void sortOperationsByDataFlow(llvm::SmallVector &ops) { + llvm::DenseSet visited; + llvm::SmallVector result; + + std::function dfs = [&](Operation *op) { + if (!visited.insert(op).second) + return; + + for (Value operand : op->getOperands()) { + if (Operation *def = operand.getDefiningOp()) { + if (llvm::is_contained(ops, def)) + dfs(def); + } + } + + result.push_back(op); + }; + + for (Operation *op : ops) + dfs(op); + + ops.assign(result.begin(), result.end()); +} + +static void rewriteOperandsRecursively(Operation *op, + DenseMap &valueMap) { + + // 1 rewrite 当前 op 的 operands + for (OpOperand &operand : op->getOpOperands()) { + Value v = operand.get(); + auto it = valueMap.find(v); + if (it != valueMap.end()) + operand.set(it->second); + } + + // 2 递归进入 region + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : block) { + rewriteOperandsRecursively(&nestedOp, valueMap); + } + } + } +} + +static void CopyOpsToAfterwardRegions( + SmallVector &mergedRegions, + DenseMap &yieldMap, + DenseMap &cloneAndOriYieldMap, + SmallVector &copiedForOps) { + + if (mergedRegions.size() <= 1) + return; + + // 先整理一个 set,方便判断哪些 op 是 yield defining op + DenseSet yieldDefOps; + for (auto &it : yieldMap) + yieldDefOps.insert(it.second); + + // 倒序遍历 region + for (int i = mergedRegions.size() - 1; i >= 0; --i) { + MergedRegion &curRegion = mergedRegions[i]; + + DenseMap valueMap; + SmallVector clonedOps; + + // 遍历前面的 region + for (int k = 0; k < i; ++k) { + MergedRegion &prevRegion = mergedRegions[k]; + + int waitSetLevel = 0; + + for (Operation *op : prevRegion.opsToMove) { + + if (isa(op)) { + waitSetLevel++; + continue; + } + + if (isa(op)) { + waitSetLevel = std::max(waitSetLevel - 1, 0); + continue; + } + + if (waitSetLevel > 0) + continue; + + if (isa(op)) + continue; + + IRMapping mapper; + + for (auto result : op->getResults()) + if (valueMap.count(result)) + mapper.map(result, valueMap[result]); + + Operation *insertPoint = + curRegion.opsToMove.empty() ? nullptr : curRegion.opsToMove.front(); + + OpBuilder builder(insertPoint ? insertPoint : op); + + Operation *cloned = builder.clone(*op, mapper); + + // 记录 result mapping + for (auto it : llvm::zip(op->getResults(), cloned->getResults())) + valueMap[std::get<0>(it)] = std::get<1>(it); + + // 如果这个 op 是 yield defining op,记录 clone -> original + if (yieldDefOps.contains(op)) { + cloneAndOriYieldMap[cloned] = op; + } + + // 记录copy的for op + if (auto forOp = dyn_cast(cloned)) { + copiedForOps.push_back(forOp); + } + + clonedOps.push_back(cloned); + } + } + + // 插入到当前 region 开头 + curRegion.opsToMove.insert(curRegion.opsToMove.begin(), + clonedOps.begin(), clonedOps.end()); + + // rebuild SSA + for (Operation *op : curRegion.opsToMove) { + rewriteOperandsRecursively(op, valueMap); + } + + // 排序保证拓扑顺序 + sortOperationsByDataFlow(curRegion.opsToMove); + + } +} + +/// 记录 forOp 的 yield value 与其原始生成的 op 的映射 +static void GetYieldMap(scf::ForOp forOp, + DenseMap &yieldMap) { + yieldMap.clear(); + + // 取 forOp body 的 scf.yield + auto yieldOp = dyn_cast(forOp.getBody()->getTerminator()); + if (!yieldOp) + return; + + for (Value yieldVal : yieldOp.getOperands()) { + // 获取生成 yieldVal 的原始 op + Operation *defOp = yieldVal.getDefiningOp(); + + // 对 block arg(可能是 iter_arg)没有 definingOp 的情况,可以跳过或直接记录 nullptr + if (!defOp) + continue; + + yieldMap[yieldVal] = defOp; + } +} + +static Value findIterArgForAIC(Value v, scf::ForOp forOp) { + while (true) { + if (auto arg = dyn_cast(v)) { + if (arg.getOwner() == forOp.getBody()) + return v; + return Value(); + } + + Operation *def = v.getDefiningOp(); + if (!def) + return Value(); + + if (def->getNumOperands() == 0) + return Value(); + + v = def->getOperand(0); + } +} + +static Operation *findCloneOfYieldOp( + Operation *oriYieldOp, + DenseMap &cloneAndOriYieldMap, + MergedRegion ®ion) { + + for (Operation *op : region.opsToMove) { + auto it = cloneAndOriYieldMap.find(op); + if (it != cloneAndOriYieldMap.end() && it->second == oriYieldOp) + return op; + } + return nullptr; +} + +static void RebuildForYielValuesForAIC( + scf::ForOp forOp, + SmallVector &mergedRegions, + DenseMap &yieldMap, + DenseMap &cloneAndOriYieldMap) { + + auto yieldOp = cast(forOp.getBody()->getTerminator()); + + for (MergedRegion ®ion : mergedRegions) { + + triton::DotOp dotOp = nullptr; + + for (Operation *op : region.opsToMove) { + if (auto d = dyn_cast(op)) { + dotOp = d; + break; + } + } + + if (!dotOp) + continue; + + // 处理 dot operand + for (Value operand : dotOp->getOperands()) { + + Value iterArg = findIterArgForAIC(operand, forOp); + if (!iterArg) + continue; + + auto arg = cast(iterArg); + int idx = arg.getArgNumber(); + + if (idx >= yieldOp.getNumOperands()) + continue; + + Value oriYieldValue = yieldOp.getOperand(idx); + + auto it = yieldMap.find(oriYieldValue); + if (it == yieldMap.end()) + continue; + + Operation *oriYieldOp = it->second; + + Operation *cloneOp = + findCloneOfYieldOp(oriYieldOp, cloneAndOriYieldMap, region); + + if (!cloneOp) + continue; + + yieldOp.setOperand(idx, cloneOp->getResult(0)); + } + } +} + +void ExpandMergedRegionOps(scf::ForOp forOp, + SmallVector &mergedRegions, + SmallVector &copiedForOps) { + bool isInAIV = false; + auto scopeOp = forOp->getParentOfType(); + if (!scopeOp) + return; + + auto coreTypeAttr = scopeOp->getAttrOfType( + hivm::TCoreTypeAttr::name); + + if (coreTypeAttr.getTcoretype() == hivm::TCoreType::VECTOR) { + isInAIV = true; + } + + if (isInAIV) { + DenseSet yieldRelatedOps; + + // 1 收集 yield 相关 op + CollectForYieldRelatedOps(forOp, mergedRegions, yieldRelatedOps); + + // 2 greedy 扩展 + // ExpandMergedRegionOpsGreedy(forOp, mergedRegions, yieldRelatedOps); + ExpandMergedRegionOpsGreedyV2(forOp, mergedRegions, yieldRelatedOps); + + // 3 与前面wait-set region独立的op应该被放入后面的关联的region + MoveIndependentOpsIntoRegionBackwardV2(forOp, mergedRegions); + + // 4 根据 iter_arg 使用位置放入 region + MoveForYieldOpIntoRegion(forOp, yieldRelatedOps, mergedRegions); + + // 5 剩余 yield chain 放入前一个 region + MoveRemainingYieldOpsToPrevRegion(forOp, yieldRelatedOps, mergedRegions); + } + else { // AIC单独处理, 避免出现CUBE内的tensor变量依赖 + // 用Map记录原始的for yield op的的映射 + DenseMap yieldMap; + GetYieldMap(forOp, yieldMap); + + llvm::outs()<<"YieldMap:\n"; + for(auto it: yieldMap) { + llvm::outs()<<*(it.second)<<"\n"; + } + + // 2 greedy 扩展, yield value后续处理 + ExpandMergedRegionOpsGreedyV2ForAIC(forOp, mergedRegions); + + // 复制当前region的除tt.dot、以及[wait - set]之间的op到后续的所有MergedRegion + // 倒序实现 + // 记录clone和original的yield对应op的map + DenseMap cloneAndOriYieldMap; + CopyOpsToAfterwardRegions(mergedRegions, yieldMap, cloneAndOriYieldMap, copiedForOps); + + // 4 先确定每个MergedRegion的tt.dot的operand的来源是for的哪个iter_arg(递归查找), 假设为%arg0, 依据yieldMap可以得到oriYield + // 遍历当前MergedRegion的所有op, 确定哪条op对应的cloneAndOriYieldMap的second是oriYield, 假设为%45 + // 最后替换for yield op对应位置的operand为%45 + RebuildForYielValuesForAIC( + forOp, + mergedRegions, + yieldMap, + cloneAndOriYieldMap); + + } +} + +void MergeWaitSetRegions(SmallVector ®ions, + SmallVector &merged) { + for (int i = 0; i < regions.size();) { + MergedRegion mr; + mr.regions.push_back(®ions[i]); + mr.opsToMove.append(regions[i].opsToMove); + + int j = i; + while (!regions[j].hasCopyOrFixpipe && + j + 1 < regions.size()) { + j++; + mr.regions.push_back(®ions[j]); + mr.opsToMove.append(regions[j].opsToMove); + } + + merged.push_back(std::move(mr)); + i = j + 1; + } + + for (MergedRegion &mr : merged) { + SmallPtrSet regionValues; + SmallPtrSet opSet; + + for (Operation *op : mr.opsToMove) + opSet.insert(op); + + for (Operation *op : mr.opsToMove) { + for (Value v : op->getResults()) { + bool usedOutside = false; + for (OpOperand &use : v.getUses()) { + Operation *user = use.getOwner(); + if (!opSet.contains(user) && + user->getBlock() == op->getBlock()) { + usedOutside = true; + break; + } + } + if (usedOutside) { + mr.yieldValues.push_back(v); + mr.resultTypes.push_back(v.getType()); + } + } + } + } +} + +void GetBlockInfos(SmallVector ®ions, Block &body) { + for (auto it = body.begin(); it != body.end();) { + Operation *op = &*it; + + auto waitOp = dyn_cast(op); + if (!waitOp) { + it++; + continue; + } + + auto pipeS = hivm::PipeAttr::get(op->getContext(), hivm::PIPE::PIPE_S); + if (auto syncWait = dyn_cast(op)) { + if (syncWait.getTpipe() == pipeS || syncWait.getPipe() == pipeS) { + return; + } + } + Operation *lastSetOp = nullptr; + + // 扫描到下一个 wait, 收集所有 set + auto curIt = std::next(it); + auto endIt = curIt; + int setOpCount = 0; + SmallVector opsInRegion; + for (; curIt != body.end(); ++curIt) { + Operation *curOp = &*curIt; + if (isa(curOp) && setOpCount >= 1) break; + if (isa(curOp)) { + setOpCount++; + endIt = curIt; //setop的位置 + lastSetOp = curOp; // 最后一个 set + } + } + + if (!lastSetOp) { + it = curIt; + continue; + }// 没有 set, 不包 + + // 收集 [wait, ..., lastSet] 之间的 ops + bool hasCopyOrFixpipe = false; + for (auto it2 = it; it2 != std::next(endIt); ++it2) { + Operation *curOp = &*it2; + opsInRegion.push_back(curOp); + if (isa(curOp) || isa(curOp)) { + hasCopyOrFixpipe = true; + } + } + + it = endIt++; + regions.push_back({waitOp, lastSetOp, opsInRegion, hasCopyOrFixpipe}); + } +} + +Value findIterArg(Value v, Type t) { + SmallVector worklist = {v}; + SmallPtrSet visited; + + while (!worklist.empty()) { + Value cur = worklist.front(); + worklist.erase(worklist.begin()); + if (!visited.insert(cur).second) + continue; + + // 匹配scf.for原始迭代参数, 直接返回 + if (auto b = mlir::dyn_cast(cur)) { + auto forOp = mlir::dyn_cast(b.getOwner()->getParentOp()); + if (forOp && b.getType() == t) { + for (Value iterArg : forOp.getRegionIterArgs()) { + if (iterArg.getAsOpaquePointer() == b.getAsOpaquePointer()) { + return b; + } + } + } + } + + Operation *defOp = cur.getDefiningOp(); + if (!defOp) continue; + + // 核心逻辑:如果当前值是scf.if的结果 + // 进入then块找源头 + if (auto ifOp = mlir::dyn_cast(defOp)) { + Block &thenBlock = ifOp.getThenRegion().front(); + // 找到then块最后一个op(scf.yield) + // 取其operands(即ifOp结果的源头值) + for (auto &innerOp : llvm::reverse(thenBlock)) { + if (auto yieldOp = mlir::dyn_cast(&innerOp)) { + // 按索引匹配: cur是ifOp的第n个结果, 取yieldOp的第n个operand + for (auto [idx, res] : llvm::enumerate(ifOp.getResults())) { + if (res.getAsOpaquePointer() == cur.getAsOpaquePointer()) { + Value srcVal = yieldOp.getOperand(idx); + if (!visited.count(srcVal)) worklist.push_back(srcVal); + break; + } + } + break; // 找到yield即退出, 无需遍历其他op + } + } + } else { + // 非if结果值 + // 正常往前追溯operands + for (Value operand : defOp->getOperands()) { + if (!visited.count(operand)) worklist.push_back(operand); + } + } + } + + llvm::outs() << "未找到迭代参数, 返回原值: "; v.print(llvm::outs()); llvm::outs() << "\n"; + return v; +} + +// 如果 v 最终被 scf.for 的 yield 使用 +// → 返回对应的 forOp 的 iter_arg +// 如果 v 只是流向后面的 wait-set region / 其他 op +// → 直接返回原值 v +Value findIterArgForAll(Value v, Type t) { + for (Operation *user : v.getUsers()) { + + if (auto yieldOp = dyn_cast(user)) { + + if (auto forOp = dyn_cast(yieldOp->getParentOp())) { + + for (auto [idx, operand] : llvm::enumerate(yieldOp.getOperands())) { + + if (operand.getAsOpaquePointer() == v.getAsOpaquePointer()) { + + Value iterArg = forOp.getRegionIterArgs()[idx]; + + if (iterArg.getType() == t) + return iterArg; + } + } + } + } + } + + return v; +} + +void FindDependValues (SmallVector &dependValues, SmallVector mergedRegions) { + dependValues.clear(); + for (auto &curMR : mergedRegions) { + for (Value yieldValue : curMR.yieldValues) { + // llvm::outs() << "yieldValue: "<< yieldValue << "\n"; + // 遍历当前区域的yieldValue的所有user OP,判断是否存在依赖关系 + for (OpOperand &use : yieldValue.getUses()) { + Operation *userOp = use.getOwner(); + + // llvm::outs() << "userOp: "<< *userOp << "\n"; + bool isUserInOtherRegion = false; + for (auto &otherMR : mergedRegions) { + // 跳过当前区域,只检查yieldValue是否被其他区域使用 + if (&otherMR == &curMR) continue; + + // 只要有一个 userOp在 otherMR 的 opsToMove 列表中,就认为是dependValue + // llvm::outs() << "judge comtain\n"; + // for (size_t k = 0; k < otherMR.opsToMove.size(); k++) { + // llvm::outs() << "otherMR op: " << *(otherMR.opsToMove[k]) << "\n"; + // } + // llvm::outs() << "otherMR end\n"; + + // if (llvm::is_contained(otherMR.opsToMove, userOp)) { + // isUserInOtherRegion = true; + // llvm::outs() << "is_contained\n"; + // break; + // } + + // 用 DenseSet 暂存当前 region 的所有 ops + // 初始 DenseSet: 顶层 opsToMove + DenseSet otherOps; + for (Operation *op : otherMR.opsToMove) { + CollectAllNestedOps(op, otherOps); // 完整展开嵌套 + } + if (otherOps.contains(userOp)) { + isUserInOtherRegion = true; + break; + } + + } + + // 无重复的添加依赖变量 + if (isUserInOtherRegion) { + if (!llvm::is_contained(dependValues, yieldValue)) { + dependValues.push_back(yieldValue); + } + break; + } + } + } + } +} + +void UpdateMergedRegionsWithNewForOp(SmallVector &mergedRegions, IRMapping &mapper) { + for (auto &mr : mergedRegions) { + // WaitSetRegion 后续已经不使用了,直接释放,否则会出现野指针 + SmallVector newRegions; + newRegions.clear(); + mr.regions = newRegions; + // // 更新 opsToMove 列表 + // llvm::outs() << "before \n"; + // for (auto &op : mr.opsToMove) { + // llvm::outs() << "opsToMove: " << op << ", " << *op << '\n'; + // } + SmallVector newOpsToMove; + newOpsToMove.clear(); + for (Operation *op : mr.opsToMove) { + if (op) { + Operation *newOp = mapper.lookupOrNull(op); + newOpsToMove.push_back(newOp); + } + } + mr.opsToMove = newOpsToMove; + // llvm::outs() << "after \n"; + // for (auto &op : mr.opsToMove) { + // llvm::outs() << "opsToMove: " << op << ", " << *op << '\n'; + // } + // 更新 yieldValues 列表 + SmallVector newYieldValues; + newYieldValues.clear(); + for (Value v : mr.yieldValues) { + if (v) { + newYieldValues.push_back(mapper.lookupOrNull(v)); + } + } + mr.yieldValues = newYieldValues; + // resultTypes 是type 类型,无需更新 + } +} + +void AddArgsForDependValues (scf::ForOp forOp, SmallVector &dependValues, SmallVector &mergedRegions, ModuleOp module) { + OpBuilder moduleBuilder(module.getContext()); + SmallVector valueTypes; + valueTypes.clear(); + + if (dependValues.empty()) { + return ; + } else { + for (Value v : dependValues) { + Type valueType = v.getType(); + valueTypes.push_back(valueType); + } + } + + // 为每个 dependValue 创建一个初始值(可能不存在相同shape和type的常量tensor) + SmallVector initTensors; + initTensors.clear(); + module.walk([&](Operation *op) { + if (auto constOp = dyn_cast(op)) { + moduleBuilder.setInsertionPoint(constOp); + for (Type valueType : valueTypes) { + auto tensorType = dyn_cast(valueType); + triton::PointerType ptrType; + ptrType = (tensorType) ? dyn_cast(tensorType.getElementType()) : dyn_cast(valueType); + if (ptrType) { + // 如果依赖变量是一个ptr类型 + // 1. 创建 i64 0 + // 2. cast 成 !tt.ptr<...> + Value zero = moduleBuilder.create(constOp.getLoc(), 0, 64); + Value ptrValue = moduleBuilder.create(constOp.getLoc(), ptrType, zero); + if (tensorType) { + // 3. splat 成 tensor<...x!tt.ptr<...>> + Value ptrTensor = moduleBuilder.create(constOp.getLoc(), tensorType, ptrValue); + initTensors.push_back(ptrTensor); + } else { + initTensors.push_back(ptrValue); + } + } else if (auto memrefType = dyn_cast(valueType)) { + // 如果中间变量是一个memref类型,为iterarg创建一个 alloc = memref + // 仅支持#hivm.address_space,对于#hivm.address_space,不存在 copy cbuf to cbuf 行为 + auto spaceAttr = cast(memrefType.getMemorySpace()); + if (spaceAttr && spaceAttr.getAddressSpace() == hivm::AddressSpace::L1) { + llvm::dbgs() << "AddArgsForDependValues: dependValue type is a memref hivm::AddressSpace::L1 type!!!\n"; + return mlir::WalkResult::interrupt(); + } else { + mlir::Value alloc = moduleBuilder.create(constOp.getLoc(), memrefType); + initTensors.push_back(alloc); + } + } else { + // 非 ptr 类型创建零值常量 + auto zeroAttr = moduleBuilder.getZeroAttr(valueType); + Value zeroTensor = moduleBuilder.create(constOp.getLoc(), zeroAttr); + initTensors.push_back(zeroTensor); + } + } + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + + auto initArgs = forOp.getInitArgs(); + + // 构建新的初始化参数列表 + SmallVector newInitArgs(initArgs.begin(), initArgs.end()); + // 添加 dependValue 的初始化参数 + for (Value initTensor : initTensors) { + newInitArgs.push_back(initTensor); + } + + // 获取原循环的边界和步长 + Value lb = forOp.getLowerBound(); + Value ub = forOp.getUpperBound(); + Value step = forOp.getStep(); + + // 创建新的 ForOp,插入点位于原操作之前 + OpBuilder builder(forOp); + auto newForOp = builder.create(forOp.getLoc(), lb, ub, step, newInitArgs); + + // 获取新循环的 region 块(已自动包含循环索引和迭代参数) + Block &newBlock = newForOp.getRegion().front(); + Block &oldBlock = forOp.getRegion().front(); + + // 建立块参数的映射:原块参数 -> 新块参数 + IRMapping mapper; + for (unsigned i = 0; i < oldBlock.getNumArguments(); ++i) { + mapper.map(oldBlock.getArgument(i), newBlock.getArgument(i)); + } + // 将原循环体中的操作(不包括终结符)克隆到新块中 + // 同时按照顺序克隆新的 dependValues + SmallVector newDependValues = dependValues; + int cnt = 0; + builder.setInsertionPointToStart(&newBlock); + for (auto &op : oldBlock) { + auto newOp = builder.clone(op, mapper); + // dependValue 的定义OP 可能有多个 result + for (size_t i = 0; i < dependValues.size(); i++) { + Operation *defineOp = dependValues[i].getDefiningOp(); + if (defineOp == &op) { + unsigned int index = cast(dependValues[i]).getResultNumber(); + newDependValues[i] = newOp->getResult(index); + cnt++; + break; + } + } + } + // 判断是否找到了所有的 dependValue + if (newDependValues.size() != cnt) { + llvm::outs() << "can not find the depend value! \n"; + return; + } + dependValues = newDependValues; + + // 更新 mergedRegions 中的 op 为新的for循环的 op + UpdateMergedRegionsWithNewForOp(mergedRegions, mapper); + + // 创建新的循环 yield 操作:原操作数 + dependValues + auto oldYield = cast(newBlock.getTerminator()); + SmallVector newYieldOps(oldYield.getOperands()); + // 按顺序增加找到的 dependvalue + for (Value v : newDependValues) { + newYieldOps.push_back(v); + } + builder.setInsertionPointToEnd(&newBlock); + builder.create(oldYield.getLoc(), newYieldOps); + oldYield.erase(); + + // 将原 forOp 的所有使用替换为新 forOp + int oldResultNum = forOp->getResults().size(); + for (auto it : llvm::zip(forOp->getResults(), newForOp->getResults().take_front(oldResultNum))) { + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + forOp.erase(); +} + +void ComputeElseYieldValues (MergedRegion mergedRegion, SmallVector &elseYieldValues, SmallVector dependValues) { + int idx = 0; + for (Value v : mergedRegion.yieldValues) { + Type yieldType = mergedRegion.resultTypes[idx]; + elseYieldValues.push_back(findIterArg(v, yieldType)); + idx++; + } +} + +void ComputeElseYieldValuesV2 (MergedRegion mergedRegion, SmallVector &elseYieldValues, SmallVector dependValues) { + // 对于yieldValues,其中的 yield value 一定是被 for op yield 所引用,或者被其他 region 所使用 + auto forOp = dyn_cast(mergedRegion.yieldValues[0].getDefiningOp()->getBlock()->getParentOp()); + if (!forOp) { + llvm::outs() << "define op's parent is not ForOp \n"; + return; + } + auto iterArgs = forOp.getRegionIterArgs(); + auto forYieldValues = forOp.getYieldedValues(); + + // 新增的与 dependvalue 相关的 initarg 是接在原本for循环args后面,数量与dependvalue数量相等 + int baseDependIdx = iterArgs.size() - dependValues.size(); + + int idx = 0; + for (Value v : mergedRegion.yieldValues) { + Type yieldType = mergedRegion.resultTypes[idx]; + // yieldValue 中是dependvalue 的情况下 + // else yield value 使用对应的新增 iterargs + if (llvm::is_contained(dependValues, v)) { + int dependIdx = 0; + for (; dependIdx < dependValues.size(); dependIdx++) { + if (v == dependValues[dependIdx]) { + break; + } + } + // llvm::outs()<<"v2for:"< newYieldValues; + SmallVector newResultTypes; + + SmallPtrSet seen; + + for (auto [idx, v] : llvm::enumerate(region.yieldValues)) { + if (seen.insert(v).second) { + newYieldValues.push_back(v); + newResultTypes.push_back(region.resultTypes[idx]); + } + } + + region.yieldValues.swap(newYieldValues); + region.resultTypes.swap(newResultTypes); +} + +static void replaceExternalIfOpUses(scf::IfOp ifOp, + ArrayRef oldYieldValues) { + + for (size_t i = 0; i < oldYieldValues.size(); ++i) { + Value oldVal = oldYieldValues[i]; + Value newVal = ifOp.getResult(i); + + SmallVector usesToReplace; + + for (OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) { + + Operation *user = use.getOwner(); + + // 跳过 ifOp 内部的使用(then / else region) + if (ifOp->isAncestor(user)) + continue; + + // 只替换 ifOp 之后的使用 + if (user->getBlock() == ifOp->getBlock()) { + if (!ifOp->isBeforeInBlock(user)) + continue; + } + + usesToReplace.push_back(&use); + } + + for (OpOperand *use : usesToReplace) + use->set(newVal); + } +} + +void CreateIfOps (SmallVector &mergedRegions, SmallVector dependValues) { + for (auto ®ion : mergedRegions) { + + // 去重yieldvalues + RemoveRedundantYieldValues(region); + + Operation *insertPt = region.opsToMove.front(); + OpBuilder builder(insertPt); + Location loc = insertPt->getLoc(); + Value cond = builder.create( + loc, builder.getI1Type(), builder.getBoolAttr(true)); + + bool needsYield = !region.yieldValues.empty(); + scf::IfOp ifOp; + if (needsYield) + ifOp = builder.create(loc, region.resultTypes, cond, true); + else + ifOp = builder.create(loc, TypeRange{}, cond, false); + + // 加标记 + ifOp->setAttr("ssbuffer", builder.getUnitAttr()); + + // 获取if yield value 在 else块 返回值 + SmallVector elseYieldValues; + + llvm::outs()<<"before ComputeElseYieldValuesV2"<<"\n"; + if (needsYield) { + // ComputeElseYieldValues(region, elseYieldValues, dependValues); + ComputeElseYieldValuesV2(region, elseYieldValues, dependValues); + } + + llvm::outs()<<"after ComputeElseYieldValuesV2"<<"\n"; + // 将op移进then块 + Block &thenBlock = ifOp.getThenRegion().front(); + for (Operation *m : llvm::reverse(region.opsToMove)) { + m->moveBefore(&thenBlock, thenBlock.begin()); + } + + // 创建 then/else yield + if (needsYield) { + OpBuilder thenBuilder(builder.getContext()); + thenBuilder.setInsertionPointToEnd(&thenBlock); + thenBuilder.create(loc, region.yieldValues); + + // else block + Block &elseBlock = ifOp.getElseRegion().front(); + OpBuilder elseBuilder(&elseBlock, elseBlock.end()); + elseBuilder.create(loc, elseYieldValues); + + // 替换外部使用 + + replaceExternalIfOpUses(ifOp, region.yieldValues); + + // 旧的逻辑 + // Block *block = ifOp->getBlock(); + // auto ifIt = Block::iterator(ifOp); + + // for (size_t i = 0; i < region.yieldValues.size(); ++i) { + // Value oldVal = region.yieldValues[i]; + // Value newVal = ifOp.getResult(i); + + // SmallVector usesToReplace; + + // for (OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) { + // Operation *user = use.getOwner(); + // // 同一个 block, user 必须在 ifOp 之后, 不能在 ifOp 内部(then / else) + // if (user->getBlock() != ifOp->getBlock() || !ifOp->isBeforeInBlock(user) || user->getParentOp() == ifOp) + // continue; + // usesToReplace.push_back(&use); + // } + + // for (OpOperand *use : usesToReplace) + // use->set(newVal); + // } + + } + + llvm::outs() <<"Create ifOp: "<< *ifOp << "\n"; + } +} + +void CreateIfOpsOrigin (SmallVector &mergedRegions) { + for (auto ®ion : mergedRegions) { + + // 去重yieldvalues + RemoveRedundantYieldValues(region); + + Operation *insertPt = region.opsToMove.front(); + OpBuilder builder(insertPt); + Location loc = insertPt->getLoc(); + Value cond = builder.create( + loc, builder.getI1Type(), builder.getBoolAttr(true)); + + bool needsYield = !region.yieldValues.empty(); + scf::IfOp ifOp; + if (needsYield) + ifOp = builder.create(loc, region.resultTypes, cond, true); + else + ifOp = builder.create(loc, TypeRange{}, cond, false); + + // 加标记 + ifOp->setAttr("ssbuffer", builder.getUnitAttr()); + + // 将op移进then块 + Block &thenBlock = ifOp.getThenRegion().front(); + for (Operation *m : llvm::reverse(region.opsToMove)) { + m->moveBefore(&thenBlock, thenBlock.begin()); + } + + // 创建 then/else yield + if (needsYield) { + OpBuilder thenBuilder(builder.getContext()); + thenBuilder.setInsertionPointToEnd(&thenBlock); + thenBuilder.create(loc, region.yieldValues); + + // else block + SmallVector elseYieldValues; + int idx = 0; + for (Value v : region.yieldValues) { + Type yieldType = region.resultTypes[idx]; + elseYieldValues.push_back(findIterArgForAll(v, yieldType)); + idx++; + } + Block &elseBlock = ifOp.getElseRegion().front(); + OpBuilder elseBuilder(&elseBlock, elseBlock.end()); + elseBuilder.create(loc, elseYieldValues); + + // 替换外部使用 + Block *block = ifOp->getBlock(); + auto ifIt = Block::iterator(ifOp); + + for (size_t i = 0; i < region.yieldValues.size(); ++i) { + Value oldVal = region.yieldValues[i]; + Value newVal = ifOp.getResult(i); + + SmallVector usesToReplace; + + for (OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) { + Operation *user = use.getOwner(); + // 同一个 block, user 必须在 ifOp 之后, 不能在 ifOp 内部(then / else) + if (user->getBlock() != ifOp->getBlock() || !ifOp->isBeforeInBlock(user) || user->getParentOp() == ifOp) + continue; + usesToReplace.push_back(&use); + } + + for (OpOperand *use : usesToReplace) + use->set(newVal); + } + + } + + llvm::outs() <<"Create ifOp: "<< *ifOp << "\n"; + } +} + +void AddIfCondition(ModuleOp module) { + SmallVector copiedForOps; + SmallVector forOpList; + SmallVector, 1> regionList; + + module.walk([&](scf::ForOp forOp) { + Block &body = forOp.getRegion().front(); + SmallVector regions; + + // 获取基本的wait-set分块信息 + GetBlockInfos(regions, body); + + SmallVector mergedRegions; + // 合并wait-set块, 依据copyop / fixpipeop合并 + MergeWaitSetRegions(regions, mergedRegions); + + // 扩展if包裹的op范围 + // AIV、AIC处理有区别 + ExpandMergedRegionOps(forOp, mergedRegions, copiedForOps); + + // 处理forop的末尾对于iter_arg的自增操作, 如tt.advance, 移进对应的if op + MoveIterArgUsersIntoIf(forOp, mergedRegions); + + // 获取if yield的value, 并更新if内op的user为yield value + for (MergedRegion &mr : mergedRegions) { + // ComputeYieldForMergedRegion(mr, body); + ComputeYieldForMergedRegionV4(mr); + } + + // // 创建最终的if op + // CreateIfOpsOrigin(mergedRegions); + // }); + + forOpList.push_back(forOp); + regionList.push_back(mergedRegions); + }); + + llvm::outs()<<"CopyForOp:\n"; + for(auto op : copiedForOps){ + llvm::outs()<<*op<<"\n"; + } + + SmallVector tmpOps; + for (auto copiedOp : copiedForOps) { + Block &body = copiedOp.getRegion().front(); + SmallVector regions; + + // 获取基本的wait-set分块信息 + GetBlockInfos(regions, body); + + SmallVector mergedRegions; + // 合并wait-set块, 依据copyop / fixpipeop合并 + MergeWaitSetRegions(regions, mergedRegions); + + // 扩展if包裹的op范围 + // AIV、AIC处理有区别 + ExpandMergedRegionOps(copiedOp, mergedRegions, tmpOps); + + // 处理forop的末尾对于iter_arg的自增操作, 如tt.advance, 移进对应的if op + MoveIterArgUsersIntoIf(copiedOp, mergedRegions); + + // 获取if yield的value, 并更新if内op的user为yield value + for (MergedRegion &mr : mergedRegions) { + // ComputeYieldForMergedRegion(mr, body); + ComputeYieldForMergedRegionV4(mr); + } + + // // 创建最终的if op + // CreateIfOpsOrigin(mergedRegions); + // }); + + forOpList.push_back(copiedOp); + regionList.push_back(mergedRegions); + } + + for (size_t i = 0; i < forOpList.size(); ++i) { + scf::ForOp oldForOp = forOpList[i]; + SmallVector newMergedRegions = regionList[i]; + + // 找到所有的VV或CC依赖 + SmallVector dependValues; + llvm::outs() << "FindDependValues! \n "; + FindDependValues(dependValues, newMergedRegions); + + if (dependValues.size() != 0) { + copyLoadCalculation(oldForOp, dependValues, newMergedRegions); + + // repeat previous operations + for (MergedRegion &mr : newMergedRegions) { + mr.yieldValues.clear(); + mr.resultTypes.clear(); + ComputeYieldForMergedRegionV4(mr); + } + FindDependValues(dependValues, newMergedRegions); + } + + // 如果存在VV或CC依赖,更新ForOp添加新的对应args + if (dependValues.size() != 0) { + AddArgsForDependValues(oldForOp, dependValues, newMergedRegions, module); + } + + // 创建最终的if op + llvm::outs() << "before create if ops" << '\n'; + CreateIfOps(newMergedRegions, dependValues); + } +} + +void ChangeAdvanceOpForm(ModuleOp module) { + module.walk([&](scf::ForOp forOp) { + Block &body = forOp.getRegion().front(); + constexpr int num = 8; + SmallVector ifOps; + for (Operation &op : body) + if (auto ifOp = dyn_cast(&op)) + ifOps.push_back(ifOp); + + for (scf::IfOp ifOp : ifOps) { + // 找 then region 中的 advance + triton::AdvanceOp advanceOp; + for (Operation &thenOp : ifOp.getThenRegion().front()) { + if (auto adv = dyn_cast(thenOp)) { + advanceOp = adv; + break; + } + } + if (!advanceOp) continue; + + // base 必须是 for的iter_arg + Value base = advanceOp.getPtr(); + auto barg = dyn_cast(base); + if (!barg || barg.getOwner() != &body) continue; + + // yield 去掉 advance 的返回值 + auto thenYield = cast(ifOp.getThenRegion().front().getTerminator()); + auto elseYield = cast(ifOp.getElseRegion().front().getTerminator()); + + int advanceIdx = -1; + for (auto it : llvm::enumerate(thenYield.getOperands())) { + if (it.value() == advanceOp.getResult()) { + advanceIdx = it.index(); + break; + } + } + + if (advanceIdx == -1) continue; + + // 删除 advance + SmallVector thenOps(thenYield.getOperands().begin(), thenYield.getOperands().end()); + SmallVector elseOps(elseYield.getOperands().begin(), elseYield.getOperands().end()); + + thenOps.erase(thenOps.begin() + advanceIdx); + elseOps.erase(elseOps.begin() + advanceIdx); + + thenYield->setOperands(thenOps); + elseYield->setOperands(elseOps); + + // 重建 ifOp(去掉 advance 对应的 result) + OpBuilder ifBuilder(ifOp); + ifBuilder.setInsertionPoint(ifOp); + + // 构造新的 result types + SmallVector newResultTypes; + for (int i = 0; i < ifOp.getNumResults(); ++i) { + if (i != advanceIdx) + newResultTypes.push_back(ifOp.getResult(i).getType()); + } + + // 创建新的 if + auto newIf = ifBuilder.create( + ifOp.getLoc(), + newResultTypes, + ifOp.getCondition(), + /*withElseRegion=*/true); + newIf->setAttr("ssbuffer", ifBuilder.getUnitAttr()); + // 把已经修改过 yield 的 region 搬过去 + newIf.getThenRegion().takeBody(ifOp.getThenRegion()); + newIf.getElseRegion().takeBody(ifOp.getElseRegion()); + + // 替换if result的user + int newIdx = 0; + for (int oldIdx = 0; oldIdx < ifOp.getNumResults(); ++oldIdx) { + if (oldIdx == advanceIdx) + continue; + ifOp.getResult(oldIdx).replaceAllUsesWith(newIf.getResult(newIdx++)); + } + + OpBuilder builder(newIf); + builder.setInsertionPointAfter(newIf); + + Value flag = newIf.getCondition(); + + SmallVector newOffsets; + for (Value off : advanceOp.getOffsets()) { + auto intTy = cast(off.getType()); + auto zero = builder.create( + newIf.getLoc(), 0, intTy.getWidth()); + auto sel = builder.create( + newIf.getLoc(), flag, off, zero); + newOffsets.push_back(sel); + } + + auto newAdvance = builder.create( + newIf.getLoc(), base.getType(), base, newOffsets); + + // 原 if 的 advance result 的 users,接到 newAdvance + ifOp.getResult(advanceIdx).replaceAllUsesWith(newAdvance.getResult()); + + // 删除旧的ifOp和advance + advanceOp.erase(); + ifOp.erase(); + } + }); +} + +void processRedudantIf(ModuleOp module) { + SmallVector forOps; + llvm::outs()< newInitArgs(initArgs.begin(), initArgs.end()); + newInitArgs.push_back(newInit); + + // 获取原循环的边界和步长 + Value lb = forOp.getLowerBound(); + Value ub = forOp.getUpperBound(); + Value step = forOp.getStep(); + + // 创建新的 ForOp,插入点位于原操作之前 + OpBuilder builder(forOp); + auto newForOp = builder.create(forOp.getLoc(), lb, ub, step, newInitArgs); + + // 获取新循环的 region 块(已自动包含循环索引和迭代参数) + Block &newBlock = newForOp.getRegion().front(); + Block &oldBlock = forOp.getRegion().front(); + + // 建立块参数的映射:原块参数 -> 新块参数(前6个对应) + IRMapping mapper; + for (unsigned i = 0; i < oldBlock.getNumArguments(); ++i) { + mapper.map(oldBlock.getArgument(i), newBlock.getArgument(i)); + } + // 将原循环体中的操作(不包括终结符)克隆到新块中 + builder.setInsertionPointToStart(&newBlock); + for (auto &op : oldBlock) { + auto newOp = builder.clone(op, mapper); + } + + // 在新块中查找第一个 scf::IfOp(即原代码中的第一个 if) + scf::IfOp firstIfOp = nullptr; + for (auto &op : newBlock.getOperations()) { + if (auto ifOp = dyn_cast(&op)) { + firstIfOp = ifOp; + break; + } + } + assert(firstIfOp && "Expected at least one if op in the loop body"); + + // 修改第一个 if 的 else 分支的 yield 操作: + // 将其第二个操作数(索引1)从原来的 %arg9 改为新迭代参数(新块参数索引6) + Block &elseBlock = firstIfOp.getElseRegion().front(); + auto elseYield = cast(elseBlock.getTerminator()); + SmallVector newElseYieldOps(elseYield.getOperands()); + newElseYieldOps[1] = newBlock.getArgument(6); // 新迭代参数 + builder.setInsertionPoint(elseYield); + builder.create(elseYield.getLoc(), newElseYieldOps); + elseYield->erase(); + + // 创建新的循环 yield 操作:原5个操作数 + 第一个 if 的第二个结果 + auto oldYield = cast(newBlock.getTerminator()); + SmallVector newYieldOps(oldYield.getOperands()); + newYieldOps.push_back(firstIfOp.getResult(1)); // 第一个 if 的第二个结果 + builder.setInsertionPointToEnd(&newBlock); + builder.create(oldYield.getLoc(), newYieldOps); + oldYield.erase(); + + // 将原 forOp 的所有使用替换为新 forOp 的前5个结果 + for (auto it : llvm::zip(forOp->getResults(), newForOp->getResults().take_front(5))) { + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + } + for (auto forOp : forOps) { + forOp.erase(); + } +} +// 针对依赖变量,对原本的for op增加double buffer相关的迭代参数 +scf::ForOp addDoubleBuffForArgs(ModuleOp module, SmallVector uniqueDeps, int bufferNum) { + mlir::OpBuilder builder(module.getContext()); + SmallVector depValueForIdxs; + + // ========== 找到scf.if所在的scf::ForOp ========== + if (!isa(uniqueDeps[0].getDefiningOp()->getParentOp())) { + llvm::errs() << "Error: parent op of scf.if is not scf.for"; + } + scf::ForOp forOp = dyn_cast(uniqueDeps[0].getDefiningOp()->getParentOp()); + + for(Value dependencyValue : uniqueDeps){ + // ========== 步骤1:验证目标Value是scf.if的返回值,并找到对应的scf::IfOp ========== + Operation *ifOp = dependencyValue.getDefiningOp(); + if (!ifOp || !isa(ifOp)) { + llvm::errs() << "Error: 目标Value不是scf.if的返回值\n"; + return nullptr; + } + scf::IfOp targetIfOp = dyn_cast(ifOp); + + // 确认当前Value是scf.if的第几个返回值 + int64_t depValueIdx = -1; + for (auto [idx, result] : llvm::enumerate(targetIfOp.getResults())) { + if (result == dependencyValue) { + depValueIdx = idx; + break; + } + } + + // ========== 步骤2:找到%38#2关联的scf.for迭代参数以及索引 ========== + // %38#2对应scf.if else分支yield的第2个操作数 → 即%arg10 + Operation *elseYield = targetIfOp.elseYield(); + Value dependencyArg = elseYield->getOperand(depValueIdx); // depValueIdx=2,对应else yield的第2个参数 + + int64_t depValueForIdx = -1; + for (auto [idx, result] : llvm::enumerate(forOp.getRegionIterArgs())) { + if (result == dependencyArg) { + depValueForIdx = idx; + break; + } + } + depValueForIdxs.push_back(depValueForIdx); + llvm::outs() << "depValueForIdx: " << depValueForIdx << '\n'; + } + + llvm::outs() << "oldFor: " << forOp << '\n'; + + // 获取原始循环的信息 + Value originalLowerBound = forOp.getLowerBound(); + Value originalUpperBound = forOp.getUpperBound(); + Value originalStep = forOp.getStep(); + SmallVector originalInitArgs = forOp.getInitArgs(); + SmallVector iterArgs; + for (auto arg : originalInitArgs) { + iterArgs.push_back(arg); + } + auto yields = forOp.getBody()->getTerminator(); + + // 创建计数器初始零值 + Value counterInit = nullptr; + mlir::Operation* parentOp = forOp->getParentOp(); + mlir::Operation* scopeOp = nullptr; + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + + builder.setInsertionPoint(scopeOp); + Location loc = forOp.getLoc(); + auto boundType = originalLowerBound.getType(); + counterInit = builder.create(loc, 0, boundType); + + // 添加和depValueForIdxs相同的迭代参数和计数器 + for (int64_t idx : depValueForIdxs) { + for (int i = 0; i < bufferNum - 1; i++) { + iterArgs.push_back(originalInitArgs[idx]); + } + + // 在迭代参数中添加计数器 + for (int i = 0; i < 2; i++) { + iterArgs.push_back(counterInit); + } + } + + builder.setInsertionPoint(forOp); + // 创建新的for循环 + auto newForOp = builder.create( + forOp.getLoc(), + originalLowerBound, + originalUpperBound, + originalStep, + iterArgs); + + // 设置IR映射表,将旧循环的变量映射到新循环 + IRMapping mapper; + + // 映射迭代变量 + mapper.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + // 映射迭代参数 + for (auto [oldArg, newArg] : + llvm::zip(forOp.getRegionIterArgs(), + newForOp.getRegionIterArgs())) { + mapper.map(oldArg, newArg); + } + + SmallVector newArgs; + for (int i = forOp.getRegionIterArgs().size(); i < newForOp.getRegionIterArgs().size(); i++) { + newArgs.push_back(newForOp.getRegionIterArgs()[i]); + } + // 克隆循环体内容到新循环 + auto &newLoopBody = *newForOp.getBody(); + builder.setInsertionPointToStart(&newLoopBody); + + for (auto &op : forOp.getBody()->without_terminator()) { + builder.clone(op, mapper); + } + + // 克隆yield操作 + if (auto yieldOp = dyn_cast(yields)) { + SmallVector newYieldOperands; + for (auto operand : yieldOp.getOperands()) { + newYieldOperands.push_back(mapper.lookupOrDefault(operand)); + } + // 将新增的迭代参数添加到yield操作数中 + for (auto currentCounter : newArgs) { + newYieldOperands.push_back(currentCounter); + } + builder.create(yieldOp.getLoc(), newYieldOperands); + } + + // 替换原循环的结果 + unsigned numOriginalResults = forOp.getNumResults(); + SmallVector originalResults; + for (unsigned i = 0; i < numOriginalResults; i++) { + originalResults.push_back(newForOp.getResult(i)); + } + forOp.replaceAllUsesWith(originalResults); + + // 8. 删除原循环 + forOp.erase(); + + llvm::outs() << "for op erased!\n"; + return newForOp; +} + +SmallVector buildNBufferProducer(OpBuilder &builder, Location loc, + Value frontCnt, Value newDepVal, + ArrayRef buffs, + ArrayRef constants) { + // N-buffer producer: determines which buffer is written to newDepVal based on frontCnt % N + const int N = buffs.size(); + SmallVector results; + + // idx = frontCnt % N + Value bufferIndex = + builder.create(loc, frontCnt, constants[N]); + + // 1. buffer0: handle the first buffer separately + Value isBuffer0 = builder.create(loc, arith::CmpIPredicate::eq, + bufferIndex, constants[0]); + + auto dstShapedType = mlir::dyn_cast(newDepVal.getType()); + auto maskType = RankedTensorType::get(dstShapedType.getShape(), isBuffer0.getType()); + Value mask = builder.create(loc, maskType, isBuffer0); + Value newBuff0 = builder.create(loc, mask, newDepVal, buffs[0]); + + results.push_back(newBuff0); + + // 2. Double-buffer specialization (when N == 2, a direct select is sufficient) + if (N == 2) { + + Value newBuff1 = builder.create(loc, mask, buffs[1], newDepVal); + + auto nextCnt = builder.create(loc, frontCnt, constants[1]); + + results.push_back(newBuff1); + results.push_back(nextCnt.getResult()); + + return results; + } + + // 3. Build the root IF: when idx == 0, + // use the first buffer; otherwise enter the nestedIf chain to use other buffers + SmallVector resultTypes; + for (int i = 1; i < N; ++i) + resultTypes.push_back(buffs[i].getType()); + + auto rootIf = builder.create(loc, resultTypes, isBuffer0, true); + + // ---- THEN: buffers are directly forwarded ---- + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&rootIf.getThenRegion().front()); + + SmallVector unchangedBuffers(buffs.begin() + 1, buffs.end()); + + builder.create(loc, unchangedBuffers); + } + + // 4. Construct the nested-if chain, updating one buffer at each level + Block *currentElseBlock = &rootIf.getElseRegion().front(); + + scf::IfOp parentIf = rootIf; + + for (int i = 1; i < N - 1; ++i) { + + builder.setInsertionPointToStart(currentElseBlock); + + // Check whether the current buffer is selected + Value isCurrent = builder.create( + loc, arith::CmpIPredicate::eq, bufferIndex, constants[i]); + + // Update buffer[i] + dstShapedType = mlir::dyn_cast(newDepVal.getType()); + maskType = RankedTensorType::get(dstShapedType.getShape(), isCurrent.getType()); + mask = builder.create(loc, maskType, isCurrent); + Value updatedBuffer = builder.create(loc, mask, newDepVal, buffs[i]); + + // If this is the last level: directly yield both buffers + if (i == N - 2) { + + dstShapedType = mlir::dyn_cast(newDepVal.getType()); + maskType = RankedTensorType::get(dstShapedType.getShape(), isCurrent.getType()); + mask = builder.create(loc, maskType, isCurrent); + Value lastBuffer = builder.create(loc, mask, buffs[N - 1], newDepVal); + + builder.create(loc, ValueRange {updatedBuffer, lastBuffer}); + + break; + } + + // Create the next nested if + SmallVector subResultTypes; + for (int j = i + 1; j < N; ++j) + subResultTypes.push_back(buffs[j].getType()); + + auto nextIf = + builder.create(loc, subResultTypes, isCurrent, true); + + // THEN: forward the remaining buffers + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&nextIf.getThenRegion().front()); + + SmallVector remainingBuffers(buffs.begin() + i + 1, buffs.end()); + + builder.create(loc, remainingBuffers); + } + + // Update the else yield + builder.setInsertionPointToEnd(&parentIf.getElseRegion().front()); + + SmallVector yields; + yields.push_back(updatedBuffer); + yields.append(nextIf.getResults().begin(), nextIf.getResults().end()); + + builder.create(loc, yields); + + parentIf = nextIf; + currentElseBlock = &nextIf.getElseRegion().front(); + } + + // 5. Update the frontCnt counter + builder.setInsertionPointAfter(rootIf); + + auto nextCnt = builder.create(loc, frontCnt, constants[1]); + + // Collect results + results.append(rootIf.getResults().begin(), rootIf.getResults().end()); + + results.push_back(nextCnt.getResult()); + + return results; +} + +SmallVector buildNBufferConsumer(OpBuilder &builder, Location loc, + Value postCnt, ArrayRef oldBuffs, + ArrayRef constants) { + // Consumer: selects which buffer to read based on postCnt % N + const int bufferNum = oldBuffs.size(); + SmallVector results; + + // idx = postCnt % N + Value bufferIndex = + builder.create(loc, postCnt, constants[bufferNum]); + + Value isBuffer0 = builder.create(loc, arith::CmpIPredicate::eq, + bufferIndex, constants[0]); + auto dstShapedType = mlir::dyn_cast(oldBuffs[0].getType()); + auto maskType = RankedTensorType::get(dstShapedType.getShape(), isBuffer0.getType()); + auto mask = builder.create(loc, maskType, isBuffer0); + + // 1. Double-buffer specialization (avoid generating scf.if) + if (bufferNum == 2) { + Value selected = builder.create(loc, mask, oldBuffs[0], oldBuffs[1]); + auto nextCnt = builder.create(loc, postCnt, constants[1]); + + results.push_back(selected); + results.push_back(nextCnt); + + return results; + } + + // 2. Build the root IF: + // when idx == 0, use the first buffer; otherwise enter the nestedIf chain to use other buffers + SmallVector resultTypes{oldBuffs[0].getType()}; + + auto rootIf = builder.create(loc, resultTypes, isBuffer0, true); + + // ---- THEN: directly return buffer0 ---- + { + builder.setInsertionPointToStart(&rootIf.getThenRegion().front()); + + builder.create(loc, oldBuffs[0]); + } + + // 3. Construct the nested-if chain + Block *currentElse = &rootIf.getElseRegion().front(); + + for (int i = 1; i < bufferNum - 2; ++i) { + + builder.setInsertionPointToStart(currentElse); + + Value isCurrent = builder.create( + loc, arith::CmpIPredicate::eq, bufferIndex, constants[i]); + + auto nestedIf = builder.create( + loc, TypeRange{oldBuffs[0].getType()}, isCurrent, true); + + // THEN → return the current buffer + { + builder.setInsertionPointToStart(&nestedIf.getThenRegion().front()); + + builder.create(loc, oldBuffs[i]); + } + + // ELSE → yield nested result + builder.setInsertionPointToEnd(currentElse); + builder.create(loc, nestedIf.getResult(0)); + + // Enter the next else branch + currentElse = &nestedIf.getElseRegion().front(); + } + + // 4. Final level (use select to finish) + builder.setInsertionPointToStart(currentElse); + + int last = bufferNum - 2; + + Value isLast = builder.create(loc, arith::CmpIPredicate::eq, + bufferIndex, constants[last]); + + maskType = RankedTensorType::get({}, isLast.getType()); + dstShapedType = mlir::dyn_cast(oldBuffs[last].getType()); + maskType = RankedTensorType::get(dstShapedType.getShape(), isLast.getType()); + mask = builder.create(loc, maskType, isLast); + + Value finalSelect = builder.create(loc, mask, oldBuffs[last], oldBuffs[last + 1]); + + builder.create(loc, finalSelect); + + // rootIf result = selected buffer + results.push_back(rootIf.getResult(0)); + + // 5. Update the postCnt counter + builder.setInsertionPointAfter(rootIf); + + auto nextCnt = builder.create(loc, postCnt, constants[1]); + + results.push_back(nextCnt); + + return results; +} + +void replaceDepsMap( + scf::IfOp oldIfOp, + scf::IfOp newIfOp, + SmallVector &newDeps, + bool isFront, + DenseMap> &newIfResultDeps) +{ + mlir::IRMapping valueMap; + + // old result -> new result + for (unsigned i = 0; i < oldIfOp.getNumResults(); ++i) { + valueMap.map(oldIfOp.getResult(i), newIfOp.getResult(i)); + } + + if (isFront) { + for (int i = 0; i < newDeps.size(); i++) { + Value v = newDeps[i]; + if (valueMap.contains(v)) + newDeps[i] = valueMap.lookup(v); + } + } + + // rewrite deps in-place + for (auto &it : newIfResultDeps) { + auto &deps = it.second; + + for (auto &value : deps) { + if (auto mapped = valueMap.lookupOrNull(value)) + value = mapped; + } + } +} + +scf::IfOp addResultsForFrontIfOp(scf::IfOp frontIfOp, OpBuilder builder, + int bufferNum, Value depValue, + SmallVector constants, + SmallVector buffs, Value frontCnt, + Value postCnt, + SmallVector &extraResultIndices, + SmallVector &newDeps, + DenseMap> &newIfResultDeps) +{ + OpBuilder::InsertionGuard guard(builder); + + Location loc = frontIfOp.getLoc(); + Value cond = frontIfOp.getCondition(); + + auto &oldThenBlock = frontIfOp.getThenRegion().front(); + auto &oldElseBlock = frontIfOp.getElseRegion().front(); + + // New result types = old results + extra buffers + counter + SmallVector newResultTypes(frontIfOp.getResultTypes().begin(), + frontIfOp.getResultTypes().end()); + + for (int i = 1; i < bufferNum; ++i) + newResultTypes.push_back(buffs[i].getType()); + + newResultTypes.push_back(frontCnt.getType()); + + unsigned oldNumResults = frontIfOp.getNumResults(); + + // Create new IfOp + builder.setInsertionPoint(frontIfOp); + auto newIfOp = + builder.create(loc, newResultTypes, cond, /*hasElse=*/true); + + SmallVector bufferIndices(bufferNum); + SmallVector newBuffs; + int frontCntIndex = -1; + + // THEN region + { + mlir::IRMapping mapping; + Block &newThenBlock = newIfOp.getThenRegion().front(); + + builder.setInsertionPointToStart(&newThenBlock); + + // Clone original then body + for (auto &op : oldThenBlock.without_terminator()) + builder.clone(op, mapping); + + // Update dependency value position inf ifOp results + auto result = dyn_cast(depValue); + if (!result) { + llvm::outs() << "depValue is not a result Value!\n"; + return nullptr; + } + + int depIdx = result.getResultNumber(); + Value depYieldValue = frontIfOp.thenYield()->getOperand(depIdx); + + Value newDepVal = mapping.contains(depYieldValue) + ? mapping.lookup(depYieldValue) + : depYieldValue; + + builder.setInsertionPointAfter(newDepVal.getDefiningOp()); + + // Create N buffer + SmallVector produced = buildNBufferProducer( + builder, loc, frontCnt, newDepVal, buffs, constants); + + // Last value in newBuffs is the counter + newBuffs.append(produced.begin(), produced.end() - 1); + + // Rebuild new yield + SmallVector thenOperands; + + for (Value v : oldThenBlock.getTerminator()->getOperands()) { + Value mapped = mapping.lookupOrDefault(v); + + // Replace first buffer + if (mapped == newDepVal) { + thenOperands.push_back(newBuffs[0]); + bufferIndices[0] = thenOperands.size() - 1; + } else { + thenOperands.push_back(mapped); + } + } + + // Replace other buffer + for (int i = 1; i < bufferNum; ++i) { + thenOperands.push_back(newBuffs[i]); + bufferIndices[i] = thenOperands.size() - 1; + } + + // Add counter + thenOperands.push_back(produced.back()); + frontCntIndex = thenOperands.size() - 1; + + builder.setInsertionPointToEnd(&newThenBlock); + builder.create(loc, thenOperands); + + // record new result indices + for (int idx : bufferIndices) + extraResultIndices.push_back(idx); + + extraResultIndices.push_back(frontCntIndex); + } + + // ELSE region + { + mlir::IRMapping mapping; + Block &newElseBlock = newIfOp.getElseRegion().front(); + + builder.setInsertionPointToStart(&newElseBlock); + + // Clone original else body + for (auto &op : oldElseBlock.without_terminator()) + builder.clone(op, mapping); + + builder.setInsertionPointToEnd(&newElseBlock); + + SmallVector elseOperands; + + for (Value v : oldElseBlock.getTerminator()->getOperands()) + elseOperands.push_back(mapping.lookupOrDefault(v)); + + // Add buffer + for (int i = 1; i < bufferNum; ++i) + elseOperands.push_back(buffs[i]); + + // Add counter + elseOperands.push_back(frontCnt); + + builder.create(loc, elseOperands); + } + + // Update dependency value + replaceDepsMap(frontIfOp, newIfOp, newDeps, true, newIfResultDeps); + + // Replace old ifOp + frontIfOp.replaceAllUsesWith(newIfOp.getResults().take_front(oldNumResults)); + + frontIfOp.erase(); + + return newIfOp; +} + +scf::IfOp addResultsForPostIfOp(scf::IfOp postIfOp, scf::IfOp newfrontIfOp, + OpBuilder builder, int bufferNum, + Value newDepValue, SmallVector constants, + SmallVector buffs, Value frontCnt, + Value postCnt, + SmallVector &extraResultIndices, + SmallVector &newDeps, + DenseMap> &newIfResultDeps) +{ + // 1. Parse the extra result indices produced by frontIf (added buffers and counters) + SmallVector bufferIndices(extraResultIndices.begin(), + extraResultIndices.end() - 1); + int frontCntIndex = extraResultIndices[bufferNum]; + + Location ifLoc = postIfOp.getLoc(); + Value cond = postIfOp.getCondition(); + + auto &oldThenBlock = postIfOp.getThenRegion().front(); + auto &oldElseBlock = postIfOp.getElseRegion().front(); + + // 2. Create a new IfOp (add a new postCnt result) + SmallVector newResultTypes(postIfOp.getResultTypes().begin(), + postIfOp.getResultTypes().end()); + newResultTypes.push_back(postCnt.getType()); + + builder.setInsertionPoint(postIfOp); + auto newIfOp = builder.create(ifLoc, newResultTypes, cond, + /*hasElse=*/true); + + mlir::IRMapping mapping; + + // 3. THEN region: clone the original logic, insert the multibuffer consumer and update dependency buffers + auto &newThenBlock = newIfOp.getThenRegion().front(); + builder.setInsertionPointToStart(&newThenBlock); + + // clone then body + for (auto &op : oldThenBlock.without_terminator()) + builder.clone(op, mapping); + builder.setInsertionPointToStart(&newThenBlock); + + // Find dependency uses that need to be replaced (located inside the current IfOp) + SmallVector replaceUses; + for (auto &use : newDepValue.getUses()) { + if (newIfOp == dyn_cast(use.getOwner()->getParentOp())) { + replaceUses.push_back(&use); + } + } + + // Collect buffers produced by frontIf + SmallVector oldBuffers; + for (int i = 0; i < bufferIndices.size(); ++i) + oldBuffers.push_back(newfrontIfOp.getResult(bufferIndices[i])); + + // Multibuffer consumer caculation + SmallVector consumerResults = + buildNBufferConsumer(builder, ifLoc, postCnt, oldBuffers, constants); + + Value selectedBuffer = consumerResults[0]; + Value nextPostCnt = consumerResults[1]; + + // Replace dependent buffer + for (auto *usePtr : replaceUses) { + usePtr->set(selectedBuffer); + } + + // Create then yield + SmallVector thenOperands; + for (auto v : oldThenBlock.getTerminator()->getOperands()) + thenOperands.push_back(mapping.lookupOrDefault(v)); + + int postCntIndex = thenOperands.size(); + thenOperands.push_back(nextPostCnt); + + builder.setInsertionPointToEnd(&newThenBlock); + builder.create(ifLoc, thenOperands); + extraResultIndices.push_back(postCntIndex); + + // 4. ELSE region:forward counter directly + auto &newElseBlock = newIfOp.getElseRegion().front(); + + for (auto &op : oldElseBlock.without_terminator()) + builder.clone(op, mapping); + + builder.setInsertionPointToEnd(&newElseBlock); + + SmallVector elseOperands; + for (auto v : oldElseBlock.getTerminator()->getOperands()) + elseOperands.push_back(mapping.lookupOrDefault(v)); + + elseOperands.push_back(postCnt); + + builder.create(ifLoc, elseOperands); + + // 5. Replace old ifOp with new one + auto oldNumResults = postIfOp.getNumResults(); + + // Update depency value + replaceDepsMap(postIfOp, newIfOp, newDeps, false, newIfResultDeps); + + postIfOp.replaceAllUsesWith(newIfOp.getResults().take_front(oldNumResults)); + + postIfOp.erase(); + + return newIfOp; +} + +void addMultiBuffCaculate(ModuleOp module, SmallVector newUniqueDeps, + DenseMap> &ifResultDeps, scf::ForOp &newForOp, int bufferNum) +{ + + // ============================================================ + // Overall Idea + // + // For each dependency Value: + // 1. Find the front IfOp that produces it + // 2. Add multi-buffer results to the front IfOp + // 3. Find the post IfOp that consumes the result and extend it accordingly + // 4. Update the for-loop yield so that buffer states are correctly propagated + // ============================================================ + + OpBuilder builder(module.getContext()); + int processedDepCount = 0; + + SmallVector postIfOps; + newForOp.walk([&](scf::IfOp postIfOp) { + postIfOps.push_back(postIfOp); + }); + for (auto postIfOp:postIfOps) { + if (!ifResultDeps.count(postIfOp)) { + continue; + } + auto newDeps = ifResultDeps[postIfOp]; + for (int depValueIdx = 0; depValueIdx < newDeps.size(); depValueIdx++) { + Value depValue = newDeps[depValueIdx]; + + // Step 1. Locate the front IfOp that produces depValue + Operation *defOp = depValue.getDefiningOp(); + if (!defOp || !isa(defOp)) { + llvm::outs() << "Error: depValue is not produced by scf.if\n"; + break; + } + + scf::IfOp frontIfOp = cast(defOp); + + // Position of depValue in the IfOp results + auto result = dyn_cast(depValue); + if (!result) { + llvm::outs() << "depValue is not an OpResult!\n"; + return; + } + + int64_t depResultIndex = result.getResultNumber(); + + // Position of depValue in the IfOp results + Value depYieldValue = frontIfOp.thenYield()->getOperand(depResultIndex); + + // Step 2. Find the multi-buffer position in the ForOp + int64_t extraArgBaseIdx = + newForOp.getRegionIterArgs().size() - + (2 + bufferNum - 1) * (newUniqueDeps.size() - processedDepCount++); + + // Collect all buffers + SmallVector buffers; + + // buffer0 来自 else yield + buffers.push_back(frontIfOp.elseYield()->getOperand(depResultIndex)); + + // Other buffers come from for iter args + for (int i = 1; i < bufferNum; ++i) { + buffers.push_back(newForOp.getRegionIterArgs()[extraArgBaseIdx + i - 1]); + } + + // Two counters + Value frontCnt = + newForOp.getRegionIterArgs()[extraArgBaseIdx + bufferNum - 1]; + Value postCnt = newForOp.getRegionIterArgs()[extraArgBaseIdx + bufferNum]; + + // Step 3. Create constants (0 ~ bufferNum) for rem / cmp buffer selection logic + SmallVector constants; + builder.setInsertionPoint(frontIfOp); + + auto dataType = frontCnt.getType(); + for (int i = 0; i <= bufferNum; ++i) { + constants.push_back(builder.create( + frontIfOp.getLoc(), dataType, + builder.getIntegerAttr(dataType, i))); + } + + // Record the positions of newly added results in the IfOp + SmallVector extraResultIndices(bufferNum + 1); + extraResultIndices.clear(); + + // Step 4. Extend the front IfOp + scf::IfOp newFrontIfOp = addResultsForFrontIfOp( + frontIfOp, builder, bufferNum, depValue, constants, buffers, frontCnt, + postCnt, extraResultIndices, newDeps, ifResultDeps); + + // buffer result indices + SmallVector bufferResultIndices(extraResultIndices.begin(), + extraResultIndices.end() - 1); + + int frontCntResultIndex = extraResultIndices[bufferNum]; + + Value newDepValue = newFrontIfOp.getResult(depResultIndex); + + // Step 5. Find the post IfOp that consumes the dependency value + scf::IfOp postIfOp = nullptr; + + for (auto &use : newDepValue.getUses()) { + if (auto candidate = dyn_cast(use.getOwner()->getParentOp())) { + postIfOp = candidate; + break; + } + } + + if (!postIfOp) { + llvm::outs() << "Error: no consuming IfOp found.\n"; + return; + } + + // Step 6. Extend the post IfOp + + scf::IfOp newPostIfOp = addResultsForPostIfOp( + postIfOp, newFrontIfOp, builder, bufferNum, newDepValue, constants, + buffers, frontCnt, postCnt, extraResultIndices, newDeps, ifResultDeps); + + llvm::outs() << "after addResultsForPostIfOp.\n"; + + int postCntResultIndex = extraResultIndices.back(); + + // Step 7. Update the ForOp yield (buffer propagation) + auto forYield = cast(newForOp.getBody()->getTerminator()); + + // Update buffer1 ~ bufferN + for (int i = 1; i < bufferNum; ++i) { + + int yieldIdx = extraArgBaseIdx + (i - 1); + + if (yieldIdx < forYield->getNumOperands() && + bufferResultIndices[i] < newFrontIfOp.getNumResults()) { + + forYield->setOperand(yieldIdx, newFrontIfOp.getResult(bufferResultIndices[i])); + + llvm::outs() << "Replaced yield operand " << yieldIdx << "\n"; + } else { + llvm::errs() << "Warning: index out of range\n"; + } + } + + // Step 8. Update frontCnt + OpOperand *frontCntYieldUse = nullptr; + + for (auto &use : frontCnt.getUses()) { + if (isa(use.getOwner()) && + newForOp == use.getOwner()->getParentOp()) { + frontCntYieldUse = &use; + break; + } + } + + frontCntYieldUse->set(newFrontIfOp.getResult(frontCntResultIndex)); + + // Step 9. Update postCnt + OpOperand *postCntYieldUse = nullptr; + + for (auto &use : postCnt.getUses()) { + if (isa(use.getOwner()) && + newForOp == use.getOwner()->getParentOp()) { + postCntYieldUse = &use; + break; + } + } + + postCntYieldUse->set(newPostIfOp.getResult(postCntResultIndex)); + } + } + + + llvm::outs() << "multibuffer end!\n"; +} + +// Compute the nesting level of an ifOp within the specified forOp +static int computeIfLevel(scf::IfOp ifOp, scf::ForOp rootForOp) +{ + int level = 1; + + Operation *parent = ifOp->getParentOp(); + + while (parent && parent != rootForOp.getOperation()) { + if (isa(parent)) + level++; + + parent = parent->getParentOp(); + } + + return level; +} + +int assignIfOpLevels(scf::ForOp forOp) +{ + SmallVector targetIfOps; + int maxLevel = 0; + // Collect all ifOp assigned with ssbuffer tag + forOp.walk([&](scf::IfOp ifOp) { + if (ifOp->hasAttr("ssbuffer")) { + targetIfOps.push_back(ifOp); + } + }); + + // Caculate buffer levels + for (auto ifOp : targetIfOps) { + int level = computeIfLevel(ifOp, forOp); + maxLevel = std::max(level, maxLevel); + Builder builder(ifOp.getContext()); + ifOp->setAttr("ssbuffer.level", + builder.getI32IntegerAttr(level)); + } + return maxLevel; +} + +static bool hasSSBufferIf(scf::ForOp forOp) +{ + bool found = false; + + forOp.walk([&](scf::IfOp ifOp) { + if (ifOp->hasAttr("ssbuffer")) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + return found; +} + +static bool hasAncestorSSBufferFor(scf::ForOp forOp) +{ + Operation *parent = forOp->getParentOp(); + + while (parent) { + if (auto parentFor = dyn_cast(parent)) { + if (hasSSBufferIf(parentFor)) + return true; + } + parent = parent->getParentOp(); + } + + return false; +} + +static bool hasAncestorRootFor(scf::ForOp forOp) +{ + Operation *parent = forOp->getParentOp(); + + while (parent) { + if (auto parentFor = dyn_cast(parent)) { + if (hasSSBufferIf(parentFor)) + return true; + } + parent = parent->getParentOp(); + } + return false; +} + +SmallVector collectIfInfo( + scf::ForOp &curForOp, + DenseMap> &ifDeps, + int level) +{ + // Find all dependency variables based on the inputs and outputs of ifOp + SmallVector allDeps; + DenseSet producedValues; + scf::ForOp newForOp = nullptr; + curForOp.walk([&](scf::IfOp ifOp) { + auto attr = ifOp->getAttrOfType("ssbuffer.level"); + // No level or level mismatch → continue searching + if (!attr || attr.getInt() != level) + return WalkResult::advance(); + + // Levels match → check the direct parent + if (auto parentFor = dyn_cast(ifOp->getParentOp())) { + newForOp = parentFor; // 更新 + } + + // Stop walking regardless of whether the parent is a for-loop + return WalkResult::interrupt(); + }); + + if (newForOp) + curForOp = newForOp; + + // Step 1: Collect first to preserve order + SmallVector ifOps; + curForOp.walk([&](scf::IfOp ifOp) { + auto curLevel = ifOp->getAttrOfType("ssbuffer.level"); + if (!curLevel || curLevel.getInt() != level) { + return WalkResult::advance(); + } + ifOps.push_back(ifOp); + return WalkResult::advance(); + }); + llvm::outs()<<"ifOps:"<getOperands():"<getOperands().size()<<"\n"; + SmallVector deps; + if (producedValues.empty()) { + llvm::outs()<<"producedValues为空!"<<"\n"; + } + + // inputs + Region &thenRegion = ifOp.getThenRegion(); + for (Operation &op : thenRegion.front()) { + for (Value operand : op.getOperands()) { + for (Value v : producedValues) { + if (operand == v && !llvm::is_contained(deps, operand)) { + deps.push_back(operand); + } + } + } + } + + // outputs + for (Value result : ifOp.getResults()) { + producedValues.insert(result); + } + + if (!deps.empty()) { + ifDeps[ifOp] = deps; + allDeps.append(deps.begin(), deps.end()); + } + } + llvm::outs().flush(); + return allDeps; +} + +bool isCube(scope::ScopeOp scope) { + bool ret = false; + scope.walk([&](Operation *op) { + if (isa(op)) { + ret = true; + } + }); + return ret; +} + +// Traverse each Vector scope, find the outer ForOp, and process internal IfOps +void WalkAIVNestedForAndProcess( + ModuleOp module, DenseMap> &ifResultDeps, + int bufferNum) { + if (bufferNum < 2) { + return; + } + + module.walk([&](scope::ScopeOp scope) { + if (isCube(scope)) { + return; + } + + // Traverse ForOps inside the Cube scope (outer loops) + SmallVector targetFors; + + scope.walk([&](scf::ForOp forOp) { + + // Must contain an ssbuffer if + if (!hasSSBufferIf(forOp)) + return WalkResult::advance(); + + // Skip if an ancestor is already the root + if (hasAncestorRootFor(forOp)) + return WalkResult::advance(); + + // Find rootForOp + targetFors.push_back(forOp); + + return WalkResult::advance(); + }); + llvm::outs() << "targetFors: " << targetFors.size(); + int maxLevels; + for (auto outerFor : targetFors) { + ifResultDeps.clear(); + scf::ForOp currentFor = outerFor; + maxLevels = assignIfOpLevels(currentFor); + for (int level = 1; level <= maxLevels; level++) { + auto uniqueDeps = collectIfInfo(currentFor, ifResultDeps, level); + llvm::outs()<<"maxLevels:"<> newIfResultDeps; + auto uniqueList = collectIfInfo(newForOp, newIfResultDeps, level); + addMultiBuffCaculate(module, uniqueList, newIfResultDeps, newForOp, bufferNum); + } + } + }); +} + +void DAGSSBufferPass::runOnOperation() { + auto module = getOperation(); + + AddIfCondition(module); + + FlowSssbuf(module); + ControlSsbufV2(module); + + // advance不能出现在if里, 规避处理 + ChangeAdvanceOpForm(module); + + DenseMap> ifResultDeps; + WalkAIVNestedForAndProcess(module, ifResultDeps, 2); + + return; +} + +std::unique_ptr> +mlir::triton::createDAGSSBufferPass() { + return std::make_unique(); +} + diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp new file mode 100644 index 0000000000..0b17b913fe --- /dev/null +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp @@ -0,0 +1,1139 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "TritonAffinityOpt/Passes.h" + +#include "bishengir/Dialect/Scope/IR/Scope.h" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "bishengir/Dialect/HIVM/IR/HIVMImpl.h" +#include "bishengir/Dialect/HIVM/Transforms/Passes.h" +#include "bishengir/Dialect/HIVM/IR/HIVMInterfaces.h" +#include "bishengir/Dialect/HIVM/Utils/Utils.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Block.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" + +#include "Utils/Utils.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include + +#include "TritonAffinityOpt/DAG.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DAGSCOPE +#include "ascend/include/TritonAffinityOpt/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace hivm; + +namespace { +struct DAGScopePass + : public mlir::triton::impl::DAGScopeBase< + DAGScopePass> { + void runOnOperation() override; +}; +} // namespace + + +static std::pair encapsulateWithScope(triton::FuncOp funcOp) { + Block &entryBlock = funcOp.getBody().front(); + Block &lastBlock = funcOp.getBody().back(); + Operation *terminator = lastBlock.getTerminator(); + + + // 辅助函数:判断操作是否应该被跳过 + auto shouldSkipOp = [](Operation *op) -> bool { + return isa(op) || isa(op) || isa(op); + }; + + // 第三步:准备要移动的操作列表(按顺序) + SmallVector opsToMove; + DenseMap opOrder; + int order = 0; + + // 记录原始顺序并收集需要移动的操作 + for (Operation &op : lastBlock.without_terminator()) { + opOrder[&op] = order++; + if (!shouldSkipOp(&op)) { + opsToMove.push_back(&op); + } + } + + // 按原始顺序排序 + std::sort(opsToMove.begin(), opsToMove.end(), + [&](Operation *a, Operation *b) { + return opOrder[a] < opOrder[b]; + }); + + if (opsToMove.empty()) { + return std::make_pair(nullptr, nullptr); + } + + // 第四步:创建scope操作并移动操作 + Operation *lastOpToMove = opsToMove.back(); + OpBuilder builder(&lastBlock, ++lastOpToMove->getIterator()); + + // 创建第一个scope + auto scopeOp = builder.create(builder.getUnknownLoc(), llvm::ArrayRef{}); + scopeOp.getBodyRegion().emplaceBlock(); + Block *scopeBody = &scopeOp.getBodyRegion().front(); + + // 移动操作到scope中 + OpBuilder scopeBuilder(scopeBody, scopeBody->end()); + DenseMap valueMapping; + + for (Operation *op : opsToMove) { + SmallVector originalResults = op->getResults(); + op->remove(); + scopeBuilder.insert(op); + + // 更新值的映射 + for (size_t i = 0; i < originalResults.size(); ++i) { + valueMapping[originalResults[i]] = op->getResult(i); + } + } + + // 添加return操作 + scopeBuilder.create(builder.getUnknownLoc()); + + // 创建第二个scope(如果需要) + scopeBuilder.setInsertionPointAfter(scopeOp); + auto newScopeOp = scopeBuilder.create(builder.getUnknownLoc(), llvm::ArrayRef{}); + newScopeOp.getRegion().emplaceBlock(); + + OpBuilder newScopeBuilder(&newScopeOp.getRegion().front(), + newScopeOp.getRegion().front().begin()); + newScopeBuilder.create(scopeOp->getLoc()); + + // 设置属性 + auto vecAttr = hivm::TCoreTypeAttr::get( + builder.getContext(), + hivm::TCoreType::VECTOR); + auto aicAttr = hivm::TCoreTypeAttr::get( + builder.getContext(), + hivm::TCoreType::CUBE); + + scopeOp->setAttr(hivm::TCoreTypeAttr::name, vecAttr); + newScopeOp->setAttr(hivm::TCoreTypeAttr::name, aicAttr); + + return std::make_pair(scopeOp, newScopeOp); +} + +struct OpMoveInfo { + Operation* op; + Operation* targetParent; // 目标父操作(nullptr表示aicScope本身) + }; + +// 递归遍历函数 - 优化版本 +void collectOpsToMove(Operation* op, AffinityDAG::Graph& graph, + Operation* parentFor, llvm::SmallVector& aivToMove, llvm::SmallVector& cubeToMove) { + // 检查当前操作是否需要移动 + bool needsMoveAiv = false; + bool needsMoveCube = false; + auto& valueTypes = graph.getValueTypes(); + // 检查结果类型 + int i = 0; + for (auto res : op->getResults()) { + i++; + if (AffinityDAG::intersects(valueTypes[res], AffinityDAG::CoreType::VECTOR_ONLY)) { + needsMoveAiv = true; + } + if (AffinityDAG::intersects(valueTypes[res], AffinityDAG::CoreType::CUBE_ONLY)) { + needsMoveCube = true; + } + } + + if (isa(op)) { + auto res = op->getOperand(0); + if (AffinityDAG::intersects(valueTypes[res], AffinityDAG::CoreType::VECTOR_ONLY)) { + needsMoveAiv = true; + } + if (AffinityDAG::intersects(valueTypes[res], AffinityDAG::CoreType::CUBE_ONLY)) { + needsMoveCube = true; + } + } + // 检查特定操作类型 + if (isa(op)) { + needsMoveAiv = true; + } + + // 检查特定操作类型 + if (isa(op)) { + needsMoveCube = true; + } + + // 检查特定操作类型 + if (isa(op) || isa(op) || isa(op)) { + needsMoveAiv = true; + needsMoveCube = true; + } + + if (isa(op)) { + if (auto storeOp = dyn_cast(op)) { + // 获取所有操作数列表 + auto operands = storeOp.getOperands(); + bool typeMatched = false; + + // 按顺序检查第1个、第0个、第2个操作数 + std::vector checkOrder = {1, 0, 2}; + for (size_t idx : checkOrder) { + // 先判断操作数索引是否有效,避免越界访问 + if (idx >= operands.size()) { + continue; + } + auto operand = operands[idx]; + auto coreType = valueTypes[operand]; + + if (coreType == AffinityDAG::CoreType::VECTOR_ONLY) { + needsMoveAiv = true; + typeMatched = true; + } else if (coreType == AffinityDAG::CoreType::CUBE_ONLY) { + needsMoveCube = true; + typeMatched = true; + } + } + // 所有指定操作数都不匹配时,执行原else逻辑 + if (!typeMatched) { + needsMoveAiv = true; + needsMoveCube = true; + } + } + } + + if (isa(op)) { + if (auto assertOp = dyn_cast(op)) { + // 获取所有操作数列表 + auto operand = assertOp.getCondition(); + + auto coreType = valueTypes[operand]; + if (coreType == AffinityDAG::CoreType::VECTOR_ONLY) { + needsMoveAiv = true; + } else if (coreType == AffinityDAG::CoreType::CUBE_ONLY) { + needsMoveCube = true; + } else { + needsMoveAiv = true; + needsMoveCube = true; + } + } + } + + // 检查 Sync 操作的 tcore_type 属性 + if ((isa(op) || isa(op))) { + mlir::OpBuilder builder(op); + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + if (op->getAttr("tcore_type") == coreAttr) { + needsMoveCube = true; + } + else { + needsMoveAiv = true; + } + } + + // 如果不需要移动,直接返回 + if (!needsMoveAiv && !needsMoveCube) { + llvm::outs()<<"Unsupport Op: "<< *op<<" \n"; + } + + // 处理 for 循环 + if (auto forOp = dyn_cast(op)) { + // 确定父级 for 循环 + Operation* targetParent = parentFor != nullptr ? parentFor : nullptr; + aivToMove.push_back({op, targetParent}); + cubeToMove.push_back({op, targetParent}); + + // 递归处理循环体 + for (auto &block : forOp.getRegion()) { + for (auto &innerOp : block) { + collectOpsToMove(&innerOp, graph, forOp, aivToMove, cubeToMove); + } + } + } else if (auto ifOp = dyn_cast(op)) { + // 确定父级 for 循环 + Operation* targetParent = parentFor != nullptr ? parentFor : nullptr; + aivToMove.push_back({op, targetParent}); + cubeToMove.push_back({op, targetParent}); + + // 递归处理循环体 + for (auto &block : ifOp.getThenRegion()) { + for (auto &innerOp : block) { + collectOpsToMove(&innerOp, graph, ifOp, aivToMove, cubeToMove); + } + } + + // 检查并遍历IfOp的else分支(如果存在) + for (auto &block : ifOp.getElseRegion()) { + for (auto &innerOp : block) { + collectOpsToMove(&innerOp, graph, ifOp, aivToMove, cubeToMove); + } + } + } else { + if (needsMoveAiv) { + // 处理其他操作 + aivToMove.push_back({op, parentFor}); + } + if (needsMoveCube) { + cubeToMove.push_back({op, parentFor}); + } + + } +} + +mlir::Block* getBlockByIndex(mlir::Region& region, int blockIndex) { + // 边界校验:索引非法时返回nullptr + if (blockIndex < 0) return nullptr; + + int currentIdx = 0; + for (auto& block : region) { + if (currentIdx == blockIndex) { + return █ // 找到对应索引的Block,直接返回 + } + currentIdx++; + } + // 索引越界时返回nullptr + return nullptr; +} + +void processOperationToMove(const OpMoveInfo& info, + llvm::DenseMap& parentMap, + mlir::OpBuilder& builder, + mlir::IRMapping& mapper, + mlir::Block* aivBlock, + mlir::Operation* terminator, + AffinityDAG::Graph& graph, + int MoveType) { + // llvm::outs()<<*info.op<<" ssss\n\n\n"; + // llvm::outs().flush(); + // 获取原始Block信息并计算索引 + mlir::Block* originalBlock = info.op->getBlock(); + int originalRegionIndex = -1; + int originalBlockIndex = -1; + int blockCounter = 0; + auto& valueTypes = graph.getValueTypes(); + if (originalBlock) { + mlir::Operation* parentOp = info.op->getParentOp(); // 原始父操作 + if (parentOp) { // 确保父操作存在 + // 老版本MLIR用 getParent() 替代 getParentRegion(),返回值就是Region* + mlir::Region* blockBelongsToRegion = originalBlock->getParent(); + int regionCounter = 0; + for (auto& region : parentOp->getRegions()) { // 遍历父操作的所有region + // 直接对比指针,判断当前region是否是block所属的region + if (®ion == blockBelongsToRegion) { + originalRegionIndex = regionCounter; + break; + } + regionCounter++; + } + } + } + + if (originalBlock) { + for (auto& block : originalBlock->getParent()->getBlocks()) { + if (&block == originalBlock) { + originalBlockIndex = blockCounter; + break; + } + blockCounter++; + } + } + + if (originalBlockIndex == -1) { + originalBlockIndex = 0; + } + if (originalRegionIndex == -1) { + originalRegionIndex = 0; + } + + // 处理 scf::ForOp 类型操作 + if (mlir::isa(info.op)) { + auto forOp = mlir::cast(info.op); + + auto getMapped = [&](mlir::Value v) { return mapper.lookupOrDefault(v); }; + auto inputs = forOp.getInitArgs(); + auto outputs = forOp.getResults(); + + // 分离需要移动到aivScope的参数 + llvm::SmallVector aivInputs; + llvm::DenseMap aivInputsMap; + int aivIndex = 1; + + for (int i = 0; i < inputs.size(); ++i) { + if (valueTypes[outputs[i]] != MoveType) { + aivInputs.push_back(inputs[i]); + aivInputsMap[i + 1] = aivIndex; + aivIndex++; + } + } + + // 创建新的for循环 + auto aivForOp = builder.create( + forOp.getLoc(), + getMapped(forOp.getLowerBound()), + getMapped(forOp.getUpperBound()), + getMapped(forOp.getStep()), + llvm::to_vector(llvm::map_range(aivInputs, getMapped)) + ); + + // 清空循环体 + if (!aivForOp.getBody()->empty()) { + aivForOp.getBody()->getTerminator()->erase(); + } + + // 处理原始循环的yield操作 + auto oldBody = forOp.getBody(); + auto oldYield = mlir::dyn_cast(oldBody->getTerminator()); + assert(oldYield && "scf::ForOp must have a yield terminator"); + + llvm::SmallVector aivYieldOperands; + for (int i = 0; i < inputs.size(); ++i) { + if (valueTypes[outputs[i]] != MoveType) { + aivYieldOperands.push_back(oldYield.getOperand(i)); + } + } + + // 映射循环参数 + auto oldBodyArgs = forOp.getBody()->getArguments(); + auto aivBodyArgs = aivForOp.getBody()->getArguments(); + + for (auto it = aivInputsMap.begin(); it != aivInputsMap.end(); ++it) { + int oldInputIndex = it->first; + int mappedNewIndex = it->second; + mapper.map(oldBodyArgs[oldInputIndex], aivBodyArgs[mappedNewIndex]); + mapper.map((*info.op).getResults()[oldInputIndex - 1], aivForOp->getResults()[mappedNewIndex - 1]); + } + mapper.map(oldBodyArgs[0], aivBodyArgs[0]); + + // 将新循环移动到目标位置 + if (info.targetParent == nullptr) { + mlir::Block* targetBlock = aivBlock; + if (terminator) { + aivForOp->moveBefore(terminator); + } else { + aivForOp->moveBefore(targetBlock, targetBlock->end()); + } + parentMap[forOp] = aivForOp; + } else { + auto targetParent = parentMap[info.targetParent]; + auto& region = targetParent->getRegion(originalRegionIndex); + + if (region.empty()) { + region.push_back(new mlir::Block()); + } + + mlir::Block* targetBlock = getBlockByIndex(region, originalBlockIndex); + if (targetBlock) { + aivForOp->moveBefore(targetBlock, targetBlock->end()); + parentMap[forOp] = aivForOp; + } else { + llvm::outs()<<"Can't find block by index\n"; + } + + } + } + + // 处理 scf::YieldOp 类型操作 + else if (mlir::isa(info.op)) { + auto yieldOp = mlir::cast(info.op); + + // 处理父节点为 scf::ForOp 的情况 + if (auto parentForOp = mlir::dyn_cast(info.targetParent)) { + auto it = parentMap.find(parentForOp); + if (it == parentMap.end()) { + return; + } + auto targetOp = it->second; + auto newForOp = mlir::cast(targetOp); + + auto oldInputs = parentForOp.getInitArgs(); + auto oldOutputs = parentForOp.getResults(); + auto oldYieldOperands = yieldOp.getOperands(); + + llvm::SmallVector newYieldOperands; + for (int i = 0; i < oldInputs.size(); ++i) { + if (valueTypes[oldOutputs[i]] != MoveType) { + mlir::Value oldOperand = oldYieldOperands[i]; + mlir::Value newOperand = mapper.lookupOrDefault(oldOperand); + newYieldOperands.push_back(newOperand); + } + } + + auto newYieldOp = builder.create(yieldOp.getLoc(), newYieldOperands); + auto& region = newForOp->getRegion(0); + mlir::Block* targetBlock = ®ion.front(); + newYieldOp->moveBefore(targetBlock, targetBlock->end()); + } + // 处理父节点为 scf::IfOp 的情况 + else if (auto parentIfOp = mlir::dyn_cast(info.targetParent)) { + auto it = parentMap.find(parentIfOp); + if (it == parentMap.end()) { + return; + } + auto targetOp = it->second; + auto newIfOp = mlir::cast(targetOp); + + auto oldInputs = parentIfOp.getResults(); + auto oldOutputs = parentIfOp.getResults(); + auto oldYieldOperands = yieldOp.getOperands(); + + llvm::SmallVector newYieldOperands; + for (int i = 0; i < oldInputs.size(); ++i) { + if (valueTypes[oldOutputs[i]] != MoveType) { + mlir::Value oldOperand = oldYieldOperands[i]; + mlir::Value newOperand = mapper.lookupOrDefault(oldOperand); + newYieldOperands.push_back(newOperand); + } + } + + auto& region = newIfOp->getRegion(originalRegionIndex); + auto newYieldOp = builder.create(yieldOp.getLoc(), newYieldOperands); + mlir::Block* targetBlock = getBlockByIndex(region, originalBlockIndex); + if (targetBlock) { + newYieldOp->moveBefore(targetBlock, targetBlock->end()); + } else { + llvm::outs()<<"Can't find block by index\n"; + } + } + } + + // 处理 scf::IfOp 类型操作 + else if (mlir::isa(info.op)) { + auto ifOp = mlir::cast(info.op); + + auto getMapped = [&](mlir::Value v) { return mapper.lookupOrDefault(v); }; + mlir::Value condition = ifOp.getCondition(); + + // 分离需要移动到aivScope的结果 + llvm::SmallVector aivResults; + llvm::SmallVector aivResultTypes; + llvm::DenseMap aivResultMap; + int aivResultIndex = 0; + + for (int i = 0; i < ifOp.getNumResults(); ++i) { + mlir::Value result = ifOp.getResult(i); + if (valueTypes[result] != MoveType) { + aivResults.push_back(result); + aivResultTypes.push_back(result.getType()); + aivResultMap[i] = aivResultIndex; + aivResultIndex++; + } + } + + // 创建新的if操作 + auto aivIfOp = builder.create( + ifOp.getLoc(), + aivResultTypes, + getMapped(condition) + ); + + // 映射if操作结果 + for (auto& [oldIdx, newIdx] : aivResultMap) { + mapper.map(ifOp.getResult(oldIdx), aivIfOp.getResult(newIdx)); + } + + // 初始化then和else区域 + mlir::Region& thenRegion = aivIfOp.getThenRegion(); + mlir::Block* thenBlock = new mlir::Block(); + thenRegion.push_back(thenBlock); + + mlir::Region& elseRegion = ifOp.getElseRegion(); + if (!elseRegion.empty()) { + mlir::Region& elseRegion = aivIfOp.getElseRegion(); + mlir::Block* elseBlock = new mlir::Block(); + elseRegion.push_back(elseBlock); + } + + // 将新if操作移动到目标位置 + if (info.targetParent == nullptr) { + mlir::Block* targetBlock = aivBlock; + if (terminator) { + aivIfOp->moveBefore(terminator); + } else { + aivIfOp->moveBefore(targetBlock, targetBlock->end()); + } + parentMap[ifOp] = aivIfOp; + } else { + auto& region = parentMap[info.targetParent]->getRegion(originalRegionIndex); + if (region.empty()) { + region.push_back(new mlir::Block()); + } + + mlir::Block* targetBlock = getBlockByIndex(region, originalBlockIndex); + if (targetBlock) { + aivIfOp->moveBefore(targetBlock, targetBlock->end()); + parentMap[ifOp] = aivIfOp; + } else { + llvm::outs()<<"Can't find block by index\n"; + } + } + } + + // 处理其他类型操作(克隆) + else { + auto clonedOp = builder.clone(*info.op, mapper); + auto numberRes = clonedOp->getNumResults(); + for (auto i = 0; i < numberRes; i++) { + mapper.map((*info.op).getResults()[i], clonedOp->getResults()[i]); + } + + if (info.targetParent == nullptr) { + mlir::Block* targetBlock = aivBlock; + clonedOp->moveBefore(terminator); + parentMap[info.op] = clonedOp; + } else { + auto parentIt = parentMap.find(info.targetParent); + auto mappedParentOp = parentIt->second; + auto& region = mappedParentOp->getRegion(originalRegionIndex); + + if (region.empty()) { + region.push_back(new mlir::Block()); + } + + mlir::Block* targetBlock = getBlockByIndex(region, originalBlockIndex); + if (targetBlock) { + clonedOp->moveBefore(targetBlock, targetBlock->end()); + } else { + llvm::outs()<<"Can't find block by index\n"; + } + + } + } +} + +static void SplitScope(triton::FuncOp funcOp, AffinityDAG::Graph& graph, Operation* aivScope, Operation* aicScope, ModuleOp module) { + llvm::SmallVector aivToMove; + llvm::SmallVector cubeToMove; + for (auto &block : aivScope->getRegion(0)) { + for (auto &op : block) { + collectOpsToMove(&op, graph, nullptr, aivToMove, cubeToMove); + } + } + mlir::IRMapping aivmapper; + mlir::OpBuilder builder(aivScope); + llvm::DenseMap aivparentMap; + + // 第二遍:实际移动操作 + // 先移动for循环 + mlir::Block* aivBlock = &aivScope->getRegion(0).front(); // 或者使用合适的block + SmallVector deleteOp; + auto* terminator = aivBlock->getTerminator(); + // 如果操作已被使用,直接跳过 + llvm::SmallVector aivUsedOp; // 改为函数内静态,保持原有逻辑 + for (const auto& info : aivToMove) { + if (std::find(aivUsedOp.begin(), aivUsedOp.end(), info.op) != aivUsedOp.end()) { + return; + } + aivUsedOp.push_back(info.op); + processOperationToMove(info, aivparentMap, builder, aivmapper, aivBlock, terminator, graph, AffinityDAG::CoreType::CUBE_ONLY); + } + + llvm::DenseMap aicparentMap; + mlir::IRMapping aicmapper; + mlir::Block* aicBlock = &aicScope->getRegion(0).front(); // 或者使用合适的block + terminator = aicBlock->getTerminator(); + llvm::SmallVector aicUsedOp; // 改为函数内静态,保持原有逻辑 + for (const auto& info : cubeToMove) { + if (std::find(aicUsedOp.begin(), aicUsedOp.end(), info.op) != aicUsedOp.end()) { + return; + } + aicUsedOp.push_back(info.op); + processOperationToMove(info, aicparentMap, builder, aicmapper, aicBlock, terminator, graph, AffinityDAG::CoreType::VECTOR_ONLY); + } + + for (const auto& info : aivToMove) { + if (std::find(deleteOp.begin(), deleteOp.end(), info.op) == deleteOp.end()) { + deleteOp.push_back(info.op); + } + } + for (const auto& info : cubeToMove) { + if (std::find(deleteOp.begin(), deleteOp.end(), info.op) == deleteOp.end()) { + deleteOp.push_back(info.op); + } + } + + // llvm::outs() << "\n" << module<<" ====== ddd ====== \n\n\n"; + // llvm::outs().flush(); + for (auto it = deleteOp.rbegin(); it != deleteOp.rend(); ++it) { + (*it)->erase(); // 解引用反向迭代器,调用 erase 方法 + } + return; + +} + + /// 创建setop + static hivm::SyncBlockSetOp createSyncBlockSetOp( + OpBuilder &builder, + Location loc, + hivm::TCoreType coreType, + hivm::PIPE setPipeEnum, + hivm::PIPE waitPipeEnum, + int64_t flag) { + MLIRContext *ctx = builder.getContext(); + auto coreAttr = hivm::TCoreTypeAttr::get(ctx, coreType); + auto setPipe = hivm::PipeAttr::get(ctx, setPipeEnum); + auto waitPipe = hivm::PipeAttr::get(ctx, waitPipeEnum); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + return builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } + + /// 创建waitop + static hivm::SyncBlockWaitOp createSyncBlockWaitOp( + OpBuilder &builder, + Location loc, + hivm::TCoreType coreType, + hivm::PIPE setPipeEnum, + hivm::PIPE waitPipeEnum, + int64_t flag) { + MLIRContext *ctx = builder.getContext(); + auto coreAttr = hivm::TCoreTypeAttr::get(ctx, coreType); + auto setPipe = hivm::PipeAttr::get(ctx, setPipeEnum); + auto waitPipe = hivm::PipeAttr::get(ctx, waitPipeEnum); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + return builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } + + // 在scope return前插入wait + static void insertWaitBeforeFinalReturn(Region *region, OpBuilder &builder, int64_t flag, bool coretypebool) { + for (Block &block : *region) { + if (auto returnOp = dyn_cast_or_null(block.getTerminator())) { + builder.setInsertionPoint(returnOp); + if (coretypebool) { + createSyncBlockWaitOp( + builder, + returnOp->getLoc(), + hivm::TCoreType::CUBE, + hivm::PIPE::PIPE_V, + hivm::PIPE::PIPE_FIX, + flag + ); + return; + } + else { + createSyncBlockWaitOp( + builder, + returnOp->getLoc(), + hivm::TCoreType::VECTOR, + hivm::PIPE::PIPE_M, + hivm::PIPE::PIPE_MTE3, + flag + ); + return; + } + } + } + } + + /// 在scope内起始位置加上set + static void insertSetAtRegionStart(Region *region, OpBuilder &builder, int64_t flag, bool coretypebool) { + if (!region->empty()) { + Block &entry = region->front(); + Location loc = entry.empty() ? region->getParentOp()->getLoc() : entry.front().getLoc(); + builder.setInsertionPointToStart(&entry); + if (coretypebool) { + createSyncBlockSetOp( + builder, + loc, + hivm::TCoreType::VECTOR, + hivm::PIPE::PIPE_V, + hivm::PIPE::PIPE_FIX, + flag + ); + } + else { + createSyncBlockSetOp( + builder, + loc, + hivm::TCoreType::CUBE, + hivm::PIPE::PIPE_M, + hivm::PIPE::PIPE_MTE3, + flag + ); + } + } + } + + static Operation *findNextSyncBlockSetAfter(Operation *startOp) { + Block *block = startOp->getBlock(); + auto it = ++startOp->getIterator(); + for (; it != block->end(); ++it) { + if (isa(*it)) + return &*it; + } + return nullptr; + } + + static hivm::SyncBlockWaitOp findWaitOpInRegionWithFlag(Region *region, int64_t flag) { + hivm::SyncBlockWaitOp result; + region->walk([&](hivm::SyncBlockWaitOp op) { + auto flagAttr = op->getAttrOfType("static_flag_id"); + if (flagAttr && flagAttr.getInt() == flag) { + result = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; + } + + static Operation *findInsertionPointAfterWaitForAIV(Operation *waitOp) { + Block *block = waitOp->getBlock(); + auto it = ++waitOp->getIterator(); + + for (; it != block->end(); ++it) { + if (isa(*it) || isa(*it)) { + break; + } + } + + while (it != block->begin()) { + auto prevIt = std::prev(it); + if (isa(*prevIt)) { + it = prevIt; + } else { + break; + } + } + + return &*it; + } + + static Operation *findInsertionPointAfterWaitForAIC(Operation *waitOp) { + Block *block = waitOp->getBlock(); + auto it = ++waitOp->getIterator(); + for (; it != block->end(); ++it) { + if (auto fixpipe = dyn_cast(*it)) { + if (it != block->begin()) { + auto prev = std::prev(it); + if (isa(*prev)) + return &*prev; + } + return &*it; + } + if (isa(*it)) + return &*it; + } + return nullptr; + } + + // 查找 FixpipeOp 下一行的 sync_block_set 操作的 flag 值 + static int findFixPipeFlagSafe(hivm::FixpipeOp fixpipeOp) { + mlir::Operation *fixpipeOperation = fixpipeOp.getOperation(); + if (!fixpipeOperation || !fixpipeOperation->getBlock()) { + return -1; + } + + // 获取 FixpipeOp 的迭代器 + auto it = ++fixpipeOperation->getIterator(); + + // 遍历后续操作直到找到 sync_block_set + while (it != fixpipeOperation->getBlock()->end()) { + mlir::Operation &op = *it++; + + if (op.getName().getStringRef() == "hivm.hir.sync_block_set") { + auto staticFlagAttr = op.getAttrOfType("static_flag_id"); + return staticFlagAttr.getInt(); + break; + } + } + + return -1; + + } + + /// cube处理逻辑 + static void processFixpipeOpsInAIC( + Region *aicRegion, + Region *aivRegion) { + + MLIRContext *ctx = aicRegion->getContext(); + OpBuilder builder(ctx); + SmallVector fixpipes; + aicRegion->walk([&](hivm::FixpipeOp op) { + fixpipes.push_back(op); + }); + + + for (auto fixpipeOp : fixpipes) { + + auto newflag = findFixPipeFlagSafe(fixpipeOp); + // 1. 在 FixpipeOp 前插 Wait + builder.setInsertionPoint(fixpipeOp); + createSyncBlockWaitOp( + builder, + fixpipeOp->getLoc(), + hivm::TCoreType::CUBE, + hivm::PIPE::PIPE_V, + hivm::PIPE::PIPE_FIX, + newflag); + bool coretypebool = true; + + // 2. 在 aicRegion 末尾 Return 前插 Wait + insertWaitBeforeFinalReturn(aicRegion, builder, newflag, coretypebool); + + // 3. 在 aivRegion 开头插 Set + insertSetAtRegionStart(aivRegion, builder, newflag, coretypebool); + + // 4. 在 aicRegion 向后找 SyncBlockSetOp + if (auto *nextSetOp = findNextSyncBlockSetAfter(fixpipeOp)) { + auto setFlagAttr = nextSetOp->getAttrOfType("static_flag_id"); + // 调试:打印set + // llvm::dbgs() << "aicnextSetOp:"; + // nextSetOp->dump(); + if (!setFlagAttr) { + llvm::dbgs() << "AIC can not find setop in aic\n"; + continue; + } + int64_t setflag = setFlagAttr.getInt(); + + // 5. 在 aivRegion 中找 flag=setflag 的 WaitOp + auto targetWait = findWaitOpInRegionWithFlag(aivRegion, setflag); + if (!targetWait) { + llvm::dbgs() << "AIC can not find waitop in aiv\n"; + continue; + } + + // 调试:打印wait + // llvm::dbgs() << "aictargetWait:"; + // llvm::dbgs() << targetWait << "\n"; + + // 6. 从该 Wait 向下找 ToMemrefOp 或 Yield,插 Set(newflag) + if (auto *insertPt = findInsertionPointAfterWaitForAIV(targetWait)) { + builder.setInsertionPoint(insertPt); + createSyncBlockSetOp( + builder, + fixpipeOp->getLoc(), + hivm::TCoreType::VECTOR, + hivm::PIPE::PIPE_V, + hivm::PIPE::PIPE_FIX, + newflag); + } + } + } + } + + // 查找 copyOp 下一行的 sync_block_set 操作的 flag 值 + static int findCopyFlagSafe(bufferization::ToMemrefOp toMemrefOp) { + mlir::Operation *toMemrefOperation = toMemrefOp.getOperation(); + if (!toMemrefOperation || !toMemrefOperation->getBlock()) { + return -1; + } + + // 获取 copyOp 的迭代器 + auto it = ++toMemrefOperation->getIterator(); + + // 遍历后续操作直到找到 sync_block_set + while (it != toMemrefOperation->getBlock()->end()) { + mlir::Operation &op = *it++; + + if (op.getName().getStringRef() == "hivm.hir.sync_block_set") { + auto staticFlagAttr = op.getAttrOfType("static_flag_id"); + return staticFlagAttr.getInt(); + break; + } + } + + return -1; + + } + /// vector处理逻辑 + static void processToMemrefOpsInAIV( + Region *aivRegion, + Region *aicRegion) { + + MLIRContext *ctx = aivRegion->getContext(); + OpBuilder builder(ctx); + SmallVector toMemrefs; + aivRegion->walk([&](bufferization::ToMemrefOp op) { + toMemrefs.push_back(op); + }); + + for (auto toMemrefOp : toMemrefs) { + auto newflag = findCopyFlagSafe(toMemrefOp); + + // 1. 在 ToMemrefOp 前插 Wait + builder.setInsertionPoint(toMemrefOp); + createSyncBlockWaitOp( + builder, + toMemrefOp->getLoc(), + hivm::TCoreType::VECTOR, + hivm::PIPE::PIPE_M, + hivm::PIPE::PIPE_MTE3, + newflag); + bool coretypebool = false; + + // 2. 在 aivRegion 末尾 Return 前插 Wait + insertWaitBeforeFinalReturn(aivRegion, builder, newflag, coretypebool); + + // 3. 在 aicRegion 开头插 Set + insertSetAtRegionStart(aicRegion, builder, newflag, coretypebool); + + // 4. 在 aivRegion 向后找 SyncBlockSetOp + if (auto *nextSetOp = findNextSyncBlockSetAfter(toMemrefOp)) { + auto setFlagAttr = nextSetOp->getAttrOfType("static_flag_id"); + // 调试:打印set及其所有attribute + // llvm::dbgs() << "aivnextSetOp:"; + // nextSetOp->dump(); + // llvm::dbgs() << "Attributes:\n"; + // for (auto namedAttr : nextSetOp->getAttrs()) { + // llvm::dbgs() << " " << namedAttr.getName() << " = "; + // namedAttr.getValue().print(llvm::dbgs()); + // llvm::dbgs() << "\n"; + // } + if (!setFlagAttr) { + llvm::dbgs() << "AIV can not find setop in aiv\n"; + continue; + } + int64_t setflag = setFlagAttr.getInt(); + + // 5. 在 aicRegion 中找 flag=setflag 的 WaitOp + auto targetWait = findWaitOpInRegionWithFlag(aicRegion, setflag); + + if (!targetWait) { + llvm::dbgs() << "AIV can not find waitop in aic\n"; + continue; + } + + // 调试:打印wait + // llvm::dbgs() << "aivtargetWait:"; + // llvm::dbgs() << targetWait << "\n"; + + // 6. 从该 Wait 向下找 Fixpipe 前 Wait 或 Yield,插 Set(newflag) + if (auto *insertPt = findInsertionPointAfterWaitForAIC(targetWait)) { + builder.setInsertionPoint(insertPt); + createSyncBlockSetOp( + builder, + toMemrefOp->getLoc(), + hivm::TCoreType::CUBE, + hivm::PIPE::PIPE_M, + hivm::PIPE::PIPE_MTE3, + newflag); + } + } + } + } + + /// 同步点增强 + void addSyncOpsForBufferWait(ModuleOp module) { + for (auto funcOp : llvm::make_early_inc_range(module.getOps())) { + if (funcOp.getBody().empty()) { + continue; + } + + Region *aicRegion = nullptr; + Region *aivRegion = nullptr; + + funcOp.walk([&](scope::ScopeOp scopeOp) { + auto coreTypeAttr = scopeOp->getAttrOfType( + hivm::TCoreTypeAttr::name); + if (!coreTypeAttr) return; + + if (coreTypeAttr.getTcoretype() == hivm::TCoreType::CUBE) { + aicRegion = &scopeOp.getRegion(); + } + if (coreTypeAttr.getTcoretype() == hivm::TCoreType::VECTOR) { + aivRegion = &scopeOp.getRegion(); + } + }); + + if (!aicRegion || !aivRegion) { + continue; + } + + processFixpipeOpsInAIC(aicRegion, aivRegion); + processToMemrefOpsInAIV(aivRegion, aicRegion); + } + } + + +void DAGScopePass::runOnOperation() { + auto module = getOperation(); + // llvm::outs()<())) { + // skip invalid function + if (funcOp.getBody().empty()) { + continue; + } + + // 收集所有 memref.alloc 操作 + llvm::SmallVector allocOps; + + // 遍历函数中的所有操作(包括嵌套区域中的操作) + funcOp.walk([&](mlir::Operation *op) { + if (mlir::isa(op)) { + allocOps.push_back(op); + } + }); + + mlir::Block& entryBlock = funcOp.getBody().front(); + mlir::Block::iterator insertPos = entryBlock.begin(); + + // 将 alloc 操作移动到函数的最前面 + for (mlir::Operation* allocOp : allocOps) { + // 如果 alloc 操作已经是最前面的操作,跳过 + if (allocOp->getBlock() == &entryBlock && + allocOp->isBeforeInBlock(&*insertPos)) { + continue; + } + + // 将 alloc 操作移动到指定位置 + allocOp->moveBefore(&entryBlock, insertPos); + } + + auto funcName = funcOp.getName(); + auto* graph_ptr = AffinityDAG::GraphManager::getInstance().getGraph(funcName); + if (!graph_ptr) { + continue; + } + auto& main_graph = *graph_ptr; + + + auto ScopeList = encapsulateWithScope(funcOp); + auto aivScope = ScopeList.first; // 第一个元素 + auto aicScope = ScopeList.second; // 第二个元素 + + SplitScope(funcOp, main_graph, aivScope, aicScope, module); + } + + addSyncOpsForBufferWait(module); + // llvm::outs()<> +mlir::triton::createDAGScopePass() { + return std::make_unique(); +} + diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp new file mode 100644 index 0000000000..17568cc9f0 --- /dev/null +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp @@ -0,0 +1,1333 @@ +#include "TritonAffinityOpt/Passes.h" + +#include "bishengir/Dialect/Annotation/IR/Annotation.h" +#include "bishengir/Dialect/Scope/IR/Scope.h" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "bishengir/Dialect/HIVM/IR/HIVMImpl.h" +#include "bishengir/Dialect/HIVM/IR/HIVMInterfaces.h" +#include "bishengir/Dialect/HIVM/Transforms/Passes.h" +#include "bishengir/Dialect/HIVM/Utils/Utils.h" +#include "bishengir/Dialect/Scope/IR/Scope.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/Support/Casting.h" + +#include "Utils/Utils.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" +#include +#include + +#include "TritonAffinityOpt/DAG.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DAGSYNC +#include "ascend/include/TritonAffinityOpt/Passes.h.inc" +} // namespace triton +} // namespace mlir + +// 使用 DAG 命名空间 +using namespace mlir; +using namespace hivm; +using namespace AffinityDAG; + +llvm::DenseMap* valueTypes; +// 修改类声明,将数据搬运逻辑集成到同步插入中 +namespace { +struct DAGSyncPass : public mlir::triton::impl::DAGSyncBase { + void runOnOperation() override; + +private: + // 原有的辅助函数 + CoreType getNodeDeviceType(OpNode *node, llvm::DenseMap *valueTypes); + bool needVectorCubeSync(CoreType src, CoreType dst); + + // 修改后的同步插入函数,包含数据搬运 + void insertSyncAndMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, + CoreType srcType, CoreType dstType, + mlir::OpBuilder &builder, int flag, llvm::DenseMap* valueMap, Graph &mainGraph); + + // 新增:处理跨 block 的同步和数据搬运 + void insertSyncAndMovementForCrossBlock(mlir::Operation *srcOp, mlir::Operation *dstOp, + CoreType srcType, CoreType dstType, + mlir::OpBuilder &builder, int flag, + bool dstIsInnerBlock, llvm::DenseMap* valueMap, Graph &mainGraph); + + // 新增:处理 scf.for 循环迭代参数的同步 + void processScfForSync(mlir::scf::ForOp forOp, + Node* forNode, + llvm::DenseMap *valueTypes, + mlir::OpBuilder &builder, + int &flag); + + // 数据搬运相关的辅助函数 + void insertCubeToVectorDataMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, + mlir::Value srcResult, mlir::OpBuilder &builder, + mlir::Location loc, mlir::Value iterArgs); + + void insertVectorToCubeDataMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, Operation * posOp, + mlir::Value srcResult, mlir::OpBuilder &builder, + mlir::Location loc, llvm::DenseMap* valueMap); + + // 获取或创建合适的 memref.alloc + mlir::Value getOrCreateAllocation(mlir::Operation *op, mlir::Type tensorType, + hivm::AddressSpace addressSpace, + mlir::OpBuilder &builder, mlir::Location loc); + + // 获取 tensor 的形状和元素类型 + mlir::RankedTensorType getTensorType(mlir::Value tensorValue); + + // 替换 dstOp 中使用 srcResult 的操作数 + void replaceOperandWithNewValue(mlir::Operation *dstOp, mlir::Value oldValue, + mlir::Value newValue); + + // Find sync position + Operation* FindLastestPosition(Operation* srcOp, Graph &mainGraph, OpBuilder &builder); + Operation* FindEarliestPosition(Operation* dstOp, Graph &mainGraph, OpBuilder &builder); +}; +} // namespace + +void DAGSyncPass::processScfForSync(mlir::scf::ForOp forOp, + Node* forNode, + llvm::DenseMap *valueTypes, + mlir::OpBuilder &builder, + int &flag) { + + mlir::Block* loopBody = forOp.getBody(); + mlir::scf::YieldOp yieldOp = nullptr; + for (mlir::Operation &op : *loopBody) { + if (auto yield = mlir::dyn_cast(&op)) { + yieldOp = yield; + break; + } + } + Location loc = forOp.getLoc(); + + for (int i = 0; i < forOp.getInitArgs().size(); i++) { + mlir::BlockArgument iterArg = loopBody->getArgument(i+1); + // 找到首次使用 + mlir::Operation* firstUser = nullptr; + + for (mlir::Operation &op : *loopBody) { + // 跳过 yield 操作 + if (mlir::isa(&op)) { + continue; + } + + // 检查是否使用该迭代参数 + bool usesIterArg = false; + for (mlir::Value operand : op.getOperands()) { + if (operand == iterArg) { + usesIterArg = true; + break; + } + } + + if (usesIterArg) { + firstUser = &op; + break; + } + } + // map 内找到对应的iterType,iterType由首次在loop内使用到的op定义 + if (!firstUser) { + continue; + } + CoreType iterType = CoreType::CUBE_AND_VECTOR; + if (valueTypes->find(firstUser->getResult(0)) != valueTypes->end()) { + iterType = valueTypes->find(firstUser->getResult(0))->second; + } + + // 获取对应yield + mlir::Value yieldOperand = yieldOp->getOperand(i); + CoreType yieldType = CoreType::CUBE_AND_VECTOR; + if (valueTypes->find(yieldOperand) != valueTypes->end()) { + yieldType = valueTypes->find(yieldOperand)->second; + } + mlir::Operation* yieldDefiningOp = yieldOperand.getDefiningOp(); + + if (yieldType == CoreType::CUBE_ONLY && iterType == CoreType::VECTOR_ONLY) { + + // 2. 插入同步指令 + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + + // set 在 yieldDefiningOp 后 + builder.setInsertionPointAfter(yieldDefiningOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + mlir::Value srcResult = yieldDefiningOp->getResult(0); + + // // 1. 插入数据搬运 + insertCubeToVectorDataMovement(yieldDefiningOp, firstUser, srcResult, builder, loc, iterArg); + + // wait 在 firstUser 前 + builder.setInsertionPoint(firstUser); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + // llvm::outs() << "yieldOp" << yieldDefiningOp << "iterargs" << firstUser << "\n"; + // llvm::outs() << "Inserted CUBE->VECTOR sync and data movement (flag=" << flag << ")\n"; + } + // VECTOR -> CUBE + else if (yieldType == CoreType::VECTOR_ONLY && iterType == CoreType::CUBE_ONLY) { + + // 2. 插入同步指令 + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + + // set 在 yieldDefiningOp 后 + builder.setInsertionPointAfter(yieldDefiningOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运 + // insertVectorToCubeDataMovement(yieldDefiningOp, firstUser, srcResult, builder, loc, iterArg); + + // wait 在 firstUser 前 + builder.setInsertionPoint(firstUser); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + // llvm::outs() << "yieldOp" << yieldDefiningOp << "iterargs" << firstUser << "\n"; + // llvm::outs() << "Inserted VECTOR->CUBE sync and data movement (flag=" << flag << ")\n"; + } + } +} + +// 获取节点的设备类型 +CoreType DAGSyncPass::getNodeDeviceType(OpNode *node, llvm::DenseMap *valueTypes) +{ + if (!node || !node->op) { + return CoreType::CUBE_AND_VECTOR; + } + + // 尝试从节点的结果中获取设备类型 + // 通常使用第一个结果来代表节点的设备类型 + if (node->op->getNumResults() > 0) { + mlir::Value result = node->op->getResult(0); + auto it = valueTypes->find(result); + if (it != valueTypes->end()) { + return it->second; + } + } + + // 如果没有找到,检查操作数 + // for (mlir::Value operand : node->op->getOperands()) { + // auto it = valueTypes->find(operand); + // if (it != valueTypes->end()) { + // return it->second; + // } + // } + + return CoreType::CUBE_AND_VECTOR; // 默认 +} + +// 判断是否需要vector<->cube同步 +bool DAGSyncPass::needVectorCubeSync(CoreType src, CoreType dst) +{ + return (src == CoreType::VECTOR_ONLY && dst == CoreType::CUBE_ONLY) || + (src == CoreType::CUBE_ONLY && dst == CoreType::VECTOR_ONLY); +} + +// 获取 tensor 类型 +mlir::RankedTensorType DAGSyncPass::getTensorType(mlir::Value tensorValue) { + if (auto tensorType = dyn_cast(tensorValue.getType())) { + return tensorType; + } + return nullptr; +} + +// 替换操作数 +void DAGSyncPass::replaceOperandWithNewValue(mlir::Operation *dstOp, mlir::Value oldValue, + mlir::Value newValue) { + for (unsigned i = 0; i < dstOp->getNumOperands(); ++i) { + if (dstOp->getOperand(i) == oldValue) { + dstOp->setOperand(i, newValue); + // llvm::outs() << "Replaced operand " << i << " of " << dstOp->getName().getStringRef() + // << " with new value\n"; + } + } +} + +// 修改 getOrCreateAllocation 函数,将 alloc 提到函数最外层 +mlir::Value DAGSyncPass::getOrCreateAllocation(mlir::Operation *op, mlir::Type tensorType, + hivm::AddressSpace addressSpace, + mlir::OpBuilder &builder, mlir::Location loc) { + auto rankedTensorType = cast(tensorType); + auto elementType = rankedTensorType.getElementType(); + auto shape = rankedTensorType.getShape(); + + auto addressSpaceAttr = hivm::AddressSpaceAttr::get(builder.getContext(), addressSpace); + auto memrefType = mlir::MemRefType::get(shape, elementType, /*layout=*/nullptr, addressSpaceAttr); + + // 查找是否已经存在相同类型的 allocation(在函数的 entry block 中) + mlir::Operation* funcOp = op; + while (funcOp && !mlir::isa(funcOp)) { + funcOp = funcOp->getParentOp(); + } + + if (auto func = mlir::dyn_cast(funcOp)) { + // 在函数的 entry block 中查找现有的 allocation + mlir::Block& entryBlock = func.getBody().front(); + // for (auto& blockOp : entryBlock) { + // if (auto allocOp = mlir::dyn_cast(&blockOp)) { + // if (allocOp.getType() == memrefType) { + // // 找到匹配的 allocation,直接复用 + // llvm::outs() << "Reusing existing allocation: " << allocOp << "\n"; + // return allocOp.getResult(); + // } + // } + // } + + // 没有找到现有的 allocation,在函数开头创建新的 + builder.setInsertionPointToStart(&entryBlock); + return builder.create(loc, memrefType); + } + + // 如果没有找到函数,回退到原逻辑 + builder.setInsertionPoint(op); + return builder.create(loc, memrefType); +} + +// 插入 CUBE -> VECTOR 数据搬运 +void DAGSyncPass::insertCubeToVectorDataMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, + mlir::Value srcResult, mlir::OpBuilder &builder, + mlir::Location loc, mlir::Value iterArgs) { + auto srcTensorType = getTensorType(srcResult); + if (!srcTensorType) { + return; + } + + // 1. 在 srcOp 之后创建 UB 空间的 memref.alloc + builder.setInsertionPointAfter(srcOp); + mlir::Value ubAlloc = getOrCreateAllocation(srcOp, srcTensorType, + hivm::AddressSpace::UB, builder, loc); + + // 2. 创建 fixpipe 指令 + builder.setInsertionPointAfter(srcOp); + FixpipeDMAModeAttr dmaModeAttr = FixpipeDMAModeAttr::get(builder.getContext(), FixpipeDMAMode::NZ2ND); + + auto fixpipeOp = builder.create( + loc, + mlir::TypeRange{}, // 没有返回值 + srcResult, // src + ubAlloc, // dst + /*unit_flag_cond=*/mlir::ValueRange{}, + /*dma_mode=*/dmaModeAttr, + /*dual_dst_mode=*/nullptr, + /*pre_quant=*/nullptr, + /*pre_relu=*/nullptr, + /*channel_split=*/nullptr, + /*unit_flag_mode=*/mlir::ArrayAttr{}); + + llvm::outs() << "Inserted fixpipe after " << srcOp->getName().getStringRef() + << " for CUBE->VECTOR data movement\n"; + + // 3. 在 dstOp 前创建 memory_space_cast 和 to_tensor + builder.setInsertionPoint(dstOp); + + // memory_space_cast(如果需要) + mlir::Value plainMemref = ubAlloc; + auto memrefType = cast(ubAlloc.getType()); + if (memrefType.getMemorySpace()) { + auto plainMemrefType = mlir::MemRefType::get(memrefType.getShape(), + memrefType.getElementType()); + plainMemref = builder.create(loc, plainMemrefType, ubAlloc); + (*valueTypes)[plainMemref] = CoreType::VECTOR_ONLY; + } + + // 4. 创建 to_tensor + auto toTensorOp = builder.create( + loc, + srcTensorType, // 原始的 tensor 类型 + plainMemref, + /*restrict=*/true, + /*writable=*/true + ); + (*valueTypes)[toTensorOp.getResult()] = CoreType::VECTOR_ONLY; + + // 5. 替换 dstOp 的操作数 + if (!iterArgs) { + replaceOperandWithNewValue(dstOp, srcResult, toTensorOp.getResult()); + } else { + replaceOperandWithNewValue(dstOp, iterArgs, toTensorOp.getResult()); + } +} + +static uint64_t getElemBytesForAlign(Type t) { + if (auto ft = dyn_cast(t)) + return (uint64_t)((ft.getWidth() + 7) / 8); + if (auto it = dyn_cast(t)) + return (uint64_t)((it.getWidth() + 7) / 8); + if (isa(t)) + return 8ULL; + if (auto ct = dyn_cast(t)) + return 2ULL * getElemBytesForAlign(ct.getElementType()); + return 0ULL; +} + +static FailureOr getBlockElemsFor32BAlign(Type elemType) { + constexpr uint64_t kAlignBytes = 32; + uint64_t elemBytes = getElemBytesForAlign(elemType); + if (elemBytes <= 0) + return failure(); + if (elemBytes >= kAlignBytes) + return 1; + if (kAlignBytes % elemBytes != 0) + return failure(); + return kAlignBytes / elemBytes; +} + +static std::optional> newCbubAllocShape(memref::AllocOp allocOp) { + auto type = dyn_cast(allocOp.getType()); + // 仅支持静态 2D MemRef + if (!type || type.getRank() != 2) + return std::nullopt; + + auto shape = type.getShape(); + int64_t M = shape[0]; + int64_t N = shape[1]; + auto elemType = type.getElementType(); + auto blkOr = getBlockElemsFor32BAlign(elemType); + int64_t blk = (int64_t)*blkOr; + // 必须是静态且 16 对齐 + if (ShapedType::isDynamic(M) || ShapedType::isDynamic(N)) + return std::nullopt; + if (M % 16 != 0) + return std::nullopt; + + // 新 shape: (N/16, M/16, 16, 16) + SmallVector newShape = {N / blk, M / 16, 16, blk}; + + return newShape; +} + +// 修改 VECTOR->CUBE 数据搬运函数 +void DAGSyncPass::insertVectorToCubeDataMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, Operation* posOp, + mlir::Value srcResult, mlir::OpBuilder &builder, + mlir::Location loc, llvm::DenseMap* valueMap) { + auto srcTensorType = getTensorType(srcResult); + if (!srcTensorType) { + return; + } + if (isa(srcOp) && isa(dstOp)) { + return; + } + + // 1. 在 srcOp 之后创建 UB 空间的 memref.alloc(用于 to_memref) + builder.setInsertionPointAfter(srcOp); + + // 首先创建 UB 空间的 memref type + auto ubSpaceAttr = hivm::AddressSpaceAttr::get(builder.getContext(), hivm::AddressSpace::UB); + auto ubMemrefType = mlir::MemRefType::get(srcTensorType.getShape(), + srcTensorType.getElementType(), + /*layout=*/nullptr, + ubSpaceAttr); + + // 创建 bufferization.to_memref + if (srcOp->getBlock() == dstOp->getBlock()) { + builder.setInsertionPoint(posOp); + } + auto toMemrefOp = builder.create( + loc, + ubMemrefType, + srcResult + ); + + // 2. 创建 CBUF 空间的 memref.alloc(用于 copy 的目标) + mlir::Value cbufAllocOld = getOrCreateAllocation(srcOp, srcTensorType, + hivm::AddressSpace::L1, builder, loc); + auto cbufShape = *newCbubAllocShape(dyn_cast(cbufAllocOld.getDefiningOp())); + // 获取旧的memref类型并创建新的类型 + auto oldType = dyn_cast(cbufAllocOld.getType()); + + // 获取新的维度数量 + unsigned newRank = cbufShape.size(); + + // 方法1:创建新的恒等布局映射 + AffineMap identityMap = builder.getMultiDimIdentityMap(newRank); + MemRefLayoutAttrInterface layout = AffineMapAttr::get(identityMap); + + // 方法2:如果旧类型有布局,尝试调整它(更安全的选择) + // 先检查旧类型是否有布局 + if (auto oldLayout = oldType.getLayout()) { + if (auto affineMapAttr = dyn_cast(oldLayout)) { + // 如果旧布局是AffineMap,尝试创建新的恒等映射 + // 因为维度改变,旧的affine map可能不再有效 + layout = AffineMapAttr::get(identityMap); + } else { + // 对于其他类型的布局,可能需要特殊处理 + layout = oldLayout; + } + } + + // 创建新的alloc类型 + auto newAllocType = MemRefType::get( + cbufShape, + oldType.getElementType(), + layout, // 使用新创建的布局 + oldType.getMemorySpace() + ); + + builder.setInsertionPoint(cbufAllocOld.getDefiningOp()); + // 创建新的alloc操作 + auto cbufAlloc = builder.create( + cbufAllocOld.getDefiningOp()->getLoc(), + newAllocType + ); + + builder.setInsertionPointAfter(toMemrefOp); + // 3. 创建 copy 指令(src 是 ub memref,dst 是 cbuf memref) + auto copyOp = builder.create( + loc, + mlir::TypeRange{}, // 没有返回值 + toMemrefOp.getResult(), // src (memref in UB) + cbufAlloc // dst (memref in CBUF) + ); + + // llvm::outs() << "Inserted copy after " << srcOp->getName().getStringRef() + // << " for VECTOR->CUBE data movement\n"; + + // 4. 在 dstOp 前创建 convert_layout + builder.setInsertionPoint(dstOp); + auto ndLayout = hivm::DataLayoutAttr::get(builder.getContext(), hivm::DataLayout::ND); + // 创建 convert_layout + auto convertLayoutOp = builder.create( + loc, + cbufAllocOld.getType(), // 输出类型与输入相同 + cbufAlloc, + ndLayout, // srcLayout + ndLayout // dstLayout + ); + (*valueTypes)[convertLayoutOp.getResult()] = CoreType::CUBE_ONLY; + + // 5. 创建 memory_space_cast + auto cbufMemrefType = cast(convertLayoutOp.getType()); + auto plainMemrefType = mlir::MemRefType::get(cbufMemrefType.getShape(), + cbufMemrefType.getElementType()); + + auto memspaceCastOp = builder.create( + loc, + plainMemrefType, + convertLayoutOp.getResult() + ); + (*valueTypes)[memspaceCastOp.getResult()] = CoreType::CUBE_ONLY; + + // 6. 创建 to_tensor + auto toTensorOp = builder.create( + loc, + srcTensorType, // 原始的 tensor 类型 + memspaceCastOp.getResult(), + /*restrict=*/true, + /*writable=*/true + ); + (*valueTypes)[toTensorOp.getResult()] = CoreType::CUBE_ONLY; + + // 7. 替换 dstOp 的操作数 + replaceOperandWithNewValue(dstOp, srcResult, toTensorOp.getResult()); +} + +Operation* DAGSyncPass::FindLastestPosition(Operation* srcOp, Graph &mainGraph, OpBuilder &builder) { + Operation* insertPos = nullptr; + auto opMap = mainGraph.getOpMapLegacy(); + auto valueTypes = &mainGraph.getValueTypes(); + // Find the first cube-dependent vector core operation. + for(auto nextOp = srcOp->getNextNode();nextOp!=nullptr; nextOp=nextOp->getNextNode()) { + auto nextType = getNodeDeviceType(opMap[nextOp], valueTypes); + if(nextType == CoreType::CUBE_ONLY) continue; + // No memref ops in IR yet; directly tracing operands + for(auto operand: nextOp->getOperands()) { + auto defOp = operand.getDefiningOp(); + auto defType = getNodeDeviceType(opMap[defOp], valueTypes); + if(defType == CoreType::CUBE_ONLY) { + //To prevent UB overflow, we need to break the dependency at the point where the result shape is minimized + // — i.e., trace upward to find the first broadcast. + for(auto prevOp = nextOp->getPrevNode(); prevOp != nullptr && prevOp != srcOp; prevOp = prevOp->getPrevNode()) { + if(isa(prevOp)) { + if(prevOp->getPrevNode() && isa(prevOp->getPrevNode())) { + return prevOp->getPrevNode(); + } + return prevOp; + } + } + // Can't find the result shape is minimized + return nextOp; + } + } + + // Once meet SyncBlockWaitOp, return now! + if(auto waitOp = dyn_cast(nextOp)) { + if(waitOp.getTcoreType() == hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR)) { + return nextOp; + } + } + insertPos = nextOp; + } + return insertPos; +} + +Operation* DAGSyncPass::FindEarliestPosition(Operation* dstOp, Graph &mainGraph, OpBuilder &builder) +{ + auto insertPos = dstOp; + auto opMap = mainGraph.getOpMapLegacy(); + auto valueTypes = &mainGraph.getValueTypes(); + for (auto prevOp = dstOp->getPrevNode(); prevOp != nullptr; prevOp = prevOp->getPrevNode()) { + if (dstOp->getBlock() != prevOp->getBlock()) continue; + // Once meet SyncBlockSetOp, return now! + if (auto waitOp = dyn_cast(prevOp)) { + if (waitOp.getTcoreType() == hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR)) { + return insertPos; + } + } + insertPos = prevOp; + } + return insertPos; +} + +// 主要的同步和数据搬运插入函数 +void DAGSyncPass::insertSyncAndMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, + CoreType srcType, CoreType dstType, + mlir::OpBuilder &builder, int flag, llvm::DenseMap* valueMap, Graph &mainGraph) { + mlir::Location loc = srcOp->getLoc(); + // 保存当前的插入点 + mlir::OpBuilder::InsertionGuard guard(builder); + + // 检查是否是跨 block + mlir::Block *srcBlock = srcOp->getBlock(); + mlir::Block *dstBlock = dstOp->getBlock(); + bool sameBlock = (srcBlock == dstBlock); + + if (!sameBlock) { + // 检查是否是外层到内层的依赖 + bool dstIsInnerBlock = false; + mlir::Operation *dstParentOp = dstBlock->getParentOp(); + while (dstParentOp) { + if (dstParentOp->getBlock() == srcBlock) { + dstIsInnerBlock = true; + break; + } + if (dstParentOp->getBlock()) { + dstParentOp = dstParentOp->getBlock()->getParentOp(); + } else { + break; + } + } + + if (dstIsInnerBlock) { + insertSyncAndMovementForCrossBlock(srcOp, dstOp, srcType, dstType, builder, flag, true, valueMap, mainGraph); + return; + } + } + + // 同一 block 内的处理 + // 获取 srcOp 的输出(假设第一个结果) + if (srcOp->getNumResults() == 0) { + return; + } + mlir::Value srcResult = srcOp->getResult(0); + + // CUBE -> VECTOR + if (srcType == CoreType::CUBE_ONLY && dstType == CoreType::VECTOR_ONLY) { + + // 2. 插入同步指令 + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto lastSetPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto lastWaitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + auto flagAddId = builder.getIntegerAttr(builder.getI64Type(), flag * 2); + auto lastFlagAddId = builder.getIntegerAttr(builder.getI64Type(), (flag - 1) * 2); + + // set 在 srcOp 后 + builder.setInsertionPointAfter(srcOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // wait 在 dstOp 前 + + auto posOp = FindEarliestPosition(dstOp, mainGraph, builder); + builder.setInsertionPoint(posOp); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运 + insertCubeToVectorDataMovement(srcOp, dstOp, srcResult, builder, loc, nullptr); + + // llvm::outs() << "Inserted CUBE->VECTOR sync and data movement (flag=" << flag << ")\n"; + } + // VECTOR -> CUBE + else if (srcType == CoreType::VECTOR_ONLY && dstType == CoreType::CUBE_ONLY) { + + // 2. 插入同步指令 + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto lastSetPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto lastWaitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + auto flagAddId = builder.getIntegerAttr(builder.getI64Type(), flag * 2); + auto lastFlagAddId = builder.getIntegerAttr(builder.getI64Type(), (flag - 1) * 2); + + // set 在 srcOp 后 + // builder.setInsertionPointAfter(srcOp); + auto posOp = FindLastestPosition(srcOp, mainGraph, builder); + if (posOp) { + builder.setInsertionPoint(posOp); + } else { + builder.setInsertionPointAfter(srcOp); + } + auto setOp = builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // wait 在 dstOp 前 + builder.setInsertionPoint(dstOp); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运 + insertVectorToCubeDataMovement(srcOp, dstOp, setOp, srcResult, builder, loc, valueMap); + + // llvm::outs() << "Inserted VECTOR->CUBE sync and data movement (flag=" << flag << ")\n"; + } +} + +// 跨 block 的同步和数据搬运 +void DAGSyncPass::insertSyncAndMovementForCrossBlock(mlir::Operation *srcOp, mlir::Operation *dstOp, + CoreType srcType, CoreType dstType, + mlir::OpBuilder &builder, int flag, + bool dstIsInnerBlock, llvm::DenseMap* valueMap, Graph &mainGraph) { + if (!dstIsInnerBlock) { + insertSyncAndMovement(srcOp, dstOp, srcType, dstType, builder, flag, valueMap, mainGraph); + return; + } + + mlir::Location loc = srcOp->getLoc(); + mlir::Block *dstBlock = dstOp->getBlock(); + + // 获取 srcOp 的输出 + if (srcOp->getNumResults() == 0) { + return; + } + mlir::Value srcResult = srcOp->getResult(0); + + // CUBE -> VECTOR + if (srcType == CoreType::CUBE_ONLY && dstType == CoreType::VECTOR_ONLY) { + + // 2. 插入同步指令(跨 block 特殊处理) + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + + // set 在 srcOp 后(外层) + builder.setInsertionPointAfter(srcOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运(同 block 内逻辑) + insertCubeToVectorDataMovement(srcOp, dstOp, srcResult, builder, loc, nullptr); + + // wait 在内层 block 入口前 + mlir::Operation *parentOp = dstBlock->getParentOp(); + if (parentOp) { + while (srcOp->getBlock() != parentOp->getBlock()) { + parentOp = parentOp->getBlock()->getParentOp(); + } + builder.setInsertionPoint(parentOp); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } else { + builder.setInsertionPoint(dstOp); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } + + } + // VECTOR -> CUBE + else if (srcType == CoreType::VECTOR_ONLY && dstType == CoreType::CUBE_ONLY) { + + // 2. 插入同步指令(跨 block 特殊处理) + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + + // set 在 srcOp 后(外层) + builder.setInsertionPointAfter(srcOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运(同 block 内逻辑) + insertVectorToCubeDataMovement(srcOp, dstOp, srcOp, srcResult, builder, loc, valueMap); + + // wait 在内层 block 入口前 + mlir::Operation *parentOp = dstBlock->getParentOp(); + if (parentOp) { + builder.setInsertionPoint(parentOp); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } else { + builder.setInsertionPoint(dstOp); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } + + // llvm::outs() << "Inserted cross-block VECTOR->CUBE sync and data movement (flag=" << flag << ")\n"; + } +} + +void LegalizeDot(triton::FuncOp funcOp) { + mlir::OpBuilder builder(funcOp); + funcOp.walk([&](triton::DotOp dotOp) { + // 获取dot操作的输入 + Value a = dotOp.getOperands()[0]; + Value b = dotOp.getOperands()[1]; + Value c = dotOp.getOperands()[2]; // 累加器参数 + + // 检查累加器是否为全零常量 + bool isZeroAccumulator = false; + + // 检查是否直接是arith.constant 0 + if (auto constantOp = c.getDefiningOp()) { + if (auto denseAttr = dyn_cast(constantOp.getValue())) { + if (denseAttr.isSplat() && denseAttr.getSplatValue().getValueAsDouble() == 0.0) { + isZeroAccumulator = true; + } + } + } + + if (!isZeroAccumulator) { + // 创建新的零累加器 + Location loc = dotOp.getLoc(); + auto resultType = dotOp.getResult().getType(); + + Value originalResult = dotOp.getResult(); + builder.setInsertionPoint(dotOp); + // 创建全零张量 + auto zeroAttr = DenseElementsAttr::get( + dyn_cast(resultType), + APFloat(0.0f)); + auto zeroConstant = builder.create(loc, zeroAttr); + + // 创建新的dot操作,使用零作为累加器 + auto newDot = builder.create( + loc, resultType, a, b, zeroConstant); + + // 创建加法操作,将新的dot结果与原来的累加器c相加 + auto addOp = builder.create(loc, newDot, c); + + // 用addOp替换原来的dotOp + originalResult.replaceAllUsesWith(addOp.getResult()); + + // 删除原dotOp(如果它没有其他用途) + if (dotOp.use_empty()) { + dotOp.erase(); + } + } + + }); +} + +static void rewriteCopyChainForCbub( + hivm::CopyOp copyOp, + ArrayRef newShape, + OpBuilder &builder) { + + // 获取 copy 的输入(ins),应为 to_memref 的结果 + Value insVal = copyOp.getOperands()[0]; + auto toMemRefOp = insVal.getDefiningOp(); + if (!toMemRefOp) + return; + + Value inputTensor = toMemRefOp.getTensor(); + auto inputTensorType = dyn_cast(inputTensor.getType()); + if (!inputTensorType || inputTensorType.getRank() != 2) + return; + + // blk = 32/位宽 + // 中间 reshape 形状:[M/16, 16, N/ blk, blk] + int64_t M = inputTensorType.getShape()[0]; + int64_t N = inputTensorType.getShape()[1]; + auto elemType = inputTensorType.getElementType(); + auto blkOr = getBlockElemsFor32BAlign(elemType); + int64_t blk = (int64_t)*blkOr; + SmallVector intermediateShape3D = {M, N / blk, blk}; + SmallVector intermediateShapetrans = {N / blk, M, blk}; + auto elementType = inputTensorType.getElementType(); + auto interTensor3DType = RankedTensorType::get(intermediateShape3D, elementType); + auto interTensortransType = RankedTensorType::get(intermediateShapetrans, elementType); + + auto finalTensorType = RankedTensorType::get(newShape, elementType); + + auto loc = inputTensor.getLoc(); + + // Set insertion point before copyOp (or toMemRefOp) + auto tensorOp = inputTensor.getDefiningOp(); + builder.setInsertionPointAfter(tensorOp); + + // 插入 triton.reshape 将 2D tensor 展开为 3D + auto reshape3DOp = builder.create( + loc, interTensor3DType, inputTensor); + (*valueTypes)[reshape3DOp.getResult()] = CoreType::VECTOR_ONLY; + + // nark tiling dim for reshapeop + auto markOp3d = builder.create(loc, reshape3DOp); + auto tilingDimAttr3d = builder.getDictionaryAttr(SmallVector{ + NamedAttribute(builder.getStringAttr("1"), builder.getIndexAttr(1))}); + markOp3d->setAttr("tiling_dim_mapping", tilingDimAttr3d); + + // 插入 triton.trans 调整维度顺序 Insert tt.trans {order = [1, 0, 2]} + SmallVector order = {1, 0, 2}; + auto orderAttr = builder.getDenseI32ArrayAttr(order); // OpBuilder supports this + auto transOp = builder.create( + loc, interTensortransType, reshape3DOp.getResult(), orderAttr); + (*valueTypes)[transOp.getResult()] = CoreType::VECTOR_ONLY; + + // 插入 triton.reshape 将 3D tensor 展开为 4D + auto reshape4DOp = builder.create( + loc, finalTensorType, transOp.getResult()); + (*valueTypes)[reshape4DOp.getResult()] = CoreType::VECTOR_ONLY; + + // nark tiling dim for reshapeop + auto markOp4d = builder.create(loc, reshape4DOp); + auto tilingDimAttr4d = builder.getDictionaryAttr(SmallVector{ + NamedAttribute(builder.getStringAttr("1"), builder.getIndexAttr(1))}); + markOp4d->setAttr("tiling_dim_mapping", tilingDimAttr4d); + + // Create new to_memref + builder.setInsertionPoint(toMemRefOp); + auto newMemRefType = MemRefType::get( + newShape, + elementType, + mlir::AffineMap{}, + toMemRefOp.getType().getMemorySpace()); + auto newToMemRefOp = builder.create( + toMemRefOp.getLoc(), + newMemRefType, + reshape4DOp.getResult()); + (*valueTypes)[newToMemRefOp.getResult()] = CoreType::VECTOR_ONLY; + + // Create NEW copyOp (replacing the old one) + builder.setInsertionPoint(copyOp); + auto resultTypes = copyOp->getResultTypes(); + auto newCopyOp = builder.create( + copyOp.getLoc(), + resultTypes, // TypeRange + reshape4DOp.getResult(), // src (ins) + copyOp.getOperands()[1] // dst (outs) + ); + + // 替换 uses 并清理旧 op + copyOp.replaceAllUsesWith(newCopyOp); + copyOp.erase(); + toMemRefOp.erase(); + + return; +} + +template +OpTy createBlockSync(OpBuilder builder, + hivm::TCoreType coreType, + hivm::PIPE srcPipe, + hivm::PIPE dstPipe, + int flag, + Operation *cause) +{ + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), coreType); + auto setPipe = PipeAttr::get(builder.getContext(), srcPipe); + auto waitPipe = PipeAttr::get(builder.getContext(), dstPipe); + return builder.create(cause->getLoc(), coreAttr, setPipe, waitPipe, flagId); +} + +// since we do not have llvm::set_intersects in this version... +template bool intersects(S1Ty &s1, S2Ty &s2) +{ + if (s1.size() > s2.size()) { + return intersects(s2, s1); + } + + return llvm::any_of(s1, [&](auto e) { return s2.count(e); }); +} + +bool mayAlias(DataFlowSolver &solver, Value ptrA, Value ptrB) +{ + if (ptrA == ptrB) { + return true; + } + const auto *stateA = solver.lookupState>(ptrA); + const auto *stateB = solver.lookupState>(ptrB); + if (!stateA || !stateB) { // not triton ptr type + return true; + } + auto infoA = stateA->getValue(); + auto infoB = stateB->getValue(); + + return intersects(infoA.getAllocs(), infoB.getAllocs()); +} + +const size_t MAX_EXPECTED_PARENTS_COUNT = 8; + +std::optional> findAncestorCommonBlock(mlir::Operation *opA, mlir::Operation *opB) +{ + if (opA->getBlock() == opB->getBlock()) { + return std::make_pair(opA, opB); + } + + // record all ancestors of opA + llvm::SmallPtrSet ancestorsA; + mlir::Operation *curr = opA; + while (curr) { + ancestorsA.insert(curr); + curr = curr->getParentOp(); + } + + // find the last ancestor of opB which is also the ancestor of opA + mlir::Operation *commonAncOp = nullptr; + curr = opB; + while (curr) { + if (ancestorsA.count(curr)) { + commonAncOp = curr; + break; + } + curr = curr->getParentOp(); + } + + if (!commonAncOp) { + return std::nullopt; + } + + // find the ancestors in the given region + for (mlir::Region ®ion : commonAncOp->getRegions()) { + for (mlir::Block &block : region) { + auto *ancA = block.findAncestorOpInBlock(*opA); + auto *ancB = block.findAncestorOpInBlock(*opB); + if (ancA && ancB) { + return std::make_pair(ancA, ancB); + } + } + } + return std::nullopt; +} + +struct SyncCandidate { + CoreType srcCoreType; + Operation *setCause; + Operation *setAfter; + Operation *waitCause; + Operation *waitBefore; +}; + +// setOp, waitOp +void createBlockSyncBetween(OpBuilder builder, + hivm::PIPE srcPipe, + hivm::PIPE dstPipe, + SyncCandidate candidate, + int flag) +{ + auto srcCoreType = toHivm(candidate.srcCoreType); + auto dstCoreType = toHivm(!candidate.srcCoreType); + + builder.setInsertionPointAfter(candidate.setAfter); + auto setOp = createBlockSync(builder, srcCoreType, srcPipe, dstPipe, flag, candidate.setCause); + builder.setInsertionPoint(candidate.waitBefore); + auto waitOp = createBlockSync(builder, dstCoreType, srcPipe, dstPipe, flag, candidate.waitCause); +}; + +void addMemEffectsSync(triton::FuncOp funcOp, Graph *graph, OpBuilder &builder, int &syncFlag) +{ + DominanceInfo domInfo(funcOp); + PostDominanceInfo postDomInfo(funcOp); + DataFlowSolver solver; + solver.load(); + solver.load(); + + if (failed(solver.initializeAndRun(funcOp))) { + funcOp->emitWarning("SharedMemoryAliasAnalysis failed! This could lead to potential memory related issues! \n"); + } + + // [(node, EffectInstance, LinearisationPt)] + llvm::SmallVector> memOps; + + // [(setAfter, waitBefore, srcOP, dstOp)][CoreType] + llvm::SmallVector candidates; + + funcOp.walk([&](MemoryEffectOpInterface memIface) { + auto *op = memIface.getOperation(); + if (llvm::isa(op)) { + return; + } + + auto *currNode = graph->getOpMap()[op].get(); + SmallVector effects; + + memIface.getEffects(effects); + + for (auto &effect : effects) { + if (!isa(effect.getEffect())) { + continue; + } + memOps.emplace_back(currNode, effect); + bool isWrite = isa(effect.getEffect()); + for (auto &[prevNode, prevEffect] : memOps) { + if ((isa(prevEffect.getEffect()) || isWrite) && + mayAlias(solver, prevEffect.getValue(), effect.getValue()) && + prevNode->isOn() != currNode->isOn() // write is forced on single core type, so we are safe to judge + // based on whether the core types are different + ) { + CoreType srcCoreType = isWrite ? !currNode->isOn() : prevNode->isOn(); + auto opPair = findAncestorCommonBlock(prevNode->op, currNode->op); + if (!opPair.has_value()) { + op->emitWarning( + llvm::formatv("Unable to find ancestors in common block with {0}\n", *prevNode->op)); + continue; + } + auto [setAfter, waitBefore] = opPair.value(); + if (setAfter == waitBefore) { + continue; + } + candidates.push_back(SyncCandidate {srcCoreType, prevNode->op, setAfter, op, waitBefore}); + } + } + } + }); + + auto addBlockSyncCommon = [&builder, &syncFlag](SyncCandidate cand) { + llvm::dbgs() << "\n\n=== Insert sync between ===\n" + << *cand.setAfter << "\n" + << *cand.waitBefore << "\n=== Insert Sync End ===\n\n"; + + auto srcPipe = cand.srcCoreType == CoreType::CUBE_ONLY ? hivm::PIPE::PIPE_FIX : hivm::PIPE::PIPE_MTE2; + auto dstPipe = hivm::PIPE::PIPE_S; + createBlockSyncBetween(builder, srcPipe, dstPipe, cand, syncFlag % 14); + syncFlag++; + }; + + if (candidates.empty()) { + return; + } + + auto setAfterDominate = [&domInfo](Operation *a, Operation *b) { + if (domInfo.dominates(a, b)) { + return true; + } + if (domInfo.dominates(b, a)) { + return false; + } + if (a->isAncestor(b)) { + return false; + } + if (b->isAncestor(a)) { + return true; + } + return false; + }; + + auto waitBeforePostDominate = [&postDomInfo](Operation *a, Operation *b) { + if (postDomInfo.postDominates(a, b)) { + return true; + } + if (postDomInfo.postDominates(b, a)) { + return false; + } + if (a->isAncestor(b)) { + return true; + } + if (b->isAncestor(a)) { + return false; + } + return false; + }; + + llvm::sort(candidates, [&](const SyncCandidate &a, const SyncCandidate &b) { + if (a.setAfter != b.setAfter) { + return setAfterDominate(a.setAfter, b.setAfter); + } + + if (a.waitBefore != b.waitBefore) { + return waitBeforePostDominate(a.waitBefore, b.waitBefore); + } + + return false; + }); + + for (auto [i, cand] : llvm::enumerate(candidates)) { + bool shouldInsert = true; + for (auto otherCand : ArrayRef(candidates).drop_front(i + 1)) { + bool duplicated = (cand.waitBefore == otherCand.waitBefore && cand.setAfter == otherCand.setAfter && + cand.srcCoreType == otherCand.srcCoreType); + bool containsOther = + (cand.srcCoreType == otherCand.srcCoreType && setAfterDominate(cand.setAfter, otherCand.setAfter) && + waitBeforePostDominate(cand.waitBefore, otherCand.waitBefore)); + if (duplicated || containsOther) { + shouldInsert = false; + break; + } + } + + if (shouldInsert) { + addBlockSyncCommon(cand); + } + } +} + +void DAGSyncPass::runOnOperation() +{ + auto module = getOperation(); + mlir::OpBuilder builder(&getContext()); + + // 遍历所有函数 + for (auto funcOp : llvm::make_early_inc_range(module.getOps())) { + // 跳过无效函数 + LegalizeDot(funcOp); + if (funcOp.getBody().empty()) { + continue; + } + + // llvm::outs() << "\n====================================\n"; + // llvm::outs() << "处理函数: " << funcOp.getName() << "\n"; + // llvm::outs() << "====================================\n"; + + auto unique_graph = Graph::fromMultiBlockFunc(funcOp); + std::shared_ptr shared_graph = std::move(unique_graph); + auto& main_graph = *shared_graph; + + auto funcName = funcOp.getName(); + + // 获取 DAG 图的映射 + auto opMapRaw = main_graph.getOpMapLegacy(); + valueTypes = &main_graph.getValueTypes(); + auto *opMap = &opMapRaw; + + if (!opMap || !valueTypes) { + llvm::errs() << "Warning: Failed to create DAG graph for function " << funcOp.getName() << "\n"; + continue; + } + + // 用于避免重复插入同步 + llvm::DenseSet> processedPairs; + int syncFlag = 1; + addMemEffectsSync(funcOp, shared_graph.get(), builder, syncFlag); + + // 3. 使用 walk 遍历函数中的所有操作 + funcOp.walk([&](mlir::Operation *op) { + // 查找当前操作对应的 Node + auto nodeIt = opMap->find(op); + if (nodeIt == opMap->end()) { + // 这个操作不在 entry block 的 DAG 图中 + // 可能是嵌套在控制流内部的操作 + return; + } + + OpNode *currentNode = nodeIt->second; + + // 检查是否是 scf.for 操作 + if (auto forOp = mlir::dyn_cast(op)) { + // 处理 scf.for 循环的特殊同步逻辑 + int temp = syncFlag % 14; + processScfForSync(forOp, currentNode, valueTypes, builder, temp); + } + + // 获取当前节点的设备类型 + CoreType currentType = getNodeDeviceType(currentNode, valueTypes); + + // 打印操作信息(可选) + // if (!llvm::isa(op->getDialect())) { + // llvm::outs() << "操作: " << *op + // << " 设备类型: " + // << (currentType == CoreType::VECTOR_ONLY ? "VECTOR" : + // currentType == CoreType::CUBE_ONLY ? "CUBE" : "SCALAR") + // << "\n"; + // } + + // 4. 遍历当前节点的所有输入节点 + for (ValueNode *inputValNode : currentNode->getInputs()) { + auto inputOp = inputValNode->value.getDefiningOp(); + if (!inputOp || !opMap->contains(inputOp)) { + continue; + } + + auto inputNode = (*opMap)[inputOp]; + + // 获取输入节点的设备类型 + CoreType inputType = getNodeDeviceType(inputNode, valueTypes); + + // 5. 判断是否需要插入同步和数据搬运 + if (needVectorCubeSync(inputType, currentType)) { + // 检查是否已经处理过这对操作 + auto opPair = std::make_pair(inputOp, op); + if (processedPairs.insert(opPair).second) { + // 插入同步和数据搬运指令 + // 检查是否是跨 block 的依赖 + mlir::Block *srcBlock = inputOp->getBlock(); + mlir::Block *dstBlock = op->getBlock(); + + if (srcBlock == dstBlock) { + // 同一 block 内 + insertSyncAndMovement(inputOp, op, inputType, currentType, builder, syncFlag % 14, valueTypes, main_graph); + syncFlag ++; + } else { + // 跨 block,判断是否是外层到内层 + llvm::outs() << "#########\n"; + bool dstIsInnerBlock = false; + mlir::Operation *dstParentOp = dstBlock->getParentOp(); + + // 向上查找,看 dstBlock 是否在 srcBlock 的区域内 + while (dstParentOp) { + if (dstParentOp->getBlock() == srcBlock) { + dstIsInnerBlock = true; + break; + } + if (dstParentOp->getBlock()) { + dstParentOp = dstParentOp->getBlock()->getParentOp(); + } else { + break; + } + } + if (dstIsInnerBlock) { + + insertSyncAndMovementForCrossBlock(inputOp, op, inputType, currentType, + builder, syncFlag % 14, dstIsInnerBlock, valueTypes, main_graph); + syncFlag ++; + } + } + } + } + } + }); + + // llvm::outs() << "\n函数 " << funcOp.getName() << " 统计:\n"; + // llvm::outs() << " - 插入的总同步操作数: " << syncFlag << "\n"; + funcOp.walk([&](hivm::CopyOp copyOp) { + llvm::outs()<(copyOp.getOperands()[1].getType()).getShape(), builder); + }); + GraphManager::getInstance().registerGraph(funcName, shared_graph); + } + + // llvm::outs()<> mlir::triton::createDAGSyncPass() +{ + return std::make_unique(); +} diff --git a/third_party/ascend/python/src/ir.cc b/third_party/ascend/python/src/ir.cc index d52868ea6f..400c376a66 100644 --- a/third_party/ascend/python/src/ir.cc +++ b/third_party/ascend/python/src/ir.cc @@ -231,6 +231,13 @@ void init_triton_ir(py::module &&m) { py::class_(m, "context", py::module_local()) .def(py::init<>()) + .def("__enter__", [](MLIRContext &self) -> MLIRContext& { return self; }, + py::return_value_policy::reference) + .def("__exit__", + [](MLIRContext &, py::object, py::object, py::object) -> bool { + // Keep context alive for the duration of the scope. + return false; + }) .def("printOpOnDiagnostic", [](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); }) .def("printStackTraceOnDiagnostic", @@ -659,9 +666,13 @@ void init_triton_ir(py::module &&m) { "get_unit_attr", [](TritonOpBuilder &self) { return self.getBuilder().getUnitAttr(); }) .def("get_i64_array_attr", - [](TritonOpBuilder &self, const std::vector &array) { - return self.getBuilder().getI64ArrayAttr(array); - }) + [](TritonOpBuilder &self, const std::vector& array) { + return self.getBuilder().getI64ArrayAttr(array); + }) + .def("get_type_array_attr", + [](TritonOpBuilder &self, const std::vector& array) { + return self.getBuilder().getTypeArrayAttr(array); + }) // Use arith.ConstantOp to create constants // Constants .def("get_int1", @@ -1712,6 +1723,10 @@ void init_triton_ir(py::module &&m) { // TODO: maybe dump module to file and print error for better // diagnostics + auto *context = mod.getContext(); + if (::triton::tools::getBoolEnv("MLIR_DISABLE_MULTITHREADING")) + context->disableMultithreading(); + auto reproducerPath = triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); if (!reproducerPath.empty()) { @@ -1719,6 +1734,7 @@ void init_triton_ir(py::module &&m) { auto passes = self.getPasses(); Operation *op = mod.getOperation(); makeReproducer(anchorName, passes, op, reproducerPath); + context->disableMultithreading(); } if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { @@ -1753,7 +1769,8 @@ void init_triton_ir(py::module &&m) { if (failed(self.run(mod.getOperation()))) throw std::runtime_error("PassManager::run failed"); - }); + }, + py::call_guard()); } void init_triton_env_vars(py::module &m) { diff --git a/third_party/ascend/triton_ascend.cc b/third_party/ascend/triton_ascend.cc index 2f08b0f331..ec40899e18 100644 --- a/third_party/ascend/triton_ascend.cc +++ b/third_party/ascend/triton_ascend.cc @@ -6,21 +6,23 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Pass/PassManager.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Pass/PassManager.h" -#include "ascend/include/TritonToHFusion/Passes.h" +#include "ascend/include/AutoBlockify/Passes.h" +#include "ascend/include/TritonToStructured/Passes.h" +#include "ascend/include/TritonToAnnotation/Passes.h" +#include "ascend/include/TritonToLinalg/Passes.h" +#include "ascend/include/Dialect/TritonAscend/IR/TritonAscendDialect.h" +#include "ascend/include/DiscreteMaskAccessConversion/Passes.h" +#include "ascend/include/TritonToUnstructure/Passes.h" #include "ascend/include/TritonToHIVM/Passes.h" +#include "ascend/include/TritonToHFusion/Passes.h" #include "ascend/include/TritonToLLVM/Passes.h" -#include "incubated/Conversion/DiscreteMaskAccessConversion/Passes.h" -#include "incubated/Conversion/TritonToAnnotation/Passes.h" -#include "incubated/Conversion/TritonToLinalgIncubated/Passes.h" -#include "incubated/Conversion/TritonToStructuredIncubated/Passes.h" -#include "incubated/Conversion/TritonToUnstructureIncubated/Passes.h" -#include "npu/Dialect/TritonAscend/IR/TritonAscendDialect.h" + #include "ascend/include/TritonAffinityOpt/Passes.h" -#include "ir.h" // TritonOpBuilder #include "triton/Dialect/Triton/IR/Dialect.h" +#include "ir.h" // TritonOpBuilder #include @@ -31,338 +33,326 @@ using namespace mlir; void init_triton_ascend_ir(py::module &&m) { auto *builder_cls = ir::getBuilderClass(); builder_cls - ->def("create_extract_scalar", - [](TritonOpBuilder &self, Value &src, - std::vector &indices) -> Value { - llvm::SmallVector arg_indices; - for (const auto &i : indices) { - auto iTy = i.getType(); - if (!iTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), i); - arg_indices.push_back(v); - } else { - arg_indices.push_back(i); - } - } - auto ret = self.create(src, arg_indices); - return ret; - }) - .def("create_extract_slice", - [](TritonOpBuilder &self, Value &ful, std::vector &offs_vec, - std::vector &sizs_vec, std::vector &strd_vec) -> Value { - llvm::SmallVector offsets; - for (const auto &o : offs_vec) { - auto oTy = o.getType(); - if (!oTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), o); - offsets.push_back(v); - } else { - offsets.push_back(o); - } - } - llvm::SmallVector sizes; - llvm::SmallVector retSizes; - for (const auto &s : sizs_vec) { - auto v = self.create(s); - sizes.push_back(v); - retSizes.push_back(s); - } - llvm::SmallVector strides; - for (const auto &s : strd_vec) { - auto v = self.create(s); - strides.push_back(v); - } - auto retTy = RankedTensorType::get( - retSizes, - cast(ful.getType()).getElementType()); - - return self.create(retTy, ful, offsets, - sizes, strides); - }) - .def("create_insert_slice", - [](TritonOpBuilder &self, Value &ful, Value &sub, - std::vector &offs_vec, std::vector &sizs_vec, - std::vector &strd_vec) -> Value { - llvm::SmallVector offsets; - for (const auto &o : offs_vec) { - auto oTy = o.getType(); - if (!oTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), o); - offsets.push_back(v); - } else { - offsets.push_back(o); - } - } - llvm::SmallVector sizes; - llvm::SmallVector retSizes; - for (const auto &s : sizs_vec) { - auto v = self.create(s); - sizes.push_back(v); - retSizes.push_back(s); - } - llvm::SmallVector strides; - for (const auto &s : strd_vec) { - auto v = self.create(s); - strides.push_back(v); - } - auto retTy = RankedTensorType::get( - retSizes, - cast(ful.getType()).getElementType()); - auto ret = self.create(sub, ful, offsets, - sizes, strides); - return ret; - }) - .def("create_custom_op_for_inter_core_sync", - [](TritonOpBuilder &self, std::string &op_name, - std::string &mode_or_sender, int id) -> void { - auto args = self.getBuilder().getArrayAttr( - {self.getBuilder().getStringAttr(mode_or_sender), - self.getBuilder().getI32IntegerAttr(id)}); - self.create(op_name, args, ValueRange()); - }) - .def("create_index_select_simd", - [](TritonOpBuilder &self, Value &src, Value &index, int32_t dim, - std::vector &srcShape, std::vector &srcOffset, - std::vector &readShape, - std::vector &returnShape) -> Value { - auto &builder = self.getBuilder(); - auto loc = self.getLastLoc(); - - // Get element type from source pointer - Type elemType; - if (auto ptrTy = dyn_cast(src.getType())) { - elemType = ptrTy.getPointeeType(); - } else { - llvm::report_fatal_error( - "index_select_simd: src must be pointer type"); - } - - // Create return tensor type - llvm::SmallVector retShape; - for (const auto &s : returnShape) { - retShape.push_back(s); - } - auto retTensorType = RankedTensorType::get(retShape, elemType); - - // Convert srcShape and srcOffset values to index type if needed - llvm::SmallVector srcShapeIndex; - for (auto val : srcShape) { - if (!val.getType().isIndex()) { - val = self.create(builder.getIndexType(), - val); - } - srcShapeIndex.push_back(val); - } - - llvm::SmallVector srcOffsetIndex; - for (auto val : srcOffset) { - if (!val.getType().isIndex()) { - val = self.create(builder.getIndexType(), - val); - } - srcOffsetIndex.push_back(val); - } - - // Create attributes - auto dimAttr = builder.getI32IntegerAttr(dim); - auto readShapeAttr = builder.getDenseI32ArrayAttr(readShape); - - // Create the IndexSelectSimdOp - // Parameter order must match TritonOps.td definition: - // src, index, dim, src_shape, src_offset, read_shape - auto indexSelectSimdOp = - builder.create( - loc, - retTensorType, // result type - src, // src pointer - index, // index tensor - dimAttr, // dim attribute - srcShapeIndex, // src_shape (variadic, index type) - srcOffsetIndex, // src_offset (variadic, index type) - readShapeAttr // read_shape attribute - ); - - return indexSelectSimdOp.getResult(); - }) - .def("create_embedding_gather", - [](TritonOpBuilder &self, Value &src, Value &idx, - const int64_t bound, const int64_t blksiz, - std::vector &offsets, - std::vector &numels) -> Value { - auto elemTy = cast(src.getType()).getPointeeType(); - auto idxTy = cast(idx.getType()); - auto idxShape = idxTy.getShape(); - std::vector retShape(idxShape.begin(), idxShape.end()); - retShape.push_back(blksiz); - auto resType = RankedTensorType::get(retShape, elemTy); - auto idxBitWidth = idxTy.getElementType().getIntOrFloatBitWidth(); - auto bound_val = - self.create(bound, idxBitWidth); - auto blksiz_val = - self.create(blksiz, idxBitWidth); - - return self.create( - resType, src, idx, bound_val, blksiz_val, offsets, numels); - }) - .def("create_index_put", - [](TritonOpBuilder &self, Value &ptr, Value &index, Value &value, - const int32_t dim, const int64_t indexBoundary, - std::vector &endOffset, std::vector &startOffset, - std::vector &dstStride) -> void { - // dim need to be i32 type - auto dimI32Ty = self.getBuilder().getI32Type(); - auto dim_val = self.create(dim, dimI32Ty); - // indexBoundary need to be i64 type - auto BoundI64Ty = self.getBuilder().getI64Type(); - auto bound_val = - self.create(indexBoundary, BoundI64Ty); - - self.create(ptr, index, value, dim_val, - bound_val, endOffset, - startOffset, dstStride); - }) - .def("create_gather_out_to_ub", - [](TritonOpBuilder &self, Value &src, Value &index, - const int64_t indexBoundary, const int32_t dim, - std::vector &srcStride, std::vector &endOffset, - std::vector &startOffset, - std::optional &other) -> Value { - auto elemTy = cast(src.getType()).getPointeeType(); - auto idxTy = cast(index.getType()); - auto idxShape = idxTy.getShape(); - std::vector retShape(idxShape.begin(), idxShape.end()); - auto resType = RankedTensorType::get(retShape, elemTy); - - // indexBoundary need to be i64 type - auto BoundI64Ty = self.getBuilder().getI64Type(); - auto bound_val = - self.create(indexBoundary, BoundI64Ty); - // dim need to be i32 type - auto dimI32Ty = self.getBuilder().getI32Type(); - auto dim_val = self.create(dim, dimI32Ty); - return self.create( - resType, src, index, bound_val, dim_val, srcStride, endOffset, - startOffset, other.value_or(Value())); - }) - .def("create_scatter_ub_to_out", - [](TritonOpBuilder &self, Value &ptr, Value &value, Value &index, - const int64_t indexBoundary, const int32_t dim, - std::vector &dstStride, std::vector &endOffset, - std::vector &startOffset) -> void { - auto idxTy = cast(index.getType()); - - // indexBoundary need to be i64 type - auto BoundI64Ty = self.getBuilder().getI64Type(); - auto bound_val = - self.create(indexBoundary, BoundI64Ty); - // dim need to be i32 type - auto dimI32Ty = self.getBuilder().getI32Type(); - auto dim_val = self.create(dim, dimI32Ty); - - self.create( - ptr, value, index, bound_val, dim_val, dstStride, endOffset, - startOffset); - }) - // Add sort - .def("create_sort", - [](TritonOpBuilder &self, Value src, int64_t dim, - bool descending) -> Value { - auto &builder = self.getBuilder(); - auto loc = self.getLastLoc(); - - auto dimAttr = builder.getI64IntegerAttr(dim); - auto descendingAttr = builder.getBoolAttr(descending); - - auto op = builder.create(loc, src, dimAttr, - descendingAttr); - - return op->getResult(0); - }) - // Add flip - .def("create_flip", - [](TritonOpBuilder &self, Value src, int64_t dim) -> Value { - auto &builder = self.getBuilder(); - auto loc = self.getLastLoc(); - - auto dimAttr = builder.getI64IntegerAttr(dim); - - auto op = - builder.create(loc, src, dimAttr); - - return op->getResult(0); - }) - .def("create_tanh", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - // Add an annotation - .def("create_annotation", - [](TritonOpBuilder &self, Value &ptr, const std::string &attrKey, - Attribute &attrVal) { - auto annotationOp = self.create(ptr); - annotationOp->setAttr(self.getBuilder().getStringAttr(attrKey), - attrVal); - }); + ->def("create_extract_scalar", + [](TritonOpBuilder &self, Value &src, std::vector &indices) -> Value { + llvm::SmallVector arg_indices; + for (const auto &i : indices) { + auto iTy = i.getType(); + if (!iTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), i); + arg_indices.push_back(v); + } else { + arg_indices.push_back(i); + } + } + auto ret = self.create(src, arg_indices); + return ret; + }) + .def("create_extract_slice", + [](TritonOpBuilder &self, Value &ful, std::vector &offs_vec, + std::vector &sizs_vec, std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get(retSizes, + cast(ful.getType()).getElementType()); + + return self.create(retTy, ful, offsets, sizes, strides); + }) + .def("create_insert_slice", + [](TritonOpBuilder &self, Value &ful, Value &sub, + std::vector &offs_vec, std::vector &sizs_vec, + std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); + auto ret = self.create(sub, ful, offsets, + sizes, strides); + return ret; + }) + .def("create_custom_op_for_inter_core_sync", + [](TritonOpBuilder &self, std::string &op_name, + std::string &mode_or_sender, int id) -> void { + auto args = self.getBuilder().getArrayAttr( + {self.getBuilder().getStringAttr(mode_or_sender), + self.getBuilder().getI32IntegerAttr(id)} + ); + self.create(op_name, args, ValueRange()); + }) + .def("create_index_select_simd", + [](TritonOpBuilder &self, Value &src, Value &index, int32_t dim, + std::vector &srcShape, std::vector &srcOffset, + std::vector &readShape, std::vector &returnShape) -> Value { + auto &builder = self.getBuilder(); + auto loc = self.getLastLoc(); + + // Get element type from source pointer + Type elemType; + if (auto ptrTy = dyn_cast(src.getType())) { + elemType = ptrTy.getPointeeType(); + } else { + llvm::report_fatal_error("index_select_simd: src must be pointer type"); + } + + // Create return tensor type + llvm::SmallVector retShape; + for (const auto &s : returnShape) { + retShape.push_back(s); + } + auto retTensorType = RankedTensorType::get(retShape, elemType); + + // Convert srcShape and srcOffset values to index type if needed + llvm::SmallVector srcShapeIndex; + for (auto val : srcShape) { + if (!val.getType().isIndex()) { + val = self.create(builder.getIndexType(), val); + } + srcShapeIndex.push_back(val); + } + + llvm::SmallVector srcOffsetIndex; + for (auto val : srcOffset) { + if (!val.getType().isIndex()) { + val = self.create(builder.getIndexType(), val); + } + srcOffsetIndex.push_back(val); + } + + // Create attributes + auto dimAttr = builder.getI32IntegerAttr(dim); + auto readShapeAttr = builder.getDenseI32ArrayAttr(readShape); + + // Create the IndexSelectSimdOp + // Parameter order must match TritonOps.td definition: + // src, index, dim, src_shape, src_offset, read_shape + auto indexSelectSimdOp = builder.create( + loc, + retTensorType, // result type + src, // src pointer + index, // index tensor + dimAttr, // dim attribute + srcShapeIndex, // src_shape (variadic, index type) + srcOffsetIndex, // src_offset (variadic, index type) + readShapeAttr // read_shape attribute + ); + + return indexSelectSimdOp.getResult(); + }) + .def("create_index_put", + [](TritonOpBuilder &self, Value &ptr, Value &index, + Value &value, const int32_t dim, const int64_t indexBoundary, + std::vector &endOffset, std::vector &startOffset, + std::vector &dstStride) -> void { + // dim need to be i32 type + auto dimI32Ty = self.getBuilder().getI32Type(); + auto dim_val = self.create(dim, dimI32Ty); + // indexBoundary need to be i64 type + auto BoundI64Ty = self.getBuilder().getI64Type(); + auto bound_val = self.create(indexBoundary, BoundI64Ty); + + self.create( + ptr, + index, + value, + dim_val, + bound_val, + endOffset, + startOffset, + dstStride + ); + }) + .def("create_gather_out_to_ub", + [](TritonOpBuilder &self, Value &src, Value &index, const int64_t indexBoundary, + const int32_t dim, std::vector &srcStride, std::vector &endOffset, + std::vector &startOffset, std::optional &other) -> Value { + auto elemTy = cast(src.getType()).getPointeeType(); + auto idxTy = cast(index.getType()); + auto idxShape = idxTy.getShape(); + std::vector retShape(idxShape.begin(), idxShape.end()); + auto resType = RankedTensorType::get(retShape, elemTy); + + // indexBoundary need to be i64 type + auto BoundI64Ty = self.getBuilder().getI64Type(); + auto bound_val = self.create(indexBoundary, BoundI64Ty); + // dim need to be i32 type + auto dimI32Ty = self.getBuilder().getI32Type(); + auto dim_val = self.create(dim, dimI32Ty); + return self.create( + resType, + src, + index, + bound_val, + dim_val, + srcStride, + endOffset, + startOffset, + other.value_or(Value()) + ); + }) + .def("create_scatter_ub_to_out", + [](TritonOpBuilder &self, Value &ptr, Value &value, Value &index, + const int64_t indexBoundary, const int32_t dim, std::vector &dstStride, + std::vector &endOffset, std::vector &startOffset) -> void { + auto idxTy = cast(index.getType()); + + // indexBoundary need to be i64 type + auto BoundI64Ty = self.getBuilder().getI64Type(); + auto bound_val = self.create(indexBoundary, BoundI64Ty); + // dim need to be i32 type + auto dimI32Ty = self.getBuilder().getI32Type(); + auto dim_val = self.create(dim, dimI32Ty); + + self.create( + ptr, + value, + index, + bound_val, + dim_val, + dstStride, + endOffset, + startOffset + ); + }) + // Add sort + .def("create_sort", + [](TritonOpBuilder &self, Value src, int64_t dim, bool descending) -> Value { + auto &builder = self.getBuilder(); + auto loc = self.getLastLoc(); + + auto dimAttr = builder.getI64IntegerAttr(dim); + auto descendingAttr = builder.getBoolAttr(descending); + + auto op = builder.create(loc, src, dimAttr, descendingAttr); + + return op->getResult(0); + }) + // Add flip + .def("create_flip", + [](TritonOpBuilder &self, Value src, int64_t dim) -> Value { + auto &builder = self.getBuilder(); + auto loc = self.getLastLoc(); + + auto dimAttr = builder.getI64IntegerAttr(dim); + + auto op = builder.create(loc, src, dimAttr); + + return op->getResult(0); + }) + .def("create_tanh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + // Add an annotation + .def("create_annotation", + [](TritonOpBuilder &self, Value &ptr, const std::string &attrKey, + Attribute &attrVal) { + auto annotationOp = self.create(ptr); + annotationOp->setAttr(self.getBuilder().getStringAttr(attrKey), + attrVal); + }); } void init_triton_ascend_passes_ttir(py::module &&m) { - m.def("add_triton_to_structure_incubated", - [](mlir::PassManager &pm, bool enableMaskFallbackConversion, - bool optimizeDynamicOffset, bool compileOn91095) { - pm.addPass(mlir::triton::createTritonToStructuredIncubatedPass( - enableMaskFallbackConversion, optimizeDynamicOffset, - compileOn91095)); - }); + m.def("add_auto_blockify", [](mlir::PassManager &pm, + int autoBlockifySize) { + AutoBlockifyOptions opts; + opts.autoBlockifySize = autoBlockifySize; + pm.addPass(mlir::triton::createAutoBlockifyPass(opts));}); - m.def("add_triton_to_annotation", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createTritonToAnnotationPass()); - }); + m.def("add_triton_to_structure", [](mlir::PassManager &pm, + bool enableMaskFallbackConversion, bool optimizeDynamicOffset) { + pm.addPass(mlir::triton::createTritonToStructuredPass( + enableMaskFallbackConversion, optimizeDynamicOffset)); }); - m.def("add_triton_to_linalg_incubated", - [](mlir::PassManager &pm, bool globalKernel, bool namedOps, - bool enableNd2nzOnVector, bool enableSelectAnalysis, - bool compileOn91095) { - pm.addPass(mlir::triton::Incubated::createTritonToLinalgIncubatedPass( - globalKernel, namedOps, enableNd2nzOnVector, enableSelectAnalysis, - compileOn91095)); - }); - - m.def("add_triton_to_unstructure_incubated", - [](mlir::PassManager &pm, bool compileOn91095, bool forceSimtTemplate) { - TritonToUnstructureIncubatedOptions opts; - opts.compileOn91095 = compileOn91095; - opts.forceSimtTemplate = forceSimtTemplate; - pm.addPass( - mlir::triton::createTritonToUnstructureIncubatedPass(opts)); - }); + m.def("add_triton_to_annotation", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createTritonToAnnotationPass());}); + + m.def("add_triton_to_linalg", [](mlir::PassManager &pm, bool globalKernel, + bool namedOps, bool enableNd2nzOnVector, bool enableSelectAnalysis, + bool compileOn91095) { + pm.addPass(mlir::triton::createTritonToLinalgPass( + globalKernel, namedOps, enableNd2nzOnVector, + enableSelectAnalysis, compileOn91095)); }); + + m.def("add_triton_to_unstructure", [](mlir::PassManager &pm, + bool compileOn91095, bool forceSimtTemplate) { + TritonToUnstructureOptions opts; + opts.compileOn91095 = compileOn91095; + opts.forceSimtTemplate = forceSimtTemplate; + pm.addPass(mlir::triton::createTritonToUnstructurePass(opts));}); m.def("add_triton_to_hfusion", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createTritonToHFusionPass()); - }); + pm.addPass(mlir::triton::createTritonToHFusionPass());}); - m.def("add_discrete_mask_access_conversion", - [](mlir::PassManager &pm, bool compileOn91095, bool forceSimtTemplate) { - DiscreteMaskAccessConversionOptions opts; - opts.compileOn91095 = compileOn91095; - opts.forceSimtTemplate = forceSimtTemplate; - pm.addPass( - mlir::triton::createDiscreteMaskAccessConversionPass(opts)); - }); + m.def("add_discrete_mask_access_conversion", [](mlir::PassManager &pm, + bool compileOn91095, bool forceSimtTemplate, bool enableSyncBlockLock) { + DiscreteMaskAccessConversionOptions opts; + opts.compileOn91095 = compileOn91095; + opts.forceSimtTemplate = forceSimtTemplate; + opts.enableSyncBlockLock = enableSyncBlockLock; + pm.addPass(mlir::triton::createDiscreteMaskAccessConversionPass(opts));}); m.def("add_triton_to_hivm", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createTritonToHIVMPass()); - }); + pm.addPass(mlir::triton::createTritonToHIVMPass());}); m.def("add_triton_to_llvm", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createTritonToLLVMPass()); - }); - + pm.addPass(mlir::triton::createTritonToLLVMPass());}); + m.def("add_bubble_up_operation", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createBubbleUpOperationPass()); - }); + pm.addPass(mlir::triton::createBubbleUpOperationPass());}); + + m.def("add_dag_sync", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createDAGSyncPass());}); + + m.def("add_dag_scope", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createDAGScopePass());}); + + m.def("add_dag_ssbuffer", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createDAGSSBufferPass());}); } // Forward declaration for ascend_ir bindings (defined in ascend_ir.cc) @@ -372,7 +362,6 @@ void init_triton_ascend(py::module &&m) { auto passes = m.def_submodule("passes"); // load dialects m.def("load_dialects", [](mlir::MLIRContext &context) { - context.allowUnregisteredDialects(); mlir::DialectRegistry registry; registry.insert(); context.appendDialectRegistry(registry); @@ -381,7 +370,7 @@ void init_triton_ascend(py::module &&m) { init_triton_ascend_passes_ttir(passes.def_submodule("ttir")); init_triton_ascend_ir(m.def_submodule("ascend_ir")); - + // Initialize ascend IR bindings (ascendnpu_ir_builder, scope/hivm dialects) init_ascend_ir(m.def_submodule("ir")); } diff --git a/third_party/ascend/tutorials/03-matrix-multiplication.py b/third_party/ascend/tutorials/03-matrix-multiplication.py new file mode 100644 index 0000000000..beae7d6f97 --- /dev/null +++ b/third_party/ascend/tutorials/03-matrix-multiplication.py @@ -0,0 +1,217 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Matrix Multiplication +=============== +""" + +import triton +import triton.language as tl +import torch +import torch_npu +import triton.language.extra.cann.extension as extension + +DEV = "npu" + + +def get_autotune_config(): + return [ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}), + triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}), + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + ACTIVATION: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + GROUP_SIZE_M: tl.constexpr = 1 + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs_base = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs_base = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + msk_m = offs_am < M + msk_n = offs_bn < N + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a_ptrs = a_ptrs_base + k * BLOCK_SIZE_K * stride_ak + b_ptrs = b_ptrs_base + k * BLOCK_SIZE_K * stride_bk + a = tl.load( + a_ptrs, + mask=msk_m[:, None] and (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=msk_n[None, :] and (offs_k[:, None] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + # Original vector operations + # # ----------------------------------------------------------- + # # Write back the block of the output matrix C with masks. + # Comment out the following lines to enable split the workload to two vector cores + SUB_BLK_M: tl.constexpr = BLOCK_SIZE_M // 2 + for s in extension.parallel(0, 2, bind_sub_block=True): + vec_sub_blk = extension.extract_slice( + accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1) + ) + if ACTIVATION == "leaky_relu_custom": + vec_sub_blk = leaky_relu_custom(vec_sub_blk) + c_sub_blk = vec_sub_blk.to(tl.float16) + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + s * SUB_BLK_M + tl.arange(0, SUB_BLK_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c_sub_blk, mask=c_mask) + + +# We can fuse `leaky_relu_custom` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu_custom(x): + return tl.where(x >= 0, x, 0.01 * x) + 1.0 + + +def torch_matmul(a, b, activation=""): + c = torch.matmul(a, b) + if activation == "leaky_relu_custom": + c = torch.where(c >= 0, c, 0.01 * c) + 1.0 + return c + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + matmul_kernel[grid]( + a, + b, + c, # + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + ACTIVATION=activation, # + ) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). +def test(): + activation = "leaky_relu_custom" + torch.manual_seed(0) + a = torch.randn((512, 512), device=DEV, dtype=torch.float16) + b = torch.randn((512, 512), device=DEV, dtype=torch.float16) + triton_output = matmul(a, b, activation) + torch_output = torch_matmul(a, b, activation) + print(f"triton_output_with_fp16_inputs={triton_output}") + print(f"torch_output_with_fp16_inputs={torch_output}") + torch.testing.assert_close(triton_output, torch_output, atol=1e-3, rtol=1e-3) + print("Passed") + + +if __name__ == "__main__": + test() \ No newline at end of file diff --git a/third_party/ascend/tutorials/04-low-memory-dropout.py b/third_party/ascend/tutorials/04-low-memory-dropout.py new file mode 100644 index 0000000000..0947d85ec3 --- /dev/null +++ b/third_party/ascend/tutorials/04-low-memory-dropout.py @@ -0,0 +1,139 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Low-Memory Dropout +================== +""" + +import tabulate +import torch +import torch_npu + +import triton +import triton.language as tl + +DEV = "npu" + + +@triton.jit +def _dropout( + x_ptr, # pointer to the input + x_keep_ptr, # pointer to a mask of 0s and 1s + output_ptr, # pointer to the output + n_elements, # number of elements in the `x` tensor + p, # probability that an element of `x` is changed to zero + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + # Load data + x = tl.load(x_ptr + offsets, mask=mask) + x_keep = tl.load(x_keep_ptr + offsets, mask=mask) + # The line below is the crucial part, described in the paragraph above! + output = tl.where(x_keep != 0, x / (1 - p), 0.0) + # Write-back output + tl.store(output_ptr + offsets, output, mask=mask) + + +def dropout(x, x_keep, p): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) + return output + + +@triton.jit +def _seeded_dropout( + x_ptr, + output_ptr, + n_elements, + p, + seed, + BLOCK_SIZE: tl.constexpr, +): + # compute memory offsets of elements handled by this instance + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # load data from x + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + # randomly prune it + random = tl.rand(seed, offsets) + x_keep = random > p + # write-back + output = tl.where(x_keep, x / (1 - p), 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + + +def seeded_dropout(x, p, seed): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) + return output + + +def test(): + # Input tensor + x = torch.randn(size=(10, ), device=DEV) + # Dropout mask + p = 0.5 + x_keep = (torch.rand(size=(10, ), device=DEV) > p).to(torch.int32) + # + output = dropout(x, x_keep=x_keep, p=p) + print(tabulate.tabulate([ + ["input"] + x.tolist(), + ["keep mask"] + x_keep.tolist(), + ["output"] + output.tolist(), + ])) + + + x = torch.randn(size=(10, ), device=DEV) + # Compare this to the baseline - dropout mask is never instantiated! + output = seeded_dropout(x, p=0.5, seed=123) + output2 = seeded_dropout(x, p=0.5, seed=123) + output3 = seeded_dropout(x, p=0.5, seed=512) + + print( + tabulate.tabulate([ + ["input"] + x.tolist(), + ["output (seed = 123)"] + output.tolist(), + ["output (seed = 123)"] + output2.tolist(), + ["output (seed = 512)"] + output3.tolist(), + ])) + + +if __name__ == "__main__": + test() \ No newline at end of file diff --git a/third_party/ascend/tutorials/05-layer-norm.py b/third_party/ascend/tutorials/05-layer-norm.py new file mode 100644 index 0000000000..b7361e9300 --- /dev/null +++ b/third_party/ascend/tutorials/05-layer-norm.py @@ -0,0 +1,127 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Layer Normalization +============= +""" + +import pytest +import torch +import triton +import triton.language as tl +import torch_npu + + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +@torch.inference_mode() +def layer_norm(x, normalized_shape, weight, bias, eps=1e-5): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + kernel = _layer_norm_fwd_fused[(M, )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + return y + + +def _layer_norm(M, N, dtype, eps=1e-5, device='npu'): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(True) + # forward pass + y_tri = layer_norm(x, w_shape, weight, bias, eps) + y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + # compare + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + print(f"y_tri: {y_tri}") + print(f"y_ref: {y_ref}") + print(f"Layer Normalization {M},{N} {dtype} PASSED!") + + +if __name__ == "__main__": + _layer_norm(128, 128, torch.float16) + _layer_norm(128, 128, torch.bfloat16) + _layer_norm(128, 128, torch.float32) diff --git a/third_party/ascend/tutorials/06-demo-autotune.py b/third_party/ascend/tutorials/06-demo-autotune.py deleted file mode 100644 index dc37a9e306..0000000000 --- a/third_party/ascend/tutorials/06-demo-autotune.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -""" -Autotune -============= -""" -import torch, torch_npu -import triton -import triton.language as tl - - -def test_triton_autotune(): - # Return a set of different kernel configurations for autotune - def get_autotune_config(): - return [ - triton.Config({'XS': 1 * 128, 'multibuffer': True}), - triton.Config({'XS': 12 * 1024, 'multibuffer': True}), - triton.Config({'XS': 12 * 1024, 'multibuffer': False}), - triton.Config({'XS': 8 * 1024, 'multibuffer': True}), - ] - - # Use @autotune decorator to automatically select the best kernel configuration - @triton.autotune(configs=get_autotune_config(), # List of configurations - key=["numel"], # the change of numel will trigger autotuning - ) - @triton.jit - def triton_calc_kernel(out_ptr0, in_ptr0, in_ptr1, numel, - XS: tl.constexpr # Block size controlling how many elements each thread block processes - ): - pid = tl.program_id(0) # Get current program ID - idx = pid * XS + tl.arange(0, XS) # Index range handled by current thread block - msk = idx < numel # Mask to avoid out-of-bound access - for i in range(10000): - tmp0 = tl.load(in_ptr0 + idx, mask=msk, other=0.0) # Load x0 - tmp1 = tl.load(in_ptr1 + idx, mask=msk, other=0.0) # Load x1 - tmp2 = tl.math.exp(tmp0) + tmp1 + i - tl.store(out_ptr0 + idx, tmp2, mask=msk) # Store result - - # Function to call the Triton kernel with autotuned configuration - def triton_calc_func(x0, x1): - n = x0.numel() - y0 = torch.empty_like(x0) - grid = lambda meta: (triton.cdiv(n, meta["XS"]), 1, 1) - triton_calc_kernel[grid](y0, x0, x1, n) - return y0 - - # Reference implementation using PyTorch for correctness check - def torch_calc_func(x0, x1): - return torch.exp(x0) + x1 + 10000 - 1 - - DEV = "npu" - DTYPE = torch.float32 - N = 192 * 1024 - x0 = torch.randn((N, ), dtype=DTYPE, device=DEV) - x1 = torch.randn((N, ), dtype=DTYPE, device=DEV) - torch_ref = torch_calc_func(x0, x1) - triton_cal = triton_calc_func(x0, x1) - torch.testing.assert_close(triton_cal, torch_ref) - - -if __name__ == "__main__": - test_triton_autotune() - print("success: test_triton_autotune") diff --git a/third_party/ascend/tutorials/06-fused-attention.py b/third_party/ascend/tutorials/06-fused-attention.py new file mode 100644 index 0000000000..dfc03e21b8 --- /dev/null +++ b/third_party/ascend/tutorials/06-fused-attention.py @@ -0,0 +1,365 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Credits: OpenAI kernel team + +Extra Credits: + +* Original flash attention paper (https://arxiv.org/abs/2205.14135) +* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +""" + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.language.extra.cann.extension as extension + + +DEVICE = "npu" + + +@triton.jit +def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, query vector + K_block_ptr, V_block_ptr, # Key and value block pointers for current stage + start_m, qk_scale, # Starting position of current query block, qk scale factor + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # Block size constants + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # Current stage flag, m and n offset indices + N_CTX: tl.constexpr, fp8_v: tl.constexpr): # Total context length, whether to enable FP8 for value precision + # Set the processing range [lo, hi) for the current stage (in column block units) + # Causal attention, as the name implies, restricts the flow of information during computation, + # only allowing the model to see the current and previous positions. + # In other words, the output at the current position can only depend on the input at or before this position, + # and cannot access information from future positions. + # Causal attention ensures sequential order and prevents "leakage of future information." + # But the following logic will also be triggered + if STAGE == 1: + # Stage 1: process all tokens before the query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + # Stage 2: process the current query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) # Align starting position + # causal = False (no need for masking) + else: + lo, hi = 0, N_CTX # Process the entire context + + # Adjust K and V block pointers to the starting position `lo` + K_block_ptr = tl.advance(K_block_ptr, (lo, 0)) # K is [HEAD_DIM, N_CTX], shift along the second dim by lo + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # V is [N_CTX, HEAD_DIM], shift along the first dim by lo + + # Index mapping for the accumulator , used for slicing when HEAD_DIM >= 256 + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + + # Iterate over all k, v blocks in the current stage and accumulate the output + for start_n in range(lo, hi, BLOCK_N): # Process BLOCK_N columns at a time + start_n = tl.multiple_of(start_n, BLOCK_N) # Align column start position + # -- Compute qk ---- + k = tl.load(K_block_ptr) + # Modify K + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + # Apply causal mask for STAGE 2 + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) # Construct upper triangular mask + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) # Set invalid positions to -∞ + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Update m_ij = max(m_i, max(qk)) + qk -= m_ij[:, None] # Subtract max for softmax stability + else: + qk = qk * qk_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max + qk = qk - m_ij[:, None] # Stabilize + + # Softmax weights p = exp(qk) + p = tl.math.exp(qk) + + # Convert softmax weight type depending on FP8 usage + if fp8_v: + p_cast = p.to(tl.float8e5) # Convert to FP8 format (save memory) + else: + p_cast = p.to(k.dtype) + + v = tl.load(V_block_ptr) # Load corresponding V block + pv = tl.dot(p_cast, v) + l_ij = tl.sum(p, 1) # Softmax denominator (sum of each row) + # -- Update m_i and l_i + alpha = tl.math.exp(m_i - m_ij) # Update factor: exp difference between old and new max + l_i = l_i * alpha + l_ij # Update softmax denominator + # -- Update output accumulator -- + if HEAD_DIM < 256: + acc_ptr = acc_ptr * alpha[:, None] + acc_ptr = tl.dot(p_cast, v, acc_ptr) + else: + # 1. Load current slice of accumulator + acc = tl.load(acc_ptr + block2d_acc) + # 2. Update in slices (split by 1/4 of BLOCK_M to avoid ub overflow) + for i in range(4): + # Calculate start/end rows for current slice + offset = i * (BLOCK_M // 4) + # Extract slice data + acc_i = extension.extract_slice(acc, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + alpha_i = extension.extract_slice(alpha, [offset], [BLOCK_M // 4], [1]) + pv_i = extension.extract_slice(pv, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # Incrementally update slice: acc = acc * alpha + pv + acc_i = acc_i * alpha_i[:, None] + pv_i + # Write updated slice back to accumulator + acc = extension.insert_slice(acc, acc_i, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # 3. updated accumulator + tl.store(acc_ptr + block2d_acc, acc) + + m_i = m_ij # Update current block max + # Advance V and K block pointers to next BLOCK_N range + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + # Return accumulated output acc_ptr, softmax denominator l_i, and max value m_i + return acc_ptr, l_i, m_i + + +@triton.jit +def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, + stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, stride_qk: tl.constexpr, + stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, stride_kk: tl.constexpr, + stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, stride_vk: tl.constexpr, + stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, stride_on: tl.constexpr, + Z: tl.constexpr, H: tl.constexpr, + N_CTX: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr + ): + # Total number of blocks in sequence dimension (M) + NUM_BLOCKS_M = N_CTX // BLOCK_M + # Total tasks = number of sequence blocks × batch size (Z) × number of attention heads (H) + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + + # Current M-dimension block index + pid = tl.program_id(0) + + for block_idx in range(pid, NUM_BLOCKS, 20): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + # Create block pointers for Q, K, V, Output + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + # Initialize offsets + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + + # Initialize accumulator + if HEAD_DIM < 256: + acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + else: + acc_offset = ( + off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM + + off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM + + task_m_idx * BLOCK_M * HEAD_DIM + ) + acc_ptr = acc + acc_offset + + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + + m_i += tl.math.log(l_i) + if HEAD_DIM < 256: + accumulator = acc_ptr / l_i[:, None] + else: + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + accumulator = tl.load(acc_ptr + block2d_acc) + accumulator = accumulator / l_i[:, None] + + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, accumulator.to(Out.type.element_ty)) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, BM, BN): + """ + Forward computation interface: + Args: + ctx: Context object + q: Query tensor (Q), shape [Z, H, N_CTX, HEAD_DIM] + k: Key tensor (K), shape [Z, H, N_CTX, HEAD_DIM] + v: Value tensor (V), shape [Z, H, N_CTX, HEAD_DIM] + causal: Whether to enable causal attention + sm_scale: Scaling factor for QK product + BM: Q block size (BLOCK_M) + BN: K/V block size (BLOCK_N) + Returns: + out: Attention output tensor, shape [Z, H, N_CTX, HEAD_DIM] + """ + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + out = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + + + # Number of NPU cores (adjust based on hardware) + num_cores = 20 + acc = torch.zeros((q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), dtype=torch.float32, device=q.device) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[(num_cores,)]( + q, k, v, M, out, acc, sm_scale, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + q.shape[0], q.shape[1], N_CTX=q.shape[2], + HEAD_DIM=HEAD_DIM_K, + BLOCK_M=BM, + BLOCK_N=BN, + STAGE=stage, + **extra_kern_args) + + ctx.save_for_backward(q, k, v, out, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return out + +attention = _attention.apply + + +@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN", [ + (1, 1, 128, 128, False, torch.float16, 32, 128), + (1, 1, 128, 128, False, torch.bfloat16, 64, 128), + (1, 2, 256, 256, False, torch.bfloat16, 32, 256), + (2, 2, 128, 256, False, torch.float16, 64, 128), + (4, 32, 64, 64, False, torch.float16, 32, 64), + (4, 32, 1024, 64, False, torch.bfloat16, 64, 128), + (4, 32, 4096, 64, False, torch.float16, 128, 128), +]) +def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN): + # Filter out non-integer cases; N_CTX must be divisible by BM and BN, and HEAD_DIM must be divisible by 16. + if N_CTX % BM != 0 or N_CTX % BN != 0 or HEAD_DIM % 16 != 0: + pytest.skip("Skipping non-divisible case") + + torch.manual_seed(20) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + + sm_scale = 0.5 + + tri_out = attention(q, k, v, causal, sm_scale, BM, BN) + ref_out = torch_npu.npu_fusion_attention( + q, k, v, H, + padding_mask=None, + atten_mask=None, + scale=sm_scale, + keep_prob=1.0, + input_layout="BNSD", + pre_tockens=65535, + next_tockens=65535, + sparse_mode=0, + )[0] + + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2, equal_nan=True) + print(f"[PASSED] Attention shape:({Z}, {H}, {N_CTX}, {HEAD_DIM}), BM: {BM}, BN: {BN}, dtype: {dtype}") + + +if __name__ == "__main__": + test_op(1, 1, 128, 128, causal=False, dtype=torch.float16, BM=32, BN=128) + test_op(1, 1, 128, 128, causal=False, dtype=torch.bfloat16, BM=64, BN=128) + test_op(1, 2, 256, 256, causal=False, dtype=torch.bfloat16, BM=32, BN=256) + test_op(2, 2, 128, 256, causal=False, dtype=torch.float16, BM=64, BN=128) + test_op(4, 32, 64, 64, causal=False, dtype=torch.float16, BM=32, BN=64) + test_op(4, 32, 1024, 64, causal=False, dtype=torch.bfloat16, BM=64, BN=128) + test_op(4, 32, 4096, 64, causal=False, dtype=torch.float16, BM=128, BN=128) diff --git a/third_party/ascend/tutorials/07-extern-functions.py b/third_party/ascend/tutorials/07-extern-functions.py new file mode 100644 index 0000000000..f48953b0f3 --- /dev/null +++ b/third_party/ascend/tutorials/07-extern-functions.py @@ -0,0 +1,89 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Libdevice (`tl.extra.libdevice`) function +============================== +""" +import inspect +import os +from pathlib import Path + +import torch +import torch_npu + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +from triton.backends.ascend.compiler import get_libdevice + +DEV = "npu" + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) + + +def test(): + torch.manual_seed(0) + size = 98432 + x = torch.rand(size, device=DEV) + output_triton = torch.zeros(size, device=DEV) + output_torch = torch.asin(x) + assert x.device.type == DEV and output_triton.device.type == DEV + n_elements = output_torch.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) + print(output_torch) + print(output_triton) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + + + current_file = inspect.getfile(inspect.currentframe()) + current_dir = Path(os.path.dirname(os.path.abspath(current_file))) + extern_libs = {'libdevice': get_libdevice()} + + output_triton = torch.empty_like(x) + asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, extern_libs=extern_libs) + torch.testing.assert_close(output_torch, output_triton, rtol=1e-4, atol=1e-4) + print(output_torch) + print(output_triton) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + + +if __name__ == "__main__": + test() diff --git a/third_party/ascend/tutorials/07-profiler.py b/third_party/ascend/tutorials/07-profiler.py deleted file mode 100644 index f62800f902..0000000000 --- a/third_party/ascend/tutorials/07-profiler.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - - -def profiler_wrapper(fn, *args): - result_path = "./result_profiling" - skip_first = 10 - wait = 0 - warmup = 3 - active = 30 - repeat = 1 - stream = torch.npu.current_stream() - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) - with torch_npu.profiler.profile( - activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], - schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, - skip_first=skip_first), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), record_shapes=True, - profile_memory=False, with_stack=False, with_flops=False, with_modules=False, - experimental_config=experimental_config) as prof: - stream.synchronize() - for i in range(skip_first + (wait + warmup + active) * repeat): - fn(*args) - prof.step() - stream.synchronize() - - -def test_add(x0, x1): - - def torch_func(x0, x1): - res = x0 + x1 - return res - - @triton.jit - def triton_kernel_add(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr): - idx = tl.arange(0, XS) - tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.load(in_ptr1 + idx) - tmp2 = tmp0 + tmp1 - tl.store(out_ptr0 + idx, tmp2) - - def triton_func(x0, x1): - y0 = torch.empty_like(x0) - triton_kernel_add[1, 1, 1](y0, x0, x1, N) - return y0 - - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1) - torch.testing.assert_close(triton_cal, torch_ref) - - def wrapper_func(x0, x1): - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1) - - profiler_wrapper(wrapper_func, x0, x1) - - -def test_or(x0, x1): - - def torch_func(x0, x1): - res = x0 | x1 - return res - - @triton.jit - def triton_kernel_or(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr): - idx = tl.arange(0, XS) - tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.load(in_ptr1 + idx) - tmp2 = tmp0 | tmp1 - tl.store(out_ptr0 + idx, tmp2) - - def triton_func(x0, x1): - y0 = torch.empty_like(x0) - triton_kernel_or[1, 1, 1](y0, x0, x1, N) - return y0 - - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1) - torch.testing.assert_close(triton_cal, torch_ref) - - def wrapper_func(x0, x1): - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1) - - profiler_wrapper(wrapper_func, x0, x1) - - -def test_inductor_add(x0, x1): - # torch_npu._inductor requires torch_npu 2.6.0+ experimental version - import torch_npu._inductor - - def torch_func(x0, x1): - res = x0 + x1 - return res - - compiled_func = torch.compile(torch_func, backend="inductor") - profiler_wrapper(compiled_func, x0, x1) - print("[INFO] Check ./result_profiling directory to find the kernel_details.csv file. " - " Check the columns: Input Shapes,Input Data Types,Input Formats") - - -if __name__ == "__main__": - test_case_is_inductor = False - N = 1024 - low = 1 - high = 100 - - # float32 - x0_fp32 = torch.rand((N, ), dtype=torch.float32).npu() - x1_fp32 = torch.rand((N, ), dtype=torch.float32).npu() - - # float16 - x0_fp16 = torch.rand((N, ), dtype=torch.float16).npu() - x1_fp16 = torch.rand((N, ), dtype=torch.float16).npu() - - # bfloat16 - x0_bf16 = torch.rand((N, ), dtype=torch.bfloat16).npu() - x1_bf16 = torch.rand((N, ), dtype=torch.bfloat16).npu() - - # int64 - x0_i64 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int64).npu() - x1_i64 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int64).npu() - - # int32 - x0_i32 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int32).npu() - x1_i32 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int32).npu() - - # int16 - x0_i16 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int16).npu() - x1_i16 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int16).npu() - - # int8 - x0_i8 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int8).npu() - x1_i8 = torch.randint(low=low, high=high, size=(N, ), dtype=torch.int8).npu() - - # bool (i1) - x0_i1 = torch.randint(low=0, high=2, size=(N, )).bool().npu() - x1_i1 = torch.randint(low=0, high=2, size=(N, )).bool().npu() - - test_cases = [ - ('fp32', x0_fp32, x1_fp32), - ('fp16', x0_fp16, x1_fp16), - ('bf16', x0_bf16, x1_bf16), - ('i64', x0_i64, x1_i64), - ('i32', x0_i32, x1_i32), - ('i16', x0_i16, x1_i16), - ('i8', x0_i8, x1_i8), - ('i1', x0_i1, x1_i1), - ] - - for dtype_name, x0, x1 in test_cases: - print(f"Running test for {dtype_name}...") - if dtype_name != 'i1': - if (test_case_is_inductor): - test_inductor_add(x0, x1) - else: - test_add(x0, x1) - else: - test_or(x0, x1) diff --git a/third_party/ascend/tutorials/08-grouped-gemm.py b/third_party/ascend/tutorials/08-grouped-gemm.py new file mode 100644 index 0000000000..4ba59bf8e0 --- /dev/null +++ b/third_party/ascend/tutorials/08-grouped-gemm.py @@ -0,0 +1,282 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# Copyright (c) Huawei Technologies Co., Ltd. 2025. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Group GEMM +============================ +""" + +import torch +import torch_npu + +import triton +import triton.language as tl +import triton.runtime.driver as driver + +DEV = "npu" + + +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +NUM_CORES = get_npu_properties()["num_aicore"] + + +@triton.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + ], + key=['group_size'], +) +@triton.jit +def grouped_matmul_kernel( + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + group_gemm_sizes, + g_lds, + group_size, + NUM_SM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + tile_idx = tl.program_id(0) + last_problem_end = 0 + for g in range(group_size): + gm = tl.load(group_gemm_sizes + g * 3) + gn = tl.load(group_gemm_sizes + g * 3 + 1) + gk = tl.load(group_gemm_sizes + g * 3 + 2) + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles): + k = gk + lda = tl.load(g_lds + g * 3) + ldb = tl.load(g_lds + g * 3 + 1) + ldc = tl.load(g_lds + g * 3 + 2) + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) + tile_idx_in_gemm = tile_idx - last_problem_end + tile_m_idx = tile_idx_in_gemm // num_n_tiles + tile_n_idx = tile_idx_in_gemm % num_n_tiles + + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] + b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for _ in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + tl.multiple_of(a_ptrs, [16, 16]) + tl.multiple_of(b_ptrs, [16, 16]) + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * ldb + c = accumulator.to(tl.float16) + + offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] + + tl.store(c_ptrs, c) + tile_idx += NUM_SM + + last_problem_end = last_problem_end + num_tiles + + +def group_gemm_fn(group_A, group_B): + device = torch.device(DEV) + assert len(group_A) == len(group_B) + group_size = len(group_A) + + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = group_A[i] + B = group_B[i] + assert A.shape[1] == B.shape[0] + M, K = A.shape + K, N = B.shape + C = torch.empty((M, N), device=device, dtype=A.dtype) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [M, N, K] + g_lds += [A.stride(0), B.stride(0), C.stride(0)] + + d_a_ptrs = torch.tensor(A_addrs, device=device) + d_b_ptrs = torch.tensor(B_addrs, device=device) + d_c_ptrs = torch.tensor(C_addrs, device=device) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) + + def grid(meta): + return (meta['NUM_SM'],) + + grouped_matmul_kernel[grid]( + d_a_ptrs, + d_b_ptrs, + d_c_ptrs, + d_g_sizes, + d_g_lds, + group_size, + ) + + return group_C + + +def test(): + group_m = [1024, 512, 256, 128] + group_n = [1024, 512, 256, 128] + group_k = [1024, 512, 256, 128] + group_A = [] + group_B = [] + assert len(group_m) == len(group_n) + assert len(group_n) == len(group_k) + group_size = len(group_m) + for i in range(group_size): + M = group_m[i] + N = group_n[i] + K = group_k[i] + A = torch.rand((M, K), device=DEV, dtype=torch.float16) + B = torch.rand((K, N), device=DEV, dtype=torch.float16) + group_A.append(A) + group_B.append(B) + + tri_out = group_gemm_fn(group_A, group_B) + ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] + for i in range(group_size): + torch.testing.assert_close(ref_out[i], tri_out[i], atol=1e-2, rtol=1e-3) + print("Passed") + + +def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): + + def grid(meta): + return (meta['NUM_SM'],) + + grouped_matmul_kernel[grid]( + a_ptrs, + b_ptrs, + c_ptrs, + sizes, + lds, + group_size, + ) + + +def torch_perf_fn(group_A, group_B): + for a, b in zip(group_A, group_B): + torch.matmul(a, b) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['N'], + x_vals=[2**i for i in range(7, 11)], + line_arg='provider', + line_vals=['torch', 'triton'], + line_names=["Torch", "Triton"], + styles=[('green', '-'), ('blue', '-')], + ylabel="runtime(ms)", + plot_name="group-gemm-performance", + args={}, + )) +def benchmark(N, provider): + group_size = 4 + group_A = [] + group_B = [] + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for _ in range(group_size): + A = torch.rand((N, N), device=DEV, dtype=torch.float16) + B = torch.rand((N, N), device=DEV, dtype=torch.float16) + C = torch.empty((N, N), device=DEV, dtype=torch.float16) + group_A.append(A) + group_B.append(B) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [N, N, N] + g_lds += [N, N, N] + + d_a_ptrs = torch.tensor(A_addrs, device=DEV) + d_b_ptrs = torch.tensor(B_addrs, device=DEV) + d_c_ptrs = torch.tensor(C_addrs, device=DEV) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEV) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEV) + + quantiles = [0.5, 0.2, 0.8] + + def bench_torch(): + torch_perf_fn(group_A, group_B) + + def bench_triton(): + triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size) + + if provider == 'torch': + ms, min_ms, max_ms = triton.testing.do_bench(bench_torch, quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(bench_triton, quantiles=quantiles) + return ms, max_ms, min_ms + + +if __name__ == "__main__": + test() \ No newline at end of file diff --git a/third_party/ascend/tutorials/09-persistent-matmul.py b/third_party/ascend/tutorials/09-persistent-matmul.py new file mode 100644 index 0000000000..f80c4852a9 --- /dev/null +++ b/third_party/ascend/tutorials/09-persistent-matmul.py @@ -0,0 +1,337 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Persistent Matmul +===================== +""" + +import argparse +import time + +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver + +DEV = "npu" +DTYPE = torch.float16 + + +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +def get_num_compute_cores(): + return get_npu_properties()["num_aicore"] + + +def _matmul_launch_metadata(grid, kernel, args): + ret = {} + M, N, K = args["M"], args["N"], args["K"] + bytes_per_elem = args["c_ptr"].element_size() + ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K + ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) + return ret + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_sm = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_sm += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + pid_m = 0 + pid_n = 0 + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in range(0, k_tiles * tiles_per_sm): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + if ki == k_tiles - 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def get_configs(dtype): + return { + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + } + }[dtype] + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + M, K = a.shape + _, N = b.shape + configs = get_configs(a.dtype) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(meta): + return (triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) + + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=configs["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs["GROUP_SIZE_M"], + num_stages=configs["num_stages"], + num_warps=configs["num_warps"], + ) + return c + + +def matmul_persistent(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + num_sms = get_num_compute_cores() + M, K = a.shape + _, N = b.shape + configs = get_configs(a.dtype) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(meta): + return (min(num_sms, triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"])), ) + + matmul_kernel_persistent[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=configs["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs["GROUP_SIZE_M"], + NUM_SMS=num_sms, + num_stages=configs["num_stages"], + num_warps=configs["num_warps"], + ) + return c + + +def torch_matmul(a, b): + return torch.matmul(a, b) + + +def bench(K, reps=10): + M = 8192 + N = 8192 + a = torch.randn((M, K), device=DEV, dtype=DTYPE) + b = torch.randn((K, N), device=DEV, dtype=DTYPE) + + for _ in range(reps): + _ = torch_matmul(a, b) + time.sleep(0.01) + for _ in range(reps): + _ = matmul(a, b) + time.sleep(0.01) + for _ in range(reps): + _ = matmul_persistent(a, b) + time.sleep(0.01) + + +def validate(M, N, K): + a = torch.randn((M, K), device=DEV, dtype=DTYPE) + b = torch.randn((K, N), device=DEV, dtype=DTYPE) + + torch_result = torch_matmul(a, b) + naive_result = matmul(a, b) + persistent_result = matmul_persistent(a, b) + + naive_vs_torch = "✅" if torch.allclose(naive_result, torch_result, atol=1.0) else "❌" + persistent_vs_torch = "✅" if torch.allclose(persistent_result, torch_result, atol=1.0) else "❌" + naive_vs_persistent = "✅" if torch.allclose(naive_result, persistent_result, atol=1.0) else "❌" + + print( + f"M={M}, N={N}, K={K} verification naive vs torch: {naive_vs_torch} " + f"persistent vs torch: {persistent_vs_torch} naive vs persistent: {naive_vs_persistent}" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-K", type=int, required=False, default=512) + parser.add_argument("--K_range", type=int, nargs=2) + parser.add_argument("--K_step", type=int, default=512) + args = parser.parse_args() + + if args.K and args.K_range is None: + args.K_range = [args.K, args.K] + args.K_step = 1 + + torch.manual_seed(0) + + validate(32, 32, 32) + validate(8192, 8192, 512) + + for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + bench(K) diff --git a/third_party/ascend/tutorials/15-embedding_gather_demo.py b/third_party/ascend/tutorials/15-embedding_gather_demo.py deleted file mode 100644 index 84fd70ef27..0000000000 --- a/third_party/ascend/tutorials/15-embedding_gather_demo.py +++ /dev/null @@ -1,118 +0,0 @@ -# only available on 910_95 -import torch -import torch_npu -from torch import empty_strided -from torch._dynamo.testing import rand_strided -import triton -import triton.language as tl - -y0_numel = 128 -r1_numel = 50 -x2_numel = 16 -embedding_size = 1353406 - - -def profiler_wrapper(fn, *args): - result_path = "./result_profiling" - skip_first = 10 - wait = 0 - warmup = 3 - active = 30 - repeat = 1 - stream = torch.npu.current_stream() - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) - with torch_npu.profiler.profile( - activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], - schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, - skip_first=skip_first), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), record_shapes=True, - profile_memory=False, with_stack=False, with_flops=False, with_modules=False, - experimental_config=experimental_config) as prof: - stream.synchronize() - for i in range(skip_first + (wait + warmup + active) * repeat): - fn(*args) - prof.step() - stream.synchronize() - - -def get_autotune_config(): - return [ - triton.Config({ - 'Y0BLOCK': 4, 'Y0BLOCK_SUB': 2, 'X2BLOCK_SUB': x2_numel, 'R1BLOCK_SUB': r1_numel, 'EMBEDDING_SIZE': - embedding_size, 'multibuffer': False - }), - ] - - -@triton.autotune(configs=get_autotune_config(), # List of configurations - key=["numel"], # the change of numel will trigger autotuning - ) -@triton.jit -def triton_unk_fused_embedding_eq_sum_where_zeros_like_0(in_ptr0, in_ptr1, out_ptr0, y0_numel, r1_numel, x2_numel, - Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, - X2BLOCK_SUB: tl.constexpr, R1BLOCK_SUB: tl.constexpr, - EMBEDDING_SIZE: tl.constexpr): - y0_offset = tl.program_id(0) * Y0BLOCK - base_y0 = tl.arange(0, Y0BLOCK_SUB) - loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB - base_r1 = tl.arange(0, R1BLOCK_SUB) - base_x2 = tl.arange(0, X2BLOCK_SUB) - r1 = base_r1[None, None, :] - r1_mask = r1 < r1_numel - x2 = base_x2[None, None, :] - x2_mask = x2 < x2_numel - # loops_x1 = (x1_numel + X2BLOCK_SUB - 1) // X2BLOCK_SUB - # loops_r2 = (r1_numel + R1BLOCK_SUB - 1) // R1BLOCK_SUB - for loop_y0 in range(loops_y0): - y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None, None] - y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) - tmp0 = tl.load(in_ptr0 + (r1 + 50 * y0), r1_mask & y0_mask, other=0.0).to(tl.int32) - tmp1 = tl.full([1, 1, 1], -1, tl.int32) - tmp2 = tmp0 == tmp1 - tmp3 = tl.full([1, 1, 1], 0, tl.int32) - tmp4 = tl.where(tmp2, tmp3, tmp0) - # tmp5 = tl.full([Y0BLOCK_SUB, X2BLOCK_SUB, R1BLOCK_SUB], 1353406, tl.int32) - # tmp6 = tmp4 + tmp5 - # tmp7 = tmp4 < 0 - # tmp8 = tl.where(tmp7, tmp6, tmp4) - # tl.device_assert(((0 <= tmp8) & (tmp8 < 1353406)) | ~(r2_mask & y0_mask), "index out of bounds: 0 <= tmp8 < 1353406") - # tmp10 = tl.load(in_ptr1 + (x1 + 16*tmp8), r2_mask & x1_mask & y0_mask) - # 用下面这行替换上述6行 SIMT - tmp8 = tl.reshape(tmp4, [Y0BLOCK_SUB, R1BLOCK_SUB]) - tmp10 = tl.index_select(in_ptr1, tmp8, EMBEDDING_SIZE, X2BLOCK_SUB, (y0_offset + (loop_y0 * Y0BLOCK_SUB), 0, 0), - (y0_numel, r1_numel, x2_numel)) - tmp14 = tl.sum(tmp10, 1).reshape(Y0BLOCK_SUB, 1, X2BLOCK_SUB) - tl.store(out_ptr0 + (x2 + 16 * y0), tmp14, x2_mask & y0_mask) - - -def triton_func(arg34_1: torch.Tensor, arg35_1: torch.Tensor, buf0: torch.Tensor): - y0_size, _ = arg34_1.size() - grid = lambda meta: (triton.cdiv(y0_size, meta['Y0BLOCK']), ) - triton_unk_fused_embedding_eq_sum_where_zeros_like_0[grid](arg34_1, arg35_1, buf0, y0_numel, r1_numel, x2_numel) - return buf0 - - -def torch_func(x0: torch.Tensor): - return torch.sqrt(x0) - - -torch.manual_seed(0) - -arg34_1 = rand_strided((y0_numel, r1_numel), (r1_numel, 1), device='npu', dtype=torch.int64) -arg35_1 = rand_strided((embedding_size, x2_numel), (x2_numel, 1), device='npu', dtype=torch.float32) -buf0 = empty_strided((y0_numel, x2_numel), (x2_numel, 1), device='npu', dtype=torch.float32) - -output_triton = triton_func(arg34_1, arg35_1, buf0) -print("triton = ", output_triton) - -# output_torch = torch_func(x0) -# print("torch = ", output_torch) -# torch.testing.assert_close(output_triton.cpu(), output_torch.cpu()) - -# def wrapper_func(x0, x1): -# torch_ref = torch_func(x0, x1) -# triton_cal = triton_func(x0, x1) - -# profiler_wrapper(wrapper_func, x0, x1) diff --git a/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/copy_use_analysis.mlir b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/copy_use_analysis.mlir new file mode 100644 index 0000000000..b8074c87ff --- /dev/null +++ b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/copy_use_analysis.mlir @@ -0,0 +1,263 @@ +// RUN: triton-opt -allow-unregistered-dialect '--triton-to-linalg=named-ops=True enable-nd2nz-on-vector=True compile-on-910-95=True' --split-input-file %s -verify-each 2>&1 | FileCheck %s --check-prefix=NOERR +// NOERR-NOT: failed to legalize unresolved materialization +// CHECK: module +// CHECK: func.func public @backward_dkdv + +module { + tt.func public @backward_dkdv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32, %arg13: i32, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: f32, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32 {tt.divisibility = 16 : i32}, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32 {tt.divisibility = 16 : i32}, %arg25: i32 {tt.divisibility = 16 : i32}, %arg26: i32 {tt.divisibility = 16 : i32}, %arg27: i32 {tt.divisibility = 16 : i32}, %arg28: i32 {tt.divisibility = 16 : i32}, %arg29: i32 {tt.divisibility = 16 : i32}, %arg30: i32 {tt.divisibility = 16 : i32}, %arg31: i32 {tt.divisibility = 16 : i32}, %arg32: i32) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32> + %cst_0 = arith.constant dense<0xFF800000> : tensor<32x32xf32> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32> + %cst_2 = arith.constant dense<1> : tensor<32xi32> + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i32 = arith.constant 32 : i32 + %cst_3 = arith.constant 1.44269502 : f32 + %c1_i32 = arith.constant 1 : i32 + %alloc = memref.alloc() : memref<2x2x16x16xf16, #hivm.address_space> + %alloc_4 = memref.alloc() : memref<32x64xf32, #hivm.address_space> + %alloc_5 = memref.alloc() : memref<2x2x16x16xf16, #hivm.address_space> + %alloc_6 = memref.alloc() : memref<32x32xf32, #hivm.address_space> + %alloc_7 = memref.alloc() : memref<32x32xf32, #hivm.address_space> + %alloc_8 = memref.alloc() : memref<32x64xf32, #hivm.address_space> + %0 = tt.get_program_id x : i32 + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 10 + hivm.hir.sync_block_set[, , ] flag = 9 + hivm.hir.sync_block_set[, , ] flag = 8 + hivm.hir.sync_block_set[, , ] flag = 7 + %1 = tt.get_num_programs x : i32 + %2 = arith.mulf %arg15, %cst_3 : f32 + %3 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> + %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %7 = tt.splat %arg14 : i32 -> tensor<1x64xi32> + %8 = arith.cmpi slt, %6, %7 : tensor<1x64xi32> + %9 = tt.broadcast %8 : tensor<1x64xi1> -> tensor<32x64xi1> + %10 = tt.splat %arg27 : i32 -> tensor<32x1xi32> + %11 = arith.muli %4, %10 : tensor<32x1xi32> + %12 = tt.broadcast %6 : tensor<1x64xi32> -> tensor<32x64xi32> + %13 = tt.splat %arg30 : i32 -> tensor<32x1xi32> + %14 = arith.muli %4, %13 : tensor<32x1xi32> + %15 = tt.splat %arg9 : i32 -> tensor<32xi32> + %16 = arith.muli %3, %15 : tensor<32xi32> + %17 = tt.splat %arg8 : i32 -> tensor<32xi32> + %18 = arith.addi %16, %17 : tensor<32xi32> + %19 = arith.subi %18, %cst_2 : tensor<32xi32> + %20 = arith.subi %arg8, %c1_i32 : i32 + %21 = tt.expand_dims %19 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> + %22 = tt.broadcast %21 : tensor<32x1xi32> -> tensor<32x32xi32> + %23 = tt.splat %2 : f32 -> tensor<32x32xf32> + %24 = tt.splat %arg15 : f32 -> tensor<32x32xf32> + scf.for %arg33 = %0 to %arg32 step %1 : i32 { + %25 = arith.divsi %arg33, %arg32 : i32 + %26 = arith.remsi %arg33, %arg32 : i32 + %27 = arith.divsi %26, %arg13 : i32 + %28 = arith.remsi %26, %arg13 : i32 + %29 = tt.addptr %arg10, %25 : !tt.ptr, i32 + %30 = tt.load %29 : !tt.ptr + %31 = tt.addptr %29, %c1_i32 : !tt.ptr, i32 + %32 = tt.load %31 : !tt.ptr + %33 = arith.subi %32, %30 : i32 + %34 = tt.addptr %arg11, %25 : !tt.ptr, i32 + %35 = tt.load %34 : !tt.ptr + %36 = tt.addptr %34, %c1_i32 : !tt.ptr, i32 + %37 = tt.load %36 : !tt.ptr + %38 = arith.subi %37, %35 : i32 + %39 = tt.splat %38 : i32 -> tensor<32x1xi32> + %40 = arith.cmpi slt, %4, %39 : tensor<32x1xi32> + %41 = tt.broadcast %40 : tensor<32x1xi1> -> tensor<32x64xi1> + %42 = arith.andi %41, %9 : tensor<32x64xi1> + %43 = arith.muli %35, %arg27 : i32 + %44 = tt.addptr %arg6, %43 : !tt.ptr, i32 + %45 = arith.muli %27, %arg28 : i32 + %46 = tt.addptr %44, %45 : !tt.ptr, i32 + %47 = arith.muli %28, %arg26 : i32 + %48 = tt.addptr %46, %47 : !tt.ptr, i32 + %49 = tt.splat %48 : !tt.ptr -> tensor<32x1x!tt.ptr> + %50 = tt.addptr %49, %11 : tensor<32x1x!tt.ptr>, tensor<32x1xi32> + %51 = tt.broadcast %50 : tensor<32x1x!tt.ptr> -> tensor<32x64x!tt.ptr> + %52 = tt.addptr %51, %12 : tensor<32x64x!tt.ptr>, tensor<32x64xi32> + %53 = arith.muli %35, %arg30 : i32 + %54 = tt.addptr %arg7, %53 : !tt.ptr, i32 + %55 = arith.muli %27, %arg31 : i32 + %56 = tt.addptr %54, %55 : !tt.ptr, i32 + %57 = arith.muli %28, %arg29 : i32 + %58 = tt.addptr %56, %57 : !tt.ptr, i32 + %59 = tt.splat %58 : !tt.ptr -> tensor<32x1x!tt.ptr> + %60 = tt.addptr %59, %14 : tensor<32x1x!tt.ptr>, tensor<32x1xi32> + %61 = tt.broadcast %60 : tensor<32x1x!tt.ptr> -> tensor<32x64x!tt.ptr> + %62 = tt.addptr %61, %12 : tensor<32x64x!tt.ptr>, tensor<32x64xi32> + %63 = arith.extsi %33 : i32 to i64 + %64 = tt.addptr %arg4, %30 : !tt.ptr, i32 + %65 = arith.muli %26, %arg23 : i32 + %66 = tt.addptr %64, %65 : !tt.ptr, i32 + %67 = tt.addptr %arg3, %30 : !tt.ptr, i32 + %68 = arith.muli %26, %arg22 : i32 + %69 = tt.addptr %67, %68 : !tt.ptr, i32 + %70:6 = scf.for %arg34 = %20 to %33 step %c32_i32 iter_args(%arg35 = %cst, %arg36 = %cst, %arg37 = %20, %arg38 = %20, %arg39 = %20, %arg40 = %20) -> (tensor<32x64xf32>, tensor<32x64xf32>, i32, i32, i32, i32) : i32 { + %73 = tt.make_tensor_ptr %66, [%63], [%c1_i64], [%arg40] {order = array} : > + %74 = tt.make_tensor_ptr %69, [%63], [%c1_i64], [%arg39] {order = array} : > + %75 = tt.load %74 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %76 = tt.expand_dims %75 {axis = 0 : i32} : tensor<32xf32> -> tensor<1x32xf32> + %77 = tt.load %73 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %78 = tt.expand_dims %77 {axis = 0 : i32} : tensor<32xf32> -> tensor<1x32xf32> + %79 = tt.splat %arg34 : i32 -> tensor<32xi32> + %80 = arith.addi %3, %79 : tensor<32xi32> + %81 = tt.expand_dims %80 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %82 = tt.broadcast %81 : tensor<1x32xi32> -> tensor<32x32xi32> + %83 = arith.cmpi sle, %22, %82 : tensor<32x32xi32> + %84 = arith.select %83, %cst_1, %cst_0 : tensor<32x32xi1>, tensor<32x32xf32> + hivm.hir.sync_block_wait[, , ] flag = 1 + %memspacecast = memref.memory_space_cast %alloc_7 : memref<32x32xf32, #hivm.address_space> to memref<32x32xf32> + %85 = bufferization.to_tensor %memspacecast restrict writable : memref<32x32xf32> + %86 = arith.mulf %85, %23 : tensor<32x32xf32> + %87 = arith.addf %84, %86 : tensor<32x32xf32> + %88 = tt.broadcast %76 : tensor<1x32xf32> -> tensor<32x32xf32> + %89 = arith.subf %87, %88 : tensor<32x32xf32> + %90 = math.exp2 %89 : tensor<32x32xf32> + %91 = arith.mulf %24, %90 : tensor<32x32xf32> + %92 = tt.broadcast %78 : tensor<1x32xf32> -> tensor<32x32xf32> + hivm.hir.sync_block_wait[, , ] flag = 2 + %memspacecast_9 = memref.memory_space_cast %alloc_6 : memref<32x32xf32, #hivm.address_space> to memref<32x32xf32> + %93 = bufferization.to_tensor %memspacecast_9 restrict writable : memref<32x32xf32> + %94 = arith.subf %93, %92 : tensor<32x32xf32> + %95 = arith.mulf %91, %94 : tensor<32x32xf32> + %96 = arith.truncf %90 : tensor<32x32xf32> to tensor<32x32xf16> + %97 = tt.reshape %96 : tensor<32x32xf16> -> tensor<2x16x2x16xf16> + %98 = tt.trans %97 {order = array} : tensor<2x16x2x16xf16> -> tensor<2x2x16x16xf16> + hivm.hir.sync_block_set[, , ] flag = 7 + hivm.hir.sync_block_set[, , ] flag = 8 + hivm.hir.sync_block_wait[, , ] flag = 11 + %99 = bufferization.to_memref %98 : memref<2x2x16x16xf16, #hivm.address_space> + hivm.hir.copy ins(%99 : memref<2x2x16x16xf16, #hivm.address_space>) outs(%alloc : memref<2x2x16x16xf16, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 5 + %100 = arith.truncf %95 : tensor<32x32xf32> to tensor<32x32xf16> + %101 = tt.reshape %100 : tensor<32x32xf16> -> tensor<2x16x2x16xf16> + %102 = tt.trans %101 {order = array} : tensor<2x16x2x16xf16> -> tensor<2x2x16x16xf16> + hivm.hir.sync_block_wait[, , ] flag = 12 + %103 = bufferization.to_memref %102 : memref<2x2x16x16xf16, #hivm.address_space> + hivm.hir.copy ins(%103 : memref<2x2x16x16xf16, #hivm.address_space>) outs(%alloc_5 : memref<2x2x16x16xf16, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 3 + hivm.hir.sync_block_wait[, , ] flag = 4 + %memspacecast_10 = memref.memory_space_cast %alloc_4 : memref<32x64xf32, #hivm.address_space> to memref<32x64xf32> + %104 = bufferization.to_tensor %memspacecast_10 restrict writable : memref<32x64xf32> + %105 = arith.addf %104, %arg35 : tensor<32x64xf32> + hivm.hir.sync_block_wait[, , ] flag = 6 + %memspacecast_11 = memref.memory_space_cast %alloc_8 : memref<32x64xf32, #hivm.address_space> to memref<32x64xf32> + %106 = bufferization.to_tensor %memspacecast_11 restrict writable : memref<32x64xf32> + %107 = arith.addf %106, %arg36 : tensor<32x64xf32> + %108 = arith.addi %arg37, %c32_i32 : i32 + %109 = arith.addi %arg38, %c32_i32 : i32 + %110 = arith.addi %arg39, %c32_i32 : i32 + %111 = arith.addi %arg40, %c32_i32 : i32 + hivm.hir.sync_block_set[, , ] flag = 9 + hivm.hir.sync_block_set[, , ] flag = 10 + scf.yield %105, %107, %108, %109, %110, %111 : tensor<32x64xf32>, tensor<32x64xf32>, i32, i32, i32, i32 + } + %71 = arith.truncf %70#0 : tensor<32x64xf32> to tensor<32x64xf16> + tt.store %52, %71, %42 : tensor<32x64x!tt.ptr> + %72 = arith.truncf %70#1 : tensor<32x64xf32> to tensor<32x64xf16> + tt.store %62, %72, %42 : tensor<32x64x!tt.ptr> + } + hivm.hir.sync_block_wait[, , ] flag = 11 + hivm.hir.sync_block_wait[, , ] flag = 12 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 12 + hivm.hir.sync_block_set[, , ] flag = 11 + %1 = tt.get_num_programs x : i32 + %2 = arith.extsi %arg14 : i32 to i64 + %3 = arith.extsi %arg18 : i32 to i64 + %4 = arith.extsi %arg20 : i32 to i64 + %5 = arith.subi %arg8, %c1_i32 : i32 + %6 = arith.extsi %arg16 : i32 to i64 + %7 = arith.extsi %arg24 : i32 to i64 + scf.for %arg33 = %0 to %arg32 step %1 : i32 { + %8 = arith.divsi %arg33, %arg32 : i32 + %9 = arith.remsi %arg33, %arg32 : i32 + %10 = arith.divsi %9, %arg13 : i32 + %11 = tt.addptr %arg10, %8 : !tt.ptr, i32 + %12 = tt.load %11 : !tt.ptr + %13 = tt.addptr %11, %c1_i32 : !tt.ptr, i32 + %14 = tt.load %13 : !tt.ptr + %15 = arith.subi %14, %12 : i32 + %16 = tt.addptr %arg11, %8 : !tt.ptr, i32 + %17 = tt.load %16 : !tt.ptr + %18 = tt.addptr %16, %c1_i32 : !tt.ptr, i32 + %19 = tt.load %18 : !tt.ptr + %20 = arith.subi %19, %17 : i32 + %21 = arith.muli %17, %arg18 : i32 + %22 = tt.addptr %arg1, %21 : !tt.ptr, i32 + %23 = arith.muli %10, %arg19 : i32 + %24 = tt.addptr %22, %23 : !tt.ptr, i32 + %25 = arith.extsi %20 : i32 to i64 + %26 = tt.make_tensor_ptr %24, [%25, %2], [%3, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %27 = arith.muli %17, %arg20 : i32 + %28 = tt.addptr %arg2, %27 : !tt.ptr, i32 + %29 = arith.muli %10, %arg21 : i32 + %30 = tt.addptr %28, %29 : !tt.ptr, i32 + %31 = tt.make_tensor_ptr %30, [%25, %2], [%4, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %32 = tt.load %26 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %33 = tt.load %31 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %34 = arith.muli %12, %arg16 : i32 + %35 = tt.addptr %arg0, %34 : !tt.ptr, i32 + %36 = arith.muli %9, %arg17 : i32 + %37 = tt.addptr %35, %36 : !tt.ptr, i32 + %38 = arith.extsi %15 : i32 to i64 + %39 = arith.muli %12, %arg24 : i32 + %40 = tt.addptr %arg5, %39 : !tt.ptr, i32 + %41 = arith.muli %9, %arg25 : i32 + %42 = tt.addptr %40, %41 : !tt.ptr, i32 + %43:4 = scf.for %arg34 = %5 to %15 step %c32_i32 iter_args(%arg35 = %5, %arg36 = %5, %arg37 = %5, %arg38 = %5) -> (i32, i32, i32, i32) : i32 { + %44 = tt.make_tensor_ptr %42, [%2, %38], [%c1_i64, %7], [%c0_i32, %arg36] {order = array} : > + %45 = tt.make_tensor_ptr %37, [%2, %38], [%c1_i64, %6], [%c0_i32, %arg35] {order = array} : > + %46 = tt.load %45 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %47 = tt.load %44 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %48 = tt.dot %32, %46, %cst_1 : tensor<32x64xf16> * tensor<64x32xf16> -> tensor<32x32xf32> + hivm.hir.sync_block_wait[, , ] flag = 7 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%48 : tensor<32x32xf32>) outs(%alloc_7 : memref<32x32xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 1 + %49 = tt.dot %33, %47, %cst_1 : tensor<32x64xf16> * tensor<64x32xf16> -> tensor<32x32xf32> + hivm.hir.sync_block_wait[, , ] flag = 8 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%49 : tensor<32x32xf32>) outs(%alloc_6 : memref<32x32xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 2 + %50 = tt.trans %46 {order = array} : tensor<64x32xf16> -> tensor<32x64xf16> + hivm.hir.sync_block_wait[, , ] flag = 3 + %51 = hivm.hir.convert_layout %alloc_5 {dstLayout = #hivm.data_layout, srcLayout = #hivm.data_layout} : (memref<2x2x16x16xf16, #hivm.address_space>) -> memref<32x32xf16, #hivm.address_space> + %memspacecast = memref.memory_space_cast %51 : memref<32x32xf16, #hivm.address_space> to memref<32x32xf16> + %52 = bufferization.to_tensor %memspacecast restrict writable : memref<32x32xf16> + %53 = tt.dot %52, %50, %cst : tensor<32x32xf16> * tensor<32x64xf16> -> tensor<32x64xf32> + hivm.hir.sync_block_set[, , ] flag = 12 + hivm.hir.sync_block_wait[, , ] flag = 9 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%53 : tensor<32x64xf32>) outs(%alloc_4 : memref<32x64xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 4 + %54 = tt.trans %47 {order = array} : tensor<64x32xf16> -> tensor<32x64xf16> + hivm.hir.sync_block_wait[, , ] flag = 5 + %55 = hivm.hir.convert_layout %alloc {dstLayout = #hivm.data_layout, srcLayout = #hivm.data_layout} : (memref<2x2x16x16xf16, #hivm.address_space>) -> memref<32x32xf16, #hivm.address_space> + %memspacecast_9 = memref.memory_space_cast %55 : memref<32x32xf16, #hivm.address_space> to memref<32x32xf16> + %56 = bufferization.to_tensor %memspacecast_9 restrict writable : memref<32x32xf16> + %57 = tt.dot %56, %54, %cst : tensor<32x32xf16> * tensor<32x64xf16> -> tensor<32x64xf32> + hivm.hir.sync_block_set[, , ] flag = 11 + hivm.hir.sync_block_wait[, , ] flag = 10 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%57 : tensor<32x64xf32>) outs(%alloc_8 : memref<32x64xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 6 + %58 = arith.addi %arg35, %c32_i32 : i32 + %59 = arith.addi %arg36, %c32_i32 : i32 + %60 = arith.addi %arg37, %c32_i32 : i32 + %61 = arith.addi %arg38, %c32_i32 : i32 + scf.yield %58, %59, %60, %61 : i32, i32, i32, i32 + } + } + hivm.hir.sync_block_wait[, , ] flag = 7 + hivm.hir.sync_block_wait[, , ] flag = 8 + hivm.hir.sync_block_wait[, , ] flag = 9 + hivm.hir.sync_block_wait[, , ] flag = 10 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + tt.return + } +} + diff --git a/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/fixpipe_use_analysis.mlir b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/fixpipe_use_analysis.mlir new file mode 100644 index 0000000000..b64a5b826c --- /dev/null +++ b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/fixpipe_use_analysis.mlir @@ -0,0 +1,421 @@ +// RUN: triton-opt -allow-unregistered-dialect '--triton-to-linalg=named-ops=True enable-nd2nz-on-vector=True compile-on-910-95=True' --split-input-file %s -verify-each 2>&1 | FileCheck %s --check-prefix=NOERR +// NOERR-NOT: failed to legalize unresolved materialization +// CHECK: module +// CHECK: func.func public @_hstu_attn_fwd + +module { + tt.func public @_hstu_attn_fwd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: f32, %arg9: f32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c256 = arith.constant 256 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : tensor<256x1xi64> + %cst_0 = arith.constant dense : tensor<256x32xi1> + %cst_1 = arith.constant dense<0> : tensor<32x1xi64> + %cst_2 = arith.constant dense : tensor<32x32xi1> + %c32_i32 = arith.constant 32 : i32 + %c2_i32 = arith.constant 2 : i32 + %c8_i64 = arith.constant 8 : i64 + %c256_i32 = arith.constant 256 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i64 = arith.constant 2 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c128_i64 = arith.constant 128 : i64 + %c255_i32 = arith.constant 255 : i32 + %c1_i64 = arith.constant 1 : i64 + %c3_i32 = arith.constant 3 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_3 = arith.constant dense<0.000000e+00> : tensor<256x32xf16> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x32xf16> + %cst_5 = arith.constant dense<0.000000e+00> : tensor<32x256xf32> + %cst_6 = arith.constant dense<128> : tensor<256x1xi64> + %cst_7 = arith.constant dense<256> : tensor<32x1xi64> + %cst_8 = arith.constant dense<1.000000e+00> : tensor<32x256xf32> + %cst_9 = arith.constant dense<0.000000e+00> : tensor<32x32xf32> + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(32 : i64) : i64 + %2 = llvm.mlir.constant(64 : i64) : i64 + %3 = llvm.mlir.constant(96 : i64) : i64 + %4 = llvm.mlir.constant(0 : i32) : i32 + %5 = llvm.mlir.constant(1 : i32) : i32 + %6 = llvm.mlir.constant(2 : i64) : i64 + %7 = llvm.mlir.constant(2 : i32) : i32 + %8 = llvm.mlir.constant(4 : i32) : i32 + %9 = llvm.mlir.constant(6 : i32) : i32 + %10 = llvm.mlir.constant(1 : i64) : i64 + %11 = llvm.mlir.constant(3 : i32) : i32 + %c64_i64 = arith.constant 64 : i64 + %12 = llvm.mlir.constant(5 : i32) : i32 + %c0_i64 = arith.constant 0 : i64 + %alloc = memref.alloc() : memref<16x2x16x16xf16, #hivm.address_space> + %alloc_10 = memref.alloc() : memref<32x256xf32, #hivm.address_space> + %alloc_11 = memref.alloc() : memref<32x32xf32, #hivm.address_space> + %13 = tt.get_program_id x : i32 + %14 = tt.get_num_programs x : i32 + %15 = arith.cmpi sle, %arg10, %c32_i32 : i32 + %16 = scf.if %15 -> (i64) { + scf.yield %c2_i64 : i64 + } else { + %41 = tt.addptr %arg5, %c2_i32 : !tt.ptr, i32 + %42 = tt.load %41 : !tt.ptr + %43 = arith.extsi %42 : i32 to i64 + scf.yield %43 : i64 + } + %17 = arith.muli %16, %c8_i64 : i64 + %18 = arith.extsi %14 : i32 to i64 + %19 = arith.minsi %18, %17 : i64 + %20 = arith.divsi %17, %19 : i64 + %21 = arith.addi %20, %c1_i64 : i64 + %22 = arith.remsi %17, %19 : i64 + %23 = arith.extsi %13 : i32 to i64 + %24 = arith.cmpi slt, %23, %19 : i64 + %25 = arith.cmpi slt, %23, %22 : i64 + %26 = arith.muli %23, %21 : i64 + %27 = arith.muli %22, %21 : i64 + %28 = arith.subi %23, %22 : i64 + %29 = arith.muli %28, %20 : i64 + %30 = arith.addi %27, %29 : i64 + %31 = arith.select %25, %26, %30 : i64 + %32 = arith.select %24, %31, %c0_i64 : i64 + %33 = arith.select %25, %21, %20 : i64 + %34 = arith.select %24, %33, %c0_i64 : i64 + %35 = arith.cmpi sge, %23, %19 : i64 + cf.cond_br %35, ^bb1, ^bb2 + ^bb1: // 2 preds: ^bb0, ^bb2 + tt.return + ^bb2: // pred: ^bb0 + %36 = arith.cmpi sle, %34, %c0_i64 : i64 + cf.cond_br %36, ^bb1, ^bb3 + ^bb3: // pred: ^bb2 + %37 = llvm.inttoptr %0 : i64 to !llvm.ptr<11> + %38 = llvm.inttoptr %1 : i64 to !llvm.ptr<11> + %39 = llvm.inttoptr %2 : i64 to !llvm.ptr<11> + %40 = llvm.inttoptr %3 : i64 to !llvm.ptr<11> + llvm.store %4, %37 : i32, !llvm.ptr<11> + llvm.store %4, %38 : i32, !llvm.ptr<11> + llvm.store %4, %39 : i32, !llvm.ptr<11> + llvm.store %4, %40 : i32, !llvm.ptr<11> + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 14 + %41 = hivm.hir.get_sub_block_idx -> i64 + %42 = arith.muli %41, %1 : i64 + %43 = arith.addi %42, %1 : i64 + hivm.hir.sync_block_set[, , ] flag = 5 + hivm.hir.sync_block_set[, , ] flag = 4 + %44 = arith.addi %arg11, %c255_i32 : i32 + %45 = arith.divsi %44, %c256_i32 : i32 + %46 = arith.extsi %45 : i32 to i64 + %47 = arith.muli %34, %46 : i64 + %48 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %49 = arith.extsi %48 : tensor<32xi32> to tensor<32xi64> + %50 = tt.splat %arg8 : f32 -> tensor<32x256xf32> + %51 = tt.splat %arg9 : f32 -> tensor<32x256xf32> + %52 = arith.muli %47, %c2_i64 : i64 + %53 = arith.divsi %52, %6 : i64 + %54:5 = scf.for %arg13 = %c0_i64 to %52 step %c1_i64 iter_args(%arg14 = %c0_i64, %arg15 = %cst_1, %arg16 = %cst_2, %arg17 = %c0_i64, %arg18 = %c0_i64) -> (i64, tensor<32x1xi64>, tensor<32x32xi1>, i64, i64) : i64 { + hivm.hir.sync_block_wait[, , ] flag = 15 + %55 = llvm.inttoptr %43 : i64 to !llvm.ptr<11> + %56 = llvm.load %55 : !llvm.ptr<11> -> i32 + %57 = arith.andi %56, %5 : i32 + %58 = arith.cmpi eq, %57, %5 : i32 + %59 = arith.andi %56, %7 : i32 + %60 = arith.cmpi eq, %59, %c0_i32 : i32 + %61 = arith.andi %56, %8 : i32 + %62 = arith.cmpi eq, %61, %8 : i32 + %63 = arith.cmpi slt, %arg17, %53 : i64 + %64 = arith.andi %58, %60 : i1 + %65 = arith.andi %64, %63 : i1 + %66 = arith.cmpi slt, %arg18, %53 : i64 + %67 = arith.andi %62, %66 : i1 + %68:4 = scf.if %65 -> (i64, tensor<32x1xi64>, tensor<32x32xi1>, i64) { + %70 = arith.divsi %arg13, %46 : i64 + %71 = arith.addi %32, %70 : i64 + %72 = arith.divsi %71, %16 : i64 + %73 = arith.remsi %71, %16 : i64 + %74:2 = scf.if %15 -> (i64, i64) { + scf.yield %73, %c0_i64 : i64, i64 + } else { + %108:2 = scf.for %arg19 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg20 = %c0_i32, %arg21 = %c3_i32) -> (i32, i32) : i32 { + %115 = arith.addi %arg20, %arg21 : i32 + %116 = arith.divsi %115, %c2_i32 : i32 + %117 = tt.addptr %arg5, %116 : !tt.ptr, i32 + %118 = tt.load %117 : !tt.ptr + %119 = arith.extsi %118 : i32 to i64 + %120 = arith.cmpi sle, %119, %73 : i64 + %121 = arith.select %120, %arg21, %116 : i32 + %122 = scf.if %120 -> (i32) { + %123 = arith.addi %116, %c1_i32 : i32 + scf.yield %123 : i32 + } else { + scf.yield %arg20 : i32 + } + scf.yield %122, %121 : i32, i32 + } + %109 = arith.subi %108#0, %c1_i32 : i32 + %110 = arith.extsi %109 : i32 to i64 + %111 = tt.addptr %arg5, %110 : !tt.ptr, i64 + %112 = tt.load %111 : !tt.ptr + %113 = arith.extsi %112 : i32 to i64 + %114 = arith.subi %73, %113 : i64 + scf.yield %110, %114 : i64, i64 + } + %75 = tt.addptr %arg3, %74#0 : !tt.ptr, i64 + %76 = tt.load %75 : !tt.ptr + %77 = tt.addptr %75, %c1_i32 : !tt.ptr, i32 + %78 = tt.load %77 : !tt.ptr + %79 = arith.subi %78, %76 : i64 + %80 = arith.muli %72, %c32_i64 : i64 + %81 = arith.muli %76, %c256_i64 : i64 + %82 = arith.addi %80, %81 : i64 + %83 = arith.muli %74#1, %c32_i64 : i64 + %84 = tt.splat %83 : i64 -> tensor<32xi64> + %85 = arith.addi %84, %49 : tensor<32xi64> + %86 = tt.splat %79 : i64 -> tensor<32xi64> + %87 = arith.cmpi slt, %85, %86 : tensor<32xi64> + %88 = tt.expand_dims %85 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64> + %89 = arith.muli %88, %cst_7 : tensor<32x1xi64> + %90 = tt.expand_dims %87 {axis = 1 : i32} : tensor<32xi1> -> tensor<32x1xi1> + %91 = tt.broadcast %90 : tensor<32x1xi1> -> tensor<32x32xi1> + hivm.hir.sync_block_wait[, , ] flag = 1 + %memspacecast = memref.memory_space_cast %alloc_10 : memref<32x256xf32, #hivm.address_space> to memref<32x256xf32> + %92 = bufferization.to_tensor %memspacecast restrict writable : memref<32x256xf32> + %93 = arith.mulf %92, %50 : tensor<32x256xf32> + %94 = arith.subf %cst_5, %93 : tensor<32x256xf32> + %95 = math.exp %94 : tensor<32x256xf32> + %96 = arith.addf %95, %cst_8 : tensor<32x256xf32> + %97 = arith.divf %cst_8, %96 : tensor<32x256xf32> + %98 = arith.mulf %93, %97 : tensor<32x256xf32> + %99 = arith.mulf %98, %51 : tensor<32x256xf32> + %100 = arith.truncf %99 : tensor<32x256xf32> to tensor<32x256xf16> + %101 = tt.reshape %100 : tensor<32x256xf16> -> tensor<2x16x16x16xf16> + %102 = tt.trans %101 {order = array} : tensor<2x16x16x16xf16> -> tensor<16x2x16x16xf16> + hivm.hir.sync_block_set[, , ] flag = 4 + hivm.hir.sync_block_wait[, , ] flag = 6 + %103 = bufferization.to_memref %102 : memref<16x2x16x16xf16, #hivm.address_space> + hivm.hir.copy ins(%103 : memref<16x2x16x16xf16, #hivm.address_space>) outs(%alloc : memref<16x2x16x16xf16, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 2 + %104 = llvm.load %55 : !llvm.ptr<11> -> i32 + %105 = arith.andi %104, %9 : i32 + %106 = arith.ori %105, %7 : i32 + llvm.store %106, %55 : i32, !llvm.ptr<11> + %107 = arith.addi %arg17, %10 : i64 + scf.yield %82, %89, %91, %107 : i64, tensor<32x1xi64>, tensor<32x32xi1>, i64 + } else { + scf.yield %arg14, %arg15, %arg16, %arg17 : i64, tensor<32x1xi64>, tensor<32x32xi1>, i64 + } + %69 = scf.if %67 -> (i64) { + hivm.hir.sync_block_wait[, , ] flag = 3 + %memspacecast = memref.memory_space_cast %alloc_11 : memref<32x32xf32, #hivm.address_space> to memref<32x32xf32> + %70 = bufferization.to_tensor %memspacecast restrict writable : memref<32x32xf32> + scf.for %arg19 = %c0 to %c32 step %c1 { + scf.for %arg20 = %c0 to %c32 step %c1 { + %extracted = tensor.extract %68#1[%arg19, %c0] {DiscreteMemAccess} : tensor<32x1xi64> + %74 = arith.addi %68#0, %extracted : i64 + %75 = arith.index_cast %arg20 : index to i32 + %76 = arith.extsi %75 : i32 to i64 + %77 = arith.addi %74, %76 : i64 + %78 = tt.addptr %arg7, %77 : !tt.ptr, i64 + %extracted_12 = tensor.extract %70[%arg19, %arg20] {DiscreteMemAccess} : tensor<32x32xf32> + %79 = arith.truncf %extracted_12 : f32 to f16 + %extracted_13 = tensor.extract %68#2[%arg19, %arg20] {DiscreteMemAccess} : tensor<32x32xi1> + tt.store %78, %79, %extracted_13 {DiscreteMemAccess} : !tt.ptr + } {ExtractedLoadOrStore} + } {ExtractedLoadOrStore} + hivm.hir.sync_block_set[, , ] flag = 5 + %71 = llvm.load %55 : !llvm.ptr<11> -> i32 + %72 = arith.andi %71, %11 : i32 + llvm.store %72, %55 : i32, !llvm.ptr<11> + %73 = arith.addi %arg18, %10 : i64 + scf.yield %73 : i64 + } else { + scf.yield %arg18 : i64 + } + hivm.hir.sync_block_set[, , ] flag = 14 + scf.yield %68#0, %68#1, %68#2, %68#3, %69 : i64, tensor<32x1xi64>, tensor<32x32xi1>, i64, i64 + } + hivm.hir.sync_block_wait[, , ] flag = 6 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 6 + %41 = arith.addi %arg11, %c255_i32 : i32 + %42 = arith.divsi %41, %c256_i32 : i32 + %43 = arith.extsi %42 : i32 to i64 + %44 = arith.muli %34, %43 : i64 + %45 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %46 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %47 = arith.extsi %45 : tensor<32xi32> to tensor<32xi64> + %48 = tt.expand_dims %45 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %49 = tt.broadcast %48 : tensor<1x32xi32> -> tensor<32x32xi32> + %50 = arith.extsi %46 : tensor<256xi32> to tensor<256xi64> + %51 = tt.broadcast %48 : tensor<1x32xi32> -> tensor<256x32xi32> + %52 = arith.muli %44, %c2_i64 : i64 + %53 = arith.divsi %52, %6 : i64 + %54:5 = scf.for %arg13 = %c0_i64 to %52 step %c1_i64 iter_args(%arg14 = %c0_i64, %arg15 = %cst, %arg16 = %cst_0, %arg17 = %c0_i64, %arg18 = %c0_i64) -> (i64, tensor<256x1xi64>, tensor<256x32xi1>, i64, i64) : i64 { + hivm.hir.sync_block_wait[, , ] flag = 14 + %55 = llvm.inttoptr %c32_i64 : i64 to !llvm.ptr<11> + %56 = llvm.inttoptr %c64_i64 : i64 to !llvm.ptr<11> + %57 = llvm.load %55 : !llvm.ptr<11> -> i32 + %58 = llvm.load %56 : !llvm.ptr<11> -> i32 + %59 = arith.andi %57, %5 : i32 + %60 = arith.andi %58, %5 : i32 + %61 = arith.cmpi eq, %59, %c0_i32 : i32 + %62 = arith.cmpi eq, %60, %c0_i32 : i32 + %63 = arith.andi %61, %62 : i1 + %64 = arith.andi %57, %7 : i32 + %65 = arith.andi %58, %7 : i32 + %66 = arith.cmpi eq, %64, %7 : i32 + %67 = arith.cmpi eq, %65, %7 : i32 + %68 = arith.andi %66, %67 : i1 + %69 = arith.andi %57, %8 : i32 + %70 = arith.andi %58, %8 : i32 + %71 = arith.cmpi eq, %69, %c0_i32 : i32 + %72 = arith.cmpi eq, %70, %c0_i32 : i32 + %73 = arith.andi %71, %72 : i1 + %74 = arith.cmpi slt, %arg17, %53 : i64 + %75 = arith.andi %63, %74 : i1 + %76 = arith.cmpi slt, %arg18, %53 : i64 + %77 = arith.andi %68, %73 : i1 + %78 = arith.andi %77, %76 : i1 + %79:4 = scf.if %75 -> (i64, tensor<256x1xi64>, tensor<256x32xi1>, i64) { + %81 = arith.divsi %arg13, %43 : i64 + %82 = arith.addi %32, %81 : i64 + %83 = arith.remsi %arg13, %43 : i64 + %84 = arith.divsi %82, %16 : i64 + %85 = arith.remsi %82, %16 : i64 + %86:2 = scf.if %15 -> (i64, i64) { + scf.yield %85, %c0_i64 : i64, i64 + } else { + %140:2 = scf.for %arg19 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg20 = %c0_i32, %arg21 = %c3_i32) -> (i32, i32) : i32 { + %147 = arith.addi %arg20, %arg21 : i32 + %148 = arith.divsi %147, %c2_i32 : i32 + %149 = tt.addptr %arg5, %148 : !tt.ptr, i32 + %150 = tt.load %149 : !tt.ptr + %151 = arith.extsi %150 : i32 to i64 + %152 = arith.cmpi sle, %151, %85 : i64 + %153 = arith.select %152, %arg21, %148 : i32 + %154 = scf.if %152 -> (i32) { + %155 = arith.addi %148, %c1_i32 : i32 + scf.yield %155 : i32 + } else { + scf.yield %arg20 : i32 + } + scf.yield %154, %153 : i32, i32 + } + %141 = arith.subi %140#0, %c1_i32 : i32 + %142 = arith.extsi %141 : i32 to i64 + %143 = tt.addptr %arg5, %142 : !tt.ptr, i64 + %144 = tt.load %143 : !tt.ptr + %145 = arith.extsi %144 : i32 to i64 + %146 = arith.subi %85, %145 : i64 + scf.yield %142, %146 : i64, i64 + } + %87 = arith.divsi %84, %c2_i64 : i64 + %88 = tt.addptr %arg3, %86#0 : !tt.ptr, i64 + %89 = tt.load %88 : !tt.ptr + %90 = tt.addptr %88, %c1_i32 : !tt.ptr, i32 + %91 = tt.load %90 : !tt.ptr + %92 = tt.addptr %arg4, %86#0 : !tt.ptr, i64 + %93 = tt.load %92 : !tt.ptr + %94 = tt.addptr %92, %c1_i32 : !tt.ptr, i32 + %95 = tt.load %94 : !tt.ptr + %96 = arith.subi %91, %89 : i64 + %97 = arith.subi %95, %93 : i64 + %98 = arith.muli %84, %c32_i64 : i64 + %99 = arith.muli %89, %c256_i64 : i64 + %100 = arith.addi %98, %99 : i64 + %101 = tt.addptr %arg0, %100 : !tt.ptr, i64 + %102 = arith.muli %87, %c32_i64 : i64 + %103 = arith.muli %93, %c128_i64 : i64 + %104 = arith.addi %102, %103 : i64 + %105 = tt.addptr %arg1, %104 : !tt.ptr, i64 + %106 = arith.muli %83, %c256_i64 : i64 + %107 = arith.muli %86#1, %c32_i64 : i64 + %108 = tt.splat %107 : i64 -> tensor<32xi64> + %109 = arith.addi %108, %47 : tensor<32xi64> + %110 = tt.splat %96 : i64 -> tensor<32xi64> + %111 = arith.cmpi slt, %109, %110 : tensor<32xi64> + %112 = tt.expand_dims %109 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64> + %113 = arith.muli %112, %cst_7 : tensor<32x1xi64> + %114 = tt.splat %101 : !tt.ptr -> tensor<32x1x!tt.ptr> + %115 = tt.addptr %114, %113 : tensor<32x1x!tt.ptr>, tensor<32x1xi64> + %116 = tt.broadcast %115 : tensor<32x1x!tt.ptr> -> tensor<32x32x!tt.ptr> + %117 = tt.addptr %116, %49 : tensor<32x32x!tt.ptr>, tensor<32x32xi32> + %118 = tt.expand_dims %111 {axis = 1 : i32} : tensor<32xi1> -> tensor<32x1xi1> + %119 = tt.broadcast %118 : tensor<32x1xi1> -> tensor<32x32xi1> + %120 = tt.load %117, %119, %cst_4 : tensor<32x32x!tt.ptr> + %121 = tt.splat %106 : i64 -> tensor<256xi64> + %122 = arith.addi %121, %50 : tensor<256xi64> + %123 = tt.splat %97 : i64 -> tensor<256xi64> + %124 = arith.cmpi slt, %122, %123 : tensor<256xi64> + %125 = tt.expand_dims %122 {axis = 1 : i32} : tensor<256xi64> -> tensor<256x1xi64> + %126 = arith.muli %125, %cst_6 : tensor<256x1xi64> + %127 = tt.splat %105 : !tt.ptr -> tensor<256x1x!tt.ptr> + %128 = tt.addptr %127, %126 : tensor<256x1x!tt.ptr>, tensor<256x1xi64> + %129 = tt.broadcast %128 : tensor<256x1x!tt.ptr> -> tensor<256x32x!tt.ptr> + %130 = tt.addptr %129, %51 : tensor<256x32x!tt.ptr>, tensor<256x32xi32> + %131 = tt.expand_dims %124 {axis = 1 : i32} : tensor<256xi1> -> tensor<256x1xi1> + %132 = tt.broadcast %131 : tensor<256x1xi1> -> tensor<256x32xi1> + %133 = tt.load %130, %132, %cst_3 : tensor<256x32x!tt.ptr> + %134 = tt.trans %133 {order = array} : tensor<256x32xf16> -> tensor<32x256xf16> + %135 = tt.dot %120, %134, %cst_5 : tensor<32x32xf16> * tensor<32x256xf16> -> tensor<32x256xf32> + hivm.hir.sync_block_wait[, , ] flag = 4 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%135 : tensor<32x256xf32>) outs(%alloc_10 : memref<32x256xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 1 + %136 = llvm.load %55 : !llvm.ptr<11> -> i32 + %137 = arith.ori %136, %5 : i32 + %138 = arith.ori %137, %5 : i32 + llvm.store %137, %55 : i32, !llvm.ptr<11> + llvm.store %138, %56 : i32, !llvm.ptr<11> + %139 = arith.addi %arg17, %10 : i64 + scf.yield %104, %126, %132, %139 : i64, tensor<256x1xi64>, tensor<256x32xi1>, i64 + } else { + scf.yield %arg14, %arg15, %arg16, %arg17 : i64, tensor<256x1xi64>, tensor<256x32xi1>, i64 + } + %80 = scf.if %78 -> (i64) { + %81 = tensor.empty() : tensor<256x32xf16> + %82 = scf.for %arg19 = %c0 to %c256 step %c1 iter_args(%arg20 = %81) -> (tensor<256x32xf16>) { + %extracted = tensor.extract %79#1[%arg19, %c0] {DiscreteMemAccess} : tensor<256x1xi64> + %92 = arith.addi %79#0, %extracted : i64 + %93 = tt.splat %92 : i64 -> tensor<1x32xi64> + %94 = arith.extsi %48 : tensor<1x32xi32> to tensor<1x32xi64> + %95 = arith.addi %93, %94 : tensor<1x32xi64> + %96 = tt.splat %arg2 : !tt.ptr -> tensor<1x32x!tt.ptr> + %97 = tt.addptr %96, %95 : tensor<1x32x!tt.ptr>, tensor<1x32xi64> + %98 = tt.load %97 {DiscreteMemAccess} : tensor<1x32x!tt.ptr> + %inserted_slice = tensor.insert_slice %98 into %arg20[%arg19, 0] [1, 32] [1, 1] : tensor<1x32xf16> into tensor<256x32xf16> + scf.yield {DiscreteMemAccess} %inserted_slice : tensor<256x32xf16> + } {ExtractedLoadOrStore} + %83 = arith.select %79#2, %82, %cst_3 : tensor<256x32xi1>, tensor<256x32xf16> + hivm.hir.sync_block_wait[, , ] flag = 2 + %84 = hivm.hir.convert_layout %alloc {dstLayout = #hivm.data_layout, srcLayout = #hivm.data_layout} : (memref<16x2x16x16xf16, #hivm.address_space>) -> memref<32x256xf16, #hivm.address_space> + %memspacecast = memref.memory_space_cast %84 : memref<32x256xf16, #hivm.address_space> to memref<32x256xf16> + %85 = bufferization.to_tensor %memspacecast restrict writable : memref<32x256xf16> + %86 = tt.dot %85, %83, %cst_9 : tensor<32x256xf16> * tensor<256x32xf16> -> tensor<32x32xf32> + hivm.hir.sync_block_set[, , ] flag = 6 + hivm.hir.sync_block_wait[, , ] flag = 5 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%86 : tensor<32x32xf32>) outs(%alloc_11 : memref<32x32xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 3 + %87 = llvm.load %55 : !llvm.ptr<11> -> i32 + %88 = arith.andi %87, %12 : i32 + %89 = arith.ori %88, %8 : i32 + %90 = arith.ori %89, %8 : i32 + llvm.store %89, %55 : i32, !llvm.ptr<11> + llvm.store %90, %56 : i32, !llvm.ptr<11> + %91 = arith.addi %arg18, %10 : i64 + scf.yield %91 : i64 + } else { + scf.yield %arg18 : i64 + } + hivm.hir.sync_block_set[, , ] flag = 15 + scf.yield %79#0, %79#1, %79#2, %79#3, %80 : i64, tensor<256x1xi64>, tensor<256x32xi1>, i64, i64 + } + hivm.hir.sync_block_wait[, , ] flag = 4 + hivm.hir.sync_block_wait[, , ] flag = 5 + hivm.hir.sync_block_wait[, , ] flag = 14 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/if_use_analysis.mlir b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/if_use_analysis.mlir new file mode 100644 index 0000000000..6a076b02e7 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/if_use_analysis.mlir @@ -0,0 +1,479 @@ +// RUN: triton-opt -allow-unregistered-dialect '--triton-to-linalg=named-ops=True enable-nd2nz-on-vector=True compile-on-910-95=True' --split-input-file %s -verify-each 2>&1 | FileCheck %s --check-prefix=NOERR +// NOERR-NOT: failed to legalize unresolved materialization +// CHECK: module +// CHECK: func.func public @dsa_prefill_kernel + +module { + tt.func public @dsa_prefill_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: f32) attributes {noinline = false} { + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : tensor<16x128xbf16> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x192xbf16> + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<1.000000e+00> : tensor<16xf32> + %cst_2 = arith.constant dense<0xFF800000> : tensor<16xf32> + %cst_3 = arith.constant dense<9.99999996E-13> : tensor<16xf32> + %cst_4 = arith.constant dense<0xFF800000> : tensor<16x16xf32> + %cst_5 = arith.constant dense<0> : tensor<16x16xi8> + %cst_6 = arith.constant dense<1024> : tensor<1x16xi32> + %cst_7 = arith.constant dense<0.000000e+00> : tensor<16x16xf32> + %c1_i32 = arith.constant 1 : i32 + %cst_8 = arith.constant dense<1024> : tensor<16x1xi32> + %c16_i32 = arith.constant 16 : i32 + %cst_9 = arith.constant dense<0.000000e+00> : tensor<16xf32> + %cst_10 = arith.constant dense<0.000000e+00> : tensor<16x128xf32> + %cst_11 = arith.constant dense : tensor<16x1xi1> + %cst_12 = arith.constant dense<0> : tensor<16x1xi32> + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(32 : i64) : i64 + %2 = llvm.mlir.constant(64 : i64) : i64 + %3 = llvm.mlir.constant(96 : i64) : i64 + %4 = llvm.mlir.constant(0 : i32) : i32 + %5 = llvm.mlir.constant(1 : i32) : i32 + %c2_i32 = arith.constant 2 : i32 + %6 = llvm.mlir.constant(2 : i32) : i32 + %7 = llvm.mlir.constant(4 : i32) : i32 + %c3_i32 = arith.constant 3 : i32 + %c4_i32 = arith.constant 4 : i32 + %c6_i32 = arith.constant 6 : i32 + %8 = llvm.mlir.constant(6 : i32) : i32 + %9 = llvm.mlir.constant(3 : i32) : i32 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %10 = llvm.mlir.constant(5 : i32) : i32 + %alloc = memref.alloc() : memref<1x1x16x16xbf16, #hivm.address_space> + %alloc_13 = memref.alloc() : memref<16x16xf32, #hivm.address_space> + %alloc_14 = memref.alloc() : memref<16x128xf32, #hivm.address_space> + %11 = tt.get_program_id x : i32 + %12 = llvm.inttoptr %0 : i64 to !llvm.ptr<11> + %13 = llvm.inttoptr %1 : i64 to !llvm.ptr<11> + %14 = llvm.inttoptr %2 : i64 to !llvm.ptr<11> + %15 = llvm.inttoptr %3 : i64 to !llvm.ptr<11> + llvm.store %4, %12 : i32, !llvm.ptr<11> + llvm.store %4, %13 : i32, !llvm.ptr<11> + llvm.store %4, %14 : i32, !llvm.ptr<11> + llvm.store %4, %15 : i32, !llvm.ptr<11> + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 14 + %16 = hivm.hir.get_sub_block_idx -> i64 + %17 = arith.muli %16, %1 : i64 + %18 = arith.addi %17, %1 : i64 + hivm.hir.sync_block_set[, , ] flag = 5 + hivm.hir.sync_block_set[, , ] flag = 4 + %19 = arith.divsi %11, %c16_i32 : i32 + %20 = arith.remsi %11, %c16_i32 : i32 + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %22 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %23 = tt.splat %arg19 : f32 -> tensor<16x16xf32> + %24 = arith.muli %19, %arg17 : i32 + %25 = tt.splat %arg18 : i32 -> tensor<16x1xi32> + %26 = tt.splat %24 : i32 -> tensor<16x1xi32> + %27 = tt.splat %arg4 : !tt.ptr -> tensor<16x16x!tt.ptr> + %28 = arith.muli %19, %arg14 : i32 + %29 = arith.muli %20, %arg15 : i32 + %30 = arith.addi %28, %29 : i32 + %31 = tt.splat %arg16 : i32 -> tensor<16x1xi32> + %32 = tt.splat %30 : i32 -> tensor<16x1xi32> + %33 = tt.expand_dims %21 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %34 = tt.broadcast %33 : tensor<1x128xi32> -> tensor<16x128xi32> + %35 = tt.splat %arg3 : !tt.ptr -> tensor<16x128x!tt.ptr> + scf.for %arg20 = %c0_i32 to %c1024_i32 step %c16_i32 : i32 { + %36 = tt.splat %arg20 : i32 -> tensor<16xi32> + %37 = arith.addi %36, %22 : tensor<16xi32> + %38 = tt.expand_dims %37 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %39 = arith.cmpi slt, %38, %cst_8 : tensor<16x1xi32> + %40 = arith.addi %arg20, %c1_i32 : i32 + %41 = arith.muli %38, %25 : tensor<16x1xi32> + %42 = arith.addi %26, %41 : tensor<16x1xi32> + %43 = tt.broadcast %42 : tensor<16x1xi32> -> tensor<16x16xi32> + %44 = tt.broadcast %39 : tensor<16x1xi1> -> tensor<16x16xi1> + %45 = arith.muli %40, %c2_i32 : i32 + %46 = arith.divsi %45, %c16_i32 : i32 + %47 = arith.divsi %46, %6 : i32 + %48:20 = scf.for %arg21 = %c0_i32 to %45 step %c16_i32 iter_args(%arg22 = %cst_10, %arg23 = %cst_2, %arg24 = %cst_10, %arg25 = %cst_9, %arg26 = %c0_i32, %arg27 = %c0_i32, %arg28 = %cst_9, %arg29 = %cst_9, %arg30 = %cst_9, %arg31 = %cst_9, %arg32 = %cst_9, %arg33 = %c0_i32, %arg34 = %c0_i32, %arg35 = %cst_10, %arg36 = %cst_10, %arg37 = %cst_10, %arg38 = %cst_10, %arg39 = %cst_10, %arg40 = %c0_i32, %arg41 = %c0_i32) -> (tensor<16x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, i32, i32, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i32, i32, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, i32, i32) : i32 { + hivm.hir.sync_block_wait[, , ] flag = 15 + %57 = llvm.inttoptr %18 : i64 to !llvm.ptr<11> + %58 = llvm.load %57 : !llvm.ptr<11> -> i32 + %59 = arith.andi %58, %5 : i32 + %60 = arith.cmpi eq, %59, %5 : i32 + %61 = arith.andi %58, %6 : i32 + %62 = arith.cmpi eq, %61, %c0_i32 : i32 + %63 = arith.andi %58, %7 : i32 + %64 = arith.cmpi eq, %63, %7 : i32 + %65 = arith.cmpi slt, %arg26, %47 : i32 + %66 = arith.andi %60, %62 : i1 + %67 = arith.andi %66, %65 : i1 + %68 = arith.cmpi slt, %arg27, %47 : i32 + %69 = arith.andi %64, %68 : i1 + %70:16 = scf.if %67 -> (tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, i32, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i32, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, i32) { + %72 = tt.splat %arg21 : i32 -> tensor<16xi32> + %73 = arith.addi %72, %22 : tensor<16xi32> + hivm.hir.sync_block_wait[, , ] flag = 1 + %memspacecast = memref.memory_space_cast %alloc_13 : memref<16x16xf32, #hivm.address_space> to memref<16x16xf32> + %74 = bufferization.to_tensor %memspacecast restrict writable : memref<16x16xf32> + %75 = arith.mulf %74, %23 : tensor<16x16xf32> + %76 = tt.expand_dims %73 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %77 = tt.broadcast %76 : tensor<1x16xi32> -> tensor<16x16xi32> + %78 = arith.addi %43, %77 : tensor<16x16xi32> + %79 = arith.cmpi slt, %76, %cst_6 : tensor<1x16xi32> + %80 = tt.broadcast %79 : tensor<1x16xi1> -> tensor<16x16xi1> + %81 = arith.andi %44, %80 : tensor<16x16xi1> + %82 = tt.addptr %27, %78 : tensor<16x16x!tt.ptr>, tensor<16x16xi32> + %83 = tt.bitcast %82 : tensor<16x16x!tt.ptr> -> tensor<16x16x!tt.ptr> + %84 = tt.load %83, %81, %cst_5 : tensor<16x16x!tt.ptr> + %85 = arith.cmpi ne, %84, %cst_5 : tensor<16x16xi8> + %86 = arith.select %85, %75, %cst_4 : tensor<16x16xi1>, tensor<16x16xf32> + %87 = "tt.reduce"(%86) <{axis = 1 : i32}> ({ + ^bb0(%arg42: f32, %arg43: f32): + %132 = arith.maxnumf %arg42, %arg43 : f32 + tt.reduce.return %132 : f32 + }) : (tensor<16x16xf32>) -> tensor<16xf32> + %88 = tt.expand_dims %87 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32> + %89 = tt.broadcast %88 : tensor<16x1xf32> -> tensor<16x16xf32> + %90 = arith.subf %86, %89 : tensor<16x16xf32> + %91 = math.exp %90 : tensor<16x16xf32> + %92 = "tt.reduce"(%91) <{axis = 1 : i32}> ({ + ^bb0(%arg42: f32, %arg43: f32): + %132 = arith.addf %arg42, %arg43 : f32 + tt.reduce.return %132 : f32 + }) : (tensor<16x16xf32>) -> tensor<16xf32> + %93 = math.log %92 : tensor<16xf32> + %94 = arith.addf %87, %93 : tensor<16xf32> + %95 = math.exp %arg23 : tensor<16xf32> + %96 = arith.addf %94, %cst_3 : tensor<16xf32> + %97 = math.exp %96 : tensor<16xf32> + %98 = arith.addf %95, %97 : tensor<16xf32> + %99 = math.log %98 : tensor<16xf32> + %100 = arith.cmpf une, %99, %99 : tensor<16xf32> + %101 = arith.select %100, %arg23, %99 : tensor<16xi1>, tensor<16xf32> + %102 = arith.subf %arg23, %101 : tensor<16xf32> + %103 = math.exp %102 : tensor<16xf32> + %104 = arith.cmpf oeq, %87, %cst_2 : tensor<16xf32> + %105 = arith.select %104, %cst_1, %103 : tensor<16xi1>, tensor<16xf32> + %106 = tt.expand_dims %105 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32> + %107 = tt.broadcast %106 : tensor<16x1xf32> -> tensor<16x128xf32> + %108 = arith.mulf %arg22, %107 : tensor<16x128xf32> + %109 = arith.remsi %arg40, %c6_i32 : i32 + %110 = arith.cmpi eq, %109, %c0_i32 : i32 + %111 = arith.select %110, %108, %arg24 : tensor<16x128xf32> + %112:5 = scf.if %110 -> (tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>) { + scf.yield %arg35, %arg36, %arg37, %arg38, %arg39 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } else { + %132 = arith.cmpi eq, %109, %c1_i32 : i32 + %133 = arith.select %132, %108, %arg35 : tensor<16x128xf32> + %134:4 = scf.if %132 -> (tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>) { + scf.yield %arg36, %arg37, %arg38, %arg39 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } else { + %135 = arith.cmpi eq, %109, %c2_i32 : i32 + %136 = arith.select %135, %108, %arg36 : tensor<16x128xf32> + %137:3 = scf.if %135 -> (tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>) { + scf.yield %arg37, %arg38, %arg39 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } else { + %138 = arith.cmpi eq, %109, %c3_i32 : i32 + %139 = arith.select %138, %108, %arg37 : tensor<16x128xf32> + %140:2 = scf.if %138 -> (tensor<16x128xf32>, tensor<16x128xf32>) { + scf.yield %arg38, %arg39 : tensor<16x128xf32>, tensor<16x128xf32> + } else { + %141 = arith.cmpi eq, %109, %c4_i32 : i32 + %142 = arith.select %141, %108, %arg38 : tensor<16x128xf32> + %143 = arith.select %141, %arg39, %108 : tensor<16x128xf32> + scf.yield %142, %143 : tensor<16x128xf32>, tensor<16x128xf32> + } + scf.yield %139, %140#0, %140#1 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } + scf.yield %136, %137#0, %137#1, %137#2 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } + scf.yield %133, %134#0, %134#1, %134#2, %134#3 : tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32> + } + %113 = arith.addi %arg40, %c1_i32 : i32 + %114 = arith.subf %94, %101 : tensor<16xf32> + %115 = math.exp %114 : tensor<16xf32> + %116 = arith.remsi %arg33, %c6_i32 : i32 + %117 = arith.cmpi eq, %116, %c0_i32 : i32 + %118 = arith.select %117, %115, %arg25 : tensor<16xf32> + %119:5 = scf.if %117 -> (tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) { + scf.yield %arg28, %arg29, %arg30, %arg31, %arg32 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } else { + %132 = arith.cmpi eq, %116, %c1_i32 : i32 + %133 = arith.select %132, %115, %arg28 : tensor<16xf32> + %134:4 = scf.if %132 -> (tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) { + scf.yield %arg29, %arg30, %arg31, %arg32 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } else { + %135 = arith.cmpi eq, %116, %c2_i32 : i32 + %136 = arith.select %135, %115, %arg29 : tensor<16xf32> + %137:3 = scf.if %135 -> (tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) { + scf.yield %arg30, %arg31, %arg32 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } else { + %138 = arith.cmpi eq, %116, %c3_i32 : i32 + %139 = arith.select %138, %115, %arg30 : tensor<16xf32> + %140:2 = scf.if %138 -> (tensor<16xf32>, tensor<16xf32>) { + scf.yield %arg31, %arg32 : tensor<16xf32>, tensor<16xf32> + } else { + %141 = arith.cmpi eq, %116, %c4_i32 : i32 + %142 = arith.select %141, %115, %arg31 : tensor<16xf32> + %143 = arith.select %141, %arg32, %115 : tensor<16xf32> + scf.yield %142, %143 : tensor<16xf32>, tensor<16xf32> + } + scf.yield %139, %140#0, %140#1 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } + scf.yield %136, %137#0, %137#1, %137#2 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } + scf.yield %133, %134#0, %134#1, %134#2, %134#3 : tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32> + } + %120 = arith.addi %arg33, %c1_i32 : i32 + %121 = tt.expand_dims %92 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32> + %122 = tt.broadcast %121 : tensor<16x1xf32> -> tensor<16x16xf32> + %123 = arith.divf %91, %122 : tensor<16x16xf32> + %124 = arith.truncf %123 : tensor<16x16xf32> to tensor<16x16xbf16> + %125 = tt.reshape %124 : tensor<16x16xbf16> -> tensor<1x16x1x16xbf16> + %126 = tt.trans %125 {order = array} : tensor<1x16x1x16xbf16> -> tensor<1x1x16x16xbf16> + hivm.hir.sync_block_set[, , ] flag = 4 + hivm.hir.sync_block_wait[, , ] flag = 6 + %127 = bufferization.to_memref %126 : memref<1x1x16x16xbf16, #hivm.address_space> + hivm.hir.copy ins(%127 : memref<1x1x16x16xbf16, #hivm.address_space>) outs(%alloc : memref<1x1x16x16xbf16, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 2 + %128 = llvm.load %57 : !llvm.ptr<11> -> i32 + %129 = arith.andi %128, %8 : i32 + %130 = arith.ori %129, %6 : i32 + llvm.store %130, %57 : i32, !llvm.ptr<11> + %131 = arith.addi %arg26, %5 : i32 + scf.yield %101, %111, %118, %131, %119#0, %119#1, %119#2, %119#3, %119#4, %120, %112#0, %112#1, %112#2, %112#3, %112#4, %113 : tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, i32, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i32, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, i32 + } else { + scf.yield %arg23, %arg24, %arg25, %arg26, %arg28, %arg29, %arg30, %arg31, %arg32, %arg33, %arg35, %arg36, %arg37, %arg38, %arg39, %arg40 : tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, i32, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i32, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, i32 + } + %71:4 = scf.if %69 -> (tensor<16x128xf32>, i32, i32, i32) { + %72 = arith.remsi %arg41, %c6_i32 : i32 + %73 = arith.cmpi eq, %72, %c0_i32 : i32 + %74 = scf.if %73 -> (tensor<16x128xf32>) { + scf.yield %70#1 : tensor<16x128xf32> + } else { + %90 = arith.cmpi eq, %72, %c1_i32 : i32 + %91 = scf.if %90 -> (tensor<16x128xf32>) { + scf.yield %70#10 : tensor<16x128xf32> + } else { + %92 = arith.cmpi eq, %72, %c2_i32 : i32 + %93 = scf.if %92 -> (tensor<16x128xf32>) { + scf.yield %70#11 : tensor<16x128xf32> + } else { + %94 = arith.cmpi eq, %72, %c3_i32 : i32 + %95 = scf.if %94 -> (tensor<16x128xf32>) { + scf.yield %70#12 : tensor<16x128xf32> + } else { + %96 = arith.cmpi eq, %72, %c4_i32 : i32 + %97 = arith.select %96, %70#13, %70#14 : tensor<16x128xf32> + scf.yield %97 : tensor<16x128xf32> + } + scf.yield %95 : tensor<16x128xf32> + } + scf.yield %93 : tensor<16x128xf32> + } + scf.yield %91 : tensor<16x128xf32> + } + %75 = arith.addi %arg41, %c1_i32 : i32 + %76 = arith.remsi %arg34, %c6_i32 : i32 + %77 = arith.cmpi eq, %76, %c0_i32 : i32 + %78 = scf.if %77 -> (tensor<16xf32>) { + scf.yield %70#2 : tensor<16xf32> + } else { + %90 = arith.cmpi eq, %76, %c1_i32 : i32 + %91 = scf.if %90 -> (tensor<16xf32>) { + scf.yield %70#4 : tensor<16xf32> + } else { + %92 = arith.cmpi eq, %76, %c2_i32 : i32 + %93 = scf.if %92 -> (tensor<16xf32>) { + scf.yield %70#5 : tensor<16xf32> + } else { + %94 = arith.cmpi eq, %76, %c3_i32 : i32 + %95 = scf.if %94 -> (tensor<16xf32>) { + scf.yield %70#6 : tensor<16xf32> + } else { + %96 = arith.cmpi eq, %76, %c4_i32 : i32 + %97 = arith.select %96, %70#7, %70#8 : tensor<16xf32> + scf.yield %97 : tensor<16xf32> + } + scf.yield %95 : tensor<16xf32> + } + scf.yield %93 : tensor<16xf32> + } + scf.yield %91 : tensor<16xf32> + } + %79 = arith.addi %arg34, %c1_i32 : i32 + %80 = tt.expand_dims %78 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32> + %81 = tt.broadcast %80 : tensor<16x1xf32> -> tensor<16x128xf32> + hivm.hir.sync_block_wait[, , ] flag = 3 + %memspacecast = memref.memory_space_cast %alloc_14 : memref<16x128xf32, #hivm.address_space> to memref<16x128xf32> + %82 = bufferization.to_tensor %memspacecast restrict writable : memref<16x128xf32> + %83 = arith.mulf %82, %81 : tensor<16x128xf32> + %84 = arith.cmpf une, %83, %83 : tensor<16x128xf32> + %85 = arith.select %84, %cst_10, %83 : tensor<16x128xi1>, tensor<16x128xf32> + %86 = arith.addf %74, %85 : tensor<16x128xf32> + hivm.hir.sync_block_set[, , ] flag = 5 + %87 = llvm.load %57 : !llvm.ptr<11> -> i32 + %88 = arith.andi %87, %9 : i32 + llvm.store %88, %57 : i32, !llvm.ptr<11> + %89 = arith.addi %arg27, %5 : i32 + scf.yield %86, %89, %79, %75 : tensor<16x128xf32>, i32, i32, i32 + } else { + scf.yield %arg22, %arg27, %arg34, %arg41 : tensor<16x128xf32>, i32, i32, i32 + } + hivm.hir.sync_block_set[, , ] flag = 14 + scf.yield %71#0, %70#0, %70#1, %70#2, %70#3, %71#1, %70#4, %70#5, %70#6, %70#7, %70#8, %70#9, %71#2, %70#10, %70#11, %70#12, %70#13, %70#14, %70#15, %71#3 : tensor<16x128xf32>, tensor<16xf32>, tensor<16x128xf32>, tensor<16xf32>, i32, i32, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, i32, i32, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, tensor<16x128xf32>, i32, i32 + } + %49 = arith.cmpf une, %48#0, %48#0 : tensor<16x128xf32> + %50 = arith.select %49, %cst_10, %48#0 : tensor<16x128xi1>, tensor<16x128xf32> + %51 = arith.muli %38, %31 : tensor<16x1xi32> + %52 = arith.addi %32, %51 : tensor<16x1xi32> + %53 = tt.broadcast %52 : tensor<16x1xi32> -> tensor<16x128xi32> + %54 = arith.addi %53, %34 : tensor<16x128xi32> + %55 = tt.addptr %35, %54 : tensor<16x128x!tt.ptr>, tensor<16x128xi32> + %56 = arith.truncf %50 : tensor<16x128xf32> to tensor<16x128xbf16> + tt.store %55, %56 : tensor<16x128x!tt.ptr> + } + hivm.hir.sync_block_wait[, , ] flag = 6 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + scope.scope : () -> () { + hivm.hir.sync_block_set[, , ] flag = 6 + %16 = arith.divsi %11, %c16_i32 : i32 + %17 = arith.remsi %11, %c16_i32 : i32 + %18 = tt.make_range {end = 192 : i32, start = 0 : i32} : tensor<192xi32> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %20 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %21 = arith.muli %16, %arg5 : i32 + %22 = arith.muli %17, %arg6 : i32 + %23 = arith.addi %21, %22 : i32 + %24 = tt.splat %arg7 : i32 -> tensor<16x1xi32> + %25 = tt.splat %23 : i32 -> tensor<16x1xi32> + %26 = tt.expand_dims %18 {axis = 0 : i32} : tensor<192xi32> -> tensor<1x192xi32> + %27 = tt.broadcast %26 : tensor<1x192xi32> -> tensor<16x192xi32> + %28 = tt.splat %arg0 : !tt.ptr -> tensor<16x192x!tt.ptr> + %29 = arith.muli %16, %arg8 : i32 + %30 = arith.muli %17, %arg9 : i32 + %31 = arith.addi %29, %30 : i32 + %32 = tt.splat %arg10 : i32 -> tensor<16x1xi32> + %33 = tt.splat %31 : i32 -> tensor<16x1xi32> + %34 = tt.splat %arg1 : !tt.ptr -> tensor<16x192x!tt.ptr> + %35 = arith.muli %16, %arg11 : i32 + %36 = arith.muli %17, %arg12 : i32 + %37 = arith.addi %35, %36 : i32 + %38 = tt.expand_dims %19 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + scf.for %arg20 = %c0_i32 to %c1024_i32 step %c16_i32 : i32 { + %39 = tt.splat %arg20 : i32 -> tensor<16xi32> + %40 = arith.addi %39, %20 : tensor<16xi32> + %41 = tt.expand_dims %40 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %42 = arith.muli %41, %24 : tensor<16x1xi32> + %43 = arith.addi %25, %42 : tensor<16x1xi32> + %44 = tt.broadcast %43 : tensor<16x1xi32> -> tensor<16x192xi32> + %45 = arith.addi %44, %27 : tensor<16x192xi32> + %46 = arith.cmpi slt, %41, %cst_8 : tensor<16x1xi32> + %47 = tt.addptr %28, %45 : tensor<16x192x!tt.ptr>, tensor<16x192xi32> + %48 = tt.broadcast %46 : tensor<16x1xi1> -> tensor<16x192xi1> + %49 = tt.load %47, %48, %cst_0 : tensor<16x192x!tt.ptr> + %50 = arith.addi %arg20, %c1_i32 : i32 + %51 = arith.muli %50, %c2_i32 : i32 + %52 = arith.divsi %51, %c16_i32 : i32 + %53 = arith.divsi %52, %6 : i32 + %54:4 = scf.for %arg21 = %c0_i32 to %51 step %c16_i32 iter_args(%arg22 = %cst_12, %arg23 = %cst_11, %arg24 = %c0_i32, %arg25 = %c0_i32) -> (tensor<16x1xi32>, tensor<16x1xi1>, i32, i32) : i32 { + hivm.hir.sync_block_wait[, , ] flag = 14 + %55 = llvm.inttoptr %c32_i64 : i64 to !llvm.ptr<11> + %56 = llvm.inttoptr %c64_i64 : i64 to !llvm.ptr<11> + %57 = llvm.load %55 : !llvm.ptr<11> -> i32 + %58 = llvm.load %56 : !llvm.ptr<11> -> i32 + %59 = arith.andi %57, %5 : i32 + %60 = arith.andi %58, %5 : i32 + %61 = arith.cmpi eq, %59, %c0_i32 : i32 + %62 = arith.cmpi eq, %60, %c0_i32 : i32 + %63 = arith.andi %61, %62 : i1 + %64 = arith.andi %57, %6 : i32 + %65 = arith.andi %58, %6 : i32 + %66 = arith.cmpi eq, %64, %6 : i32 + %67 = arith.cmpi eq, %65, %6 : i32 + %68 = arith.andi %66, %67 : i1 + %69 = arith.andi %57, %7 : i32 + %70 = arith.andi %58, %7 : i32 + %71 = arith.cmpi eq, %69, %c0_i32 : i32 + %72 = arith.cmpi eq, %70, %c0_i32 : i32 + %73 = arith.andi %71, %72 : i1 + %74 = arith.cmpi slt, %arg24, %53 : i32 + %75 = arith.andi %63, %74 : i1 + %76 = arith.cmpi slt, %arg25, %53 : i32 + %77 = arith.andi %68, %73 : i1 + %78 = arith.andi %77, %76 : i1 + %79:3 = scf.if %75 -> (tensor<16x1xi32>, tensor<16x1xi1>, i32) { + %81 = tt.splat %arg21 : i32 -> tensor<16xi32> + %82 = arith.addi %81, %20 : tensor<16xi32> + %83 = tt.expand_dims %82 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %84 = arith.muli %83, %32 : tensor<16x1xi32> + %85 = arith.addi %33, %84 : tensor<16x1xi32> + %86 = tt.broadcast %85 : tensor<16x1xi32> -> tensor<16x192xi32> + %87 = arith.addi %86, %27 : tensor<16x192xi32> + %88 = arith.cmpi slt, %83, %cst_8 : tensor<16x1xi32> + %89 = tt.addptr %34, %87 : tensor<16x192x!tt.ptr>, tensor<16x192xi32> + %90 = tt.broadcast %88 : tensor<16x1xi1> -> tensor<16x192xi1> + %91 = tt.load %89, %90, %cst_0 : tensor<16x192x!tt.ptr> + %92 = tt.trans %91 {order = array} : tensor<16x192xbf16> -> tensor<192x16xbf16> + %93 = tt.dot %49, %92, %cst_7 : tensor<16x192xbf16> * tensor<192x16xbf16> -> tensor<16x16xf32> + hivm.hir.sync_block_wait[, , ] flag = 4 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%93 : tensor<16x16xf32>) outs(%alloc_13 : memref<16x16xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 1 + %94 = llvm.load %55 : !llvm.ptr<11> -> i32 + %95 = arith.ori %94, %5 : i32 + %96 = arith.ori %95, %5 : i32 + llvm.store %95, %55 : i32, !llvm.ptr<11> + llvm.store %96, %56 : i32, !llvm.ptr<11> + %97 = arith.addi %arg24, %5 : i32 + scf.yield %83, %88, %97 : tensor<16x1xi32>, tensor<16x1xi1>, i32 + } else { + scf.yield %arg22, %arg23, %arg24 : tensor<16x1xi32>, tensor<16x1xi1>, i32 + } + %80 = scf.if %78 -> (i32) { + %81 = tt.broadcast %79#1 : tensor<16x1xi1> -> tensor<16x128xi1> + %82 = tensor.empty() : tensor<16x128xbf16> + %83 = scf.for %arg26 = %c0 to %c16 step %c1 iter_args(%arg27 = %82) -> (tensor<16x128xbf16>) { + %extracted = tensor.extract %79#0[%arg26, %c0] {DiscreteMemAccess} : tensor<16x1xi32> + %93 = arith.muli %extracted, %arg13 : i32 + %94 = arith.addi %37, %93 : i32 + %95 = tt.splat %94 : i32 -> tensor<1x128xi32> + %96 = arith.addi %95, %38 : tensor<1x128xi32> + %97 = arith.extsi %96 : tensor<1x128xi32> to tensor<1x128xi64> + %98 = tt.splat %arg2 : !tt.ptr -> tensor<1x128x!tt.ptr> + %99 = tt.addptr %98, %97 : tensor<1x128x!tt.ptr>, tensor<1x128xi64> + %100 = tt.load %99 {DiscreteMemAccess} : tensor<1x128x!tt.ptr> + %inserted_slice = tensor.insert_slice %100 into %arg27[%arg26, 0] [1, 128] [1, 1] : tensor<1x128xbf16> into tensor<16x128xbf16> + scf.yield {DiscreteMemAccess} %inserted_slice : tensor<16x128xbf16> + } {ExtractedLoadOrStore} + %84 = arith.select %81, %83, %cst : tensor<16x128xi1>, tensor<16x128xbf16> + hivm.hir.sync_block_wait[, , ] flag = 2 + %85 = hivm.hir.convert_layout %alloc {dstLayout = #hivm.data_layout, srcLayout = #hivm.data_layout} : (memref<1x1x16x16xbf16, #hivm.address_space>) -> memref<16x16xbf16, #hivm.address_space> + %memspacecast = memref.memory_space_cast %85 : memref<16x16xbf16, #hivm.address_space> to memref<16x16xbf16> + %86 = bufferization.to_tensor %memspacecast restrict writable : memref<16x16xbf16> + %87 = tt.dot %86, %84, %cst_10 : tensor<16x16xbf16> * tensor<16x128xbf16> -> tensor<16x128xf32> + hivm.hir.sync_block_set[, , ] flag = 6 + hivm.hir.sync_block_wait[, , ] flag = 5 + hivm.hir.fixpipe {dma_mode = #hivm.dma_mode} ins(%87 : tensor<16x128xf32>) outs(%alloc_14 : memref<16x128xf32, #hivm.address_space>) + hivm.hir.sync_block_set[, , ] flag = 3 + %88 = llvm.load %55 : !llvm.ptr<11> -> i32 + %89 = arith.andi %88, %10 : i32 + %90 = arith.ori %89, %7 : i32 + %91 = arith.ori %90, %7 : i32 + llvm.store %90, %55 : i32, !llvm.ptr<11> + llvm.store %91, %56 : i32, !llvm.ptr<11> + %92 = arith.addi %arg25, %5 : i32 + scf.yield %92 : i32 + } else { + scf.yield %arg25 : i32 + } + hivm.hir.sync_block_set[, , ] flag = 15 + scf.yield %79#0, %79#1, %79#2, %80 : tensor<16x1xi32>, tensor<16x1xi1>, i32, i32 + } + } + hivm.hir.sync_block_wait[, , ] flag = 4 + hivm.hir.sync_block_wait[, , ] flag = 5 + hivm.hir.sync_block_wait[, , ] flag = 14 + scope.return + } {hivm.tcore_type = #hivm.tcore_type} + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/General/AutoBlockify/auto_blockify.mlir b/third_party/ascend/unittest/Conversion/General/AutoBlockify/auto_blockify.mlir new file mode 100644 index 0000000000..2fcb46e6e3 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/AutoBlockify/auto_blockify.mlir @@ -0,0 +1,134 @@ +// RUN: triton-opt --auto-blockify="auto-blockify-size=5" --split-input-file %s | FileCheck %s + +// ----- + +// CHECK-LABEL: tt.func @kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) attributes {auto_blockify_size = 5 : i32} { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : tensor<5x8xf32> +// CHECK: %[[VAL_2:.*]] = arith.constant dense<8> : tensor<5xi32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : tensor<5xi32> +// CHECK: %[[VAL_4:.*]] = tt.get_num_programs x : i32 +// CHECK: %[[VAL_5:.*]] = tt.get_num_programs y : i32 +// CHECK: %[[VAL_6:.*]] = tt.get_num_programs z : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_5]], %[[VAL_6]] : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : i32 +// CHECK: %[[VAL_9:.*]] = tt.get_program_id x {logical_block_id} : i32 +// CHECK: %[[VAL_10:.*]] = tt.get_program_id y {logical_block_id} : i32 +// CHECK: %[[VAL_11:.*]] = tt.get_program_id z {logical_block_id} : i32 +// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_10]], %[[VAL_6]] : i32 +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_16:.*]] = tt.make_range {end = 5 : i32, start = 0 : i32} : tensor<5xi32> +// CHECK: %[[VAL_17:.*]] = tt.splat %[[VAL_15]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_16]] : tensor<5xi32> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_8]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_20:.*]] = arith.cmpi slt, %[[VAL_18]], %[[VAL_19]] : tensor<5xi32> +// CHECK: %[[VAL_21:.*]] = arith.cmpi sge, %[[VAL_18]], %[[VAL_3]] : tensor<5xi32> +// CHECK: %[[VAL_22:.*]] = arith.ori %[[VAL_20]], %[[VAL_21]] : tensor<5xi1> +// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_7]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_24:.*]] = arith.divsi %[[VAL_18]], %[[VAL_23]] : tensor<5xi32> +// CHECK: %[[VAL_25:.*]] = tt.splat %[[VAL_4]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_26:.*]] = arith.remsi %[[VAL_24]], %[[VAL_25]] : tensor<5xi32> +// CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_26]], %[[VAL_2]] : tensor<5xi32> +// CHECK: %[[VAL_28:.*]] = tt.expand_dims %[[VAL_27]] {axis = 1 : i32} : tensor<5xi32> -> tensor<5x1xi32> +// CHECK: %[[VAL_29:.*]] = tt.broadcast %[[VAL_28]] : tensor<5x1xi32> -> tensor<5x8xi32> +// CHECK: %[[VAL_30:.*]] = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> +// CHECK: %[[VAL_31:.*]] = tt.expand_dims %[[VAL_30]] {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> +// CHECK: %[[VAL_32:.*]] = tt.broadcast %[[VAL_31]] : tensor<1x8xi32> -> tensor<5x8xi32> +// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_29]], %[[VAL_32]] : tensor<5x8xi32> +// CHECK: %[[VAL_34:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<5x8x!tt.ptr> +// CHECK: %[[VAL_35:.*]] = tt.addptr %[[VAL_34]], %[[VAL_33]] : tensor<5x8x!tt.ptr>, tensor<5x8xi32> +// CHECK: %[[VAL_36:.*]] = tt.expand_dims %[[VAL_22]] {axis = 1 : i32} : tensor<5xi1> -> tensor<5x1xi1> +// CHECK: %[[VAL_37:.*]] = tt.broadcast %[[VAL_36]] : tensor<5x1xi1> -> tensor<5x8xi1> +// CHECK: tt.store %[[VAL_35]], %[[VAL_1]], %[[VAL_37]] : tensor<5x8x!tt.ptr> +// CHECK: tt.return +// CHECK: } +tt.func @kernel(%arg0: !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c8_i32 : i32 + %2 = tt.splat %1 : i32 -> tensor<8xi32> + %3 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %4 = arith.addi %2, %3 : tensor<8xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %cst : tensor<8x!tt.ptr> + tt.return +} + +// ----- + +// CHECK-LABEL: tt.func @kernel2( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) attributes {auto_blockify_size = 5 : i32} { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<8> : tensor<5xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_3:.*]] = arith.constant 5 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_6:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> +// CHECK: %[[VAL_7:.*]] = tt.get_num_programs x : i32 +// CHECK: %[[VAL_8:.*]] = tt.get_num_programs y : i32 +// CHECK: %[[VAL_9:.*]] = tt.get_num_programs z : i32 +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_8]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_12:.*]] = tt.get_program_id x {logical_block_id} : i32 +// CHECK: %[[VAL_13:.*]] = tt.get_program_id y {logical_block_id} : i32 +// CHECK: %[[VAL_14:.*]] = tt.get_program_id z {logical_block_id} : i32 +// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_12]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_13]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_19:.*]] = tt.make_range {end = 5 : i32, start = 0 : i32} : tensor<5xi32> +// CHECK: %[[VAL_20:.*]] = tt.splat %[[VAL_18]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : tensor<5xi32> +// CHECK: %[[VAL_22:.*]] = tt.splat %[[VAL_10]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_23:.*]] = arith.divsi %[[VAL_21]], %[[VAL_22]] : tensor<5xi32> +// CHECK: %[[VAL_24:.*]] = tt.splat %[[VAL_7]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_25:.*]] = arith.remsi %[[VAL_23]], %[[VAL_24]] : tensor<5xi32> +// CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_9]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_27:.*]] = arith.divsi %[[VAL_21]], %[[VAL_26]] : tensor<5xi32> +// CHECK: %[[VAL_28:.*]] = tt.splat %[[VAL_8]] : i32 -> tensor<5xi32> +// CHECK: %[[VAL_29:.*]] = arith.remsi %[[VAL_27]], %[[VAL_28]] : tensor<5xi32> +// CHECK: %[[VAL_30:.*]] = arith.cmpi slt, %[[VAL_29]], %[[VAL_1]] : tensor<5xi32> +// CHECK: %[[VAL_31:.*]] = arith.muli %[[VAL_25]], %[[VAL_1]] : tensor<5xi32> +// CHECK: %[[VAL_32:.*]] = tt.expand_dims %[[VAL_31]] {axis = 1 : i32} : tensor<5xi32> -> tensor<5x1xi32> +// CHECK: %[[VAL_33:.*]] = tt.broadcast %[[VAL_32]] : tensor<5x1xi32> -> tensor<5x8xi32> +// CHECK: %[[VAL_34:.*]] = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> +// CHECK: %[[VAL_35:.*]] = tt.expand_dims %[[VAL_34]] {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> +// CHECK: %[[VAL_36:.*]] = tt.broadcast %[[VAL_35]] : tensor<1x8xi32> -> tensor<5x8xi32> +// CHECK: %[[VAL_37:.*]] = arith.addi %[[VAL_33]], %[[VAL_36]] : tensor<5x8xi32> +// CHECK: %[[VAL_38:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<5x8x!tt.ptr> +// CHECK: %[[VAL_39:.*]] = tt.addptr %[[VAL_38]], %[[VAL_37]] : tensor<5x8x!tt.ptr>, tensor<5x8xi32> +// CHECK: %[[VAL_40:.*]] = arith.subi %[[VAL_11]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_41:.*]] = arith.maxsi %[[VAL_40]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_42:.*]] = arith.index_cast %[[VAL_41]] : i32 to index +// CHECK: %[[VAL_43:.*]] = arith.minsi %[[VAL_42]], %[[VAL_3]] : index +// CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_5]] to %[[VAL_43]] step %[[VAL_4]] { +// CHECK: %[[VAL_45:.*]] = tensor.extract %[[VAL_30]]{{\[}}%[[VAL_44]]] : tensor<5xi1> +// CHECK: scf.if %[[VAL_45]] { +// CHECK: %[[VAL_46:.*]] = tensor.extract_slice %[[VAL_39]]{{\[}}%[[VAL_44]], 0] [1, 8] [1, 1] : tensor<5x8x!tt.ptr> to tensor<8x!tt.ptr> +// CHECK: tt.store %[[VAL_46]], %[[VAL_6]] : tensor<8x!tt.ptr> +// CHECK: } +// CHECK: } {auto_blockify_loop} +// CHECK: tt.return +// CHECK: } +tt.func @kernel2(%arg0: !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %a = tt.get_program_id y : i32 + %b = arith.cmpi slt, %a, %c8_i32 : i32 + %1 = arith.muli %0, %c8_i32 : i32 + %2 = tt.splat %1 : i32 -> tensor<8xi32> + %3 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %4 = arith.addi %2, %3 : tensor<8xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<8x!tt.ptr>, tensor<8xi32> + scf.if %b { + tt.store %6, %cst : tensor<8x!tt.ptr> + scf.yield + } + tt.return +} \ No newline at end of file diff --git a/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/atomic.mlir b/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/atomic.mlir new file mode 100644 index 0000000000..71d8fb3db1 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/atomic.mlir @@ -0,0 +1,201 @@ +// RUN: triton-opt %s --discrete-mask-access-conversion --split-input-file %s | FileCheck %s + +// CHECK-LABEL: tt.func @atomic_add_i32 +// CHECK: %[[default:.*]] = arith.constant dense<0> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw add, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_add_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw add, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_fadd_f32 +// CHECK: %[[default:.*]] = arith.constant dense<0.000000e+00> : tensor<1024xf32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_fadd_f32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw fadd, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xf32>, tensor<1024xi1>) -> tensor<1024xf32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_max_i32 +// CHECK: %[[default:.*]] = arith.constant dense<-2147483648> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_max_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw max, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_umax_i32 +// CHECK: %[[default:.*]] = arith.constant dense<0> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw umax, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_umax_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw umax, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_min_i32 +// CHECK: %[[default:.*]] = arith.constant dense<2147483647> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw min, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_min_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw min, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_umin_i32 +// CHECK: %[[default:.*]] = arith.constant dense<2147483647> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw umin, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_umin_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw umin, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_and_i32 +// CHECK: %[[default:.*]] = arith.constant dense<2147483647> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw and, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_and_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw and, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_or_i32 +// CHECK: %[[default:.*]] = arith.constant dense<0> : tensor<1024xi32> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw or, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_or_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw or, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi32>, tensor<1024xi1>) -> tensor<1024xi32> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_max_i16 +// CHECK: %[[default:.*]] = arith.constant dense<-32768> : tensor<1024xi16> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_max_i16(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw max, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xi16>, tensor<1024xi1>) -> tensor<1024xi16> + tt.return +} + +// CHECK-LABEL: tt.func @atomic_max_f16 +// CHECK: %[[default:.*]] = arith.constant dense<0xFC00> : tensor<1024xf16> +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] +// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] +tt.func @atomic_max_f16(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<200> : tensor<1024xi32> + %cst_0 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_0 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + %9 = tt.atomic_rmw max, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xf16>, tensor<1024xi1>) -> tensor<1024xf16> + tt.return +} \ No newline at end of file diff --git a/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/loadstore.mlir b/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/loadstore.mlir new file mode 100644 index 0000000000..6854e118ea --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/loadstore.mlir @@ -0,0 +1,67 @@ +// RUN: triton-opt %s --discrete-mask-access-conversion --split-input-file %s | FileCheck %s +// RUN: triton-opt %s --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' + +// CHECK-LABEL: tt.func @discrete_load +// CHECK: %[[loaded_value:.*]] = tt.load %[[load_ptr:.*]] +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[loaded_value]], %[[other:.*]] +// CHECK: tt.store %[[store_ptr:.*]], %[[value]] +tt.func @discrete_load(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<0> : tensor<1024xi32> + %cst_0 = arith.constant dense<200> : tensor<1024xi32> + %cst_1 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst_0 : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_1 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %5, %3, %cst : tensor<1024x!tt.ptr> + tt.store %7, %8 : tensor<1024x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func @discrete_load_without_other +// CHECK: %[[other:.*]] = arith.constant dense<0> +// CHECK: %[[loaded_value:.*]] = tt.load %[[load_ptr:.*]] +// CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[loaded_value]], %[[other]] +// CHECK: tt.store %[[store_ptr:.*]], %[[value]] +tt.func @discrete_load_without_other(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<0> : tensor<1024xi32> + %cst_0 = arith.constant dense<200> : tensor<1024xi32> + %cst_1 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst_0 : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_1 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %5, %3 : tensor<1024x!tt.ptr> + tt.store %7, %8 : tensor<1024x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func @discrete_store +// CHECK: %[[loaded_value:.*]] = tt.load %[[load_ptr:.*]] : tensor<1024x!tt.ptr> +// CHECK: %[[origin_value:.*]] = tt.load %[[store_ptr:.*]] : tensor<1024x!tt.ptr> +// CHECK: %[[store_value:.*]] = arith.select %[[mask:.*]], %[[loaded_value]], %[[origin_value]] +// CHECK: tt.store %[[store_ptr]], %[[store_value]] +tt.func @discrete_store(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst = arith.constant dense<0> : tensor<1024xi32> + %cst_0 = arith.constant dense<200> : tensor<1024xi32> + %cst_1 = arith.constant dense<400> : tensor<1024xi32> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = arith.cmpi slt, %0, %cst_0 : tensor<1024xi32> + %2 = arith.cmpi sgt, %0, %cst_1 : tensor<1024xi32> + %3 = arith.ori %1, %2 : tensor<1024xi1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %5 : tensor<1024x!tt.ptr> + tt.store %7, %8, %3 : tensor<1024x!tt.ptr> + tt.return +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonAscendAllPass/simplify_for_loop.mlir b/third_party/ascend/unittest/Conversion/General/TritonAscendAllPass/simplify_for_loop.mlir new file mode 100644 index 0000000000..30b3c375c2 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonAscendAllPass/simplify_for_loop.mlir @@ -0,0 +1,121 @@ +// RUN: triton-opt -allow-unregistered-dialect --triton-to-structured '--discrete-mask-access-conversion=compile-on-910-95=False force-simt-template=False' '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation --triton-to-structured --triton-to-linalg --split-input-file %s | FileCheck %s +// CHECK-LABEL: func.func @matmul_kernel +// CHECK-DAG: %[[C0:.*]] = arith.constant{{.*}}0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant{{.*}}1 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant{{.*}}64 : index +// CHECK: %{{.*}} = scf.for %{{.*}} = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%{{.*}} = %{{.*}}, %[[ARG16:.*]] = %[[C0]]) -> (tensor<128x256xi32>, index) : i32 { +// CHECK: %[[INNERFOR:.*]]:3 = scf.for {{.*}} = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %[[ARG20:.*]] = %[[ARG16]]) -> (tensor<128x256xi32>, tensor<64x256xi64>, index) : i32 { +// CHECK: %{{.*}} = memref.reinterpret_cast %{{.*}} to offset: [%[[ARG20]]], sizes: [1, 64], strides: [%[[C1]], %[[C1]]] : memref to memref<1x64xi8, strided<[?, ?], offset: ?>> +// CHECK: %{{.*}} = linalg.broadcast ins(%{{.*}} : tensor<64xi8>) outs(%{{.*}} : tensor<128x64xi8>) dimensions = [0] +// CHECK: %[[RES72:.*]] = arith.addi %[[ARG20]], %[[C64]] : index +// CHECK: scf.yield %{{.*}}, %{{.*}}, %[[RES72]] : tensor<128x256xi32>, tensor<64x256xi64>, index +// CHECK: } {{{.*}}tts.simplify_tensor_iter_args.done} +// CHECK: scf.yield %[[INNERFOR]]#0, %[[INNERFOR]]#2 : tensor<128x256xi32>, index +// CHECK: } {{{.*}}tts.simplify_tensor_iter_args.done} + +module { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<1> : tensor<64x256xi8> + %cst_0 = arith.constant dense<0> : tensor<64x256xi8> + %cst_1 = arith.constant dense<0> : tensor<128x64xi8> + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<1024> : tensor<1x256xi32> + %cst_3 = arith.constant dense<1> : tensor<128x1xi32> + %cst_4 = arith.constant dense<64> : tensor<128x64xi32> + %cst_5 = arith.constant dense<0> : tensor<128x256xi32> + %c3_i32 = arith.constant 3 : i32 + %c2_i32 = arith.constant 2 : i32 + %cst_6 = arith.constant dense<8192> : tensor<64x1xi32> + %c8192_i32 = arith.constant 8192 : i32 + %c64_i32 = arith.constant 64 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst_7 = arith.constant dense<1024> : tensor<256xi32> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c32_i32 : i32 + %2 = arith.muli %1, %c8_i32 : i32 + %3 = arith.subi %c1_i32, %2 : i32 + %4 = arith.minsi %3, %c8_i32 : i32 + %5 = arith.remsi %0, %c32_i32 : i32 + %6 = arith.remsi %5, %4 : i32 + %7 = arith.addi %2, %6 : i32 + %8 = arith.divsi %5, %4 : i32 + %9 = arith.muli %8, %c256_i32 : i32 + %10 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %11 = tt.splat %9 : i32 -> tensor<256xi32> + %12 = arith.addi %11, %10 : tensor<256xi32> + %13 = arith.remsi %12, %cst_7 : tensor<256xi32> + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %15 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %16 = tt.splat %arg0 : !tt.ptr -> tensor<1x64x!tt.ptr> + %17 = tt.addptr %16, %15 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> + %18 = tt.broadcast %17 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> + %19 = tt.expand_dims %14 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %20 = tt.splat %arg4 : i32 -> tensor<64x1xi32> + %21 = arith.muli %19, %20 : tensor<64x1xi32> + %22 = tt.expand_dims %13 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %23 = tt.broadcast %21 : tensor<64x1xi32> -> tensor<64x256xi32> + %24 = tt.broadcast %22 : tensor<1x256xi32> -> tensor<64x256xi32> + %25 = arith.addi %23, %24 : tensor<64x256xi32> + %26 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> + %27 = tt.addptr %26, %25 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + %28 = arith.cmpi slt, %19, %cst_6 : tensor<64x1xi32> + %29 = tt.broadcast %28 : tensor<64x1xi1> -> tensor<64x256xi1> + %30 = arith.muli %arg4, %c64_i32 : i32 + %31 = tt.splat %30 : i32 -> tensor<64x256xi32> + %32:2 = scf.for %arg6 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg7 = %cst_5, %arg8 = %18) -> (tensor<128x256xi32>, tensor<128x64x!tt.ptr>) : i32 { + %51 = arith.muli %arg6, %c32_i32 : i32 + %52 = arith.muli %arg6, %c2_i32 : i32 + %53 = arith.shli %c3_i32, %52 : i32 + %54 = tt.splat %53 : i32 -> tensor<64x256xi32> + %55 = tt.splat %52 : i32 -> tensor<64x256xi32> + %56:3 = scf.for %arg9 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg10 = %arg7, %arg11 = %arg8, %arg12 = %27) -> (tensor<128x256xi32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr>) : i32 { + %57 = arith.addi %51, %arg9 : i32 + %58 = arith.muli %57, %c64_i32 : i32 + %59 = arith.subi %c8192_i32, %58 : i32 + %60 = tt.splat %59 : i32 -> tensor<1x64xi32> + %61 = arith.cmpi slt, %15, %60 : tensor<1x64xi32> + %62 = tt.broadcast %61 : tensor<1x64xi1> -> tensor<128x64xi1> + %63 = tt.load %arg11, %62, %cst_1 : tensor<128x64x!tt.ptr> + %64 = tt.load %arg12, %29, %cst_0 : tensor<64x256x!tt.ptr> + %65 = arith.extui %64 : tensor<64x256xi8> to tensor<64x256xi32> + %66 = arith.andi %65, %54 : tensor<64x256xi32> + %67 = arith.shrsi %66, %55 : tensor<64x256xi32> + %68 = arith.trunci %67 : tensor<64x256xi32> to tensor<64x256xi8> + %69 = arith.subi %68, %cst : tensor<64x256xi8> + %70 = tt.dot %63, %69, %arg10 : tensor<128x64xi8> * tensor<64x256xi8> -> tensor<128x256xi32> + %71 = tt.addptr %arg11, %cst_4 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %72 = tt.addptr %arg12, %31 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + scf.yield %70, %71, %72 : tensor<128x256xi32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr> + } + scf.yield %56#0, %56#1 : tensor<128x256xi32>, tensor<128x64x!tt.ptr> + } + %33 = arith.muli %7, %c128_i32 : i32 + %34 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %35 = tt.splat %33 : i32 -> tensor<128xi32> + %36 = arith.addi %35, %34 : tensor<128xi32> + %37 = tt.expand_dims %36 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %38 = tt.splat %arg5 : i32 -> tensor<128x1xi32> + %39 = arith.muli %38, %37 : tensor<128x1xi32> + %40 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> + %41 = tt.addptr %40, %39 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + %42 = tt.expand_dims %12 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %43 = tt.broadcast %41 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> + %44 = tt.broadcast %42 : tensor<1x256xi32> -> tensor<128x256xi32> + %45 = tt.addptr %43, %44 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %46 = arith.cmpi slt, %37, %cst_3 : tensor<128x1xi32> + %47 = arith.cmpi slt, %42, %cst_2 : tensor<1x256xi32> + %48 = tt.broadcast %46 : tensor<128x1xi1> -> tensor<128x256xi1> + %49 = tt.broadcast %47 : tensor<1x256xi1> -> tensor<128x256xi1> + %50 = arith.andi %48, %49 : tensor<128x256xi1> + tt.store %45, %32#0, %50 : tensor<128x256x!tt.ptr> + tt.return + } +} + + + diff --git a/third_party/ascend/unittest/Conversion/General/TritonToHFusion/fp_to_fp_rtz.mlir b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/fp_to_fp_rtz.mlir new file mode 100644 index 0000000000..cbabb03d89 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/fp_to_fp_rtz.mlir @@ -0,0 +1,18 @@ +// RUN: triton-opt %s -triton-to-hfusion | FileCheck %s + +// CHECK-LABEL: tt.func @test_fp32_to_fp16_rtz +tt.func @test_fp32_to_fp16_rtz(%arg0: tensor<1024xf32>) -> tensor<1024xf16> { + // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1024xf16> + // CHECK: %[[RESULT:.*]] = hfusion.cast {mode = #hfusion.round_mode} ins(%arg0 : tensor<1024xf32>) outs(%[[EMPTY]] : tensor<1024xf16>) -> tensor<1024xf16> + %0 = tt.fp_to_fp %arg0, rounding = rtz : tensor<1024xf32> -> tensor<1024xf16> + // CHECK: return %[[RESULT]] + tt.return %0 : tensor<1024xf16> +} + + +// CHECK-LABEL: tt.func @test_fp32_to_fp16_rtz_fail +tt.func @test_fp32_to_fp16_rtz_fail(%arg0: tensor<1024xf32>) -> tensor<1024xf16> { + %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf32> -> tensor<1024xf16> + // CHECK: %{{.*}} = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf32> -> tensor<1024xf16> + tt.return %0 : tensor<1024xf16> +} \ No newline at end of file diff --git a/third_party/ascend/unittest/Conversion/General/TritonToHFusion/mod.mlir b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/mod.mlir new file mode 100644 index 0000000000..fd9b479006 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/mod.mlir @@ -0,0 +1,11 @@ +// RUN: triton-opt %s -triton-to-hfusion | FileCheck %s + +// CHECK: tensor.empty() : tensor<1xf32> +// CHECK: hfusion.elemwise_binary {fun = #hfusion.binary_fn} ins(%arg0, %arg1 : tensor<1xf32>, tensor<1xf32>) outs(%0 : tensor<1xf32>) -> tensor<1xf32> + +module { + tt.func @test_mod(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %0 = ascend.mod %arg0, %arg1 : tensor<1xf32> tensor<1xf32> -> tensor<1xf32> + tt.return %0 : tensor<1xf32> + } +} \ No newline at end of file diff --git a/third_party/ascend/unittest/Conversion/General/TritonToHIVM/sync_block_op_conversion.mlir b/third_party/ascend/unittest/Conversion/General/TritonToHIVM/sync_block_op_conversion.mlir new file mode 100644 index 0000000000..d004352a3e --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToHIVM/sync_block_op_conversion.mlir @@ -0,0 +1,20 @@ +// RUN: triton-opt %s --triton-to_hivm | FileCheck %s + +// CHECK-LABEL: tt.func @triton_func +tt.func @triton_func() { + ascend.custom "sync_block_set" {str_args = ["vector", 1 : i32]} + ascend.custom "sync_block_wait" {str_args = ["vector", 1 : i32]} + ascend.custom "sync_block_set" {str_args = ["cube", 2 : i32]} + ascend.custom "sync_block_wait" {str_args = ["cube", 2 : i32]} + ascend.custom "sync_block_all" {str_args = ["all_cube", 1 : i32]} + ascend.custom "sync_block_all" {str_args = ["all_vector", 1 : i32]} + ascend.custom "sync_block_all" {str_args = ["all", 1 : i32]} + tt.return +} +// CHECK: hivm.hir.sync_block_set[, , ] flag = 1 +// CHECK: hivm.hir.sync_block_wait[, , ] flag = 1 +// CHECK: hivm.hir.sync_block_set[, , ] flag = 2 +// CHECK: hivm.hir.sync_block_wait[, , ] flag = 2 +// CHECK: hivm.hir.sync_block[, 1 : i16] tcube_pipe = +// CHECK: hivm.hir.sync_block[, 1 : i16] tvector_pipe = +// CHECK: hivm.hir.sync_block[, 1 : i16] tcube_pipe = tvector_pipe = \ No newline at end of file diff --git a/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw.mlir b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw.mlir new file mode 100644 index 0000000000..f3f36e5dcc --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw.mlir @@ -0,0 +1,109 @@ +// RUN: triton-opt --triton-to-linalg="named-ops=True" --split-input-file %s | FileCheck %s +// CHECK-LABEL: func.func @matmul_atomic_add +// CHECK-NOT: GenericAtomicRMW +// CHECK: tensor.extract_slice +// CHECK: hivm.hir.store ins(%{{.*}} : tensor) outs(%{{.*}} : memref) atomic = + + tt.func public @matmul_atomic_add(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_program_id z : i32 + %c16_i32 = arith.constant 16 : i32 + %c16_i32_0 = arith.constant 16 : i32 + %3 = arith.muli %0, %c16_i32_0 : i32 + %4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %5 = tt.splat %3 : i32 -> tensor<16xi32> + %6 = arith.addi %5, %4 : tensor<16xi32> + %c16_i32_1 = arith.constant 16 : i32 + %c16_i32_2 = arith.constant 16 : i32 + %7 = arith.muli %1, %c16_i32_2 : i32 + %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %9 = tt.splat %7 : i32 -> tensor<16xi32> + %10 = arith.addi %9, %8 : tensor<16xi32> + %11 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %12 = tt.splat %arg10 : i32 -> tensor<16x1xi32> + %13 = arith.muli %11, %12 : tensor<16x1xi32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<16x1x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<16x1x!tt.ptr>, tensor<16x1xi32> + %16 = tt.expand_dims %10 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %17 = tt.splat %arg11 : i32 -> tensor<1x16xi32> + %18 = arith.muli %16, %17 : tensor<1x16xi32> + %19 = tt.broadcast %15 : tensor<16x1x!tt.ptr> -> tensor<16x16x!tt.ptr> + %20 = tt.broadcast %18 : tensor<1x16xi32> -> tensor<16x16xi32> + %21 = tt.addptr %19, %20 : tensor<16x16x!tt.ptr>, tensor<16x16xi32> + %22 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %23 = tt.splat %arg3 : i32 -> tensor<16x1xi32> + %24 = arith.cmpi slt, %22, %23 : tensor<16x1xi32> + %25 = tt.expand_dims %10 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %26 = tt.splat %arg4 : i32 -> tensor<1x16xi32> + %27 = arith.cmpi slt, %25, %26 : tensor<1x16xi32> + %28 = tt.broadcast %24 : tensor<16x1xi1> -> tensor<16x16xi1> + %29 = tt.broadcast %27 : tensor<1x16xi1> -> tensor<16x16xi1> + %30 = arith.andi %28, %29 : tensor<16x16xi1> + %c16_i32_3 = arith.constant 16 : i32 + %c16_i32_4 = arith.constant 16 : i32 + %31 = arith.muli %2, %c16_i32_4 : i32 + %c32_i32 = arith.constant 32 : i32 + %32 = arith.bitcast %31 : i32 to i32 + %33 = arith.bitcast %arg5 : i32 to i32 + %34 = arith.bitcast %c32_i32 : i32 to i32 + %35 = ub.poison : i32 + scf.for %arg12 = %32 to %33 step %34 : i32 { + %36 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %37 = tt.splat %arg12 : i32 -> tensor<16xi32> + %38 = arith.addi %37, %36 : tensor<16xi32> + %39 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %40 = tt.splat %arg6 : i32 -> tensor<16x1xi32> + %41 = arith.muli %39, %40 : tensor<16x1xi32> + %42 = tt.splat %arg0 : !tt.ptr -> tensor<16x1x!tt.ptr> + %43 = tt.addptr %42, %41 : tensor<16x1x!tt.ptr>, tensor<16x1xi32> + %44 = tt.expand_dims %38 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %45 = tt.splat %arg7 : i32 -> tensor<1x16xi32> + %46 = arith.muli %44, %45 : tensor<1x16xi32> + %47 = tt.broadcast %43 : tensor<16x1x!tt.ptr> -> tensor<16x16x!tt.ptr> + %48 = tt.broadcast %46 : tensor<1x16xi32> -> tensor<16x16xi32> + %49 = tt.addptr %47, %48 : tensor<16x16x!tt.ptr>, tensor<16x16xi32> + %50 = tt.expand_dims %38 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %51 = tt.splat %arg8 : i32 -> tensor<16x1xi32> + %52 = arith.muli %50, %51 : tensor<16x1xi32> + %53 = tt.splat %arg1 : !tt.ptr -> tensor<16x1x!tt.ptr> + %54 = tt.addptr %53, %52 : tensor<16x1x!tt.ptr>, tensor<16x1xi32> + %55 = tt.expand_dims %10 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %56 = tt.splat %arg9 : i32 -> tensor<1x16xi32> + %57 = arith.muli %55, %56 : tensor<1x16xi32> + %58 = tt.broadcast %54 : tensor<16x1x!tt.ptr> -> tensor<16x16x!tt.ptr> + %59 = tt.broadcast %57 : tensor<1x16xi32> -> tensor<16x16xi32> + %60 = tt.addptr %58, %59 : tensor<16x16x!tt.ptr>, tensor<16x16xi32> + %61 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %62 = tt.splat %arg3 : i32 -> tensor<16x1xi32> + %63 = arith.cmpi slt, %61, %62 : tensor<16x1xi32> + %64 = tt.expand_dims %38 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %65 = tt.splat %arg5 : i32 -> tensor<1x16xi32> + %66 = arith.cmpi slt, %64, %65 : tensor<1x16xi32> + %67 = tt.broadcast %63 : tensor<16x1xi1> -> tensor<16x16xi1> + %68 = tt.broadcast %66 : tensor<1x16xi1> -> tensor<16x16xi1> + %69 = arith.andi %67, %68 : tensor<16x16xi1> + %70 = tt.expand_dims %38 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %71 = tt.splat %arg5 : i32 -> tensor<16x1xi32> + %72 = arith.cmpi slt, %70, %71 : tensor<16x1xi32> + %73 = tt.expand_dims %10 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %74 = tt.splat %arg4 : i32 -> tensor<1x16xi32> + %75 = arith.cmpi slt, %73, %74 : tensor<1x16xi32> + %76 = tt.broadcast %72 : tensor<16x1xi1> -> tensor<16x16xi1> + %77 = tt.broadcast %75 : tensor<1x16xi1> -> tensor<16x16xi1> + %78 = arith.andi %76, %77 : tensor<16x16xi1> + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0> : tensor<16x16xi32> + %79 = arith.sitofp %cst : tensor<16x16xi32> to tensor<16x16xf32> + %80 = tt.load %49, %69, %79 : tensor<16x16x!tt.ptr> + %c0_i32_5 = arith.constant 0 : i32 + %cst_6 = arith.constant dense<0> : tensor<16x16xi32> + %81 = arith.sitofp %cst_6 : tensor<16x16xi32> to tensor<16x16xf32> + %82 = tt.load %60, %78, %81 : tensor<16x16x!tt.ptr> + %cst_7 = arith.constant 0.000000e+00 : f32 + %cst_8 = arith.constant dense<0.000000e+00> : tensor<16x16xf32> + %83 = tt.dot %80, %82, %cst_8 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> + %84 = tt.atomic_rmw fadd, acq_rel, gpu, %21, %83, %30 : (tensor<16x16x!tt.ptr>, tensor<16x16xf32>, tensor<16x16xi1>) -> tensor<16x16xf32> + } + tt.return + } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw_block.mlir b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw_block.mlir new file mode 100644 index 0000000000..c7ea71622c --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw_block.mlir @@ -0,0 +1,46 @@ +// RUN: triton-opt --triton-to-linalg="named-ops=True" --split-input-file %s | FileCheck %s + +module attributes {hacc.target = #hacc.target<"Ascend910B2">} { + tt.func public @moe_align_block_size_stage4(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} , %arg3: !tt.ptr {tt.divisibility = 16 : i32} , %arg4: !tt.ptr {tt.divisibility = 16 : i32} , %arg5: i32) attributes {noinline = false} { + %cst = arith.constant dense<1> : tensor<1xi32> + %cst_0 = arith.constant dense<0> : tensor<1xi32> + %c250_i32 = arith.constant 250 : i32 + %c16_i32 = arith.constant 16 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.addptr %arg4, %0 : !tt.ptr, i32 + %2 = tt.load %1 : !tt.ptr + %3 = tt.addptr %1, %c1_i32 : !tt.ptr, i32 + %4 = tt.load %3 : !tt.ptr + scf.for %arg6 = %2 to %4 step %c16_i32 : i32 { + %22 = arith.divsi %arg6, %c16_i32 : i32 + %23 = tt.addptr %arg2, %22 : !tt.ptr, i32 + tt.store %23, %0 : !tt.ptr + } + %5 = arith.muli %0, %c250_i32 : i32 + %6 = tt.splat %0 : i32 -> tensor<1xi32> + %7 = arith.cmpi slt, %0, %arg5 : i32 + %8 = tt.splat %7 : i1 -> tensor<1xi1> + %9 = tt.addptr %arg0, %0 : !tt.ptr, i32 + %10 = tt.splat %9 : !tt.ptr -> tensor<1x!tt.ptr> + %11 = tt.load %10, %8, %cst_0 : tensor<1x!tt.ptr> + %12 = tt.addptr %arg3, %5 : !tt.ptr, i32 + %13 = tt.splat %12 : !tt.ptr -> tensor<1x!tt.ptr> + %14 = tt.addptr %13, %11 : tensor<1x!tt.ptr>, tensor<1xi32> + %15 = tt.atomic_rmw add, acq_rel, gpu, %14, %cst, %8 : (tensor<1x!tt.ptr>, tensor<1xi32>, tensor<1xi1>) -> tensor<1xi32> + %16 = tt.splat %arg4 : !tt.ptr -> tensor<1x!tt.ptr> + %17 = tt.addptr %16, %11 : tensor<1x!tt.ptr>, tensor<1xi32> + %18 = tt.load %17, %8, %cst_0 : tensor<1x!tt.ptr> + %19 = arith.addi %15, %18 : tensor<1xi32> + %20 = tt.splat %arg1 : !tt.ptr -> tensor<1x!tt.ptr> + %21 = tt.addptr %20, %19 : tensor<1x!tt.ptr>, tensor<1xi32> + tt.store %21, %6, %8 : tensor<1x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @moe_align_block_size_stage4 + +// CHECK: %[[CAST1:.*]] = memref.reinterpret_cast %[[.*]] to offset: [%[[.*]]], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1], offset: ?>> +// CHECK: %[[CAST2:.*]] = memref.alloc() : memref<1xi32> +// CHECK: memref.copy %[[CAST1]], %[[CAST2]] : memref<1xi32, strided<[1], offset: ?>> to memref<1xi32> \ No newline at end of file diff --git a/third_party/ascend/unittest/Conversion/General/TritonToLinalg/legal_stride.mlir b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/legal_stride.mlir new file mode 100644 index 0000000000..96f84ef0f9 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/legal_stride.mlir @@ -0,0 +1,30 @@ +// RUN: triton-opt --triton-to-linalg="named-ops=True" --split-input-file %s | FileCheck %s +// CHECK-LABEL: func.func @triton_fn_broadcast_nested +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[CAST1:.*]] = memref.reinterpret_cast %[[ARG2:.*]] to offset: [%[[ARG13:.*]]], sizes: [4, 1], strides: [%c4, %[[C1]]] : memref to memref<4x1xf32, strided<[?, ?], offset: ?>> +// CHECK: %[[CAST2:.*]] = memref.reinterpret_cast %[[ARG3:.*]] to offset: [%[[ARG13]]], sizes: [4, 1], strides: [%c4, %[[C1]]] : memref to memref<4x1xf32, strided<[?, ?], offset: ?>> + +module { + tt.func @triton_fn_broadcast_nested(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32){ + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = scf.for %arg10 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg11 = %c0) -> (index) : i32 { + %1 = scf.for %arg12 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg13 = %arg11) -> (index) : i32 { + %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%arg13], sizes: [4, 1], strides: [%c4, %c0] : memref to memref<4x1xf32, strided<[?, ?], offset: ?>> + %alloc = memref.alloc() : memref<4x1xf32> + memref.copy %reinterpret_cast, %alloc : memref<4x1xf32, strided<[?, ?], offset: ?>> to memref<4x1xf32> + %2 = bufferization.to_tensor %alloc restrict writable : memref<4x1xf32> + %reinterpret_cast_0 = memref.reinterpret_cast %arg3 to offset: [%arg13], sizes: [4, 1], strides: [%c4, %c0] : memref to memref<4x1xf32, strided<[?, ?], offset: ?>> + bufferization.materialize_in_destination %2 in writable %reinterpret_cast_0 : (tensor<4x1xf32>, memref<4x1xf32, strided<[?, ?], offset: ?>>) -> () + %3 = arith.addi %arg13, %c1 : index + scf.yield %3 : index + } + scf.yield %1 : index + } + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToLinalg/parse_select.mlir b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/parse_select.mlir new file mode 100644 index 0000000000..f844ca90d5 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/parse_select.mlir @@ -0,0 +1,37 @@ +// RUN: triton-opt --triton-to-linalg --split-input-file %s -verify-each 2>&1 | FileCheck %s --check-prefix=NOERR +// NOERR-NOT: parseSelect currently supports all-ones shape unless cond=i1 with dense constants +// CHECK-LABEL: func.func public @triton_for_if_load + +module { + tt.func public @triton_for_if_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<16xi32> + %c0_i32 = arith.constant 0 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst_0 = arith.constant dense<1> : tensor<16xi32> + %cst_1 = arith.constant dense<32> : tensor<16xi32> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16xf32> + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %1 = tt.get_program_id x : i32 + %2 = arith.cmpi ne, %1, %c0_i32 : i32 + %3 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr> + %5:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<16xi32>, tensor<16xi32>) : i32 { + %6 = arith.muli %arg2, %c16_i32 : i32 + %7 = tt.splat %6 : i32 -> tensor<16xi32> + %8 = arith.addi %arg3, %7 : tensor<16xi32> + %9 = arith.addi %arg4, %7 : tensor<16xi32> + %10 = arith.select %2, %cst_0, %cst : tensor<16xi32> + %11 = arith.addi %8, %10 : tensor<16xi32> + %12 = tt.addptr %3, %11 : tensor<16x!tt.ptr>, tensor<16xi32> + %13 = arith.cmpi slt, %11, %cst_1 : tensor<16xi32> + %14 = tt.load %12 : tensor<16x!tt.ptr> + %15 = arith.select %13, %14, %cst_2 : tensor<16xi1>, tensor<16xf32> + %16 = tt.addptr %4, %9 : tensor<16x!tt.ptr>, tensor<16xi32> + tt.store %16, %15 : tensor<16x!tt.ptr> + scf.yield %11, %9 : tensor<16xi32>, tensor<16xi32> + } + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/CmpConverter.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/CmpConverter.mlir new file mode 100644 index 0000000000..909de6174d --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/CmpConverter.mlir @@ -0,0 +1,24 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +module { + tt.func public @test_cmp(%arg0: tensor<128xi32>) -> tensor<128xi1> { + %cst_12 = arith.constant dense<0> : tensor<128xi32> + %cst_13 = arith.constant dense<1> : tensor<128xi32> + %cst_14 = arith.constant dense<100> : tensor<128xi32> + %39 = arith.cmpi slt, %arg0, %cst_14 : tensor<128xi32> + %40 = arith.select %39, %cst_13, %cst_12 : tensor<128xi1>, tensor<128xi32> + %41 = arith.cmpi ne, %40, %cst_12 : tensor<128xi32> + tt.return %41 : tensor<128xi1> + } +} + +// CHECK-LABEL: tt.func public @test_cmp( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xi32>) -> tensor<128xi1> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<0> : tensor<128xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant dense<1> : tensor<128xi32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<100> : tensor<128xi32> +// CHECK: %[[VAL_4:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_3]] : tensor<128xi32> +// CHECK: %[[VAL_5:.*]] = arith.select %[[VAL_4]], %[[VAL_2]], %[[VAL_1]] : tensor<128xi1>, tensor<128xi32> +// CHECK: %[[VAL_6:.*]] = arith.cmpi ne, %[[VAL_5]], %[[VAL_1]] : tensor<128xi32> +// CHECK: tt.return %[[VAL_6]] : tensor<128xi1> +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/PromotePointerIterArgsPattern.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/PromotePointerIterArgsPattern.mlir new file mode 100644 index 0000000000..7163b89f5f --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/PromotePointerIterArgsPattern.mlir @@ -0,0 +1,73 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +module { + tt.func public @test_promote_pointer_iter(%base_ptr: !tt.ptr {tt.divisibility = 16 : i32}) -> !tt.ptr { + %c1_i32 = arith.constant 1 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %true_mask = arith.constant 1 : i1 + %c0_index = arith.constant 0 : index + %c10_index = arith.constant 10 : index + %c1_index = arith.constant 1 : index + %final_ptr = scf.for %iv = %c0_index to %c10_index step %c1_index iter_args(%ptr = %base_ptr) -> (!tt.ptr) { + %data = tt.load %ptr, %true_mask, %c0_f32 : !tt.ptr + %new_ptr = tt.addptr %ptr, %c1_i32 : !tt.ptr, i32 + scf.yield %new_ptr : !tt.ptr + } + tt.return %final_ptr : !tt.ptr + } +} + +// CHECK-LABEL: tt.func public @test_promote_pointer_iter( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) -> !tt.ptr { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = scf.for %[[VAL_6:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_4]] iter_args(%[[VAL_7:.*]] = %[[VAL_0]]) -> (!tt.ptr) { +// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_7]], %[[VAL_1]] : !tt.ptr, i32 +// CHECK: scf.yield %[[VAL_8]] : !tt.ptr +// CHECK: } +// CHECK: tt.return %[[VAL_5]] : !tt.ptr +// CHECK: } + + +// ----- + + +module { + tt.func public @test_promote_pointer_iter_advance(%base_ptr: !tt.ptr) -> !tt.ptr>{ + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1_i32 = arith.constant 1 : i32 // nonZeroConstant 需要 1 + %c0_i32_2 = arith.constant 0 : i32 + %c0_index = arith.constant 0 : index + %c10_index = arith.constant 10 : index + %c1_index = arith.constant 1 : index + %cst = arith.constant dense<0.000000e+00> : tensor<32xf16> + %ptr0 = tt.make_tensor_ptr %base_ptr, [%c32_i64], [%c1_i64], [%c0_i32] {order = array} : !tt.ptr> + %final_ptr = scf.for %iv = %c0_index to %c10_index step %c1_index iter_args(%ptr = %ptr0) -> !tt.ptr> { + %data = tt.load %ptr : !tt.ptr> + %new_ptr = tt.advance %ptr, [%c1_i32, %c0_i32_2] : !tt.ptr> + scf.yield %new_ptr : !tt.ptr> + } + tt.return %final_ptr : !tt.ptr> + } +} + +// CHECK-LABEL: tt.func public @test_promote_pointer_iter_advance( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> !tt.ptr> { +// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_3:.*]] = arith.constant 32 : i64 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_8:.*]] = tt.make_tensor_ptr %[[VAL_0]], [%[[VAL_3]]], [%[[VAL_2]]], [%[[VAL_1]]] {order = array} : > +// CHECK: %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_7]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (!tt.ptr>) { +// CHECK: %[[VAL_12:.*]] = tt.advance %[[VAL_11]], [%[[VAL_4]], %[[VAL_1]]] : > +// CHECK: scf.yield %[[VAL_12]] : !tt.ptr> +// CHECK: } +// CHECK: tt.return %[[VAL_9]] : !tt.ptr> +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/SplatCmpConverter.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/SplatCmpConverter.mlir new file mode 100644 index 0000000000..f046891ec1 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/SplatCmpConverter.mlir @@ -0,0 +1,16 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s +module { + tt.func public @test_splat_cmp(%arg0: i32, %arg1: i32) -> tensor<128xi1> { + %0 = tt.splat %arg0 : i32 -> tensor<128xi32> + %1 = tt.splat %arg1 : i32 -> tensor<128xi32> + %2 = arith.cmpi slt, %0, %1 : tensor<128xi32> + tt.return %2 : tensor<128xi1> + } +} + +// CHECK-LABEL: tt.func public @test_splat_cmp( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32) -> tensor<128xi1> { +// CHECK: %[[VAL_2:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_1]] : i32 +// CHECK: %[[VAL_3:.*]] = tt.splat %[[VAL_2]] : i1 -> tensor<128xi1> +// CHECK: tt.return %[[VAL_3]] : tensor<128xi1> +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseCmp.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseCmp.mlir new file mode 100644 index 0000000000..5c589cca91 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseCmp.mlir @@ -0,0 +1,117 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +tt.func public @test_cmp_ult(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32> + %cst_0 = arith.constant dense<512> : tensor<1024xi32> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 {tt.divisibility = dense<512> : tensor<1xi32>} : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = arith.divsi %4, %cst_0 : tensor<1024xi32> + %6 = arith.cmpi ult, %5, %cst_0 : tensor<1024xi32> + %7 = arith.muli %5, %cst_0 : tensor<1024xi32> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %9, %6, %cst : tensor<1024x!tt.ptr> + %11 = arith.muli %0, %c1024_i32 : i32 + %12 = tt.splat %11 : i32 -> tensor<1024xi32> + %13 = arith.addi %12, %2 : tensor<1024xi32> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %15, %10 : tensor<1024x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @test_cmp_ult( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<1024> : tensor<1xi64> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK: %[[VAL_4:.*]] = arith.constant dense<512> : tensor<2xi32> +// CHECK: %[[VAL_5:.*]] = arith.constant 512 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_6]] {tt.divisibility = dense<512> : tensor<1xi32>} : i32 +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index +// CHECK: %[[VAL_10:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_11:.*]] = arith.divsi %[[VAL_9]], %[[VAL_5]] : index +// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_11]], %[[VAL_5]] : index +// CHECK: %[[VAL_13:.*]] = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> +// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_13]], %[[VAL_4]] : tensor<2xi32> +// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_12]] : index to i32 +// CHECK: %[[VAL_16:.*]] = tt.splat %[[VAL_15]] : i32 -> tensor<2xi32> +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_14]], %[[VAL_16]] : tensor<2xi32> +// CHECK: %[[VAL_18:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<2x!tt.ptr> +// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_18]], %[[VAL_17]] : tensor<2x!tt.ptr>, tensor<2xi32> +// CHECK: %[[VAL_20:.*]] = arith.index_cast %[[VAL_11]] : index to i32 +// CHECK: %[[VAL_21:.*]] = tt.splat %[[VAL_20]] : i32 -> tensor<2xi32> +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_13]], %[[VAL_21]] : tensor<2xi32> +// CHECK: %[[VAL_23:.*]] = arith.cmpi slt, %[[VAL_22]], %[[VAL_4]] : tensor<2xi32> +// CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_19]], %[[VAL_23]], %[[VAL_3]] : tensor<2x!tt.ptr> +// CHECK: %[[VAL_25:.*]] = tensor.empty() : tensor<2x512xf32> +// CHECK: %[[VAL_26:.*]] = linalg.broadcast ins(%[[VAL_24]] : tensor<2xf32>) outs(%[[VAL_25]] : tensor<2x512xf32>) dimensions = [1] +// CHECK: %[[VAL_27:.*]] = tensor.reshape %[[VAL_26]](%[[VAL_2]]) : (tensor<2x512xf32>, tensor<1xi64>) -> tensor<1024xf32> +// CHECK: %[[VAL_28:.*]] = arith.muli %[[VAL_7]], %[[VAL_6]] : i32 +// CHECK: %[[VAL_29:.*]] = tt.splat %[[VAL_28]] : i32 -> tensor<1024xi32> +// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_10]] : tensor<1024xi32> +// CHECK: %[[VAL_31:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_32:.*]] = tt.addptr %[[VAL_31]], %[[VAL_30]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: tt.store %[[VAL_32]], %[[VAL_27]] : tensor<1024x!tt.ptr> +// CHECK: tt.return +// CHECK: } + + +// ----- + + +tt.func public @test_cmp_uge(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32> + %cst_0 = arith.constant dense<511> : tensor<1024xi32> + %cst_1 = arith.constant dense<512> : tensor<1024xi32> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 {tt.divisibility = dense<512> : tensor<1xi32>} : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = arith.divsi %4, %cst_1 : tensor<1024xi32> + %6 = arith.cmpi uge, %cst_0, %5 : tensor<1024xi32> + %7 = arith.muli %5, %cst_1 : tensor<1024xi32> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %9, %6, %cst : tensor<1024x!tt.ptr> + %11 = arith.muli %0, %c1024_i32 : i32 + %12 = tt.splat %11 : i32 -> tensor<1024xi32> + %13 = arith.addi %12, %2 : tensor<1024xi32> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %15, %10 : tensor<1024x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @test_cmp_uge( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : tensor<1024xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<511> : tensor<1024xi32> +// CHECK: %[[VAL_4:.*]] = arith.constant dense<512> : tensor<1024xi32> +// CHECK: %[[VAL_5:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_5]] {tt.divisibility = dense<512> : tensor<1xi32>} : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_7]] : i32 -> tensor<1024xi32> +// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_9]], %[[VAL_8]] : tensor<1024xi32> +// CHECK: %[[VAL_11:.*]] = arith.divsi %[[VAL_10]], %[[VAL_4]] : tensor<1024xi32> +// CHECK: %[[VAL_12:.*]] = arith.cmpi ule, %[[VAL_11]], %[[VAL_3]] : tensor<1024xi32> +// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : tensor<1024xi32> +// CHECK: %[[VAL_14:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_14]], %[[VAL_13]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_16:.*]] = tt.load %[[VAL_15]], %[[VAL_12]], %[[VAL_2]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_6]], %[[VAL_5]] : i32 +// CHECK: %[[VAL_18:.*]] = tt.splat %[[VAL_17]] : i32 -> tensor<1024xi32> +// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_8]] : tensor<1024xi32> +// CHECK: %[[VAL_20:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_21:.*]] = tt.addptr %[[VAL_20]], %[[VAL_19]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: tt.store %[[VAL_21]], %[[VAL_16]] : tensor<1024x!tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseConstant.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseConstant.mlir new file mode 100644 index 0000000000..4e1053793d --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseConstant.mlir @@ -0,0 +1,28 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +tt.func public @test_non_splat_mask(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %non_splat_mask = arith.constant dense<[false, true]> : tensor<2xi1> + %c0_f32 = arith.constant dense<0.000000e+00> : tensor<2xf32> + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %ptr_load = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %2 = tt.load %ptr_load, %non_splat_mask, %c0_f32 : tensor<2x!tt.ptr> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %ptr_store = tt.addptr %3, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %ptr_store, %2 : tensor<2x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @test_non_splat_mask( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<[false, true]> : tensor<2xi1> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK: %[[VAL_4:.*]] = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> +// CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<2x!tt.ptr> +// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_5]], %[[VAL_4]] : tensor<2x!tt.ptr>, tensor<2xi32> +// CHECK: %[[VAL_7:.*]] = tt.load %[[VAL_6]], %[[VAL_2]], %[[VAL_3]] : tensor<2x!tt.ptr> +// CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<2x!tt.ptr> +// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_8]], %[[VAL_4]] : tensor<2x!tt.ptr>, tensor<2xi32> +// CHECK: tt.store %[[VAL_9]], %[[VAL_7]] : tensor<2x!tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseMakeRange.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseMakeRange.mlir new file mode 100644 index 0000000000..efdaad3204 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseMakeRange.mlir @@ -0,0 +1,24 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +module { + tt.func public @test_stride_not_one(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %c0_f32 = arith.constant dense<0.000000e+00> : tensor<4xf32> + %fake_range_mask = arith.constant dense<[false, false, false, false]> : tensor<4xi1> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %ptr = tt.addptr %1, %0 : tensor<4x!tt.ptr>, tensor<4xi32> + %2 = tt.load %ptr, %fake_range_mask, %c0_f32 : tensor<4x!tt.ptr> + tt.store %ptr, %2 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: tt.func public @test_stride_not_one( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : tensor<4xf32> +// CHECK: %[[VAL_2:.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> +// CHECK: %[[VAL_3:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<4x!tt.ptr> +// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_3]], %[[VAL_2]] : tensor<4x!tt.ptr>, tensor<4xi32> +// CHECK: tt.store %[[VAL_4]], %[[VAL_1]] : tensor<4x!tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseRem.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseRem.mlir new file mode 100644 index 0000000000..b48d7fc288 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseRem.mlir @@ -0,0 +1,81 @@ +// RUN: triton-opt '--triton-to-structured=enable-mask-fallback-conversion=false optimize-dynamic-offset=true' --split-input-file %s | FileCheck %s + +tt.func public @kernel_with_rem_safe(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256xf32> + %cst_0 = arith.constant dense<1024> : tensor<256xi32> + %c256_i32 = arith.constant 256 : i32 + %cst_1 = arith.constant dense<64> : tensor<256xi32> + %cst_2 = arith.constant dense<128> : tensor<256xi32> + %0 = tt.get_program_id x : i32 + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %2 = arith.remsi %1, %cst_2 : tensor<256xi32> + %3 = arith.cmpi slt, %2, %cst_1 : tensor<256xi32> + %4 = arith.muli %0, %c256_i32 : i32 + %5 = tt.splat %4 : i32 -> tensor<256xi32> + %6 = arith.addi %5, %1 : tensor<256xi32> + %7 = arith.cmpi slt, %6, %cst_0 : tensor<256xi32> + %8 = arith.andi %3, %7 : tensor<256xi1> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %10 = tt.addptr %9, %2 : tensor<256x!tt.ptr>, tensor<256xi32> + %11 = tt.load %10, %8, %cst : tensor<256x!tt.ptr> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr> + %13 = tt.addptr %12, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %13, %11, %8 : tensor<256x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @kernel_with_rem_safe( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : tensor<256xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<1024> : tensor<256xi32> +// CHECK: %[[VAL_4:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant dense<64> : tensor<256xi32> +// CHECK: %[[VAL_6:.*]] = arith.constant dense<128> : tensor<256xi32> +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> +// CHECK: %[[VAL_9:.*]] = arith.remsi %[[VAL_8]], %[[VAL_6]] : tensor<256xi32> +// CHECK: %[[VAL_10:.*]] = arith.cmpi slt, %[[VAL_9]], %[[VAL_5]] : tensor<256xi32> +// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : i32 +// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_11]] : i32 -> tensor<256xi32> +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_8]] : tensor<256xi32> +// CHECK: %[[VAL_14:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_3]] : tensor<256xi32> +// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_10]], %[[VAL_14]] : tensor<256xi1> +// CHECK: %[[VAL_16:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<256x!tt.ptr> +// CHECK: %[[VAL_17:.*]] = tt.addptr %[[VAL_16]], %[[VAL_9]] : tensor<256x!tt.ptr>, tensor<256xi32> +// CHECK: %[[VAL_18:.*]] = tt.load %[[VAL_17]], %[[VAL_15]], %[[VAL_2]] : tensor<256x!tt.ptr> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<256x!tt.ptr> +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_19]], %[[VAL_8]] : tensor<256x!tt.ptr>, tensor<256xi32> +// CHECK: tt.store %[[VAL_20]], %[[VAL_18]], %[[VAL_15]] : tensor<256x!tt.ptr> +// CHECK: tt.return +// CHECK: } + + +// ----- + + +tt.func public @test_remsi_with_broadcast(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<2x2xf32> { + %c0_f32 = arith.constant dense<0.000000e+00> : tensor<2x2xf32> + %c4 = arith.constant dense<4> : tensor<2x2xi32> + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> + %2 = tt.broadcast %1 : tensor<1x2xi32> -> tensor<2x2xi32> + %3 = arith.remsi %2, %c4 : tensor<2x2xi32> + %4 = arith.trunci %3 : tensor<2x2xi32> to tensor<2x2xi1> + %ptrs = tt.splat %arg0 : !tt.ptr -> tensor<2x2x!tt.ptr> + %vals = tt.load %ptrs, %4, %c0_f32 : tensor<2x2x!tt.ptr> + tt.return %vals : tensor<2x2xf32> +} + +// CHECK-LABEL: tt.func public @test_remsi_with_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<2x2xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : tensor<2x2xf32> +// CHECK: %[[VAL_2:.*]] = arith.constant dense<4> : tensor<2x2xi32> +// CHECK: %[[VAL_3:.*]] = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> +// CHECK: %[[VAL_4:.*]] = tt.expand_dims %[[VAL_3]] {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> +// CHECK: %[[VAL_5:.*]] = tt.broadcast %[[VAL_4]] : tensor<1x2xi32> -> tensor<2x2xi32> +// CHECK: %[[VAL_6:.*]] = arith.remsi %[[VAL_5]], %[[VAL_2]] : tensor<2x2xi32> +// CHECK: %[[VAL_7:.*]] = arith.trunci %[[VAL_6]] : tensor<2x2xi32> to tensor<2x2xi1> +// CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<2x2x!tt.ptr> +// CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_8]], %[[VAL_7]], %[[VAL_1]] : tensor<2x2x!tt.ptr> +// CHECK: tt.return %[[VAL_9]] : tensor<2x2xf32> +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/bubbleupoperation.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/bubbleupoperation.mlir new file mode 100644 index 0000000000..13d653f80f --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/bubbleupoperation.mlir @@ -0,0 +1,127 @@ +// RUN: triton-opt %s --bubble-up-operation | FileCheck %s + +// CHECK-LABEL: tt.func @test_subi_extract_bubbleup +tt.func @test_subi_extract_bubbleup(%a: tensor<128xi32>, %b: tensor<128xi32>, %i: index, %c: i32) -> i32 { + %0 = arith.subi %a, %b : tensor<128xi32> + %1 = tensor.extract %0[%i] : tensor<128xi32> + %2 = arith.muli %1, %c : i32 + tt.return %2 : i32 +} + + +// CHECK-LABEL: tt.func @test_maxsi_extract_bubbleup +tt.func @test_maxsi_extract_bubbleup(%a: tensor<128xi32>, %b: tensor<128xi32>, %i: index, %c: i32) -> i32 { + %0 = arith.maxsi %a, %b : tensor<128xi32> + %1 = tensor.extract %0[%i] : tensor<128xi32> + %2 = arith.muli %1, %c : i32 + tt.return %2 : i32 +} + + +// CHECK-LABEL: tt.func @test_minsi_extract_bubbleup +tt.func @test_minsi_extract_bubbleup(%a: tensor<128xi32>, %b: tensor<128xi32>, %i: index, %c: i32) -> i32 { + %0 = arith.minsi %a, %b : tensor<128xi32> + %1 = tensor.extract %0[%i] : tensor<128xi32> + %2 = arith.muli %1, %c : i32 + tt.return %2 : i32 +} + + +// CHECK-LABEL: tt.func @test_extf_extract_bubbleup +tt.func @test_extf_extract_bubbleup(%a: tensor<128xf16>, %i: index, %c: f32) -> f32 { + %0 = arith.extf %a : tensor<128xf16> to tensor<128xf32> + %1 = tensor.extract %0[%i] : tensor<128xf32> + %2 = arith.mulf %1, %c : f32 + tt.return %2 : f32 +} + + +// CHECK-LABEL: tt.func @test_minnumf_extract_bubbleup +tt.func @test_minnumf_extract_bubbleup(%a: tensor<128xf32>, %b: tensor<128xf32>, %i: index, %c: f32) -> f32 { + %0 = arith.minnumf %a, %b : tensor<128xf32> + %1 = tensor.extract %0[%i] : tensor<128xf32> + %2 = arith.mulf %1, %c : f32 + tt.return %2 : f32 +} + + +// CHECK-LABEL: tt.func @test_maxnumf_extract_bubbleup +tt.func @test_maxnumf_extract_bubbleup(%a: tensor<128xf32>, %b: tensor<128xf32>, %i: index, %c: f32) -> f32 { + %0 = arith.maxnumf %a, %b : tensor<128xf32> + %1 = tensor.extract %0[%i] : tensor<128xf32> + %2 = arith.mulf %1, %c : f32 + tt.return %2 : f32 +} + + +// CHECK-LABEL: tt.func @test_cmpf_extract_bubbleup +tt.func @test_cmpf_extract_bubbleup(%a: tensor<128xf32>, %b: tensor<128xf32>, %i: index) -> i1 { + %0 = arith.cmpf olt, %a, %b : tensor<128xf32> + %1 = tensor.extract %0[%i] : tensor<128xi1> + tt.return %1 : i1 +} + + +// CHECK-LABEL: tt.func @test_addptr_extract_bubbleup +tt.func @test_addptr_extract_bubbleup(%a: tensor<128x!tt.ptr>, %b: tensor<128xi32>, %i: index) -> !tt.ptr { + %0 = tt.addptr %a, %b : tensor<128x!tt.ptr>, tensor<128xi32> + %1 = tensor.extract %0[%i] : tensor<128x!tt.ptr> + tt.return %1 : !tt.ptr +} + + +// CHECK-LABEL: tt.func @test_ceil_extract_bubbleup +tt.func @test_ceil_extract_bubbleup(%a: tensor<128xf32>, %i: index, %c: f32) -> f32 { + %0 = math.ceil %a : tensor<128xf32> + %1 = tensor.extract %0[%i] : tensor<128xf32> + %2 = arith.mulf %1, %c : f32 + tt.return %2 : f32 +} + + +// CHECK-LABEL: tt.func @test_slice_extract_dropdim_bubbleup +tt.func @test_slice_extract_dropdim_bubbleup(%a: tensor<128x128x128xf32>, %i: index, %j: index) -> f32 { + %0 = tensor.extract_slice %a[0, %i, 0][1, 1, 128][1, 1, 1] : tensor<128x128x128xf32> to tensor<128xf32> + %1 = tensor.extract %0[%j] : tensor<128xf32> + tt.return %1 : f32 +} + + +// CHECK-LABEL: tt.func @test_expand_slice_bubbleup +tt.func @test_expand_slice_bubbleup(%a: tensor<128xf32>, %i: index, %c: f32) -> tensor<1x1xf32> { + %0 = tt.expand_dims %a {axis = 0 : i32} : tensor<128xf32> -> tensor<1x128xf32> + %1 = tensor.extract_slice %0[0, %i][1, 1][1, 1] : tensor<1x128xf32> to tensor<1x1xf32> + tt.return %1 : tensor<1x1xf32> +} + + +// CHECK-LABEL: tt.func @test_expand_slice_dropdim_bubbleup +tt.func @test_expand_slice_dropdim_bubbleup(%a: tensor<128x128xf32>, %i: index, %c: f32) -> tensor<128x1xf32> { + %0 = tt.expand_dims %a {axis = 2 : i32} : tensor<128x128xf32> -> tensor<128x128x1xf32> + %1 = tensor.extract_slice %0[%i, 0, 0][1, 128, 1][1, 1, 1] : tensor<128x128x1xf32> to tensor<128x1xf32> + tt.return %1 : tensor<128x1xf32> +} + + +// CHECK-LABEL: tt.func @test_splat_slice_bubbleup +tt.func @test_splat_slice_bubbleup(%a: f32, %i: index, %c: f32) -> tensor<1xf32> { + %0 = tt.splat %a : f32 -> tensor<128xf32> + %1 = tensor.extract_slice %0[%i][1][1] : tensor<128xf32> to tensor<1xf32> + tt.return %1 : tensor<1xf32> +} + + +// CHECK-LABEL: tt.func @test_makerange_slice_bubbleup +tt.func @test_makerange_slice_bubbleup(%i: index, %c: f32) -> tensor<1xi32> { + %0 = tt.make_range {start = 0 : i32, end = 128 : i32} : tensor<128xi32> + %1 = tensor.extract_slice %0[%i][1][1] : tensor<128xi32> to tensor<1xi32> + tt.return %1 : tensor<1xi32> +} + + +// CHECK-LABEL: tt.func @test_slice_all_bubbleup +tt.func @test_slice_all_bubbleup(%i: index, %c: f32) -> tensor<128xi32> { + %0 = tt.make_range {start = 0 : i32, end = 128 : i32} : tensor<128xi32> + %1 = tensor.extract_slice %0[0][128][1] : tensor<128xi32> to tensor<128xi32> + tt.return %1 : tensor<128xi32> +} \ No newline at end of file diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/if_simplifier.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/if_simplifier.mlir new file mode 100644 index 0000000000..7f71ec94a9 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/if_simplifier.mlir @@ -0,0 +1,45 @@ +// RUN: triton-opt --triton-to-unstructure --split-input-file %s | FileCheck %s --implicit-check-not="DiscreteMemAccess" +// CHECK-LABEL: tt.func public @triton_for_if_load +// CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor<16xi32> +// CHECK: %[[CST0:.*]] = arith.constant dense<1> : tensor<16xi32> +// CHECK: %[[SEL:.*]] = arith.select %{{.*}}, %[[CST0]], %[[CST]] : tensor<16xi32> +// CHECK: %[[ADD:.*]] = arith.addi %{{.*}}, %[[SEL]] : tensor<16xi32> +// CHECK: %[[ADDPTR:.*]] = tt.addptr %{{.*}}, %[[ADD]] : tensor<16x!tt.ptr>, tensor<16xi32> + + +module { + tt.func public @triton_for_if_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16xf32> + %cst_0 = arith.constant dense<32> : tensor<16xi32> + %cst_1 = arith.constant dense<1> : tensor<16xi32> + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %1 = tt.get_program_id x : i32 + %2 = arith.cmpi ne, %1, %c0_i32 : i32 + %3 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr> + %5:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<16xi32>, tensor<16xi32>) : i32 { + %6 = arith.muli %arg2, %c16_i32 : i32 + %7 = tt.splat %6 : i32 -> tensor<16xi32> + %8 = arith.addi %arg3, %7 : tensor<16xi32> + %9 = arith.addi %arg4, %7 : tensor<16xi32> + %10 = scf.if %2 -> (tensor<16xi32>) { + %16 = arith.addi %8, %cst_1 : tensor<16xi32> + scf.yield %16 : tensor<16xi32> + } else { + scf.yield %8 : tensor<16xi32> + } + %11 = tt.addptr %3, %10 : tensor<16x!tt.ptr>, tensor<16xi32> + %12 = arith.cmpi slt, %10, %cst_0 : tensor<16xi32> + %13 = tt.load %11 : tensor<16x!tt.ptr> + %14 = arith.select %12, %13, %cst : tensor<16xi1>, tensor<16xf32> + %15 = tt.addptr %4, %9 : tensor<16x!tt.ptr>, tensor<16xi32> + tt.store %15, %14 : tensor<16x!tt.ptr> + scf.yield %10, %9 : tensor<16xi32>, tensor<16xi32> + } + tt.return + } +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/nested_loop.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/nested_loop.mlir new file mode 100644 index 0000000000..984141e3a2 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/nested_loop.mlir @@ -0,0 +1,207 @@ +// RUN: triton-opt --triton-to-unstructure --split-input-file %s | FileCheck %s + +tt.func public @test_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c3_i32 = arith.constant 3 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<128> : tensor<128xi32> + %cst_0 = arith.constant dense<0> : tensor<128xi32> + %cst_1 = arith.constant dense<300> : tensor<128xi32> + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %3 = tt.splat %1 : i32 -> tensor<128xi32> + %4 = arith.addi %3, %2 : tensor<128xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %7 = tt.addptr %6, %4 : tensor<128x!tt.ptr>, tensor<128xi32> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<128x!tt.ptr> + %9 = tt.addptr %8, %4 : tensor<128x!tt.ptr>, tensor<128xi32> + %10 = tt.load %9 : tensor<128x!tt.ptr> + %11 = tt.splat %arg3 : !tt.ptr -> tensor<128x!tt.ptr> + %12 = tt.addptr %11, %10 : tensor<128x!tt.ptr>, tensor<128xi32> + %13:3 = scf.for %arg4 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg5 = %4, %arg6 = %7, %arg7 = %12) -> (tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>) : i32 { + %14:3 = scf.for %arg8 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg9 = %arg5, %arg10 = %arg6, %arg11 = %arg7) -> (tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>) : i32 { + %18 = arith.cmpi slt, %arg9, %cst_1 : tensor<128xi32> + %19 = tt.addptr %5, %arg9 : tensor<128x!tt.ptr>, tensor<128xi32> + %20 = tt.load %19, %18, %cst_0 : tensor<128x!tt.ptr> + %21 = tt.load %arg11 : tensor<128x!tt.ptr> + %22 = arith.addi %20, %21 : tensor<128xi32> + tt.store %arg10, %22, %18 : tensor<128x!tt.ptr> + %23 = arith.addi %arg9, %cst : tensor<128xi32> + %24 = tt.addptr %arg10, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + %25 = tt.addptr %arg11, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + scf.yield %23, %24, %25 : tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr> + } + %15 = arith.addi %14#0, %cst : tensor<128xi32> + %16 = tt.addptr %14#1, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + %17 = tt.addptr %14#2, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + scf.yield %15, %16, %17 : tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr> + } + tt.return +} + +// CHECK-LABEL: tt.func public @test_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_4:.*]] = arith.constant dense<128> : tensor<128xi64> +// CHECK: %[[VAL_8:.*]] = arith.constant 3 : i32 +// CHECK: %[[VAL_15:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_15]], %{{.*}} : i32 +// CHECK: %[[VAL_17:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +// CHECK: %[[VAL_18:.*]] = tt.splat %[[VAL_16]] : i32 -> tensor<128xi32> +// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_17]] : tensor<128xi32> +// CHECK: %[[VAL_20:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_19]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[VAL_22:.*]] = tt.splat %[[VAL_2]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_22]], %[[VAL_19]] : tensor<128x!tt.ptr>, tensor<128xi32> +// CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_23]] : tensor<128x!tt.ptr> +// CHECK: %[[VAL_25:.*]] = arith.extsi %[[VAL_24]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[VAL_26:.*]]:3 = scf.for %[[VAL_27:.*]] = %{{.*}} to %[[VAL_8]] step %{{.*}} iter_args(%[[VAL_28:.*]] = %[[VAL_19]], %[[VAL_29:.*]] = %[[VAL_21]], %[[VAL_30:.*]] = %[[VAL_25]]) -> (tensor<128xi32>, tensor<128xi64>, tensor<128xi64>) : i32 { +// CHECK: %[[VAL_31:.*]]:3 = scf.for %[[VAL_32:.*]] = %{{.*}} to %[[VAL_8]] step %{{.*}} iter_args(%[[VAL_33:.*]] = %[[VAL_28]], %[[VAL_34:.*]] = %[[VAL_29]], %[[VAL_35:.*]] = %[[VAL_30]]) -> (tensor<128xi32>, tensor<128xi64>, tensor<128xi64>) : i32 { +// CHECK: %[[VAL_36:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_37:.*]] = tt.addptr %[[VAL_36]], %[[VAL_34]] : tensor<128x!tt.ptr>, tensor<128xi64> +// CHECK: %[[VAL_38:.*]] = arith.cmpi slt, %[[VAL_33]], %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_39:.*]] = tt.addptr %[[VAL_20]], %[[VAL_33]] : tensor<128x!tt.ptr>, tensor<128xi32> +// CHECK: %[[VAL_40:.*]] = tt.load %[[VAL_39]], %[[VAL_38]], %{{.*}} : tensor<128x!tt.ptr> +// CHECK: %[[VAL_41:.*]] = tensor.empty() : tensor<128xi32> +// CHECK: %[[VAL_42:.*]] = scf.for %[[VAL_43:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[VAL_44:.*]] = %[[VAL_41]]) -> (tensor<128xi32>) { +// CHECK: %[[VAL_45:.*]] = tensor.extract %[[VAL_35]]{{\[}}%[[VAL_43]]] {DiscreteMemAccess} : tensor<128xi64> +// CHECK: %[[VAL_46:.*]] = tt.addptr %[[VAL_3]], %[[VAL_45]] : !tt.ptr, i64 +// CHECK: %[[VAL_47:.*]] = tt.load %[[VAL_46]] {DiscreteMemAccess} : !tt.ptr +// CHECK: %[[VAL_49:.*]] = tensor.insert_slice %{{.*}} into %[[VAL_44]]{{\[}}%[[VAL_43]]] [1] [1] : tensor<1xi32> into tensor<128xi32> +// CHECK: scf.yield {DiscreteMemAccess} %[[VAL_49]] : tensor<128xi32> +// CHECK: } {ExtractedLoadOrStore} +// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_40]], %[[VAL_42]] : tensor<128xi32> +// CHECK: tt.store %[[VAL_37]], %[[VAL_50]], %[[VAL_38]] : tensor<128x!tt.ptr> +// CHECK: %[[VAL_51:.*]] = arith.addi %[[VAL_33]], %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_52:.*]] = arith.addi %[[VAL_34]], %[[VAL_4]] : tensor<128xi64> +// CHECK: %[[VAL_53:.*]] = arith.addi %[[VAL_35]], %[[VAL_4]] : tensor<128xi64> +// CHECK: scf.yield %[[VAL_51]], %[[VAL_52]], %[[VAL_53]] : tensor<128xi32>, tensor<128xi64>, tensor<128xi64> +// CHECK: } +// CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_55:.*]]#0, %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_55]]#1, %[[VAL_4]] : tensor<128xi64> +// CHECK: %[[VAL_57:.*]] = arith.addi %[[VAL_55]]#2, %[[VAL_4]] : tensor<128xi64> +// CHECK: scf.yield %[[VAL_54]], %[[VAL_56]], %[[VAL_57]] : tensor<128xi32>, tensor<128xi64>, tensor<128xi64> +// CHECK: } +// CHECK: tt.return +// CHECK: } + +// ----- + +tt.func public @test_kernel2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c300_i32 = arith.constant 300 : i32 + %c3_i32 = arith.constant 3 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<128> : tensor<128xi32> + %cst_0 = arith.constant dense<0> : tensor<128xi32> + %cst_1 = arith.constant dense<300> : tensor<128xi32> + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %3 = tt.splat %1 : i32 -> tensor<128xi32> + %4 = arith.addi %3, %2 : tensor<128xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %7 = tt.addptr %6, %4 : tensor<128x!tt.ptr>, tensor<128xi32> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<128x!tt.ptr> + %9 = tt.addptr %8, %4 : tensor<128x!tt.ptr>, tensor<128xi32> + %10 = tt.load %9 : tensor<128x!tt.ptr> + %11 = tt.splat %arg3 : !tt.ptr -> tensor<128x!tt.ptr> + %12 = tt.addptr %11, %10 : tensor<128x!tt.ptr>, tensor<128xi32> + %13:3 = scf.while (%arg4 = %cst, %arg5 = %4, %arg6 = %7, %arg7 = %12) : (tensor<128xi32>, tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>) -> (tensor<128x!tt.ptr>, tensor<128xi32>, tensor<128x!tt.ptr>) { + %14 = "tt.reduce"(%arg4) <{axis = 0 : i32}> ({ + ^bb0(%arg8: i32, %arg9: i32): + %16 = arith.addi %arg8, %arg9 : i32 + tt.reduce.return %16 : i32 + }) : (tensor<128xi32>) -> i32 + %15 = arith.cmpi slt, %14, %c300_i32 : i32 + scf.condition(%15) %arg7, %arg5, %arg6 : tensor<128x!tt.ptr>, tensor<128xi32>, tensor<128x!tt.ptr> + } do { + ^bb0(%arg4: tensor<128x!tt.ptr>, %arg5: tensor<128xi32>, %arg6: tensor<128x!tt.ptr>): + %14:4 = scf.while (%arg7 = %c0_i32, %arg8 = %arg6, %arg9 = %arg4, %arg10 = %arg5) : (i32, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>, tensor<128xi32>) -> (tensor<128xi32>, i32, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>) { + %18 = arith.cmpi slt, %arg7, %c3_i32 : i32 + scf.condition(%18) %arg10, %arg7, %arg8, %arg9 : tensor<128xi32>, i32, tensor<128x!tt.ptr>, tensor<128x!tt.ptr> + } do { + ^bb0(%arg7: tensor<128xi32>, %arg8: i32, %arg9: tensor<128x!tt.ptr>, %arg10: tensor<128x!tt.ptr>): + %18 = arith.cmpi slt, %arg7, %cst_1 : tensor<128xi32> + %19 = tt.addptr %5, %arg7 : tensor<128x!tt.ptr>, tensor<128xi32> + %20 = tt.load %19, %18, %cst_0 : tensor<128x!tt.ptr> + %21 = tt.load %arg10 : tensor<128x!tt.ptr> + %22 = arith.addi %20, %21 : tensor<128xi32> + tt.store %arg9, %22, %18 : tensor<128x!tt.ptr> + %23 = arith.addi %arg7, %cst : tensor<128xi32> + %24 = arith.addi %arg8, %c1_i32 : i32 + %25 = tt.addptr %arg9, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + %26 = tt.addptr %arg10, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + scf.yield %24, %25, %26, %23 : i32, tensor<128x!tt.ptr>, tensor<128x!tt.ptr>, tensor<128xi32> + } + %15 = arith.addi %14#0, %cst : tensor<128xi32> + %16 = tt.addptr %14#2, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + %17 = tt.addptr %14#3, %cst : tensor<128x!tt.ptr>, tensor<128xi32> + scf.yield %14#0, %15, %16, %17 : tensor<128xi32>, tensor<128xi32>, tensor<128x!tt.ptr>, tensor<128x!tt.ptr> + } + tt.return +} + +// CHECK-LABEL: tt.func public @test_kernel2( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_4:.*]] = arith.constant dense<128> : tensor<128xi64> +// CHECK: %[[VAL_8:.*]] = arith.constant 300 : i32 +// CHECK: %[[VAL_16:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %{{.*}} : i32 +// CHECK: %[[VAL_18:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_17]] : i32 -> tensor<128xi32> +// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : tensor<128xi32> +// CHECK: %[[VAL_21:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_22:.*]] = arith.extsi %[[VAL_20]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_2]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_20]] : tensor<128x!tt.ptr>, tensor<128xi32> +// CHECK: %[[VAL_25:.*]] = tt.load %[[VAL_24]] : tensor<128x!tt.ptr> +// CHECK: %[[VAL_26:.*]] = arith.extsi %[[VAL_25]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[VAL_27:.*]]:3 = scf.while (%[[VAL_28:.*]] = %{{.*}}, %[[VAL_29:.*]] = %[[VAL_20]], %[[VAL_30:.*]] = %[[VAL_22]], %[[VAL_31:.*]] = %[[VAL_26]]) : (tensor<128xi32>, tensor<128xi32>, tensor<128xi64>, tensor<128xi64>) -> (tensor<128xi64>, tensor<128xi32>, tensor<128xi64>) { +// CHECK: %[[VAL_32:.*]] = "tt.reduce"(%[[VAL_28]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0(%[[VAL_33:.*]]: i32, %[[VAL_34:.*]]: i32): +// CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_33]], %[[VAL_34]] : i32 +// CHECK: tt.reduce.return %[[VAL_35]] : i32 +// CHECK: }) : (tensor<128xi32>) -> i32 +// CHECK: %[[VAL_36:.*]] = arith.cmpi slt, %[[VAL_32]], %[[VAL_8]] : i32 +// CHECK: scf.condition(%[[VAL_36]]) %[[VAL_31]], %[[VAL_29]], %[[VAL_30]] : tensor<128xi64>, tensor<128xi32>, tensor<128xi64> +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_37:.*]]: tensor<128xi64>, %[[VAL_38:.*]]: tensor<128xi32>, %[[VAL_39:.*]]: tensor<128xi64>): +// CHECK: %[[VAL_40:.*]]:4 = scf.while (%[[VAL_41:.*]] = %{{.*}}, %[[VAL_42:.*]] = %[[VAL_39]], %[[VAL_43:.*]] = %[[VAL_37]], %[[VAL_44:.*]] = %[[VAL_38]]) : (i32, tensor<128xi64>, tensor<128xi64>, tensor<128xi32>) -> (i32, tensor<128xi64>, tensor<128xi64>, tensor<128xi32>) { +// CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_41]], %{{.*}} : i32 +// CHECK: scf.condition(%[[VAL_45]]) %[[VAL_41]], %[[VAL_42]], %[[VAL_43]], %[[VAL_44]] : i32, tensor<128xi64>, tensor<128xi64>, tensor<128xi32> +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_46:.*]]: i32, %[[VAL_47:.*]]: tensor<128xi64>, %[[VAL_48:.*]]: tensor<128xi64>, %[[VAL_49:.*]]: tensor<128xi32>): +// CHECK: %[[VAL_50:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: %[[VAL_51:.*]] = tt.addptr %[[VAL_50]], %[[VAL_47]] : tensor<128x!tt.ptr>, tensor<128xi64> +// CHECK: %[[VAL_52:.*]] = arith.cmpi slt, %[[VAL_49]], %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_53:.*]] = tt.addptr %[[VAL_21]], %[[VAL_49]] : tensor<128x!tt.ptr>, tensor<128xi32> +// CHECK: %[[VAL_54:.*]] = tt.load %[[VAL_53]], %[[VAL_52]], %{{.*}} : tensor<128x!tt.ptr> +// CHECK: %[[VAL_55:.*]] = tensor.empty() : tensor<128xi32> +// CHECK: %[[VAL_56:.*]] = scf.for %[[VAL_57:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[VAL_58:.*]] = %[[VAL_55]]) -> (tensor<128xi32>) { +// CHECK: %[[VAL_59:.*]] = tensor.extract %[[VAL_48]]{{\[}}%[[VAL_57]]] {DiscreteMemAccess} : tensor<128xi64> +// CHECK: %[[VAL_60:.*]] = tt.addptr %[[VAL_3]], %[[VAL_59]] : !tt.ptr, i64 +// CHECK: %[[VAL_61:.*]] = tt.load %[[VAL_60]] {DiscreteMemAccess} : !tt.ptr +// CHECK: %[[VAL_62:.*]] = tt.splat %[[VAL_61]] : i32 -> tensor<1xi32> +// CHECK: %[[VAL_63:.*]] = tensor.insert_slice %[[VAL_62]] into %[[VAL_58]]{{\[}}%[[VAL_57]]] [1] [1] : tensor<1xi32> into tensor<128xi32> +// CHECK: scf.yield {DiscreteMemAccess} %[[VAL_63]] : tensor<128xi32> +// CHECK: } {ExtractedLoadOrStore} +// CHECK: %[[VAL_64:.*]] = arith.addi %[[VAL_54]], %[[VAL_56]] : tensor<128xi32> +// CHECK: tt.store %[[VAL_51]], %[[VAL_64]], %[[VAL_52]] : tensor<128x!tt.ptr> +// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_49]], %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_46]], %{{.*}} : i32 +// CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_47]], %[[VAL_4]] : tensor<128xi64> +// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_48]], %[[VAL_4]] : tensor<128xi64> +// CHECK: scf.yield %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_65]] : i32, tensor<128xi64>, tensor<128xi64>, tensor<128xi32> +// CHECK: } +// CHECK: %[[VAL_69:.*]] = arith.addi %[[VAL_70:.*]]#3, %{{.*}} : tensor<128xi32> +// CHECK: %[[VAL_71:.*]] = arith.addi %[[VAL_70]]#1, %[[VAL_4]] : tensor<128xi64> +// CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_70]]#2, %[[VAL_4]] : tensor<128xi64> +// CHECK: scf.yield %[[VAL_70]]#3, %[[VAL_69]], %[[VAL_71]], %[[VAL_72]] : tensor<128xi32>, tensor<128xi32>, tensor<128xi64>, tensor<128xi64> +// CHECK: } +// CHECK: tt.return +// CHECK: } \ No newline at end of file diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/splat.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/splat.mlir new file mode 100644 index 0000000000..dbfa654c23 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/splat.mlir @@ -0,0 +1,14 @@ +// RUN: triton-opt %s --triton-to-unstructure | FileCheck %s + +// CHECK-LABEL: tt.func@test_unstructure_splatandloadscenario +// CHECK: %[[EXT:.*]] = tensor.extract %{{.*}}[%{{.*}}] {DiscreteMemAccess} : tensor<128x!tt.ptr> +// CHECK: %[[VAL1:.*]] = tt.load %[[EXT]] : !tt.ptr +// CHECK: %[[VAL2:.*]] = tt.splat %[[VAL1]] : f32 -> tensor<128xf32> +tt.func@test_unstructure_splatandloadscenario(%base: !tt.ptr) -> tensor<128xf32> { + %offset = arith.constant 10 : i64 + %offset_tensor = tt.splat %offset : i64 -> tensor<128xi64> + %base_tensor = tt.splat %base : !tt.ptr -> tensor<128x!tt.ptr> + %ptr = tt.addptr %base_tensor, %offset_tensor : tensor<128x!tt.ptr>, tensor<128xi64> + %val = tt.load %ptr : tensor<128x!tt.ptr> + tt.return %val : tensor<128xf32> +} \ No newline at end of file diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/unstructure_mix.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/unstructure_mix.mlir new file mode 100644 index 0000000000..db36ad34c1 --- /dev/null +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/unstructure_mix.mlir @@ -0,0 +1,82 @@ +// RUN: triton-opt --triton-to-unstructure %s | FileCheck %s + +tt.func public @indirect_mix_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<16> : tensor<1x8xi32> + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c8_i32 : i32 + %2 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %3 = tt.splat %1 : i32 -> tensor<8xi32> + %4 = arith.addi %3, %2 : tensor<8xi32> + %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr> + %7 = tt.addptr %6, %5 : tensor<16x!tt.ptr>, tensor<16xi32> + %8 = tt.load %7 : tensor<16x!tt.ptr> + %9 = tt.expand_dims %2 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %10 = tt.splat %arg3 : i32 -> tensor<1x8xi32> + %11 = arith.muli %9, %10 : tensor<1x8xi32> + %12 = tt.expand_dims %8 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64> + %13 = arith.extsi %11 : tensor<1x8xi32> to tensor<1x8xi64> + %14 = tt.broadcast %13 : tensor<1x8xi64> -> tensor<16x8xi64> + %15 = tt.broadcast %12 : tensor<16x1xi64> -> tensor<16x8xi64> + %16 = arith.addi %14, %15 : tensor<16x8xi64> + %17 = tt.splat %arg2 : !tt.ptr -> tensor<16x8x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<16x8x!tt.ptr>, tensor<16x8xi64> + %19 = tt.load %18 : tensor<16x8x!tt.ptr> + %20 = math.exp %19 : tensor<16x8xf32> + %21 = tt.expand_dims %4 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %22 = arith.muli %21, %cst : tensor<1x8xi32> + %23 = tt.expand_dims %5 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %24 = tt.broadcast %22 : tensor<1x8xi32> -> tensor<16x8xi32> + %25 = tt.broadcast %23 : tensor<16x1xi32> -> tensor<16x8xi32> + %26 = arith.addi %24, %25 : tensor<16x8xi32> + %27 = tt.splat %arg0 : !tt.ptr -> tensor<16x8x!tt.ptr> + %28 = tt.addptr %27, %26 : tensor<16x8x!tt.ptr>, tensor<16x8xi32> + tt.store %28, %20 : tensor<16x8x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @indirect_mix_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_7:.*]] = arith.constant dense<16> : tensor<1x8xi32> +// CHECK: %[[VAL_9:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %{{.*}} : i32 +// CHECK: %[[VAL_11:.*]] = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> +// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_10]] : i32 -> tensor<8xi32> +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : tensor<8xi32> +// CHECK: %[[VAL_14:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> +// CHECK: %[[VAL_15:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<16x!tt.ptr> +// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_15]], %[[VAL_14]] : tensor<16x!tt.ptr>, tensor<16xi32> +// CHECK: %[[VAL_17:.*]] = tt.load %[[VAL_16]] : tensor<16x!tt.ptr> +// CHECK: %[[VAL_18:.*]] = tt.expand_dims %[[VAL_11]] {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_3]] : i32 -> tensor<1x8xi32> +// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_18]], %[[VAL_19]] : tensor<1x8xi32> +// CHECK: %[[VAL_21:.*]] = tt.expand_dims %[[VAL_17]] {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64> +// CHECK: %[[VAL_22:.*]] = arith.extsi %[[VAL_20]] : tensor<1x8xi32> to tensor<1x8xi64> +// CHECK: %[[VAL_23:.*]] = tt.broadcast %[[VAL_22]] : tensor<1x8xi64> -> tensor<16x8xi64> +// CHECK: %[[VAL_24:.*]] = tt.broadcast %[[VAL_21]] : tensor<16x1xi64> -> tensor<16x8xi64> +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : tensor<16x8xi64> +// CHECK: %[[VAL_26:.*]] = tensor.empty() : tensor<16x8xf32> +// CHECK: %[[VAL_27:.*]] = scf.for %[[VAL_28:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] iter_args(%[[VAL_29:.*]] = %[[VAL_26]]) -> (tensor<16x8xf32>) { +// CHECK: %[[VAL_30:.*]] = tensor.extract_slice %[[VAL_25]]{{\[}}%[[VAL_28]], 0] [1, 8] [1, 1] {DiscreteMemAccess} : tensor<16x8xi64> to tensor<1x8xi64> +// CHECK: %[[VAL_31:.*]] = tt.splat %[[VAL_2]] : !tt.ptr -> tensor<1x8x!tt.ptr> +// CHECK: %[[VAL_32:.*]] = tt.addptr %[[VAL_31]], %[[VAL_30]] : tensor<1x8x!tt.ptr>, tensor<1x8xi64> +// CHECK: %[[VAL_33:.*]] = tt.load %[[VAL_32]] {DiscreteMemAccess} : tensor<1x8x!tt.ptr> +// CHECK: %[[VAL_34:.*]] = tensor.insert_slice %[[VAL_33]] into %[[VAL_29]]{{\[}}%[[VAL_28]], 0] [1, 8] [1, 1] : tensor<1x8xf32> into tensor<16x8xf32> +// CHECK: scf.yield {DiscreteMemAccess} %[[VAL_34]] : tensor<16x8xf32> +// CHECK: } {ExtractedLoadOrStore} +// CHECK: %[[VAL_35:.*]] = math.exp %[[VAL_27]] : tensor<16x8xf32> +// CHECK: %[[VAL_36:.*]] = tt.expand_dims %[[VAL_13]] {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> +// CHECK: %[[VAL_37:.*]] = arith.muli %[[VAL_36]], %[[VAL_7]] : tensor<1x8xi32> +// CHECK: %[[VAL_38:.*]] = tt.expand_dims %[[VAL_14]] {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> +// CHECK: %[[VAL_39:.*]] = tt.broadcast %[[VAL_37]] : tensor<1x8xi32> -> tensor<16x8xi32> +// CHECK: %[[VAL_40:.*]] = tt.broadcast %[[VAL_38]] : tensor<16x1xi32> -> tensor<16x8xi32> +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_39]], %[[VAL_40]] : tensor<16x8xi32> +// CHECK: %[[VAL_42:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<16x8x!tt.ptr> +// CHECK: %[[VAL_43:.*]] = tt.addptr %[[VAL_42]], %[[VAL_41]] : tensor<16x8x!tt.ptr>, tensor<16x8xi32> +// CHECK: tt.store %[[VAL_43]], %[[VAL_35]] : tensor<16x8x!tt.ptr> +// CHECK: tt.return +// CHECK: } \ No newline at end of file diff --git a/third_party/ascend/unittest/affine_map/affine_map.py b/third_party/ascend/unittest/affine_map/affine_map.py new file mode 100644 index 0000000000..5d8460aed6 --- /dev/null +++ b/third_party/ascend/unittest/affine_map/affine_map.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +def main(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + c2 = ascend_ir.affine_expr.get_constant(2) + + expr = (d0 + c2) * d1 + print("expr:", expr) + print("expr pure affine:", expr.is_pure_affine()) + print("expr hashable:", hash(expr)) + + m0 = ascend_ir.affine_map.get_identity(2) + m1 = ascend_ir.affine_map.get(2, 0, [d1, d0]) + m2 = ascend_ir.affine_map.get(2, 0, [d0 + d1, d1]) + m3 = ascend_ir.affine_map.get_constant(7) + minor = ascend_ir.affine_map.get_minor_identity(3, 2) + + print("m0:", m0) + print("m1:", m1) + print("m2:", m2) + print("m1 inverse:", m1.inverse_permutation()) + print("m2 submap[1]:", m2.get_sub_map([1])) + print("m2 compose m1:", m2.compose(m1)) + print("m1 as dict:", m1.to_dict()) + print("m3 constant:", m3, "value=", m3.get_constant_result()) + print("minor identity:", minor) + print("m2 results:", [str(x) for x in m2.get_results()]) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/affine_map/affine_map_buffer_type_demo.py b/third_party/ascend/unittest/affine_map/affine_map_buffer_type_demo.py new file mode 100644 index 0000000000..a3e248a454 --- /dev/null +++ b/third_party/ascend/unittest/affine_map/affine_map_buffer_type_demo.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +def main(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + builder = ascend_ir.ascendnpu_ir_builder(ctx) + f32 = builder.get_float_ty() + ub_space = builder.get_target_attribute(ascend_ir.AddressSpace.UB) + + # Build a memref type using an explicit affine map layout. + transpose_map = ascend_ir.affine_map.get(2, 0, [1, 0]) + memref_ty = builder.get_buffer_ty_with_affine_map([8, 16], f32, transpose_map, ub_space) + map_attr = builder.get_affine_map_attr(transpose_map) + + print("affine map:", transpose_map) + print("affine map attr:", map_attr) + print("memref type:", memref_ty) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/affine_map/affine_map_complex_expr_demo.py b/third_party/ascend/unittest/affine_map/affine_map_complex_expr_demo.py new file mode 100644 index 0000000000..4748e69a02 --- /dev/null +++ b/third_party/ascend/unittest/affine_map/affine_map_complex_expr_demo.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +def main(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + s0 = ascend_ir.affine_expr.get_symbol(0) + c3 = ascend_ir.affine_expr.get_constant(3) + c4 = ascend_ir.affine_expr.get_constant(4) + + # Complex expressions with symbols and integer arithmetic. + tiled_row = (d0 + s0).floordiv(c4) + tiled_col = (d1 + c3).ceildiv(c4) + inner = (d0 + d1).mod(c4) + + map_a = ascend_ir.affine_map.get(2, 1, [tiled_row, tiled_col, inner]) + map_b = ascend_ir.affine_map.get(2, 0, [d1, d0]) + map_comp = map_a.compose(map_b) + + print("map_a:", map_a) + print("map_b:", map_b) + print("map_a composed with map_b:", map_comp) + print("map_a results:", [str(r) for r in map_a.get_results()]) + print("map_a submap [0, 2]:", map_a.get_sub_map([0, 2])) + print("map_a metadata:", map_a.to_dict()) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/affine_map/affine_map_indexing_map_demo.py b/third_party/ascend/unittest/affine_map/affine_map_indexing_map_demo.py new file mode 100644 index 0000000000..d8be55d37e --- /dev/null +++ b/third_party/ascend/unittest/affine_map/affine_map_indexing_map_demo.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +def main(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + builder = ascend_ir.ascendnpu_ir_builder(ctx) + + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + c8 = ascend_ir.affine_expr.get_constant(8) + + # Example indexing maps: transpose and a tiled/reduced projection. + map_in0 = ascend_ir.affine_map.get(2, 0, [d1, d0]) + map_in1 = ascend_ir.affine_map.get(2, 0, [d0, d1]) + map_out = ascend_ir.affine_map.get(2, 0, [d0.floordiv(c8), d1.mod(c8)]) + + indexing_map_attr = builder.get_affine_map_array_attr([map_in0, map_in1, map_out]) + print("indexing_map attr:", indexing_map_attr) + + ub_space = builder.get_target_attribute(ascend_ir.AddressSpace.UB) + f32 = builder.get_float_ty() + memref_ty = builder.get_buffer_ty_with_affine_map([16, 32], f32, map_in0, ub_space) + print("buffer type with map_in0:", memref_ty) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/affine_map/affine_map_parse_demo.py b/third_party/ascend/unittest/affine_map/affine_map_parse_demo.py new file mode 100644 index 0000000000..9bb8e09667 --- /dev/null +++ b/third_party/ascend/unittest/affine_map/affine_map_parse_demo.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +def main(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + identity_map = ascend_ir.affine_map.get_identity(2) + transpose_map = ascend_ir.affine_map.get(2, 0, [1, 0]) + + print("identity map:", identity_map) + print(" dims:", identity_map.get_num_dims()) + print(" symbols:", identity_map.get_num_symbols()) + print(" results:", identity_map.get_num_results()) + print(" is_identity:", identity_map.is_identity()) + print(" is_permutation:", identity_map.is_permutation()) + + print("transpose map:", transpose_map) + print(" is_identity:", transpose_map.is_identity()) + print(" is_permutation:", transpose_map.is_permutation()) + print(" as python object:", transpose_map.to_dict()) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/autotune_ut/01-vector-add.py b/third_party/ascend/unittest/autotune_ut/01-vector-add.py new file mode 100644 index 0000000000..555b961d94 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/01-vector-add.py @@ -0,0 +1,91 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Vector Add +============= +""" + +import os + +import torch +import torch_npu +import triton +import triton.language as tl +import triton.backends.ascend.runtime +from triton.backends.ascend.testing import do_bench_npu + + +@triton.autotune( + configs=[], + key=["n_elements"] +) +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +def add_torch(x, y): + return x + y + + +def add_autotune(x, y): + output = torch.empty_like(x) + n_elements = output.numel() + add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)](x, y, output, n_elements) + return output + + +def test_add(size: int): + x = torch.rand(size, device="npu") + y = torch.rand(size, device="npu") + + output_torch = add_torch(x, y) + output_triton = add_autotune(x, y) + assert torch.allclose(output_triton, output_torch) + print(f"Vector Add {size} PASSED!") + + +if __name__ == "__main__": + test_add(98432) diff --git a/third_party/ascend/unittest/autotune_ut/02-fused-softmax.py b/third_party/ascend/unittest/autotune_ut/02-fused-softmax.py new file mode 100644 index 0000000000..09c7de0f80 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/02-fused-softmax.py @@ -0,0 +1,107 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Fused Softmax +============= +""" + +import os + +import torch +import torch_npu +import triton +import triton.language as tl +import triton.backends.ascend.runtime +from triton.backends.ascend.testing import do_bench_npu + + +@triton.autotune( + configs=[], + key=["n_rows", "n_cols"], +) +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, + XBLOCK: tl.constexpr, + XBLOCK_SUB: tl.constexpr, +): + # starting row of the program + row_start = tl.program_id(0) * XBLOCK + for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB): + # The stride represents how much we need to increase the pointer to advance 1 row + row_offsets = row_start + row_idx + tl.arange(0, XBLOCK_SUB)[:, None] + col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] + xmask = row_offsets < n_rows + ymask = col_offsets < n_cols + mask = xmask & ymask + input_ptrs = input_ptr + (row_offsets * input_row_stride + col_offsets) + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to( + XBLOCK_SUB, BLOCK_SIZE + ) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = ( + tl.sum(numerator, axis=1) + .reshape(XBLOCK_SUB, 1) + .broadcast_to(XBLOCK_SUB, BLOCK_SIZE) + ) + softmax_output = numerator / denominator + # Write back output to DRAM + output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) + tl.store(output_ptrs, softmax_output, mask=mask) + + +def softmax_torch(x): + return torch.softmax(x, axis=-1) + + +def softmax_autotune(x): + n_rows, n_cols = x.shape + BLOCK_SIZE = n_cols + + # Allocate output + y = torch.empty_like(x) + # Create a number of persistent programs. + softmax_kernel[lambda meta: (triton.cdiv(n_rows, meta["XBLOCK"]), 1, 1)]( + y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE + ) + return y + + +def test_softmax(shape, dtype): + x = torch.randn(shape, dtype=dtype, device="npu") + y_torch = softmax_torch(x) + y_triton = softmax_autotune(x) + assert torch.allclose(y_triton, y_torch) + print(f"Fused Softmax {shape} {dtype} PASSED!") + + +if __name__ == "__main__": + test_softmax((16896, 1024), torch.float32) diff --git a/third_party/ascend/unittest/autotune_ut/03-layer-norm.py b/third_party/ascend/unittest/autotune_ut/03-layer-norm.py new file mode 100644 index 0000000000..42e6b5b1bd --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/03-layer-norm.py @@ -0,0 +1,140 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Layer Normalization +============= +""" + +import os + +import torch +import torch_npu +import triton +import triton.language as tl +import triton.backends.ascend.runtime +from triton.backends.ascend.testing import do_bench_npu + + +@triton.autotune( + configs=[], + key=["M", "N"], +) +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, + M, # number of columns in X + eps, # epsilon to avoid division by zero + XBLOCK_SIZE: tl.constexpr, + RBLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row_begin = tl.program_id(0) * XBLOCK_SIZE + row_idx = row_begin + tl.arange(0, XBLOCK_SIZE) + row_mask = row_idx < M + row_offsets = row_idx[:, None] * stride + # Compute mean + _mean = tl.zeros((XBLOCK_SIZE, RBLOCK_SIZE), dtype=tl.float32) + for off in range(0, N, RBLOCK_SIZE): + col_idx = off + tl.arange(0, RBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:, None] & col_mask[None, :] + a = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to( + tl.float32 + ) + _mean += a + mean = tl.sum(_mean, axis=1, keep_dims=True) / N + # Compute variance + _var = tl.zeros((XBLOCK_SIZE, RBLOCK_SIZE), dtype=tl.float32) + for off in range(0, N, RBLOCK_SIZE): + col_idx = off + tl.arange(0, RBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:, None] & col_mask[None, :] + x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to( + tl.float32 + ) + x = tl.where(mask, x - mean, 0.0) + _var += x * x + var = tl.sum(_var, axis=1, keep_dims=True) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row_idx[:, None], mean, mask=row_mask[:, None]) + tl.store(Rstd + row_idx[:, None], rstd, mask=row_mask[:, None]) + # Normalize and apply linear transformation + for off in range(0, N, RBLOCK_SIZE): + col_idx = off + tl.arange(0, RBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:, None] & col_mask[None, :] + w = tl.load(W + col_idx, mask=col_mask).reshape((1, RBLOCK_SIZE)) + b = tl.load(B + col_idx, mask=col_mask).reshape((1, RBLOCK_SIZE)) + x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to( + tl.float32 + ) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + row_offsets + col_idx[None, :], y, mask=mask) + + +def layer_norm_torch(args): + x, w_shape, weight, bias, eps, dtype = args + return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + + +def layer_norm_autotune(args): + x, weight, bias, eps = args + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M,), dtype=torch.float32, device=x.device) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + + # enqueue kernel + _layer_norm_fwd_fused[lambda meta: (triton.cdiv(M, meta["XBLOCK_SIZE"]), 1, 1)]( # + x_arg, y, weight, bias, mean, rstd, x_arg.stride(0), N, M, eps # + ) + return y + + +def test_layer_norm(shape, dtype, eps=1e-5): + M, N = shape + device = "npu" + x_shape = shape + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device=device) + bias = torch.rand(w_shape, dtype=dtype, device=device) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + y_torch = layer_norm_torch((x, w_shape, weight, bias, eps, dtype)) + y_triton = layer_norm_autotune((x, weight, bias, eps)) + assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0) + print(f"Layer Normalization {M},{N} {dtype} PASSED!") + + +if __name__ == "__main__": + test_layer_norm((128, 32), torch.float16) diff --git a/third_party/ascend/unittest/autotune_ut/04-libentry.py b/third_party/ascend/unittest/autotune_ut/04-libentry.py new file mode 100644 index 0000000000..72949d9cf9 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/04-libentry.py @@ -0,0 +1,101 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Vector Add with Libentry +============= +""" + +import os + +import torch +import torch_npu +import triton +import triton.language as tl +import triton.backends.ascend.runtime +from triton.runtime.libentry import libentry +from triton.backends.ascend.testing import do_bench_npu + + +# NB: Inserting any other decorator between @triton.autotune and @triton.jit disables +# parallel compilation during autotuning. To enable parallel compilation, apply @triton.autotune +# directly around @triton.jit (i.e., nest autotune as the outermost decorator on the JIT-compiled function) +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 1 * 1024, 'multibuffer': True}), + triton.Config({'BLOCK_SIZE': 12 * 1024, 'multibuffer': True}), + triton.Config({'BLOCK_SIZE': 12 * 1024, 'multibuffer': False}), + triton.Config({'BLOCK_SIZE': 8 * 1024, 'multibuffer': True}), + ], + key=["n_elements"] +) +@libentry() +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +def add_torch(x, y): + return x + y + + +def add_autotune(x, y): + output = torch.empty_like(x) + n_elements = output.numel() + add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)](x, y, output, n_elements) + return output + + +def test_add(size: int): + x = torch.rand(size, device="npu") + y = torch.rand(size, device="npu") + + output_torch = add_torch(x, y) + output_triton = add_autotune(x, y) + assert torch.allclose(output_triton, output_torch) + print(f"Vector Add {size} with libentry PASSED!") + + +if __name__ == "__main__": + test_add(98432) diff --git a/third_party/ascend/unittest/autotune_ut/test_autotune_param_valid.py b/third_party/ascend/unittest/autotune_ut/test_autotune_param_valid.py new file mode 100644 index 0000000000..1ded361ae2 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_autotune_param_valid.py @@ -0,0 +1,175 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os + +import pytest +import torch +import torch_npu +import triton +import triton.backends.ascend.runtime +import triton.language as tl + + +@triton.autotune( + configs=[], + key={"x": "n_elements"}, + hints={ + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SIZE_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } +) +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_SUB: tl.constexpr, +): + offset = tl.program_id(0) * BLOCK_SIZE + loops1 = (BLOCK_SIZE + BLOCK_SIZE_SUB - 1) // BLOCK_SIZE_SUB + for loop in range(0, loops1): + x0 = offset + loop * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE_SUB) + mask = x0 < n_elements + x = tl.load(x_ptr + x0, mask) + y = tl.load(y_ptr + x0, mask) + output = x + y + tl.store(output_ptr + x0, output) + + +def add_torch(x, y): + return x + y + + +def add_autotune(x, y): + output = torch.empty_like(x) + n_elements = output.numel() + add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)](x, y, output, n_elements) + return output + + +@pytest.mark.autotune +@pytest.mark.parametrize('size', [2048, ]) +def test_add(size: int): + x = torch.rand(size, device="npu") + y = torch.rand(size, device="npu") + + output_torch = add_torch(x, y) + output_triton = add_autotune(x, y) + assert torch.allclose(output_triton, output_torch) + + +@pytest.mark.autotune +def test_add_no_reduction_axes(): + try: + @triton.autotune( + configs=[], + key={"x": "n_elements"}, + hints={ + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SIZE_SUB"}, + "low_dim_axes": ["x"], + } + ) + @triton.jit + def add_kernel_exception(): + pass + except ValueError as e: + assert "reduction_axes must be a list" in str(e) + + +@pytest.mark.autotune +def test_add_no_low_dim_axes(): + try: + @triton.autotune( + configs=[], + key={"x": "n_elements"}, + hints={ + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SIZE_SUB"}, + "reduction_axes": [], + } + ) + @triton.jit + def add_kernel_exception(): + pass + except ValueError as e: + assert "low_dim_axes must be a list" in str(e) + + +@pytest.mark.autotune +def test_add_no_tiling_params(): + try: + @triton.autotune( + configs=[], + key={"x": "n_elements"}, + hints={ + "split_params": {"x": "BLOCK_SIZE"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + ) + @triton.jit + def add_kernel_exception(): + pass + except ValueError as e: + assert "tiling_params must be a dict" in str(e) + + +@pytest.mark.autotune +def test_add_no_split_params(): + try: + @triton.autotune( + configs=[], + key={"x": "n_elements"}, + hints={ + "tiling_params": {"x": "BLOCK_SIZE_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + ) + @triton.jit + def add_kernel_exception(): + pass + except ValueError as e: + assert "split_params must be a dict" in str(e) + + +@pytest.mark.autotune +def test_add_no_keyname(): + try: + @triton.autotune( + configs=[], + key={"x0": "n_elements"}, + hints={ + "tiling_params": {"x": "BLOCK_SIZE_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + ) + @triton.jit + def add_kernel_exception(): + pass + except ValueError as e: + assert "All keys in 'key' must be valid axis names" in str(e) diff --git a/third_party/ascend/unittest/autotune_ut/test_common.py b/third_party/ascend/unittest/autotune_ut/test_common.py index d512d3358e..50502ec774 100644 --- a/third_party/ascend/unittest/autotune_ut/test_common.py +++ b/third_party/ascend/unittest/autotune_ut/test_common.py @@ -20,6 +20,7 @@ import unittest.mock as mock import pytest +import torch def MockAutoTilingTunerRun(self, *args, **kwargs): @@ -82,5 +83,23 @@ def normalize_axis_list(axis_list: list, sym_to_sem: dict) -> list: @pytest.fixture def mock_autotuner(): - with mock.patch("triton.backends.ascend.runtime.autotuner.AutoTilingTuner.run", new=MockAutoTilingTunerRun): + with mock.patch( + "triton.backends.ascend.runtime.autotuner.AutoTilingTuner.run", + new=MockAutoTilingTunerRun + ): yield + + +def generate_tensor(shape, dtype): + if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': + return torch.randn(size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': + return torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int8': + return torch.randint(low=0, high=127, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'bool': + return torch.randint(low=0, high=2, size=shape).bool() + elif dtype == 'uint8': + return torch.randint(low=0, high=255, size=shape, dtype=torch.uint8) + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) diff --git a/third_party/ascend/unittest/autotune_ut/test_customized_config.py b/third_party/ascend/unittest/autotune_ut/test_customized_config.py new file mode 100644 index 0000000000..ee73db9f57 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_customized_config.py @@ -0,0 +1,99 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os + +import pytest +import torch +import torch_npu +import triton +import triton.backends.ascend.runtime +import triton.language as tl + +os.environ['TRITON_PRINT_AUTOTUNING'] = '0' + + +@triton.autotune( + configs=[ + triton.Config({'XBLOCK': 128, 'XBLOCK_SUB': 32}), + triton.Config({'XBLOCK': 128, 'XBLOCK_SUB': 64}), + triton.Config({'XBLOCK': 396, 'XBLOCK_SUB': 6}), + ], + key=["n_rows", "n_cols"], + hints={ + "auto_gen_config": False, + } +) +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, + XBLOCK: tl.constexpr, + XBLOCK_SUB: tl.constexpr, +): + row_start = tl.program_id(0) * XBLOCK + for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB): + row_offsets = row_start + row_idx + tl.arange(0, XBLOCK_SUB)[:, None] + col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] + xmask = row_offsets < n_rows + ymask = col_offsets < n_cols + mask = xmask & ymask + input_ptrs = input_ptr + (row_offsets * input_row_stride + col_offsets) + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) + row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to( + XBLOCK_SUB, BLOCK_SIZE + ) + numerator = tl.exp(row_minus_max) + denominator = ( + tl.sum(numerator, axis=1) + .reshape(XBLOCK_SUB, 1) + .broadcast_to(XBLOCK_SUB, BLOCK_SIZE) + ) + softmax_output = numerator / denominator + output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) + tl.store(output_ptrs, softmax_output, mask=mask) + + +def softmax_torch(x): + return torch.softmax(x, axis=-1) + + +def softmax_autotune(x): + n_rows, n_cols = x.shape + BLOCK_SIZE = n_cols + y = torch.empty_like(x) + softmax_kernel[lambda meta: (triton.cdiv(n_rows, meta["XBLOCK"]), 1, 1)]( + y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE + ) + return y + + +@pytest.mark.autotune +@pytest.mark.parametrize('shape,dtype', [((16896, 1024), torch.float32), ]) +def test_softmax(shape, dtype): + x = torch.randn(shape, dtype=dtype, device="npu") + y_torch = softmax_torch(x) + y_triton = softmax_autotune(x) + torch.testing.assert_close(y_torch, y_triton, rtol=1e-03, atol=1e-03, equal_nan=True) diff --git a/third_party/ascend/unittest/autotune_ut/test_low_dim_axes_parse.py b/third_party/ascend/unittest/autotune_ut/test_low_dim_axes_parse.py new file mode 100644 index 0000000000..6c0be27f0d --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_low_dim_axes_parse.py @@ -0,0 +1,59 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +from test_common import check_axes_parse_res, mock_autotuner + + +def test_low_dim_axis_parse_base_case1(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["n_elements"] + ) + @triton.jit + def triton_low_dim_axis_parse_base_case1( + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE # <- Separate assignment + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"],) + act_res = triton_low_dim_axis_parse_base_case1[grid]() + + check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/unittest/autotune_ut/test_mask_parse.py b/third_party/ascend/unittest/autotune_ut/test_mask_parse.py new file mode 100644 index 0000000000..1b7d2383da --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_mask_parse.py @@ -0,0 +1,164 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import triton.language as tl +from test_common import check_axes_parse_res, mock_autotuner + + +def test_triton_dot_case1(mock_autotuner): + """ + The current operator is only used for aixs analysis test cases. + CV fused operators do not support autotuning for now. + """ + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["M", "N", "K"] + ) + @triton.jit + def triton_dot_case1( + A, B, C, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + MBLOCK: tl.constexpr, + NBLOCK: tl.constexpr, + MBLOCK_SUB: tl.constexpr, + NBLOCK_SUB: tl.constexpr, + KBLOCK_SUB: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + base_m = pid_m * MBLOCK + base_n = pid_n * NBLOCK + + loops_m = (MBLOCK + MBLOCK_SUB - 1) // MBLOCK_SUB + loops_n = (NBLOCK + NBLOCK_SUB - 1) // NBLOCK_SUB + loops_k = (K + KBLOCK_SUB - 1) // KBLOCK_SUB + + for loop_m in range(loops_m): + for loop_n in range(loops_n): + acc = tl.zeros((MBLOCK_SUB, NBLOCK_SUB), dtype=tl.float32) + + mdx = base_m + loop_m * MBLOCK_SUB + tl.arange(0, MBLOCK_SUB)[:, None] + ndx = base_n + loop_n * NBLOCK_SUB + tl.arange(0, NBLOCK_SUB)[None, :] + + for loop_k in range(loops_k): + kdx = loop_k * KBLOCK_SUB + tl.arange(0, KBLOCK_SUB) + kdx_m = kdx[None, :] # <- + A_ptr = A + mdx * K + kdx_m + a_mask = (mdx < M) & (kdx_m < K) # Use res of Subscript in mask compare + a = tl.load(A_ptr, mask=a_mask, other=0.0) + + kdx_n = kdx[:, None] + B_ptr = B + kdx_n * N + ndx + b_mask = (kdx_n < K) & (ndx < N) + b = tl.load(B_ptr, mask=b_mask, other=0.0) + + acc += tl.dot(a, b) + + C_ptr = C + mdx * N + ndx + c_mask = (mdx < M) & (ndx < N) + tl.store(C_ptr, acc, mask=c_mask) + + ref_res = { + "keys": {"x": "M", "y": "N", "z": "K"}, + "split_params": {"x": "MBLOCK", "y": "NBLOCK"}, + "tiling_params": {"x": "MBLOCK_SUB", "y": "NBLOCK_SUB", "z": "KBLOCK_SUB"}, + "low_dim_axes": ["y", "z"], + "reduction_axes": [], + } + grid = lambda meta: (meta["MBLOCK"], meta["NBLOCK"]) + act_res = triton_dot_case1[grid]() + + check_axes_parse_res(act_res, ref_res) + + +@pytest.mark.skip +def test_triton_dot_case2(mock_autotuner): + """ + The current operator is only used for aixs analysis test cases. + CV fused operators do not support autotuning for now. + """ + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["M", "N", "K"] + ) + @triton.jit + def triton_dot_case2( + A, B, C, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + MBLOCK: tl.constexpr, + NBLOCK: tl.constexpr, + MBLOCK_SUB: tl.constexpr, + NBLOCK_SUB: tl.constexpr, + KBLOCK_SUB: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + base_m = pid_m * MBLOCK + base_n = pid_n * NBLOCK + + loops_m = (MBLOCK + MBLOCK_SUB - 1) // MBLOCK_SUB + loops_n = (NBLOCK + NBLOCK_SUB - 1) // NBLOCK_SUB + loops_k = (K + KBLOCK_SUB - 1) // KBLOCK_SUB + + for loop_m in range(loops_m): + for loop_n in range(loops_n): + acc = tl.zeros((MBLOCK_SUB, NBLOCK_SUB), dtype=tl.float32) + + mdx = base_m + loop_m * MBLOCK_SUB + tl.arange(0, MBLOCK_SUB)[:, None] + ndx = base_n + loop_n * NBLOCK_SUB + tl.arange(0, NBLOCK_SUB)[None, :] + + for loop_k in range(loops_k): + kdx = loop_k * KBLOCK_SUB + tl.arange(0, KBLOCK_SUB) + A_ptr = A + mdx * K + kdx[None, :] # <- + a_mask = (mdx < M) & (kdx[None, :] < K) # Cal subsript directly in mask compare + a = tl.load(A_ptr, mask=a_mask, other=0.0) + + B_ptr = B + kdx[:, None] * N + ndx + b_mask = (kdx[:, None] < K) & (ndx < N) + b = tl.load(B_ptr, mask=b_mask, other=0.0) + + acc += tl.dot(a, b) + + C_ptr = C + mdx * N + ndx + c_mask = (mdx < M) & (ndx < N) + tl.store(C_ptr, acc, mask=c_mask) + + ref_res = { + "keys": {"x": "M", "y": "N", "z": "K"}, + "split_params": {"x": "MBLOCK", "y": "NBLOCK"}, + "tiling_params": {"x": "MBLOCK_SUB", "y": "NBLOCK_SUB", "z": "KBLOCK_SUB"}, + "low_dim_axes": ["y", "z"], + "reduction_axes": [], + } + grid = lambda meta: (meta["MBLOCK"], meta["NBLOCK"]) + act_res = triton_dot_case2[grid]() + + check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/unittest/autotune_ut/test_no_tiling_axis_parse.py b/third_party/ascend/unittest/autotune_ut/test_no_tiling_axis_parse.py new file mode 100644 index 0000000000..af7945bf9d --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_no_tiling_axis_parse.py @@ -0,0 +1,99 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import os +import shutil +import pytest +import torch +import torch_npu +import triton +import triton.backends.ascend.runtime +import triton.language as tl +from triton.tools.get_ascend_devices import is_compile_on_910_95 + +import test_common + +os.environ['TRITON_ALWAYS_COMPILE'] = '1' +os.environ['TRITON_AUTOTUNE_PARALLEL_COMPILE'] = '0' + + +def case_torch(x): + return torch.permute(x, (1, 0)) + + +@triton.autotune( + configs=[], + key=['xnumel', 'ynumel'], + hints={ + "auto_gen_config": True, + } +) +@triton.jit +def triton_permute_2d(output_ptr, + x_ptr, + xnumel: tl.constexpr, + ynumel: tl.constexpr, + XBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, ): + xpid = tl.program_id(0) + ypid = tl.program_id(1) + + x_off = xpid * XBLOCK + tl.arange(0, XBLOCK)[:, None] + y_off = ypid * YBLOCK + tl.arange(0, YBLOCK)[None, :] + mask = (x_off < xnumel) & (y_off < ynumel) + offs = y_off + x_off * ynumel + b = tl.load(x_ptr + offs, mask=mask) + ox_off = ypid * YBLOCK + tl.arange(0, YBLOCK)[:, None] + oy_off = xpid * XBLOCK + tl.arange(0, XBLOCK)[None, :] + o_mask = (ox_off < ynumel) & (oy_off < xnumel) + o_offs = oy_off + ox_off * xnumel + ret = tl.permute(b, (1, 0)) + tl.store(output_ptr + o_offs, ret, mask=o_mask) + + +def case_triton(x_cal, is_simt_only=False): + xnumel = x_cal.shape[0] + ynumel = x_cal.shape[1] + output = torch.randint(1, (ynumel, xnumel), dtype=x_cal.dtype, device=x_cal.device) + if is_simt_only: + (triton_permute_2d[lambda meta: (triton.cdiv(xnumel, meta['XBLOCK']), triton.cdiv(ynumel, meta['YBLOCK']), 1)] + (output, x_cal, xnumel, ynumel, force_simt_only=True)) + else: + (triton_permute_2d[lambda meta: (triton.cdiv(xnumel, meta['XBLOCK']), triton.cdiv(ynumel, meta['YBLOCK']), 1)] + (output, x_cal, xnumel, ynumel)) + return output + + +@pytest.mark.parametrize('shape', [(1024, 32), (32, 8)]) +@pytest.mark.parametrize('dtype', ['bfloat16']) +def test_permute(shape, dtype): + x_cal = test_common.generate_tensor(shape, dtype).npu() + torch_output = case_torch(x_cal) + triton_output = case_triton(x_cal) + torch.testing.assert_close(torch_output, triton_output, rtol=1e-03, atol=1e-03, equal_nan=True) + + +@pytest.mark.skipif(not is_compile_on_910_95, reason="only support A5") +@pytest.mark.parametrize('shape', [(1024, 32)]) +@pytest.mark.parametrize('dtype', ['bfloat16']) +def test_permute_simt(shape, dtype): + x_cal = test_common.generate_tensor(shape, dtype).npu() + torch_output = case_torch(x_cal) + triton_output = case_triton(x_cal, True) + torch.testing.assert_close(torch_output, triton_output, rtol=1e-03, atol=1e-03, equal_nan=True) diff --git a/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py b/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py index 2893bf3473..b1c2a8c97b 100644 --- a/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py +++ b/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py @@ -22,18 +22,18 @@ from test_common import check_axes_parse_res, mock_autotuner -def test_triton_max_last_dim_case(mock_autotuner): +def test_triton_max_last_dim_case1(mock_autotuner): import triton.backends.ascend.runtime @triton.autotune(configs=[], key=["x0_numel", "r1_numel"]) @triton.jit - def triton_max_last_dim( - in_ptr0, - out_ptr0, - x0_numel, - r1_numel, - X0BLOCK: tl.constexpr, - X0BLOCK_SUB: tl.constexpr, + def triton_max_last_dim1( + in_ptr0, + out_ptr0, + x0_numel, + r1_numel, + X0BLOCK: tl.constexpr, + X0BLOCK_SUB: tl.constexpr, R1BLOCK_SUB: tl.constexpr, ): x0_offset = tl.program_id(0) * X0BLOCK @@ -50,7 +50,8 @@ def triton_max_last_dim( r1_mask = r1 < r1_numel tmp = tl.load(in_ptr0 + (r1 + r1_numel * x0), r1_mask & x0_mask, other=float("-inf")) block_val = tl.maximum(block_val, tmp) - block_res = tl.max(block_val, axis=1)[:, None] + # Reduce along axis = 1 (the last dimension in this 2D tensor) + block_res = tl.max(block_val, axis=1)[:, None] # <- explicit positive axis index tl.store(out_ptr0 + x0, block_res, x0_mask) ref_res = { @@ -60,6 +61,103 @@ def triton_max_last_dim( "low_dim_axes": ["ry"], "reduction_axes": ["ry"], } - act_res = triton_max_last_dim[(1, )]() + grid = lambda meta: (meta["X0BLOCK"],) + act_res = triton_max_last_dim1[grid]() + + check_axes_parse_res(act_res, ref_res) + + +def test_triton_max_last_dim_case2(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["x0_numel", "r1_numel"] + ) + @triton.jit + def triton_max_last_dim2( + in_ptr0, + out_ptr0, + x0_numel, + r1_numel, + X0BLOCK: tl.constexpr, + X0BLOCK_SUB: tl.constexpr, + R1BLOCK_SUB: tl.constexpr, + ): + x0_offset = tl.program_id(0) * X0BLOCK + base_x0 = tl.arange(0, X0BLOCK_SUB) + loops_x0 = (X0BLOCK + X0BLOCK_SUB - 1) // X0BLOCK_SUB + base_r1 = tl.arange(0, R1BLOCK_SUB) + loops_r1 = (r1_numel + R1BLOCK_SUB - 1) // R1BLOCK_SUB + for loop_x0 in range(loops_x0): + x0 = x0_offset + (loop_x0 * X0BLOCK_SUB) + base_x0[:, None] + x0_mask = x0 < min(X0BLOCK + x0_offset, x0_numel) + block_val = tl.full([X0BLOCK_SUB, R1BLOCK_SUB], float("-inf"), tl.float32) + for loop_r1 in range(loops_r1): + r1 = (loop_r1 * R1BLOCK_SUB) + base_r1[None, :] + r1_mask = r1 < r1_numel + tmp = tl.load(in_ptr0 + (r1 + r1_numel * x0), r1_mask & x0_mask, other=float("-inf")) + block_val = tl.maximum(block_val, tmp) + # Reduce along axis=-1 (the last dimension, equivalent to axis=1 in 2D) + block_res = tl.max(block_val, axis=-1)[:, None] # <- negative axis index (last dim) + tl.store(out_ptr0 + x0, block_res, x0_mask) + + ref_res = { + "keys": {"x": "x0_numel", "ry": "r1_numel"}, + "split_params": {"x": "X0BLOCK"}, + "tiling_params": {"x": "X0BLOCK_SUB", "ry": "R1BLOCK_SUB"}, + "low_dim_axes": ["ry"], + "reduction_axes": ["ry"], + } + grid = lambda meta: (meta["X0BLOCK"],) + act_res = triton_max_last_dim2[grid]() + + check_axes_parse_res(act_res, ref_res) + + +def test_triton_max_last_dim_case3(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["x0_numel", "r1_numel"] + ) + @triton.jit + def triton_max_last_dim3( + in_ptr0, + out_ptr0, + x0_numel, + r1_numel, + X0BLOCK: tl.constexpr, + X0BLOCK_SUB: tl.constexpr, + R1BLOCK_SUB: tl.constexpr, + ): + x0_offset = tl.program_id(0) * X0BLOCK + base_x0 = tl.arange(0, X0BLOCK_SUB) + loops_x0 = (X0BLOCK + X0BLOCK_SUB - 1) // X0BLOCK_SUB + base_r1 = tl.arange(0, R1BLOCK_SUB) + loops_r1 = (r1_numel + R1BLOCK_SUB - 1) // R1BLOCK_SUB + for loop_x0 in range(loops_x0): + x0 = x0_offset + (loop_x0 * X0BLOCK_SUB) + base_x0[:, None] + x0_mask = x0 < min(X0BLOCK + x0_offset, x0_numel) + block_val = tl.full([X0BLOCK_SUB, R1BLOCK_SUB], float("-inf"), tl.float32) + for loop_r1 in range(loops_r1): + r1 = (loop_r1 * R1BLOCK_SUB) + base_r1[None, :] + r1_mask = r1 < r1_numel + tmp = tl.load(in_ptr0 + (r1 + r1_numel * x0), r1_mask & x0_mask, other=float("-inf")) + block_val = tl.maximum(block_val, tmp) + # Reduce along axis=1, passed as a positional argument (not keyword `axis=...`) + block_res = tl.max(block_val, 1)[:, None] # <- explicit positive axis index + tl.store(out_ptr0 + x0, block_res, x0_mask) + + ref_res = { + "keys": {"x": "x0_numel", "ry": "r1_numel"}, + "split_params": {"x": "X0BLOCK"}, + "tiling_params": {"x": "X0BLOCK_SUB", "ry": "R1BLOCK_SUB"}, + "low_dim_axes": ["ry"], + "reduction_axes": ["ry"], + } + grid = lambda meta: (meta["X0BLOCK"],) + act_res = triton_max_last_dim3[grid]() check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/unittest/autotune_ut/test_split_axis_parse.py b/third_party/ascend/unittest/autotune_ut/test_split_axis_parse.py new file mode 100644 index 0000000000..a2d70314a8 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_split_axis_parse.py @@ -0,0 +1,167 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import unittest.mock as mock + +import triton +import triton.language as tl + +from test_common import check_axes_parse_res, mock_autotuner + + +def test_split_axis_parse_base_case1(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["n_elements"] + ) + @triton.jit + def triton_split_axis_parse_base_case1( + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE # <- Separate assignment + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"],) + act_res = triton_split_axis_parse_base_case1[grid]() + + check_axes_parse_res(act_res, ref_res) + + +def test_split_axis_parse_base_case2(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["n_elements"] + ) + @triton.jit + def triton_split_axis_parse_base_case2( + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr + ): + block_start = tl.program_id(axis=0) * BLOCK_SIZE # <- Computed inline but still named + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"],) + act_res = triton_split_axis_parse_base_case2[grid]() + + check_axes_parse_res(act_res, ref_res) + + +def test_split_axis_parse_base_case3(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["n_elements"] + ) + @triton.jit + def triton_split_axis_parse_base_case3( + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr + ): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # <- Fully fused + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"],) + act_res = triton_split_axis_parse_base_case3[grid]() + + check_axes_parse_res(act_res, ref_res) + + +def test_grid_stride_loop_block_only_tiling_semantics(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["N", "index_len"] + ) + @triton.jit + def triton_grid_stride_loop_block_only_tiling_semantics( + input_ptr, + output_ptr, + index_ptr, + N: tl.constexpr, + index_len: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + grid_x = tl.num_programs(axis=0) + grid_y = tl.num_programs(axis=1) + for x in range(pid_x * BLOCK_M, index_len, grid_x * BLOCK_M): + row_offsets = x + tl.arange(0, BLOCK_M) + indices = tl.load(index_ptr + row_offsets, mask=row_offsets < index_len, other=0) + for y in range(pid_y * BLOCK_N, N, grid_y * BLOCK_N): + col_offsets = y + tl.arange(0, BLOCK_N) + col_mask = col_offsets < N + inp_offset = indices[:, None] * N + col_offsets[None, :] + out_offset = row_offsets[:, None] * N + col_offsets[None, :] + selected = tl.load(input_ptr + inp_offset, mask=col_mask[None, :], other=0.0) + tl.store(output_ptr + out_offset, selected, mask=col_mask[None, :]) + + act_res = triton_grid_stride_loop_block_only_tiling_semantics[(1, 1)]() + assert act_res["split_params"] == {} + assert act_res["tiling_params"] == {"y": "BLOCK_M", "x": "BLOCK_N"} diff --git a/third_party/ascend/unittest/autotune_ut/test_tiling_axis_parse.py b/third_party/ascend/unittest/autotune_ut/test_tiling_axis_parse.py new file mode 100644 index 0000000000..104535c781 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_tiling_axis_parse.py @@ -0,0 +1,135 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import triton +import triton.language as tl +from test_common import check_axes_parse_res, mock_autotuner + + +def test_tiling_axis_parse_base_case1(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["n_elements"] + ) + @triton.jit + def triton_tiling_axis_parse_base_case1( + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_SUB: tl.constexpr + ): + offset = tl.program_id(axis=0) * BLOCK_SIZE + base = tl.arange(0, BLOCK_SUB) + loops = (BLOCK_SIZE + BLOCK_SUB - 1) // BLOCK_SUB # <- + for loop in range(loops): + offsets = offset + (loop * BLOCK_SUB) + base + mask = offsets < min(BLOCK_SIZE + offset, n_elements) + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"],) + act_res = triton_tiling_axis_parse_base_case1[grid]() + + check_axes_parse_res(act_res, ref_res) + + +@pytest.mark.skip +def test_tiling_axis_parse_base_case2(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["n_elements"] + ) + @triton.jit + def triton_tiling_axis_parse_base_case2( + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_SUB: tl.constexpr + ): + offset = tl.program_id(axis=0) * BLOCK_SIZE + base = tl.arange(0, BLOCK_SUB) + for offset_sub in range(0, BLOCK_SIZE, BLOCK_SUB): + offsets = offset + offset_sub + base[:] # <- + mask = offsets < min(BLOCK_SIZE + offset, n_elements) + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"],) + act_res = triton_tiling_axis_parse_base_case2[grid]() + + check_axes_parse_res(act_res, ref_res) + + +@pytest.mark.skip +def test_tiling_axis_parse_base_case3(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune( + configs=[], + key=["n_elements"] + ) + @triton.jit + def triton_tiling_axis_parse_base_case3( + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_SUB: tl.constexpr + ): + offset = tl.program_id(axis=0) * BLOCK_SIZE + base = tl.arange(0, BLOCK_SUB)[:] # <- + for offset_sub in range(0, BLOCK_SIZE, BLOCK_SUB): + offsets = offset + offset_sub + base + mask = offsets < min(BLOCK_SIZE + offset, n_elements) + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + + tl.store(output_ptr + offsets, output, mask=mask) + + ref_res = { + "keys": {"x": "n_elements"}, + "split_params": {"x": "BLOCK_SIZE"}, + "tiling_params": {"x": "BLOCK_SUB"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + } + grid = lambda meta: (meta["BLOCK_SIZE"],) + act_res = triton_tiling_axis_parse_base_case3[grid]() + + check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/unittest/custom_op/builtin_ops_demo.py b/third_party/ascend/unittest/custom_op/builtin_ops_demo.py new file mode 100644 index 0000000000..232b54fccc --- /dev/null +++ b/third_party/ascend/unittest/custom_op/builtin_ops_demo.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +import subprocess +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + + +@triton.jit +def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + index = tl.full([8], 0, tl.int32) + value = tl.full([8, 64], 0, tl.float32) + tmp = tl.full([8], 0, tl.float32) + x = al.custom("__builtin_index_select", + x_ptr, index, + dim=0, + bound=100, + end_offset=(2, 2), + start_offset=(0, 0), + src_stride=(4, 1), + out=x) + al.custom("__builtin_index_put", + x_ptr, index, value, + dim=0, + bound=12, + dst_shape=(1, 2, 3), + dst_offset=(4, 5, 6), + dst_stride=(8, 4, 1)) + tmp = al.custom("__builtin_gather_load", + y_ptr, index, + bound=100, + dim=0, + src_stride=(1,), + index_shape=(3,), + offsets=(0,), + out=tmp) + al.custom("__builtin_scatter_store", + out_ptr, value, index, + 1, 0, (1, ), (2, ), (1, )) + y = al.custom("__builtin_indirect_load", x_ptr, index, mask=i < n, other=y, out=y) + al.custom("__builtin_indirect_store", out_ptr, index, value) + tl.store(out_ptr + i, y, mask=i < n) + + +if __name__ == "__main__": + src = ASTSource(my_kernel, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, {"BLOCK": 256}) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + options = NPUOptions() + try: + ttir = ast_to_ttir(my_kernel, src, context, options, {}, {}) + print("=== TTIR ===") + print(ttir) + metadata = { + **options.__dict__, + } + linalg = ttir_to_linalg(ttir, metadata, options, named_ops=True) + print("=== MLIR (linalg) ===") + print(linalg) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + print("failed") diff --git a/third_party/ascend/unittest/custom_op/custom_op_demo.py b/third_party/ascend/unittest/custom_op/custom_op_demo.py new file mode 100644 index 0000000000..4817658627 --- /dev/null +++ b/third_party/ascend/unittest/custom_op/custom_op_demo.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +import subprocess +import os +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + + +@al.register_custom_op +class min_custom_op: + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_MTE2 + mode = al.MODE.SIMD + + symbol = 'min_custom_op_impl' + bitcode = os.path.abspath(__file__) + + +@al.register_custom_op +class simple_custom_op: + # name is optional, use class name by default. + name = 'simple_custom_op' + + # required attributes. + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + + symbol = 'simple_custom_op_impl' + bitcode = os.path.abspath(__file__) + + # __init__ method is optional, but it can be used for better user experience + # when provided. for example, you can validate arguments here. + def __init__(self, x, y, dim=0, out=None): + assert x.shape == y.shape, "x and y should have same shape" + assert isinstance(dim, int), "dim should be const integer" + assert out, "out is required" + + +@al.register_custom_op +class _example_custom_op: + name = 'example_custom_op' + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + + symbol = 'example_custom_op_impl' + bitcode = os.path.abspath(__file__) + + def __init__(self, src, index, offset: tl.int64, axis, out=None): + # support validate arguments in __init__ method. + assert isinstance(src, tl.tensor), "src should be tensor" + assert index.dtype.is_int(), "index should be integer tensor" + assert isinstance(offset, int), "offset should be integer" + assert isinstance(axis, int), "axis should be integer" + + # support multi-output by using tuple or list. + assert isinstance(out, tuple) and len(out) == 2, "out should be tuple of 2 items" + + # setup the symbol name of the function that will be called at runtime. + rank = len(index.shape) + self.symbol = f"{self.name}_{rank}d_{src.dtype.cname}_{index.dtype.cname}" + + # setup source and compile command if it is implemented by user source code. + self.source = f"workspace/example_custom_op_impl.cce" + self.compile = "bisheng -O2 -std=c++17 -o $@ -c $<" + + # dynamic set argument type. + self.arg_type['axis'] = index.dtype + + +@al.builtin +def example_op(src, index, offset, axis, _builder=None): + # you can wrap a custom op as a builtin operation, + # output can be provided here to make it easy to use. + x = tl.semantic.full(src.shape, 0, tl.float32, _builder) + y = tl.semantic.full(index.shape, 0, tl.float32, _builder) + return al.custom_semantic(_example_custom_op.name, + src, index, offset, axis, out=(x, y), _builder=_builder) + + +@triton.jit +def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + y = al.custom("min_custom_op", x, x_ptr, y_ptr + i, al.int64(0), (1, 2, 3), [4.1, 5.2], out=y) + y = al.custom("simple_custom_op", x, y, dim=1, out=y) + index = tl.full((2, 3), 0, tl.int64) + x, y = al.custom("example_custom_op", x, index, offset=1, axis=0, out=(x, y)) + result, _ = example_op(x, index, offset=2, axis=1) + tl.store(out_ptr + i, result, mask=i < n) + + +if __name__ == "__main__": + src = ASTSource(my_kernel, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, {"BLOCK": 256}) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + options = NPUOptions() + try: + ttir = ast_to_ttir(my_kernel, src, context, options, {}, {}) + print("=== TTIR ===") + print(ttir) + metadata = { + **options.__dict__, + } + linalg = ttir_to_linalg(ttir, metadata, options, named_ops=True) + print("=== MLIR (linalg) ===") + print(linalg) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + print("failed") diff --git a/third_party/ascend/unittest/custom_op/custom_op_extra_buffer_demo.py b/third_party/ascend/unittest/custom_op/custom_op_extra_buffer_demo.py new file mode 100644 index 0000000000..4469d7eec0 --- /dev/null +++ b/third_party/ascend/unittest/custom_op/custom_op_extra_buffer_demo.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Demo: declare scratch/extra buffers on a custom op via `extra_buffers` (dtype, size) +# and read back the sizes from lowered HIVM MLIR (`extra_buffers_sizes` attribute). + +from __future__ import annotations + +import os +import re +import subprocess + +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + +# Scratch buffers requested by the custom kernel (element type + length in elements). +SCRATCH_SPEC = [ + (tl.float32, 1024), + (tl.bfloat16, 512), + (tl.int32, 256), +] + + +@al.register_custom_op +class demo_extra_buffer_op: + """Custom op that advertises extra device buffers for the NPU compiler / runtime.""" + + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "demo_extra_buffer_op_impl" + bitcode = os.path.abspath(__file__) + + def __init__(self, x, out=None): + self.indexing_map = [al.affine_map.get_identity(1)] + self.extra_buffers = list(SCRATCH_SPEC) + + +@triton.jit +def kernel_extra_buf(x_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(out_ptr + i, mask=i < n) + r = al.custom("demo_extra_buffer_op", x, out=y) + tl.store(out_ptr + i, r, mask=i < n) + + +def compile_to_linalg_mlir(kernel, signature: dict, constants: dict) -> str | None: + src = ASTSource(kernel, signature, constants) + ctx = ir.context() + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + options = NPUOptions() + try: + ttir = ast_to_ttir(kernel, src, ctx, options, {}, {}) + meta = {**options.__dict__} + return str(ttir_to_linalg(ttir, meta, options, named_ops=True)) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + return None + + +def extract_extra_buffer_sizes_from_mlir(mlir: str) -> list[int]: + """ + Parse `extra_buffers_sizes` from HIVM custom op text. + """ + # Parse [1024, 512, 256, ...] + m = re.search(r"extra_buffers_sizes\s*=\s*\[([^\]]+)\]", mlir) + if m: + raw = m.group(1).replace(" ", "") + return [int(x) for x in raw.split(",") if x] + + return [] + + +def main() -> None: + expected_sizes = [size for _, size in SCRATCH_SPEC] + print("Declared extra_buffers (dtype, element_count):") + for dt, sz in SCRATCH_SPEC: + print(f" {dt} -> {sz} elements") + + mlir = compile_to_linalg_mlir( + kernel_extra_buf, + {"x_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 128}, + ) + if not mlir: + print("Compilation failed.") + return + + parsed = extract_extra_buffer_sizes_from_mlir(mlir) + print("\nParsed extra_buffers_sizes from MLIR:", parsed) + if parsed == expected_sizes: + print("OK: MLIR sizes match the Python extra_buffers specification.") + elif parsed: + print("Note: parsed sizes differ from spec; inspect MLIR spelling below.") + else: + print( + "Could not parse extra_buffers_sizes automatically; " + "search the dump for 'extra_buffers_sizes'." + ) + + print("\n--- MLIR excerpt (lines containing hivm.hir.custom) ---") + for line in mlir.splitlines(): + if "hivm.hir.custom" in line and "demo_extra_buffer_op" in line: + print(line) + + +if __name__ == "__main__": + main() diff --git a/third_party/ascend/unittest/custom_op/custom_op_indexing_map_complex_demo.py b/third_party/ascend/unittest/custom_op/custom_op_indexing_map_complex_demo.py new file mode 100644 index 0000000000..11e6dcf3cf --- /dev/null +++ b/third_party/ascend/unittest/custom_op/custom_op_indexing_map_complex_demo.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +import os +import subprocess +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + + +def _make_indexing_maps(): + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + c8 = ascend_ir.affine_expr.get_constant(8) + + # Input maps use transpose and identity-like projections. + in0 = ascend_ir.affine_map.get(2, 0, [d1, d0]) + in1 = ascend_ir.affine_map.get(2, 0, [d0, d1]) + + # Output map models tiled coordinates. + out = ascend_ir.affine_map.get(2, 0, [d0.floordiv(c8), d1.mod(c8)]) + return [in0, in1, out] + + +@al.register_custom_op +class complex_indexing_map_custom_op: + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "complex_indexing_map_custom" + # Fake path: this example checks IR lowering only. + bitcode = os.path.abspath(__file__) + + def __init__(self, x, y, out=None): + assert out is not None + self.indexing_map = _make_indexing_maps() + + +@triton.jit +def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + out = al.custom("complex_indexing_map_custom_op", x, y, out=x) + tl.store(out_ptr + i, out, mask=i < n) + + +if __name__ == "__main__": + src = ASTSource( + my_kernel, + {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + options = NPUOptions() + try: + ttir = ast_to_ttir(my_kernel, src, context, options, {}, {}) + print("=== TTIR ===") + print(ttir) + linalg = ttir_to_linalg(ttir, {**options.__dict__}, options, named_ops=True) + print("=== MLIR (linalg) ===") + print(linalg) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + print("failed") diff --git a/third_party/ascend/unittest/custom_op/custom_op_indexing_map_compose_demo.py b/third_party/ascend/unittest/custom_op/custom_op_indexing_map_compose_demo.py new file mode 100644 index 0000000000..6d42c5b832 --- /dev/null +++ b/third_party/ascend/unittest/custom_op/custom_op_indexing_map_compose_demo.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +import os +import subprocess +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + + +def _compose_indexing_maps(): + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + c4 = ascend_ir.affine_expr.get_constant(4) + + # Base permutation map. + perm = ascend_ir.affine_map.get_permutation([1, 0]) + # Tile map (row-major tile decomposition). + tile = ascend_ir.affine_map.get(2, 0, [d0.floordiv(c4), d1.mod(c4)]) + # Compose tile with permutation to build a different output indexing. + out = tile.compose(perm) + + in0 = ascend_ir.affine_map.get_identity(2) + in1 = perm + return [in0, in1, out] + + +@al.register_custom_op +class compose_indexing_map_custom_op: + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "compose_indexing_map_custom" + bitcode = os.path.abspath(__file__) + + def __init__(self, x, y, out=None): + assert out is not None + self.indexing_map = _compose_indexing_maps() + + +@triton.jit +def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + out = al.custom("compose_indexing_map_custom_op", x, y, out=y) + tl.store(out_ptr + i, out, mask=i < n) + + +if __name__ == "__main__": + src = ASTSource( + my_kernel, + {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + options = NPUOptions() + try: + ttir = ast_to_ttir(my_kernel, src, context, options, {}, {}) + print("=== TTIR ===") + print(ttir) + linalg = ttir_to_linalg(ttir, {**options.__dict__}, options, named_ops=True) + print("=== MLIR (linalg) ===") + print(linalg) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + print("failed") diff --git a/third_party/ascend/unittest/custom_op/test_gather_load.py b/third_party/ascend/unittest/custom_op/test_gather_load.py new file mode 100644 index 0000000000..dc92988d28 --- /dev/null +++ b/third_party/ascend/unittest/custom_op/test_gather_load.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +import torch +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al + + +@triton.jit +def test_gather_load_kernel(src_ptr, index_ptr, out_ptr): + # index tile shape: (2, 2) + cols = tl.arange(0, 2)[None, :] # [[0, 1]] + rows = tl.arange(0, 2)[:, None] # [[0],[1]] + mask = (rows < 2) & (cols < 2) + + # load index tile to UB + index = tl.load(index_ptr + rows * 2 + cols, mask) + + # gather load from GM to UB + dst = tl.full(index.shape, 0, tl.float32) + gathered = al.custom("__builtin_gather_load", + src_ptr, index, + bound=4, + dim=0, + src_stride=(2, 1), + index_shape=(2, 2), + offsets=(0, 0), + out=dst) + + # store result to GM + tl.store(out_ptr + rows * 2 + cols, gathered, mask) + + +if __name__ == "__main__": + src = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]], device='npu') + index = torch.tensor([[0, 1], [2, 3]], device='npu') + out = torch.empty((2, 2), device='npu', dtype=torch.float32) + test_gather_load_kernel[(1,)](src, index, out) + print("result: ", out) # [[1., 4.], [5., 8.]] diff --git a/third_party/ascend/unittest/custom_op/test_index_select.py b/third_party/ascend/unittest/custom_op/test_index_select.py new file mode 100644 index 0000000000..97c3e72502 --- /dev/null +++ b/third_party/ascend/unittest/custom_op/test_index_select.py @@ -0,0 +1,49 @@ +import pytest +import torch +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al + + +@triton.jit +def builtin_index_select_kernel(src_ptr, index_ptr, out_ptr): + # Define 2x2 tile indices for output tensor + r = tl.arange(0, 2)[:, None] # Row indices: shape [2, 1] + c = tl.arange(0, 2)[None, :] # Column indices: shape [1, 2] + + # Load index tensor (shape [2]) from GM to UB + idx = tl.load(index_ptr + tl.arange(0, 2)) + # Initialize empty 2x2 output tile in UB (default value: 0) + dst = tl.full((2, 2), 0, dtype=tl.float32) + + # Invoke __builtin_index_select custom op to gather elements + out_tile = al.custom( + "__builtin_index_select", + src_ptr, # Pointer to source tensor in GM + idx, # Index tensor (in UB) for gathering + dim=0, # Dimension to gather along + bound=4, # Upper bound for valid index values (out-of-bound check) + end_offset=(2, 2),# End offsets of each dimension for the index tensor + start_offset=(0, 0), # Start offsets of each dimension for the source tensor + src_stride=(4, 1),# Stride of each dimension for the source tensor in GM + out=dst # Output tensor (in UB) to store gathered elements + ) + + # Store the gathered tile from UB to output tensor in GM + tl.store(out_ptr + r * 2 + c, out_tile) + + +if __name__ == "__main__": + src = torch.tensor( + [[10., 11., 12., 13.], + [20., 21., 22., 23.], + [30., 31., 32., 33.], + [40., 41., 42., 43.]], + device="npu", + dtype=torch.float32, + ) + index = torch.tensor([2, 0], device="npu", dtype=torch.int32) + out = torch.empty((2, 2), device="npu", dtype=torch.float32) + ref = torch.index_select(src, 0, index.to(torch.int64))[:, :2] + builtin_index_select_kernel[(1,)](src, index, out) + torch.testing.assert_close(out, ref) # ref: [[30., 31.], [10., 11.]] diff --git a/third_party/ascend/unittest/generalization_cases/acc_util.py b/third_party/ascend/unittest/generalization_cases/acc_util.py deleted file mode 100644 index b1295885a6..0000000000 --- a/third_party/ascend/unittest/generalization_cases/acc_util.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import numpy as np -import torch -import torch_npu - -eval_standard = { - torch.float32: { - "rtol": 1e-6, - "small_value": 1e-6, - "small_value_atol": 1e-9, - "etol": 1e-4, - }, - torch.float16: { - "rtol": 1e-3, - "small_value": 1e-3, - "small_value_atol": 1e-5, - "etol": 1e-3, - }, - torch.bfloat16: { - "rtol": 4e-3, - "small_value": 1e-3, - "small_value_atol": 1e-5, - "etol": 1e-3, - }, -} - - -def assert_close(gold: torch.Tensor, act: torch.Tensor, eval_type: str = 'DEFAULT'): - gold = gold.cpu() - act = act.cpu() - if act.dtype == torch.float16 or act.dtype == torch.float32 or act.dtype == torch.bfloat16: - assert gold.dtype == torch.float32, "golden should be f32" - assert not (torch.isnan(act).any() or torch.isinf(act).any()), "actual tensor can not have 'inf' or 'nan'" - eps = eval_standard[act.dtype]['small_value'] - rtol = eval_standard[act.dtype]['rtol'] - atol = eval_standard[act.dtype]['small_value_atol'] - if eval_type == 'DEFAULT': - ae = torch.abs(act - gold) - re = ae / torch.abs(gold) - mask = torch.abs(gold) < eps - - print(f"count ae > {atol}: {(ae > atol).sum()}") - print(f"count re > {rtol}: {(re > rtol).sum()}") - - not_close = torch.where(mask, ae > atol, re > rtol) - print(f"count not_close = {torch.sum(not_close).item()}") - print(f"not_close.numel = {not_close.numel()}, gold.numel = {gold.numel()}") - print(f"not close ratio = {torch.sum(not_close).item() / not_close.numel()}") - if not torch.any(not_close): - return False - - assert torch.sum( - not_close).item() < not_close.numel() * eps, "actual tensor are not close enough with golden tensor,\ -you can use 'benchmark_compare_close' function to compare again!" - - elif eval_type == 'ABS': - act = act.to(gold.dtype) - assert torch.equal(gold, act), "actual tensor and golden tensor are not binary equal!" - else: - assert 0, "ERROR! invalid eval_type" - return False - - -def benchmark_compare_close(gold: torch.Tensor, act: torch.Tensor, std: torch.tensor): - assert act.dtype == std.dtype, "standard tensor's dtype must equal to actual tensor's dtype!" - if act.dtype == torch.float16 or act.dtype == torch.float32 or act.dtype == torch.bfloat16: - assert gold.dtype == torch.float32, "golden should be f32" - assert not (torch.isnan(act).any() or torch.isinf(act).any()), "actual tensor can not have 'inf' or 'nan'" - - gold = gold.cpu() - act = act.cpu() - std = std.cpu() - - eps = eval_standard[act.dtype]['small_value'] - atol = eval_standard[act.dtype]['small_value_atol'] - - mask = torch.abs(gold) <= eps - small_count = mask.sum().item() - - def calculate_relative_errors_except_small(tensor): - re = torch.abs(gold - tensor) / torch.abs(gold) - return torch.where(mask, 0, re) - - act_re = calculate_relative_errors_except_small(act) - std_re = calculate_relative_errors_except_small(std) - act_ae = torch.abs(gold - std) - std_ae = torch.abs(gold - std) - - # 小值域的定义为golden小于某个阈值 eps - act_small_error_count = (mask & (act_ae > atol)).sum().item() - std_small_error_count = (mask & (std_ae > atol)).sum().item() - act_total = act.numel() - std_total = std.numel() - - act_small_error_ratio = act_small_error_count / act_total - std_small_error_ratio = std_small_error_count / std_total - - def calculate_rmse(tensor): - dlt2 = (tensor - gold)**2 - dlt2_except_small_mean = torch.where(mask, 0, dlt2).sum() / small_count - return torch.sqrt(dlt2_except_small_mean) - - act_rmse = calculate_rmse(act) - std_rmse = calculate_rmse(std) - - print(f"act_re.max = {act_re.max()}, std_re.max = {std_re.max()}, limit ratio = 10") - print(f"act_re.sum = {act_re.sum()}, std_re.sum = {std_re.sum()}, limit_ratio = 2") - print( - f"act_small_error_ratio = {act_small_error_ratio}, std_small_error_ratio = {std_small_error_ratio}, limit_ratio = 2" - ) - print(f"act_rmse = {act_rmse}, std_rmse = {std_rmse}, limit_ratio = 2") - - # 条件 1:actual 与 golden 相对误差最大值超过 10 倍 standard 与 golden 相对误差最大值 - assert act_re.max() <= 10 * std_re.max(), "actual re max > stdandard re max's 10 times" - - # 条件 2:actual 与 golden 相对误差均值超过 2 倍 standard 与 golden 相对误差均值 - assert act_re.sum() <= 2 * std_re.sum(), "actual re sum > stdandard re sum's 2 times" - - # 条件 3:actual 小值域 ERROR 占比超过 standard 的两倍 - assert act_small_error_ratio <= 2 * std_small_error_ratio, "act_small_error_ratio > std_small_error_ratio 's 2 times" - - # 条件 4:actual 均方根误差差于 standard 的两倍 - assert act_rmse <= 2 * std_rmse, "act_rmse > std_rmse 's 2 times" - - return False diff --git a/third_party/ascend/unittest/generalization_cases/test_abs.py b/third_party/ascend/unittest/generalization_cases/test_abs.py deleted file mode 100644 index c5d06d0d40..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_abs.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, avoid_not_support -import math -import logging - - -def torch_pointwise(x0): - if x0.dtype != torch.uint32: - return torch.abs(x0) - else: - return torch.abs(x0.to(torch.float32)) - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.abs(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_abs_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.abs(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_pointwise(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) -def test_abs_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_pointwise(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_abs_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_advance.py b/third_party/ascend/unittest/generalization_cases/test_advance.py deleted file mode 100644 index 8c7a75dc46..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_advance.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, ), - strides=(1, ), - offsets=(5, ), - block_shape=(XB, ), - order=(0, ), - ) - bbptr = tl.advance(block_ptr_in, (-5, )) - # XB,YB,1 - X = tl.load(bbptr) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, ), - strides=(1, ), - offsets=(0, ), - block_shape=(XB, ), - order=(0, ), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def fn_npu_2d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xoffset = tl.program_id(0) - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, YB), - strides=(YB, 1), - offsets=(6 + xoffset, 5), - block_shape=(XB, YB), - order=(1, 0), - ) - bbptr = tl.advance(block_ptr_in, (-6, -5)) - # XB,YB,1 - X = tl.load(bbptr) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, YB), - strides=(YB, 1), - offsets=(xoffset, 0), - block_shape=(XB, YB), - order=(1, 0), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def fn_npu_3d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, YB, ZB), - strides=(YB * ZB, ZB, 1), - offsets=(3, 1, 2), - block_shape=(XB, YB, ZB), - order=(2, 1, 0), - ) - bbptr = tl.advance(block_ptr_in, (-3, -1, -2)) - # XB,YB,1 - X = tl.load(bbptr) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, YB, ZB), - strides=(YB * ZB, ZB, 1), - offsets=(0, 0, 0), - block_shape=(XB, YB, ZB), - order=(2, 1, 0), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def triton_advance_4d( - output_ptr, - x_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), - offsets=(6, 5, 4, 3), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), - order=(3, 2, 1, 0), - ) - bbptr = tl.advance(block_ptr_in, (-6, -5, -4, -3)) - x = tl.load(bbptr) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), - offsets=(0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), - order=(3, 2, 1, 0), - ) - tl.store(block_ptr_out, x) - - -@triton.jit -def triton_advance_5d( - output_ptr, - x_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - BLOCK_4: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr, -): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), - offsets=(6, 5, 4, 3, 2), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), - order=(4, 3, 2, 1, 0), - ) - bbptr = tl.advance(block_ptr_in, (-6, -5, -4, -3, -2)) - x = tl.load(bbptr) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), - offsets=(0, 0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), - order=(4, 3, 2, 1, 0), - ) - tl.store(block_ptr_out, x) - - -temporarily_not_support_dtype = ['bool'] - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.full_shape) -def test_npu(dtype, shape): - if dtype in temporarily_not_support_dtype: - return - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - - a = x - blocks = list(x.size()) - strides = list(x.stride()) - grid = (1, ) - if len(shape) == 5: - triton_advance_5d[grid](output, x, *blocks, *blocks, *strides) - elif len(shape) == 4: - triton_advance_4d[grid](output, x, *blocks, *blocks, *strides) - elif len(shape) == 3: - fn_npu_3d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=shape[2]) - elif len(shape) == 2: - if x.numel() * x.element_size() > 8192: - fn_npu_2d[shape[0], 1, 1](output, x, y, z, output1, XB=1, YB=shape[1], ZB=1) - else: - fn_npu_2d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=1) - else: - fn_npu_1d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=1, ZB=1) - - torch.testing.assert_close(output, a) diff --git a/third_party/ascend/unittest/generalization_cases/test_and.py b/third_party/ascend/unittest/generalization_cases/test_and.py deleted file mode 100644 index 4bac287eaf..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_and.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_pointwise(x, y): - res = x & y - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X & Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_and_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val & y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_pointwise(x, y) - output = torch.zeros_like(ans) - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_and_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x & y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_and_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/unittest/generalization_cases/test_argmax.py b/third_party/ascend/unittest/generalization_cases/test_argmax.py deleted file mode 100644 index edbf8b9d8d..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_argmax.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import math -import pytest -import torch -import torch_npu -import numpy as np -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - -logger = logging.getLogger(__name__) - - -# <<<<<<< test_argmax_1d -def torch_argmax(x0, dim, keepdim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - return torch.argmax(x0, dim=dim, keepdim=keepdim).npu() - - -@triton.jit -def triton_argmax_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None) - tmp4 = tl.argmax(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_argmax_1d(dtype, shape): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty(1, dtype=torch.int32).npu() - numel = shape[0] - triton_argmax_1d[1, 1, 1](x0, triton_res, numel, numel) - torch_res = torch_argmax(x0, dim=0, keepdim=True) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmax_1d - - -# <<<<<<< test_argmax_2d -@triton.jit -def triton_argmax_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) - tmp4 = tl.argmax(x, dim) - if dim == 0: - tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) - else: - tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0, 1]) -def test_argmax_2d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - shapex, shapey = shape - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[1 - dim], - ], dtype=torch.int32).npu() - triton_argmax_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) - torch_res = torch_argmax(x0, dim=dim, keepdim=False) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmax_2d - - -# <<<<<<< test_argmax_3d -def torch_argmax_3d(x0, no_reduce_dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - if no_reduce_dim == 0: - return torch.argmax(torch.max(x0, 1)[0], 1).npu() - elif no_reduce_dim == 1: - return torch.argmax(torch.max(x0, 0)[0], 1).npu() - elif no_reduce_dim == 2: - return torch.argmax(torch.max(x0, 0)[0], 0).npu() - else: - assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" - - -@triton.jit -def triton_argmax_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.max(x, 0) - ret = tl.argmax(tmp, 0) - oidx = zidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_argmax_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.max(x, 0) - ret = tl.argmax(tmp, 1) - oidx = yidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_argmax_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.max(x, 1) - ret = tl.argmax(tmp, 1) - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -def triton_argmax_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): - if no_reduce_dim == 0: - triton_argmax_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 1: - triton_argmax_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 2: - triton_argmax_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) -def test_argmax_3d(dtype, shape, no_reduce_dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[no_reduce_dim], - ], dtype=torch.int32).npu() - triton_argmax_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) - torch_res = torch_argmax_3d(x0, no_reduce_dim) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmax_3d - - -# <<<<<<< test_argmax_4d -def torch_argmax_4d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.argmax(x0, dim) - - -@triton.jit -def argmax_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB // MB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_argmax_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - - idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[ - None, None, :, None] * MB + midx[None, None, None, :] - - x = tl.load(in_ptr + idx) - - argmax_4d(out_ptr, x, XB, YB, ZB, MB, DIM) - - -def triton_argmax_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): - triton_argmax_kernel_4d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 2, 4, 8), - (2, 3, 4, 8), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0]) -def test_argmax_4d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_argmax_4d(x0, dim).to(torch.int32) - triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() - triton_argmax_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) - - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmax_4d - - -# <<<<<<< test_argmax_5d -def torch_argmax_5d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.argmax(x0, dim) - - -@triton.jit -def argmax_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, - DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 3: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // NB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_argmax_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - - idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[ - None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] - - x = tl.load(in_ptr + idx) - - argmax_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) - - -def triton_argmax_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): - triton_argmax_kernel_5d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 2, 2, 4, 8), - (2, 2, 3, 4, 8), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0]) -def test_argmax_5d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_argmax_5d(x0, dim).to(torch.int32) - triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() - triton_argmax_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) - - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmax_5d - - -# <<<<<<< test_argmax_1d_bool -@triton.jit -def triton_argmax_1d_bool(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None).to(tl.int1) - tmp4 = tl.argmax(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['bool']) -def test_argmax_1d_bool(dtype, shape): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - x0 = test_common.generate_tensor(shape, dtype) - triton_res = torch.empty(1, dtype=torch.int32).npu() - numel = shape[0] - triton_argmax_1d_bool[1, 1, 1](x0.npu(), triton_res, numel, numel) - np_res = np.argmax(x0.numpy()) - np.equal(triton_res.item(), np_res) - - -# >>>>>>> test_argmax_1d_bool diff --git a/third_party/ascend/unittest/generalization_cases/test_argmin.py b/third_party/ascend/unittest/generalization_cases/test_argmin.py deleted file mode 100644 index 36a671d1ba..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_argmin.py +++ /dev/null @@ -1,360 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import math -import pytest -import torch -import torch_npu -import numpy as np -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - -logger = logging.getLogger(__name__) - - -# <<<<<<< test_argmin_1d -def torch_argmin(input_tensor, dim, keepdim): - return torch.argmin(input_tensor, dim=dim, keepdim=keepdim) - - -@triton.jit -def triton_argmin_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None) - tmp4 = tl.argmin(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_argmin_1d(dtype, shape): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty(1, dtype=torch.int32).npu() - numel = shape[0] - triton_argmin_1d[1, 1, 1](x0, triton_res, numel, numel) - torch_res = torch_argmin(x0, dim=0, keepdim=True) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmin_1d - - -# <<<<<<< test_argmin_2d -@triton.jit -def triton_argmin_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) - tmp4 = tl.argmin(x, dim) - if dim == 0: - tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) - else: - tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0, 1]) -def test_argmin_2d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - shapex, shapey = shape - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[1 - dim], - ], dtype=torch.int32).npu() - triton_argmin_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) - torch_res = torch_argmin(x0, dim=dim, keepdim=False) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmin_2d - - -# <<<<<<< test_argmin_3d -def torch_argmin_3d(x0, no_reduce_dim): - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - if no_reduce_dim == 0: - return torch.argmin(torch.min(x0, 1)[0], 1) - elif no_reduce_dim == 1: - return torch.argmin(torch.min(x0, 0)[0], 1) - elif no_reduce_dim == 2: - return torch.argmin(torch.min(x0, 0)[0], 0) - else: - assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" - - -@triton.jit -def triton_argmin_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.min(x, 0) - ret = tl.argmin(tmp, 0) - oidx = zidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_argmin_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.min(x, 0) - ret = tl.argmin(tmp, 1) - oidx = yidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_argmin_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.min(x, 1) - ret = tl.argmin(tmp, 1) - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -def triton_argmin_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): - if no_reduce_dim == 0: - triton_argmin_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 1: - triton_argmin_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 2: - triton_argmin_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) -def test_argmin_3d(dtype, shape, no_reduce_dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[no_reduce_dim], - ], dtype=torch.int32).npu() - triton_argmin_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) - torch_res = torch_argmin_3d(x0, no_reduce_dim) - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmin_3d - - -# <<<<<<< test_argmin_4d -def torch_argmin_4d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.argmin(x0, dim) - - -@triton.jit -def argmin_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB // MB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_argmin_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - - idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[ - None, None, :, None] * MB + midx[None, None, None, :] - - x = tl.load(in_ptr + idx) - - argmin_4d(out_ptr, x, XB, YB, ZB, MB, DIM) - - -def triton_argmin_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): - triton_argmin_kernel_4d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 2, 4, 8), - (2, 3, 4, 8), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0]) -def test_argmin_4d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_argmin_4d(x0, dim).to(torch.int32) - triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() - triton_argmin_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) - - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmin_4d - - -# <<<<<<< test_argmin_5d -def torch_argmin_5d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.argmin(x0, dim) - - -@triton.jit -def argmin_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, - DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 3: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // NB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_argmin_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - - idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[ - None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] - - x = tl.load(in_ptr + idx) - - argmin_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) - - -def triton_argmin_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): - triton_argmin_kernel_5d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 2, 2, 4, 8), - (2, 2, 3, 4, 8), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize('dim', [0]) -def test_argmin_5d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_argmin_5d(x0, dim).to(torch.int32) - triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() - triton_argmin_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) - - test_common.validate_cmp("int32", triton_res, torch_res) - - -# >>>>>>> test_argmin_5d - - -# <<<<<<< test_argmin_1d_bool -@triton.jit -def triton_argmin_1d_bool(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None).to(tl.int1) - tmp4 = tl.argmin(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['bool']) -def test_argmin_1d_bool(dtype, shape): - dtype_size = get_dtype_size(dtype) - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") - return - x0 = test_common.generate_tensor(shape, dtype) - triton_res = torch.empty(1, dtype=torch.int32).npu() - numel = shape[0] - triton_argmin_1d_bool[1, 1, 1](x0.npu(), triton_res, numel, numel) - np_res = np.argmin(x0.numpy()) - np.equal(triton_res.item(), np_res) - - -# >>>>>>> test_argmin_1d_bool diff --git a/third_party/ascend/unittest/generalization_cases/test_associative_scan.py b/third_party/ascend/unittest/generalization_cases/test_associative_scan.py deleted file mode 100644 index 249abc09fb..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_associative_scan.py +++ /dev/null @@ -1,523 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import random -import torch -import torch_npu -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, get_dtype_size - - -def combine_fn_test_torch(a, b, combine_fn): - if combine_fn == 'maximum_fn': - return torch.maximum(a, b) # 最大值 - elif combine_fn == 'minimum_fn': - return torch.minimum(a, b) # 最小值 - elif combine_fn == 'bitwise_xor_fn': - return a ^ b # 按位异或 - elif combine_fn == 'bitwise_or_fn': - return a | b # 按位异 - elif combine_fn == 'bitwise_and_fn': - return a & b # 按位与 - else: - pytest.skip("The combine_fn is not within the following scope , skipping.") - - -def torch_func_scan(input: torch.Tensor, dim: int, combine_fn='maximum', reverse=False): - """ - PyTorch 实现 associative_scan,语义与 Triton 完全对齐 - 支持任意 combine_fn(如 a|b, a&b, min, max 等) - """ - dim = dim % input.ndim - - if reverse: - input = input.flip(dim) - - N = input.size(dim) - - tensors = torch.unbind(input, dim=dim) - - outputs = [] - - carry = tensors[0] - outputs.append(carry) - - for i in range(1, N): - carry = combine_fn_test_torch(tensors[i], carry, combine_fn) - outputs.append(carry) - - output = torch.stack(outputs, dim=dim) - - if reverse: - output = output.flip(dim) - - return output - - -@triton.jit -def bitwise_and_fn(a, b): - return a & b - - -@triton.jit -def bitwise_or_fn(a, b): - return a | b - - -@triton.jit -def bitwise_xor_fn(a, b): - return a ^ b - - -@triton.jit -def minimum_fn(a, b): - return tl.minimum(a, b) - - -@triton.jit -def maximum_fn(a, b): - return tl.maximum(a, b) - - -@triton.jit -def triton_kernel_1d_scan( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - XBLOCK: tl.constexpr, - combine_fn_name: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - idx = tl.arange(0, XBLOCK) - x = tl.load(in_ptr0 + idx) - if combine_fn_name == "maximum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) - elif combine_fn_name == "minimum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) - elif combine_fn_name == "bitwise_or_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) - elif combine_fn_name == "bitwise_xor_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) - elif combine_fn_name == "bitwise_and_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) - - tl.store(out_ptr0 + idx, ret) - - -@triton.jit -def triton_kernel_2d_scan( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, - combine_fn_name: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx = idx_x[:, None] * numel_r + idx_r[None, :] - x = tl.load(in_ptr0 + idx) - - if combine_fn_name == "maximum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) - elif combine_fn_name == "minimum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) - elif combine_fn_name == "bitwise_or_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) - elif combine_fn_name == "bitwise_xor_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) - elif combine_fn_name == "bitwise_and_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) - tl.store(out_ptr0 + idx, ret) - - -@triton.jit -def triton_kernel_3d_scan( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - numel_z: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, - ZBLOCK: tl.constexpr, - combine_fn_name: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - tl.static_assert(numel_z == ZBLOCK, "numel_z must be equal to ZBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx_z = tl.arange(0, ZBLOCK) - idx = idx_x[:, None, None] * numel_r * numel_z + idx_r[None, :, None] * numel_z + idx_z[None, None, :] - x = tl.load(in_ptr0 + idx) - if combine_fn_name == "maximum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) - elif combine_fn_name == "minimum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) - elif combine_fn_name == "bitwise_or_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) - elif combine_fn_name == "bitwise_xor_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) - elif combine_fn_name == "bitwise_and_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) - tl.store(out_ptr0 + idx, ret) - - -@triton.jit -def triton_kernel_4d_scan( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, - combine_fn_name: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - idx = (xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + - zidx[None, None, :, None] * MB + midx[None, None, None, :]) - x = tl.load(in_ptr0 + idx) - if combine_fn_name == "maximum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) - elif combine_fn_name == "minimum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) - elif combine_fn_name == "bitwise_or_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) - elif combine_fn_name == "bitwise_xor_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) - elif combine_fn_name == "bitwise_and_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) - tl.store(out_ptr0 + idx, ret) - - -@triton.jit -def triton_kernel_5d_scan( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, - NB: tl.constexpr, - combine_fn_name: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - idx = (xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + - zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + - nidx[None, None, None, None, :]) - x = tl.load(in_ptr0 + idx) - if combine_fn_name == "maximum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) - elif combine_fn_name == "minimum_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) - elif combine_fn_name == "bitwise_or_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) - elif combine_fn_name == "bitwise_xor_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) - elif combine_fn_name == "bitwise_and_fn": - ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) - tl.store(out_ptr0 + idx, ret) - - -def triton_func_scan(x, dim, combine_fn, reverse): - res = torch.empty_like(x) - shape = x.size() - - if len(shape) == 1: - if dim >= 1: - pytest.skip("dim >= 1 for 1D tensor, skipping.") - triton_kernel_1d_scan[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[0], combine_fn) - elif len(shape) == 2: - if dim >= 2: - pytest.skip("dim >= 2 for 2D tensor, skipping.") - triton_kernel_2d_scan[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1], combine_fn) - elif len(shape) == 3: - if dim >= 3: - pytest.skip("dim >= 3 for 3D tensor, skipping.") - triton_kernel_3d_scan[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[0], x.shape[1], - x.shape[2], combine_fn) - elif len(shape) == 4: - if dim >= 4: - pytest.skip("dim >= 4 for 4D tensor, skipping.") - triton_kernel_4d_scan[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], combine_fn) - elif len(shape) == 5: - if dim >= 5: - pytest.skip("dim >= 5 for 5D tensor, skipping.") - triton_kernel_5d_scan[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4], - combine_fn) - else: - pytest.skip(f"Unsupported tensor dimension: {len(shape)}") - - return res - - -def should_skip_due_to_mem(dtype, shape): - dtype_size = get_dtype_size(dtype) - total_mem = dtype_size * math.prod(shape) - if dtype in ('int8', 'bool'): - threshold = TestUtils.ub_size / 13 - else: - threshold = TestUtils.ub_size / 6 - - if total_mem >= threshold: - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - - -@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize("shape", TestUtils.test_shape1d) -@pytest.mark.parametrize("dim", [0]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_1d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize("shape", TestUtils.test_shape2d) -@pytest.mark.parametrize("dim", [1]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_2d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize("shape", TestUtils.test_shape3d) -@pytest.mark.parametrize("dim", [2]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_3d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize("shape", TestUtils.test_shape4d) -@pytest.mark.parametrize("dim", [3]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_4d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize("shape", TestUtils.test_shape5d) -@pytest.mark.parametrize("dim", [4]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_5d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape1d) -@pytest.mark.parametrize("dim", [0]) -@pytest.mark.parametrize("combine_fn", ['maximum_fn', 'minimum_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", random.sample(TestUtils.test_shape2d, 5)) -@pytest.mark.parametrize("dim", [1]) -@pytest.mark.parametrize("combine_fn", ['maximum_fn', 'minimum_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_float_2d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape3d) -@pytest.mark.parametrize("dim", [2]) -@pytest.mark.parametrize("combine_fn", ['maximum_fn', 'minimum_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape4d) -@pytest.mark.parametrize("dim", [3]) -@pytest.mark.parametrize("combine_fn", ['maximum_fn', 'minimum_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape5d) -@pytest.mark.parametrize("dim", [4]) -@pytest.mark.parametrize("combine_fn", ['maximum_fn', 'minimum_fn']) -@pytest.mark.parametrize("reverse", [False]) -def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - x_gold = x - cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) - - -@pytest.mark.parametrize("dtype", ['float16', 'float32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape1d) -@pytest.mark.parametrize("dim", [0]) -@pytest.mark.parametrize("combine_fn", ['bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [False]) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_scan_float_invalid(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) - - -@pytest.mark.parametrize("dtype", ['int32']) -@pytest.mark.parametrize("shape", TestUtils.test_shape1d) -@pytest.mark.parametrize("dim", [0]) -@pytest.mark.parametrize("combine_fn", - ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) -@pytest.mark.parametrize("reverse", [True]) -@test_common.raises_with_match(triton.compiler.errors.MLIRCompilationError, - "reverse=True is not yet supported for scan op") -def test_scan_float_invalid_reverse(dtype, shape, dim, combine_fn, reverse): - should_skip_due_to_mem(dtype, shape) - torch.manual_seed(0) - x = test_common.generate_tensor(shape=shape, dtype=dtype) - - x_npu = x.npu() - triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_add.py b/third_party/ascend/unittest/generalization_cases/test_atomic_add.py deleted file mode 100644 index d55448ae66..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_add.py +++ /dev/null @@ -1,576 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import numpy as np -import test_common -from test_common import TestUtils - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'int64', 'bool'}] - - -@triton.jit -def atomic_add(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - offset = tl.program_id(0) * BLOCK_SIZE - index = offset + tl.arange(0, BLOCK_SIZE)[:] - xmask = index < n_elements - - tmp0 = tl.load(in_ptr0 + (index), xmask) - tmp1 = tl.load(out_ptr0 + (index), xmask) - tl.atomic_add(out_ptr1 + (index), tmp0, xmask) - tl.atomic_add(out_ptr1 + (index), tmp1, xmask) - - -@triton.jit -def atomic_add_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic add: y += x (broadcasted) - tl.atomic_add(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_add(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -def promote_dtype(x_dtype, y_dtype): - """ - 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 - """ - # 如果两个数据类型一致,直接返回 - if x_dtype == y_dtype: - return y_dtype - - # 构建类型的优先级列表(从低到高) - priority = [torch.int8, torch.int16, torch.int32, torch.float16, torch.bfloat16, torch.float32] - - # 查找两种类型在优先级列表中的位置 - x_priority = priority.index(x_dtype) - y_priority = priority.index(y_dtype) - - # 如果y的优先级比x小,则提升到x的类型 - if y_priority < x_priority: - return x_dtype - else: - return y_dtype - - -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -@pytest.mark.parametrize('x_shape, y_shape, BLOCK_SIZE', test_cases) -def test_atomic_add_broadcast_combined(x_dtype_str, y_dtype_str, x_shape, y_shape, BLOCK_SIZE): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - # 先构造 x0 - x0 = torch.full(x_shape, 83.0000, dtype=x_dtype).npu() - - y_raw_dtype = eval('torch.' + y_dtype_str) - - out_dtype = promote_dtype(x_dtype, y_raw_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - - # 构造y和out - y = torch.full(y_shape, -105, dtype=y_raw_dtype).npu() - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - # 保存副本用于验证 - x_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - # 计算网格大小和元素总数 - n_elements = y.numel() - grid = (n_elements // BLOCK_SIZE, ) # 自动计算需要的线程块数量 - - # 调用 Triton 核函数 - atomic_add_broadcast[grid](x_ptr=x0, y_ptr=y, out_ptr=out, n_elements=n_elements, BLOCK_SIZE=BLOCK_SIZE) - - # 验证结果:y += x (广播加法) - expected = out_temp + y_temp + x_temp - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x0 = test_common.generate_tensor(shape, x_dtype_str).npu() - x1 = test_common.generate_tensor(shape, y_dtype_str).npu() - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - y = torch.full(x1.shape, 0, dtype=out_dtype).npu() - - # 保存副本用于验证 - x0_temp = x0.clone() - x1_temp = x1.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] - atomic_add[shape[0], 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=shape[1]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - atomic_add[grid_size, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=BLOCK_SIZE) - - expected = y_temp + x1_temp + x0_temp - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add_3d(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x0 = test_common.generate_tensor(shape, x_dtype_str).npu() - x1 = test_common.generate_tensor(shape, y_dtype_str).npu() - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - y = torch.full(x1.shape, 0, dtype=out_dtype).npu() - - # 保存副本用于验证 - x0_temp = x0.clone() - x1_temp = x1.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_add[1, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=shape[0] * shape[1] * shape[2]) - - expected = y_temp + x1_temp + x0_temp - torch.testing.assert_close(y, expected) - - -@triton.jit -def atomic_add_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_add(out_ptr0 + offsets, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_add_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = x1 + x0_value - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_add_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@triton.jit -def atomic_add_5d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr, - NB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1 * NB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1 * NB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1 * NB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] * NB1 - offsets1 = offsets1[:, :, :, :, None] + tl.arange(0, NB1)[None, None, None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tl.atomic_add(out_ptr + offsets1, tmp0) - tl.atomic_add(out_ptr + offsets1, tmp1) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 2, 1, 1), (1, 1, 2, 1, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add_5d(x_dtype_str, y_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() - else: - y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - x0_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 5: - triton_shape1.append(1) - XB1, YB1, ZB1, MB1, NB1 = triton_shape1 - - atomic_add_5d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - NB=NB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - NB1=NB1, - ) - - expected = out_temp + y_temp + x0_temp - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_add_4d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tl.atomic_add(out_ptr + offsets1, tmp0) - tl.atomic_add(out_ptr + offsets1, tmp1) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 2, 1), (1, 1, 2, 2)], - [(1, 1, 1, 1), (1, 1, 2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add_4d(x_dtype_str, y_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() - else: - y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - x0_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 4: - triton_shape1.append(1) - XB1, YB1, ZB1, MB1 = triton_shape1 - - atomic_add_4d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - ) - - expected = out_temp + y_temp + x0_temp - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_add_3d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XB1: tl.constexpr, - YB1: tl.constexpr, ZB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tl.atomic_add(out_ptr + offsets1, tmp0) - tl.atomic_add(out_ptr + offsets1, tmp1) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 2), (1, 2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add_3d_2(x_dtype_str, y_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() - else: - y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - x0_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 3: - triton_shape1.append(1) - XB1, YB1, ZB1 = triton_shape1 - - atomic_add_3d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - ) - - expected = out_temp + y_temp + x0_temp - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_add_2d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] - - offsets1 = tl.arange(0, XB1) * (YB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tl.atomic_add(out_ptr + offsets1, tmp0) - tl.atomic_add(out_ptr + offsets1, tmp1) - - -@pytest.mark.parametrize('param_list', [ - [(1, 2), (2, 2)], - [(1, 1), (2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@pytest.mark.parametrize('y_dtype_str', filtered_dtype) -def test_atomic_add_2d(x_dtype_str, y_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() - else: - y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - x0_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 2: - triton_shape1.append(1) - XB1, YB1 = triton_shape1 - - atomic_add_2d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - XB1=XB1, - YB1=YB1, - ) - - expected = out_temp + y_temp + x0_temp - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('param_list', [ - ['uint8', (32, 32), 2], - ['uint16', (32, 32), 2], - ['uint32', (32, 32), 2], -]) -def test_atomic_add_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] / ncore - split_size = shape[0] // ncore - x0_value = 3 - x0_cpu = torch.full(shape, x0_value, dtype=eval(f'torch.{dtype}')).cpu() - x0 = x0_cpu.to("npu") - x1_cpu = torch.full((split_size, shape[1]), 4, dtype=eval(f'torch.{dtype}')).cpu() - x1 = x1_cpu.to("npu") - y_cpu = torch.full((split_size, shape[1]), -10, dtype=eval(f'torch.{dtype}')).cpu() - y = y_cpu.to("npu") - - x1_np = x1_cpu.numpy() - y_ref_np = x1_np + 0 - x1_ref_np = x1_np + ncore * x0_value - - x1_ref = torch.from_numpy(x1_ref_np).npu() - y_ref = torch.from_numpy(y_ref_np).npu() - - @triton.jit - def atomic_add_uint(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) - tmp1 = tl.atomic_add(out_ptr0 + (x1), tmp0, xmask) - tl.store(out_ptr1 + (x1), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_add_uint[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1]) - test_common.validate_cmp(dtype, x1, x1_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_and.py b/third_party/ascend/unittest/generalization_cases/test_atomic_and.py deleted file mode 100644 index 0ef250741c..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_and.py +++ /dev/null @@ -1,562 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import test_common -from test_common import TestUtils - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'float16', 'float32', 'bfloat16', 'bool'}] - - -@triton.jit -def atomic_and(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask) - - -@triton.jit -def atomic_and_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic or: y &= x (broadcasted) - tl.atomic_and(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_and(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - # OR的时候任何位和0做OR都不变 任何位和1做AND也都不变,所以为了保持不变 不能用0 只能用1 - y = torch.full(shape, torch.iinfo(x_dtype).max, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] * 2 - atomic_and[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - aligned_size = grid_size * BLOCK_SIZE - x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - x_concat[0:n_elements] = x[0:n_elements] - x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] - atomic_and[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) - - expected = y_temp & x_temp[0:shape[0]] & x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and_3d(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_and[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - expected = y_temp & x_temp[0:shape[0]] & x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape_ub_overflow) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -@test_common.raises_with_match(triton.compiler.errors.MLIRCompilationError, "ub overflow") -def test_atomic_and_ub_overflow(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_and[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - -@triton.jit -def atomic_and_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_and(out_ptr0 + offsets, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_and_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = x1 & x0_value - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_and_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@triton.jit -def atomic_and_5d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr, - NB1: tl.constexpr): - base = tl.program_id(0) * (XB * YB * ZB * MB * NB) - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1 * NB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1 * NB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1 * NB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] * NB1 - offsets1 = offsets1[:, :, :, :, None] + tl.arange(0, NB1)[None, None, None, None, :] - - based_offsets = offsets + base - - tmp0 = tl.load(x_ptr + based_offsets) - tl.atomic_and(out_ptr + offsets1, tmp0) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 2, 1, 1), (1, 1, 2, 1, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and_5d(x_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(x0_shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - - out = torch.full(y_shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 5: - triton_shape1.append(1) - XB1, YB1, ZB1, MB1, NB1 = triton_shape1 - - atomic_and_5d[(2, )]( - x_ptr=x, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - NB=NB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - NB1=NB1, - ) - - expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_and_4d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr): - base = tl.program_id(0) * (XB * YB * ZB * MB) - offsets = tl.arange(0, XB) * (YB * ZB * MB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] - - based_offsets = offsets + base - - tmp0 = tl.load(x_ptr + based_offsets) - tl.atomic_and(out_ptr + offsets1, tmp0) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 2, 1), (1, 1, 2, 2)], - [(1, 1, 1, 1), (1, 1, 2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and_4d(x_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(x0_shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(y_shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 4: - triton_shape1.append(1) - XB1, YB1, ZB1, MB1 = triton_shape1 - - atomic_and_4d[(2, )]( - x_ptr=x, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - ) - - expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_and_3d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XB1: tl.constexpr, - YB1: tl.constexpr, ZB1: tl.constexpr): - base = tl.program_id(0) * (XB * YB * ZB) - offsets = tl.arange(0, XB) * (YB * ZB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] - - based_offsets = offsets + base - - tmp0 = tl.load(x_ptr + based_offsets) - tl.atomic_and(out_ptr + offsets1, tmp0) - - -@pytest.mark.parametrize('param_list', [ - [(1, 1, 1), (1, 1, 2)], - [(1, 1, 2), (1, 2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and_3d_2(x_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(x0_shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(y_shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 3: - triton_shape1.append(1) - XB1, YB1, ZB1 = triton_shape1 - - atomic_and_3d[(2, )]( - x_ptr=x, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - ) - - expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_and_2d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr): - base = tl.program_id(0) * (XB * YB) - offsets = tl.arange(0, XB) * (YB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] - - offsets1 = tl.arange(0, XB1) * (YB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] - - based_offsets = offsets + base - - tmp0 = tl.load(x_ptr + based_offsets) - tl.atomic_and(out_ptr + offsets1, tmp0) - - -@pytest.mark.parametrize('param_list', [ - [(1, 2), (2, 2)], - [(1, 1), (2, 2)], -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_and_2d(x_dtype_str, param_list): - x0_shape, y_shape = param_list - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(x0_shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(y_shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - - triton_shape1 = [*y_shape] - while len(triton_shape1) < 2: - triton_shape1.append(1) - XB1, YB1 = triton_shape1 - - atomic_and_2d[(2, )]( - x_ptr=x, - out_ptr=out, - XB=XB, - YB=YB, - XB1=XB1, - YB1=YB1, - ) - - expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@triton.jit -def atomic_and(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr, - mode: tl.constexpr = 0): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - if mode == 0: - tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask, 'acq_rel', 'cta') - elif mode == 1: - tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask, "test") - elif mode == 2: - tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask, "acq_rel", "test") - - -invalid_types_float = ['float16', 'float32', 'bfloat16'] - - -@pytest.mark.parametrize("sigtype", invalid_types_float) -@test_common.raises_with_match(triton.compiler.errors.MLIRCompilationError, "must be signless-integer-like") -def test_invalid_types_float(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - atomic_and[1, 1, 1](x, y, 1, 1, 32) - - -default_types = ['int8'] - - -@pytest.mark.parametrize("sigtype", default_types) -@pytest.mark.parametrize("test_type", ["sem", "scope"]) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Memory semantic test not supported") -def test_invalid_sem_scope(sigtype, test_type): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - if test_type == "sem": - atomic_and[1, 1, 1](x, y, 1, 1, 32, 1) - elif test_type == "scope": - atomic_and[1, 1, 1](x, y, 1, 1, 32, 2) - - -@triton.jit -def _atomic_and_ss(in_ptr, out_ptr, n_cols, BLOCK_SIZE: tl.constexpr, SEM: tl.constexpr, SCOPE: tl.constexpr): - pid = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = pid < n_cols - val = tl.load(in_ptr + pid, mask) - tl.atomic_and(out_ptr + pid, val, mask, sem=SEM, scope=SCOPE) - - -SEMS = ("relaxed", "acquire", "release", "acq_rel") -SCOPES = ("cta", "gpu", "sys") - - -@pytest.mark.parametrize("sem", SEMS) -@pytest.mark.parametrize("scope", SCOPES) -def test_atomic_sem_vs_scope(sem: str, scope: str): - n_cols = 1024 - BLOCK = 128 - grid = (triton.cdiv(n_cols, BLOCK), ) - - inp = torch.full((n_cols, ), 0xFF, dtype=torch.int32, device="npu") - - base = torch.full_like(inp, 0xFF) - _atomic_and_ss[grid](inp, base, n_cols, BLOCK_SIZE=BLOCK, SEM="acq_rel", SCOPE="gpu") - - cur = torch.full_like(inp, 0xFF) - _atomic_and_ss[grid](inp, cur, n_cols, BLOCK_SIZE=BLOCK, SEM=sem, SCOPE=scope) - - torch.testing.assert_close(cur, base) - - -@pytest.mark.parametrize('param_list', [ - ['uint8', (32, 32), 2], - ['uint16', (32, 32), 2], - ['uint32', (32, 32), 2], - ['uint64', (32, 32), 2], -]) -def test_atomic_and_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] // ncore - split_size = shape[0] // ncore - - val_cpu = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).cpu() - val = val_cpu.to("npu") - - pointer_cpu = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).cpu() - pointer = pointer_cpu.to("npu") - pointer_old_cpu = torch.full_like(pointer_cpu, -10).cpu() - pointer_old = pointer_old_cpu.to("npu") - pointer_ref_cpu = pointer_cpu.clone() - - for i in range(ncore - 1): - pointer_ref_cpu &= val_cpu[(i * split_size):((i + 1) * split_size)] - - pointer_ref_last = pointer_ref_cpu.clone() - pointer_ref_cpu &= val_cpu[((ncore - 1) * split_size):(ncore * split_size)] - pointer_ref = pointer_ref_cpu.to("npu") - - @triton.jit - def atomic_and_uint(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) - tmp1 = tl.atomic_and(out_ptr0 + (x1), tmp0, xmask) - tl.store(out_ptr1 + (x1), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_and_uint[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) - test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_cas.py b/third_party/ascend/unittest/generalization_cases/test_atomic_cas.py deleted file mode 100644 index eab2568755..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_cas.py +++ /dev/null @@ -1,484 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import test_common -from test_common import TestUtils -import numpy as np - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'bfloat16', 'int8', 'bool'}] - - -@triton.jit -def atomic_cas(in_ptr0, in_ptr1, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - tmp1 = tl.load(in_ptr1 + (in_index), xmask) - tl.atomic_cas(out_ptr0 + (out_index), tmp1, tmp0) - - -@triton.jit -def atomic_cas_ndim(x_ptr, y_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, DIM0: tl.constexpr, - DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): - sub_idx = tl.program_id(1) - base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE - base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE - offsets_src = tl.arange(0, BLOCK_SIZE) + base_src - offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst - mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 - tmp = tl.load(x_ptr + offsets_src, mask) - tmp_c = tl.load(y_ptr + offsets_src, mask) - tl.atomic_cas(out_ptr + offsets_dst, tmp_c, tmp) - - -@triton.jit -def atomic_cas_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic or: y |= x (broadcasted) - tl.atomic_cas(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_cas(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - c = torch.randint(low=0, high=2, size=x_shape, dtype=x_dtype).npu() - y = torch.randint(low=0, high=2, size=shape, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - c_temp = c.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] * 2 - atomic_cas[shape[0] * 2, 1, 1](x, c, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - aligned_size = grid_size * BLOCK_SIZE - # value - x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - x_concat[0:n_elements] = x[0:n_elements] - x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] - # compare - c_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - c_concat[0:n_elements] = c[0:n_elements] - c_concat[aligned_size:(aligned_size + n_elements)] = c[n_elements:(n_elements * 2)] - atomic_cas[grid_size * 2, 1, 1](x_concat, c_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, - BLOCK_NUM=grid_size) - - expected = torch.where(y_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], y_temp) - expected = torch.where(expected == c_temp[shape[0]:(shape[0] * 2)], x_temp[shape[0]:(shape[0] * 2)], expected) - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_3d(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - y = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - c_temp = c.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_cas[2, 1, 1](x, c, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - expected = torch.where(y_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], y_temp) - expected = torch.where(expected == c_temp[shape[0]:(shape[0] * 2)], x_temp[shape[0]:(shape[0] * 2)], expected) - torch.testing.assert_close(y, expected) - - -@triton.jit -def atomic_cas_multi_d(in_ptr0, in_ptr1, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - MB: tl.constexpr, NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tmp1 = tl.load(in_ptr1 + offsets) - tl.atomic_cas(out_ptr0 + offsets, tmp1, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_cas_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - c = torch.randint(low=2, high=4, size=shape, dtype=eval('torch.' + dtype)).npu() - x1 = torch.randint(low=2, high=4, size=shape, dtype=eval('torch.' + dtype)).npu() - - x1_ref = torch.where(x1 == c, 3, x1) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - - atomic_cas_multi_d[(1, )](x0, c, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1, 2), - (10, 1, 15, 1, 7), - (1, 1, 1, 1, 257), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_5d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() - - x_temp = x.clone() - c_temp = c.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - BLOCK_SIZE = 256 - ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_cas_ndim[(2 * XB * YB * ZB * MB, ncore)]( - x_ptr=x, - y_ptr=c, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=XB, - DIM1=YB, - DIM2=ZB, - DIM3=MB, - DIM4=NB, - ) - - expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) - expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1), - (1, 1, 2, 2), - (1, 3, 2, 7), - (1, 3, 2, 651), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_4d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() - - x_temp = x.clone() - c_temp = c.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - BLOCK_SIZE = 256 - ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_cas_ndim[(2 * XB * YB * ZB, ncore)]( - x_ptr=x, - y_ptr=c, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=XB, - DIM2=YB, - DIM3=ZB, - DIM4=MB, - ) - - expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) - expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1), - (1, 1, 2), - (1, 31, 275), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_3d_2(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() - - x_temp = x.clone() - c_temp = c.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - BLOCK_SIZE = 256 - ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_cas_ndim[(2 * XB * YB, ncore)]( - x_ptr=x, - y_ptr=c, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=XB, - DIM3=YB, - DIM4=ZB, - ) - - expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) - expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 2), - (1, 1), - (257, 1), - (257, 2), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_2d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() - - x_temp = x.clone() - c_temp = c.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - BLOCK_SIZE = 256 - ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_cas_ndim[(2 * XB, ncore)]( - x_ptr=x, - y_ptr=c, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=XB, - DIM4=YB, - ) - - expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) - expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [(1, ), (9, ), (256, ), (257, ), (65535, ), (65536, )]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_cas_1d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - c_temp = c.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB = triton_shape[0] - BLOCK_SIZE = 256 - ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_cas_ndim[(2, ncore)]( - x_ptr=x, - y_ptr=c, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=1, - DIM4=XB, - ) - - expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) - expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('param_list', [ - ['uint16', (32, 32), 2], - ['uint32', (32, 32), 2], - ['uint64', (32, 32), 2], -]) -def test_atomic_cas_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] // ncore - split_size = shape[0] // ncore - - import random - cmp_val = [random.randint(0, 10) for _ in range(ncore)] - - cmp_cpu_parts = [] - for i in range(ncore): - part = torch.ones(split_size, shape[1], dtype=eval(f'torch.{dtype}')) * cmp_val[i] - cmp_cpu_parts.append(part) - cmp_cpu = torch.cat(cmp_cpu_parts, dim=0) - cmp = cmp_cpu.to("npu") - - val_cpu = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).cpu() - val = val_cpu.to("npu") - - pointer_cpu = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).cpu() - pointer = pointer_cpu.to("npu") - pointer_old_cpu = torch.full_like(pointer_cpu, -10).cpu() - pointer_old = pointer_old_cpu.to("npu") - pointer_ref_cpu = pointer_cpu.clone() - - pointer_ref_np = pointer_cpu.numpy() - val_np = val_cpu.numpy() - for i in range(ncore): - val_subview_np = val_np[(i * split_size):((i + 1) * split_size)] - pointer_ref_np = np.where(pointer_ref_np == cmp_val[i], val_subview_np, pointer_ref_np) - pointer_ref_cpu = torch.from_numpy(pointer_ref_np) - pointer_ref = pointer_ref_cpu.to("npu") - - @triton.jit - def atomic_cas_uint(in_ptr0, in_ptr1, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - val = tl.load(in_ptr0 + (x0), xmask) - cmp = tl.load(in_ptr1 + (x0), xmask) - tmp1 = tl.atomic_cas(out_ptr0 + (x1), cmp, val) - tl.store(out_ptr1 + (x1), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_cas_uint[ncore, 1, 1](val, cmp, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) - test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_max.py b/third_party/ascend/unittest/generalization_cases/test_atomic_max.py deleted file mode 100644 index 87347cb2fd..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_max.py +++ /dev/null @@ -1,258 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import random -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_test_fn_atomic_max_dma(in_ptr0, in_ptr1, out_ptr1, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - index = xoffset + tl.arange(0, BLOCK_SIZE)[:] - mask = index < n_elements - inp0 = tl.load(in_ptr0 + (index), mask) - inp1 = tl.load(in_ptr1 + (index), mask) - tmp1 = tl.atomic_max(out_ptr1 + (index), inp0, mask) - tmp2 = tl.atomic_max(out_ptr1 + (index), inp1, mask) - - -def promote_dtype(x_dtype, y_dtype): - """ - 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 - """ - # 如果两个数据类型一致,直接返回 - if x_dtype == y_dtype: - return y_dtype - - # 构建类型的优先级列表(从低到高) - priority = [torch.int8, torch.int16, torch.int32, torch.float16, torch.bfloat16, torch.float32] - - # 查找两种类型在优先级列表中的位置 - x_priority = priority.index(x_dtype) - y_priority = priority.index(y_dtype) - - # 如果y的优先级比x小,则提升到x的类型 - if y_priority < x_priority: - return x_dtype - else: - return y_dtype - - -# torch.max do not support int -@pytest.mark.parametrize('shape', random.sample(TestUtils.test_shape2d + TestUtils.test_shape1d, 5)) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_max(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x0 = test_common.generate_tensor(shape, x_dtype_str) - x1 = test_common.generate_tensor(shape, y_dtype_str) - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(x1.shape, 0, dtype=out_dtype) - - out_ref = torch.maximum(out, x0) - out_ref = torch.maximum(out_ref, x1) - out_ref = out_ref.npu() - x0 = x0.npu() - x1 = x1.npu() - out = out.npu() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] - triton_test_fn_atomic_max_dma[shape[0], 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=shape[1]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - triton_test_fn_atomic_max_dma[grid_size, 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) - - torch.testing.assert_close(out, out_ref) - - -# 3d -testlist = [ - (1, 22, 39), - (27, 1, 39), - (27, 22, 1), - (1, 1, 23), - (23, 1, 1), - (1, 23, 1), - (27, 5, 3), - (2, 29, 4), - (7, 31, 7), - (3, 5, 8), - (7, 17, 15), - (25, 5, 16), - (23, 5, 31), - (7, 11, 32), - (7, 11, 33), - (2, 3, 255), - (3, 3, 256), - (3, 2, 257), -] - - -@pytest.mark.parametrize('shape', random.sample(testlist, 5)) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_max_3d(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - ncore = 1 - split_size = shape[0] // ncore - x0 = test_common.generate_tensor(shape, x_dtype_str) - x1 = test_common.generate_tensor(shape, y_dtype_str) - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - y = torch.full(shape, 0, dtype=out_dtype) - - out_ref = torch.full_like(x0, 0, dtype=out_dtype) - out_ref = torch.maximum(out_ref, x0) - out_ref = torch.maximum(out_ref, x1) - x0 = x0.npu() - x1 = x1.npu() - y = y.npu() - - n_elements = shape[0] * shape[1] * shape[2] - triton_test_fn_atomic_max_dma[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1] * shape[2]) - y = y.cpu() - torch.testing.assert_close(y, out_ref) - - -@triton.jit -def atomic_max_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_max(out_ptr0 + offsets, tmp0) - - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'int64', 'bool'}] - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_max_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = torch.maximum(x1, x0) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_max_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@triton.jit -def atomic_max_multi_d_2(in_ptr0, out_ptr0, out_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - MB: tl.constexpr, NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tmp1 = tl.load(out_ptr0 + offsets) - tl.atomic_max(out_ptr1 + offsets, tmp0) - tl.atomic_max(out_ptr1 + offsets, tmp1) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_max_4d_5d_2(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - x1 = torch.randint(low=0, high=100, size=shape, dtype=y_dtype).npu() - else: - x1 = torch.randn(shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: - y = torch.full(shape, torch.iinfo(out_dtype).min, dtype=out_dtype).npu() - else: - y = torch.full(shape, float('-inf'), dtype=out_dtype).npu() - - y_tmp = y - x1_ref = torch.maximum(y_tmp, x0) - x1_ref = torch.maximum(x1_ref, x1) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_max_multi_d_2[(1, )](x0, x1, y, *triton_shape) - torch.testing.assert_close(y, x1_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_min.py b/third_party/ascend/unittest/generalization_cases/test_atomic_min.py deleted file mode 100644 index a74e99058f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_min.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import random -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_test_fn_atomic_min_dma(in_ptr0, in_ptr1, out_ptr1, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - index = xoffset + tl.arange(0, BLOCK_SIZE)[:] - mask = index < n_elements - inp0 = tl.load(in_ptr0 + (index), mask) - inp1 = tl.load(in_ptr1 + (index), mask) - tmp1 = tl.atomic_min(out_ptr1 + (index), inp0, mask) - tmp2 = tl.atomic_min(out_ptr1 + (index), inp1, mask) - - -def promote_dtype(x_dtype, y_dtype): - """ - 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 - """ - # 如果两个数据类型一致,直接返回 - if x_dtype == y_dtype: - return y_dtype - - # 构建类型的优先级列表(从低到高) - priority = [torch.int8, torch.int16, torch.int32, torch.float16, torch.bfloat16, torch.float32] - - # 查找两种类型在优先级列表中的位置 - x_priority = priority.index(x_dtype) - y_priority = priority.index(y_dtype) - - # 如果y的优先级比x小,则提升到x的类型 - if y_priority < x_priority: - return x_dtype - else: - return y_dtype - - -# torch.min do not support int -@pytest.mark.parametrize('shape', random.sample(TestUtils.test_shape2d + TestUtils.test_shape1d, 5)) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_min(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x0 = test_common.generate_tensor(shape, x_dtype_str) - x1 = test_common.generate_tensor(shape, y_dtype_str) - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: # 判断是否是整数类型 - out = torch.full(x1.shape, torch.iinfo(out_dtype).max, dtype=out_dtype) - else: - out = torch.full(x1.shape, torch.finfo(out_dtype).max, dtype=out_dtype) - - out_ref = torch.minimum(out, x0) - out_ref = torch.minimum(out_ref, x1) - out_ref = out_ref.npu() - x0 = x0.npu() - x1 = x1.npu() - out = out.npu() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] - triton_test_fn_atomic_min_dma[shape[0], 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=shape[1]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - triton_test_fn_atomic_min_dma[grid_size, 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) - - torch.testing.assert_close(out, out_ref) - - -# 3d -testlist = [ - (1, 22, 39), - (27, 1, 39), - (27, 22, 1), - (1, 1, 23), - (23, 1, 1), - (1, 23, 1), - (27, 5, 3), - (2, 29, 4), - (7, 31, 7), - (3, 5, 8), - (7, 17, 15), - (25, 5, 16), - (23, 5, 31), - (7, 11, 32), - (7, 11, 33), - (2, 3, 255), - (3, 3, 256), - (3, 2, 257), -] - - -@pytest.mark.parametrize('shape', random.sample(testlist, 5)) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_min_3d(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - ncore = 1 - split_size = shape[0] // ncore - x0 = test_common.generate_tensor(shape, x_dtype_str) - x1 = test_common.generate_tensor(shape, y_dtype_str) - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: - y = torch.full(shape, torch.iinfo(out_dtype).max, dtype=out_dtype) - else: - y = torch.full(shape, float('inf'), dtype=out_dtype) - - y_tmp = y - x1_ref = torch.minimum(y_tmp, x0) - x1_ref = torch.minimum(x1_ref, x1) - x0 = x0.npu() - x1 = x1.npu() - y = y.npu() - - n_elements = shape[0] * shape[1] * shape[2] - triton_test_fn_atomic_min_dma[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1] * shape[2]) - y = y.cpu() - torch.testing.assert_close(y, x1_ref) - - -@triton.jit -def atomic_min_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_min(out_ptr0 + offsets, tmp0) - - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'int64', 'bool'}] - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_min_4d_5d(dtype, shape): - x0_value = 1 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = torch.minimum(x1, x0) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_min_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@triton.jit -def atomic_min_multi_d_2(in_ptr0, out_ptr0, out_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - MB: tl.constexpr, NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tmp1 = tl.load(out_ptr0 + offsets) - tl.atomic_min(out_ptr1 + offsets, tmp0) - tl.atomic_min(out_ptr1 + offsets, tmp1) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) -def test_atomic_min_4d_5d_2(x_dtype_str, y_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: - x0 = torch.randint(low=0, high=100, size=shape, dtype=x_dtype).npu() - else: - x0 = torch.randn(shape, dtype=eval('torch.' + x_dtype_str)).npu() - - if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: - x1 = torch.randint(low=0, high=100, size=shape, dtype=y_dtype).npu() - else: - x1 = torch.randn(shape, dtype=eval('torch.' + y_dtype_str)).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: - y = torch.full(shape, torch.iinfo(out_dtype).max, dtype=out_dtype).npu() - else: - y = torch.full(shape, float('inf'), dtype=out_dtype).npu() - - y_tmp = y - x1_ref = torch.minimum(y_tmp, x0) - x1_ref = torch.minimum(x1_ref, x1) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_min_multi_d_2[(1, )](x0, x1, y, *triton_shape) - torch.testing.assert_close(y, x1_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_or.py b/third_party/ascend/unittest/generalization_cases/test_atomic_or.py deleted file mode 100644 index 4e5493b362..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_or.py +++ /dev/null @@ -1,438 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import test_common -from test_common import TestUtils - -filtered_dtype = [ - dtype for dtype in TestUtils.full_dtype - if dtype not in {'uint32', 'float16', 'float32', 'bfloat16', 'int64', 'bool'} -] - - -@triton.jit -def atomic_or(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - tl.atomic_or(out_ptr0 + (out_index), tmp0, xmask) - - -@triton.jit -def atomic_or_ndim(x_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, DIM0: tl.constexpr, - DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): - sub_idx = tl.program_id(1) - base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE - base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE - offsets_src = tl.arange(0, BLOCK_SIZE) + base_src - offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst - mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 - tmp = tl.load(x_ptr + offsets_src, mask) - tl.atomic_or(out_ptr + offsets_dst, tmp, mask) - - -@triton.jit -def atomic_or_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic or: y |= x (broadcasted) - tl.atomic_or(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_or(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] * 2 - atomic_or[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - aligned_size = grid_size * BLOCK_SIZE - x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - x_concat[0:n_elements] = x[0:n_elements] - x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] - atomic_or[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) - - expected = y_temp | x_temp[0:shape[0]] | x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_3d(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_or[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - expected = y_temp | x_temp[0:shape[0]] | x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -@triton.jit -def atomic_or_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_or(out_ptr0 + offsets, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_or_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = x1 | x0_value - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_or_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1, 2), - (10, 1, 15, 1, 7), - (1, 1, 1, 1, 257), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_5d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - BLOCK_SIZE = 256 - ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_or_ndim[(2 * XB * YB * ZB * MB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=XB, - DIM1=YB, - DIM2=ZB, - DIM3=MB, - DIM4=NB, - ) - - expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1), - (1, 1, 2, 2), - (1, 3, 2, 7), - (1, 3, 2, 651), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_4d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - BLOCK_SIZE = 256 - ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_or_ndim[(2 * XB * YB * ZB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=XB, - DIM2=YB, - DIM3=ZB, - DIM4=MB, - ) - - expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1), - (1, 1, 2), - (1, 31, 275), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_3d_2(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - BLOCK_SIZE = 256 - ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_or_ndim[(2 * XB * YB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=XB, - DIM3=YB, - DIM4=ZB, - ) - - expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 2), - (1, 1), - (257, 1), - (257, 2), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_2d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - BLOCK_SIZE = 256 - ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_or_ndim[(2 * XB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=XB, - DIM4=YB, - ) - - expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [(1, ), (9, ), (256, ), (257, ), (65535, ), (65536, )]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_or_1d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB = triton_shape[0] - BLOCK_SIZE = 256 - ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_or_ndim[(2, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=1, - DIM4=XB, - ) - - expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('param_list', [ - ['uint8', (32, 32), 2], - ['uint16', (32, 32), 2], - ['uint32', (32, 32), 2], - ['uint64', (32, 32), 2], -]) -def test_atomic_or_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] // ncore - split_size = shape[0] // ncore - - val_cpu = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).cpu() - val = val_cpu.to("npu") - - pointer_cpu = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).cpu() - pointer = pointer_cpu.to("npu") - pointer_old_cpu = torch.full_like(pointer_cpu, -10).cpu() - pointer_old = pointer_old_cpu.to("npu") - pointer_ref_cpu = pointer_cpu.clone() - - for i in range(ncore - 1): - pointer_ref_cpu |= val_cpu[(i * split_size):((i + 1) * split_size)] - - pointer_ref_last = pointer_ref_cpu.clone() - pointer_ref_cpu |= val_cpu[((ncore - 1) * split_size):(ncore * split_size)] - pointer_ref = pointer_ref_cpu.to("npu") - - @triton.jit - def atomic_or_uint(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) - tmp1 = tl.atomic_or(out_ptr0 + (x1), tmp0, xmask) - tl.store(out_ptr1 + (x1), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_or_uint[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) - test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_xchg.py b/third_party/ascend/unittest/generalization_cases/test_atomic_xchg.py deleted file mode 100644 index 740378e929..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_xchg.py +++ /dev/null @@ -1,434 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import test_common -from test_common import TestUtils - -filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'bfloat16', 'bool'}] - - -@triton.jit -def atomic_xchg(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - tl.atomic_xchg(out_ptr0 + (out_index), tmp0, xmask) - - -@triton.jit -def atomic_xchg_ndim(x_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, DIM0: tl.constexpr, - DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): - sub_idx = tl.program_id(1) - base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE - base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE - offsets_src = tl.arange(0, BLOCK_SIZE) + base_src - offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst - mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 - tmp = tl.load(x_ptr + offsets_src, mask) - tl.atomic_xchg(out_ptr + offsets_dst, tmp, mask) - - -@triton.jit -def atomic_xchg_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic or: y |= x (broadcasted) - tl.atomic_xchg(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_xchg(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] * 2 - atomic_xchg[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - aligned_size = grid_size * BLOCK_SIZE - x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - x_concat[0:n_elements] = x[0:n_elements] - x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] - atomic_xchg[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) - - expected = x_temp[shape[0]:(shape[0] * 2)].expand(y_temp.shape) - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_3d(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_xchg[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - expected = x_temp[shape[0]:(shape[0] * 2)].expand(y_temp.shape) - torch.testing.assert_close(y, expected) - - -@triton.jit -def atomic_xchg_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_xchg(out_ptr0 + offsets, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_xchg_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = x0 - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_xchg_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@pytest.mark.parametrize('shaape', [ - (1, 1, 1, 1, 2), - (10, 1, 15, 1, 7), - (1, 1, 1, 1, 257), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_5d(x_dtype_str, shaape): - shape = shaape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - BLOCK_SIZE = 256 - ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xchg_ndim[(2 * XB * YB * ZB * MB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=XB, - DIM1=YB, - DIM2=ZB, - DIM3=MB, - DIM4=NB, - ) - - expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shaape', [ - (1, 1, 1, 1), - (1, 1, 2, 2), - (1, 3, 2, 7), - (1, 3, 2, 651), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_4d(x_dtype_str, shaape): - shape = shaape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - BLOCK_SIZE = 256 - ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xchg_ndim[(2 * XB * YB * ZB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=XB, - DIM2=YB, - DIM3=ZB, - DIM4=MB, - ) - - expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shaape', [ - (1, 1, 1), - (1, 1, 2), - (1, 31, 275), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_3d_2(x_dtype_str, shaape): - shape = shaape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - BLOCK_SIZE = 256 - ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xchg_ndim[(2 * XB * YB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=XB, - DIM3=YB, - DIM4=ZB, - ) - - expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shaape', [ - (1, 2), - (1, 1), - (257, 1), - (257, 2), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_2d(x_dtype_str, shaape): - shape = shaape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - BLOCK_SIZE = 256 - ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xchg_ndim[(2 * XB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=XB, - DIM4=YB, - ) - - expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shaape', [(1, ), (9, ), (256, ), (257, ), (65535, ), (65536, )]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xchg_1d(x_dtype_str, shaape): - shape = shaape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB = triton_shape[0] - BLOCK_SIZE = 256 - ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xchg_ndim[(2, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=1, - DIM4=XB, - ) - - expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('param_list', [['uint8', (32, 32), 2], ['uint16', - (32, 32), 2], ['uint32', - (32, 32), 2], ['uint64', (32, 32), 2]]) -def test_atomic_xchg_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] // ncore - split_size = shape[0] // ncore - - val_cpu = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).cpu() - val = val_cpu.to("npu") - - pointer_cpu = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).cpu() - pointer = pointer_cpu.to("npu") - - pointer_ref = pointer.clone() - pointer_old_cpu = torch.full_like(val_cpu, -10).cpu() - pointer_old = pointer_old_cpu.to("npu") - pointer_old_ref = pointer_old.clone() - - pointer_ref = val[((ncore - 1) * split_size):(ncore * split_size)].clone() - pointer_old_ref[0:split_size] = pointer - pointer_old_ref[split_size:((ncore - 1) * split_size)] = val[0:(ncore - 2) * split_size] - - @triton.jit - def atomic_xchg_uint(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) - tmp1 = tl.atomic_xchg(out_ptr0 + (x1), tmp0, xmask) - tl.store(out_ptr1 + (x0), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_xchg_uint[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) - - pointer_cpu = pointer.cpu() - pointer_ref_cpu = pointer_ref.cpu() - assert (pointer_cpu == pointer_ref_cpu).all() diff --git a/third_party/ascend/unittest/generalization_cases/test_atomic_xor.py b/third_party/ascend/unittest/generalization_cases/test_atomic_xor.py deleted file mode 100644 index 4a83697261..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_atomic_xor.py +++ /dev/null @@ -1,441 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import triton - -import triton.language as tl - -import test_common -from test_common import TestUtils - -filtered_dtype = [ - dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'float16', 'float32', 'bfloat16', 'bool'} -] - - -@triton.jit -def atomic_xor(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): - in_offset = tl.program_id(0) * BLOCK_SIZE - out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE - in_index = in_offset + tl.arange(0, BLOCK_SIZE) - out_index = out_offset + tl.arange(0, BLOCK_SIZE) - xmask = in_index < n_elements - - tmp0 = tl.load(in_ptr0 + (in_index), xmask) - tl.atomic_xor(out_ptr0 + (out_index), tmp0, xmask) - - -@triton.jit -def atomic_xor_ndim(x_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, DIM0: tl.constexpr, - DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): - sub_idx = tl.program_id(1) - base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE - base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE - offsets_src = tl.arange(0, BLOCK_SIZE) + base_src - offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst - mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 - tmp = tl.load(x_ptr + offsets_src, mask) - tl.atomic_xor(out_ptr + offsets_dst, tmp, mask) - - -@triton.jit -def atomic_xor_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - x = tl.load(x_ptr) # x is scalar or 1D, no mask needed - - # Compute y indices - y_offset = pid * BLOCK_SIZE - y_indices = y_offset + tl.arange(0, BLOCK_SIZE) - y_mask = y_indices < n_elements - - y_value = tl.load(y_ptr + y_indices, y_mask) - # Atomic or: y |= x (broadcasted) - tl.atomic_xor(out_ptr + y_indices, y_value, mask=y_mask) - tl.atomic_xor(out_ptr + y_indices, x, mask=y_mask) - - -# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) -test_cases = [ - ((1, 1, 1, 1), (1, 1, 1, 4), 4), - ((1, 1, 1, 3), (1, 5, 1, 3), 5), - ((3, ), (2, 3, 3, 3, 3), 81), - ((3, ), (2, 3, 3, 3), 27), - ((3, ), (2, 3, 3), 9), - ((3, ), (2, 3), 3), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - if len(shape) == 1 and shape[0] == 1: # golden 问题,手动验证 - return - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - if len(shape) == 2: - n_elements = shape[0] * shape[1] * 2 - atomic_xor[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) - elif len(shape) == 1: - n_elements = shape[0] - BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 - grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 - aligned_size = grid_size * BLOCK_SIZE - x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() - x_concat[0:n_elements] = x[0:n_elements] - x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] - atomic_xor[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) - - expected = y_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -# 3d -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_3d(x_dtype_str, shape): - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = torch.full(shape, 0, dtype=x_dtype).npu() - - # 保存副本用于验证 - x_temp = x.clone() - y_temp = y.clone() - - n_elements = shape[0] * shape[1] * shape[2] - atomic_xor[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) - - expected = y_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:(shape[0] * 2)] - torch.testing.assert_close(y, expected) - - -@triton.jit -def atomic_xor_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - tmp0 = tl.load(in_ptr0 + offsets) - tl.atomic_xor(out_ptr0 + offsets, tmp0) - - -# multi_d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 8, 4), - (8, 4, 2, 4), - (2, 8, 2, 2), - (2, 4, 8, 4, 2), - (8, 4, 2, 4, 4), - (2, 8, 2, 2, 2), -]) -@pytest.mark.parametrize('dtype', filtered_dtype) -def test_atomic_xor_4d_5d(dtype, shape): - x0_value = 3 - x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() - x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() - - x1_ref = x1 ^ x0_value - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - atomic_xor_multi_d[(1, )](x0, x1, *triton_shape) - test_common.validate_cmp(dtype, x1, x1_ref) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1, 2), - (10, 1, 15, 1, 7), - (1, 1, 1, 1, 257), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_5d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - XB, YB, ZB, MB, NB = triton_shape - BLOCK_SIZE = 256 - ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xor_ndim[(2 * XB * YB * ZB * MB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=XB, - DIM1=YB, - DIM2=ZB, - DIM3=MB, - DIM4=NB, - ) - - expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1, 1), - (1, 1, 2, 2), - (1, 3, 2, 7), - (1, 3, 2, 651), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_4d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 4: - triton_shape.append(1) - XB, YB, ZB, MB = triton_shape - - BLOCK_SIZE = 256 - ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xor_ndim[(2 * XB * YB * ZB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=XB, - DIM2=YB, - DIM3=ZB, - DIM4=MB, - ) - - expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 1, 1), - (1, 1, 2), - (1, 31, 275), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_3d_2(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 3: - triton_shape.append(1) - XB, YB, ZB = triton_shape - BLOCK_SIZE = 256 - ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xor_ndim[(2 * XB * YB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=XB, - DIM3=YB, - DIM4=ZB, - ) - - expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [ - (1, 2), - (1, 1), - (257, 1), - (257, 2), -]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_2d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB, YB = triton_shape - BLOCK_SIZE = 256 - ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xor_ndim[(2 * XB, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=XB, - DIM4=YB, - ) - - expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', [(1, ), (9, ), (256, ), (257, ), (65535, ), (65536, )]) -@pytest.mark.parametrize('x_dtype_str', filtered_dtype) -def test_atomic_xor_1d(x_dtype_str, shape): - shape = shape - - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - x_shape = list(shape[:]) - x_shape[0] *= 2 - x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() - out = torch.full(shape, 0, dtype=x_dtype).npu() - - x_temp = x.clone() - out_temp = out.clone() - - triton_shape = [*shape] - while len(triton_shape) < 2: - triton_shape.append(1) - XB = triton_shape[0] - BLOCK_SIZE = 256 - ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE - - atomic_xor_ndim[(2, ncore)]( - x_ptr=x, - out_ptr=out, - NCORE=ncore, - BLOCK_SIZE=BLOCK_SIZE, - DIM0=1, - DIM1=1, - DIM2=1, - DIM3=1, - DIM4=XB, - ) - - expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('param_list', [ - ['uint8', (32, 32), 2], - ['uint16', (32, 32), 2], - ['uint32', (32, 32), 2], - ['uint64', (32, 32), 2], -]) -def test_atomic_xor_uint(param_list): - dtype, shape, ncore = param_list - block_size = shape[0] * shape[1] // ncore - split_size = shape[0] // ncore - - val_value = 3 - val_cpu = torch.full(shape, val_value, dtype=eval(f'torch.{dtype}')).cpu() - val = val_cpu.to("npu") - - pointer_value = 5 - pointer_cpu = torch.full((split_size, shape[1]), pointer_value, dtype=eval(f'torch.{dtype}')).cpu() - pointer = pointer_cpu.to("npu") - pointer_old_cpu = torch.full_like(pointer_cpu, -10).cpu() - pointer_old = pointer_old_cpu.to("npu") - - pointer_result = pointer_value - for _ in range(ncore): - pointer_result ^= val_value - - pointer_ref_cpu = torch.full_like(pointer_cpu, pointer_result).cpu() - pointer_ref = pointer_ref_cpu.to("npu") - - @triton.jit - def atomic_xor_uint(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): - xoffset = tl.program_id(0) * BLOCK_SIZE - xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] - yindex = tl.arange(0, BLOCK_SIZE)[:] - xmask = xindex < n_elements - x0 = xindex - x1 = yindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) - tmp1 = tl.atomic_xor(out_ptr0 + (x1), tmp0, xmask) - tl.store(out_ptr1 + (x1), tmp1, xmask) - - n_elements = shape[0] * shape[1] - atomic_xor_uint[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) - test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_broadcast.py b/third_party/ascend/unittest/generalization_cases/test_broadcast.py deleted file mode 100644 index e9f7a46d8b..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_broadcast.py +++ /dev/null @@ -1,324 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_broadcast_1d(output_ptr, x_ptr, XS: tl.constexpr, YS: tl.constexpr): - xidx = tl.arange(0, XS)[None, :] - base = tl.load(x_ptr + xidx) - out = base.broadcast_to((YS, XS)) - oidx = tl.arange(0, YS)[:, None] * XS + tl.arange(0, XS)[None, :] - tl.store(output_ptr + oidx, out) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_npu_1d(shape, dtype): - XS = shape[0] - YS = 4 - - x = test_common.generate_tensor((XS, ), dtype=dtype).npu() - std = torch.broadcast_to(x, (YS, XS)) - output = test_common.generate_tensor((YS, XS), dtype=dtype).npu() - fn_broadcast_1d[1, 1, 1](output, x, XS, YS) - test_common.validate_cmp(dtype, std, output) - - -@triton.jit -def fn_broadcast_2d(output_ptr, x_ptr, NUMEL: tl.constexpr, XS: tl.constexpr, YS: tl.constexpr, ZS: tl.constexpr): - zoffset = tl.program_id(0) * ZS - zidx = tl.arange(0, ZS)[None, :] - base = tl.load(x_ptr + zoffset + zidx) - out = base.broadcast_to((YS, ZS)) - oidx = zoffset * YS + tl.arange(0, YS)[:, None] * ZS + zidx - tl.store(output_ptr + oidx, out) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_npu_2d(shape, dtype): - XS = shape[0] - ZS = shape[1] - YS = 4 - NUMEL = XS * ZS - - x = test_common.generate_tensor((XS, 1, ZS), dtype=dtype).npu() # randn not support int type - std = torch.broadcast_to(x, (XS, YS, ZS)) - output = test_common.generate_tensor((XS, YS, ZS), dtype=dtype).npu() - fn_broadcast_2d[XS, 1, 1](output, x, NUMEL, XS, YS, ZS) - test_common.validate_cmp(dtype, std, output) - - -@triton.jit -def triton_broadcast_to_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim0(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, M, N), dtype=dtype).npu() - ans = x0.repeat(L, 1, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim0[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim1(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, 1, N), dtype=dtype).npu() - ans = x0.repeat(1, M, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim1[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * 1 * M + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim2(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, M, 1), dtype=dtype).npu() - ans = x0.repeat(1, 1, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim2[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim01(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim01(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, 1, N), dtype=dtype).npu() - ans = x0.repeat(L, M, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim01[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim02(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * M * 1 + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim02(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, M, 1), dtype=dtype).npu() - ans = x0.repeat(L, 1, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim02[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim12(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * 1 * 1 + tl.arange(0, 1)[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim12(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, 1, 1), dtype=dtype).npu() - ans = x0.repeat(1, M, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim12[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def fn_broadcast_multi_d(to_ptr, from_ptr, F_L: tl.constexpr, F_M: tl.constexpr, F_N: tl.constexpr, F_X: tl.constexpr, - F_Y: tl.constexpr, T_L: tl.constexpr, T_M: tl.constexpr, T_N: tl.constexpr, T_X: tl.constexpr, - T_Y: tl.constexpr): - from_offsets = tl.arange(0, F_L) - if F_M is not None: - from_offsets = from_offsets[:, None] * F_M + tl.arange(0, F_M)[None, :] - if F_N is not None: - from_offsets = from_offsets[:, :, None] * F_N + tl.arange(0, F_N)[None, None, :] - if F_X is not None: - from_offsets = from_offsets[:, :, :, None] * F_X + tl.arange(0, F_X)[None, None, None, :] - if F_Y is not None: - from_offsets = from_offsets[:, :, :, :, None] * F_Y + tl.arange(0, F_Y)[None, None, None, None, :] - - to_offsets = tl.arange(0, T_L) - if T_M is not None: - to_offsets = to_offsets[:, None] * T_M + tl.arange(0, T_M)[None, :] - if T_N is not None: - to_offsets = to_offsets[:, :, None] * T_N + tl.arange(0, T_N)[None, None, :] - if T_X is not None: - to_offsets = to_offsets[:, :, :, None] * T_X + tl.arange(0, T_X)[None, None, None, :] - if T_Y is not None: - to_offsets = to_offsets[:, :, :, :, None] * T_Y + tl.arange(0, T_Y)[None, None, None, None, :] - - from_data = tl.load(from_ptr + from_offsets) - to_data = tl.load(to_ptr + to_offsets) - ret_data = tl.broadcast(from_data, to_data) - - tl.store(to_ptr + to_offsets, ret_data) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shapes', [ - [(1, 64, 16, 1), (2, 64, 16, 2)], - [(8, 1, 1, 2), (8, 8, 4, 2)], -]) -@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) -def test_broadcast_to_4d(shapes, dtype): - from_shape, to_shape = shapes - dtype = eval(f"torch.{dtype}") - - x = torch.randint(0, 8, from_shape, dtype=dtype).npu() - y = torch.randint(0, 8, to_shape, dtype=dtype).npu() - expected = x.expand(to_shape) - - grid = (1, ) - triton_from_shape = [*from_shape] - triton_to_shape = [*to_shape] - while len(triton_from_shape) < 5: - triton_from_shape.append(None) - triton_to_shape.append(None) - fn_broadcast_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) - assert (torch.equal(y, expected)) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) -@pytest.mark.parametrize('shapes', [ - [(1, 4, 2, 1, 4), (2, 4, 2, 8, 4)], - [(3, 1, 2, 1, 4), (3, 4, 2, 8, 4)], -]) -def test_broadcast_to_5d(shapes, dtype): - from_shape, to_shape = shapes - dtype = eval(f"torch.{dtype}") - - x = torch.randint(0, 8, from_shape, dtype=dtype).npu() - y = torch.randint(0, 8, to_shape, dtype=dtype).npu() - expected = x.expand(to_shape) - - grid = (1, ) - triton_from_shape = [*from_shape] - triton_to_shape = [*to_shape] - while len(triton_from_shape) < 5: - triton_from_shape.append(None) - triton_to_shape.append(None) - fn_broadcast_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) - assert (torch.equal(y, expected)) - - -@triton.jit -def fn_broadcast(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - x1 = tl.load(out_ptr0 + odx) - ret = tl.broadcast(x, x1) - tl.store(out_ptr0 + odx, ret) - - -XS: tl.constexpr = 2 -YS: tl.constexpr = 4 -ZS: tl.constexpr = 8 - - -@pytest.mark.parametrize('dtype', - ["uint8", "int8", "int16", "int32", "int64", "float16", "float32", "bfloat16", "bool"]) -def test_broadcast_alltype(dtype): - input = test_common.generate_tensor((1, YS, ZS), dtype).npu() - ans = input.repeat(XS, 1, 1) - output = torch.zeros((XS, YS, ZS), dtype=eval('torch.' + dtype)).npu() - fn_broadcast[1, 1, 1](input, output, XS, YS, ZS) - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_broadcast_to.py b/third_party/ascend/unittest/generalization_cases/test_broadcast_to.py deleted file mode 100644 index 4ec6173874..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_broadcast_to.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import time - -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_broadcast_1d(output_ptr, x_ptr, XS: tl.constexpr, YS: tl.constexpr): - xidx = tl.arange(0, XS)[None, :] - base = tl.load(x_ptr + xidx) - out = base.broadcast_to((YS, XS)) - oidx = tl.arange(0, YS)[:, None] * XS + tl.arange(0, XS)[None, :] - tl.store(output_ptr + oidx, out) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_npu_1d(shape, dtype): - XS = shape[0] - YS = 4 - - x = test_common.generate_tensor((XS, ), dtype=dtype).npu() - std = torch.broadcast_to(x, (YS, XS)) - output = test_common.generate_tensor((YS, XS), dtype=dtype).npu() - fn_broadcast_1d[1, 1, 1](output, x, XS, YS) - test_common.validate_cmp(dtype, std, output) - - -@triton.jit -def fn_broadcast_2d(output_ptr, x_ptr, NUMEL: tl.constexpr, XS: tl.constexpr, YS: tl.constexpr, ZS: tl.constexpr): - zoffset = tl.program_id(0) * ZS - zidx = tl.arange(0, ZS)[None, :] - base = tl.load(x_ptr + zoffset + zidx) - out = base.broadcast_to((YS, ZS)) - oidx = zoffset * YS + tl.arange(0, YS)[:, None] * ZS + zidx - tl.store(output_ptr + oidx, out) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_npu_2d(shape, dtype): - XS = shape[0] - ZS = shape[1] - YS = 4 - NUMEL = XS * ZS - - x = test_common.generate_tensor((XS, 1, ZS), dtype=dtype).npu() - std = torch.broadcast_to(x, (XS, YS, ZS)) - output = test_common.generate_tensor((XS, YS, ZS), dtype=dtype).npu() - fn_broadcast_2d[XS, 1, 1](output, x, NUMEL, XS, YS, ZS) - test_common.validate_cmp(dtype, std, output) - - -@triton.jit -def triton_broadcast_to_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim0(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, M, N), dtype=dtype).npu() - ans = x0.repeat(L, 1, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim0[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim1(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, 1, N), dtype=dtype).npu() - ans = x0.repeat(1, M, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim1[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * 1 * M + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim2(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, M, 1), dtype=dtype).npu() - ans = x0.repeat(1, 1, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim2[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim01(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim01(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, 1, N), dtype=dtype).npu() - ans = x0.repeat(L, M, 1) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim01[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim02(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = tl.arange(0, 1)[:, None, None] * M * 1 + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim02(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(1, M, 1), dtype=dtype).npu() - ans = x0.repeat(L, 1, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim02[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def triton_broadcast_to_dim12(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * 1 * 1 + tl.arange(0, 1)[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = x.broadcast_to(L, M, N) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_broadcast_to_dim12(shape, dtype): - L, M, N = shape - x0 = test_common.generate_tensor(shape=(L, 1, 1), dtype=dtype).npu() - ans = x0.repeat(1, M, N) - output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() - triton_broadcast_to_dim12[1, 1, 1](x0, output, L, M, N) - test_common.validate_cmp(dtype, output, ans) - - -@triton.jit -def fn_broadcast_to_multi_d(to_ptr, from_ptr, F_L: tl.constexpr, F_M: tl.constexpr, F_N: tl.constexpr, - F_X: tl.constexpr, F_Y: tl.constexpr, T_L: tl.constexpr, T_M: tl.constexpr, - T_N: tl.constexpr, T_X: tl.constexpr, T_Y: tl.constexpr): - from_offsets = tl.arange(0, F_L) - if F_M is not None: - from_offsets = from_offsets[:, None] * F_M + tl.arange(0, F_M)[None, :] - if F_N is not None: - from_offsets = from_offsets[:, :, None] * F_N + tl.arange(0, F_N)[None, None, :] - if F_X is not None: - from_offsets = from_offsets[:, :, :, None] * F_X + tl.arange(0, F_X)[None, None, None, :] - if F_Y is not None: - from_offsets = from_offsets[:, :, :, :, None] * F_Y + tl.arange(0, F_Y)[None, None, None, None, :] - - to_offsets = tl.arange(0, T_L) - if T_M is not None: - to_offsets = to_offsets[:, None] * T_M + tl.arange(0, T_M)[None, :] - if T_N is not None: - to_offsets = to_offsets[:, :, None] * T_N + tl.arange(0, T_N)[None, None, :] - if T_X is not None: - to_offsets = to_offsets[:, :, :, None] * T_X + tl.arange(0, T_X)[None, None, None, :] - if T_Y is not None: - to_offsets = to_offsets[:, :, :, :, None] * T_Y + tl.arange(0, T_Y)[None, None, None, None, :] - - from_data = tl.load(from_ptr + from_offsets) - if F_Y is not None: - ret_data = from_data.broadcast_to((T_L, T_M, T_N, T_X, T_Y)) - elif F_X is not None: - ret_data = from_data.broadcast_to((T_L, T_M, T_N, T_X)) - elif F_N is not None: - ret_data = from_data.broadcast_to((T_L, T_M, T_N)) - elif F_M is not None: - ret_data = from_data.broadcast_to((T_L, T_M)) - else: - ret_data = from_data.broadcast_to((T_L)) - - tl.store(to_ptr + to_offsets, ret_data) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shapes', [ - [(1, 64, 16, 1), (2, 64, 16, 2)], - [(8, 1, 1, 2), (8, 8, 4, 2)], -]) -@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) -def test_broadcast_to_4d(shapes, dtype): - from_shape, to_shape = shapes - dtype = eval(f"torch.{dtype}") - - x = torch.randint(0, 8, from_shape, dtype=dtype).npu() - y = torch.randint(0, 8, to_shape, dtype=dtype).npu() - expected = x.expand(to_shape) - - grid = (1, ) - triton_from_shape = [*from_shape] - triton_to_shape = [*to_shape] - while len(triton_from_shape) < 5: - triton_from_shape.append(None) - triton_to_shape.append(None) - fn_broadcast_to_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) - assert (torch.equal(y, expected)) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) -@pytest.mark.parametrize('shapes', [ - [(1, 4, 2, 1, 4), (2, 4, 2, 8, 4)], - [(3, 1, 2, 1, 4), (3, 4, 2, 8, 4)], -]) -def test_broadcast_to_5d(shapes, dtype): - from_shape, to_shape = shapes - dtype = eval(f"torch.{dtype}") - - x = torch.randint(0, 8, from_shape, dtype=dtype).npu() - y = torch.randint(0, 8, to_shape, dtype=dtype).npu() - expected = x.expand(to_shape) - - grid = (1, ) - triton_from_shape = [*from_shape] - triton_to_shape = [*to_shape] - while len(triton_from_shape) < 5: - triton_from_shape.append(None) - triton_to_shape.append(None) - fn_broadcast_to_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) - assert (torch.equal(y, expected)) - - -XS: tl.constexpr = 2 -YS: tl.constexpr = 4 -ZS: tl.constexpr = 8 -NUMEL: tl.constexpr = XS * ZS - - -@triton.jit -def fn_broadcast_to(output_ptr, input_ptr, length): - col_offsets = tl.arange(0, NUMEL) - input = tl.load(input_ptr + col_offsets) - result = input.reshape((XS, 1, ZS)).broadcast_to((XS, YS, ZS)).reshape((XS * YS * ZS)) - brc_col_offsets = tl.arange(0, NUMEL * YS) - tl.store(output_ptr + brc_col_offsets, result) - - -@pytest.mark.parametrize('dtype', - ["uint8", "int8", "int16", "int32", "int64", "float16", "float32", "bfloat16", "bool"]) -def test_broadcast_to_alltype(dtype): - length = NUMEL - input = test_common.generate_tensor((XS, 1, ZS), dtype).npu() - output = test_common.generate_tensor((XS, YS, ZS), dtype).npu() - fn_broadcast_to[1, 1, 1](output, input, length, debug=True) - assert (torch.equal(output, input.repeat(1, YS, 1))) diff --git a/third_party/ascend/unittest/generalization_cases/test_cast.py b/third_party/ascend/unittest/generalization_cases/test_cast.py deleted file mode 100644 index 3e1608cb97..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_cast.py +++ /dev/null @@ -1,391 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import random -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - - -@triton.jit -def cast_to_bool(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int1) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_i8(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int8) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_i16(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int16) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_i32(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int32) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_i64(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int64) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_fp32(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.float32) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_fp16(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.float16) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_bf16(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, - ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.bfloat16) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_uint32(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.uint32) - tl.store(output_ptr + idx, ret) - - -@triton.jit -def cast_to_int64(output_ptr, x_ptr, x_stride, y_stride, z_stride, DIM: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr): - if DIM == 1: - xidx = tl.arange(0, XB) - idx = xidx * x_stride - elif DIM == 2: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride - elif DIM == 3: - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride - - X = tl.load(x_ptr + idx) - ret = tl.cast(X, dtype=tl.int64) - tl.store(output_ptr + idx, ret) - - -triton_func_map = { - "bool": cast_to_bool, "int8": cast_to_i8, "int16": cast_to_i16, "int32": cast_to_i32, "float16": cast_to_fp16, - "bfloat16": cast_to_bf16, "float32": cast_to_fp32, "uint32": cast_to_uint32, "int64": cast_to_int64 -} - - -def structParam(x0): - dim = x0.dim() - stride0, stride1, stride2 = 0, 0, 0 - shape0, shape1, shape2 = 0, 0, 0 - if dim >= 1: - stride0 = x0.stride(0) - shape0 = x0.shape[0] - if dim >= 2: - stride1 = x0.stride(1) - shape1 = x0.shape[1] - if dim == 3: - stride2 = x0.stride(2) - shape2 = x0.shape[2] - return dim, stride0, stride1, stride2, shape0, shape1, shape2 - - -@pytest.mark.parametrize('shape', random.sample(TestUtils.full_shape, 5)) -@pytest.mark.parametrize('srcDtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dstDtype', TestUtils.full_dtype) -def test_cast(srcDtype, dstDtype, shape): - if srcDtype == dstDtype: - return - srcBytes = get_dtype_size(srcDtype) - dstBytes = get_dtype_size(dstDtype) - dtype_size = max(srcBytes, dstBytes) - if dstDtype == 'int8': - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 100): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 12): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - - x0 = test_common.generate_tensor(shape, srcDtype) - torch_res = x0.to(eval("torch." + dstDtype)) - x0 = x0.npu() - triton_func = triton_func_map.get(dstDtype, None) - assert triton_func is not None, f"triton_func not Found, srcDtype:{srcDtype}, dstDtype:{dstDtype}" - triton_res = torch.empty(shape, dtype=eval("torch." + dstDtype)).npu() - dim, stride0, stride1, stride2, XB, YB, ZB = structParam(x0) - assert 0 <= dim <= 3, f"dim out of range [0, 3], dim:{dim}" - triton_func[1, 1, 1](triton_res, x0, stride0, stride1, stride2, dim, XB, YB, ZB) - test_common.validate_cmp(dstDtype, triton_res, torch_res) - - -@triton.jit -def cast_to_multi_d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - dtype = output_ptr.type.element_ty - - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - X = tl.load(x_ptr + offsets) - ret = tl.cast(X, dtype=dtype) - - tl.store(output_ptr + offsets, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (6, 2, 4, 2), - (4, 2, 8, 4), - (4, 3, 8, 4), -]) -@pytest.mark.parametrize('srcDtype', ['int8', 'float16', 'float32']) -@pytest.mark.parametrize('dstDtype', ['int8', 'float16', 'float32']) -def test_cast_4d(srcDtype, dstDtype, shape): - if srcDtype == dstDtype: - return - srcBytes = get_dtype_size(srcDtype) - dstBytes = get_dtype_size(dstDtype) - dtype_size = max(srcBytes, dstBytes) - if dstDtype == 'int8': - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 100): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 12): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - - x0 = test_common.generate_tensor(shape, srcDtype) - torch_res = x0.to(eval("torch." + dstDtype)) - x0 = x0.npu() - - triton_res = torch.empty(shape, dtype=eval("torch." + dstDtype)).npu() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - grid = (1, ) - cast_to_multi_d[grid](triton_res, x0, *triton_shape) - test_common.validate_cmp(dstDtype, triton_res, torch_res) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 6, 2, 4, 2), - (2, 4, 2, 8, 4), - (3, 4, 2, 8, 4), -]) -@pytest.mark.parametrize('srcDtype', ['int8', 'float16', 'float32']) -@pytest.mark.parametrize('dstDtype', ['int8', 'float16', 'float32']) -def test_cast_5d(srcDtype, dstDtype, shape): - if srcDtype == dstDtype: - return - srcBytes = get_dtype_size(srcDtype) - dstBytes = get_dtype_size(dstDtype) - dtype_size = max(srcBytes, dstBytes) - if dstDtype == 'int8': - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 100): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 12): - print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") - return - - x0 = test_common.generate_tensor(shape, srcDtype) - torch_res = x0.to(eval("torch." + dstDtype)) - x0 = x0.npu() - - triton_res = torch.empty(shape, dtype=eval("torch." + dstDtype)).npu() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - grid = (1, ) - cast_to_multi_d[grid](triton_res, x0, *triton_shape) - test_common.validate_cmp(dstDtype, triton_res, torch_res) - - -if __name__ == "__main__": - for shape in [(3, ), (3, 3), (3, 3, 3)]: - for srcDtype in ['int8', 'float32', 'bool']: - for dstDtype in ['int8', 'float32', 'bool']: - test_cast(srcDtype, dstDtype, shape) diff --git a/third_party/ascend/unittest/generalization_cases/test_cdiv.py b/third_party/ascend/unittest/generalization_cases/test_cdiv.py deleted file mode 100644 index 4f9afe73c6..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_cdiv.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math -import logging - - -def torch_cdiv(x0, x1, dtype): - return (x0 + x1 - 1) // x1 - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.cdiv(X, Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_cdiv_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.cdiv(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_case2(dtype, shape): - # 生成数据, cdiv int8 溢出的行为triton与torch_cpu不一致 - x = (test_common.generate_tensor(shape, dtype) // 2).abs().npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - y = (y.abs() // 2 + 1) - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_cdiv(x.cpu(), y.cpu(), eval('torch.' + dtype)) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if dtype == 'int8': - if x.numel() * x.element_size() >= 512: - grid = (1, 1, ZB) - ZB = 1 - else: - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_cdiv_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = (test_common.generate_tensor(shape, dtype) // 2).abs().npu() - y = test_common.generate_tensor(shape, dtype).npu() - y = (y.abs() // 2 + 1) - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_cdiv(x.cpu(), y.cpu(), eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_cdiv_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_ceil.py b/third_party/ascend/unittest/generalization_cases/test_ceil.py deleted file mode 100644 index bb0e925658..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_ceil.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import time - -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging -import math - - -def torch_ceil(x0): - res = torch.ceil(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.ceil(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_ceil_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.ceil(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_ceil(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_ceil_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_ceil(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_ceil_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_common.py b/third_party/ascend/unittest/generalization_cases/test_common.py deleted file mode 100644 index e6cf112f74..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_common.py +++ /dev/null @@ -1,343 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import os -import re -import torch -import torch_npu -import math -import logging -from typing import AnyStr -import pytest -import functools -import numpy as np - -_float_dtypes = ['float32', 'float16', 'bfloat16'] -_int_dtypes = ['int32', 'int64', 'int16', 'int8'] -_uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] - -log_level = os.getenv("LOG_LEVEL", "WARN").upper() -level_mapping = { - "DEBUG": logging.DEBUG, "INFO": logging.INFO, "WARN": logging.WARNING, "ERROR": logging.ERROR, "CRITICAL": - logging.CRITICAL -} - -logging.basicConfig(level=level_mapping.get(log_level, logging.WARNING), - format="[%(asctime)s][%(levelname)s] %(message)s") - -bisheng_not_support_dtypes = { - 'abs': [], 'eq': [], 'ne': [], 'flip': ['int64', - 'bfloat16'], 'load_store': ['int64'], 'permute2d': ['int64'], 'permute3d': - ['int64'], 'trans2d': ['int64'], 'trans3d': ['int64'], 'matmul': ['int16', 'int32', 'uint32', 'int64', 'bool'] -} - -tritonascend_not_support_dtypes = { - 'abs': ['bool'], - 'eq': ['bool'], - 'ne': ['bool'], - 'flip': ['bool'], - 'load_store': ['bool'], - 'permute2d': ['bool'], - 'permute3d': ['bool'], - 'trans2d': ['bool'], - 'trans3d': ['bool'], -} - - -def avoid_not_support(op: AnyStr): - - def decorator(test_func): - - @functools.wraps(test_func) - def wrapper(shape, dtype, *args, **kwargs): - if dtype in bisheng_not_support_dtypes.get(op, []): - logging.warn(f'skiped bisheng not support dtype:{dtype}') - return - if dtype in tritonascend_not_support_dtypes.get(op, []): - logging.warn(f'skiped triton ascend not support dtype:{dtype}') - return - return test_func(shape, dtype, *args, **kwargs) - - return wrapper - - return decorator - - -def get_shape1d(in_shape1d): - result = [] - for i in in_shape1d: - v = tuple((i, )) - result.append(v) - return result - - -def get_shape2d(in_shape1d, custom_shape): - result = [] - for a in in_shape1d: - for b in custom_shape: - t1 = tuple((a, b)) - t2 = tuple((b, a)) - if t1 not in result: - result.append(t1) - if t2 not in result: - result.append(t2) - return result - - -def get_shape3d(): - return [(1, 22, 39), (27, 1, 39), (27, 22, 1), (23, 1, 1), (1, 23, 1), (1, 1, 23), (37, 5, 3), (2, 29, 4), - (7, 31, 7), (3, 5, 8), (7, 17, 15), (23, 5, 16), (23, 5, 31), (7, 11, 32), (7, 11, 33), (2, 3, 255), - (3, 3, 256), (3, 2, 257)] - - -def get_shape1_2_3d(in_shape1d, custom_shape): - return get_shape1d(in_shape1d) + get_shape2d(in_shape1d, custom_shape) + get_shape3d() - - -class TestUtils: - in_shape1d = [1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 37, 741] - custom_shape = [3, 13, 32, 256] - batch = [1, 2, 3, 4, 5, 8] - test_shape1d = get_shape1d(in_shape1d) - test_shape2d = get_shape2d(in_shape1d, custom_shape) - test_shape3d = [ - (1, 22, 39), - (27, 1, 39), - (27, 22, 1), - (1, 1, 23), - (23, 1, 1), - (1, 23, 1), - (37, 5, 3), - (2, 29, 4), - (7, 31, 7), - (3, 5, 8), - (7, 17, 15), - (25, 5, 16), - (23, 5, 31), - (7, 11, 32), - (7, 11, 33), - (2, 3, 255), - (3, 3, 256), - (3, 2, 257), - ] - test_shape4d = [(8, 4, 8, 8), (1, 11, 16, 2)] - test_shape5d = [(2, 3, 4, 5, 6), (1, 3, 4, 5, 6), (3, 6, 2, 4, 4)] - test_shape6d = [(2, 3, 5, 6, 3, 2)] - test_shape7d = [(1, 2, 3, 4, 3, 2, 2)] - test_shape_ub_overflow = [(10, 50, 1000)] - test_shape8d = [(1, 2, 3, 2, 5, 3, 7, 2), (1, 3, 2, 5, 6, 7, 2, 1), (2, 3, 7, 3, 2, 3, 2, 3)] - full_shape_4_8d = test_shape4d + test_shape5d + test_shape6d + test_shape7d + test_shape8d - - full_shape = test_shape1d + test_shape2d + test_shape3d - test_shape1_2_3d = full_shape - full_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32', 'bool'] - ub_size = 98304 * 2 - dtype_list = full_dtype - - -def get_dtype_size(dtype): - torch_dtype = eval('torch.' + dtype) - bits = 0 - if torch_dtype == torch.bool: - bits = 8 - elif torch.is_floating_point(torch.tensor(0, dtype=torch_dtype)): - bits = torch.finfo(torch_dtype).bits - else: - bits = torch.iinfo(torch_dtype).bits - return bits // 8 - - -def check_ub_mem_overflow(dtype, shape): - bytes = get_dtype_size(dtype) - if bytes * math.prod(shape) > TestUtils.ub_size: - logging.warning(f'dtype:{dtype} shape:{shape} mem overflow') - return True - return False - - -def generate_numpy(shape, dtype, low=None, high=None): - if dtype in _int_dtypes + _uint_dtypes: - iinfo = np.iinfo(getattr(np, dtype)) - low = iinfo.min if low is None else max(low, iinfo.min) - high = iinfo.max if high is None else min(high, iinfo.max) - dty = getattr(np, dtype) - return np.random.randint(low, high, shape, dtype=dty) - elif dtype == 'float16' or dtype == 'float32': - return np.random.normal(0, 1, shape).astype(dtype) - elif dtype == 'bfloat16': - return (np.random.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') - elif dtype == 'bool': - return np.random.randint(low=0, high=2, size=shape).astype(bool) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def generate_tensor(shape, dtype): - if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': - return torch.randn(size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'uint32': - return torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int8': - return torch.randint(low=0, high=127, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'bool': - return torch.randint(low=0, high=2, size=shape).bool() - elif dtype == 'uint8': - return torch.randint(low=0, high=255, size=shape, dtype=torch.uint8) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def generate_tensor_int_withSigns(shape, dtype): - if dtype == 'int32' or dtype == 'int64' or dtype == 'int16': - return torch.randint(low=-32768, high=32767, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int8': - return torch.randint(low=-128, high=127, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'bool': - return torch.randint(low=0, high=2, size=shape).bool() - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def get_triton_sig_typename(dtype): - if dtype == 'float32': - tyname = "*fp32" - elif dtype == 'int32': - tyname = "*i32" - elif dtype == 'int64': - tyname = "*i64" - elif dtype == 'float16': - tyname = "*fp16" - elif dtype == 'int16': - tyname = "*i16" - elif dtype == 'int8': - tyname = "*i8" - elif dtype == 'bool': - tyname = "*i1" - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - return tyname - - -# Relative error: abs(x_ref - x_cal) / abs(x_ref) -# Absolute error: abs(x_ref - x_cal) - - -# calculation type operators require different error range -# It is a stricter verification and not satisfied now, save it here -def validate_cal(dtype, y_cal, y_ref): - if dtype == 'float16': - if torch.mean(y_ref) < 0.001: - assert torch.abs(y_cal - y_ref) < 0.001, "|y_cal - y_ref| < 0.001 is required !" - else: - diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.001 - # all true - assert diff.all(), "Relative error is less than 0.001 !" - if dtype == 'float32': - if torch.mean(y_ref) < 0.0001: - assert torch.abs(y_cal - y_ref) < 0.0001, "|y_cal - y_ref| < 0.0001 is required !" - else: - diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.0001 - assert diff.all(), "Relative error is less than 0.001 !" - elif dtype == 'bfloat16': - diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.001 - assert diff.all(), "Relative error is less than 0.001 !" - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8': - assert torch.equal(y_cal, y_ref) - elif dtype == 'uint8': - assert torch.equal(y_cal, y_ref) - elif dtype == 'bool': - assert torch.equal(y_cal, y_ref) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -# moving and comparison ops require no precision error -def validate_cmp(dtype, y_cal, y_ref): - y_cal = y_cal.npu() - y_ref = y_ref.npu() - if dtype == 'float16': - torch.testing.assert_close(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) - elif dtype == 'bfloat16': - torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=1e-03, atol=1e-03, - equal_nan=True) - elif dtype == 'float32': - torch.testing.assert_close(y_ref, y_cal, rtol=1e-04, atol=1e-04, equal_nan=True) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8': - assert torch.equal(y_cal, y_ref) - elif dtype == 'uint8' or dtype == 'uint16' or dtype == 'uint32' or dtype == 'uint64': - assert torch.equal(y_cal, y_ref) - elif dtype == 'bool': - assert torch.equal(y_cal.cpu(), y_ref.cpu()) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def validate_cmp_with_expection(dtype, y_cal, y_ref, expect): - if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': - if expect: - assert torch.allclose(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) - else: - assert not torch.allclose(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8' \ - or dtype == 'uint8' or dtype == 'uint16' or dtype == 'uint32' or dtype == 'uint64': - if expect: - assert torch.equal(y_cal, y_ref) - else: - assert not torch.equal(y_cal, y_ref) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def raises_with_match(expected_exception, match_pattern): - - def decorator(test_func): - - @functools.wraps(test_func) - def wrapper(*args, **kwargs): - with pytest.raises(expected_exception, match=match_pattern): - return test_func(*args, **kwargs) - - return wrapper - - return decorator - - -def capture_output(expected_output): - - def decorator(test_func): - - @functools.wraps(test_func) - def wrapper(*args, **kwargs): - capsys = kwargs.pop('capsys', None) - if capsys is None: - try: - capsys = pytest.fixture(capsys)() - except: - raise RuntimeError("This decorator requires pytest's capsys fixture") - test_func(capsys, *args, **kwargs) - captured = capsys.readouterr() - # pybind11::scoped_ostream_redirect captures std::cout with \x00 inserted - # for now, no idea how to eliminate \x00 from C++ side. - cleaned = re.sub(r"\x00", "", captured.out) - assert expected_output in cleaned - - return wrapper - - return decorator diff --git a/third_party/ascend/unittest/generalization_cases/test_cos.py b/third_party/ascend/unittest/generalization_cases/test_cos.py deleted file mode 100644 index 9f1980a515..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_cos.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import time - -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging -import math - - -def torch_cos(x0): - res = torch.cos(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.cos(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_cos_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.cos(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_cos(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_cos_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_cos(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_cos_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_count_dim0.py b/third_party/ascend/unittest/generalization_cases/test_count_dim0.py deleted file mode 100644 index 826f649909..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_count_dim0.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -def standard_count(x0, cmp_val, dim, dtype): - res = (x0 == cmp_val).sum(dim=dim) - return res - - -def standard_count_gt(x0, cmp_val, dim, dtype): - res = (x0 > cmp_val).sum(dim=dim) - return res - - -def standard_count_lt(x0, cmp_val, dim, dtype): - res = (x0 < cmp_val).sum(dim=dim) - return res - - -@triton.jit -def count(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, N) + tl.program_id(2) * N - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x == cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + nblk_idx, ret, mask=nmask) - - -@triton.jit -def count_gt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, N) + tl.program_id(2) * N - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x > cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + nblk_idx, ret, mask=nmask) - - -@triton.jit -def count_lt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, N) + tl.program_id(2) * N - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x < cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + nblk_idx, ret, mask=nmask) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8']) -def test_count_dim0_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count(x0, cmp_val, 0, dtype) - - output = torch.zeros((shape[1], ), dtype=torch.float32).npu() - count[1, 1, rblock](x0, output, cmp_val, 0, xblock, 1, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) -def test_count_gt_dim0_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count_gt(x0, cmp_val, 0, dtype) - - output = torch.zeros((shape[1], ), dtype=torch.float32).npu() - count_gt[1, 1, rblock](x0, output, cmp_val, 0, xblock, 1, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) -def test_count_lt_dim0_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count_lt(x0, cmp_val, 0, dtype) - - output = torch.zeros((shape[1], ), dtype=torch.float32).npu() - count_lt[1, 1, rblock](x0, output, cmp_val, 0, xblock, 1, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) diff --git a/third_party/ascend/unittest/generalization_cases/test_count_dim1.py b/third_party/ascend/unittest/generalization_cases/test_count_dim1.py deleted file mode 100644 index ebd19cf7ab..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_count_dim1.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -def standard_count(x0, cmp_val, dim, dtype): - res = (x0 == cmp_val).sum(dim=dim) - return res - - -def standard_count_gt(x0, cmp_val, dim, dtype): - res = (x0 > cmp_val).sum(dim=dim) - return res - - -def standard_count_lt(x0, cmp_val, dim, dtype): - res = (x0 < cmp_val).sum(dim=dim) - return res - - -@triton.jit -def count(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, M) + tl.program_id(1) * M - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x == cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) - - -@triton.jit -def count_gt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, M) + tl.program_id(1) * M - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x > cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) - - -@triton.jit -def count_lt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, M) + tl.program_id(1) * M - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < MNUMEL - nmask = nblk_idx < NNUMEL - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x < cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8']) -def test_count_dim1_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count(x0, cmp_val, 1, dtype) - - output = torch.zeros((shape[0], ), dtype=torch.float32).npu() - count[1, xblock, 1](x0, output, cmp_val, 1, 1, rblock, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) -def test_count_gt_dim1_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count_gt(x0, cmp_val, 1, dtype) - - output = torch.zeros((shape[0], ), dtype=torch.float32).npu() - count_gt[1, xblock, 1](x0, output, cmp_val, 1, 1, rblock, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) -def test_count_lt_dim1_common(shape, dtype): - rblock = shape[1] - xblock = shape[0] - x0 = test_common.generate_tensor(shape, dtype).npu() - - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - - ans = standard_count_lt(x0, cmp_val, 1, dtype) - - output = torch.zeros((shape[0], ), dtype=torch.float32).npu() - count_lt[1, xblock, 1](x0, output, cmp_val, 1, 1, rblock, xblock, rblock) - - test_common.validate_cmp("float32", output, ans.to(torch.float32)) diff --git a/third_party/ascend/unittest/generalization_cases/test_cumprod.py b/third_party/ascend/unittest/generalization_cases/test_cumprod.py deleted file mode 100644 index 9af5216d9b..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_cumprod.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import torch_npu -import triton -import triton.language as tl -from triton.runtime.libentry import libentry - -from test_common import TestUtils, validate_cmp, get_dtype_size - - -def torch_func(x, dim, reverse): - is_bf16 = x.dtype == torch.bfloat16 - if is_bf16: - x = x.to(torch.float32) - if reverse: - x = torch.flip(x, [dim]) - res = torch.cumprod(x, dim=dim) - if is_bf16: - res = res.to(torch.bfloat16) - return res - - -@libentry() -@triton.jit -def triton_kernel_1d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - XBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - idx = tl.arange(0, XBLOCK) - x = tl.load(in_ptr0 + idx) - ret = tl.cumprod(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_2d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx = idx_x[:, None] * numel_r + idx_r[None, :] - x = tl.load(in_ptr0 + idx) - ret = tl.cumprod(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_3d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - numel_z: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, - ZBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - tl.static_assert(numel_z == ZBLOCK, "numel_z must be equal to ZBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx_z = tl.arange(0, ZBLOCK) - idx = idx_x[:, None, None] * numel_r * numel_z + idx_r[None, :, None] * numel_z + idx_z[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = tl.cumprod(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_4d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - idx = (xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + - zidx[None, None, :, None] * MB + midx[None, None, None, :]) - x = tl.load(in_ptr0 + idx) - ret = tl.cumprod(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_5d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, - NB: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - idx = (xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + - zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + - nidx[None, None, None, None, :]) - x = tl.load(in_ptr0 + idx) - ret = tl.cumprod(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -def convert_cumprod_dtype(x: torch.Tensor) -> torch.Tensor: - """ - 根据 cumprod 类型转换规则,返回转换后的张量。 - """ - dtype_map = { - torch.int8: torch.int64, - torch.int16: torch.int64, - torch.int32: torch.int64, - torch.int64: torch.int64, - torch.bfloat16: torch.bfloat16, - torch.float16: torch.float16, - torch.float32: torch.float32, - torch.bool: torch.int64, - } - - target_dtype = dtype_map.get(x.dtype, None) - if target_dtype is None: - raise ValueError(f"Unsupported input dtype for cumprod conversion: {x.dtype}") - - return x.to(target_dtype) - - -def triton_func(x, dim, reverse): - x = convert_cumprod_dtype(x) - - res = torch.empty_like(x) - shape = x.size() - if len(shape) == 1: - if dim >= 1: - pytest.skip("dim >= 1 for 1D tensor, skipping.") - triton_kernel_1d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[0]) - elif len(shape) == 2: - if dim >= 2: - pytest.skip("dim >= 2 for 2D tensor, skipping.") - triton_kernel_2d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1]) - elif len(shape) == 3: - if dim >= 3: - pytest.skip("dim >= 3 for 3D tensor, skipping.") - triton_kernel_3d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[0], x.shape[1], - x.shape[2]) - elif len(shape) == 4: - if dim >= 4: - pytest.skip("dim >= 4 for 4D tensor, skipping.") - triton_kernel_4d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3]) - elif len(shape) == 5: - if dim >= 5: - pytest.skip("dim >= 5 for 5D tensor, skipping.") - triton_kernel_5d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4]) - else: - pytest.skip(f"Unsupported tensor dimension: {len(shape)}") - - return res - - -def cumprod_generate_tensor(shape, dtype): - if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': - return torch.rand(size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': - return torch.randint(low=1, high=5, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int8': - return torch.randint(low=1, high=5, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'bool': - return torch.randint(low=0, high=2, size=shape).bool() - else: - raise ValueError(f"Unsupported dtype: {dtype}") - - -def should_skip_due_to_mem(dtype, shape): - dtype_size = get_dtype_size(dtype) - total_mem = dtype_size * math.prod(shape) - - if dtype in ('int8', 'bool'): - threshold = TestUtils.ub_size / 13 - else: - threshold = TestUtils.ub_size / 6 - - if total_mem >= threshold: - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - - -# reverse=True not support; -@pytest.mark.parametrize("dtype", TestUtils.full_dtype) -@pytest.mark.parametrize("shape", TestUtils.full_shape) -@pytest.mark.parametrize("dim", [0, 1, 2, 3, 4]) -@pytest.mark.parametrize("reverse", [False]) -def test_cumprod(dtype, shape, dim, reverse): - should_skip_due_to_mem(dtype, shape) - - x = cumprod_generate_tensor(shape=shape, dtype=dtype) - x_npu = x.npu() - - triton_res = triton_func(x_npu, dim, reverse) - - x_gold = x - cpu_res = torch_func(x_gold, dim, reverse) - - validate_cmp(dtype, triton_res, cpu_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_cumsum.py b/third_party/ascend/unittest/generalization_cases/test_cumsum.py deleted file mode 100644 index 06ef04eceb..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_cumsum.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import torch -import torch_npu -import triton -import triton.language as tl -from triton.runtime.libentry import libentry - -import acc_util -import test_common -from test_common import TestUtils, get_dtype_size - - -def torch_func(x, dim, reverse): - if reverse: - x = torch.flip(x, [dim]) - res = torch.cumsum(x, dim=dim) - return res - - -@libentry() -@triton.jit -def triton_kernel_1d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - XBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - idx = tl.arange(0, XBLOCK) - x = tl.load(in_ptr0 + idx) - ret = tl.cumsum(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_2d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx = idx_x[:, None] * numel_r + idx_r[None, :] - x = tl.load(in_ptr0 + idx) - ret = tl.cumsum(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_3d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - numel_x: tl.constexpr, - numel_r: tl.constexpr, - numel_z: tl.constexpr, - XBLOCK: tl.constexpr, - RBLOCK: tl.constexpr, - ZBLOCK: tl.constexpr, -): - tl.static_assert(numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel") - tl.static_assert(numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel") - tl.static_assert(numel_z == ZBLOCK, "numel_z must be equal to ZBLOCK in this kernel") - idx_x = tl.arange(0, XBLOCK) - idx_r = tl.arange(0, RBLOCK) - idx_z = tl.arange(0, ZBLOCK) - idx = idx_x[:, None, None] * numel_r * numel_z + idx_r[None, :, None] * numel_z + idx_z[None, None, :] - x = tl.load(in_ptr0 + idx) - ret = tl.cumsum(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_4d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - idx = (xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + - zidx[None, None, :, None] * MB + midx[None, None, None, :]) - x = tl.load(in_ptr0 + idx) - ret = tl.cumsum(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -@libentry() -@triton.jit -def triton_kernel_5d( - out_ptr0, - in_ptr0, - dim: tl.constexpr, - reverse: tl.constexpr, - XB: tl.constexpr, - YB: tl.constexpr, - ZB: tl.constexpr, - MB: tl.constexpr, - NB: tl.constexpr, -): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - idx = (xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + - zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + - nidx[None, None, None, None, :]) - x = tl.load(in_ptr0 + idx) - ret = tl.cumsum(x, axis=dim, reverse=reverse) - tl.store(out_ptr0 + idx, ret) - - -def convert_cumsum_dtype(x: torch.Tensor) -> torch.Tensor: - """ - 根据 cumsum 类型转换规则,返回转换后的张量。 - """ - dtype_map = { - torch.int8: torch.int64, - torch.int16: torch.int64, - torch.int32: torch.int64, - torch.int64: torch.int64, - torch.bfloat16: torch.bfloat16, - torch.float16: torch.float16, - torch.float32: torch.float32, - torch.bool: torch.int64, - } - - target_dtype = dtype_map.get(x.dtype, None) - if target_dtype is None: - raise ValueError(f"Unsupported input dtype for cumsum conversion: {x.dtype}") - - return x.to(target_dtype) - - -def triton_func(x, dim, reverse): - x = convert_cumsum_dtype(x) - - res = torch.empty_like(x) - shape = x.size() - if len(shape) == 1: - if dim >= 1: - pytest.skip("dim >= 1 for 1D tensor, skipping.") - triton_kernel_1d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[0]) - elif len(shape) == 2: - if dim >= 2: - pytest.skip("dim >= 2 for 2D tensor, skipping.") - triton_kernel_2d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1]) - elif len(shape) == 3: - if dim >= 3: - pytest.skip("dim >= 3 for 3D tensor, skipping.") - triton_kernel_3d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[0], x.shape[1], - x.shape[2]) - elif len(shape) == 4: - if dim >= 4: - pytest.skip("dim >= 4 for 4D tensor, skipping.") - triton_kernel_4d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3]) - elif len(shape) == 5: - if dim >= 5: - pytest.skip("dim >= 5 for 5D tensor, skipping.") - triton_kernel_5d[1, 1, 1](res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4]) - else: - pytest.skip(f"Unsupported tensor dimension: {len(shape)}") - - return res - - -def cumsum_generate_tensor(shape, dtype): - if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': - return torch.rand(size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': - return torch.randint(low=0, high=3, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int8': - return torch.randint(low=0, high=3, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'bool': - return torch.randint(low=0, high=2, size=shape).bool() - else: - raise ValueError(f"Unsupported dtype: {dtype}") - - -def should_skip_due_to_mem(dtype, shape): - dtype_size = get_dtype_size(dtype) - total_mem = dtype_size * math.prod(shape) - - if dtype in ('int8', 'bool'): - threshold = TestUtils.ub_size / 13 - else: - threshold = TestUtils.ub_size / 6 - - if total_mem >= threshold: - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - - -# reverse=True not support; - - -@pytest.mark.parametrize("dtype", TestUtils.full_dtype) -@pytest.mark.parametrize("shape", TestUtils.full_shape) -@pytest.mark.parametrize("dim", [0, 1, 2, 3, 4]) -@pytest.mark.parametrize("reverse", [False]) -def test_cumsum(dtype, shape, dim, reverse): - should_skip_due_to_mem(dtype, shape) - - x = cumsum_generate_tensor(shape=shape, dtype=dtype) - x_npu = x.npu() - - triton_res = triton_func(x_npu, dim, reverse) - - x_gold = x - cpu_res = torch_func(x_gold, dim, reverse) - - test_common.validate_cmp(dtype, triton_res, cpu_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_debug_barrier.py b/third_party/ascend/unittest/generalization_cases/test_debug_barrier.py deleted file mode 100644 index fb17fcb23f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_debug_barrier.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import numpy as np -import torch -import logging -import pytest -import test_common -from test_common import TestUtils - - -def torch_invert(x0, ddtype): - if 'float' in str(ddtype): - x0 = x0.to(torch.int32) - y_ref = ~x0 - y_ref = y_ref.to(ddtype) - else: - y_ref = ~x0 - return y_ref - - -@triton.jit -def triton_sub(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X - Y - tl.debug_barrier() - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_invert_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = ~x_val - tl.debug_barrier() - tl.store(output_ptr + offsets, ret, mask=masks) - - -test_shape_1d_2d_3d = [(1, ), (2, ), (1, 1), (3, 13), (1, 1, 1), (4, 3, 8)] -test_shape_4_5d = [(1, 1, 1, 1), (2, 2, 2, 2), (1, 1, 1, 1, 1), (2, 2, 2, 2, 1)] - - -@pytest.mark.parametrize('shape', test_shape_1d_2d_3d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_sub(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x - y - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if dtype == 'int8': - if x.numel() * x.element_size() >= 512: - grid = (1, 1, ZB) - ZB = 1 - else: - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - triton_sub[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', test_shape_1d_2d_3d + test_shape_4_5d) -@pytest.mark.parametrize('dtype', ['bool']) -def test_invert_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_invert(x, eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_invert_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_device_print.py b/third_party/ascend/unittest/generalization_cases/test_device_print.py deleted file mode 100644 index 5421db96ed..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_device_print.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -import torch_npu -import triton -import triton.language as tl -import pytest -import sys -import os -import subprocess -import tempfile -import textwrap - -os.environ["TRITON_DEVICE_PRINT"] = "1" -os.environ["TRITON_ENABLE_TASKQUEUE"] = "0" - -shape = (8, ) -XS = 8 -XVALS_INT = [ - 0, - torch.iinfo(torch.int8).min, - torch.iinfo(torch.int8).max, - torch.iinfo(torch.int16).min, - torch.iinfo(torch.int16).max, - torch.iinfo(torch.int32).min, - torch.iinfo(torch.int32).max, - torch.iinfo(torch.int32).max + 1 -] - - -@pytest.mark.parametrize('sigtype', ['int32', 'int64', 'int16', 'int8', 'float32', 'float16', 'bfloat16']) -def test_device_print_int32(sigtype): - - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: - temp_script = f.name - - f.write( - textwrap.dedent(f""" -import torch -import torch_npu -import triton -import triton.language as tl -import os -import sys - -os.environ["TRITON_DEVICE_PRINT"] = "1" -os.environ["TRITON_ENABLE_TASKQUEUE"] = "0" - -@triton.jit -def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr): - idx = tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.load(in_ptr1 + idx) - tmp2 = tmp0 + tmp1 - tl.device_print("OUTPUT = ", tmp2) - tl.store(out_ptr0 + idx, tmp2) - -def main(): - shape = (8,) - XS = 8 - dtype = torch.{sigtype} - - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - - XVALS_INT = [0, -128, 127, -32768, 32767, -2147483648, 2147483647, 2147483648] - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - - out = torch.empty_like(x0) - - triton_kernel[1,](out, x0, x1, XS) - - print("Kernel execution completed") - - return out - -if __name__ == "__main__": - result = main() - print(f"Result shape: {{result.shape}}") - """)) - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - - torch_ref = x0 + x1 - if 'int' in sigtype: - torch_ref_str = ','.join([str(int(val)) for val in torch_ref.cpu().numpy()]) - else: - values = torch_ref.cpu() - if values.dtype == torch.bfloat16: - values = values.float() - torch_ref_str = ','.join([f"{float(val):.6f}" for val in values.numpy()]) - - result = subprocess.run([sys.executable, temp_script], capture_output=True, text=True, env=os.environ.copy()) - - captured_output = result.stdout + "\n=== STDERR ===\n" + result.stderr - - ##with open(f"manual_capture_{sigtype}.txt", "w") as f: - ##f.write(captured_output) - ##f.write(f"torch_ref:{torch_ref_str}") - - if os.path.exists(temp_script): - os.remove(temp_script) - - assert torch_ref_str in captured_output diff --git a/third_party/ascend/unittest/generalization_cases/test_div_rn.py b/third_party/ascend/unittest/generalization_cases/test_div_rn.py deleted file mode 100644 index f0ce253288..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_div_rn.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import time - -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging -import math - - -def torch_divRn(x0, x1): - return x0 / x1 - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.div_rn(X, Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_div_rn_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.div_rn(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_divRn(x, y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_div_rn_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - ans = torch_divRn(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_div_rn_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_dot_scaled.py b/third_party/ascend/unittest/generalization_cases/test_dot_scaled.py deleted file mode 100644 index c40360e45a..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_dot_scaled.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import contextlib -import itertools -import re -import math -import textwrap -import os -import inspect -import pathlib -import test_common -import numpy as np -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - -from numpy.random import RandomState -from triton.language.extra import libdevice - - -@triton.jit -def dot_scale_kernel(a_base, stride_a0: tl.constexpr, stride_a1: tl.constexpr, a_scale, b_base, stride_b0: tl.constexpr, - stride_b1: tl.constexpr, b_scale, out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, type_a: tl.constexpr, type_b: tl.constexpr, acc_num: tl.constexpr): - PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K - PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K - str_a0: tl.constexpr = stride_a0 - a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, str_a0)[None, :] * stride_a1 - b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, BLOCK_N)[None, :] * stride_b1 - - a = tl.load(a_ptr) - b = tl.load(b_ptr) - SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - if a_scale is not None: - scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] - a_scale = tl.load(scale_a_ptr) - if b_scale is not None: - scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] - b_scale = tl.load(scale_b_ptr) - accumulator = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b, acc=accumulator, out_dtype=tl.float32) - if acc_num is not None: - for _ in range(acc_num): - accumulator = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b, acc=accumulator, out_dtype=tl.float32) - - out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] - tl.store(out_ptr, accumulator.to(a.dtype)) - - -def golden_ref(x, scale_x, y, scale_y): - shape_expand_x = x.shape[-1] // scale_x.shape[-1] - if x.dtype == torch.bfloat16: - upscale_x = scale_x.repeat_interleave(shape_expand_x, dim=1).to(torch.int16) - upscale_x = (upscale_x + 127 << 7).view(torch.bfloat16) - else: - scale_fp32 = scale_x.repeat_interleave(shape_expand_x, dim=1).to(torch.int32) - scale_fp32 = (scale_fp32 + 127 << 23).view(torch.float32) - upscale_x = scale_fp32.to(torch.float16) - upscale_y = None - if scale_y is None: - upscale_y = torch.ones_like(y) - else: - scale_y = scale_y.T - shape_expand_y = y.shape[0] // scale_y.shape[0] - if y.dtype == torch.bfloat16: - upscale_y = scale_y.repeat_interleave(shape_expand_y, dim=0).to(torch.int16) - upscale_y = (upscale_y + 127 << 7).view(torch.bfloat16) - else: - scale_fp32 = scale_y.repeat_interleave(shape_expand_y, dim=0).to(torch.int32) - scale_fp32 = (scale_fp32 + 127 << 23).view(torch.float32) - upscale_y = scale_fp32.to(torch.float16) - ret = torch.matmul(x * upscale_x, y * upscale_y) - return ret - - -@pytest.mark.parametrize("M, N, K, rhs_scale, normal_type, acc_num, num_warps", - [(M, N, K, rhs_scale, normal_type, acc_num, 4) - for M, N, K in itertools.product([16, 32, 64, 128], [16, 32, 64, 128], [32, 64]) - for rhs_scale in [False, True] - for normal_type in ["bf16", "fp16"] - for acc_num in [None, 1, 2]]) -def test_scaled_dot(M, N, K, rhs_scale, normal_type, num_warps, acc_num): - device = "npu" - - # The max exponent we use to initialize data in the x/y and associated scale tensor to avoid - # overflow when scaling. - comp_dtype_max_exp = 6 if normal_type == "fp16" else 15 - - torch.manual_seed(0) - - def make_arg(shape, ty): - if ty == "bf16" or ty == "fp16": - comp_dtype = torch.float16 if ty == "fp16" else torch.bfloat16 - ret = torch.randn(shape, dtype=comp_dtype, device=device) - # Clamp to avoid relative error issues - ret.clamp_(-2**comp_dtype_max_exp, 2**comp_dtype_max_exp - 1) - else: - ret = torch.randint(256, shape, dtype=torch.int8, device=device) - return ret - - type_a = normal_type - type_b = type_a - - x = make_arg((M, K), type_a) - y = make_arg((K, N), type_b) - - min_scale, max_scale = (0, 142) if type_a == torch.bfloat16 else (124, 131) - scale_x = torch.randint(min_scale - 128, max_scale - 127, (M, K // 32), dtype=torch.int8, device=device) - min_scale, max_scale = (0, 142) if type_b == torch.bfloat16 else (124, 131) - scale_y = torch.randint(min_scale - 128, max_scale - 127, (N, K // 32), dtype=torch.int8, device=device) - - if not rhs_scale: - scale_y = None - - kernel_kwargs = {"num_warps": num_warps} - z = x.new_empty((M, N), dtype=x.dtype) - pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b, acc_num, - **kernel_kwargs) - z_ref = golden_ref(x, scale_x, y, scale_y) - if acc_num is not None: - z_ref = z_ref * (acc_num + 1) - - atol = 1e-5 - rtol = 1e-2 - torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) - - -@pytest.mark.parametrize("B, M, N, K", [(1, 32, 64, 64)]) -def test_4d_dot(B, M, N, K): - device = "npu" - torch.manual_seed(0) - - x4d = torch.randn((B, B, M, N), dtype=torch.float16, device=device) - y4d = torch.randn((B, B, N, K), dtype=torch.float16, device=device) - - x2d = x4d.view(-1, N) # shape (B*B*M, N) - y2d = y4d.view(-1, K) # shape (B*B*N, K) - scale_x = torch.randint(-10, 10, (x2d.shape[0], N // 32), dtype=torch.int8, device=device) - scale_y = torch.randint(-10, 10, (y2d.shape[1], N // 32), dtype=torch.int8, device=device) - - z = torch.empty((x2d.shape[0], y2d.shape[0]), dtype=x2d.dtype, device=device) - acc_num = None - dot_scale_kernel[(1, )](x2d, *x2d.stride(), scale_x, y2d, *y2d.stride(), None, z, x2d.shape[0], y2d.shape[0], K, - "fp16", "fp16", None, num_warps=4) - z_ref = golden_ref(x2d, scale_x, y2d, None) - if acc_num is not None: - z_ref = z_ref * (acc_num + 1) - - atol = 1e-5 - rtol = 1e-2 - torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) - - -@pytest.mark.parametrize("B, M, N, K", [(2, 16, 16, 32)]) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, - r"lhs last dimension .* must equal rhs penultimate dimension") -def test_2d_dot_invaild_shape(B, M, N, K): - device = "npu" - torch.manual_seed(0) - - x4d = torch.randn((B, B, M, N), dtype=torch.float16, device=device) - y4d = torch.randn((B, B, N, K), dtype=torch.float16, device=device) - - x2d = x4d.view(-1, N) # shape (B*B*M, N) - y2d = y4d.view(-1, K) # shape (B*B*N, K) - scale_x = torch.randint(-10, 10, (x2d.shape[0], N // 32), dtype=torch.int8, device=device) - scale_y = torch.randint(-10, 10, (y2d.shape[1], N // 32), dtype=torch.int8, device=device) - - z = torch.empty((x2d.shape[0], y2d.shape[0]), dtype=x2d.dtype, device=device) - acc_num = None - dot_scale_kernel[(1, )](x2d, *x2d.stride(), scale_x, y2d, *y2d.stride(), None, z, x2d.shape[0], y2d.shape[0], K, - "fp16", "fp16", None, num_warps=4) - - -VALID_MAIN_DTYPES = { - torch.float16, # fp16 - torch.bfloat16, # bf16 -} - -ALL_DTYPES = { - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.float32, # fp32 - torch.bool, -} -ILLEGAL_MAIN_DTYPES = ALL_DTYPES - VALID_MAIN_DTYPES - -ILLEGAL_SCALE_DTYPES = { - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.bfloat16, - torch.bool, -} - -from itertools import product - - -def is_legal_dtype(lhs_dtype, rhs_dtype, lhs_scale_dtype, rhs_scale_dtype): - return (lhs_dtype in VALID_MAIN_DTYPES and rhs_dtype in VALID_MAIN_DTYPES and lhs_scale_dtype is torch.int8 - and rhs_scale_dtype is torch.int8) - - -illegal_cases = [] -for lhs, rhs, lhs_s, rhs_s in product( - VALID_MAIN_DTYPES | ILLEGAL_MAIN_DTYPES, - VALID_MAIN_DTYPES | ILLEGAL_MAIN_DTYPES, - {torch.int8} | ILLEGAL_SCALE_DTYPES, - {torch.int8} | ILLEGAL_SCALE_DTYPES, -): - - if not is_legal_dtype(lhs, rhs, lhs_s, rhs_s): - illegal_cases.append((lhs, rhs, lhs_s, rhs_s)) - -illegal_cases = sorted(set(illegal_cases), key=lambda t: tuple(str(i) for i in t)) - - -@pytest.mark.parametrize( - "lhs_dtype, rhs_dtype, lhs_scale_dtype, rhs_scale_dtype", - illegal_cases, -) -@test_common.raises_with_match(Exception, r"(?i)invalid|unsupported|dtype") -def test_invalid_dtype_should_fail(lhs_dtype, rhs_dtype, lhs_scale_dtype, rhs_scale_dtype): - device = "npu" - M, N, K = 32, 32, 64 - num_warps = 4 - - def make_tensor(shape, dtype): - return torch.randn(shape, dtype=dtype, device=device) \ - if dtype.is_floating_point else \ - torch.randint(-10, 10, shape, dtype=dtype, device=device) - - def make_scale(shape, dtype): - return torch.randint(-10, 10, shape, dtype=dtype, device=device) - - x = make_tensor((M, K), lhs_dtype) - y = make_tensor((K, N), rhs_dtype) - lhs_scale = make_scale((M, K // 32), lhs_scale_dtype) - rhs_scale = make_scale((N, K // 32), rhs_scale_dtype) - z = torch.empty((M, N), dtype=lhs_dtype, device=device) - - dot_scale_kernel[(1, )]( - x, - *x.stride(), - lhs_scale, - y, - *y.stride(), - rhs_scale, - z, - M, - N, - K, - str(lhs_dtype).split('.')[-1], - str(rhs_dtype).split('.')[-1], - None, - num_warps=num_warps, - ) diff --git a/third_party/ascend/unittest/generalization_cases/test_eq.py b/third_party/ascend/unittest/generalization_cases/test_eq.py deleted file mode 100644 index 94292ac21c..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_eq.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math -import logging - - -def torch_eq(x0, x1): - if x0.dtype != torch.uint32: - return x0 == x1 - else: - return x0.to(torch.float32) == x1.to(torch.float32) - - -@triton.jit -def triton_eq(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base1 = tl.arange(0, XBLOCK_SUB) - loops1: tl.constexpr = XBLOCK // XBLOCK_SUB - for loop1 in range(loops1): - x_index = offset + (loop1 * XBLOCK_SUB) + base1 - tmp0 = tl.load(in_ptr0 + x_index, mask=x_index < N) - tmp1 = tl.load(in_ptr1 + x_index, mask=x_index < N) - tmp2 = tmp0 == tmp1 - tl.store(out_ptr0 + x_index, tmp2, mask=x_index < N) - - -@triton.jit -def triton_eq_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val == y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_eq(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - # 生成数据 - x0 = test_common.generate_tensor(shape, dtype).npu() - x1 = test_common.generate_tensor(shape, dtype).npu() - - numel = x0.numel() - ncore = 1 if numel <= 32 else 32 - xblock = math.ceil(numel / ncore) - xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) - - # torch结果 - torch_res = torch_eq(x0, x1).to(eval('torch.' + dtype)) - # triton结果 - triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - N = triton_res.numel() - triton_eq[ncore, 1, 1](x0, x1, triton_res, N, xblock, xblock_sub) - # 比较结果 - torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) - triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) - cmp_dtype = dtype if dtype != 'uint32' else 'float32' - test_common.validate_cmp(cmp_dtype, triton_res, torch_res) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_eq_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_eq(x, y).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_eq_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_erf.py b/third_party/ascend/unittest/generalization_cases/test_erf.py deleted file mode 100644 index b82945f0f0..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_erf.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_erf(x0): - res = torch.erf(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.erf(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_erf_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr, - BLOCK_TOTAL: tl.constexpr): - - pid = tl.program_id(0) - start_idx = pid * BLOCK_TOTAL - local_idx = tl.arange(0, BLOCK_TOTAL) - global_idx = start_idx + local_idx - total_elements = SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 - masks = global_idx < total_elements - - dim1_base = SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 - dim2_base = SHAPE_2 * SHAPE_3 * SHAPE_4 - dim3_base = SHAPE_3 * SHAPE_4 - dim4_base = SHAPE_4 - - idx_0 = (global_idx // dim1_base) % SHAPE_0 - idx_1 = (global_idx // dim2_base) % SHAPE_1 - idx_2 = (global_idx // dim3_base) % SHAPE_2 - idx_3 = (global_idx // dim4_base) % SHAPE_3 - idx_4 = global_idx % SHAPE_4 - - offsets = idx_0 * STRIDE_0 + idx_1 * STRIDE_1 + idx_2 * STRIDE_2 + idx_3 * STRIDE_3 + idx_4 * STRIDE_4 - - x_val = tl.load(x_ptr + offsets, mask=masks) - ret = tl.erf(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_erf(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_erf_4d_5d(shape, dtype): - logging.debug(f"Testing erf for shape={shape}, dtype={dtype}") - - x = test_common.generate_tensor(shape, dtype).npu() - output = torch.empty_like(x) - - ans = torch_erf(x) - - shape_5d = list(shape) - strides_5d = list(x.stride()) - while len(shape_5d) < 5: - shape_5d.append(1) - strides_5d.append(1) - - MAX_BLOCK_ELEMENTS = 1024 - total_elements = x.numel() - - block_5d = [1] * 5 - for i in reversed(range(5)): - if shape_5d[i] == 0: - continue - max_block_i = min(shape_5d[i], MAX_BLOCK_ELEMENTS // (torch.prod(torch.tensor(block_5d)).item())) - block_5d[i] = max_block_i - if torch.prod(torch.tensor(block_5d)).item() >= MAX_BLOCK_ELEMENTS: - break - block_total = torch.prod(torch.tensor(block_5d)).item() - - grid = (triton.cdiv(total_elements, block_total), ) - logging.debug(f"Grid={grid}, block_5d={block_5d}, block_total={block_total}") - - triton_erf_4d_5d[grid](output, x, *block_5d, *shape_5d, *strides_5d, block_total) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_exp.py b/third_party/ascend/unittest/generalization_cases/test_exp.py deleted file mode 100644 index 52233f89c8..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_exp.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import numpy as np -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_pointwise(x0): - res = torch.exp(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.exp(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_exp_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.exp(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_pointwise(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_exp_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_pointwise(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_exp_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_exp2.py b/third_party/ascend/unittest/generalization_cases/test_exp2.py deleted file mode 100644 index b8e8aa3122..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_exp2.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_exp2(x0): - res = torch.pow(2, x0, out=None) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.exp2(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_exp2_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.exp2(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_exp2(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_exp2_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_exp2(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_exp2_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_expand_dims.py b/third_party/ascend/unittest/generalization_cases/test_expand_dims.py deleted file mode 100644 index f9a85f044c..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_expand_dims.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): - yidx = tl.arange(0, YB) - - X = tl.load(x_ptr + yidx) - - ret = tl.expand_dims(X, 1) - - oidx = yidx[:, None] + tl.arange(0, 1)[None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_expand_dims_1d(shape, dtype): - x = test_common.generate_tensor(shape, dtype).npu() - a = x.unsqueeze(1) - - output = torch.randint(1, (shape[0], 1), dtype=eval('torch.' + dtype)).npu() - - fn_npu_1d[1, 1, 1](output, x, YB=shape[0], ZB=1, debug=True) - - torch.testing.assert_close(output, a) - - -@triton.jit -def fn_npu_2d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): - yoffs = tl.program_id(0) - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) - - idx = yidx[:, None] * ZB + zidx[None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.expand_dims(X, 1) - - oidx = yidx[:, None, None] * ZB + tl.arange(0, 1)[None, :, None] + zidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_expand_dims_2d(shape, dtype): - x = test_common.generate_tensor(shape, dtype).npu() - a = x.unsqueeze(1) - - output = torch.randint(1, (shape[0], 1, shape[1]), dtype=eval('torch.' + dtype)).npu() - - if x.numel() * x.element_size() > 8192: - fn_npu_2d[shape[0], 1, 1](output, x, YB=1, ZB=shape[1]) - else: - fn_npu_2d[1, 1, 1](output, x, YB=shape[0], ZB=shape[1]) - - torch.testing.assert_close(output, a) - - -@triton.jit -def fn_npu_3d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.expand_dims(X, 2) - - oidx = xidx[:, None, None, None] * YB * ZB + yidx[None, :, None, None] * ZB + tl.arange( - 0, 1)[None, None, :, None] + zidx[None, None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -def test_expand_dims_3d(dtype, shape): - x = test_common.generate_tensor(shape, dtype).npu() - a = x.unsqueeze(2) - - output = torch.randint(1, (shape[0], shape[1], 1, shape[2]), dtype=eval('torch.' + dtype)).npu() - - fn_npu_3d[1, 1, 1](output, x, XB=shape[0], YB=shape[1], ZB=shape[2]) - - torch.testing.assert_close(output, a) - - -@triton.jit -def fn_npu_multi_d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr): - in_offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - in_offsets = in_offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - in_offsets = in_offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - in_offsets = in_offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - in_offsets = in_offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - X = tl.load(x_ptr + in_offsets) - - ret = tl.expand_dims(X, DIM).reshape(XB * YB * ZB * MB * NB) - - out_offstes = tl.arange(0, XB * YB * ZB * MB * NB) - tl.store(output_ptr + out_offstes, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('dtype', ['int8', 'float16', 'float32']) -@pytest.mark.parametrize('shape', [ - (2, 64, 16, 2), - (8, 8, 4, 2), - (8, 8, 4, 1), -]) -@pytest.mark.parametrize('dim', [-1, 0, 1, 2, 3]) -def test_npu_4d(shape, dtype, dim): - x = test_common.generate_tensor(shape, dtype).npu() - expected = x.unsqueeze(dim) - - output = torch.empty_like(expected) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - grid = (1, ) - fn_npu_multi_d[grid](output, x, *triton_shape, len(shape), dim) - - torch.testing.assert_close(output, expected) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('dtype', ['int8', 'float16', 'float32']) -@pytest.mark.parametrize('shape', [ - (2, 32, 3, 16, 2), - (8, 8, 3, 4, 2), - (8, 8, 3, 4, 1), -]) -@pytest.mark.parametrize('dim', [-1, 0, 1, 2, 3, 4]) -def test_npu_5d(shape, dtype, dim): - x = test_common.generate_tensor(shape, dtype).npu() - expected = x.unsqueeze(dim) - - output = torch.empty_like(expected) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - grid = (1, ) - fn_npu_multi_d[grid](output, x, *triton_shape, len(shape), dim) - - torch.testing.assert_close(output, expected) - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.expand_dims(X, 2) - - oidx = xidx[:, None, None, None] * YB * ZB + yidx[None, :, None, None] * ZB + tl.arange( - 0, 1)[None, None, :, None] + zidx[None, None, None, :] - - tl.store(output_ptr + oidx, ret) - - -paras = [ - ('bfloat16', eval('torch.bfloat16'), 1, 255, 8, 8, 4), - ('uint8', eval('torch.uint8'), 1, 125, 1, 256, 16), - ('uint16', eval('torch.uint16'), 1, 256, 2, 2, 3), - ('uint32', eval('torch.uint32'), 1, 256, 8, 8, 4), - ('uint64', eval('torch.uint64'), 1, 256, 8, 8, 4), - ('bool', eval('torch.bool'), 0, 2, 1, 1, 2), -] - - -@pytest.mark.parametrize('para_type,data_type,low,top,XB,YB,ZB', paras) -def test_expand_dims(para_type, data_type, low, top, XB, YB, ZB): - x = torch.randint(low=low, high=top, size=(XB, YB, ZB), dtype=data_type).npu() - a = x.unsqueeze(2) - output = torch.randint(1, (XB, YB, 1, ZB), dtype=data_type).npu() - fn_npu_[1, 1, 1](output, x, XB=XB, YB=YB, ZB=ZB, debug=True) - test_common.validate_cmp(para_type, output, a) diff --git a/third_party/ascend/unittest/generalization_cases/test_fdiv.py b/third_party/ascend/unittest/generalization_cases/test_fdiv.py deleted file mode 100644 index 099a82387f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_fdiv.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_fdiv(x0, x1): - res = x0 / x1 - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.fdiv(X, Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_fdiv_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.fdiv(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_fdiv(x, y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_fdiv_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_fdiv(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_fdiv_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_full_op.py b/third_party/ascend/unittest/generalization_cases/test_full_op.py deleted file mode 100644 index e74314d52a..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_full_op.py +++ /dev/null @@ -1,1096 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import test_common - -from test_common import TestUtils -import torch -import torch_npu -import pytest -import math -import random - - -@triton.jit -def fn_npu_int8_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.int8) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.int16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_uint32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.uint32) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.int32) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.int64) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.float16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.float32) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bf16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=100, dtype=tl.bfloat16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): - xidx = tl.arange(0, X) - yidx = tl.arange(0, Y) - zidx = tl.arange(0, Z) - ret = tl.full((X, Y, Z), value=0, dtype=tl.int1) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int8_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.int8) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.int16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_uint32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.uint32) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.int32) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.int64) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.float16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.float32) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bf16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=100, dtype=tl.bfloat16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): - yoffs = tl.program_id(0) * Y - yidx = tl.arange(0, Y) + yoffs - zidx = tl.arange(0, Z) - ret = tl.full((Y, Z), value=0, dtype=tl.int1) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int8_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.int8) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.int16) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_uint32_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.uint32) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.int32) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.int64) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.float16) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.float32) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bf16_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=100, dtype=tl.bfloat16) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_1d(output_ptr, Z: tl.constexpr): - zidx = tl.arange(0, Z) - ret = tl.full((Z, ), value=0, dtype=tl.int1) - oidx = zidx - tl.store(output_ptr + oidx, ret) - - -test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -test_shape1d = TestUtils.test_shape1d -test_shape2d = TestUtils.test_shape2d -test_shape3d = TestUtils.test_shape3d - -# 定义 dtype 到 (test_func, test_sigtype) 的映射 -dtype_mapping3d = { - 'int8': (fn_npu_int8_3d, torch.int8), - 'int16': (fn_npu_int16_3d, torch.int16), - 'int32': (fn_npu_int32_3d, torch.int32), - 'uint32': (fn_npu_uint32_3d, torch.uint32), - 'int64': (fn_npu_int64_3d, torch.int64), - 'float16': (fn_npu_fp16_3d, torch.float16), - 'float32': (fn_npu_fp32_3d, torch.float32), - 'bfloat16': (fn_npu_bf16_3d, torch.bfloat16), - 'bool': (fn_npu_bool_3d, torch.bool), -} -dtype_mapping2d = { - 'int8': (fn_npu_int8_2d, torch.int8), - 'int16': (fn_npu_int16_2d, torch.int16), - 'int32': (fn_npu_int32_2d, torch.int32), - 'uint32': (fn_npu_uint32_2d, torch.uint32), - 'int64': (fn_npu_int64_2d, torch.int64), - 'float16': (fn_npu_fp16_2d, torch.float16), - 'float32': (fn_npu_fp32_2d, torch.float32), - 'bfloat16': (fn_npu_bf16_2d, torch.bfloat16), - 'bool': (fn_npu_bool_2d, torch.bool), -} -dtype_mapping1d = { - 'int8': (fn_npu_int8_1d, torch.int8), - 'int16': (fn_npu_int16_1d, torch.int16), - 'int32': (fn_npu_int32_1d, torch.int32), - 'uint32': (fn_npu_uint32_1d, torch.uint32), - 'int64': (fn_npu_int64_1d, torch.int64), - 'float16': (fn_npu_fp16_1d, torch.float16), - 'float32': (fn_npu_fp32_1d, torch.float32), - 'bfloat16': (fn_npu_bf16_1d, torch.bfloat16), - 'bool': (fn_npu_bool_1d, torch.bool), -} - -# 生成测试用例 -testlist = [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape1d - for func, dtype in [dtype_mapping1d[sigtype]] # 直接解包映射结果 - ] - -testlist += [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape2d - for func, dtype in [dtype_mapping2d[sigtype]] # 直接解包映射结果 - ] - -testlist += [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape3d - for func, dtype in [dtype_mapping3d[sigtype]] # 直接解包映射结果 - ] - - -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist) -def test_npu(testfunc, sigtype, dtype, shape): - x = 0 - output = 0 - if len(shape) == 3: - if dtype == torch.bool: - x = torch.full((shape[0], shape[1], shape[2]), 0, dtype=dtype).npu() - else: - x = torch.full((shape[0], shape[1], shape[2]), 100, dtype=dtype).npu() - output = torch.randint(1, (shape[0], shape[1], shape[2]), dtype=dtype).npu() - testfunc[(1, 1, 1)](output, shape[0], shape[1], shape[2], debug=True) - if len(shape) == 2: - if dtype == torch.bool: - x = torch.full((shape[0], shape[1]), 0, dtype=dtype).npu() - else: - x = torch.full((shape[0], shape[1]), 100, dtype=dtype).npu() - output = torch.randint(1, (shape[0], shape[1]), dtype=dtype).npu() - shape0 = shape[0] - shape1 = shape[1] - if x.numel() * x.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - testfunc[grid](output, shape0, shape1, debug=True) - if len(shape) == 1: - if dtype == torch.bool: - x = torch.full((shape[0], ), 0, dtype=dtype).npu() - else: - x = torch.full((shape[0], ), 100, dtype=dtype).npu() - output = torch.randint(1, (shape[0], ), dtype=dtype).npu() - testfunc[1, 1, 1](output, shape[0], debug=True) - test_common.validate_cmp(sigtype, output, x) - - -@triton.jit -def fn_npu_multi_d(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - dtype = output_ptr.type.element_ty - - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - if (YB * ZB * MB * NB) == 1: - ret = tl.full((XB, ), value=100, dtype=dtype) - elif (ZB * MB * NB) == 1: - ret = tl.full((XB, YB), value=100, dtype=dtype) - elif (MB * NB) == 1: - ret = tl.full((XB, YB, ZB), value=100, dtype=dtype) - elif NB == 1: - ret = tl.full((XB, YB, ZB, MB), value=100, dtype=dtype) - else: - ret = tl.full((XB, YB, ZB, MB, NB), value=100, dtype=dtype) - - tl.store(output_ptr + offsets, ret) - - -testlist_multi_d = [ - (fn_npu_multi_d, 'float32', torch.float32, (4, 2, 16, 16)), - (fn_npu_multi_d, 'float32', torch.float32, (2, 4, 2, 16, 16)), - (fn_npu_multi_d, 'float32', torch.float16, (4, 2, 16, 16)), - (fn_npu_multi_d, 'float32', torch.float16, (2, 4, 2, 16, 16)), - (fn_npu_multi_d, 'float32', torch.int8, (4, 2, 16, 16)), - (fn_npu_multi_d, 'float32', torch.int8, (2, 4, 2, 16, 16)), -] - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist_multi_d) -def test_npu_4d_5d(testfunc, sigtype, dtype, shape): - x = torch.full(shape, 100, dtype=dtype).npu() - - print(f"shape = {x.shape}") - print(x.dtype) - print(torch.flatten(x)[0:16]) - - output = torch.randint(1, shape, dtype=dtype).npu() - - print(f"output.dtype={output.dtype}") - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - testfunc[(1, )](output, *triton_shape) - print(torch.flatten(output)[0:16]) - - test_common.validate_cmp(sigtype, output, x) - - -@triton.jit -def fn_npu_bf16_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.bfloat16) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int8_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int8) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int16) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int32) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int64) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.float16) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.float32) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - - ret = tl.full((A, B, C, D, E, F), value=0, dtype=tl.int1) - - oidx = (aidx[:, None, None, None, None, None] * B * C * D * E * F + - bidx[None, :, None, None, None, None] * C * D * E * F + cidx[None, None, :, None, None, None] * D * E * F + - didx[None, None, None, :, None, None] * E * F + eidx[None, None, None, None, :, None] * F + - fidx[None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -test_shape6d = TestUtils.test_shape6d -dtype_mapping6d = { - 'int8': (fn_npu_int8_6d, torch.int8), - 'int16': (fn_npu_int16_6d, torch.int16), - 'int32': (fn_npu_int32_6d, torch.int32), - 'int64': (fn_npu_int64_6d, torch.int64), - 'float16': (fn_npu_fp16_6d, torch.float16), - 'float32': (fn_npu_fp32_6d, torch.float32), - 'bfloat16': (fn_npu_bf16_6d, torch.bfloat16), - 'bool': (fn_npu_bool_6d, torch.bool), -} - -testlist6d = [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape6d - for func, dtype in [dtype_mapping6d[sigtype]]] - - -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist6d) -def test_npu_6d(testfunc, sigtype, dtype, shape): - x = 0 - output = 0 - if len(shape) == 6: - if dtype == torch.bool: - x = torch.full((shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]), 0, dtype=dtype).npu() - else: - x = torch.full(shape, 10, dtype=dtype).npu() - output = torch.randint(1, shape, dtype=dtype).npu() - testfunc[1, 1, 1](output, *shape, debug=True) - test_common.validate_cmp(sigtype, output, x) - - -@triton.jit -def fn_npu_bf16_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.bfloat16) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int8_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int8) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int16) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int32) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int64) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.float16) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.float32) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - - ret = tl.full((A, B, C, D, E, F, G), value=0, dtype=tl.int1) - - oidx = (aidx[:, None, None, None, None, None, None] * B * C * D * E * F * G + - bidx[None, :, None, None, None, None, None] * C * D * E * F * G + - cidx[None, None, :, None, None, None, None] * D * E * F * G + - didx[None, None, None, :, None, None, None] * E * F * G + - eidx[None, None, None, None, :, None, None] * F * G + fidx[None, None, None, None, None, :, None] * G + - gidx[None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -test_shape7d = TestUtils.test_shape7d -dtype_mapping7d = { - 'int8': (fn_npu_int8_7d, torch.int8), - 'int16': (fn_npu_int16_7d, torch.int16), - 'int32': (fn_npu_int32_7d, torch.int32), - 'int64': (fn_npu_int64_7d, torch.int64), - 'float16': (fn_npu_fp16_7d, torch.float16), - 'float32': (fn_npu_fp32_7d, torch.float32), - 'bfloat16': (fn_npu_bf16_7d, torch.bfloat16), - 'bool': (fn_npu_bool_7d, torch.bool), -} - -testlist7d = [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape7d - for func, dtype in [dtype_mapping7d[sigtype]]] - - -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist7d) -def test_npu_7d(testfunc, sigtype, dtype, shape): - x = 0 - output = 0 - if len(shape) == 7: - if dtype == torch.bool: - x = torch.full((shape[0], shape[1], shape[2], shape[3], shape[4], shape[5], shape[6]), 0, dtype=dtype).npu() - else: - x = torch.full(shape, 10, dtype=dtype).npu() - output = torch.randint(1, shape, dtype=dtype).npu() - testfunc[1, 1, 1](output, *shape, debug=True) - test_common.validate_cmp(sigtype, output, x) - - -@triton.jit -def fn_npu_bf16_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.bfloat16) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int8_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int8) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int16_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int16) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int32_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int32) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_int64_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int64) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp16_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.float16) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_fp32_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.float32) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_bool_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr, E: tl.constexpr, - F: tl.constexpr, G: tl.constexpr, H: tl.constexpr): - - aidx = tl.arange(0, A) - bidx = tl.arange(0, B) - cidx = tl.arange(0, C) - didx = tl.arange(0, D) - eidx = tl.arange(0, E) - fidx = tl.arange(0, F) - gidx = tl.arange(0, G) - hidx = tl.arange(0, H) - - ret = tl.full((A, B, C, D, E, F, G, H), value=0, dtype=tl.int1) - - oidx = (aidx[:, None, None, None, None, None, None, None] * B * C * D * E * F * G * H + - bidx[None, :, None, None, None, None, None, None] * C * D * E * F * G * H + - cidx[None, None, :, None, None, None, None, None] * D * E * F * G * H + - didx[None, None, None, :, None, None, None, None] * E * F * G * H + - eidx[None, None, None, None, :, None, None, None] * F * G * H + - fidx[None, None, None, None, None, :, None, None] * G * H + - gidx[None, None, None, None, None, None, :, None] * H + hidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -test_shape8d = TestUtils.test_shape8d -dtype_mapping8d = { - 'int8': (fn_npu_int8_8d, torch.int8), - 'int16': (fn_npu_int16_8d, torch.int16), - 'int32': (fn_npu_int32_8d, torch.int32), - 'int64': (fn_npu_int64_8d, torch.int64), - 'float16': (fn_npu_fp16_8d, torch.float16), - 'float32': (fn_npu_fp32_8d, torch.float32), - 'bfloat16': (fn_npu_bf16_8d, torch.bfloat16), - 'bool': (fn_npu_bool_8d, torch.bool), -} - -testlist8d = [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape8d - for func, dtype in [dtype_mapping8d[sigtype]]] - - -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist8d) -def test_npu_8d(testfunc, sigtype, dtype, shape): - x = 0 - output = 0 - if len(shape) == 8: - if dtype == torch.bool: - x = torch.full((shape[0], shape[1], shape[2], shape[3], shape[4], shape[5], shape[6], shape[7]), 0, - dtype=dtype).npu() - else: - x = torch.full(shape, 10, dtype=dtype).npu() - output = torch.randint(1, shape, dtype=dtype).npu() - testfunc[1, 1, 1](output, *shape, debug=True) - test_common.validate_cmp(sigtype, output, x) diff --git a/third_party/ascend/unittest/generalization_cases/test_ge_op.py b/third_party/ascend/unittest/generalization_cases/test_ge_op.py deleted file mode 100644 index d23da78f38..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_ge_op.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def triton_ge_3d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 >= x1 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_ge_2d(in_ptr0, in_ptr1, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 >= x1 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_ge_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 >= x1 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_ge_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val >= y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_ge(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.where(torch.ge(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - triton_ge_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_ge_2d[grid](x0, x1, output, shape0, shape1) - if len(shape) == 1: - triton_ge_1d[1, 1, 1](x0, x1, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_ge_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.where(torch.ge(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_ge_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_add.py b/third_party/ascend/unittest/generalization_cases/test_general_add.py deleted file mode 100644 index 8ddb6adeb0..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_add.py +++ /dev/null @@ -1,438 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import torch -import test_common -from test_common import TestUtils -import logging -import numpy as np - - -@triton.jit -def triton_add(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X + Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_add_broadcast(in_ptr0, in_ptr1, out_ptr0, X_SHAPE_0: tl.constexpr, X_SHAPE_1: tl.constexpr, - X_SHAPE_2: tl.constexpr, X_SHAPE_3: tl.constexpr, X_SHAPE_4: tl.constexpr, - Y_SHAPE_0: tl.constexpr, Y_SHAPE_1: tl.constexpr, Y_SHAPE_2: tl.constexpr, - Y_SHAPE_3: tl.constexpr, Y_SHAPE_4: tl.constexpr): - x_idx0 = tl.arange(0, X_SHAPE_0) - x_idx1 = tl.arange(0, X_SHAPE_1) - x_idx2 = tl.arange(0, X_SHAPE_2) - x_idx3 = tl.arange(0, X_SHAPE_3) - x_idx4 = tl.arange(0, X_SHAPE_4) - - y_idx0 = tl.arange(0, Y_SHAPE_0) - y_idx1 = tl.arange(0, Y_SHAPE_1) - y_idx2 = tl.arange(0, Y_SHAPE_2) - y_idx3 = tl.arange(0, Y_SHAPE_3) - y_idx4 = tl.arange(0, Y_SHAPE_4) - - xidx = x_idx0[:, None, None, None, None] * X_SHAPE_1 * X_SHAPE_2 * X_SHAPE_3 * X_SHAPE_4 + \ - x_idx1[None, :, None, None, None] * X_SHAPE_2 * X_SHAPE_3 * X_SHAPE_4 + \ - x_idx2[None, None, :, None, None] * X_SHAPE_3 * X_SHAPE_4 + \ - x_idx3[None, None, None, :, None] * X_SHAPE_4 + x_idx4[None, None, None, None, :] - - yidx = y_idx0[:, None, None, None, None] * Y_SHAPE_1 * Y_SHAPE_2 * Y_SHAPE_3 * Y_SHAPE_4 + \ - y_idx1[None, :, None, None, None] * Y_SHAPE_2 * Y_SHAPE_3 * Y_SHAPE_4 + \ - y_idx2[None, None, :, None, None] * Y_SHAPE_3 * Y_SHAPE_4 + \ - y_idx3[None, None, None, :, None] * Y_SHAPE_4 + y_idx4[None, None, None, None, :] - - X = tl.load(in_ptr0 + xidx) - Y = tl.load(in_ptr1 + yidx) - ret = X + Y - - tl.store(out_ptr0 + xidx, ret) - - -@triton.jit -def triton_add_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val + y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_add(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - ans = x + y - output = torch.zeros_like(ans) - - if len(shape) == 1: - triton_add[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - triton_add[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - triton_add[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - triton_add[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - triton_add[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_add[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - triton_add[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_add_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x + y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_add_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -def promote_dtype(x_dtype, y_dtype): - """ - 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 - """ - # 如果两个数据类型一致,直接返回 - if x_dtype == y_dtype: - return y_dtype - - # 构建类型的优先级列表(从低到高) - priority = [ - torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.bfloat16, torch.float32 - ] - - # 查找两种类型在优先级列表中的位置 - x_priority = priority.index(x_dtype) - y_priority = priority.index(y_dtype) - - # 如果y的优先级比x小,则提升到x的类型 - if y_priority < x_priority: - return x_dtype - else: - return y_dtype - - -@pytest.mark.parametrize('param_list', - [[(5, 1, 1, 1, 1), - (5, 1, 1, 2, 1)], [(2, 1), (2, 4)], [(2, 1, 1), (2, 4, 2)], [(2, 1, 1, 1), (2, 4, 2, 2)], - [(2, 1, 1, 1, 1), - (2, 4, 2, 2, 2)], [(1, ), (4, )], [(1, 2, 1), (1, 2, 3)], [(1, 1, 1, 1), (7, 1, 1, 1)]]) -@pytest.mark.parametrize('x_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -@pytest.mark.parametrize('y_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -def test_add_broadcast(param_list, x_dtype_str, y_dtype_str): - x_shape, y_shape = param_list - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x = test_common.generate_tensor(x_shape, x_dtype_str).npu() - y = test_common.generate_tensor(y_shape, y_dtype_str).npu() - if y.numel() > x.numel(): - tmp = y - y = x - x = tmp - ans = x + y - while x.dim() < 5: - x = x.unsqueeze(-1) - while y.dim() < 5: - y = y.unsqueeze(-1) - bf2fpFlag = False - out_dtype = promote_dtype(x_dtype, y_dtype) - if (x_dtype == torch.bfloat16 and y_dtype == torch.float16) or \ - (x_dtype == torch.float16 and y_dtype == torch.bfloat16): - out_dtype = torch.float32 - bf2fpFlag = True - out_dtype = str(out_dtype).split('.')[-1] - out = test_common.generate_tensor(x.shape, out_dtype).npu() - - triton_add_broadcast[1, 1, 1](x, y, out, *x.shape, *y.shape) - while out.dim() > ans.dim(): - out = out.squeeze(-1) - - if bf2fpFlag: - torch.testing.assert_close(out, ans, rtol=1e-3, atol=1e-3) - else: - torch.testing.assert_close(out, ans) - - -@triton.jit -def add_5d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr, - NB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1 * NB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1 * NB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1 * NB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] * NB1 - offsets1 = offsets1[:, :, :, :, None] + tl.arange(0, NB1)[None, None, None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tmp2 = tl.load(out_ptr + offsets1) - out = tmp2 + tmp1 + tmp0 - tl.store(out_ptr + offsets1, out) - - -@triton.jit -def add_4d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB) - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1) - offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tmp2 = tl.load(out_ptr + offsets1) - out = tmp2 + tmp1 + tmp0 - tl.store(out_ptr + offsets1, out) - - -@triton.jit -def add_3d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XB1: tl.constexpr, - YB1: tl.constexpr, ZB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB) - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] - - offsets1 = tl.arange(0, XB1) * (YB1 * ZB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1) - offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tmp2 = tl.load(out_ptr + offsets1) - out = tmp2 + tmp1 + tmp0 - tl.store(out_ptr + offsets1, out) - - -@triton.jit -def add_2d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, XB1: tl.constexpr, YB1: tl.constexpr): - offsets = tl.arange(0, XB) * (YB) - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] - - offsets1 = tl.arange(0, XB1) * (YB1) - offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] - - tmp0 = tl.load(x_ptr + offsets) - tmp1 = tl.load(y_ptr + offsets1) - tmp2 = tl.load(out_ptr + offsets1) - out = tmp2 + tmp1 + tmp0 - tl.store(out_ptr + offsets1, out) - - -@pytest.mark.parametrize('param_list', [ - [(5, 1, 1, 1, 1), (5, 1, 1, 2, 1)], -]) -@pytest.mark.parametrize('x_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -@pytest.mark.parametrize('y_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -def test_add_2d_to_5d(x_dtype_str, y_dtype_str, param_list): - x0_shape, y_shape = param_list - ndim = max(len(x0_shape), len(y_shape)) - # 获取原始类型 - x_dtype = eval('torch.' + x_dtype_str) - y_dtype = eval('torch.' + y_dtype_str) - - x0 = test_common.generate_tensor(x0_shape, x_dtype_str).npu() - y = test_common.generate_tensor(y_shape, y_dtype_str).npu() - - out_dtype = promote_dtype(x_dtype, y_dtype) - if out_dtype == torch.bfloat16: - out_dtype = torch.float32 - out = torch.full(y_shape, 0, dtype=out_dtype).npu() - - x0_temp = x0.clone() - y_temp = y.clone() - out_temp = out.clone() - - triton_shape = [*x0_shape] - while len(triton_shape) < ndim: - triton_shape.append(1) - - triton_shape1 = [*y_shape] - while len(triton_shape1) < ndim: - triton_shape1.append(1) - - # 按维度分支 - if ndim == 2: - XB, YB = triton_shape - XB1, YB1 = triton_shape1 - - add_2d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - XB1=XB1, - YB1=YB1, - ) - - elif ndim == 3: - XB, YB, ZB = triton_shape - XB1, YB1, ZB1 = triton_shape1 - - add_3d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - ) - - elif ndim == 4: - XB, YB, ZB, MB = triton_shape - XB1, YB1, ZB1, MB1 = triton_shape1 - - add_4d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - ) - - elif ndim == 5: - XB, YB, ZB, MB, NB = triton_shape - XB1, YB1, ZB1, MB1, NB1 = triton_shape1 - - add_5d[(1, )]( - x_ptr=x0, - y_ptr=y, - out_ptr=out, - XB=XB, - YB=YB, - ZB=ZB, - MB=MB, - NB=NB, - XB1=XB1, - YB1=YB1, - ZB1=ZB1, - MB1=MB1, - NB1=NB1, - ) - - else: - raise ValueError(f"Unsupported tensor dim: {ndim}") - expected = out_temp + y_temp + x0_temp - torch.testing.assert_close(out, expected) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['uint16', 'uint32', 'uint64']) -def test_add_uint(shape, dtype): - torch_dtype = eval('torch.' + dtype) - np_x0 = test_common.generate_numpy(shape, dtype) - np_x1 = test_common.generate_numpy(shape, dtype) - np_x2 = test_common.generate_numpy(shape, dtype) - - x0 = torch.from_numpy(np_x0).to(torch_dtype).npu() - x1 = torch.from_numpy(np_x1).to(torch_dtype).npu() - x2 = torch.from_numpy(np_x2).to(torch_dtype).npu() - - #numpy result - ans_numpy = np_x0 + np_x1 - z_ref1 = torch.from_numpy(ans_numpy).npu() - - triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_add[1, 1, shape[0]](triton_res, x0, x1, x2, 1, 1, 1, 1, 1, shape[0]) - test_common.validate_cmp(dtype, z_ref1, triton_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_clamp.py b/third_party/ascend/unittest/generalization_cases/test_general_clamp.py deleted file mode 100644 index d6d91c20e7..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_clamp.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging - - -def torch_clamp(x0, min_, max_): - res = torch.clamp(x0, min_, max_) - return res - - -@triton.jit -def tt_clamp_1d(in_ptr, out_ptr, min_ptr, max_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - idx = tl.arange(0, XB) - - x = tl.load(in_ptr + idx) - min_ = tl.load(min_ptr + idx) - max_ = tl.load(max_ptr + idx) - ret = tl.clamp(x, min_, max_) - - tl.store(out_ptr + idx, ret) - - -@triton.jit -def tt_clamp_2d(in_ptr, out_ptr, min_ptr, max_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - x = tl.load(in_ptr + idx) - min_ = tl.load(min_ptr + idx) - max_ = tl.load(max_ptr + idx) - ret = tl.clamp(x, min_, max_) - - tl.store(out_ptr + idx, ret) - - -@triton.jit -def tt_clamp_3d(in_ptr, out_ptr, min_ptr, max_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - min_ = tl.load(min_ptr + idx) - max_ = tl.load(max_ptr + idx) - ret = tl.clamp(x, min_, max_) - - tl.store(out_ptr + idx, ret) - - -@triton.jit -def triton_clamp_4d_5d(x_ptr, output_ptr, min_ptr, max_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - min_ = tl.load(min_ptr + offsets) - max_ = tl.load(max_ptr + offsets) - ret = tl.clamp(x_val, min_, max_) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_clamp(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - torch.manual_seed(0) - x = test_common.generate_tensor(shape, dtype).npu() - a = test_common.generate_tensor(shape, dtype) - b = test_common.generate_tensor(shape, dtype) - min_ = torch.min(a, b).npu() - max_ = torch.max(a, b).npu() - - grid = (1, 1, 1) - - y_cal = torch.empty(shape, dtype=eval('torch.' + dtype), device="npu") - - y_ref = torch_clamp(x, min_, max_) - if len(shape) == 1: - tt_clamp_1d[grid](x, y_cal, min_, max_, x.numel(), 1, 1, x.numel(), 1, 1) - elif len(shape) == 2: - xnumel, ynumel, znumel = shape + (1, ) - XB, YB, ZB = xnumel, ynumel, znumel - if x.numel() * x.element_size() > 8192: - grid = (1, ynumel, 1) - YB = 1 - tt_clamp_2d[grid](x, y_cal, min_, max_, xnumel, ynumel, znumel, XB, YB, ZB) - - elif len(shape) == 3: - xnumel, ynumel, znumel = shape - XB, YB, ZB = xnumel, ynumel, znumel - tt_clamp_3d[grid](x, y_cal, min_, max_, xnumel, ynumel, znumel, XB, YB, ZB) - - test_common.validate_cmp(dtype, y_cal, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_clamp_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - torch.manual_seed(0) - x = test_common.generate_tensor(shape, dtype).npu() - a = test_common.generate_tensor(shape, dtype) - b = test_common.generate_tensor(shape, dtype) - min_ = torch.min(a, b).npu() - max_ = torch.max(a, b).npu() - - output = torch.empty(shape, dtype=eval('torch.' + dtype), device="npu") - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_clamp(x, min_, max_) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_clamp_4d_5d[grid](x, output, min_, max_, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_div.py b/third_party/ascend/unittest/generalization_cases/test_general_div.py deleted file mode 100644 index fd4e252177..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_div.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import torch -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def triton_div(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X / Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_div_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val / y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) # some shape with int8 over ub -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_div(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - y[y == 0] = 1 - - ans = x / y - output = torch.zeros_like(ans) - if len(shape) == 1: - triton_div[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - triton_div[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - triton_div[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - triton_div[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - triton_div[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_div[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - triton_div[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - # change dtype beacuse of triton processing, triton div op will change from int to float - if dtype in ['int8', 'int16', 'int32', 'int64']: - dtype = 'float32' - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_div_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - y[y == 0] = 1 - - new_shape = shape - if dtype == 'int8' or dtype == 'int16' or dtype == 'int32' or dtype == 'int64': - output = torch.randint(1, new_shape, dtype=eval('torch.float32')).npu() - dtype = 'float32' - else: - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - ans = x / y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_div_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_floor.py b/third_party/ascend/unittest/generalization_cases/test_general_floor.py deleted file mode 100644 index 38bc1621ac..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_floor.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X + tl.floor(Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_floor_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val + tl.floor(y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_floor(dtype, shape): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - y = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x + torch.floor(y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_floor_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x + torch.floor(y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_floor_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_floordiv.py b/third_party/ascend/unittest/generalization_cases/test_general_floordiv.py deleted file mode 100644 index 9e5cf1c6ef..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_floordiv.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import torch -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def triton_floordiv(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X // Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_floordiv_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val // y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) # some shape with int8 over ub -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_floordiv(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - z = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - - new_shape = shape - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - y[y == 0] = 1 - ans = x // y - ans_mask = (x.to(torch.int64) % y.to(torch.int64) != 0) & (~((x ^ y) > 0)).to(ans.dtype) - ans = ans + ans_mask - - if len(shape) == 1: - triton_floordiv[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - triton_floordiv[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - triton_floordiv[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - triton_floordiv[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - triton_floordiv[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_floordiv[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - triton_floordiv[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_floordiv_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - new_shape = shape - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - y[y == 0] = 1 - ans = x // y - ans_mask = (x.to(torch.int64) % y.to(torch.int64) != 0) & (~((x ^ y) > 0)).to(ans.dtype) - ans = ans + ans_mask - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_floordiv_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = y.masked_fill(y == 0, 1) - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - triton_floordiv[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_fma.py b/third_party/ascend/unittest/generalization_cases/test_general_fma.py deleted file mode 100644 index eb255558ca..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_fma.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - Z = tl.load(z_ptr + idx) - - ret = tl.fma(X, Y, Z) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_fma_4d_5d(output_ptr, x_ptr, y_ptr, z_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - z_val = tl.load(z_ptr + offsets, masks) - ret = tl.fma(x_val, y_val, z_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) # math.fma do not support int dtype -def test_fma(dtype, shape): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if (x.dtype == torch.bfloat16): - ans = x.to(torch.float32) * y.to(torch.float32) + z.to(torch.float32) - ans = ans.to(torch.bfloat16) - else: - ans = x * y + z - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_fma_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if (x.dtype == torch.bfloat16): - ans = x.to(torch.float32) * y.to(torch.float32) + z.to(torch.float32) - ans = ans.to(torch.bfloat16) - else: - ans = x * y + z - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_fma_4d_5d[grid](output, x, y, z, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_gather.py b/third_party/ascend/unittest/generalization_cases/test_general_gather.py deleted file mode 100644 index ee2b8bc437..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_gather.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import numpy as np -import torch -import torch_npu -import triton -import triton.language as tl -import test_common -import pytest -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - - -@pytest.mark.parametrize("src_shape, indices_shape, axis", [ - ([2, 2], [4, 2], 0), - ([3, 3], [1, 3], 0), - ([3, 4], [4, 4], 0), - ([4, 4], [8, 4], 0), - ([4, 32], [4, 16], 1), - ([4, 64], [4, 32], 1), - ([128, 64], [128, 128], 1), -]) -def test_gather(src_shape, indices_shape, axis): - - @triton.jit - def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, - src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, - idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, - out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, - out_stride1: tl.constexpr): - src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) - src = tl.load(src_ptr + src_offs) - - idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) - idx = tl.load(idx_ptr + idx_offs) - - out = tl.gather(src, idx, axis) - - out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) - tl.store(out_ptr + out_offs, out) - - def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): - output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) - gather_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], - src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), - indices.stride(1), output.shape[0], output.shape[1], output.stride(0), output.stride(1)) - return output - - DEV = "npu" - src = torch.randn(src_shape, device=DEV) - indices = torch.randint(0, src.shape[axis], indices_shape, device=DEV) - - dtype_size = get_dtype_size('int32') - if dtype_size * math.prod(src.shape) >= (TestUtils.ub_size / 8): - print(f"dtype:int32 shape:{src.shape} mem overflow") - return - - ref = torch.gather(src, axis, indices) - result = triton_gather(src, axis, indices) - torch.testing.assert_close(result, ref, rtol=0, atol=0) - - -@triton.jit -def gather_kernel_multi_d(src_ptr, idx_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - MB: tl.constexpr, NB: tl.constexpr, I_XB: tl.constexpr, I_YB: tl.constexpr, - I_ZB: tl.constexpr, I_MB: tl.constexpr, I_NB: tl.constexpr, DIMS: tl.constexpr, - AXIS: tl.constexpr): - in_offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - in_offsets = in_offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - in_offsets = in_offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - in_offsets = in_offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - in_offsets = in_offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - idx_offsets = tl.arange(0, I_XB) * (I_YB * I_ZB * I_MB * I_NB) - if DIMS > 1: - idx_offsets = idx_offsets[:, None] + tl.arange(0, I_YB)[None, :] * (I_ZB * I_MB * I_NB) - if DIMS > 2: - idx_offsets = idx_offsets[:, :, None] + tl.arange(0, I_ZB)[None, None, :] * (I_MB * I_NB) - if DIMS > 3: - idx_offsets = idx_offsets[:, :, :, None] + tl.arange(0, I_MB)[None, None, None, :] * I_NB - if DIMS > 4: - idx_offsets = idx_offsets[:, :, :, :, None] + tl.arange(0, I_NB)[None, None, None, None, :] - - src = tl.load(src_ptr + in_offsets) - idx = tl.load(idx_ptr + idx_offsets) - - out = tl.gather(src, idx, AXIS) - - tl.store(out_ptr + idx_offsets, out) - - -def triton_gather_multi_d(src: torch.Tensor, axis: int, indices: torch.Tensor): - output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) - - s_shape = [*(src.shape)] - while len(s_shape) < 5: - s_shape.append(1) - i_shape = [*(indices.shape)] - while len(i_shape) < 5: - i_shape.append(1) - gather_kernel_multi_d[(1, )](src, indices, output, *s_shape, *i_shape, len(src.shape), axis) - return output - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize("src_shape, indices_shape, axis", [ - ((2, 2, 4, 8), (2, 2, 4, 8), 0), - ((2, 2, 4, 1), (2, 2, 4, 1), 3), - ((2, 3, 4, 8), (2, 3, 4, 8), 1), - ((2, 3, 4, 8), (2, 3, 4, 8), 2), - ((2, 2, 2, 4, 1), (2, 2, 2, 4, 1), 4), - ((2, 2, 2, 4, 8), (2, 2, 2, 4, 8), 1), - ((2, 2, 3, 4, 8), (2, 2, 3, 4, 8), 2), - ((2, 2, 3, 4, 8), (2, 2, 3, 4, 8), 0), -]) -def test_gather_4d_5d(src_shape, indices_shape, axis): - DEV = "npu" - src = torch.randn(src_shape, device=DEV) - indices = torch.randint(0, src.shape[axis], indices_shape, device=DEV) - - ref = torch.gather(src, axis, indices) - result = triton_gather_multi_d(src, axis, indices) - torch.testing.assert_close(result, ref, rtol=0, atol=0) - - -if __name__ == "__main__": - test_gather([4, 64], [4, 32], 1) - print("success: test_gather") diff --git a/third_party/ascend/unittest/generalization_cases/test_general_interleave.py b/third_party/ascend/unittest/generalization_cases/test_general_interleave.py deleted file mode 100644 index cce95fef86..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_interleave.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import logging -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - zoffs2 = tl.program_id(2) * ZB * 2 - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - zidx2 = tl.arange(0, 2 * ZB) + zoffs2 - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.interleave(X, Y) - - oidx = xidx[:, None, None] * YNUMEL * ZNUMEL * 2 + yidx[None, :, None] * ZNUMEL * 2 + zidx2[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def triton_interleave_4d( - output_ptr, - x_ptr, - y_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] - tmp4 = tl.arange(0, 2 * BLOCK_3)[None, None, None, :] - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - - ret = tl.interleave(x_val, y_val) - - out_offsets = pid + tmp0 * STRIDE_0 * 2 + tmp1 * STRIDE_1 * 2 + tmp2 * STRIDE_2 * 2 + tmp4 * STRIDE_3 - out_masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp4 < 2 * SHAPE_3) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@triton.jit -def triton_interleave_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] - tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] - tmp5 = tl.arange(0, 2 * BLOCK_4)[None, None, None, None, :] - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - - ret = tl.interleave(x_val, y_val) - - out_offsets = pid + tmp0 * STRIDE_0 * 2 + tmp1 * STRIDE_1 * 2 + tmp2 * STRIDE_2 * 2 + tmp3 * STRIDE_3 * 2 + tmp5 * STRIDE_4 - out_masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp5 < 2 * SHAPE_4) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_interleave(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - new_shape = shape[:-1] + (2 * shape[-1], ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.stack((x, y), dim=-1).reshape(new_shape) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_interleave_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape[:-1] + (2 * shape[-1], ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.stack((x, y), dim=-1).reshape(new_shape) - - blocks = list(x.size()) - strides = list(x.stride()) - - grid = (1, ) - if len(shape) == 4: - triton_interleave_4d[grid](output, x, y, *blocks, *blocks, *strides) - else: - triton_interleave_5d[grid](output, x, y, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans, output) - - -@triton.jit -def fn_npu_dtype(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.interleave(X, Y) - - oidx = xidx[:, None, None] * YB * ZB * 2 + yidx[None, :, None] * ZB * 2 + tl.arange(0, 2 * ZB)[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('para_type,data_type,XB,YB,ZB', [ - ('bfloat16', eval('torch.bfloat16'), 8, 8, 4), - ('uint8', eval('torch.uint8'), 1, 256, 16), - ('bool', eval('torch.bool'), 1, 1, 2), -]) -def test_interleave_u(para_type, data_type, XB, YB, ZB): - x = torch.full((XB, YB, ZB), 100, dtype=data_type).npu() - y = torch.full((XB, YB, ZB), 30, dtype=data_type).npu() - output = torch.randint(1, (XB, YB, ZB * 2), dtype=data_type).npu() - ans = torch.stack((x, y), dim=-1).reshape(XB, YB, ZB * 2) - fn_npu_dtype[1, 1, 1](output, x, y, XB, YB, ZB) - test_common.validate_cmp(para_type, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_join.py b/third_party/ascend/unittest/generalization_cases/test_general_join.py deleted file mode 100644 index a1d8cd3cd0..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_join.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - ret = tl.join(X, Y) - - oidx = xidx[:, None, None, None] * YNUMEL * ZNUMEL * 2 + yidx[None, :, None, None] * ZNUMEL * 2 + \ - zidx[None, None, :, None] * 2 + tl.arange(0, 2)[None, None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def triton_join_4d( - output_ptr, - x_ptr, - y_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] - - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - - ret = tl.join(x_val, y_val) - - out_tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] - out_tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] - out_tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] - out_tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] - out_tmp4 = tl.arange(0, 2)[None, None, None, None, :] - out_offsets = pid + out_tmp0 * STRIDE_0 * 2 + out_tmp1 * STRIDE_1 * 2 + out_tmp2 * STRIDE_2 * 2 \ - + out_tmp3 * STRIDE_3 * 2 + out_tmp4 - out_masks = (out_tmp0 < SHAPE_0) & (out_tmp1 < SHAPE_1) & (out_tmp2 < SHAPE_2) \ - & (out_tmp3 < SHAPE_3) & (out_tmp4 < 2) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@triton.jit -def triton_join_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] - tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] - - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - - ret = tl.join(x_val, y_val) - - out_tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None, None] - out_tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None, None] - out_tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None, None] - out_tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None, None] - out_tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :, None] - out_tmp5 = tl.arange(0, 2)[None, None, None, None, None, :] - out_offsets = pid + out_tmp0 * STRIDE_0 * 2 + out_tmp1 * STRIDE_1 * 2 + out_tmp2 * STRIDE_2 * 2 \ - + out_tmp3 * STRIDE_3 * 2 + out_tmp4 * STRIDE_4 * 2 + out_tmp5 - out_masks = (out_tmp0 < SHAPE_0) & (out_tmp1 < SHAPE_1) & (out_tmp2 < SHAPE_2) \ - & (out_tmp3 < SHAPE_3) & (out_tmp4 < SHAPE_4) & (out_tmp5 < 2) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_join(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - new_shape = shape + (2, ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.stack((x, y), dim=-1) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_join_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape + (2, ), dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.stack((x, y), dim=-1) - - blocks = list(x.size()) - strides = list(x.stride()) - - grid = (1, ) - if len(shape) == 4: - triton_join_4d[grid](output, x, y, *blocks, *blocks, *strides) - else: - triton_join_5d[grid](output, x, y, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans, output) - - -@triton.jit -def fn_npu_dtype(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - - idx = xidx[:, None] * YB + yidx[None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.join(X, Y) - - oidx = xidx[:, None, None] * YB * 2 + yidx[None, :, None] * 2 + tl.arange(0, 2)[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('para_type,data_type,XB,YB,ZB', [ - ('bfloat16', eval('torch.bfloat16'), 8, 8, 4), - ('uint8', eval('torch.uint8'), 1, 256, 16), - ('bool', eval('torch.bool'), 1, 1, 2), -]) -def test_join_u(para_type, data_type, XB, YB, ZB): - x = torch.full((XB, YB), 100, dtype=data_type).npu() - y = torch.full((XB, YB), 30, dtype=data_type).npu() - - ans = torch.stack((x, y), dim=-1) - output = torch.randint(1, (XB, YB, 2), dtype=data_type).npu() - fn_npu_dtype[1, 1, 1](output, x, y, XB, YB, ZB, debug=True) - test_common.validate_cmp(para_type, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_log.py b/third_party/ascend/unittest/generalization_cases/test_general_log.py deleted file mode 100644 index 8bcafa2274..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_log.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import triton.language.extra.ascend.libdevice as libdevice -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.log(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_log_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.log(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_log(dtype, shape): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - y = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - z = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.log(x).to(eval('torch.' + dtype)) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_log_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.log(x).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_log_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_log2.py b/third_party/ascend/unittest/generalization_cases/test_general_log2.py deleted file mode 100644 index 0a0321466d..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_log2.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import triton.language.extra.ascend.libdevice as libdevice -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.log2(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_log2_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.log2(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_log2(dtype, shape): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - y = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - z = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.log2(x).to(eval('torch.' + dtype)) - - if len(shape) == 1: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, shape[0], 1, 1, shape[0]) - elif len(shape) == 2: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - mx = max(shape[0], shape[1], shape[2]) - if mx == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif mx == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_dtypes = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_dtypes) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_log2_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_log2_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.log2(x).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_log2_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_maximum.py b/third_party/ascend/unittest/generalization_cases/test_general_maximum.py deleted file mode 100644 index e7de75d0d3..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_maximum.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_maximum(x, y): - return torch.maximum(x, y) - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.maximum(X, Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_maximum_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.maximum(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) -def test_maximum(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_maximum(x, y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -def test_maximum_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_maximum(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_maximum_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_minimum.py b/third_party/ascend/unittest/generalization_cases/test_general_minimum.py deleted file mode 100644 index 1b419c0e5e..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_minimum.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_minimum(x, y): - return torch.minimum(x, y) - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.minimum(X, Y) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_minimum_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.minimum(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) -def test_minimum(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_minimum(x, y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) -def test_minimum_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_minimum(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_minimum_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_mul.py b/third_party/ascend/unittest/generalization_cases/test_general_mul.py deleted file mode 100644 index 29666d72e4..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_mul.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import torch -import test_common -from test_common import TestUtils -import logging -import numpy as np - - -@triton.jit -def triton_mul(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X * Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_mul_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val * y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) # some shape with int8 over ub -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_mul(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - ans = x * y - output = torch.zeros_like(ans) - - if len(shape) == 1: - triton_mul[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - triton_mul[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - triton_mul[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - triton_mul[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - triton_mul[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_mul[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - triton_mul[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_mul_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x * y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_mul_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['uint16', 'uint32', 'uint64']) -def test_mul_uint(shape, dtype): - torch_dtype = eval('torch.' + dtype) - np_x0 = test_common.generate_numpy(shape, dtype) - np_x1 = test_common.generate_numpy(shape, dtype) - np_x2 = test_common.generate_numpy(shape, dtype) - - x0 = torch.from_numpy(np_x0).to(torch_dtype).npu() - x1 = torch.from_numpy(np_x1).to(torch_dtype).npu() - x2 = torch.from_numpy(np_x2).to(torch_dtype).npu() - - #numpy result - ans_numpy = np_x0 * np_x1 - z_ref1 = torch.from_numpy(ans_numpy).npu() - - triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_mul[1, 1, shape[0]](triton_res, x0, x1, x2, 1, 1, 1, 1, 1, shape[0]) - test_common.validate_cmp(dtype, z_ref1, triton_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_ravel.py b/third_party/ascend/unittest/generalization_cases/test_general_ravel.py deleted file mode 100644 index 2e6735709d..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_ravel.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.ravel(X) - - oidx = tl.arange(0, XB * YB * ZB) + xoffs * YNUMEL * ZNUMEL + yoffs * ZNUMEL + zoffs - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def triton_ravel_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.ravel(x_val) - - pid0 = tl.program_id(0) - - flat_idx = tl.arange(0, BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) - out_offsets = pid0 * BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4 + flat_idx - out_masks = out_offsets < SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_ravel(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - new_shape = (x.numel(), ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.ravel(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - if xnumel > 1: - grid = (XB, 1, 1) - XB = 1 - elif ynumel > 1: - grid = (1, YB, 1) - YB = 1 - else: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_ravel_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - - output = torch.randint(1, (x.numel(), ), dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.ravel(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_ravel_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -@triton.jit -def fn_npu_dtype(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.ravel(X) - - oidx = tl.arange(0, XB * YB * ZB) - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('sigtype, dtype, XB, YB, ZB', [ - ('bfloat16', torch.bfloat16, 2, 8, 4), - ('uint8', torch.uint8, 1, 256, 16), - ('bool', torch.bool, 1, 1, 2), -]) -def test_ravel_u(sigtype, dtype, XB, YB, ZB): - x = test_common.generate_tensor((XB, YB, ZB), sigtype).npu() - ans = torch.ravel(x) - output = test_common.generate_tensor((1, XB * YB * ZB), sigtype).npu() - output = output.reshape(-1) - fn_npu_dtype[1, 1, 1](output, x, XB, YB, ZB) - test_common.validate_cmp(sigtype, output, ans) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_reshape.py b/third_party/ascend/unittest/generalization_cases/test_general_reshape.py deleted file mode 100644 index 25b661b0ad..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_reshape.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math -import logging - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.reshape(X, (ZB * YB * XB, )) - - oidx = tl.arange(0, XB * YB * ZB) + xoffs * YNUMEL * ZNUMEL + yoffs * ZNUMEL + zoffs - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def triton_reshape_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.reshape(x_val, (SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4, )) - - pid0 = tl.program_id(0) - - flat_idx = tl.arange(0, BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) - out_offsets = pid0 * BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4 + flat_idx - out_masks = out_offsets < SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_reshape(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - new_shape = (x.numel(), ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x.reshape(-1) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - if xnumel > 1: - grid = (XB, 1, 1) - XB = 1 - elif ynumel > 1: - grid = (1, YB, 1) - YB = 1 - else: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_reshape_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - - output = torch.randint(1, (x.numel(), ), dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x.reshape(-1) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_reshape_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_rsqrt.py b/third_party/ascend/unittest/generalization_cases/test_general_rsqrt.py deleted file mode 100644 index 54a0984fec..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_rsqrt.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.rsqrt(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_rsqrt_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.rsqrt(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_rsqrt( - dtype, - shape, -): - x = test_common.generate_tensor(shape, dtype).abs().npu() - y = test_common.generate_tensor(shape, dtype).abs().npu() - z = test_common.generate_tensor(shape, dtype).abs().npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.rsqrt(x) - - if len(shape) == 1: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, shape[0], 1, 1, shape[0]) - elif len(shape) == 2: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - mx = max(shape[0], shape[1], shape[2]) - if mx == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif mx == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_dtypes = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_dtypes) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_rsqrt_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_rsqrt_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.rsqrt(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_rsqrt_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_sigmoid.py b/third_party/ascend/unittest/generalization_cases/test_general_sigmoid.py deleted file mode 100644 index 91737647e6..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_sigmoid.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.sigmoid(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_sigmoid_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.sigmoid(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sigmoid( - dtype, - shape, -): - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if (x.dtype == torch.bfloat16): - ans = torch.sigmoid(x.to(torch.float32)).to(torch.bfloat16) - else: - ans = torch.sigmoid(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_dtypes = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_dtypes) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_sigmoid_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sigmoid_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if (x.dtype == torch.bfloat16): - ans = torch.sigmoid(x.to(torch.float32)).to(torch.bfloat16) - else: - ans = torch.sigmoid(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_sigmoid_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_sin.py b/third_party/ascend/unittest/generalization_cases/test_general_sin.py deleted file mode 100644 index f52d0405de..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_sin.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import numpy as np -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.sin(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_sin_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.sin(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -import logging - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sin( - dtype, - shape, -): - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.sin(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_dtypes = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_dtypes) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_sin_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sin_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.sin(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_sin_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_softmax.py b/third_party/ascend/unittest/generalization_cases/test_general_softmax.py deleted file mode 100644 index ce4b34d4a1..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_softmax.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -# 实际实现与官网定义不符,可能和triton submodule版本有关, 当前的submodule 不接受指定dim,都是按第0维做softmax -# arith.maximum 不支持类似 1x3 -> 3 和 1 -> 1 的reduce -import triton -import triton.language as tl -import torch -import logging -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_softmax_d0(x1): - res = torch.softmax(x1, axis=0).to(x1.dtype) - return res - - -@triton.jit -def tt_softmax_1d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr): - idx = tl.arange(0, XB) - x = tl.load(in_ptr + idx) - ret = tl.softmax(x) - tl.store(out_ptr + idx, ret) - - -@triton.jit -def tt_softmax_2d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - a = tl.load(in_ptr + idx) - ret = tl.softmax(a) - - tl.store(out_ptr + idx, ret) - - -@triton.jit -def tt_softmax_3d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - a = tl.load(in_ptr + idx) - ret = tl.softmax(a) - - tl.store(out_ptr + idx, ret) - - -@triton.jit -def triton_softmax_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.softmax(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_softmax(dtype, shape): - logging.log(logging.DEBUG, f"shape = {shape}", flush=True) - torch.manual_seed(0) - x = torch.rand(shape, dtype=eval('torch.' + dtype), device="npu") * 10 - grid = (1, 1, 1) - - y_cal = torch.rand(shape, dtype=eval('torch.' + dtype), device="npu") - - y_ref = torch_softmax_d0(x) - if len(shape) == 1: - tt_softmax_1d[grid](x, y_cal, x.numel(), 1, 1, x.numel(), 1, 1) - elif len(shape) == 2: - xnumel, ynumel, znumel = shape + (1, ) - XB, YB, ZB = xnumel, ynumel, znumel - if x.numel() * x.element_size() > 8192: - grid = (1, ynumel, 1) - YB = 1 - tt_softmax_2d[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB) - - elif len(shape) == 3: - mx = max(shape[1], shape[2]) - if mx == shape[1]: - tt_softmax_3d[1, shape[1], 1](x, y_cal, shape[0], shape[1], shape[2], shape[0], 1, shape[2]) - else: - tt_softmax_3d[1, 1, shape[2]](x, y_cal, shape[0], shape[1], shape[2], shape[0], shape[1], 1) - - test_common.validate_cmp(dtype, y_cal, y_ref) - - -invalid_types = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_softmax_invalid_dtype_case(dtype): - x0 = test_common.generate_tensor((1, ), dtype).npu() - - y_cal = torch.zeros((1, ), dtype=eval('torch.' + dtype)).npu() - tt_softmax_1d[1, 1, 1](x0, y_cal, 0, 0, 0, 1, 0, 0) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_softmax_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_softmax_d0(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_softmax_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_split.py b/third_party/ascend/unittest/generalization_cases/test_general_split.py deleted file mode 100644 index 9efedcf71a..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_split.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx=xidx[:,None,None,None]*YNUMEL*ZNUMEL*2+yidx[None,:,None,None]*ZNUMEL*2+ \ - zidx[None,None,:,None]*2 + tl.arange(0,2)[None,None,None,:] - - X = tl.load(x_ptr + idx) - - xx, yy = tl.split(X) - - oidx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - tl.store(output_ptr + oidx, xx) - tl.store(output_ptr1 + oidx, yy) - - -import logging - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_split(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - xx = torch.stack((x, y), dim=-1) - - a, b = torch.split(xx, 1, dim=-1) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - a = a.reshape(XB, YB, ZB) - b = b.reshape(XB, YB, ZB) - output = torch.randint(1, (XB, YB, ZB), dtype=eval('torch.' + dtype)).npu() - output1 = torch.randint(1, (XB, YB, ZB), dtype=eval('torch.' + dtype)).npu() - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - if xnumel > 1: - grid = (XB, 1, 1) - XB = 1 - elif ynumel > 1: - grid = (1, YB, 1) - YB = 1 - else: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, xx, output1, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, a, output) - test_common.validate_cmp(dtype, b, output1) - - -@triton.jit -def fn_npu_4_8d(output_ptr, x_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, WB: tl.constexpr, - VB: tl.constexpr, UB: tl.constexpr, TB: tl.constexpr, SB: tl.constexpr): - - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - widx = tl.arange(0, WB) - vidx = tl.arange(0, VB) - uidx = tl.arange(0, UB) - tidx = tl.arange(0, TB) - sidx = tl.arange(0, SB) - - idx = (xidx[:, None, None, None, None, None, None, None, None] * YB * ZB * WB * VB * UB * TB * SB * 2 + - yidx[None, :, None, None, None, None, None, None, None] * ZB * WB * VB * UB * TB * SB * 2 + - zidx[None, None, :, None, None, None, None, None, None] * WB * VB * UB * TB * SB * 2 + - widx[None, None, None, :, None, None, None, None, None] * VB * UB * TB * SB * 2 + - vidx[None, None, None, None, :, None, None, None, None] * UB * TB * SB * 2 + - uidx[None, None, None, None, None, :, None, None, None] * TB * SB * 2 + - tidx[None, None, None, None, None, None, :, None, None] * SB * 2 + - sidx[None, None, None, None, None, None, None, :, None] * 2 + - tl.arange(0, 2)[None, None, None, None, None, None, None, None, :]) - - X = tl.load(x_ptr + idx) - xx, yy = tl.split(X) - - oidx = (xidx[:, None, None, None, None, None, None, None] * YB * ZB * WB * VB * UB * TB * SB + - yidx[None, :, None, None, None, None, None, None] * ZB * WB * VB * UB * TB * SB + - zidx[None, None, :, None, None, None, None, None] * WB * VB * UB * TB * SB + - widx[None, None, None, :, None, None, None, None] * VB * UB * TB * SB + - vidx[None, None, None, None, :, None, None, None] * UB * TB * SB + - uidx[None, None, None, None, None, :, None, None] * TB * SB + - tidx[None, None, None, None, None, None, :, None] * SB + sidx[None, None, None, None, None, None, None, :]) - - tl.store(output_ptr + oidx, xx) - tl.store(output_ptr1 + oidx, yy) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape_4_8d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_split_4_8d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - xx = torch.stack((x, y), dim=-1) - - a, b = torch.split(xx, 1, dim=-1) - - if len(shape) == 1: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, 1, 1, 1, shape[0] - elif len(shape) == 2: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, 1, 1, shape[0], shape[1] - elif len(shape) == 3: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, 1, shape[0], shape[1], shape[2] - elif len(shape) == 4: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, shape[0], shape[1], shape[2], shape[3] - elif len(shape) == 5: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, shape[0], shape[1], shape[2], shape[3], shape[4] - elif len(shape) == 6: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, shape[0], shape[1], shape[2], shape[3], shape[4], shape[5] - elif len(shape) == 7: - XB, YB, ZB, WB, VB, UB, TB, SB = 1, shape[0], shape[1], shape[2], shape[3], shape[4], shape[5], shape[6] - else: - XB, YB, ZB, WB, VB, UB, TB, SB = shape - - a = a.reshape(XB, YB, ZB, WB, VB, UB, TB, SB) - b = b.reshape(XB, YB, ZB, WB, VB, UB, TB, SB) - - output = torch.randint(1, (XB, YB, ZB, WB, VB, UB, TB, SB), dtype=eval('torch.' + dtype)).npu() - output1 = torch.randint(1, (XB, YB, ZB, WB, VB, UB, TB, SB), dtype=eval('torch.' + dtype)).npu() - - grid = (1, 1, 1) - fn_npu_4_8d[grid](output, xx, output1, XB, YB, ZB, WB, VB, UB, TB, SB) - - test_common.validate_cmp(dtype, a, output) - test_common.validate_cmp(dtype, b, output1) - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - xx, yy = tl.split(X) - - oidx = xidx[:, None] * YB + yidx[None, :] - - tl.store(output_ptr + oidx, xx) - tl.store(output_ptr1 + oidx, yy) - - -@pytest.mark.parametrize('para_type, data_type, XB, YB, ZB', [ - ('bfloat16', torch.bfloat16, 2, 8, 2), - ('uint8', torch.uint8, 1, 256, 2), - ('bool', torch.bool, 1, 1, 2), -]) -def test_split_u(para_type, data_type, XB, YB, ZB): - x = test_common.generate_tensor((XB, YB, ZB), para_type).npu() - a, b = torch.split(x, 1, dim=-1) - a = a.reshape(XB, YB) - b = b.reshape(XB, YB) - - output = test_common.generate_tensor((XB, YB), para_type).npu() - output1 = test_common.generate_tensor((XB, YB), para_type).npu() - fn_npu_[1, 1, 1](output, x, output1, XB, YB, ZB, debug=True) - - test_common.validate_cmp(para_type, a, output) - test_common.validate_cmp(para_type, b, output1) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_sub.py b/third_party/ascend/unittest/generalization_cases/test_general_sub.py deleted file mode 100644 index a831d6c421..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_sub.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest - -import triton -import triton.language as tl -import torch -import test_common -from test_common import TestUtils -import logging -import numpy as np - - -@triton.jit -def triton_sub(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X - Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_sub_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val - y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_sub(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - ans = x - y - output = torch.zeros_like(ans) - - if len(shape) == 1: - triton_sub[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - triton_sub[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - triton_sub[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - triton_sub[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - triton_sub[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_sub[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - triton_sub[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_sub_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x - y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_sub_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['uint16', 'uint32', 'uint64']) -def test_sub_uint(shape, dtype): - torch_dtype = eval('torch.' + dtype) - np_x0 = test_common.generate_numpy(shape, dtype) - np_x1 = test_common.generate_numpy(shape, dtype) - np_x2 = test_common.generate_numpy(shape, dtype) - - x0 = torch.from_numpy(np_x0).to(torch_dtype).npu() - x1 = torch.from_numpy(np_x1).to(torch_dtype).npu() - x2 = torch.from_numpy(np_x2).to(torch_dtype).npu() - - #numpy result - ans_numpy = np_x0 - np_x1 - z_ref1 = torch.from_numpy(ans_numpy).npu() - - triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_sub[1, 1, shape[0]](triton_res, x0, x1, x2, 1, 1, 1, 1, 1, shape[0]) - test_common.validate_cmp(dtype, z_ref1, triton_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_tensor_descriptor.py b/third_party/ascend/unittest/generalization_cases/test_general_tensor_descriptor.py deleted file mode 100644 index ec4a7c8bc0..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_tensor_descriptor.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils - -full_dtype = test_common._float_dtypes + test_common._int_dtypes + test_common._uint_dtypes -temporarily_not_support_dtype = ['bool'] - - -@triton.jit -def triton_tensor_descriptor_2d( - out_ptr, - x_ptr, - M: tl.constexpr, - N: tl.constexpr, - M_BLOCK: tl.constexpr, - N_BLOCK: tl.constexpr, -): - in_desc = tl.make_tensor_descriptor( - x_ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[M_BLOCK, N_BLOCK], - ) - out_desc = tl.make_tensor_descriptor( - out_ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[M_BLOCK, N_BLOCK], - ) - moffset = tl.program_id(0) * M_BLOCK - noffset = tl.program_id(1) * N_BLOCK - block = in_desc.load([moffset, noffset]) - out_desc.store([moffset, noffset], block) - - -@triton.jit -def triton_tensor_descriptor_3d( - out_ptr, - x_ptr, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - stride_m: tl.constexpr, - stride_n: tl.constexpr, - stride_k: tl.constexpr, - M_BLOCK: tl.constexpr, - N_BLOCK: tl.constexpr, - K_BLOCK: tl.constexpr, -): - in_desc = tl.make_tensor_descriptor( - x_ptr, - shape=[M, N, K], - strides=[stride_m, stride_n, stride_k], - block_shape=[M_BLOCK, N_BLOCK, K_BLOCK], - ) - out_desc = tl.make_tensor_descriptor( - out_ptr, - shape=[M, N, K], - strides=[stride_m, stride_n, stride_k], - block_shape=[M_BLOCK, N_BLOCK, K_BLOCK], - ) - moffset = tl.program_id(0) * M_BLOCK - noffset = tl.program_id(1) * N_BLOCK - koffset = tl.program_id(2) * K_BLOCK - block = in_desc.load([moffset, noffset, koffset]) - out_desc.store([moffset, noffset, koffset], block) - - -@triton.jit -def triton_tensor_descriptor_4d( - out_ptr, - x_ptr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, -): - pid0 = tl.program_id(0) - pid1 = tl.program_id(1) - pid2 = tl.program_id(2) - idx2 = pid2 // BLOCK_3 - idx3 = pid2 % BLOCK_3 - o1 = pid0 * BLOCK_0 - o2 = pid1 * BLOCK_1 - o3 = idx2 * BLOCK_2 - o4 = idx3 * BLOCK_3 - in_desc = tl.make_tensor_descriptor( - x_ptr, - shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3], - strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3], - block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3], - ) - out_desc = tl.make_tensor_descriptor( - out_ptr, - shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3], - strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3], - block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3], - ) - block = in_desc.load([o1, o2, o3, o4]) - out_desc.store([o1, o2, o3, o4], block) - - -@triton.jit -def triton_tensor_descriptor_5d( - out_ptr, - x_ptr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - BLOCK_4: tl.constexpr, -): - pid0 = tl.program_id(0) - pid1 = tl.program_id(1) - pid2 = tl.program_id(2) - idx3 = pid2 // (BLOCK_3 * BLOCK_4) - idx4 = (pid2 // BLOCK_4) % BLOCK_3 - idx5 = pid2 % BLOCK_4 - o1 = pid0 * BLOCK_0 - o2 = pid1 * BLOCK_1 - o3 = idx3 * BLOCK_2 - o4 = idx4 * BLOCK_3 - o5 = idx5 * BLOCK_4 - in_desc = tl.make_tensor_descriptor( - x_ptr, - shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4], - strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4], - block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4], - ) - out_desc = tl.make_tensor_descriptor( - out_ptr, - shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4], - strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4], - block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4], - ) - block = in_desc.load([o1, o2, o3, o4, o5]) - out_desc.store([o1, o2, o3, o4, o5], block) - - -@triton.jit -def triton_tensor_descriptor_function_2d( - out_ptr, - x_ptr, - M: tl.constexpr, - N: tl.constexpr, - M_BLOCK: tl.constexpr, - N_BLOCK: tl.constexpr, -): - in_desc = tl.make_tensor_descriptor( - x_ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[M_BLOCK, N_BLOCK], - ) - out_desc = tl.make_tensor_descriptor( - out_ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[M_BLOCK, N_BLOCK], - ) - moffset = tl.program_id(0) * M_BLOCK - noffset = tl.program_id(1) * N_BLOCK - block = tl.load_tensor_descriptor(in_desc, [moffset, noffset]) - tl.store_tensor_descriptor(out_desc, [moffset, noffset], block) - - -@pytest.mark.parametrize('dtype', full_dtype) -@pytest.mark.parametrize('shape', TestUtils.full_shape) -def test_tensor_descriptor_load_store_nd(dtype, shape): - """test tensor_descriptor load/store for nd tensor""" - - if dtype in temporarily_not_support_dtype: - pytest.skip(f"{dtype} not supported") - - inp = test_common.generate_numpy(shape, dtype) - inp = torch.from_numpy(inp).npu() - out = inp.new_empty(shape) - blocks = list(inp.size()) - strides = list(inp.stride()) - grid = (1, ) - dims = len(shape) - - # 如果最后一维小于16字节,则跳过 - itemsize = torch.tensor([], dtype=inp.dtype).element_size() - if blocks[-1] * itemsize < 16: - pytest.skip(f"last dimension must be at least 16 bytes, but got {blocks[-1] * itemsize} bytes") - - if dims == 2: - if inp.numel() * inp.element_size() > 8192: - triton_tensor_descriptor_2d[shape[0], 1, 1](out, inp, 1, shape[1], 1, shape[1]) - else: - triton_tensor_descriptor_2d[grid](out, inp, *shape, *blocks) - test_common.validate_cmp(dtype, inp, out) - elif dims == 3: - triton_tensor_descriptor_3d[grid](out, inp, *shape, *strides, *blocks) - test_common.validate_cmp(dtype, inp, out) - elif dims == 4: - triton_tensor_descriptor_4d[grid](out, inp, *shape, *strides, *blocks) - test_common.validate_cmp(dtype, inp, out) - elif dims == 5: - triton_tensor_descriptor_5d[grid](out, inp, *shape, *strides, *blocks) - test_common.validate_cmp(dtype, inp, out) - else: - pytest.skip(f"{dims}d not supported") - - -@pytest.mark.parametrize("dtype", test_common._uint_dtypes) -def test_tensor_descriptor_in_function(dtype): - """test tensor_descriptor load/store in function""" - - if dtype in temporarily_not_support_dtype: - pytest.skip(f"{dtype} not supported") - - M, N = 32, 128 - inp = test_common.generate_numpy((M, N), dtype) - inp = torch.from_numpy(inp).npu() - out = inp.new_empty((M, N)) - - M_BLOCK = 8 - N_BLOCK = 32 - grid_m = M // M_BLOCK - grid_n = N // N_BLOCK - - triton_tensor_descriptor_function_2d[(grid_m, grid_n)](out, inp, M, N, M_BLOCK, N_BLOCK) - test_common.validate_cmp(dtype, inp, out) diff --git a/third_party/ascend/unittest/generalization_cases/test_general_view.py b/third_party/ascend/unittest/generalization_cases/test_general_view.py deleted file mode 100644 index 7f0f9b1532..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_general_view.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import logging -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.view(X, (ZB * YB * XB, )) - - oidx = tl.arange(0, XB * YB * ZB) + xoffs * YNUMEL * ZNUMEL + yoffs * ZNUMEL + zoffs - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def triton_view_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.view(x_val, (SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4, )) - - pid0 = tl.program_id(0) - - flat_idx = tl.arange(0, BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) - out_offsets = pid0 * BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4 + flat_idx - out_masks = out_offsets < SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_view(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() - new_shape = (x.numel(), ) - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x.view(new_shape) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - if xnumel > 1: - grid = (XB, 1, 1) - XB = 1 - elif ynumel > 1: - grid = (1, YB, 1) - YB = 1 - else: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_view_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() - - output = torch.randint(1, (x.numel(), ), dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x.view(x.numel(), ) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_view_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_gt_op.py b/third_party/ascend/unittest/generalization_cases/test_gt_op.py deleted file mode 100644 index 7079457b2e..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_gt_op.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_gt_3d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 > x1 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_gt_2d(in_ptr0, in_ptr1, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 > x1 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_gt_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 > x1 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_gt_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val > y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_gt(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.where(torch.gt(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - triton_gt_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_gt_2d[grid](x0, x1, output, shape0, shape1) - if len(shape) == 1: - triton_gt_1d[1, 1, 1](x0, x1, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_gt_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.where(torch.gt(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_gt_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_invert.py b/third_party/ascend/unittest/generalization_cases/test_invert.py deleted file mode 100644 index 698cb8f11c..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_invert.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_invert(x0, ddtype): - if 'float' in str(ddtype): - x0 = x0.to(torch.int32) - y_ref = ~x0 - y_ref = y_ref.to(ddtype) - else: - y_ref = ~x0 - return y_ref - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = ~X - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_invert_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = ~x_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ddtype = eval('torch.' + dtype) - ans = torch_invert(x, ddtype) - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_invert_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_invert(x, eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_invert_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_le_op.py b/third_party/ascend/unittest/generalization_cases/test_le_op.py deleted file mode 100644 index d305395417..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_le_op.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_le_3d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 <= x1 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_le_2d(in_ptr0, in_ptr1, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 <= x1 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_le_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 <= x1 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_le_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val <= y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_le(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.where(torch.le(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - triton_le_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_le_2d[grid](x0, x1, output, shape0, shape1) - if len(shape) == 1: - triton_le_1d[1, 1, 1](x0, x1, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_le_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.where(torch.le(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_le_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_load_store.py b/third_party/ascend/unittest/generalization_cases/test_load_store.py deleted file mode 100644 index 82013d9f38..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_load_store.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, YB: tl.constexpr): - idx = tl.arange(0, YB) - X = tl.load(x_ptr + idx) - tl.store(output_ptr + idx, X) - - -def torch_fn_npu_1d(x): - return x - - -@triton.jit -def fn_npu_2d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): - pid = tl.program_id(0) - y_idx = tl.arange(0, YB)[:, None] + pid * YB - z_idx = tl.arange(0, ZB)[None, :] - idx = y_idx * ZB + z_idx - - X = tl.load(x_ptr + idx) - - tl.store(output_ptr + idx, X) - - -def torch_fn_npu_2d(x): - return x - - -@triton.jit -def fn_npu_3d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - y = tl.arange(0, YB)[:, None, None] - z = tl.arange(0, ZB)[None, :, None] - k = tl.arange(0, KB)[None, None, :] - - idx = y * ZB * KB + z * KB + k - - X = tl.load(x_ptr + idx) - - tl.store(output_ptr + idx, X) - - -def torch_fn_npu_3d(x): - return x - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_npu(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() - triton_res = torch.empty(shape, dtype=data_type).npu() - torch_res = x - if len(shape) == 1: - torch_res = torch_fn_npu_1d(x) - fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) - # uint32 转成 float32算精度,因为torch_npu不支持uint32类型张量的slice - torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) - triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) - cmp_type = dtype if dtype != 'uint32' else 'float32' - test_common.validate_cmp(cmp_type, triton_res[:2 * shape[0] // 3], torch_res[:2 * shape[0] // 3]) - elif len(shape) == 2: - torch_res = torch_fn_npu_2d(x) - fn_npu_2d[shape[0], 1, 1](triton_res, x, 1, shape[1]) - torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) - triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) - cmp_type = dtype if dtype != 'uint32' else 'float32' - test_common.validate_cmp(cmp_type, triton_res[:2 * shape[0] // 3, :2 * shape[1] // 3], - torch_res[:2 * shape[0] // 3, :2 * shape[1] // 3]) - elif len(shape) == 3: - torch_res = torch_fn_npu_3d(x) - fn_npu_3d[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) - triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) - cmp_type = dtype if dtype != 'uint32' else 'float32' - test_common.validate_cmp(cmp_type, triton_res[:2 * shape[0] // 3, :2 * shape[1] // 3, :2 * shape[2] // 3], - torch_res[:2 * shape[0] // 3, :2 * shape[1] // 3, :2 * shape[2] // 3]) - - -# require: all data (4d and 5d) can be placed into but without ub overflow -@triton.jit -def triton_load_store_multi_d(in_ptr0, out_ptr0, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, - SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - tmp_in = tl.load(in_ptr0 + offsets, masks) - tmp_out = tmp_in - tl.store(out_ptr0 + offsets, tmp_out, masks) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('param_list', [ - ['float32', (8, 4, 16, 16)], - ['float16', (8, 4, 16, 16)], - ['int8', (8, 4, 16, 16)], - ['float32', (8, 8, 4, 4)], - ['float16', (8, 8, 4, 4)], - ['int8', (8, 8, 4, 4)], - ['float32', (3, 8, 2, 16, 16)], - ['float16', (3, 8, 2, 16, 16)], - ['int8', (9, 8, 8, 16, 16)], - ['float32', (11, 8, 8, 4, 4)], - ['float16', (11, 8, 8, 4, 4)], - ['int8', (11, 8, 8, 4, 4)], -]) -def test_load_store_4d_5d(param_list): - # 生成数据 - dtype, shape = param_list - x0 = test_common.generate_tensor(shape, dtype).npu() - # torch结果 - y_expect = x0 - y_actual = test_common.generate_tensor(shape, dtype).npu() - # triton结果 - blocks = list(x0.size()) - shapes = list(x0.stride()) - while len(blocks) < 5: - blocks.append(1) - shapes.append(1) - triton_load_store_multi_d[(1, )](x0, y_actual, *blocks, *blocks, *shapes) - # 比较结果 - test_common.validate_cmp(dtype, y_actual, y_expect) diff --git a/third_party/ascend/unittest/generalization_cases/test_logical_and_op.py b/third_party/ascend/unittest/generalization_cases/test_logical_and_op.py deleted file mode 100644 index b89ca7f08f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_logical_and_op.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils, generate_tensor -import logging - - -@triton.jit -def triton_logical_and_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_and(x1) - odx = lblk_idx - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_and_2d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr): - loffs = tl.program_id(0) * L - lblk_idx = tl.arange(0, L) + loffs - mblk_idx = tl.arange(0, M) - idx = lblk_idx[:, None] * M + mblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_and(x1) - odx = lblk_idx[:, None] * M + mblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_and_3d(in_ptr0, in_ptr1, out_ptr0, XB, YB, ZB, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) + tl.program_id(0) * XB - mblk_idx = tl.arange(0, M) + tl.program_id(1) * YB - nblk_idx = tl.arange(0, N) + tl.program_id(2) * ZB - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_and(x1) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_and_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val.logical_and(y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -support_typelist = [ - 'bool', -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', support_typelist) -def test_logical_and(shape, sigtype): - logging.debug(f"dtype:{sigtype} shape:{shape}") - dtype = eval('torch.' + sigtype) - x0 = generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.logical_and(x0, x1) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 1: - triton_logical_and_1d[1, 1, 1](x0, x1, output, shape[0]) - elif len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_logical_and_2d[grid](x0, x1, output, shape0, shape1) - elif len(shape) == 3: - mx = max(shape[0], shape[1], shape[2]) - if mx == shape[0]: - triton_logical_and_3d[shape[0], 1, 1](x0, x1, output, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif mx == shape[1]: - triton_logical_and_3d[1, shape[1], 1](x0, x1, output, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_logical_and_3d[1, 1, shape[2]](x0, x1, output, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['bool']) -def test_logical_and_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.logical_and(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_logical_and_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_logical_or_op.py b/third_party/ascend/unittest/generalization_cases/test_logical_or_op.py deleted file mode 100644 index f470de056f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_logical_or_op.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils, generate_tensor -import logging - - -@triton.jit -def triton_logical_or_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_or(x1) - odx = lblk_idx - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_or_2d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr): - pid = tl.program_id(0) - lblk_idx = tl.arange(0, L) + pid * L - mblk_idx = tl.arange(0, M) - idx = lblk_idx[:, None] * M + mblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_or(x1) - odx = lblk_idx[:, None] * M + mblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_or_3d(in_ptr0, in_ptr1, out_ptr0, XB, YB, ZB, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) + tl.program_id(0) * XB - mblk_idx = tl.arange(0, M) + tl.program_id(1) * YB - nblk_idx = tl.arange(0, N) + tl.program_id(2) * ZB - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0.logical_or(x1) - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_logical_or_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val.logical_or(y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -support_typelist = [ - 'bool', -] - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('sigtype', support_typelist) -def test_logical_or(shape, sigtype): - logging.debug(f"dtype:{sigtype} shape:{shape}") - dtype = eval('torch.' + sigtype) - x0 = generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.logical_or(x0, x1) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 1: - triton_logical_or_1d[1, 1, 1](x0, x1, output, shape[0]) - elif len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_logical_or_2d[grid](x0, x1, output, shape0, shape1) - elif len(shape) == 3: - mx = max(shape[0], shape[1], shape[2]) - if mx == shape[0]: - triton_logical_or_3d[shape[0], 1, 1](x0, x1, output, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif mx == shape[1]: - triton_logical_or_3d[1, shape[1], 1](x0, x1, output, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - triton_logical_or_3d[1, 1, shape[2]](x0, x1, output, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['bool']) -def test_logical_or_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.logical_or(x, y) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_logical_or_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_lshift_op.py b/third_party/ascend/unittest/generalization_cases/test_lshift_op.py deleted file mode 100644 index b70020aca5..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_lshift_op.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_lshift_1d(in_ptr0, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - ret = x0 << 2 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lshift_2d(in_ptr0, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - ret = x0 << 2 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lshift_3d(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - loffs = tl.program_id(0) * L - lblk_idx = tl.arange(0, L) + loffs - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - ret = x0 << 2 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lshift_4d_5d(x_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = x_val << 2 - tl.store(output_ptr + offsets, ret, mask=masks) - - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - -typelist = [ - 'int8', - 'int16', - 'int32', - 'int64', -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_lshift(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = x0 << 2 - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - shape0 = shape[0] - shape1 = shape[1] - shape2 = shape[2] - if x0.numel() * x0.element_size() >= 1024: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_lshift_3d[grid](x0, output, shape0, shape1, shape2) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 1024: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_lshift_2d[grid](x0, output, shape0, shape1) - if len(shape) == 1: - triton_lshift_1d[1, 1, 1](x0, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_lshift_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x << 2 - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_lshift_4d_5d[grid](x, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - triton_lshift_1d[1, 1, 1](x, output, N) diff --git a/third_party/ascend/unittest/generalization_cases/test_lt_op.py b/third_party/ascend/unittest/generalization_cases/test_lt_op.py deleted file mode 100644 index 8f013d7c9f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_lt_op.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils -import logging - - -@triton.jit -def triton_lt_3d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - lblk_idx = tl.arange(0, L) - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 < x1 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lt_2d(in_ptr0, in_ptr1, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 < x1 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lt_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - x1 = tl.load(in_ptr1 + idx) - ret = x0 < x1 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_lt_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val < y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_lt(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - x1 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = torch.where(torch.lt(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - triton_lt_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_lt_2d[grid](x0, x1, output, shape0, shape1) - if len(shape) == 1: - triton_lt_1d[1, 1, 1](x0, x1, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_lt_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.where(torch.lt(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_lt_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_make_blkptr_matmul.py b/third_party/ascend/unittest/generalization_cases/test_make_blkptr_matmul.py deleted file mode 100644 index 0fdc244e79..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_make_blkptr_matmul.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, avoid_not_support, get_dtype_size - - -@triton.jit -def matmul_kernel( - a_ptr, - b_ptr, - c_ptr, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - acc_dtype: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - matxa_ptr_in = tl.make_block_ptr(a_ptr, (M, K), (K, 1), (0, 0), (M, K), order=(1, 0)) - matxb_ptr_in = tl.make_block_ptr(b_ptr, (K, N), (N, 1), (0, 0), (K, N), order=(1, 0)) - matxc_ptr_in = tl.make_block_ptr(c_ptr, (M, N), (N, 1), (0, 0), (M, N), order=(1, 0)) - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) - a = tl.load(matxa_ptr_in) - b = tl.load(matxb_ptr_in) - accumulator = tl.dot(a, b, accumulator, out_dtype=acc_dtype) - c = accumulator.to(c_ptr.dtype.element_ty) - tl.store(matxc_ptr_in, c) - - -@avoid_not_support('matmul') -@pytest.mark.parametrize('shape', [(16, 32)]) -@pytest.mark.parametrize('dtype', ['float32']) -def test_matmul(shape, dtype): - M, N, K = shape[0], shape[0], shape[1] - - BLOCK_M, BLOCK_N, BLOCK_K = M, N, K - a = test_common.generate_tensor((M, K), dtype) - b = test_common.generate_tensor((K, N), dtype) - - triton_res = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() - accumulator_type = tl.float32 - - matmul_kernel[ - 1, - ](a.npu(), b.npu(), triton_res, M, N, K, accumulator_type, BLOCK_M, BLOCK_N, BLOCK_K, enable_nd2nz_on_vector=False) - - print("PASSED") diff --git a/third_party/ascend/unittest/generalization_cases/test_make_block_ptr.py b/third_party/ascend/unittest/generalization_cases/test_make_block_ptr.py deleted file mode 100644 index 4c95d5623b..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_make_block_ptr.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl - -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, ), - strides=(1, ), - offsets=(0, ), - block_shape=(XB, ), - order=(0, ), - ) - X = tl.load(block_ptr_in) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, ), - strides=(1, ), - offsets=(0, ), - block_shape=(XB, ), - order=(0, ), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def fn_npu_2d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xoffset = tl.program_id(0) - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, YB), - strides=(YB, 1), - offsets=(xoffset, 0), - block_shape=(XB, YB), - order=(1, 0), - ) - X = tl.load(block_ptr_in) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, YB), - strides=(YB, 1), - offsets=(xoffset, 0), - block_shape=(XB, YB), - order=(1, 0), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def fn_npu_3d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, YB, ZB), - strides=(YB * ZB, ZB, 1), - offsets=(0, 0, 0), - block_shape=(XB, YB, ZB), - order=(2, 1, 0), - ) - X = tl.load(block_ptr_in) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, YB, ZB), - strides=(YB * ZB, ZB, 1), - offsets=(0, 0, 0), - block_shape=(XB, YB, ZB), - order=(2, 1, 0), - ) - tl.store(block_ptr_out, X) - - -@triton.jit -def triton_make_block_ptr_4d( - output_ptr, - x_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), - offsets=(0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), - order=(3, 2, 1, 0), - ) - x = tl.load(block_ptr_in) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), - offsets=(0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), - order=(3, 2, 1, 0), - ) - tl.store(block_ptr_out, x) - - -@triton.jit -def triton_make_block_ptr_5d( - output_ptr, - x_ptr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - BLOCK_4: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr, -): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), - offsets=(0, 0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), - order=(4, 3, 2, 1, 0), - ) - x = tl.load(block_ptr_in) - - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), - strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), - offsets=(0, 0, 0, 0, 0), - block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), - order=(4, 3, 2, 1, 0), - ) - tl.store(block_ptr_out, x) - - -temporarily_not_support_dtype = ['bool'] - - -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('shape', TestUtils.full_shape) -def test_npu(dtype, shape): - if dtype in temporarily_not_support_dtype: - return - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - - a = x - blocks = list(x.size()) - strides = list(x.stride()) - grid = (1, ) - if len(shape) == 5: - triton_make_block_ptr_5d[grid](output, x, *blocks, *blocks, *strides) - elif len(shape) == 4: - triton_make_block_ptr_4d[grid](output, x, *blocks, *blocks, *strides) - elif len(shape) == 3: - fn_npu_3d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=shape[2]) - elif len(shape) == 2: - if x.numel() * x.element_size() > 8192: - fn_npu_2d[shape[0], 1, 1](output, x, y, z, output1, XB=1, YB=shape[1], ZB=1) - else: - fn_npu_2d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=1) - else: - fn_npu_1d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=1, ZB=1) - torch.testing.assert_close(output, a) diff --git a/third_party/ascend/unittest/generalization_cases/test_matmul.py b/third_party/ascend/unittest/generalization_cases/test_matmul.py deleted file mode 100644 index edeca4f170..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_matmul.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - -import acc_util -import test_common -from test_common import TestUtils, avoid_not_support, get_dtype_size - - -@triton.jit -def matmul_kernel( - a_ptr, - b_ptr, - c_ptr, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - acc_dtype: tl.constexpr, - stride_am: tl.constexpr, - stride_ak: tl.constexpr, - stride_bk: tl.constexpr, - stride_bn: tl.constexpr, - stride_cm: tl.constexpr, - stride_cn: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - pid = tl.program_id(axis=0) - num_pid_n = tl.cdiv(N, BLOCK_N) - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) - offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) - offs_k = tl.arange(0, BLOCK_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) - accumulator = tl.dot(a, b, accumulator, out_dtype=acc_dtype) - a_ptrs += BLOCK_K * stride_ak - b_ptrs += BLOCK_K * stride_bk - c = accumulator.to(c_ptr.dtype.element_ty) - - offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - -@avoid_not_support('matmul') -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_matmul(shape, dtype): - M, N, K = shape[0], shape[0], shape[1] - # 32byte/Dtype_bytes - kalign = 32 // get_dtype_size(dtype) - BLOCK_M, BLOCK_N, BLOCK_K = min(max(M, 16), 32), min(max(N, 16), 32), min(max(K, kalign), 32) - a = test_common.generate_tensor((M, K), dtype) - b = test_common.generate_tensor((K, N), dtype) - - if dtype == "int8": - triton_res = torch.zeros((M, N), dtype=torch.int32).npu() - accumulator_type = tl.int32 - else: - triton_res = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() - accumulator_type = tl.float32 - grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), ) - - matmul_kernel[grid](a.npu(), b.npu(), triton_res, M, N, K, accumulator_type, a.stride(0), a.stride(1), b.stride(0), - b.stride(1), triton_res.stride(0), triton_res.stride(1), BLOCK_M, BLOCK_N, BLOCK_K) - - a_gold = a.to(torch.float32) - b_gold = b.to(torch.float32) - cpu_res = torch.mm(a_gold, b_gold) - - if dtype == "int8": - # torch_npu do not support int8 matmul - a_npu = a.npu().to(torch.float32) - b_npu = b.npu().to(torch.float32) - torch_res = torch.mm(a_npu, b_npu) - triton_res = triton_res.to(torch.float32) - else: - a_npu = a.npu() - b_npu = b.npu() - torch_res = torch.mm(a_npu, b_npu) - - try: - print("starting compare of cpu vs triton:") - acc_util.assert_close(cpu_res, triton_res) - except Exception as e: - print(e) - print("starting compare of cpu vs triton vs torch_npu:") - acc_util.benchmark_compare_close(cpu_res, triton_res, torch_res) - print("PASSED") - - -@avoid_not_support('matmul') -@pytest.mark.parametrize('batch', TestUtils.batch) -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_batch_matmul(shape, dtype, batch): - M, N, K = shape[0], shape[0], shape[1] - # 32byte/Dtype_bytes - kalign = 32 // get_dtype_size(dtype) - BLOCK_M, BLOCK_N, BLOCK_K = min(max(M, 16), 32), min(max(N, 16), 32), min(max(K, kalign), 32) - - aa = test_common.generate_tensor((batch, M, K), dtype) - bb = test_common.generate_tensor((batch, K, N), dtype) - - if dtype == "int8": - final_triton_res = torch.zeros((batch, M, N), dtype=torch.int32).npu() - accumulator_type = tl.int32 - else: - final_triton_res = torch.zeros((batch, M, N), dtype=eval('torch.' + dtype)).npu() - accumulator_type = tl.float32 - - for i in range(0, batch): - if dtype == "int8": - triton_res = torch.zeros((M, N), dtype=torch.int32).npu() - else: - triton_res = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() - grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), ) - a = aa[i] - b = bb[i] - matmul_kernel[grid](a.npu(), b.npu(), triton_res, M, N, K, accumulator_type, a.stride(0), a.stride(1), - b.stride(0), b.stride(1), triton_res.stride(0), triton_res.stride(1), BLOCK_M, BLOCK_N, - BLOCK_K) - final_triton_res[i] = triton_res - - a_gold = aa.to(torch.float32) - b_gold = bb.to(torch.float32) - cpu_res = torch.bmm(a_gold, b_gold) - - if dtype == "int8": - a_npu = aa.npu().to(torch.float32) - b_npu = bb.npu().to(torch.float32) - final_triton_res = final_triton_res.to(torch.float32) - else: - a_npu = aa.npu() - b_npu = bb.npu() - torch_res = torch.bmm(a_npu, b_npu) - - try: - print("starting compare of cpu vs triton:") - acc_util.assert_close(cpu_res, final_triton_res) - except Exception as e: - print(e) - print("starting compare of cpu vs triton vs torch_npu:") - acc_util.benchmark_compare_close(cpu_res, final_triton_res, torch_res) - print("PASSED") - - -if __name__ == "__main__": - test_matmul((16, 32), 'float32') - test_matmul((16, 32), 'int8') - test_batch_matmul(2, (16, 32), 'float32') - test_batch_matmul(2, (16, 32), 'int8') diff --git a/third_party/ascend/unittest/generalization_cases/test_max.py b/third_party/ascend/unittest/generalization_cases/test_max.py deleted file mode 100644 index 6029b77d9b..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_max.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import triton -import triton.language as tl -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - - -# <<<<<<< test_max_1d -def torch_max(x0, dim, keepdim): - inp = x0 if x0.device == "cpu" else x0.cpu() - return torch.max(inp, dim=dim, keepdim=keepdim)[0].npu() - - -@triton.jit -def triton_max_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None) - tmp4 = tl.max(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_max_1d(dtype, shape): - if check_ub_mem_overflow(dtype, shape): - return - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty(1, dtype=eval("torch." + dtype)).npu() - numel = shape[0] - triton_max_1d[1, 1, 1](x0, triton_res, numel, numel) - torch_res = torch_max(x0, dim=0, keepdim=True) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_max_1d - - -# <<<<<<< test_max_2d -@triton.jit -def triton_max_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) - tmp4 = tl.max(x, dim) - if dim == 0: - tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) - else: - tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0, 1]) -def test_max_2d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype == 'int8' or dtype == 'bool': - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 5): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - shapex, shapey = shape - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[1 - dim], - ], dtype=eval("torch." + dtype)).npu() - triton_max_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) - torch_res = torch_max(x0, dim=dim, keepdim=False) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_max_2d - - -# <<<<<<< test_max_3d -def torch_max_3d(x0, no_reduce_dim): - inp = x0 if x0.device == "cpu" else x0.cpu() - if no_reduce_dim == 0: - return torch.max(torch.max(inp, 1)[0], 1)[0].npu() - elif no_reduce_dim == 1: - return torch.max(torch.max(inp, 0)[0], 1)[0].npu() - elif no_reduce_dim == 2: - return torch.max(torch.max(inp, 0)[0], 0)[0].npu() - else: - assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" - - -@triton.jit -def triton_max_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.max(x, 0) - ret = tl.max(tmp, 0) - oidx = zidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_max_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.max(x, 0) - ret = tl.max(tmp, 1) - oidx = yidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_max_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.max(x, 1) - ret = tl.max(tmp, 1) - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -def triton_max_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): - if no_reduce_dim == 0: - triton_max_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 1: - triton_max_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 2: - triton_max_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) -def test_max_3d(dtype, shape, no_reduce_dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[no_reduce_dim], - ], dtype=eval("torch." + dtype)).npu() - triton_max_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) - torch_res = torch_max_3d(x0, no_reduce_dim) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_max_3d - - -# <<<<<<< test_max_4d -def torch_max_4d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.max(x0, dim=dim)[0] - - -@triton.jit -def max_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB // MB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_max_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - - idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[ - None, None, :, None] * MB + midx[None, None, None, :] - - x = tl.load(in_ptr + idx) - - max_4d(out_ptr, x, XB, YB, ZB, MB, DIM) - - -def triton_max_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): - triton_max_kernel_4d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [(2, 2, 4, 8)]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0]) -def test_max_4d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_max_4d(x0, dim) - triton_res = torch.empty_like(torch_res).npu() - triton_max_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) - - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_max_4d - - -# <<<<<<< test_max_5d -def torch_max_5d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.max(x0, dim=dim)[0] - - -@triton.jit -def max_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, - DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 3: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // NB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_max_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - - idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[ - None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] - - x = tl.load(in_ptr + idx) - - max_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) - - -def triton_max_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): - triton_max_kernel_5d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [(2, 2, 2, 4, 8)]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0]) -def test_max_5d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_max_5d(x0, dim) - triton_res = torch.empty_like(torch_res).npu() - triton_max_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) - - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_max_5d diff --git a/third_party/ascend/unittest/generalization_cases/test_min.py b/third_party/ascend/unittest/generalization_cases/test_min.py deleted file mode 100644 index 2c080e3b18..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_min.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import triton -import triton.language as tl -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - - -# <<<<<<< test_min_1d -def torch_min(x0, dim, keepdim): - inp = x0 if x0.device == "cpu" else x0.cpu() - return torch.min(inp, dim=dim, keepdim=keepdim)[0].npu() - - -@triton.jit -def triton_min_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None) - tmp4 = tl.min(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -def test_min_1d(dtype, shape): - if check_ub_mem_overflow(dtype, shape): - return - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty(1, dtype=eval("torch." + dtype)).npu() - numel = shape[0] - triton_min_1d[1, 1, 1](x0, triton_res, numel, numel) - torch_res = torch_min(x0, dim=0, keepdim=True) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_min_1d - - -# <<<<<<< test_min_2d -@triton.jit -def triton_min_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) - tmp4 = tl.min(x, dim) - if dim == 0: - tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) - else: - tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0, 1]) -def test_min_2d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype == 'int8' or dtype == 'bool': - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 5): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - shapex, shapey = shape - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[1 - dim], - ], dtype=eval("torch." + dtype)).npu() - triton_min_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) - torch_res = torch_min(x0, dim=dim, keepdim=False) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_min_2d - - -# <<<<<<< test_min_3d -def torch_min_3d(x0, no_reduce_dim): - inp = x0 if x0.device == "cpu" else x0.cpu() - if no_reduce_dim == 0: - return torch.min(torch.min(inp, 1)[0], 1)[0].npu() - elif no_reduce_dim == 1: - return torch.min(torch.min(inp, 0)[0], 1)[0].npu() - elif no_reduce_dim == 2: - return torch.min(torch.min(inp, 0)[0], 0)[0].npu() - else: - assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" - - -@triton.jit -def triton_min_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.min(x, 0) - ret = tl.min(tmp, 0) - oidx = zidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_min_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.min(x, 0) - ret = tl.min(tmp, 1) - oidx = yidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_min_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.min(x, 1) - ret = tl.min(tmp, 1) - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -def triton_min_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): - if no_reduce_dim == 0: - triton_min_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 1: - triton_min_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 2: - triton_min_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) -def test_min_3d(dtype, shape, no_reduce_dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[no_reduce_dim], - ], dtype=eval("torch." + dtype)).npu() - triton_min_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) - torch_res = torch_min_3d(x0, no_reduce_dim) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_min_3d - - -# <<<<<<< test_min_4d -def torch_min_4d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.min(x0, dim=dim)[0] - - -@triton.jit -def min_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB // MB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_min_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - - idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[ - None, None, :, None] * MB + midx[None, None, None, :] - - x = tl.load(in_ptr + idx) - - min_4d(out_ptr, x, XB, YB, ZB, MB, DIM) - - -def triton_min_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): - triton_min_kernel_4d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [(2, 2, 4, 8)]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0]) -def test_min_4d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_min_4d(x0, dim) - triton_res = torch.empty_like(torch_res).npu() - triton_min_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) - - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_min_4d - - -# <<<<<<< test_min_5d -def torch_min_5d(x0, dim): - x0 = x0 if x0.device == "cpu" else x0.cpu() - if x0.dtype in (torch.int8, torch.int16, torch.int32): - x0 = x0.to(torch.int64) - return torch.min(x0, dim=dim)[0] - - -@triton.jit -def min_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, - DIM: tl.constexpr): - if DIM == 0: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // XB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 1: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // YB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 2: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // ZB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) - tl.store(out_ptr + o_idx, ret) - elif DIM == 3: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // MB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) - tl.store(out_ptr + o_idx, ret) - else: - ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // NB) - o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) - tl.store(out_ptr + o_idx, ret) - - -@triton.jit -def triton_min_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIM: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - midx = tl.arange(0, MB) - nidx = tl.arange(0, NB) - - idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[ - None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] - - x = tl.load(in_ptr + idx) - - min_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) - - -def triton_min_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): - triton_min_kernel_5d[(1, )](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [(2, 2, 2, 4, 8)]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dim', [0]) -def test_min_5d(dtype, shape, dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_min_5d(x0, dim) - triton_res = torch.empty_like(torch_res).npu() - triton_min_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) - - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_min_5d diff --git a/third_party/ascend/unittest/generalization_cases/test_mod.py b/third_party/ascend/unittest/generalization_cases/test_mod.py deleted file mode 100644 index ce15ea3d84..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_mod.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_pointwise(x, y): - res = x % y - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X % Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_mod_4d( - output_ptr, - x_ptr, - y_ptr, - BLOCK_SIZE: tl.constexpr, - SUB_BLOCK: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - pid = tl.program_id(0) - for loop in range(0, tl.cdiv(BLOCK_SIZE, SUB_BLOCK)): - base_idx = tl.arange(0, SUB_BLOCK) - pid_tensor = tl.full((SUB_BLOCK, ), pid * BLOCK_SIZE + loop * SUB_BLOCK, dtype=tl.int32) - tmp0 = (pid_tensor + base_idx)[:, None, None, None] - tmp1 = tl.arange(0, SHAPE_1)[None, :, None, None] - tmp2 = tl.arange(0, SHAPE_2)[None, None, :, None] - tmp3 = tl.arange(0, SHAPE_3)[None, None, None, :] - offsets = tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 - masks = tmp0 < SHAPE_0 - x = tl.load(x_ptr + offsets, mask=masks) - y = tl.load(y_ptr + offsets, mask=masks) - ret = x % y - tl.store(output_ptr + offsets, ret, mask=masks) - - -@triton.jit -def triton_mod_5d( - output_ptr, - x_ptr, - y_ptr, - BLOCK_SIZE: tl.constexpr, - SUB_BLOCK: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr, -): - pid = tl.program_id(0) - for loop in range(0, tl.cdiv(BLOCK_SIZE, SUB_BLOCK)): - base_idx = tl.arange(0, SUB_BLOCK) - pid_tensor = tl.full((SUB_BLOCK, ), pid * BLOCK_SIZE + loop * SUB_BLOCK, dtype=tl.int32) - tmp0 = (pid_tensor + base_idx)[:, None, None, None, None] - tmp1 = tl.arange(0, SHAPE_1)[None, :, None, None, None] - tmp2 = tl.arange(0, SHAPE_2)[None, None, :, None, None] - tmp3 = tl.arange(0, SHAPE_3)[None, None, None, :, None] - tmp4 = tl.arange(0, SHAPE_4)[None, None, None, None, :] - offsets = tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 - masks = tmp0 < SHAPE_0 - x = tl.load(x_ptr + offsets, mask=masks) - y = tl.load(y_ptr + offsets, mask=masks) - ret = x % y - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) -def test_case2(dtype, shape): - if dtype in ['int8', 'int16', 'int32', 'int64']: - x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - z = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - - x[x <= 0] = 1 - y[y <= 0] = 1 - z[z <= 0] = 1 - - ans = torch_pointwise(x.cpu(), y.cpu()) - ans = ans.npu() - output = torch.zeros_like(ans) - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + [(25, 2, 3, 31), (2, 2, 39, 23), (17, 27, 3, 3), - (3, 2, 27, 37)]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_mod_4d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - if dtype in ['int8', 'int16', 'int32', 'int64']: - x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - x[x <= 0] = 1 - y[y <= 0] = 1 - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_pointwise(x.cpu(), y.cpu()) - ans = ans.npu() - - n = x.numel() - block_size = min(triton.next_power_of_2(n), 64) - sub_block_size = 1 - grid = (triton.cdiv(n, block_size), ) - print(" ") - print(f"=== loops: {triton.cdiv(block_size, sub_block_size)}") - print(f"=== grid : {grid}") - triton_mod_4d[grid](output, x, y, block_size, sub_block_size, *list(shape), *list(x.stride())) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape5d + [(32, 5, 3, 1, 8)]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_mod_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - if dtype in ['int8', 'int16', 'int32', 'int64']: - x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - x[x <= 0] = 1 - y[y <= 0] = 1 - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_pointwise(x.cpu(), y.cpu()) - ans = ans.npu() - - n = x.numel() - block_size = min(triton.next_power_of_2(n), 32) - sub_block_size = 1 - grid = (triton.cdiv(n, block_size), ) - print(" ") - print(f"=== loops: {triton.cdiv(block_size, sub_block_size)}") - print(f"=== grid : {grid}") - triton_mod_5d[grid](output, x, y, block_size, sub_block_size, *list(shape), *list(x.stride())) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_ne.py b/third_party/ascend/unittest/generalization_cases/test_ne.py deleted file mode 100644 index a05220da45..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_ne.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math -import logging - - -def torch_ne(x0, x1): - if x0.dtype != torch.uint32: - return x0 != x1 - else: - return x0.to(torch.float32) != x1.to(torch.float32) - - -@triton.jit -def triton_ne(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base1 = tl.arange(0, XBLOCK_SUB) - loops1: tl.constexpr = XBLOCK // XBLOCK_SUB - for loop1 in range(loops1): - x_index = offset + (loop1 * XBLOCK_SUB) + base1 - tmp0 = tl.load(in_ptr0 + x_index, mask=x_index < N) - tmp1 = tl.load(in_ptr1 + x_index, mask=x_index < N) - tmp2 = tmp0 != tmp1 - tl.store(out_ptr0 + x_index, tmp2, mask=x_index < N) - - -@triton.jit -def triton_ne_4d_5d(x_ptr, y_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val != y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) -def test_ne(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - # 生成数据 - x0 = test_common.generate_tensor(shape, dtype).npu() - x1 = test_common.generate_tensor(shape, dtype).npu() - - numel = x0.numel() - ncore = 1 if numel <= 32 else 32 - xblock = math.ceil(numel / ncore) - xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) - - # torch结果 - torch_res = torch_ne(x0, x1).to(eval('torch.' + dtype)) - # triton结果 - triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - N = triton_res.numel() - triton_ne[ncore, 1, 1](x0, x1, triton_res, N, xblock, xblock_sub) - # 比较结果 - torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) - triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) - cmp_dtype = dtype if dtype != 'uint32' else 'float32' - test_common.validate_cmp(cmp_dtype, triton_res, torch_res) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_ne_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_ne(x, y).to(eval('torch.' + dtype)) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_ne_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_neg.py b/third_party/ascend/unittest/generalization_cases/test_neg.py deleted file mode 100644 index 07c738d28f..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_neg.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_pointwise(x): - res = -x - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = -X - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_neg_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = -x_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) -def test_case2(dtype, shape): - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_pointwise(x.cpu()) - ans = ans.npu() - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) -def test_neg_4d_5d(shape, dtype): - x = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_pointwise(x.cpu()) - ans = ans.npu() - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_neg_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'bool', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/unittest/generalization_cases/test_not.py b/third_party/ascend/unittest/generalization_cases/test_not.py deleted file mode 100644 index 21397985d6..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_not.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_not(x0): - res = torch.bitwise_not(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = not (X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_not_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = not (x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_not(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_not_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_not(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_not_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_or.py b/third_party/ascend/unittest/generalization_cases/test_or.py deleted file mode 100644 index 9861b8daf7..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_or.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_or(x0, x1): - return x0 | x1 - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X | Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_or_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val | y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - ans = torch_or(x, y) - output = torch.zeros_like(ans) - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - elif max(shape[0], shape[1], shape[2]) == shape[1]: - fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_or_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x | y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_or_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/unittest/generalization_cases/test_permute_1d_2d.py b/third_party/ascend/unittest/generalization_cases/test_permute_1d_2d.py deleted file mode 100644 index 70d41abc22..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_permute_1d_2d.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, xnumel: tl.constexpr): - idx = tl.arange(0, xnumel) - - X = tl.load(x_ptr + idx) - - ret = tl.permute(X, (0, )) - - tl.store(output_ptr + idx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_permute_1d(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() - - triton_res = torch.randint(1, shape, dtype=data_type).npu() - torch_res = torch.permute(x, (0, )) - fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -@triton.jit -def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr): - pid = tl.program_id(0) - yidx = tl.arange(0, YB) + pid * YB - zidx = tl.arange(0, ZB) - idx = yidx[:, None] * znumel + zidx[None, :] - - # XB,YB,1 - X = tl.load(x_ptr + idx) - - ret = tl.permute(X, (1, 0)) - - oidx = zidx[:, None] * ynumel + yidx[None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_permute(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - - ynumel = shape[0] - YB = 1 - znumel = shape[1] - ZB = shape[1] - - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=(shape[0], shape[1]), dtype=data_type).npu() - - triton_res = torch.randint(1, (shape[1], shape[0]), dtype=data_type).npu() - torch_res = torch.permute(x, (1, 0)) - fn_npu_021[shape[0], 1, 1](triton_res, x, YB, ZB, ynumel, znumel) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -if __name__ == "__main__": - for shape in [(37, 3)]: - for dtype in TestUtils.dtype_list: - test_permute(shape, dtype) diff --git a/third_party/ascend/unittest/generalization_cases/test_permute_3d.py b/third_party/ascend/unittest/generalization_cases/test_permute_3d.py deleted file mode 100644 index 9696e09f32..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_permute_3d.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def fn_npu_102(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.permute(X, (1, 0, 2)) - - oidx = zidx[:, None, None] * YB * KB + yidx[None, :, None] * KB + kidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_210(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.permute(X, (2, 1, 0)) - - oidx = kidx[:, None, None] * ZB * YB + zidx[None, :, None] * YB + yidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.permute(X, (0, 2, 1)) - - oidx = yidx[:, None, None] * ZB * KB + kidx[None, :, None] * ZB + zidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', ["int8", 'int16', 'int32', 'float16', 'float32', 'bfloat16', 'int64']) -def test_permute_3d(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() - - triton_res = torch.empty((shape[1], shape[0], shape[2]), dtype=data_type).npu() - torch_res = torch.permute(x, (1, 0, 2)) - fn_npu_102[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - test_common.validate_cmp(dtype, triton_res, torch_res) - - # not support yet: need bisheng support later - # triton_res = torch.empty((shape[2], shape[1], shape[0]), dtype=data_type).npu() - # torch_res = torch.permute(x, (2, 1, 0)) - # fn_npu_210[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - # test_common.validate_cmp(dtype, triton_res, torch_res) - - triton_res = torch.empty((shape[0], shape[2], shape[1]), dtype=data_type).npu() - torch_res = torch.permute(x, (0, 2, 1)) - fn_npu_021[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_permute_4d_5d.py b/third_party/ascend/unittest/generalization_cases/test_permute_4d_5d.py deleted file mode 100644 index 615ff3cd6e..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_permute_4d_5d.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def triton_permute_4d( - output_ptr, - x_ptr, - PERM: tl.constexpr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] - tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] - tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None] - tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] - tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None] - tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] - tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None] - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) - x_val = tl.load(x_ptr + offsets, masks) - - if PERM == 0: # 1, 0, 2, 3 - ret = tl.permute(x_val, (1, 0, 2, 3)) - shape0 = SHAPE_1 - shape1 = SHAPE_0 - shape2 = SHAPE_2 - shape3 = SHAPE_3 - elif PERM == 1: # 0, 2, 1, 3 - ret = tl.permute(x_val, (0, 2, 1, 3)) - shape0 = SHAPE_0 - shape1 = SHAPE_2 - shape2 = SHAPE_1 - shape3 = SHAPE_3 - else: # 0, 1, 3, 2 - ret = tl.permute(x_val, (0, 1, 3, 2)) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_3 - shape3 = SHAPE_2 - - s3 = 1 - s2 = s3 * shape3 - s1 = s2 * shape2 - s0 = s1 * shape1 - - if PERM == 0: # 1, 0, 2, 3 - out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 - out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) - elif PERM == 1: # 0, 2, 1, 3 - out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 - out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) - else: # 0, 1, 3, 2 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@triton.jit -def triton_permute_5d(output_ptr, x_ptr, PERM: tl.constexpr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] - tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] - - tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None, None] - tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None, None] - - tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None, None] - tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None, None] - - tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :, None] - tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None, None] - - tmp3_4 = tl.arange(0, BLOCK_3)[None, None, None, None, :] - tmp4_3 = tl.arange(0, BLOCK_4)[None, None, None, :, None] - - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) - x_val = tl.load(x_ptr + offsets, masks) - - if PERM == 0: # 1, 0, 2, 3, 4 - ret = tl.permute(x_val, 1, 0, 2, 3, 4) - shape0 = SHAPE_1 - shape1 = SHAPE_0 - shape2 = SHAPE_2 - shape3 = SHAPE_3 - shape4 = SHAPE_4 - elif PERM == 1: # 0, 2, 1, 3, 4 - ret = tl.permute(x_val, 0, 2, 1, 3, 4) - shape0 = SHAPE_0 - shape1 = SHAPE_2 - shape2 = SHAPE_1 - shape3 = SHAPE_3 - shape4 = SHAPE_4 - elif PERM == 2: # 0, 1, 3, 2, 4 - ret = tl.permute(x_val, 0, 1, 3, 2, 4) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_3 - shape3 = SHAPE_2 - shape4 = SHAPE_4 - else: # 0, 1, 2, 4, 3 - ret = tl.permute(x_val, 0, 1, 2, 4, 3) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_2 - shape3 = SHAPE_4 - shape4 = SHAPE_3 - - s4 = 1 - s3 = s4 * shape4 - s2 = s3 * shape3 - s1 = s2 * shape2 - s0 = s1 * shape1 - - if PERM == 0: # 1, 0, 2, 3, 4 - out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 + tmp4 * s4 - out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) & (tmp4 < shape4) - elif PERM == 1: # 0, 2, 1, 3, 4 - out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 + tmp4 * s4 - out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) & (tmp4 < shape4) - elif PERM == 2: # 0, 1, 3, 2, 4 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 + tmp4 * s4 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) & (tmp4 < shape4) - else: # 0, 1, 2, 4, 3 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp2 * s2 + tmp4_3 * s3 + tmp3_4 * s4 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp2 < shape2) & (tmp4_3 < shape3) & (tmp3_4 < shape4) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -@pytest.mark.parametrize('perm', [0, 1, 2, 3]) # 4d: support 3 mode; 5d: support 4 mode -def test_permute_4d_5d(shape, dtype, perm): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.randint(low=0, high=2, size=shape, dtype=eval('torch.' + dtype)).npu() - grid = (1, ) - if len(shape) == 4: - blocks = list(x.size()) - strides = list(x.stride()) - if perm == 0: # 1, 0, 2, 3; exchange axis 0, 1 - output = torch.empty((shape[1], shape[0], shape[2], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (1, 0, 2, 3)) - triton_permute_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - elif perm == 1: # 0, 2, 1, 3; exchange axis 1, 2 - output = torch.empty((shape[0], shape[2], shape[1], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (0, 2, 1, 3)) - triton_permute_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - elif perm == 2: # 0, 1, 3, 2; exchange axis 2, 3 - output = torch.empty((shape[0], shape[1], shape[3], shape[2]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (0, 1, 3, 2)) - triton_permute_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - else: - pass - else: - blocks = list(x.size()) - strides = list(x.stride()) - - if perm == 0: # 1, 0, 2, 3, 4; exchange axis 0, 1 - output = torch.empty((shape[1], shape[0], shape[2], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (1, 0, 2, 3, 4)) - elif perm == 1: # 0, 2, 1, 3, 4; exchange axis 1, 2 - output = torch.empty((shape[0], shape[2], shape[1], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 2, 1, 3, 4)) - elif perm == 2: # 0, 1, 3, 2, 4; exchange axis 2, 3 - output = torch.empty((shape[0], shape[1], shape[3], shape[2], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 1, 3, 2, 4)) - else: # 0, 1, 2, 4, 3; exchange axis 3, 4 - output = torch.empty((shape[0], shape[1], shape[2], shape[4], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 1, 2, 4, 3)) - triton_permute_5d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_5d, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_rand.py b/third_party/ascend/unittest/generalization_cases/test_rand.py deleted file mode 100644 index 8e66e48bfa..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_rand.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math -import numpy as np -import scipy - - -@triton.jit -def kernel_rand(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): - block_offset = tl.program_id(0) * XBLOCK - block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset - for inner_idx in range(block_size): - global_offset = block_offset + inner_idx - rand_vals = tl.rand(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 - tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 - - -@triton.jit -def triton_rand_4d_5d(output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - # 1D program_id for flatten multi-d offset - pid = tl.program_id(0) - # base offset for dimension 0 - offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 - mask = tl.arange(0, BLOCK_0) < SHAPE_0 - # nested offset expansion - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - ret = tl.rand(5, offsets, 10) - tl.store(output_ptr + offsets, ret, mask=mask) - - -@triton.jit -def kernel_randn(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): - block_offset = tl.program_id(0) * XBLOCK - block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset - for inner_idx in range(block_size): - global_offset = block_offset + inner_idx - rand_vals = tl.randn(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 - tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 - - -@triton.jit -def triton_randn_4d_5d(output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - # 1D program_id for flatten multi-d offset - pid = tl.program_id(0) - # base offset for dimension 0 - offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 - mask = tl.arange(0, BLOCK_0) < SHAPE_0 - # nested offset expansion - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - ret = tl.randn(5, offsets, 10) - tl.store(output_ptr + offsets, ret, mask=mask) - - -@triton.jit -def kernel_randint(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): - block_offset = tl.program_id(0) * XBLOCK - block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset - for inner_idx in range(block_size): - global_offset = block_offset + inner_idx - rand_vals = tl.randint(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 - tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 - - -@triton.jit -def triton_randint_4d_5d(output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - # 1D program_id for flatten multi-d offset - pid = tl.program_id(0) - # base offset for dimension 0 - offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 - mask = tl.arange(0, BLOCK_0) < SHAPE_0 - # nested offset expansion - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - ret = tl.randint(5, offsets, 10) - tl.store(output_ptr + offsets, ret, mask=mask) - - -@triton.jit -def kernel_randint4x(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): - block_offset = tl.program_id(0) * XBLOCK - indices = tl.arange(0, 4) - block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset - for inner_idx in range(0, block_size + 4, step=4): - global_offset = block_offset + inner_idx - rand_vals = tl.randint4x(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 - mask = (global_offset + indices) < (block_offset + block_size) - tl.store(x_ptr + global_offset + indices, rand_vals, mask) # 存储随机数 - - -@triton.jit -def triton_randint4x_4d_5d(output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - # 1D program_id for flatten multi-d offset - pid = tl.program_id(0) - # base offset for dimension 0 - offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 - mask = tl.arange(0, BLOCK_0) < SHAPE_0 - # nested offset expansion - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - ret = tl.randint4x(5, offsets, 10) - tl.store(output_ptr + offsets, ret, mask=mask) - - -# With alpha=0.01, z=-3.0902, N=100, we have (1-0.01)+(-3.0902)*sqrt(0.01*(1-0.01)/100)=0.9593, -# so there must be 96 cases for each shape to have pvalue larger than 0.01. -# There is higher possibility to fail with small shapes, so we will use large shape. -@pytest.mark.parametrize('shape', [ - (256, 256), - (512, 512), - (1024, 1024), -]) -def test_rand_case(shape): - y_calf = torch.zeros(shape, dtype=eval('torch.float32')).npu() - - numel = y_calf.numel() - ncore = 1 if numel < 32 else 32 - xblock = math.ceil(numel / ncore) - - correctness = 0 - for _ in range(100): - ref = np.random.random_sample(shape).flatten() - kernel_rand[ncore, 1, 1](y_calf, 10, numel, xblock) - - pvalue = scipy.stats.kstest(ref, y_calf.cpu().numpy().flatten()).pvalue - if pvalue > 0.01: - correctness += 1 - - assert correctness > 95 - - -@pytest.mark.parametrize('shape', [ - (256, 256), - (512, 512), - (1024, 1024), -]) -def test_randn_case(shape): - y_calf = torch.zeros(shape, dtype=eval('torch.float32')).npu() - - numel = y_calf.numel() - ncore = 1 if numel < 32 else 32 - xblock = math.ceil(numel / ncore) - - correctness = 0 - for _ in range(100): - ref = np.random.standard_normal(shape).flatten() - kernel_randn[ncore, 1, 1](y_calf, 10, numel, xblock) - - pvalue = scipy.stats.kstest(ref, y_calf.cpu().numpy().flatten()).pvalue - if pvalue > 0.01: - correctness += 1 - - assert correctness > 95 - - -@pytest.mark.parametrize('shape', [ - (256, 256), - (512, 512), - (1024, 1024), -]) -def test_randint_case(shape): - y_cali = torch.zeros(shape, dtype=eval('torch.int32')).npu() - - numel = y_cali.numel() - ncore = 1 if numel < 32 else 32 - xblock = math.ceil(numel / ncore) - - correctness = 0 - ii32 = np.iinfo(np.int32) - for _ in range(100): - ref = np.random.randint(low=ii32.min, high=ii32.max, size=shape).flatten() - kernel_randint[ncore, 1, 1](y_cali, 10, numel, xblock) - - pvalue = scipy.stats.kstest(ref, y_cali.cpu().numpy().flatten()).pvalue - if pvalue > 0.01: - correctness += 1 - - assert correctness > 95 - - -@pytest.mark.parametrize('shape', [ - (256, 256), - (512, 512), - (1024, 1024), -]) -def test_randint4x_case(shape): - y_cali = torch.zeros(shape, dtype=eval('torch.int32')).npu() - - numel = y_cali.numel() - ncore = 1 if numel < 32 else 32 - xblock = math.ceil(numel / ncore) - - correctness = 0 - ii32 = np.iinfo(np.int32) - for _ in range(100): - ref = np.random.randint(low=ii32.min, high=ii32.max, size=shape).flatten() - kernel_randint4x[ncore, 1, 1](y_cali, 10, numel, xblock) - - pvalue = scipy.stats.kstest(ref, y_cali.cpu().numpy().flatten()).pvalue - if pvalue > 0.01: - correctness += 1 - - assert correctness > 95 - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -def test_rand_4d_5d(shape): - x = torch.zeros(shape, dtype=eval('torch.float32')).npu() - y = torch.zeros(shape, dtype=eval('torch.int32')).npu() - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_rand_4d_5d[grid](x, *blocks, *blocks, *strides) - triton_randn_4d_5d[grid](x, *blocks, *blocks, *strides) - triton_randint_4d_5d[grid](y, *blocks, *blocks, *strides) - triton_randint4x_4d_5d[grid](y, *blocks, *blocks, *strides) diff --git a/third_party/ascend/unittest/generalization_cases/test_range.py b/third_party/ascend/unittest/generalization_cases/test_range.py deleted file mode 100644 index 992076fc6d..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_range.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import torch -import test_common -import logging - -import triton.language as tl -from test_common import TestUtils - - -@triton.jit -def triton_range(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X + Y - for _ in tl.range(2, 5, 2): - ret = ret + X - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_static_range(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, - XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X + Y - for _ in tl.static_range(2, 5, 2): - ret = ret + X - - tl.store(output_ptr + idx, ret) - - -test_shape = [(1, ), (2, ), (1, 1), (3, 4), (1, 1, 1), (2, 4, 8)] - - -@pytest.mark.parametrize('shape', test_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_range(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if dtype == 'bfloat16': - ans = (x.to(torch.float32) + y.to(torch.float32) + x.to(torch.float32) + x.to(torch.float32)).to(torch.bfloat16) - else: - ans = x + y + x + x - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if dtype == 'int8': - if x.numel() * x.element_size() >= 512: - grid = (1, 1, ZB) - ZB = 1 - else: - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - triton_range[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', test_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -def test_static_range(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - if dtype == 'bfloat16': - ans = (x.to(torch.float32) + y.to(torch.float32) + x.to(torch.float32) + x.to(torch.float32)).to(torch.bfloat16) - else: - ans = x + y + x + x - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if dtype == 'int8': - if x.numel() * x.element_size() >= 512: - grid = (1, 1, ZB) - ZB = 1 - else: - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - triton_static_range[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_reduce.py b/third_party/ascend/unittest/generalization_cases/test_reduce.py deleted file mode 100644 index 14ab4696cb..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_reduce.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import random -import pytest -import torch -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, get_dtype_size - - -def torch_reduce(x1, dim): - if x1.dtype == torch.float16 or x1.dtype == torch.float32: - res = torch.sum(x1.to(torch.float32), dim=dim).to(x1.dtype) - else: - res = torch.sum(x1, dim=dim).to(x1.dtype) - return res - - -@triton.jit -def _reduce_combine(a, b): - return a + b - - -@triton.jit -def tt_reduce_1d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - idx = tl.arange(0, XB) - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_reduce_2d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - - if dim == 0: - oidx = yidx - else: - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_reduce_1d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - idx = tl.arange(0, XB) - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_reduce_2d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_reduce_3d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - - tl.store(out_ptr, ret) - - -@triton.jit -def tt_reduce_3d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - ret = tl.reduce(x, dim, _reduce_combine) - - if dim == 0: - oidx = yidx[:, None] * znumel + zidx[None, :] - elif dim == 1: - oidx = xidx[:, None] * znumel + zidx[None, :] - else: - oidx = xidx[:, None] * ynumel + yidx[None, :] - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_reduce_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.reduce(x, 0, _reduce_combine) - ret = tl.reduce(tmp, 0, _reduce_combine) - oidx = zidx - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_reduce_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.reduce(x, 0, _reduce_combine) - ret = tl.reduce(tmp, 1, _reduce_combine) - oidx = yidx - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_reduce_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.reduce(x, 1, _reduce_combine) - ret = tl.reduce(tmp, 1, _reduce_combine) - oidx = xidx - - tl.store(out_ptr + oidx, ret) - - -def is_legal_combine(shape, dims): - return dims is None or (len(shape) == 3) or \ - (len(dims) == 1 and dims[0] < len(shape)) - - -dims_map = {(0, 1): tt_reduce_3d_0_1, (1, 2): tt_reduce_3d_1_2, (0, 2): tt_reduce_3d_0_2} - -shape_map = { - 1: {"append_shape": (1, 1), "func": tt_reduce_1d}, 2: {"append_shape": (1, ), "func": tt_reduce_2d}, 3: - {"append_shape": (), "func": tt_reduce_3d} -} - - -def reduce_check_ub_mem_overflow(dtype, shape): - dtype_size = get_dtype_size(dtype) - if (dtype == "int8" or dtype == "bool") and dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): - pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 6): - pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") - - -@pytest.mark.parametrize('shape', random.sample(TestUtils.full_shape, 5)) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (0, 1), (1, 2), (0, 2)]) -def test_reduce(dtype, shape, dims): - if not is_legal_combine(shape, dims): - return - - torch.manual_seed(0) - x = test_common.generate_tensor(shape, dtype).npu() - grid = (1, 1, 1) - - y_ref = torch_reduce(x, dims) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - if dims is None: - reduce_check_ub_mem_overflow(dtype, shape) - append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] - xnumel, ynumel, znumel = shape + append_shape - XB, YB, ZB = xnumel, ynumel, znumel - if len(shape) == 1: - tt_reduce_1d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - if len(shape) == 2: - tt_reduce_2d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - if len(shape) == 3: - tt_reduce_3d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - - test_common.validate_cmp(dtype, y_cal, y_ref) - - elif len(dims) == 1: # 1d reduce, 1-3d shape - append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] - xnumel, ynumel, znumel = shape + append_shape - XB, YB, ZB = xnumel, ynumel, znumel - if (len(shape) == 2) and (x.numel() * x.element_size() > 8192): - if dims[0] == 0: - grid = (1, ynumel, 1) - YB = 1 - else: - grid = (xnumel, 1, 1) - XB = 1 - tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) - test_common.validate_cmp(dtype, y_cal, y_ref) - else: # 3d shape, 2d reduce - tt_kernel = dims_map[dims] - xnumel, ynumel, znumel = shape - XB, YB, ZB = xnumel, ynumel, znumel - - tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) - test_common.validate_cmp(dtype, y_cal, y_ref) - - -@triton.jit -def triton_reduce_multi_d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr, REDUCE_NUMEL: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - x = tl.load(in_ptr + offsets) - - if DIM is not None: - ret = tl.reshape(tl.reduce(x, DIM, _reduce_combine), REDUCE_NUMEL) - o_offsets = tl.arange(0, REDUCE_NUMEL) - tl.store(out_ptr + o_offsets, ret) - else: - ret = tl.reduce(x, DIM, _reduce_combine) - tl.store(out_ptr, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (4, 2, 8, 4), - (4, 3, 8, 1), -]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (3, )]) -def test_reduce_4d(dtype, shape, dims): - torch.manual_seed(0) - - x = test_common.generate_tensor(shape, dtype).npu() - dim = dims[0] if dims is not None else None - - y_ref = torch_reduce(x, dim) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_reduce_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, y_cal, y_ref) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 2, 8, 4), - (3, 4, 2, 8, 1), -]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (3, ), (4, )]) -def test_reduce_5d(dtype, shape, dims): - torch.manual_seed(0) - - x = test_common.generate_tensor(shape, dtype).npu() - dim = dims[0] if dims is not None else None - - y_ref = torch_reduce(x, dim) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_reduce_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_relu.py b/third_party/ascend/unittest/generalization_cases/test_relu.py deleted file mode 100644 index 21880163bf..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_relu.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -import triton.language.extra.ascend.libdevice as libdevice -from test_common import TestUtils -import math - - -def torch_relu(x0, x1): - res = x0 + torch.relu(x1) - return res - - -@triton.jit -def triton_relu(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - xoffset = tl.program_id(0) * XBLOCK - for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): - x_index = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] - xmask = x_index < xnumel - tmp0 = tl.load(in_ptr0 + x_index, xmask) - tmp1 = tl.load(in_ptr1 + x_index, xmask) - tmp2 = tmp0 + libdevice.relu(tmp1) - tl.store(out_ptr0 + x_index, tmp2, xmask) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['float32', 'float16']) -def test_relu(dtype, shape): - # 生成数据 - x0 = test_common.generate_tensor(shape, dtype).npu() - x1 = test_common.generate_tensor(shape, dtype).npu() - - numel = x0.numel() - ncore = 1 if numel <= 32 else 32 - xblock = math.ceil(numel / ncore) - xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) - - # torch结果 - torch_res = torch_relu(x0, x1) - # triton结果 - triton_res = test_common.generate_tensor(shape, dtype).npu() - triton_relu[ncore, 1, 1](x0, x1, triton_res, x0.numel(), xblock, xblock_sub) - # 比较结果 - test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_rshift_op.py b/third_party/ascend/unittest/generalization_cases/test_rshift_op.py deleted file mode 100644 index 33b6fffecd..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_rshift_op.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common -from test_common import TestUtils - - -@triton.jit -def triton_rshift_1d(in_ptr0, out_ptr0, L: tl.constexpr): - lblk_idx = tl.arange(0, L) - idx = lblk_idx[:] - x0 = tl.load(in_ptr0 + idx) - ret = x0 >> 2 - odx = lblk_idx[:] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_rshift_2d(in_ptr0, out_ptr0, M: tl.constexpr, N: tl.constexpr): - moffs = tl.program_id(0) * M - mblk_idx = tl.arange(0, M) + moffs - nblk_idx = tl.arange(0, N) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x0 = tl.load(in_ptr0 + idx) - ret = x0 >> 2 - odx = mblk_idx[:, None] * N + nblk_idx[None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_rshift_3d(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): - loffs = tl.program_id(0) * L - lblk_idx = tl.arange(0, L) + loffs - mblk_idx = tl.arange(0, M) - nblk_idx = tl.arange(0, N) - idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - x0 = tl.load(in_ptr0 + idx) - ret = x0 >> 2 - odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] - tl.store(out_ptr0 + odx, ret) - - -@triton.jit -def triton_rshift_4d_5d(x_ptr, output_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = x_val >> 2 - tl.store(output_ptr + offsets, ret, mask=masks) - - -dtype_mapping = { - 'int8': (torch.int8), - 'int16': (torch.int16), - 'int32': (torch.int32), - 'uint32': (torch.uint32), - 'int64': (torch.int64), - 'float16': (torch.float16), - 'float32': (torch.float32), - 'bfloat16': (torch.bfloat16), - 'bool': (torch.bool), -} - -typelist = [ - 'int8', - 'int16', - 'int32', - 'int64', -] - - -# @pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) -@pytest.mark.parametrize('sigtype', typelist) -def test_lshift(sigtype, shape): - dtype = dtype_mapping[sigtype] - x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() - # ncore, xblock, xblock_sub = 2, 32768, 1024 - y_ref = x0 >> 2 - output = torch.zeros(shape, dtype=dtype).npu() - if len(shape) == 3: - shape0 = shape[0] - shape1 = shape[1] - shape2 = shape[2] - if x0.numel() * x0.element_size() >= 1024: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_rshift_3d[grid](x0, output, shape0, shape1, shape2) - if len(shape) == 2: - shape0 = shape[0] - shape1 = shape[1] - if x0.numel() * x0.element_size() >= 1024: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - triton_rshift_2d[grid](x0, output, shape0, shape1) - if len(shape) == 1: - triton_rshift_1d[1, 1, 1](x0, output, shape[0]) - test_common.validate_cmp(sigtype, output, y_ref) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_rshift_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x >> 2 - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_rshift_4d_5d[grid](x, output, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - triton_rshift_1d[1, 1, 1](x, output, N) diff --git a/third_party/ascend/unittest/generalization_cases/test_scalar_tensor.py b/third_party/ascend/unittest/generalization_cases/test_scalar_tensor.py deleted file mode 100644 index defb936aa7..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_scalar_tensor.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import numpy as np -import torch -import pytest -import test_common - - -def torch_(x0, x1, op_type): - if op_type == 'mul': - return torch.tensor(x0 * x1) - elif op_type == 'lshift': - return torch.tensor(x0 << x1) - elif op_type == 'eq': - return torch.tensor(x0 == x1) - else: - raise TypeError('Invalid op_type') - - -@triton.jit -def scalar_mul(out_ptr0, val0: tl.constexpr, val1: tl.constexpr): - scalar0 = tl.core.tensor(val0, tl.core.block_type(tl.float32, [])) - scalar1 = tl.core.tensor(val1, tl.core.block_type(tl.float32, [])) - ret = scalar0 * scalar1 - tl.store(out_ptr0, ret) - - -@triton.jit -def scalar_lshift(out_ptr0, val0: tl.constexpr, val1: tl.constexpr): - scalar0 = tl.core.tensor(val0, tl.core.block_type(tl.int32, [])) - scalar1 = tl.core.tensor(val1, tl.core.block_type(tl.int32, [])) - ret = scalar0 << scalar1 - tl.store(out_ptr0, ret) - - -@triton.jit -def scalar_eq(out_ptr0, val0: tl.constexpr, val1: tl.constexpr): - scalar0 = tl.core.tensor(val0, tl.core.block_type(tl.int16, [])) - scalar1 = tl.core.tensor(val1, tl.core.block_type(tl.int16, [])) - ret = scalar0 == scalar1 - tl.store(out_ptr0, ret) - - -@pytest.mark.parametrize('param_list', [ - ['float32', 'mul', (1, ), 3.14, 6.66], - ['int32', 'lshift', (1, ), 6, 7], - ['bool', 'eq', (1, ), 5, 5], -]) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "0d block_type is forbidden") -def test_case(param_list): - dtype, op_type, shape, lval, rval = param_list - ans = torch_(lval, rval, op_type) - ret = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - - if op_type == 'mul': - scalar_mul[1, 1, 1](ret, lval, rval) - elif op_type == 'lshift': - scalar_lshift[1, 1, 1](ret, lval, rval) - elif op_type == 'eq': - scalar_eq[1, 1, 1](ret, lval, rval) - - test_common.validate_cmp(dtype, ans, ret) diff --git a/third_party/ascend/unittest/generalization_cases/test_sort.py b/third_party/ascend/unittest/generalization_cases/test_sort.py deleted file mode 100644 index 543915ec2b..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sort.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import pytest -import torch -import numpy as np -import triton.language as tl -import test_common -from test_common import TestUtils - - -# ---------------------- -# 1D sort kernel -# ---------------------- -@triton.jit -def sort_kernel_1d(X, Z, M: tl.constexpr, descending: tl.constexpr): - off = tl.arange(0, M) - x = tl.load(X + off) - x = tl.sort(x, descending=descending, dim=0) - tl.store(Z + off, x) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("shape", TestUtils.test_shape1d) -@pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16", "bool"]) -def test_sort_1d(shape, descending, dtype): - if dtype == "bool": - x = test_common.generate_tensor(shape, dtype) - np_sorted = np.sort(x) - if descending: - np_sorted = np_sorted[::-1].copy() - torch_res = torch.from_numpy(np_sorted).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch.sort(x, descending=descending)[0] - - x = x.npu() - triton_res = torch.zeros_like(x) - M = x.shape[0] - sort_kernel_1d[(1, )](x, triton_res, M, descending) - assert torch.equal(torch_res, triton_res) - - -# ---------------------- -# 2D sort kernel (split by rows, not cutting M axis) -# ---------------------- -@triton.jit -def sort_kernel_2d(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr): - pid = tl.program_id(0) - offx = tl.arange(0, M) - offy = pid * M - off2d = offx + offy - x = tl.load(X + off2d) - x = tl.sort(x, descending=descending, dim=0) - tl.store(Z + off2d, x) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("shape", TestUtils.test_shape2d) -@pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16", "bool"]) -def test_sort_2d(shape, descending, dtype): - if dtype == "bool": - x = test_common.generate_tensor(shape, dtype) - np_sorted = np.sort(x) - if descending: - np_sorted = np_sorted[:, ::-1].copy() - torch_res = torch.from_numpy(np_sorted).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch.sort(x, descending=descending)[0] - - x = x.npu() - triton_res = torch.zeros_like(x) - N, M = x.shape - # 每行一个 block - sort_kernel_2d[(N, )](x, triton_res, N, M, descending) - assert torch.equal(torch_res, triton_res), (torch_res, triton_res) - - -# ---------------------- -# 3D sort kernel (split by D0, D1, not cutting D2) -# ---------------------- -@triton.jit -def sort_kernel_3d(X, Z, D0: tl.constexpr, D1: tl.constexpr, D2: tl.constexpr, descending: tl.constexpr): - pid = tl.program_id(0) - row_id = pid % D1 - batch_id = pid // D1 - - off2 = tl.arange(0, D2) - off1 = row_id * D2 - off0 = batch_id * D1 * D2 - off = off2 + off1 + off0 - - x = tl.load(X + off) - x = tl.sort(x, descending=descending, dim=0) # 一整行排序 - tl.store(Z + off, x) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("shape", TestUtils.test_shape3d) -@pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16", "bool"]) -def test_sort_3d(shape, descending, dtype): - if dtype == "bool": - x = test_common.generate_tensor(shape, dtype) - np_sorted = np.sort(x) - if descending: - np_sorted = np_sorted[:, :, ::-1].copy() - torch_res = torch.from_numpy(np_sorted).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch.sort(x, descending=descending)[0] - - x = x.npu() - triton_res = torch.zeros_like(x) - D0, D1, D2 = x.shape - # 每个 (D0,D1) 对应一个 block - sort_kernel_3d[(D0 * D1, )](x, triton_res, D0, D1, D2, descending) - assert torch.equal(torch_res, triton_res), (torch_res, triton_res) - - -# ---------------------- -# 4D sort kernel -# ---------------------- -@triton.jit -def sort_kernel_4d(X, Z, D0: tl.constexpr, D1: tl.constexpr, D2: tl.constexpr, D3: tl.constexpr, - descending: tl.constexpr): - pid = tl.program_id(0) - row_id = pid % D2 - col_id = (pid // D2) % D1 - batch_id = pid // (D1 * D2) - - off3 = tl.arange(0, D3) - off2 = row_id * D3 - off1 = col_id * D2 * D3 - off0 = batch_id * D1 * D2 * D3 - off = off3 + off2 + off1 + off0 - - x = tl.load(X + off) - x = tl.sort(x, descending=descending, dim=0) - tl.store(Z + off, x) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("shape", TestUtils.test_shape4d) -@pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16", "bool"]) -def test_sort_4d(shape, descending, dtype): - if dtype == "bool": - x = test_common.generate_tensor(shape, dtype) - np_sorted = np.sort(x) - if descending: - np_sorted = np_sorted[:, :, :, ::-1].copy() - torch_res = torch.from_numpy(np_sorted).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch.sort(x, descending=descending)[0] - - x = x.npu() - triton_res = torch.zeros_like(x) - D0, D1, D2, D3 = x.shape - sort_kernel_4d[(D0 * D1 * D2, )](x, triton_res, D0, D1, D2, D3, descending) - assert torch.equal(torch_res, triton_res) - - -# ---------------------- -# 5D sort kernel -# ---------------------- -@triton.jit -def sort_kernel_5d(X, Z, D0: tl.constexpr, D1: tl.constexpr, D2: tl.constexpr, D3: tl.constexpr, D4: tl.constexpr, - descending: tl.constexpr): - pid = tl.program_id(0) - row_id = pid % D3 - col_id = (pid // D3) % D2 - depth_id = (pid // (D2 * D3)) % D1 - batch_id = pid // (D1 * D2 * D3) - - off4 = tl.arange(0, D4) - off3 = row_id * D4 - off2 = col_id * D3 * D4 - off1 = depth_id * D2 * D3 * D4 - off0 = batch_id * D1 * D2 * D3 * D4 - off = off4 + off3 + off2 + off1 + off0 - - x = tl.load(X + off) - x = tl.sort(x, descending=descending, dim=0) - tl.store(Z + off, x) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("shape", TestUtils.test_shape5d) -@pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16", "bool"]) -def test_sort_5d(shape, descending, dtype): - if dtype == "bool": - x = test_common.generate_tensor(shape, dtype) - np_sorted = np.sort(x) - if descending: - np_sorted = np_sorted[:, :, :, :, ::-1].copy() - torch_res = torch.from_numpy(np_sorted).npu() - else: - x = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch.sort(x, descending=descending)[0] - - x = x.npu() - triton_res = torch.zeros_like(x) - D0, D1, D2, D3, D4 = x.shape - sort_kernel_5d[(D0 * D1 * D2 * D3, )](x, triton_res, D0, D1, D2, D3, D4, descending) - assert torch.equal(torch_res, triton_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_sqrt.py b/third_party/ascend/unittest/generalization_cases/test_sqrt.py deleted file mode 100644 index 49055ff811..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sqrt.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import numpy as np -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_sqrt(x0): - res = torch.sqrt(x0) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.sqrt(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_sqrt_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.sqrt(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_sqrt(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_sqrt_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sqrt_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch.sqrt(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_sqrt_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_sqrt_rn.py b/third_party/ascend/unittest/generalization_cases/test_sqrt_rn.py deleted file mode 100644 index a1add886b4..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sqrt_rn.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_sqrt_rn(x0): - tmp = torch.sqrt(x0) - return tmp - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = tl.sqrt_rn(X) - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_sqrt_rn_4d_5d(output_ptr, x_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - ret = tl.sqrt_rn(x_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_sqrt_rn(x) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'int8', - 'int16', - 'int32', - 'uint32', - 'int64', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_sqrt_rn_invalid_dtype_case(dtype): - x = test_common.generate_tensor((1, ), dtype).npu() - y = test_common.generate_tensor((1, ), dtype).npu() - z = test_common.generate_tensor((1, ), dtype).npu() - - output = torch.randint(1, (1, ), dtype=eval('torch.' + dtype)).npu() - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) -def test_sqrt_rn_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = torch_sqrt_rn(x) - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_sqrt_rn_4d_5d[grid](output, x, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_static_print_and_assert_op.py b/third_party/ascend/unittest/generalization_cases/test_static_print_and_assert_op.py deleted file mode 100644 index d383648c4c..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_static_print_and_assert_op.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -import torch_npu -import triton -import triton.language as tl -import pytest -import test_common - -import os - -os.environ["TRITON_ALWAYS_COMPILE"] = "1" -os.environ["PYTEST_ADDOPTS"] = "-sv" - -shape = (8, ) -XS = 8 -XVALS_INT = [ - 0, - torch.iinfo(torch.int8).min, - torch.iinfo(torch.int8).max, - torch.iinfo(torch.int16).min, - torch.iinfo(torch.int16).max, - torch.iinfo(torch.int32).min, - torch.iinfo(torch.int32).max, - torch.iinfo(torch.int32).max + 1 -] - - -def torch_func(x0, x1): - res = x0 + x1 - return res - - -@triton.jit -def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr): - idx = tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.load(in_ptr1 + idx) - tmp2 = tmp0 + tmp1 - tl.static_print(XBLOCK) - tl.static_print(tmp2) - tl.static_assert(XBLOCK == 8) - tl.store(out_ptr0 + idx, tmp2) - - -def triton_func(x0, x1, XS): - out = torch.empty_like(x0) - triton_kernel[ - 1, - ](out, x0, x1, XS) - return out - - -@pytest.mark.parametrize('sigtype', ['int32', 'int64', 'int16', 'int8', 'float32', 'float16', 'bfloat16']) -def test_static_print_and_assert(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS) - captured = capsys.readouterr() - - if sigtype == "float32": - assert "fp32" in captured.out - if sigtype == "float16": - assert "fp16" in captured.out - if sigtype == "bfloat16": - assert "bf16" in captured.out - if "int" in sigtype: - assert sigtype in captured.out - assert "8" in captured.out - - test_common.validate_cmp(sigtype, triton_cal, torch_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_sum.py b/third_party/ascend/unittest/generalization_cases/test_sum.py deleted file mode 100644 index e3caa4edbe..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sum.py +++ /dev/null @@ -1,332 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import random -import pytest -import torch -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, get_dtype_size - - -def torch_sum(x1, dim): - if x1.dtype == torch.float16 or x1.dtype == torch.bfloat16: - res = torch.sum(x1.to(torch.float32), dim=dim, keepdim=False).to(x1.dtype) - else: - res = torch.sum(x1, dim=dim, keepdim=False).to(x1.dtype) - return res - - -@triton.jit -def tt_sum_1d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - idx = tl.arange(0, XB) - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_sum_2d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - - if dim == 0: - oidx = yidx - else: - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_sum_1d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - idx = tl.arange(0, XB) - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_sum_2d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - idx = xidx[:, None] * ynumel + yidx[None, :] - - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - - tl.store(out_ptr + tl.arange(0, 1), ret) - - -@triton.jit -def tt_sum_3d_dim_none(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - - tl.store(out_ptr, ret) - - -@triton.jit -def tt_sum_3d(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - ret = tl.sum(x, dim) - - if dim == 0: - oidx = yidx[:, None] * znumel + zidx[None, :] - elif dim == 1: - oidx = xidx[:, None] * znumel + zidx[None, :] - else: - oidx = xidx[:, None] * ynumel + yidx[None, :] - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_sum_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.sum(x, 0) - ret = tl.sum(tmp, 0) - oidx = zidx - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_sum_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.sum(x, 0) - ret = tl.sum(tmp, 1) - oidx = yidx - - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def tt_sum_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, XB: tl.constexpr, - YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - - x = tl.load(in_ptr + idx) - - tmp = tl.sum(x, 1) - ret = tl.sum(tmp, 1) - oidx = xidx - - tl.store(out_ptr + oidx, ret) - - -def is_legal_combine(shape, dims): - return dims is None or (len(shape) == 3) or \ - (len(dims) == 1 and dims[0] < len(shape)) - - -dims_map = {(0, 1): tt_sum_3d_0_1, (1, 2): tt_sum_3d_1_2, (0, 2): tt_sum_3d_0_2} - -shape_map = { - 1: {"append_shape": (1, 1), "func": tt_sum_1d}, 2: {"append_shape": (1, ), "func": tt_sum_2d}, 3: - {"append_shape": (), "func": tt_sum_3d} -} - - -def reduce_check_ub_mem_overflow(dtype, shape): - dtype_size = get_dtype_size(dtype) - if (dtype == "int8" or dtype == "bool") and dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): - pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") - elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 6): - pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") - - -@pytest.mark.parametrize('shape', random.sample(TestUtils.full_shape, 5)) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (0, 1), (1, 2), (0, 2)]) -def test_sum(dtype, shape, dims): - if not is_legal_combine(shape, dims): - return - - torch.manual_seed(0) - x = test_common.generate_tensor(shape, dtype).npu() - grid = (1, 1, 1) - - y_ref = torch_sum(x, dims) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - if dims is None: - reduce_check_ub_mem_overflow(dtype, shape) - append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] - xnumel, ynumel, znumel = shape + append_shape - XB, YB, ZB = xnumel, ynumel, znumel - if len(shape) == 1: - tt_sum_1d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - if len(shape) == 2: - tt_sum_2d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - if len(shape) == 3: - tt_sum_3d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) - - test_common.validate_cmp(dtype, y_cal, y_ref) - - elif len(dims) == 1: # 1d sum, 1-3d shape - append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] - xnumel, ynumel, znumel = shape + append_shape - XB, YB, ZB = xnumel, ynumel, znumel - if (len(shape) == 2) and (x.numel() * x.element_size() > 8192): - if dims[0] == 0: - grid = (1, ynumel, 1) - YB = 1 - else: - grid = (xnumel, 1, 1) - XB = 1 - tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) - test_common.validate_cmp(dtype, y_cal, y_ref) - else: # 3d shape, 2d sum - tt_kernel = dims_map[dims] - xnumel, ynumel, znumel = shape - XB, YB, ZB = xnumel, ynumel, znumel - tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) - test_common.validate_cmp(dtype, y_cal, y_ref) - - -@triton.jit -def triton_sum_multi_d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr, REDUCE_NUMEL: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - x = tl.load(in_ptr + offsets) - - if DIM is not None: - ret = tl.reshape(tl.sum(x, DIM), REDUCE_NUMEL) - o_offsets = tl.arange(0, REDUCE_NUMEL) - tl.store(out_ptr + o_offsets, ret) - else: - ret = tl.sum(x, DIM) - tl.store(out_ptr, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (4, 2, 8, 4), - (4, 3, 8, 1), -]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (3, )]) -def test_sum_4d(dtype, shape, dims): - torch.manual_seed(0) - - x = test_common.generate_tensor(shape, dtype).npu() - dim = dims[0] if dims is not None else None - - y_ref = torch_sum(x, dim) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_sum_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, y_cal, y_ref) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 2, 8, 4), - (3, 4, 2, 8, 1), -]) -@pytest.mark.parametrize('dtype', TestUtils.full_dtype) -@pytest.mark.parametrize('dims', [None, (0, ), (1, ), (2, ), (3, ), (4, )]) -def test_sum_5d(dtype, shape, dims): - torch.manual_seed(0) - - x = test_common.generate_tensor(shape, dtype).npu() - dim = dims[0] if dims is not None else None - - y_ref = torch_sum(x, dim) - y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_sum_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_sum_dim0.py b/third_party/ascend/unittest/generalization_cases/test_sum_dim0.py deleted file mode 100644 index 9ef39548d7..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sum_dim0.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, get_dtype_size -import math - - -def torch_sum(x0): - res = torch.sum(x0, 0) - return res - - -@triton.jit -def triton_sum(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr, RBLOCK_SUB: tl.constexpr): - xindex = tl.arange(0, XBLOCK) - xmask = xindex[:, None] < xnumel - for roffset_sub in range(0, RBLOCK, RBLOCK_SUB): - rindex = roffset_sub + tl.arange(0, RBLOCK_SUB) - x0 = xindex - r1 = rindex - rmask = rindex < rnumel - tmp0 = tl.load(in_ptr0 + (r1 + (RBLOCK * x0[:, None])), xmask & rmask) - tmp2 = tl.reshape(tmp0, [XBLOCK, RBLOCK_SUB]) - tmp4 = tl.sum(tmp2, 0) - tl.store(out_ptr1 + (rindex), tmp4, rmask) - - -def should_skip_due_to_mem(dtype, shape): - dtype_size = get_dtype_size(dtype) - total_mem = dtype_size * math.prod(shape) - threshold = TestUtils.ub_size / 1.5 - - if total_mem >= threshold: - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'int32']) -def test_case(dtype, shape): - should_skip_due_to_mem(dtype, shape) - x0 = test_common.generate_tensor(shape, dtype).npu() - - rblock = shape[1] - xblock = shape[0] - ncore = 1 #if numel <= 32 else 32 - rblock_sub = rblock #if xblock <= 16 else 16 - RBLOCK_tl = 256 if rblock > 1 else 1 - - y_ref = torch_sum(x0) - y_cal = torch.zeros(shape[1], dtype=eval('torch.' + dtype)).npu() - triton_sum[ncore, 1, 1](x0, y_cal, xblock, rblock, xblock, rblock, rblock_sub) - test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_sum_dim1.py b/third_party/ascend/unittest/generalization_cases/test_sum_dim1.py deleted file mode 100644 index dd304da524..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_sum_dim1.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_sum(x0): - res = torch.sum(x0, 1) - return res - - -@triton.jit -def triton_sum(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr, RBLOCK: tl.constexpr): - xoffset = tl.program_id(0) * XBLOCK - rindex = tl.arange(0, RBLOCK)[None, :] - rmask = rindex < rnumel - for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): - xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB) - x0 = xindex - r1 = rindex - xmask = xindex[:, None] < xnumel - xmask_prime = xindex < xnumel - tmp0 = tl.load(in_ptr0 + (r1 + (RBLOCK * x0[:, None])), rmask & xmask) - tmp2 = tl.reshape(tmp0, [XBLOCK_SUB, RBLOCK]) - tmp4 = tl.sum(tmp2, 1) - tl.store(out_ptr1 + (xindex), tmp4, xmask_prime) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int32']) -def test_case(dtype, shape): - x0 = test_common.generate_tensor(shape, dtype).npu() - - rblock = shape[1] - xblock = shape[0] - ncore = 1 #if numel <= 32 else 32 - xblock_sub = xblock if xblock <= 16 else 16 - RBLOCK_tl = 256 if rblock > 1 else 1 - - y_ref = torch_sum(x0) - y_cal = torch.zeros(shape[:-1], dtype=eval('torch.' + dtype)).npu() - triton_sum[ncore, 1, 1](x0, y_cal, xblock, rblock, xblock, xblock_sub, rblock) - test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_swizzle2d.py b/third_party/ascend/unittest/generalization_cases/test_swizzle2d.py deleted file mode 100644 index 46bd2ebdf8..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_swizzle2d.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import random -import triton -import triton.language as tl -import torch -import torch_npu -import pytest -import test_common -from test_common import TestUtils - - -def swizzle2d(size_i, size_j, size_g): - i = torch.arange(0, size_i)[:, None] - j = torch.arange(0, size_j)[None, :] - ij = i * size_j + j - size_gj = size_g * size_j - group_id = ij // size_gj - off_i = group_id * size_g - size_g = torch.min(size_i - off_i, torch.tensor(size_g).expand_as(off_i)) - ij = ij % size_gj - new_i = off_i + ij % size_g - new_j = ij // size_g - ret = new_i * size_i + new_j - return ret - - -@triton.jit -def fn_npu_(out0, out1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - i = tl.arange(0, XB)[:, None] - j = tl.arange(0, YB)[None, :] - ij = i * YB + j - xx, yy = tl.swizzle2d(i, j, size_i=XB, size_j=YB, size_g=ZB) - - ptr = tl.load(out0) - xx = tl.cast(xx, dtype=ptr.dtype) - yy = tl.cast(yy, dtype=ptr.dtype) - tl.store(out0 + ij, xx) - tl.store(out1 + ij, yy) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) -def test_swizzle2d(shape, dtype): - if (shape[0] > 255) or (shape[1] > 255): - return - size_g = random.randint(1, min(shape[0], shape[1])) - ans = swizzle2d(shape[0], shape[1], size_g).to(eval('torch.' + dtype)).npu() - - out0 = test_common.generate_tensor(shape, dtype).npu() - out1 = test_common.generate_tensor(shape, dtype).npu() - fn_npu_[1, 1, 1](out0, out1, shape[0], shape[1], size_g) - triton_ret = out0 * shape[0] + out1 - torch.testing.assert_close(triton_ret, ans) diff --git a/third_party/ascend/unittest/generalization_cases/test_trans_1d_2d.py b/third_party/ascend/unittest/generalization_cases/test_trans_1d_2d.py deleted file mode 100644 index d56ec1bbb4..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_trans_1d_2d.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, xnumel: tl.constexpr): - idx = tl.arange(0, xnumel) - - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 0) - - tl.store(output_ptr + idx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_trans_1d(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() - - triton_res = torch.randint(1, shape, dtype=data_type).npu() - torch_res = torch.permute(x, (0, )) - fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -@triton.jit -def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = yidx[:, None] * ZB + zidx[None, :] - - # XB,YB,1 - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 1, 0) - - oidx = zidx[:, None] * YB + yidx[None, :] - - tl.store(output_ptr + oidx, ret) - - -bisheng_notsupport_dtype = ['int64'] -tritonascend_notsupport_dtype = ['bool'] -# check_ub_mem_overflow没拦住,在kernel中最大ub占用超过ubsize -mem_overflow_scene = [ - ('bfloat16', (128, 256)), - ('bfloat16', (256, 128)), - ('int8', (741, 256)), - ('int8', (256, 741)), - ('int16', (256, 256)), - ('float16', (256, 256)), - ('bfloat16', (256, 256)), - ('int32', (128, 256)), - ('int32', (256, 128)), - ('float32', (128, 256)), - ('float32', (256, 128)), -] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_permute(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - if dtype in bisheng_notsupport_dtype or dtype in tritonascend_notsupport_dtype: - return - if (dtype, shape) in mem_overflow_scene: - return - if check_ub_mem_overflow(dtype, shape): - return - YB = shape[0] - ZB = shape[1] - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=(YB, ZB), dtype=data_type).npu() - - triton_res = torch.randint(1, (ZB, YB), dtype=data_type).npu() - torch_res = torch.permute(x, (1, 0)) - fn_npu_021[1, 1, 1](triton_res, x, YB, ZB) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -if __name__ == "__main__": - for shape in [(37, 3)]: - for dtype in TestUtils.dtype_list: - test_permute(shape, dtype) diff --git a/third_party/ascend/unittest/generalization_cases/test_trans_3d.py b/third_party/ascend/unittest/generalization_cases/test_trans_3d.py deleted file mode 100644 index 6f8428e575..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_trans_3d.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def fn_npu_102(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 1, 0, 2) - - oidx = zidx[:, None, None] * YB * KB + yidx[None, :, None] * KB + kidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_210(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 2, 1, 0) - - oidx = kidx[:, None, None] * ZB * YB + zidx[None, :, None] * YB + yidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 0, 2, 1) - - oidx = yidx[:, None, None] * ZB * KB + kidx[None, :, None] * ZB + zidx[None, None, :] - - tl.store(output_ptr + oidx, ret) - - -bisheng_notsupport_dtype = [] -tritonascend_notsupport_dtype = ['bool'] - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_permute_3d(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - if dtype in bisheng_notsupport_dtype or dtype in tritonascend_notsupport_dtype: - return - if check_ub_mem_overflow(dtype, shape): - return - - data_type = eval('torch.' + dtype) - x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() - - triton_res = torch.empty((shape[1], shape[0], shape[2]), dtype=data_type).npu() - torch_res = torch.permute(x, (1, 0, 2)) - fn_npu_102[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - test_common.validate_cmp(dtype, triton_res, torch_res) - - # not support yet: need bisheng support later - # triton_res = torch.empty((shape[2], shape[1], shape[0]), dtype=data_type).npu() - # torch_res = torch.permute(x, (2, 1, 0)) - # fn_npu_210[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - # test_common.validate_cmp(dtype, triton_res, torch_res) - - triton_res = torch.empty((shape[0], shape[2], shape[1]), dtype=data_type).npu() - torch_res = torch.permute(x, (0, 2, 1)) - fn_npu_021[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -if __name__ == "__main__": - for shape in [(1, 22, 39)]: - for dtype in TestUtils.dtype_list: - test_permute_3d(shape, dtype) - - -@triton.jit -def fn_npu_102(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - kidx = tl.arange(0, KB) - idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] - - X = tl.load(x_ptr + idx) - - ret = tl.trans(X, 1, 0, 2) - - oidx = (zidx[:, None, None] * YB * KB + yidx[None, :, None] * KB + kidx[None, None, :]) - - tl.store(output_ptr + oidx, ret) - - -@pytest.mark.parametrize('sigtype, dtype, XB, YB, ZB', [ - ('bfloat16', torch.bfloat16, 2, 8, 4), - ('uint8', torch.uint8, 1, 256, 16), - ('bool', torch.bool, 1, 1, 2), -]) -def test_permute_3d_u(sigtype, dtype, XB, YB, ZB): - x = test_common.generate_tensor((XB, YB, ZB), sigtype).npu() - triton_res = torch.empty((YB, XB, ZB), dtype=dtype).npu() - torch_res = torch.permute(x, (1, 0, 2)) - fn_npu_102[1, 1, 1](triton_res, x, XB, YB, ZB) - test_common.validate_cmp(sigtype, triton_res, torch_res) diff --git a/third_party/ascend/unittest/generalization_cases/test_trans_4d_5d.py b/third_party/ascend/unittest/generalization_cases/test_trans_4d_5d.py deleted file mode 100644 index 8505e974f3..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_trans_4d_5d.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils, check_ub_mem_overflow -import math -import logging - - -@triton.jit -def triton_trans_4d( - output_ptr, - x_ptr, - PERM: tl.constexpr, - BLOCK_0: tl.constexpr, - BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, - SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, - SHAPE_3: tl.constexpr, - STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, - STRIDE_2: tl.constexpr, - STRIDE_3: tl.constexpr, -): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] - tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] - tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None] - tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] - tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None] - tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] - tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None] - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) - x_val = tl.load(x_ptr + offsets, masks) - - if PERM == 0: # 1, 0, 2, 3 - ret = tl.trans(x_val, (1, 0, 2, 3)) - shape0 = SHAPE_1 - shape1 = SHAPE_0 - shape2 = SHAPE_2 - shape3 = SHAPE_3 - elif PERM == 1: # 0, 2, 1, 3 - ret = tl.trans(x_val, (0, 2, 1, 3)) - shape0 = SHAPE_0 - shape1 = SHAPE_2 - shape2 = SHAPE_1 - shape3 = SHAPE_3 - else: # 0, 1, 3, 2 - ret = tl.trans(x_val, (0, 1, 3, 2)) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_3 - shape3 = SHAPE_2 - - s3 = 1 - s2 = s3 * shape3 - s1 = s2 * shape2 - s0 = s1 * shape1 - - if PERM == 0: # 1, 0, 2, 3 - out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 - out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) - elif PERM == 1: # 0, 2, 1, 3 - out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 - out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) - else: # 0, 1, 3, 2 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@triton.jit -def triton_trans_5d(output_ptr, x_ptr, PERM: tl.constexpr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, - BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, - SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, - STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, - STRIDE_4: tl.constexpr): - pid = tl.program_id(0) - tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] - tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] - tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] - tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] - tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] - - tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None, None] - tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None, None] - - tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None, None] - tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None, None] - - tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :, None] - tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None, None] - - tmp3_4 = tl.arange(0, BLOCK_3)[None, None, None, None, :] - tmp4_3 = tl.arange(0, BLOCK_4)[None, None, None, :, None] - - offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 - masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) - x_val = tl.load(x_ptr + offsets, masks) - - if PERM == 0: # 1, 0, 2, 3, 4 - ret = tl.trans(x_val, 1, 0, 2, 3, 4) - shape0 = SHAPE_1 - shape1 = SHAPE_0 - shape2 = SHAPE_2 - shape3 = SHAPE_3 - shape4 = SHAPE_4 - elif PERM == 1: # 0, 2, 1, 3, 4 - ret = tl.trans(x_val, 0, 2, 1, 3, 4) - shape0 = SHAPE_0 - shape1 = SHAPE_2 - shape2 = SHAPE_1 - shape3 = SHAPE_3 - shape4 = SHAPE_4 - elif PERM == 2: # 0, 1, 3, 2, 4 - ret = tl.trans(x_val, 0, 1, 3, 2, 4) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_3 - shape3 = SHAPE_2 - shape4 = SHAPE_4 - else: # 0, 1, 2, 4, 3 - ret = tl.trans(x_val, 0, 1, 2, 4, 3) - shape0 = SHAPE_0 - shape1 = SHAPE_1 - shape2 = SHAPE_2 - shape3 = SHAPE_4 - shape4 = SHAPE_3 - - s4 = 1 - s3 = s4 * shape4 - s2 = s3 * shape3 - s1 = s2 * shape2 - s0 = s1 * shape1 - - if PERM == 0: # 1, 0, 2, 3, 4 - out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 + tmp4 * s4 - out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) & (tmp4 < shape4) - elif PERM == 1: # 0, 2, 1, 3, 4 - out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 + tmp4 * s4 - out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) & (tmp4 < shape4) - elif PERM == 2: # 0, 1, 3, 2, 4 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 + tmp4 * s4 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) & (tmp4 < shape4) - else: # 0, 1, 2, 4, 3 - out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp2 * s2 + tmp4_3 * s3 + tmp3_4 * s4 - out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp2 < shape2) & (tmp4_3 < shape3) & (tmp3_4 < shape4) - tl.store(output_ptr + out_offsets, ret, mask=out_masks) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) -@pytest.mark.parametrize('perm', [0, 1, 2, 3]) # 4d: support 3 mode; 5d: support 4 mode -def test_trans_4d_5d(shape, dtype, perm): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.randint(low=0, high=2, size=shape, dtype=eval('torch.' + dtype)).npu() - grid = (1, ) - if len(shape) == 4: - blocks = list(x.size()) - strides = list(x.stride()) - if perm == 0: # 1, 0, 2, 3; exchange axis 0, 1 - output = torch.empty((shape[1], shape[0], shape[2], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (1, 0, 2, 3)) - triton_trans_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - elif perm == 1: # 0, 2, 1, 3; exchange axis 1, 2 - output = torch.empty((shape[0], shape[2], shape[1], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (0, 2, 1, 3)) - triton_trans_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - elif perm == 2: # 0, 1, 3, 2; exchange axis 2, 3 - output = torch.empty((shape[0], shape[1], shape[3], shape[2]), dtype=eval('torch.' + dtype)).npu() - ans_4d = torch.permute(x, (0, 1, 3, 2)) - triton_trans_4d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_4d, output) - else: - pass - else: - blocks = list(x.size()) - strides = list(x.stride()) - - if perm == 0: # 1, 0, 2, 3, 4; exchange axis 0, 1 - output = torch.empty((shape[1], shape[0], shape[2], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (1, 0, 2, 3, 4)) - elif perm == 1: # 0, 2, 1, 3, 4; exchange axis 1, 2 - output = torch.empty((shape[0], shape[2], shape[1], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 2, 1, 3, 4)) - elif perm == 2: # 0, 1, 3, 2, 4; exchange axis 2, 3 - output = torch.empty((shape[0], shape[1], shape[3], shape[2], shape[4]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 1, 3, 2, 4)) - else: # 0, 1, 2, 4, 3; exchange axis 3, 4 - output = torch.empty((shape[0], shape[1], shape[2], shape[4], shape[3]), dtype=eval('torch.' + dtype)).npu() - ans_5d = torch.permute(x, (0, 1, 2, 4, 3)) - triton_trans_5d[grid](output, x, perm, *blocks, *blocks, *strides) - test_common.validate_cmp(dtype, ans_5d, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_umulhi.py b/third_party/ascend/unittest/generalization_cases/test_umulhi.py deleted file mode 100644 index 421fc77322..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_umulhi.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import triton -import torch -import pytest -import test_common - -import numpy as np -import triton.language as tl -from test_common import TestUtils - - -# inp the two 32 bit signed integers. -@triton.jit -def umulhi_kernel(X, Y, Z, N: tl.constexpr): - offs = tl.arange(0, N) - x = tl.load(X + offs) - y = tl.load(Y + offs) - z = tl.umulhi(x, y) - tl.store(Z + tl.arange(0, N), z) - - -@triton.jit -def triton_umulhi_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = tl.umulhi(x_val, y_val) - tl.store(output_ptr + offsets, ret, mask=masks) - - -# accuracy reference -def umulhi32(a, b): - a_64 = a.astype(np.int64) - b_64 = b.astype(np.int64) - product_64 = a_64 * b_64 - # get the high part - result_high_32 = product_64 >> 32 - return result_high_32.astype(np.int32) - - -@pytest.mark.parametrize('dtype', ['int32']) -@pytest.mark.parametrize('shape', TestUtils.full_shape) -def test_case2(dtype, shape): - N = shape[0] - dtypes = eval('torch.' + dtype) - x = torch.randint(low=0, high=2000, size=shape, dtype=dtypes) - y = torch.randint(low=0, high=2000, size=shape, dtype=dtypes) - xx = x.npu() - yy = y.npu() - z_tri = torch.zeros(size=shape, dtype=dtypes).npu() - umulhi_kernel[(1, )](xx, yy, z_tri, N=N) - - xxx = x.numpy() - yyy = y.numpy() - z_ref = umulhi32(xxx, yyy) - z_ref1 = torch.from_numpy(z_ref).npu() - torch.equal(z_tri, z_ref1) - - -invalid_types = [ - 'int8', - 'int16', - 'int64', - 'float16', - 'float32', - 'bfloat16', - 'bool', -] - - -@pytest.mark.parametrize("dtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") -def test_umulhi_invalid_dtype_case(dtype): - x0 = test_common.generate_tensor((1, ), dtype).npu() - x1 = test_common.generate_tensor((1, ), dtype).npu() - - y_cal = torch.zeros((1, ), dtype=eval('torch.' + dtype)).npu() - umulhi_kernel[(1, )](x0, x1, y_cal, 1) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int32']) -def test_umulhi_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) - y = torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) - xx = x.npu() - yy = y.npu() - - output = torch.zeros(size=shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - xxx = x.numpy() - yyy = y.numpy() - z = umulhi32(xxx, yyy) - ans = torch.from_numpy(z).npu() - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_umulhi_4d_5d[grid](output, xx, yy, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_where.py b/third_party/ascend/unittest/generalization_cases/test_where.py deleted file mode 100644 index 79345b6f21..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_where.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math - - -def torch_pointwise(x0, x1): - res = torch.where(x0 < x1, x0, 1) - return res - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - tmp2 = X < Y - ret = tl.where(tmp2, X, 1) - - tl.store(output_ptr + idx, ret) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['bool', 'float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - ans = torch_pointwise(x, y) - output = torch.zeros_like(ans) - - if len(shape) == 1: - fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) - elif len(shape) == 2: - if shape[0] > shape[1]: - fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) - else: - fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) - elif len(shape) == 3: - if max(shape[0], shape[1], shape[2]) == shape[0]: - fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) - else: - fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) - - test_common.validate_cmp(dtype, ans, output) - - -@triton.jit -def fn_npu_multi_d(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIMS: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - X = tl.load(x_ptr + offsets) - Y = tl.load(y_ptr + offsets) - - tmp2 = X < Y - ret = tl.where(tmp2, X, 1) - - tl.store(output_ptr + offsets, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (4, 2, 8, 4), - (2, 4, 2, 8, 1), - (4, 3, 8, 1), - (3, 4, 2, 8, 4), -]) -@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) -def test_case_4d_5d(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - - ans = torch_pointwise(x, y) - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - grid = (1, ) - fn_npu_multi_d[grid](output, x, y, *triton_shape, len(shape)) - - test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/unittest/generalization_cases/test_xor.py b/third_party/ascend/unittest/generalization_cases/test_xor.py deleted file mode 100644 index fe696552af..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_xor.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging - -import pytest -import triton -import triton.language as tl -import torch -import torch_npu -import test_common -from test_common import TestUtils -import math - - -def torch_xor(x0, x1): - return x0 ^ x1 - - -@triton.jit -def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xoffs = tl.program_id(0) * XB - yoffs = tl.program_id(1) * YB - zoffs = tl.program_id(2) * ZB - - xidx = tl.arange(0, XB) + xoffs - yidx = tl.arange(0, YB) + yoffs - zidx = tl.arange(0, ZB) + zoffs - - idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - Y = tl.load(y_ptr + idx) - - ret = X ^ Y - - tl.store(output_ptr + idx, ret) - - -@triton.jit -def triton_xor_4d_5d(output_ptr, x_ptr, y_ptr, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, - BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, - SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, STRIDE_0: tl.constexpr, - STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr): - offsets = tl.program_id(0) - - offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 - masks = tl.arange(0, BLOCK_0) < SHAPE_0 - if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 - masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) - if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 - masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) - if (BLOCK_3 * BLOCK_4) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 - masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) - if BLOCK_4 > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 - masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) - - x_val = tl.load(x_ptr + offsets, masks) - y_val = tl.load(y_ptr + offsets, masks) - ret = x_val ^ y_val - tl.store(output_ptr + offsets, ret, mask=masks) - - -@pytest.mark.parametrize('shape', TestUtils.full_shape) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_case2(dtype, shape): - # 生成数据 - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - z = test_common.generate_tensor(shape, dtype).npu() - new_shape = shape - - output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() - output1 = output - logging.debug(f"output.dtype={output.dtype}") - - ans = torch_xor(x, y) - - if len(shape) == 1: - XB = 1 - xnumel = 1 - YB = 1 - ynumel = 1 - ZB = shape[0] - znumel = shape[0] - elif len(shape) == 2: - XB = 1 - xnumel = 1 - YB = shape[0] - ynumel = shape[0] - ZB = shape[1] - znumel = shape[1] - else: - XB = shape[0] - xnumel = shape[0] - YB = shape[1] - ynumel = shape[1] - ZB = shape[2] - znumel = shape[2] - - grid = (1, 1, 1) - if x.numel() * x.element_size() >= 8192: - grid = (1, 1, ZB) - ZB = 1 - - fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) - - test_common.validate_cmp(dtype, ans, output) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_xor_4d_5d(shape, dtype): - logging.log(logging.DEBUG, f"shape = {shape}") - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - - logging.log(logging.DEBUG, f"output.dtype={output.dtype}") - - ans = x ^ y - - blocks = list(x.size()) - strides = list(x.stride()) - while len(blocks) < 5: - blocks.append(1) - strides.append(1) - - grid = (1, ) - triton_xor_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) - - test_common.validate_cmp(dtype, ans, output) - - -invalid_types = [ - 'float16', - 'float32', - 'bfloat16', -] - - -@pytest.mark.parametrize("sigtype", invalid_types) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") -def test_invalid_types(sigtype): - N = 32 - x = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - y = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - z = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - output = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() - - fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/unittest/generalization_cases/test_xorsum.py b/third_party/ascend/unittest/generalization_cases/test_xorsum.py deleted file mode 100644 index 633db01c15..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_xorsum.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import triton -import triton.language as tl -import torch -import torch_npu -import pytest -import test_common -import functools -from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size - - -# <<<<<<< test_xorsum_1d -def torch_xorsum(tensor, dim=None, keepdim=False): - if dim is None: - result = tensor.flatten()[0] - for x in tensor.flatten()[1:]: - result = result ^ x - return result - else: - assert dim < tensor.dim(), f"Invalid dim {dim} for tensor shape {tensor.shape}" - result = tensor.select(dim, 0) - for i in range(1, tensor.size(dim)): - result = result ^ tensor.select(dim, i) - if keepdim: - result = result.unsqueeze(dim) - return result - - -@triton.jit -def triton_xorsum_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): - xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + xoffset, None) - tmp4 = tl.xor_sum(tmp0, 0) - tl.store(out_ptr1, tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -def test_xorsum_1d(dtype, shape): - if check_ub_mem_overflow(dtype, shape): - return - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty(1, dtype=eval("torch." + dtype)).npu() - numel = shape[0] - triton_xorsum_1d[1, 1, 1](x0, triton_res, numel, numel) - torch_res = torch_xorsum(x0, dim=0, keepdim=True) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_xorsum_1d - - -# <<<<<<< test_xorsum_2d -@triton.jit -def triton_xorsum_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) - tmp4 = tl.xor_sum(x, dim) - if dim == 0: - tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) - else: - tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape2d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize('dim', [0, 1]) -def test_xorsum_2d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype in ['int8', 'int16', 'int32', 'int64']: - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - elif dtype in ['bool']: - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 5): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - shapex, shapey = shape - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[1 - dim], - ], dtype=eval("torch." + dtype)).npu() - triton_xorsum_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) - torch_res = torch_xorsum(x0, dim=dim, keepdim=False) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_xorsum_2d - - -# <<<<<<< test_xorsum_3d -def torch_xorsum_3d(x0, no_reduce_dim): - inp = x0 if x0.device == "cpu" else x0.cpu() - if no_reduce_dim == 0: - return torch_xorsum(torch_xorsum(inp, 1), 1).npu() - elif no_reduce_dim == 1: - return torch_xorsum(torch_xorsum(inp, 0), 1).npu() - elif no_reduce_dim == 2: - return torch_xorsum(torch_xorsum(inp, 0), 0).npu() - else: - assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" - - -@triton.jit -def triton_xorsum_3d_0_1(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.xor_sum(x, 0) - ret = tl.xor_sum(tmp, 0) - oidx = zidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_xorsum_3d_0_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.xor_sum(x, 0) - ret = tl.xor_sum(tmp, 1) - oidx = yidx - tl.store(out_ptr + oidx, ret) - - -@triton.jit -def triton_xorsum_3d_1_2(in_ptr, out_ptr, xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, - XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - xidx = tl.arange(0, XB) - yidx = tl.arange(0, YB) - zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] - x = tl.load(in_ptr + idx) - tmp = tl.xor_sum(x, 1) - ret = tl.xor_sum(tmp, 1) - oidx = xidx - tl.store(out_ptr + oidx, ret) - - -def triton_xorsum_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): - if no_reduce_dim == 0: - triton_xorsum_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 1: - triton_xorsum_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - elif no_reduce_dim == 2: - triton_xorsum_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape3d) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) -def test_xorsum_3d(dtype, shape, no_reduce_dim): - x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([ - shape[no_reduce_dim], - ], dtype=eval("torch." + dtype)).npu() - triton_xorsum_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) - torch_res = torch_xorsum_3d(x0, no_reduce_dim) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_xorsum_3d - - -# <<<<<<< test_xorsum_4d -@triton.jit -def triton_xorsum_multi_d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr, REDUCE_NUMEL: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if DIMS > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if DIMS > 2: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if DIMS > 3: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if DIMS > 4: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - x = tl.load(in_ptr + offsets) - - if DIM is not None: - ret = tl.reshape(tl.xor_sum(x, DIM), REDUCE_NUMEL) - o_offsets = tl.arange(0, REDUCE_NUMEL) - tl.store(out_ptr + o_offsets, ret) - else: - ret = tl.xor_sum(x, DIM) - tl.store(out_ptr, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (4, 2, 8, 4), - (4, 3, 8, 1), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize('dim', [0, 1, 2, 3]) -def test_xorsum_4d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype in ['int8', 'int16', 'int32', 'int64']: - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - print(f"dtype:{dtype} shape:{shape} mem overflow") - return - - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_xorsum(x0, dim=dim, keepdim=False) - triton_res = torch.empty_like(torch_res, dtype=eval("torch." + dtype)).npu() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_xorsum_multi_d[grid](x0, triton_res, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_xorsum_4d - - -# <<<<<<< test_xorsum_5d -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('shape', [ - (2, 4, 2, 8, 4), - (3, 4, 2, 8, 1), -]) -@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) -@pytest.mark.parametrize('dim', [0, 1, 2, 3, 4]) -def test_xorsum_5d(dtype, shape, dim): - dtype_size = get_dtype_size(dtype) - if dtype in ['int8', 'int16', 'int32', 'int64']: - if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): - print(f"dtype:{dtype} shape:{shape} mem overflow") - return - - x0 = test_common.generate_tensor(shape, dtype).npu() - torch_res = torch_xorsum(x0, dim=dim, keepdim=False) - triton_res = torch.empty_like(torch_res, dtype=eval("torch." + dtype)).npu() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None - grid = (1, ) - triton_xorsum_multi_d[grid](x0, triton_res, *triton_shape, len(shape), dim, reduce_numel) - test_common.validate_cmp(dtype, triton_res, torch_res) - - -# >>>>>>> test_xorsum_5d - -if __name__ == "__main__": - test_xorsum_3d('int8', (3, 3, 3), 0) diff --git a/third_party/ascend/unittest/generalization_cases/test_zeros_op.py b/third_party/ascend/unittest/generalization_cases/test_zeros_op.py deleted file mode 100644 index 7e5304d153..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_zeros_op.py +++ /dev/null @@ -1,534 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import math -import pytest -import random -import torch -import torch_npu -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, check_ub_mem_overflow - - -@triton.jit -def fn_npu_int8_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int8) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int32) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int64_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int64) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.float16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.float32) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bf16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.bfloat16) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bool_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, - YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - xidx = tl.arange(0, XNUMEL) - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Xmask = xidx < X - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) - ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int1) - oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int8_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int8) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int32) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int64_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int64) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.float16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.float32) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bf16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.bfloat16) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bool_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): - yidx = tl.arange(0, YNUMEL) - zidx = tl.arange(0, ZNUMEL) - Ymask = yidx < Y - Zmask = zidx < Z - mask = (Ymask[:, None]) & (Zmask[None, :]) - ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int1) - oidx = yidx[:, None] * Z + zidx[None, :] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int8_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.int8) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int16_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.int16) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int32_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.int32) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int64_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.int64) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp16_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.float16) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_fp32_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.float32) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bf16_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.bfloat16) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_bool_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): - zidx = tl.arange(0, ZNUMEL) - Zmask = zidx < Z - mask = (Zmask[:]) - ret = tl.zeros((ZNUMEL, ), dtype=tl.int1) - oidx = zidx[:] - tl.store(output_ptr + oidx, ret, mask=mask) - - -@triton.jit -def fn_npu_int8_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.int8) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_int16_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.int16) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_int32_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.int32) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_int64_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.int64) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_fp16_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.float16) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_fp32_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.float32) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_bf16_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.bfloat16) - tl.store(output_ptr, zero) - - -@triton.jit -def fn_npu_bool_0d(output_ptr, N: tl.constexpr): - zero = tl.zeros((), dtype=tl.int1) - tl.store(output_ptr, zero) - - -test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -test_shape0d = [()] -test_shape1d = TestUtils.test_shape1d -test_shape2d = TestUtils.test_shape2d -test_shape3d = TestUtils.test_shape3d - -# 定义 dtype 到 (test_func, test_sigtype) 的映射 -dtype_mapping3d = { - 'int8': (fn_npu_int8_3d, torch.int8), - 'int16': (fn_npu_int16_3d, torch.int16), - 'int32': (fn_npu_int32_3d, torch.int32), - 'int64': (fn_npu_int64_3d, torch.int64), - 'float16': (fn_npu_fp16_3d, torch.float16), - 'float32': (fn_npu_fp32_3d, torch.float32), - 'bfloat16': (fn_npu_bf16_3d, torch.bfloat16), - 'bool': (fn_npu_bool_3d, torch.bool), -} -dtype_mapping2d = { - 'int8': (fn_npu_int8_2d, torch.int8), - 'int16': (fn_npu_int16_2d, torch.int16), - 'int32': (fn_npu_int32_2d, torch.int32), - 'int64': (fn_npu_int64_2d, torch.int64), - 'float16': (fn_npu_fp16_2d, torch.float16), - 'float32': (fn_npu_fp32_2d, torch.float32), - 'bfloat16': (fn_npu_bf16_2d, torch.bfloat16), - 'bool': (fn_npu_bool_2d, torch.bool), -} -dtype_mapping1d = { - 'int8': (fn_npu_int8_1d, torch.int8), - 'int16': (fn_npu_int16_1d, torch.int16), - 'int32': (fn_npu_int32_1d, torch.int32), - 'int64': (fn_npu_int64_1d, torch.int64), - 'float16': (fn_npu_fp16_1d, torch.float16), - 'float32': (fn_npu_fp32_1d, torch.float32), - 'bfloat16': (fn_npu_bf16_1d, torch.bfloat16), - 'bool': (fn_npu_bool_1d, torch.bool), -} -dtype_mapping0d = { - 'int8': (fn_npu_int8_0d, torch.int8), - 'int16': (fn_npu_int16_0d, torch.int16), - 'int32': (fn_npu_int32_0d, torch.int32), - 'int64': (fn_npu_int64_0d, torch.int64), - 'float16': (fn_npu_fp16_0d, torch.float16), - 'float32': (fn_npu_fp32_0d, torch.float32), - 'bfloat16': (fn_npu_bf16_0d, torch.bfloat16), - 'bool': (fn_npu_bool_0d, torch.bool), -} - -# 生成测试用例 -testlist = [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape0d - for func, dtype in [dtype_mapping0d[sigtype]] # 直接解包映射结果 - ] - -testlist += [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape1d - for func, dtype in [dtype_mapping1d[sigtype]] # 直接解包映射结果 - ] - -testlist += [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape2d - for func, dtype in [dtype_mapping2d[sigtype]] # 直接解包映射结果 - ] - -testlist += [(func, sigtype, dtype, shape) - for sigtype in test_dtype - for shape in test_shape3d - for func, dtype in [dtype_mapping3d[sigtype]] # 直接解包映射结果 - ] - - -@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist) -def test_npu(testfunc, sigtype, dtype, shape): - if check_ub_mem_overflow(sigtype, shape): - pytest.skip(f"dtype:{sigtype} shape:{shape} mem overflow") - x = 0 - output = 0 - if len(shape) == 3: - x = torch.full((shape[0], shape[1], shape[2]), 0, dtype=dtype).npu() - output = torch.randint(1, (shape[0], shape[1], shape[2]), dtype=dtype).npu() - testfunc[(1, 1, 1)](output, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2]) - if len(shape) == 2: - x = torch.full((shape[0], shape[1]), 0, dtype=dtype).npu() - output = torch.randint(1, (shape[0], shape[1]), dtype=dtype).npu() - shape0 = shape[0] - shape1 = shape[1] - if x.numel() * x.element_size() >= 8192: - grid = (shape0, 1, 1) - shape0 = 1 - else: - grid = (1, 1, 1) - testfunc[grid](output, shape0, shape1, shape0, shape1) - if len(shape) == 1: - x = torch.full((shape[0], ), 0, dtype=dtype).npu() - output = torch.randint(1, (shape[0], ), dtype=dtype).npu() - testfunc[1, 1, 1](output, shape[0], shape[0]) - if len(shape) == 0: - output = torch.randint(1, size=shape, dtype=dtype).npu() - x = torch.zeros_like(output) - testfunc[(1, )](output_ptr=output, N=1) - test_common.validate_cmp(sigtype, output, x) - - -@triton.jit -def fn_npu_multi_d(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - dtype = output_ptr.type.element_ty - - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - if (YB * ZB * MB * NB) == 1: - ret = tl.zeros((XB, ), dtype=dtype) - elif (ZB * MB * NB) == 1: - ret = tl.zeros((XB, YB), dtype=dtype) - elif (MB * NB) == 1: - ret = tl.zeros((XB, YB, ZB), dtype=dtype) - elif NB == 1: - ret = tl.zeros((XB, YB, ZB, MB), dtype=dtype) - else: - ret = tl.zeros((XB, YB, ZB, MB, NB), dtype=dtype) - - tl.store(output_ptr + offsets, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('param_list', [ - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), -]) -def test_case_4d_5d(param_list): - dtype, shape = param_list - if check_ub_mem_overflow(dtype, shape): - pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") - y_ref = torch.full(shape, 0, dtype=eval('torch.' + dtype)).npu() - print(f"y_ref = {torch.flatten(y_ref)[0:4]}") - - y_cal = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - fn_npu_multi_d[(1, )](y_cal, *triton_shape) - print(f"y_cal = {torch.flatten(y_cal)[0:4]}") - test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_zeroslike.py b/third_party/ascend/unittest/generalization_cases/test_zeroslike.py deleted file mode 100644 index 014ba4bbdc..0000000000 --- a/third_party/ascend/unittest/generalization_cases/test_zeroslike.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - -import test_common -from test_common import TestUtils, check_ub_mem_overflow - - -@triton.jit -def fn_npu_0d(output_ptr, x_ptr, YB: tl.constexpr): - yidx = tl.arange(0, YB) - - idx = yidx - - X = tl.load(x_ptr + idx) - - ret = tl.zeros_like(X) - - oidx = yidx - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_1d(output_ptr, x_ptr, YB: tl.constexpr): - yidx = tl.arange(0, YB) - - idx = yidx - - X = tl.load(x_ptr + idx) - - ret = tl.zeros_like(X) - - oidx = yidx - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_2d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): - pid = tl.program_id(0) - yidx = tl.arange(0, YB)[:, None] + pid * YB - zidx = tl.arange(0, ZB)[None, :] - - idx = yidx * ZB + zidx - - X = tl.load(x_ptr + idx) - - ret = tl.zeros_like(X) - - oidx = yidx * ZB + zidx - - tl.store(output_ptr + oidx, ret) - - -@triton.jit -def fn_npu_3d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): - yidx = tl.arange(0, YB)[:, None, None] * ZB * KB - zidx = tl.arange(0, ZB)[None, :, None] * KB - kidx = tl.arange(0, KB)[None, None, :] - - idx = yidx + zidx + kidx - - X = tl.load(x_ptr + idx) - - ret = tl.zeros_like(X) - - oidx = yidx + zidx + kidx - - tl.store(output_ptr + oidx, ret) - - -test_shape0d = [()] -testlist = test_shape0d + TestUtils.test_shape1_2_3d - - -@pytest.mark.parametrize('shape', testlist) -@pytest.mark.parametrize('dtype', TestUtils.dtype_list) -def test_npu(shape, dtype): - logging.debug(f'dtype:{dtype} shape:{shape}') - if check_ub_mem_overflow(dtype, shape): - return - x = torch.full(shape, 0, dtype=eval('torch.' + dtype)).npu() - triton_res = torch.empty(shape, dtype=eval('torch.' + dtype)).npu() - torch_res = x - - if len(shape) == 0: - fn_npu_0d[1, 1, 1](triton_res, x, 1) - elif len(shape) == 1: - fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) - elif len(shape) == 2: - fn_npu_2d[shape[0], 1, 1](triton_res, x, 1, shape[1]) - elif len(shape) == 3: - fn_npu_3d[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) - - test_common.validate_cmp(dtype, triton_res, torch_res) - - -@triton.jit -def fn_npu_multi_d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, - NB: tl.constexpr): - offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) - if (YB * ZB * MB * NB) > 1: - offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) - if (ZB * MB * NB) > 1: - offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) - if (MB * NB) > 1: - offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB - if NB > 1: - offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] - - X = tl.load(x_ptr + offsets) - ret = tl.zeros_like(X) - - tl.store(output_ptr + offsets, ret) - - -@pytest.mark.shape_4d_5d -@pytest.mark.parametrize('param_list', [ - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), - ('float32', (4, 2, 16, 16)), - ('float32', (2, 4, 2, 16, 16)), -]) -def test_case_4d_5d(param_list): - dtype, shape = param_list - if check_ub_mem_overflow(dtype, shape): - return - x0 = test_common.generate_tensor(shape, dtype) - y_ref = torch.zeros_like(x0, dtype=eval('torch.' + dtype)).npu() - print(f"y_ref = {torch.flatten(y_ref)[0:4]}") - y_cal = torch.ones(shape, dtype=eval('torch.' + dtype)).npu() - - triton_shape = [*shape] - while len(triton_shape) < 5: - triton_shape.append(1) - fn_npu_multi_d[(1, )](y_cal, x0, *triton_shape) - print(f"y_cal = {torch.flatten(y_cal)[0:4]}") - test_common.validate_cmp(dtype, y_cal, y_ref) - - -if __name__ == "__main__": - for dtype in TestUtils.dtype_list: - for shape in [(37, ), (37, 3), (1, 22, 39)]: - test_npu(shape, dtype) diff --git a/third_party/ascend/unittest/kernels/README.md b/third_party/ascend/unittest/kernels/README.md deleted file mode 100644 index 20eb7e42aa..0000000000 --- a/third_party/ascend/unittest/kernels/README.md +++ /dev/null @@ -1,62 +0,0 @@ -# 指导:如何新增kernel测试用例 -新增kernel测试用例可以分为三大步: -1、准备pt文件 -2、在triton-ascend仓中添加kernel算子,完成本地kernel测试 -3、将pt文件上传到obs桶中 - -## 1、准备pt文件 - -pt 文件用于把 GPU(或参考实现)上的输入与输出作为 golden 数据,后续测试会在 NPU 上运行 Triton kernel 并与之比对。 - -**三步生成流程** - -- **步骤 1 — 构造GPU输入并保存副本预处理成NPU kernel的输入**:根据GPU上kernel或pytorch算子的参数构造 `input_data`(键名须与 kernel 参数一致),把所有 Tensor 克隆到 CPU,形成 `input_data_before`,若GPU上算子的输入和NPU上算子有出入,需要提前预处理使`input_data_before`符合NPU上算子入参的要求。 -- **步骤 2 — 运行GPU Kernel获取输出**:在GPU上运行GPU kernel,得到 `gpu_output`,并将 Tensor 转为 CPU。 -- **步骤 3 — 打包并保存**:把 `input_data_before`、`grid`、`gpu_output` 封装为字典,通过 `torch.save` 保存为 `{kernel_name}.pt`。如果有多组用例,保存为 list-of-dicts(`[case0, case1]`)。 - -**精简示例** - -```python -import copy -import torch - -DEVICE = torch.device("cuda:0") -batch_size = 2 -grid = (batch_size,) - -input_data = { - "output_token_ids_ptr": torch.zeros((batch_size, 4), dtype=torch.int32, device=DEVICE), - "cu_num_draft_tokens_ptr": torch.tensor([2, 1], dtype=torch.int32, device=DEVICE), - # ... 其它字段 -} - -# 保存输入副本到 CPU -input_data_before = { - k: (v.clone().cpu() if isinstance(v, torch.Tensor) else copy.deepcopy(v)) - for k, v in input_data.items() -} -# 预处理 input_data_before 符合 NPU kernel 输入 -input_data_before["npu_need_param_key"] = NPU_NEED_PARAMS_VALUE -# 运行 kernel(在 GPU / 参考实现上)并收集输出 -triton_kernel[grid](**input_data) -# 这里用 input_data 作为示例,实际应调用对应的 triton/pytorch 函数 -gpu_output = {k: (v.cpu() if isinstance(v, torch.Tensor) else v) for k, v in input_data.items()} - -save_obj = {"input_data": input_data_before, "grid": grid, "gpu_output": gpu_output} -torch.save(save_obj, ".pt") -# 多组用例场景:torch.save([save_obj1, save_obj2], ".pt") -``` - -## 2、在triton-ascend新增三方kernel测试用例 - -- **步骤 1 — 在triton-ascend仓中新增kernel算子** :本地验证阶段,在 kernels/xxx(例如vllm、sglang) 下新增与算子同名的 Python 文件,内容为Triton kernel函数。 -- **步骤 2 — 本地测试** :将pt文件放在kernels目录下,在项目根目录运行 -python -m pytest -v third_party/ascend/unittest/kernels/test_triton_kernel.py - -**说明** -- 指定单个 kernel:在项目根目录下执行 python -m pytest -v ascend/test/common/test_triton_kernel.py --kernel={kernel_name} -- pt文件查找策略:优先使用仓库内匹配的本地 pt,若本地不存在则按需从远端 OBS 下载 {kernel_name}.pt文件。 -- 本地已存在的pt文件,在执行完测试后不会删除,从obs桶取的文件在跑完测试后会被测试程序直接删除。 - -## 3、将pt文件上传至obs桶 -本地验证通过后,将pt文件统一上传到OBS桶当中,OBS桶链接:https://triton-ascend-artifacts.obs.cn-southwest-2.myhuaweicloud.com/test/kernels/{xxx}_pt/{kernel_name}.pt,xxx为vllm或sglang diff --git a/third_party/ascend/unittest/kernels/common_kernel.py b/third_party/ascend/unittest/kernels/common_kernel.py deleted file mode 100644 index fbbce42dd9..0000000000 --- a/third_party/ascend/unittest/kernels/common_kernel.py +++ /dev/null @@ -1,7 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit -def safe_exp(x): - return tl.exp(tl.where(x <= 0, x, float("-inf"))) diff --git a/third_party/ascend/unittest/kernels/test_common.py b/third_party/ascend/unittest/kernels/test_common.py deleted file mode 100644 index ababcb1540..0000000000 --- a/third_party/ascend/unittest/kernels/test_common.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import Optional -import torch -import pytest - -DEVICE_TYPE_NPU = 'npu' - - -def validate_cmp(dtype, y_cal, y_ref, overflow_mode: Optional[str] = None, device_type: Optional[str] = None): - if device_type is not None: - target_device = torch.device(device_type) - y_cal = y_cal.to(target_device) - y_ref = y_ref.to(target_device) - else: - y_cal = y_cal.npu() - y_ref = y_ref.npu() - if overflow_mode == "saturate": - if dtype in ['float32', 'float16']: - min_value = -torch.finfo(dtype).min - max_value = torch.finfo(dtype).max - elif dtype in ['int32', 'int16', 'int8']: - min_value = torch.iinfo(dtype).min - max_value = torch.iinfo(dtype).max - elif dtype == 'bool': - min_value = 0 - max_value = 1 - else: - raise ValueError('Invalid parameter "dtype" is found : {}'.format(dtype)) - y_ref = torch.clamp(y_ref, min=min_value, max=max_value) - if dtype == 'float16': - torch.testing.assert_close(y_ref, y_cal, rtol=5e-03, atol=5e-03, equal_nan=True) - elif dtype == 'bfloat16': - torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=5e-03, atol=5e-03, - equal_nan=True) - elif dtype == 'float32': - torch.testing.assert_close(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) - elif dtype in ['int64', 'int32', 'int16', 'int8']: - assert torch.equal(y_cal, y_ref) - elif dtype == 'bool': - assert torch.equal(y_cal, y_ref) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def convert_tensor_with_device_type(indata: dict, device_type: str): - target_device = torch.device(device_type) - outdata = {} - - for key, value in indata.items(): - if isinstance(value, torch.Tensor): - if value.device.type != target_device.type: - outdata[key] = value.to(target_device) - else: - outdata[key] = value - else: - outdata[key] = value - - return outdata - - -def compare_data_precision(dict_ref: dict, dict_cal: dict, device_type: str): - keys_ref, keys_cal = set(dict_ref.keys()), set(dict_cal.keys()) - if not keys_ref.issubset(keys_cal): - raise ValueError("The keys of dict_ref is not subset of dict_cal") - - for key in dict_ref.keys(): - val_a, val_b = dict_ref[key], dict_cal[key] - if not isinstance(val_b, type(val_a)): - raise ValueError("The data type of two dicts are different") - - if isinstance(val_a, torch.Tensor): - validate_cmp(dtype=str(val_a.dtype).split('.')[-1], y_ref=val_a, y_cal=val_b, device_type=device_type) - - -def run_and_compare_ptfile(ptfile_path: str, kernel_runner, device_type: str = DEVICE_TYPE_NPU): - try: - datas = torch.load(ptfile_path, map_location=torch.device('cpu')) - except Exception as e: - pytest.fail(f"load file {ptfile_path} failed: {e}") - - def _run_single_case(data): - if not isinstance(data, dict): - pytest.fail("Each case loaded from pt file must be a dict") - - input_data = convert_tensor_with_device_type(data.get("input_data", {}), device_type=device_type) - grid = data.get("grid") - try: - kernel_runner(input_data, grid) - except Exception as e: - pytest.fail(f"kernel_runner execution failed: {e}") - - output_data_cpu = convert_tensor_with_device_type(input_data, device_type='cpu') - expected = data.get("gpu_output", {}) - expected_filtered = {k: expected[k] for k in output_data_cpu.keys() if k in expected} - if not expected_filtered: - pytest.fail("No matching expected outputs found in pt file for comparison") - try: - compare_data_precision(expected_filtered, output_data_cpu, device_type='cpu') - except Exception as e: - pytest.fail(f"The testcase failed: {e}") - - # Supports three scenarios: - # 1) The file stores a single dict (existing behavior) - # 2) The file stores a list, where each element is a case dict - # 3) The file stores a dict, but some tensors represent multiple cases in batch on the 0th dimension (no automatic splitting; it is recommended to use a list) - if isinstance(datas, list): - for _, data in enumerate(datas): - _run_single_case(data) - elif isinstance(datas, dict): - _run_single_case(datas) - else: - pytest.fail("Unsupported pt file format: must be a dict or a list of dicts") diff --git a/third_party/ascend/unittest/kernels/test_triton_kernel.py b/third_party/ascend/unittest/kernels/test_triton_kernel.py deleted file mode 100644 index 528c8eb088..0000000000 --- a/third_party/ascend/unittest/kernels/test_triton_kernel.py +++ /dev/null @@ -1,73 +0,0 @@ -import importlib -import os -import urllib.request -from pathlib import Path - -import pytest - -import test_common - - -def discover_kernels(): - kernels = [] - kernels_root_path = Path(__file__).parents[0] - for p in kernels_root_path.rglob("*.py"): - if not p.is_file(): - continue - if p.parent == kernels_root_path: - continue - rel = p.relative_to(kernels_root_path) - if len(rel.parts) == 1 or p.name == "__init__.py": - continue - module_path = ".".join(rel.with_suffix("").parts) - kernels.append((module_path, p.stem)) - return sorted(kernels, key=lambda x: x[1]) - - -KERNEL_ITEMS = discover_kernels() - - -@pytest.mark.parametrize("module_path, kernel_name", KERNEL_ITEMS) -def test_triton_kernel(module_path, kernel_name, pytestconfig): - selected = pytestconfig.getoption("kernel") - if selected: - if kernel_name not in selected: - pytest.skip(f"skip {kernel_name} due to --kernel filter") - base_url = "https://triton-ascend-artifacts.obs.cn-southwest-2.myhuaweicloud.com" - rel = module_path - parts = rel.split(".") if rel else [] - pt_url = f"{base_url}/test/kernels/{parts[0]}_pt/{kernel_name}.pt" - local_pt = Path(__file__).parent / f"{kernel_name}.pt" - downloaded = False - if not local_pt.exists(): - try: - urllib.request.urlretrieve(pt_url, local_pt) - downloaded = True - except Exception as e: - pytest.fail( - f"Failed to download the {kernel_name}.pt file. Please check whether the {kernel_name}.pt file has been uploaded to the OBS bucket: {e}" - ) - try: - mod = importlib.import_module(module_path) - except Exception as e: - pytest.fail(f"import {module_path} failed: {e}") - - if hasattr(mod, kernel_name): - kernel_attr = kernel_name - else: - candidates = [a for a in dir(mod) if a.endswith("_kernel")] - kernel_attr = candidates[0] if candidates else None - - if not kernel_attr: - pytest.fail(f"No kernel callable found in {module_path}") - - kernel_callable = getattr(mod, kernel_attr) - - def runner(input_data, grid): - kernel_callable[grid](**input_data) - - try: - test_common.run_and_compare_ptfile(str(local_pt), runner, device_type='npu') - finally: - if downloaded and local_pt.exists(): - local_pt.unlink() diff --git a/third_party/ascend/unittest/kernels/vllm/expand_kernel.py b/third_party/ascend/unittest/kernels/vllm/expand_kernel.py deleted file mode 100644 index 8c87b6d0b2..0000000000 --- a/third_party/ascend/unittest/kernels/vllm/expand_kernel.py +++ /dev/null @@ -1,33 +0,0 @@ -import triton -import triton.language as tl -import triton.language.extra.cann.extension as extension - - -@triton.jit(do_not_specialize=["replace_from", "replace_to"]) -def expand_kernel( - output_ptr, # [num_tokens] - input_ptr, # [batch_size] - cu_num_tokens_ptr, # [batch_size] - replace_from, - replace_to, - vec_len, - MAX_NUM_TOKENS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - req_idx = tl.program_id(0) - offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - len_mask = offset < vec_len - - start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) - end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask) - num_tokens = end_idx - start_idx - - src_val = tl.load(input_ptr + offset, len_mask) - src_val = tl.where(src_val == replace_from, replace_to, src_val) - - for i in tl.range(0, BLOCK_SIZE): - num_tokens1 = extension.get_element(num_tokens, (i, )) - start_idx1 = extension.get_element(start_idx, (i, )) - src_val1 = extension.get_element(src_val, (i, )) - offset1 = tl.arange(0, MAX_NUM_TOKENS) - tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1) diff --git a/third_party/ascend/unittest/kernels/vllm/rejection_random_sample_kernel.py b/third_party/ascend/unittest/kernels/vllm/rejection_random_sample_kernel.py deleted file mode 100644 index b5d124a912..0000000000 --- a/third_party/ascend/unittest/kernels/vllm/rejection_random_sample_kernel.py +++ /dev/null @@ -1,55 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit(do_not_specialize=["max_spec_len"]) -def rejection_random_sample_kernel( - output_token_ids_ptr, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens_ptr, # [batch_size] - draft_token_ids_ptr, # [num_tokens] - draft_probs_ptr, # [num_tokens, vocab_size] or None - target_probs_ptr, # [num_tokens, vocab_size] - bonus_token_ids_ptr, # [batch_size] - recovered_token_ids_ptr, # [num_tokens] - uniform_probs_ptr, # [num_tokens] - is_greedy_ptr, # [batch_size] - max_spec_len, - vocab_size, - NO_DRAFT_PROBS: tl.constexpr, -): - req_idx = tl.program_id(0) - is_greedy = tl.load(is_greedy_ptr + req_idx) - if is_greedy: - # Early exost for greedy sampling requests - return - - start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) - end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) - num_draft_tokens = end_idx - start_idx - - rejected = False - for pos in range(num_draft_tokens): - if not rejected: - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - if NO_DRAFT_PROBS: - draft_prob = 1 - else: - draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) - if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: - # Accept - token_id = draft_token_id - else: - # Reject. Use recovered token - rejected = True - token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) - tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) - - if not rejected: - # If all tokens are accepted, append the bonus token - bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) - tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, - bonus_token_id, - ) diff --git a/third_party/ascend/unittest/kernels/vllm/sample_recovered_tokens_kernel.py b/third_party/ascend/unittest/kernels/vllm/sample_recovered_tokens_kernel.py deleted file mode 100644 index 24aa9c7b7c..0000000000 --- a/third_party/ascend/unittest/kernels/vllm/sample_recovered_tokens_kernel.py +++ /dev/null @@ -1,77 +0,0 @@ -import triton -import triton.language as tl -import triton.language.extra.cann.extension as extension - - -@triton.jit -def sample_recovered_tokens_kernel( - output_token_ids_ptr, # [num_tokens] - cu_num_draft_tokens_ptr, # [batch_size] - draft_token_ids_ptr, # [num_tokens] - draft_probs_ptr, # [num_tokens, vocab_size] or None - target_probs_ptr, # [num_tokens, vocab_size] - q_ptr, # [batch_size, vocab_size] - vocab_size, - PADDED_VOCAB_SIZE: tl.constexpr, - NO_DRAFT_PROBS: tl.constexpr, - SUB_BLOCK: tl.constexpr, -): - req_idx = tl.program_id(0) - start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) - end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) - num_draft_tokens = end_idx - start_idx - - # Early exit for out-of-range positions. - pos = tl.program_id(1) - if pos >= num_draft_tokens: - return - - loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK - global_recovered_id = -1 - global_max_p = -1.0 - if NO_DRAFT_PROBS: - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - # Temporarily zero out the probability of the draft token. - # This is essentially the same as target_prob - draft_prob, except that - # n-gram does not have draft_prob. We regard it as 1. - tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0) - for loop_i in range(loop): - vocab_start = loop_i * SUB_BLOCK - vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) - prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset - < vocab_size, other=0) - q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, - other=float("-inf")) - new_p = prob / q - recovered_id = tl.argmax(new_p, axis=-1) - max_p = extension.get_element(new_p, (recovered_id, )) - if max_p > global_max_p: - global_max_p = max_p - global_recovered_id = vocab_start + recovered_id - else: - for loop_i in range(loop): - vocab_start = loop_i * SUB_BLOCK - vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) - draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset - < vocab_size, other=0) - target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset - < vocab_size, other=0) - prob = tl.maximum(target_prob - draft_prob, 0) - # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because - # `tl.argmax` will select the maximum value. - - q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, - other=float("-inf")) - new_p = prob / q - recovered_id = tl.argmax(new_p, axis=-1) - max_p = extension.get_element(new_p, (recovered_id, )) - if max_p > global_max_p: - global_max_p = max_p - global_recovered_id = vocab_start + recovered_id - - tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id) - - if NO_DRAFT_PROBS: - # Restore the original probability. - tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob) diff --git a/third_party/ascend/unittest/pytest_ut/test_01_vector_add.py b/third_party/ascend/unittest/pytest_ut/test_01_vector_add.py new file mode 100644 index 0000000000..5b3554c7c2 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_01_vector_add.py @@ -0,0 +1,87 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +""" +Vector Addition - Pytest Version +""" + +import torch +import torch_npu + +import triton +import triton.language as tl + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +# %% +# Let's also declare a helper function to (1) allocate the `z` tensor +# and (2) enqueue the above kernel with appropriate grid/block sizes: + + +def add(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + return output + + +# %% +# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: +def test_vector_addition(): + torch.manual_seed(0) + size = 98432 + x = torch.rand(size, device='npu') + y = torch.rand(size, device='npu') + output_torch = x + y + output_triton = add(x, y) + torch.testing.assert_close(output_triton, output_torch) diff --git a/third_party/ascend/unittest/pytest_ut/test_02_fused_softmax.py b/third_party/ascend/unittest/pytest_ut/test_02_fused_softmax.py new file mode 100644 index 0000000000..e4b8c0a3f6 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_02_fused_softmax.py @@ -0,0 +1,129 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Fused Softmax +============= +""" + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + + +def naive_softmax(x): + """Compute row-wise softmax of X using native pytorch + + We subtract the maximum element in order to avoid overflows. Softmax is invariant to + this shift. + """ + # read MN elements ; write M elements + x_max = x.max(dim=1)[0] + # read MN + M elements ; write MN elements + z = x - x_max[:, None] + # read MN elements ; write MN elements + numerator = torch.exp(z) + # read MN elements ; write M elements + denominator = numerator.sum(dim=1) + # read MN + M elements ; write MN elements + ret = numerator / denominator[:, None] + # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements + return ret + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) + + +kernels = {} + + +def softmax(x): + n_rows, n_cols = x.shape + + # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + + # Allocate output + y = torch.empty_like(x) + + # pre-compile kernel to get register usage and compute thread occupancy. + kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) + if kernel is None: + num_programs = 32 + kernel = softmax_kernel + kernels[BLOCK_SIZE] = (kernel, num_programs) + + num_programs = min(num_programs, n_rows) + + # Create a number of persistent programs. + kernel[(num_programs, 1, 1)]( + y, + x, + x.stride(0), + y.stride(0), + n_rows, + n_cols, + BLOCK_SIZE + ) + return y + + +@pytest.mark.parametrize( + "shape", + [ + (1823, 781), + (128, 257), + ], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_fused_softmax(shape, dtype): + torch.manual_seed(0) + x = torch.randn(shape, dtype=dtype, device="npu") + + y_triton = softmax(x) + y_torch = torch.softmax(x, axis=1) + + torch.testing.assert_close(y_triton, y_torch, atol=1e-4, rtol=1e-4) diff --git a/third_party/ascend/unittest/pytest_ut/test_03_matrix_multiplication.py b/third_party/ascend/unittest/pytest_ut/test_03_matrix_multiplication.py new file mode 100644 index 0000000000..43aabfafc7 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_03_matrix_multiplication.py @@ -0,0 +1,228 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Matrix Multiplication +=============== +""" + +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import triton.language.extra.cann.extension as extension + +DEV = "npu" + + +def get_autotune_config(): + return [ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}), + triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}), + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + ACTIVATION: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + GROUP_SIZE_M: tl.constexpr = 1 + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs_base = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs_base = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + msk_m = offs_am < M + msk_n = offs_bn < N + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a_ptrs = a_ptrs_base + k * BLOCK_SIZE_K * stride_ak + b_ptrs = b_ptrs_base + k * BLOCK_SIZE_K * stride_bk + a = tl.load( + a_ptrs, + mask=msk_m[:, None] and (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=msk_n[None, :] and (offs_k[:, None] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + # Original vector operations + # # ----------------------------------------------------------- + # # Write back the block of the output matrix C with masks. + # Comment out the following lines to enable split the workload to two vector cores + SUB_BLK_M: tl.constexpr = BLOCK_SIZE_M // 2 + for s in extension.parallel(0, 2, bind_sub_block=True): + vec_sub_blk = extension.extract_slice( + accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1) + ) + if ACTIVATION == "leaky_relu_custom": + vec_sub_blk = leaky_relu_custom(vec_sub_blk) + c_sub_blk = vec_sub_blk.to(tl.float16) + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + s * SUB_BLK_M + tl.arange(0, SUB_BLK_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c_sub_blk, mask=c_mask) + + +# We can fuse `leaky_relu_custom` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu_custom(x): + return tl.where(x >= 0, x, 0.01 * x) + 1.0 + + +def torch_matmul(a, b, activation=""): + c = torch.matmul(a, b) + if activation == "leaky_relu_custom": + c = torch.where(c >= 0, c, 0.01 * c) + 1.0 + return c + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + matmul_kernel[grid]( + a, + b, + c, # + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + ACTIVATION=activation, # + ) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation. +@pytest.mark.parametrize( + "shape", + [ + (512, 512, 512), + (256, 384, 128), + ], +) +@pytest.mark.parametrize( + "activation", + [ + "", + pytest.param("leaky_relu_custom", marks=pytest.mark.skip(reason="temporarily skip leaky_relu_custom ub overflow case")), + ], +) +def test_matrix_multiplication(shape, activation): + m, k, n = shape + torch.manual_seed(0) + + a = torch.randn((m, k), device=DEV, dtype=torch.float16) + b = torch.randn((k, n), device=DEV, dtype=torch.float16) + + triton_output = matmul(a, b, activation) + torch_output = torch_matmul(a, b, activation) + + torch.testing.assert_close(triton_output, torch_output, atol=1e-3, rtol=1e-3) diff --git a/third_party/ascend/unittest/pytest_ut/test_04_low_memory_dropout.py b/third_party/ascend/unittest/pytest_ut/test_04_low_memory_dropout.py new file mode 100644 index 0000000000..1aab7a20d6 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_04_low_memory_dropout.py @@ -0,0 +1,134 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Low-Memory Dropout +================== +""" + +import pytest +import torch +import torch_npu + +import triton +import triton.language as tl + +DEV = "npu" + + +@triton.jit +def _dropout( + x_ptr, # pointer to the input + x_keep_ptr, # pointer to a mask of 0s and 1s + output_ptr, # pointer to the output + n_elements, # number of elements in the `x` tensor + p, # probability that an element of `x` is changed to zero + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + # Load data + x = tl.load(x_ptr + offsets, mask=mask) + x_keep = tl.load(x_keep_ptr + offsets, mask=mask) + # The line below is the crucial part, described in the paragraph above! + output = tl.where(x_keep != 0, x / (1 - p), 0.0) + # Write-back output + tl.store(output_ptr + offsets, output, mask=mask) + + +def dropout(x, x_keep, p): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) + return output + + +@triton.jit +def _seeded_dropout( + x_ptr, + output_ptr, + n_elements, + p, + seed, + BLOCK_SIZE: tl.constexpr, +): + # compute memory offsets of elements handled by this instance + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # load data from x + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + # randomly prune it + random = tl.rand(seed, offsets) + x_keep = random > p + # write-back + output = tl.where(x_keep, x / (1 - p), 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + + +def seeded_dropout(x, p, seed): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) + return output + + +@pytest.mark.parametrize("shape,p", [((10, ), 0.5), ((256, ), 0.5), ((513, ), 0.2), ((32, 64), 0.35)]) +def test_dropout_matches_reference(shape, p): + torch.manual_seed(0) + x = torch.randn(size=shape, device=DEV, dtype=torch.float32) + x_keep = (torch.rand(size=shape, device=DEV) > p).to(torch.int32) + + output = dropout(x, x_keep=x_keep, p=p) + expected = torch.where(x_keep != 0, x / (1 - p), torch.zeros_like(x)) + + torch.testing.assert_close(output, expected, atol=1e-6, rtol=0) + + +@pytest.mark.parametrize("shape,p,seed", [((10, ), 0.5, 123), ((256, ), 0.5, 123), ((513, ), 0.2, 7), ((32, 64), 0.35, 999)]) +def test_seeded_dropout_is_deterministic(shape, p, seed): + torch.manual_seed(0) + x = torch.randn(size=shape, device=DEV, dtype=torch.float32) + + output = seeded_dropout(x, p=p, seed=seed) + output_same_seed = seeded_dropout(x, p=p, seed=seed) + output_different_seed = seeded_dropout(x, p=p, seed=512) + + torch.testing.assert_close(output, output_same_seed, atol=1e-6, rtol=0) + + assert output.shape == x.shape + assert output.dtype == x.dtype + assert torch.count_nonzero(output != output_different_seed).item() > 0 + assert torch.count_nonzero(output).item() <= x.numel() diff --git a/third_party/ascend/unittest/pytest_ut/test_05_layer_norm.py b/third_party/ascend/unittest/pytest_ut/test_05_layer_norm.py new file mode 100644 index 0000000000..ef14ac70f7 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_05_layer_norm.py @@ -0,0 +1,120 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Layer Normalization +============= +""" + +import pytest +import torch +import triton +import triton.language as tl +import torch_npu + + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +@torch.inference_mode() +def layer_norm(x, normalized_shape, weight, bias, eps=1e-5): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + kernel = _layer_norm_fwd_fused[(M, )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + return y + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_layer_norm(dtype): + M, N = 128, 128 + eps = 1e-5 + device = 'npu' + + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device=device) + bias = torch.rand(w_shape, dtype=dtype, device=device) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + + y_tri = layer_norm(x, w_shape, weight, bias, eps) + y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) diff --git a/third_party/ascend/unittest/pytest_ut/test_06_fused_attention.py b/third_party/ascend/unittest/pytest_ut/test_06_fused_attention.py new file mode 100644 index 0000000000..5d8c695c19 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_06_fused_attention.py @@ -0,0 +1,352 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Credits: OpenAI kernel team + +Extra Credits: + +* Original flash attention paper (https://arxiv.org/abs/2205.14135) +* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +""" + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.language.extra.cann.extension as extension + +DEVICE = "npu" + + +@triton.jit +def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, query vector + K_block_ptr, V_block_ptr, # Key and value block pointers for current stage + start_m, qk_scale, # Starting position of current query block, qk scale factor + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # Block size constants + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # Current stage flag, m and n offset indices + N_CTX: tl.constexpr, fp8_v: tl.constexpr): # Total context length, whether to enable FP8 for value precision + # Set the processing range [lo, hi) for the current stage (in column block units) + # Causal attention, as the name implies, restricts the flow of information during computation, + # only allowing the model to see the current and previous positions. + # In other words, the output at the current position can only depend on the input at or before this position, + # and cannot access information from future positions. + # Causal attention ensures sequential order and prevents "leakage of future information." + # But the following logic will also be triggered + if STAGE == 1: + # Stage 1: process all tokens before the query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + # Stage 2: process the current query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) # Align starting position + # causal = False (no need for masking) + else: + lo, hi = 0, N_CTX # Process the entire context + + # Adjust K and V block pointers to the starting position `lo` + K_block_ptr = tl.advance(K_block_ptr, (lo, 0)) # K is [HEAD_DIM, N_CTX], shift along the second dim by lo + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # V is [N_CTX, HEAD_DIM], shift along the first dim by lo + + # Index mapping for the accumulator , used for slicing when HEAD_DIM >= 256 + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + + # Iterate over all k, v blocks in the current stage and accumulate the output + for start_n in range(lo, hi, BLOCK_N): # Process BLOCK_N columns at a time + start_n = tl.multiple_of(start_n, BLOCK_N) # Align column start position + # -- Compute qk ---- + k = tl.load(K_block_ptr) + # Modify K + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + # Apply causal mask for STAGE 2 + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) # Construct upper triangular mask + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) # Set invalid positions to -∞ + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Update m_ij = max(m_i, max(qk)) + qk -= m_ij[:, None] # Subtract max for softmax stability + else: + qk = qk * qk_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max + qk = qk - m_ij[:, None] # Stabilize + + # Softmax weights p = exp(qk) + p = tl.math.exp(qk) + + # Convert softmax weight type depending on FP8 usage + if fp8_v: + p_cast = p.to(tl.float8e5) # Convert to FP8 format (save memory) + else: + p_cast = p.to(k.dtype) + + v = tl.load(V_block_ptr) # Load corresponding V block + pv = tl.dot(p_cast, v) + l_ij = tl.sum(p, 1) # Softmax denominator (sum of each row) + # -- Update m_i and l_i + alpha = tl.math.exp(m_i - m_ij) # Update factor: exp difference between old and new max + l_i = l_i * alpha + l_ij # Update softmax denominator + # -- Update output accumulator -- + if HEAD_DIM < 256: + acc_ptr = acc_ptr * alpha[:, None] + acc_ptr = tl.dot(p_cast, v, acc_ptr) + else: + # 1. Load current slice of accumulator + acc = tl.load(acc_ptr + block2d_acc) + # 2. Update in slices (split by 1/4 of BLOCK_M to avoid ub overflow) + for i in range(4): + # Calculate start/end rows for current slice + offset = i * (BLOCK_M // 4) + # Extract slice data + acc_i = extension.extract_slice(acc, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + alpha_i = extension.extract_slice(alpha, [offset], [BLOCK_M // 4], [1]) + pv_i = extension.extract_slice(pv, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # Incrementally update slice: acc = acc * alpha + pv + acc_i = acc_i * alpha_i[:, None] + pv_i + # Write updated slice back to accumulator + acc = extension.insert_slice(acc, acc_i, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # 3. updated accumulator + tl.store(acc_ptr + block2d_acc, acc) + + m_i = m_ij # Update current block max + # Advance V and K block pointers to next BLOCK_N range + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + # Return accumulated output acc_ptr, softmax denominator l_i, and max value m_i + return acc_ptr, l_i, m_i + + +@triton.jit +def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, + stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, stride_qk: tl.constexpr, + stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, stride_kk: tl.constexpr, + stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, stride_vk: tl.constexpr, + stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, stride_on: tl.constexpr, + Z: tl.constexpr, H: tl.constexpr, + N_CTX: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr + ): + # Total number of blocks in sequence dimension (M) + NUM_BLOCKS_M = N_CTX // BLOCK_M + # Total tasks = number of sequence blocks × batch size (Z) × number of attention heads (H) + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + + # Current M-dimension block index + pid = tl.program_id(0) + + for block_idx in range(pid, NUM_BLOCKS, 20): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + # Create block pointers for Q, K, V, Output + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + # Initialize offsets + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + + # Initialize accumulator + if HEAD_DIM < 256: + acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + else: + acc_offset = ( + off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM + + off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM + + task_m_idx * BLOCK_M * HEAD_DIM + ) + acc_ptr = acc + acc_offset + + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + + m_i += tl.math.log(l_i) + if HEAD_DIM < 256: + accumulator = acc_ptr / l_i[:, None] + else: + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + accumulator = tl.load(acc_ptr + block2d_acc) + accumulator = accumulator / l_i[:, None] + + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, accumulator.to(Out.type.element_ty)) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, BM, BN): + """ + Forward computation interface: + Args: + ctx: Context object + q: Query tensor (Q), shape [Z, H, N_CTX, HEAD_DIM] + k: Key tensor (K), shape [Z, H, N_CTX, HEAD_DIM] + v: Value tensor (V), shape [Z, H, N_CTX, HEAD_DIM] + causal: Whether to enable causal attention + sm_scale: Scaling factor for QK product + BM: Q block size (BLOCK_M) + BN: K/V block size (BLOCK_N) + Returns: + o: Attention output tensor, shape [Z, H, N_CTX, HEAD_DIM] + """ + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + out = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + + # Number of NPU cores (adjust based on hardware) + num_cores = 20 + acc = torch.zeros((q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), dtype=torch.float32, device=q.device) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[(num_cores,)]( + q, k, v, M, out, acc, sm_scale, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + q.shape[0], q.shape[1], N_CTX=q.shape[2], + HEAD_DIM=HEAD_DIM_K, + BLOCK_M=BM, + BLOCK_N=BN, + STAGE=stage, + **extra_kern_args) + + ctx.save_for_backward(q, k, v, out, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return out + + +attention = _attention.apply + + +# ==================== Pytest Test ==================== +@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN", [ + (1, 1, 128, 128, False, torch.float16, 32, 128), + (1, 1, 128, 128, False, torch.bfloat16, 64, 128), + (1, 2, 256, 256, False, torch.bfloat16, 32, 256), + (2, 2, 128, 256, False, torch.float16, 64, 128), + (4, 32, 64, 64, False, torch.float16, 32, 64), + (4, 32, 1024, 64, False, torch.bfloat16, 64, 128), + (4, 32, 4096, 64, False, torch.float16, 128, 128), +]) +def test_attention_fused(Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN): + if N_CTX % BM != 0 or N_CTX % BN != 0 or HEAD_DIM % 16 != 0: + pytest.skip("Skipping non-divisible case") + + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + + sm_scale = 0.5 + tri_out = attention(q, k, v, causal, sm_scale, BM, BN) + ref_out = torch_npu.npu_fusion_attention( + q, k, v, H, + padding_mask=None, + atten_mask=None, + scale=sm_scale, + keep_prob=1.0, + input_layout="BNSD", + pre_tockens=65535, + next_tockens=65535, + sparse_mode=0, + )[0] + + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2, equal_nan=True) diff --git a/third_party/ascend/unittest/pytest_ut/test_07_extern_functions.py b/third_party/ascend/unittest/pytest_ut/test_07_extern_functions.py new file mode 100644 index 0000000000..3da074572f --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_07_extern_functions.py @@ -0,0 +1,84 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Libdevice (`tl.extra.libdevice`) function +============================== +""" + +import pytest +import torch +import torch_npu + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +from triton.backends.ascend.compiler import get_libdevice + +DEV = "npu" + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) + + +extern_libs = {'libdevice': get_libdevice()} + + +def run_asin_case(size, use_extern_libs): + torch.manual_seed(0) + x = torch.rand(size, device=DEV) + output_triton = torch.empty_like(x) + output_torch = torch.asin(x) + assert x.device.type == DEV and output_triton.device.type == DEV + + n_elements = output_torch.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + launch_kwargs = {"BLOCK_SIZE": 1024} + if use_extern_libs: + launch_kwargs["extern_libs"] = extern_libs + + asin_kernel[grid](x, output_triton, n_elements, **launch_kwargs) + torch.testing.assert_close(output_torch, output_triton, rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize("size", [98432, 1024]) +def test_asin_kernel_matches_torch(size): + run_asin_case(size=size, use_extern_libs=False) + + +@pytest.mark.parametrize("size", [98432, 1024]) +def test_asin_kernel_matches_torch_with_extern_libs(size): + run_asin_case(size=size, use_extern_libs=True) diff --git a/third_party/ascend/unittest/pytest_ut/test_08_grouped_gemm.py b/third_party/ascend/unittest/pytest_ut/test_08_grouped_gemm.py new file mode 100644 index 0000000000..c4954ebb16 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_08_grouped_gemm.py @@ -0,0 +1,288 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# Copyright (c) Huawei Technologies Co., Ltd. 2025. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Group GEMM +============================ +""" + +import pytest +import torch +import torch_npu + +import triton +import triton.language as tl +import triton.runtime.driver as driver + +DEV = "npu" + + +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +NUM_CORES = get_npu_properties()["num_aicore"] + + +@triton.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': NUM_CORES, + }), + ], + key=['group_size'], +) +@triton.jit +def grouped_matmul_kernel( + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + group_gemm_sizes, + g_lds, + group_size, + NUM_SM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + tile_idx = tl.program_id(0) + last_problem_end = 0 + for g in range(group_size): + gm = tl.load(group_gemm_sizes + g * 3) + gn = tl.load(group_gemm_sizes + g * 3 + 1) + gk = tl.load(group_gemm_sizes + g * 3 + 2) + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles): + k = gk + lda = tl.load(g_lds + g * 3) + ldb = tl.load(g_lds + g * 3 + 1) + ldc = tl.load(g_lds + g * 3 + 2) + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) + tile_idx_in_gemm = tile_idx - last_problem_end + tile_m_idx = tile_idx_in_gemm // num_n_tiles + tile_n_idx = tile_idx_in_gemm % num_n_tiles + + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] + b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for _ in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + tl.multiple_of(a_ptrs, [16, 16]) + tl.multiple_of(b_ptrs, [16, 16]) + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * ldb + c = accumulator.to(tl.float16) + + offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] + + tl.store(c_ptrs, c) + tile_idx += NUM_SM + + last_problem_end = last_problem_end + num_tiles + + +def group_gemm_fn(group_A, group_B): + device = torch.device(DEV) + assert len(group_A) == len(group_B) + group_size = len(group_A) + + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = group_A[i] + B = group_B[i] + assert A.shape[1] == B.shape[0] + M, K = A.shape + K, N = B.shape + C = torch.empty((M, N), device=device, dtype=A.dtype) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [M, N, K] + g_lds += [A.stride(0), B.stride(0), C.stride(0)] + + d_a_ptrs = torch.tensor(A_addrs, device=device) + d_b_ptrs = torch.tensor(B_addrs, device=device) + d_c_ptrs = torch.tensor(C_addrs, device=device) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) + + def grid(meta): + return (meta['NUM_SM'],) + + grouped_matmul_kernel[grid]( + d_a_ptrs, + d_b_ptrs, + d_c_ptrs, + d_g_sizes, + d_g_lds, + group_size, + ) + + return group_C + + +def build_group_inputs(group_m, group_n, group_k): + assert len(group_m) == len(group_n) + assert len(group_n) == len(group_k) + + group_A = [] + group_B = [] + for m, n, k in zip(group_m, group_n, group_k): + group_A.append(torch.rand((m, k), device=DEV, dtype=torch.float16)) + group_B.append(torch.rand((k, n), device=DEV, dtype=torch.float16)) + return group_A, group_B + + +def run_group_gemm_case(group_m, group_n, group_k): + group_A, group_B = build_group_inputs(group_m, group_n, group_k) + + tri_out = group_gemm_fn(group_A, group_B) + ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] + + assert len(tri_out) == len(ref_out) + for tri_tensor, ref_tensor, m, n in zip(tri_out, ref_out, group_m, group_n): + assert tri_tensor.shape == (m, n) + assert tri_tensor.dtype == torch.float16 + torch.testing.assert_close(ref_tensor, tri_tensor, atol=1e-2, rtol=1e-3) + + +@pytest.mark.parametrize( + "group_m,group_n,group_k", + [([1024, 512, 256, 128], [1024, 512, 256, 128], [1024, 512, 256, 128])], +) +def test_grouped_gemm_tutorial_example(group_m, group_n, group_k): + run_group_gemm_case( + group_m=group_m, + group_n=group_n, + group_k=group_k, + ) + + +def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): + + def grid(meta): + return (meta['NUM_SM'],) + + grouped_matmul_kernel[grid]( + a_ptrs, + b_ptrs, + c_ptrs, + sizes, + lds, + group_size, + ) + + +def torch_perf_fn(group_A, group_B): + for a, b in zip(group_A, group_B): + torch.matmul(a, b) + + +def run_benchmark_case(N, provider): + group_size = 4 + group_A = [] + group_B = [] + group_C = [] + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + for _ in range(group_size): + A = torch.rand((N, N), device=DEV, dtype=torch.float16) + B = torch.rand((N, N), device=DEV, dtype=torch.float16) + C = torch.empty((N, N), device=DEV, dtype=torch.float16) + group_A.append(A) + group_B.append(B) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [N, N, N] + g_lds += [N, N, N] + + d_a_ptrs = torch.tensor(A_addrs, device=DEV) + d_b_ptrs = torch.tensor(B_addrs, device=DEV) + d_c_ptrs = torch.tensor(C_addrs, device=DEV) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEV) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEV) + + quantiles = [0.5, 0.2, 0.8] + + def bench_torch(): + torch_perf_fn(group_A, group_B) + + def bench_triton(): + triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size) + + if provider == 'torch': + ms, min_ms, max_ms = triton.testing.do_bench(bench_torch, quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(bench_triton, quantiles=quantiles) + + assert ms >= 0 + assert min_ms >= 0 + assert max_ms >= 0 + + +@pytest.mark.parametrize("N", [2**i for i in range(7, 11)]) +@pytest.mark.parametrize("provider", ["torch", "triton"]) +def test_grouped_gemm_benchmark_cases(N, provider): + run_benchmark_case(N=N, provider=provider) diff --git a/third_party/ascend/unittest/pytest_ut/test_09_persistent_matmul.py b/third_party/ascend/unittest/pytest_ut/test_09_persistent_matmul.py new file mode 100644 index 0000000000..3b2e3e3e48 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_09_persistent_matmul.py @@ -0,0 +1,326 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Persistent Matmul +===================== +""" + +import time + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver + +DEV = "npu" +DTYPE = torch.float16 + + +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +def get_num_compute_cores(): + return get_npu_properties()["num_aicore"] + + +def _matmul_launch_metadata(grid, kernel, args): + ret = {} + M, N, K = args["M"], args["N"], args["K"] + bytes_per_elem = args["c_ptr"].element_size() + ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K + ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) + return ret + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_sm = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_sm += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + pid_m = 0 + pid_n = 0 + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in range(0, k_tiles * tiles_per_sm): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + if ki == k_tiles - 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def get_configs(dtype): + return { + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + } + }[dtype] + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + M, K = a.shape + _, N = b.shape + configs = get_configs(a.dtype) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(meta): + return (triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),) + + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=configs["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs["GROUP_SIZE_M"], + num_stages=configs["num_stages"], + num_warps=configs["num_warps"], + ) + return c + + +def matmul_persistent(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + num_sms = get_num_compute_cores() + M, K = a.shape + _, N = b.shape + configs = get_configs(a.dtype) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(meta): + return (min(num_sms, triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"])),) + + matmul_kernel_persistent[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=configs["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs["GROUP_SIZE_M"], + NUM_SMS=num_sms, + num_stages=configs["num_stages"], + num_warps=configs["num_warps"], + ) + return c + + +def torch_matmul(a, b): + return torch.matmul(a, b) + + +def bench(K, reps=10): + M = 8192 + N = 8192 + a = torch.randn((M, K), device=DEV, dtype=DTYPE) + b = torch.randn((K, N), device=DEV, dtype=DTYPE) + + for _ in range(reps): + _ = torch_matmul(a, b) + time.sleep(0.01) + for _ in range(reps): + _ = matmul(a, b) + time.sleep(0.01) + for _ in range(reps): + _ = matmul_persistent(a, b) + time.sleep(0.01) + + +def validate(M, N, K): + a = torch.randn((M, K), device=DEV, dtype=DTYPE) + b = torch.randn((K, N), device=DEV, dtype=DTYPE) + + torch_result = torch_matmul(a, b) + naive_result = matmul(a, b) + persistent_result = matmul_persistent(a, b) + return torch_result, naive_result, persistent_result + + +@pytest.mark.skip(reason="temporarily skip persistent matmul validate cases until UB overflow issue is fixed") +@pytest.mark.parametrize( + "M,N,K", + [ + (32, 32, 32), + (8192, 8192, 512), + ], +) +def test_persistent_matmul_validate_cases(M, N, K): + torch.manual_seed(0) + torch_result, naive_result, persistent_result = validate(M, N, K) + + torch.testing.assert_close(naive_result, torch_result, atol=1.0, rtol=0) + torch.testing.assert_close(persistent_result, torch_result, atol=1.0, rtol=0) + torch.testing.assert_close(naive_result, persistent_result, atol=1.0, rtol=0) diff --git a/third_party/ascend/unittest/pytest_ut/test_10_gather_sorted.py b/third_party/ascend/unittest/pytest_ut/test_10_gather_sorted.py new file mode 100644 index 0000000000..38614b9405 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_10_gather_sorted.py @@ -0,0 +1,195 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Gather sorted +=============== +This is an example only for npu. +""" + +import pytest +import torch +import torch_npu +import triton +import triton.runtime.driver as driver +import triton.language as tl + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +# a torch-version gather_sorted benchmark +def torch_gather_sorted(embeddings, sorted_idxes, aux_idxes): + # make the result tensor + res = torch.empty((aux_idxes.shape[0], embeddings.shape[-1]), dtype=embeddings.dtype, device=embeddings.device) + + # scatter embeddings + res[aux_idxes] = embeddings[sorted_idxes] + + return res + + +# triton-version gather_sorted's kernel +@triton.jit +def gather_sorted_kernel(embeddings_ptr, sorted_indices_ptr, aux_indices_ptr, res_ptr, rows, cols, DEFAULT_VALUE: tl.constexpr, BIG_CORE_NUM: tl.constexpr, BIG_ROW_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE_SUB: tl.constexpr): + SMALL_ROW_BLOCK_SIZE = BIG_ROW_BLOCK_SIZE - 1 + + emb_dtype = embeddings_ptr.type.element_ty + default_value = tl.cast(DEFAULT_VALUE, dtype=emb_dtype) + + core_idx = tl.program_id(0) + # compute the the size and start index of block + row_block_size = BIG_ROW_BLOCK_SIZE if (core_idx < BIG_CORE_NUM) else SMALL_ROW_BLOCK_SIZE + row_start_idx = (core_idx * BIG_ROW_BLOCK_SIZE) if (core_idx < BIG_CORE_NUM) else (BIG_CORE_NUM * BIG_ROW_BLOCK_SIZE + (core_idx - BIG_CORE_NUM) * SMALL_ROW_BLOCK_SIZE) + + # this version has 3-buffers, initilize for buffers + row_block_size_0 = tl.cdiv(row_block_size, 3) + remain_row_block_size = row_block_size - row_block_size_0 + row_block_size_1 = tl.cdiv(remain_row_block_size, 2) + row_block_size_2 = remain_row_block_size - row_block_size_1 + + row_start_idx_0 = row_start_idx + row_start_idx_1 = row_start_idx + row_block_size_0 + row_start_idx_2 = row_start_idx + row_block_size_0 + row_block_size_1 + + # process blocks witn shape (row_block_size, COL_BLOCK_SIZE_SUB) one by one + for col_idx in tl.range(0, COL_BLOCK_SIZE, COL_BLOCK_SIZE_SUB): + + embedding_0 = tl.full((COL_BLOCK_SIZE_SUB, ), default_value, dtype=emb_dtype) + embedding_1 = embedding_0 + 0 + embedding_2 = embedding_0 + 0 + + emb_offsets = col_idx + tl.arange(0, COL_BLOCK_SIZE_SUB) + emb_mask = emb_offsets < cols + + prev_embedding_idx_0 = tl.cast(-1, dtype=tl.int32) + prev_embedding_idx_1 = tl.cast(-1, dtype=tl.int32) + prev_embedding_idx_2 = tl.cast(-1, dtype=tl.int32) + for row_idx in tl.range(row_start_idx_0, row_start_idx_1): + # process the first buffer + embedding_idx_0 = tl.load(sorted_indices_ptr + row_idx) + res_idx_0 = tl.load(aux_indices_ptr + row_idx) + + if (embedding_idx_0 != 0) and (embedding_idx_0 != prev_embedding_idx_0): + embedding_0 = tl.load(embeddings_ptr + embedding_idx_0 * cols + emb_offsets, emb_mask) + tl.store(res_ptr + res_idx_0 * cols + emb_offsets, embedding_0, emb_mask) + else: + tl.store(res_ptr + res_idx_0 * cols + emb_offsets, embedding_0, emb_mask) + + prev_embedding_idx_0 = embedding_idx_0 + + # process the second buffer + if (row_idx + row_block_size_0) < (row_start_idx_1 + row_block_size_1): + embedding_idx_1 = tl.load(sorted_indices_ptr + row_idx + row_block_size_0) + res_idx_1 = tl.load(aux_indices_ptr + row_idx + row_block_size_0) + + if (embedding_idx_1 != 0) and (embedding_idx_1 != prev_embedding_idx_1): + embedding_1 = tl.load(embeddings_ptr + embedding_idx_1 * cols + emb_offsets, emb_mask) + tl.store(res_ptr + res_idx_1 * cols + emb_offsets, embedding_1, emb_mask) + else: + tl.store(res_ptr + res_idx_1 * cols + emb_offsets, embedding_1, emb_mask) + + prev_embedding_idx_1 = embedding_idx_1 + + # process the third buffer + if (row_idx + row_block_size_0 + row_block_size_1) < (row_start_idx_2 + row_block_size_2): + embedding_idx_2 = tl.load(sorted_indices_ptr + row_idx + row_block_size_0 + row_block_size_1) + res_idx_2 = tl.load(aux_indices_ptr + row_idx + row_block_size_0 + row_block_size_1) + + if (embedding_idx_2 != 0) and (embedding_idx_2 != prev_embedding_idx_2): + embedding_2 = tl.load(embeddings_ptr + embedding_idx_2 * cols + emb_offsets, emb_mask) + tl.store(res_ptr + res_idx_2 * cols + emb_offsets, embedding_2, emb_mask) + else: + tl.store(res_ptr + res_idx_2 * cols + emb_offsets, embedding_2, emb_mask) + + prev_embedding_idx_2 = embedding_idx_2 + + +# triton-version gather_sorted's host +def triton_gather_sorted(embeddings: torch.Tensor, sorted_indices: torch.Tensor, aux_indices: torch.Tensor, default_value=1.0): + # constant settings for npu + ALIGNED = 32 + USE_SIZE = 96 * 1024 + CORE_NUM = get_npu_properties()["num_vectorcore"] + + n_rows = sorted_indices.shape[0] + n_cols = embeddings.shape[1] + # make the result tensor + output = torch.empty(n_rows, n_cols, dtype=embeddings.dtype, device=embeddings.device) + + # when writing an npu kernel using triton, + # you should note that the difference between BLOCK_SIZE and BLOCK_SIZE_SUB + # BLOCK_SIZE specifies the size of data that are processed in one program + col_size_aligned = triton.cdiv(embeddings.shape[-1] * embeddings.element_size(), ALIGNED) * ALIGNED // embeddings.element_size() + # the data are scattered to multiple programs, which can not be even + # some process more data, some process less + big_row_block_size = triton.cdiv(n_rows, CORE_NUM) + big_core_num = CORE_NUM - ((big_row_block_size * CORE_NUM) - n_rows) + col_block_size = col_size_aligned + # BLOCK_SIZE_SUB specifies the size of data that are processed in one loop of a program + col_block_size_sub = min(1024, col_size_aligned) + + grid = (min(n_rows, CORE_NUM), triton.cdiv(n_cols, col_block_size)) + # launch the kernel + gather_sorted_kernel[grid](embeddings, sorted_indices, aux_indices, output, n_rows, n_cols, default_value, BIG_CORE_NUM=big_core_num, BIG_ROW_BLOCK_SIZE=big_row_block_size, COL_BLOCK_SIZE=col_block_size, COL_BLOCK_SIZE_SUB=col_block_size_sub) + + return output + + +# genreate the desired inputs +def generate_inputs(index_shape, table_shape, dtype): + sorted_indices = torch.randint(1, table_shape[0], index_shape, dtype=torch.int32).npu() + mask = torch.rand_like(sorted_indices, dtype=torch.float).npu() < 0.2 + + # make sorted_indices + sorted_indices[mask] = 0 + sorted_indices, _ = torch.sort(sorted_indices) + counts = torch.bincount(sorted_indices) + _, _indices = torch.sort(counts[sorted_indices], descending=True, stable=True) + sorted_indices = sorted_indices[_indices] + + # make aux_indicess + aux_indices = torch.arange(0, index_shape[0], dtype=torch.int32).npu() + _indices = torch.randperm(aux_indices.size(0)) + aux_indices = aux_indices[_indices] + + # make table, the first contains only 1.0 + table = torch.randn(table_shape, dtype=dtype).npu() + table[0] = 1.0 + + return table, sorted_indices, aux_indices + + +# ==================== Pytest Test ==================== +@pytest.mark.parametrize("table_rows", [500, 1000]) +@pytest.mark.parametrize("table_cols", [16, 17, 31, 32, 63, 64, 128, 256, 819, 512, 1024, 8192, 1001, 2003, 17000]) +@pytest.mark.parametrize("index_num", [19, 123, 4321, 54321, 100, 200, 819, 500, 700, 1000]) +def test_gather_sorted(table_rows, table_cols, index_num): + table, sorted_indices, aux_indices = generate_inputs((index_num,), (table_rows, table_cols), torch.float) + + expect = torch_gather_sorted(table, sorted_indices, aux_indices).cpu() + torch.npu.synchronize() + actual = triton_gather_sorted(table, sorted_indices, aux_indices).cpu() + torch.npu.synchronize() + + torch.testing.assert_close(actual, expect) diff --git a/third_party/ascend/unittest/pytest_ut/test_11_rab_time.py b/third_party/ascend/unittest/pytest_ut/test_11_rab_time.py new file mode 100644 index 0000000000..fe8293b51c --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_11_rab_time.py @@ -0,0 +1,407 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Relative Attention Bias Timestamps +=============== +""" + +import math +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver + +NUM_BUCKETS = 128 +BUCKET_DIVISOR = 0.301 + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +def create_pos_w(train_len: int, num_layers: int) -> torch.Tensor: + return torch.arange(0, 2 * train_len + 1).unsqueeze(1).repeat(1, num_layers) + + +def create_past_valid_lens(bs: int, past_len: int) -> torch.Tensor: + return torch.randint(0, past_len, (bs,)) + + +def create_timestamps( + train_len: int, candidate_len: int, past_valid_lens: torch.Tensor +) -> torch.Tensor: + bs = past_valid_lens.size(0) + timestamps = torch.zeros(bs, train_len + candidate_len // 2) + for i, valid_len in enumerate(past_valid_lens): + if valid_len > 0: + timestamps[i, :valid_len] = torch.arange(1, valid_len.int() + 1) + + if candidate_len <= 0: + return timestamps + timestamps[:, -candidate_len // 2:] = train_len + 1 + + return timestamps + + +def create_timestamps_weights(num_layers: int): + return ( + torch.arange(0, NUM_BUCKETS + 1) + .repeat(num_layers) + .reshape(NUM_BUCKETS + 1, num_layers) + ) + + +def create_rab_time_grad(num_layers: int, batchsize: int, s: int): + return torch.rand(num_layers, batchsize, s, s) * 1e-4 + + +def create_bucket_timestamps(batchsize: int, s: int): + result = torch.arange(batchsize * s) % NUM_BUCKETS + result = result.unsqueeze(-1).repeat(1, 1, s) + return result + + +@triton.jit +def rab_time_forward_kernel( + inp, + out, + index, + index_len: tl.constexpr, + inp_row_stride: tl.constexpr, + clamp_max: tl.constexpr, + bucketization_divisor: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + COL_BLOCK_SIZE: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + + col_iter_num = tl.cdiv(BLOCK_SIZE, COL_BLOCK_SIZE) + + for col_idx in tl.range(0, col_iter_num): + cols_offsets = ( + pid0 * BLOCK_SIZE + col_idx * COL_BLOCK_SIZE + tl.arange(0, COL_BLOCK_SIZE) + ) + cols_mask = cols_offsets < index_len + + out_mask = cols_offsets < index_len + + index_val = tl.load(index + cols_offsets, mask=cols_mask, other=0.0) + index_val = tl.abs(index_val) + index_val = tl.minimum(tl.maximum(index_val, 1.0), clamp_max) + index_val = tl.log(index_val) + index_val = index_val / bucketization_divisor + index_val = tl.cast(index_val, tl.int64) + + inp_val = tl.load(inp + pid1 * inp_row_stride + tl.arange(0, inp_row_stride)) + out_val = tl.gather(inp_val, index_val, 0) + + tl.store(out + pid1 * index_len + cols_offsets, out_val, mask=out_mask) + + +def get_outer_loop_num(num_layers, index_len): + sub_num_layers = num_layers + while sub_num_layers * index_len >= 2**31 - 1: + sub_num_layers = sub_num_layers // 2 + outer_loop_num = (num_layers + sub_num_layers - 1) // sub_num_layers + remain_layers = num_layers % sub_num_layers + return outer_loop_num, sub_num_layers, remain_layers + + +def rab_time_forward_triton(ts_w, timestamps, bucketization_divisor): + ts_w_trans = ts_w.t().contiguous() + + bs, seq_len = timestamps.shape + infer_len = 2 * seq_len + num_layers = ts_w.shape[1] + num_buckets = ts_w.shape[0] - 1 + + timestamps_expanded = timestamps.unsqueeze(-1).repeat(1, 1, 2) + timestamps_expanded = timestamps_expanded.reshape( + bs, infer_len, 1 + ) - timestamps_expanded.reshape(bs, 1, infer_len) + + timestamps_expanded = timestamps_expanded.view(-1) + timestamps_expanded = timestamps_expanded.contiguous() + + clamp_max = torch.exp(torch.tensor(num_buckets * bucketization_divisor)).item() + index_len = bs * infer_len * infer_len + + out = torch.empty((num_layers, index_len), dtype=ts_w.dtype, device=ts_w.device) + outer_loop_num, sub_num_layers, remain_layers = get_outer_loop_num( + num_layers, index_len + ) + + CORE_NUM = get_npu_properties()["num_vectorcore"] + BLOCK_SIZE = math.ceil(index_len / CORE_NUM) + COL_BLOCK_SIZE = 8 * 1024 + + curr_layers = sub_num_layers + for i in range(outer_loop_num): + if i == outer_loop_num - 1 and remain_layers != 0: + curr_layers = remain_layers + + def grid(meta): + return (triton.cdiv(index_len, meta["BLOCK_SIZE"]), curr_layers) + + rab_time_forward_kernel[grid]( + ts_w_trans[i * sub_num_layers], + out[i * sub_num_layers], + timestamps_expanded, + index_len, + num_buckets + 1, + clamp_max, + bucketization_divisor, + BLOCK_SIZE, + COL_BLOCK_SIZE, + ) + + out = out.view(num_layers, bs, infer_len, infer_len) + + return out + + +@triton.jit +def rab_time_backward_kernel( + inp, src, index, index_len, BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr +): + pid0 = tl.program_id(axis=0) + total_col_num = ( + BLOCK_SIZE + if pid0 * BLOCK_SIZE + BLOCK_SIZE < index_len + else index_len - pid0 * BLOCK_SIZE + ) + COL_BLOCK_SIZE = min(COL_BLOCK_SIZE, total_col_num) + col_iter_num = (total_col_num + COL_BLOCK_SIZE - 1) // COL_BLOCK_SIZE + + for col_idx in tl.range(0, col_iter_num): + base_idx = 0 + base_idx = base_idx.to(index.dtype.element_ty) + + col_start_offset = col_idx * COL_BLOCK_SIZE + + acc_result = 0.0 + acc_result = acc_result.to(inp.dtype.element_ty) + cur_col_num = ( + COL_BLOCK_SIZE + if col_start_offset + COL_BLOCK_SIZE < total_col_num + else total_col_num - col_start_offset + ) + + for cur_idx in range(0, cur_col_num): + cur_offset = pid0 * BLOCK_SIZE + col_start_offset + cur_idx + + src_val = tl.load(src + cur_offset) + new_idx = tl.load(index + cur_offset) + + if base_idx == new_idx: + acc_result += src_val + else: + tl.atomic_add(inp + base_idx, acc_result) + + base_idx = new_idx + acc_result = 0.0 + acc_result = acc_result.to(inp.dtype.element_ty) + acc_result += src_val + + tl.atomic_add(inp + base_idx, acc_result) + + +def rab_time_backward_triton( + rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor +): + num_layers, b, s, _ = rab_time_grad.shape + tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to( + rab_time_grad.device + ) + + bucket_timestamps_expand = ( + bucket_timestamps.reshape(b, s // 2, 1, s // 2, 1) + .repeat(1, 1, 2, 1, 2) + .reshape(b, s, s) + .to(torch.int64) + ).view(-1) + + index_len = bucket_timestamps_expand.numel() + + rab_time_grad_f32 = rab_time_grad.to(torch.float32) + sorted_bucket_timestamps_expand, sorted_idx = torch.sort( + bucket_timestamps_expand.view(-1) + ) + + torch.npu.synchronize() + + def grid(meta): + return (triton.cdiv(index_len, meta["BLOCK_SIZE"]),) + + CORE_NUM = get_npu_properties()["num_vectorcore"] + BLOCK_SIZE = math.ceil(index_len / CORE_NUM) + + COL_BLOCK_SIZE = 8 * 1024 + + for layer_idx in range(num_layers): + curr_sorted_grad_f32 = rab_time_grad_f32[layer_idx].view(-1)[sorted_idx] + rab_time_backward_kernel[grid]( + tsw_grad[layer_idx], + curr_sorted_grad_f32, + sorted_bucket_timestamps_expand, + index_len, + BLOCK_SIZE, + COL_BLOCK_SIZE, + ) + + return tsw_grad + + +def rab_time_forward_golden( + ts_w: torch.Tensor, timestamps: torch.Tensor, bucketization_divisor: float +) -> torch.Tensor: + """ + torch realization of rab time forward for reference. + """ + infer_len = timestamps.shape[1] * 2 + bs = timestamps.shape[0] + num_layers = ts_w.shape[1] + + timestamps = timestamps.unsqueeze(-1).repeat(1, 1, 2) + diff_timestamps = timestamps.reshape(bs, infer_len, 1) - timestamps.reshape( + bs, 1, infer_len + ) + + clamp_max = torch.exp(torch.tensor(NUM_BUCKETS * BUCKET_DIVISOR)) + diff_timestamps = ( + torch.log(torch.abs(diff_timestamps).clamp(1, clamp_max)) + / bucketization_divisor + ) + bucket_timestamps = diff_timestamps.long() + bucket_timestamps = bucket_timestamps.view(-1) + result = torch.index_select(ts_w, dim=0, index=bucket_timestamps) + + result = result.t() + + result = result.view(num_layers, bs, infer_len, infer_len) + return result + + +def rab_time_backward_golden( + rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor +): + """ + torch realization of rab time backward for reference. + """ + num_layers, b, s, _ = rab_time_grad.shape + tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to( + rab_time_grad.device + ) + + bucket_timestamps_expand = ( + bucket_timestamps.reshape(b, s // 2, 1, s // 2, 1) + .repeat(1, 1, 2, 1, 2) + .reshape(b, s, s) + .to(torch.int64) + ) + for n, grad in enumerate(rab_time_grad.to(torch.float32)): + tsw_grad[n] = tsw_grad[n].scatter_add( + src=grad.view(-1), index=bucket_timestamps_expand.view(-1), dim=0 + ) + return tsw_grad + + +def run_rab_time_forward_case(num_layers, train_len, candidate_len, bs, dtype): + past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) + timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to( + torch.int32 + ) + timestamps_weights = create_timestamps_weights(num_layers).to(dtype) + timestamps = timestamps.npu() + timestamps_weights = timestamps_weights.npu() + + torch_npu.npu.synchronize() + + # triton output + rab_time_out_triton = rab_time_forward_triton( + ts_w=timestamps_weights, + timestamps=timestamps, + bucketization_divisor=BUCKET_DIVISOR, + ) + torch_npu.npu.synchronize() + + # pytorch output + rab_time_out_golden = rab_time_forward_golden( + ts_w=timestamps_weights, + timestamps=timestamps, + bucketization_divisor=BUCKET_DIVISOR, + ) + torch_npu.npu.synchronize() + + torch.testing.assert_close(rab_time_out_triton, rab_time_out_golden) + + +def run_rab_time_backward_case(num_layers: int, batchsize: int, s: int, dtype: torch.dtype): + grad = create_rab_time_grad(num_layers, batchsize, s).to(dtype).npu() + bucket_timestamps = ( + create_bucket_timestamps(batchsize, s // 2).to(torch.int32).npu() + ) + + torch_npu.npu.synchronize() + + golden_result = ( + rab_time_backward_golden(grad, bucket_timestamps).to(torch.float32).cpu() + ) + op_result = ( + rab_time_backward_triton(grad, bucket_timestamps).to(torch.float32).cpu() + ) + + loss = 1e-4 if dtype == torch.float32 else 1e-3 + torch.testing.assert_close(op_result, golden_result, rtol=loss, atol=loss) + + +@pytest.mark.parametrize( + "num_layers, train_len, candidate_len, batch_size, dtype", + [ + pytest.param( + 8, + 500, + 500, + 4, + torch.float32, + marks=pytest.mark.skip(reason="temporarily skip UB overflow case"), + ), + ], +) +def test_rab_time_forward(num_layers, train_len, candidate_len, batch_size, dtype): + torch.manual_seed(0) + run_rab_time_forward_case(num_layers, train_len, candidate_len, batch_size, dtype) + + +@pytest.mark.parametrize( + "num_layers, batch_size, seq_len, dtype", + [ + (8, 4, 1500, torch.float32), + ], +) +def test_rab_time_backward(num_layers, batch_size, seq_len, dtype): + torch.manual_seed(0) + run_rab_time_backward_case(num_layers, batch_size, seq_len, dtype) diff --git a/third_party/ascend/unittest/pytest_ut/test_12_hstu_attention.py b/third_party/ascend/unittest/pytest_ut/test_12_hstu_attention.py new file mode 100644 index 0000000000..72d6ff6ed4 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_12_hstu_attention.py @@ -0,0 +1,824 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +HSTU Attention +=============== +""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver +import numpy as np +import torch.nn.functional as F + +DEVICE = "npu" +BLOCK_FWD = 64 +BLOCK_BWD = 32 + + +@dataclass +class JaggedData: + grad: torch.Tensor + q: torch.Tensor + k: torch.Tensor + v: torch.Tensor + bias: torch.Tensor + mask: torch.Tensor + max_seq_len: int + seq_offset: np.ndarray + + +def get_npu_properties(coreType): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device)[coreType] + + +@triton.jit +def _hstu_attn_fwd_one_block( + q, + k_block_ptr, + v_block_ptr, + bias_block_ptr, + alpha, + MAX_SEQ_LEN, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + mask_block, +): + k = tl.load(k_block_ptr) + qk = tl.dot(q, tl.trans(k)) * alpha + if HAS_BIAS: + rel_attn_bias = tl.load(bias_block_ptr) + qk = qk + rel_attn_bias + silu = qk / (1.0 + tl.exp(-qk)) * (1.0 / MAX_SEQ_LEN) + if CAUSAL: + silu = tl.where(mask_block != 0, silu, 0) + v = tl.load(v_block_ptr) + silu = silu.to(v.dtype) + return tl.dot(silu, v) + + +@triton.jit +def _hstu_attn_fwd_compute( # noqa C901 + Q, + K, + V, + seq_offsets, + Out, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vh: tl.constexpr, + stride_om: tl.constexpr, + stride_oh: tl.constexpr, + alpha, + head_num, + MAX_SEQ_LEN, + off_batch, + off_head, + start_m, + seq_start, + seq_len, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + mask_block, + bias, +): + off_head = off_head.to(tl.int64) + off_seq = seq_start.to(tl.int64) + start_m = start_m.to(tl.int32) + + # initialize offsets + q_offset = off_seq * stride_qm + off_head * stride_qh + k_offset = off_seq * stride_kn + off_head * stride_kh + v_offset = off_seq * stride_vn + off_head * stride_kh + + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seq_len, BLOCK_D_Q), + strides=(stride_qm, 1), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_D_Q), + order=(1, 0), + ) + k_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(seq_len, BLOCK_D_Q), + strides=(stride_kn, 1), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_D_Q), + order=(1, 0), + ) + v_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seq_len, BLOCK_D_V), + strides=(stride_vn, 1), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_D_V), + order=(1, 0), + ) + q = tl.load(Q_block_ptr) + + acc = tl.zeros([BLOCK_M, BLOCK_D_V], dtype=tl.float32) + if CAUSAL: + low = 0 + high = start_m + BLOCK_M + else: + low = 0 + high = seq_len + + bias_block_ptr = None + if HAS_BIAS: + bias_block_ptr = tl.make_block_ptr( + base=bias + off_batch * head_num * MAX_SEQ_LEN * MAX_SEQ_LEN + off_head * MAX_SEQ_LEN * MAX_SEQ_LEN, + shape=(MAX_SEQ_LEN, MAX_SEQ_LEN), + strides=(MAX_SEQ_LEN, 1), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + for start_n in range(low, high, BLOCK_N): + acc += _hstu_attn_fwd_one_block( + q=q, + k_block_ptr=k_block_ptr, + v_block_ptr=v_block_ptr, + bias_block_ptr=bias_block_ptr, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + CAUSAL=CAUSAL and start_m == start_n, + HAS_BIAS=HAS_BIAS, + mask_block=mask_block, + ) + k_block_ptr = tl.advance(k_block_ptr, (BLOCK_N, 0)) + v_block_ptr = tl.advance(v_block_ptr, (BLOCK_N, 0)) + if HAS_BIAS: + bias_block_ptr = tl.advance(bias_block_ptr, (0, BLOCK_N)) + + # rematerialize offsets to save registers + offs_v_d = tl.arange(0, BLOCK_D_V) + off_o = Out + off_seq * stride_om + off_head * stride_oh + offs_m = start_m + tl.arange(0, BLOCK_M) + out_ptrs = off_o + offs_m[:, None] * stride_om + offs_v_d[None, :] + tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None]) + + +@triton.jit +def _hstu_attn_fwd( # noqa C901 + Q, + K, + V, + seq_offsets, + Out, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vh: tl.constexpr, + stride_om: tl.constexpr, + stride_oh: tl.constexpr, + alpha: tl.constexpr, + batch: tl.constexpr, + head_num: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + head_dim: tl.constexpr, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + CORE_NUM: tl.constexpr, + tasks: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + mask, + bias, +): + core_id = tl.program_id(0) + cur_batch = 0 + mask_block = None + if CAUSAL and mask is not None: + mask_ptr = tl.make_block_ptr( + base=mask, + shape=(MAX_SEQ_LEN, MAX_SEQ_LEN), + strides=(MAX_SEQ_LEN, 1), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_M), + order=(1, 0), + ) + mask_block = tl.load(mask_ptr) + for col in range(core_id, tasks, CORE_NUM): + seq_end = tl.load(seq_offsets + cur_batch + 1) + start_m = col * BLOCK_M + while start_m >= seq_end * head_num // 2: + cur_batch += 1 + seq_end = tl.load(seq_offsets + cur_batch + 1) + seq_start = tl.load(seq_offsets + cur_batch) + seq_len = seq_end - seq_start + off_batch = cur_batch + off_head = (start_m - seq_start * head_num // 2) // (seq_len // 2) + start_m_1 = (start_m - seq_start * head_num // 2) % (seq_len // 2) + start_m_2 = seq_len - start_m_1 - BLOCK_M + _hstu_attn_fwd_compute(Q, K, V, seq_offsets, Out, stride_qm, stride_qh, stride_kn, stride_kh, + stride_vn, stride_vh, stride_om, stride_oh, alpha, head_num, MAX_SEQ_LEN, off_batch, off_head, + start_m_1, seq_start, seq_len, CAUSAL, HAS_BIAS, head_dim, head_dim, BLOCK_M, BLOCK_N, + mask_block=mask_block, + bias=bias, + ) + _hstu_attn_fwd_compute(Q, K, V, seq_offsets, Out, stride_qm, stride_qh, stride_kn, stride_kh, + stride_vn, stride_vh, stride_om, stride_oh, alpha, head_num, MAX_SEQ_LEN, off_batch, off_head, + start_m_2, seq_start, seq_len, CAUSAL, HAS_BIAS, head_dim, head_dim, BLOCK_M, BLOCK_N, + mask_block=mask_block, + bias=bias, + ) + + +@triton.jit +def _hstu_attn_bwd_one_block( # noqa C901 + start_m, + offs_n, + offs_m, + q_ptrs, + dq_ptrs, + mask_n, + do_ptrs, + dk, + dv, + k, + v, + pos_offs_n, + seq_len, + max_ids, + stride_qm, + stride_dom, + stride_dqm, + alpha, + MAX_SEQ_LEN, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + bias_block_ptr, +): + pos_offs_m = offs_m + start_m + mask_m = pos_offs_m < seq_len + # recompute qk and silu + q = tl.load( + q_ptrs + start_m * stride_qm, + mask=mask_m[:, None], + other=0.0, + ) + q_trans = tl.trans(q) + qk_trans = tl.dot(k, q_trans) * alpha + if HAS_BIAS: + rel_attn_bias = tl.load(bias_block_ptr) + qk_trans = qk_trans + tl.trans(rel_attn_bias) + sig_trans = 1.0 / (1.0 + tl.exp(-qk_trans)) + silu_trans = qk_trans * sig_trans * (1.0 / MAX_SEQ_LEN) + if CAUSAL: + invalid_mask_trans = pos_offs_m[None, :] == offs_n[:, None] + pos_offs_m_minus_n = pos_offs_m[None, :] - pos_offs_n[:, None] + invalid_mask_trans = invalid_mask_trans | (pos_offs_m_minus_n > 0) + silu_trans = tl.where(invalid_mask_trans, silu_trans, 0) + silu_trans = silu_trans.to(k.dtype) + # compute dv + do = tl.load( + do_ptrs + start_m * stride_dom, + mask=mask_m[:, None], + other=0.0, + ) + dv += tl.dot(silu_trans, do) + # compute dk and dq (dqk = do * v^T dk = dqk^T * q dq = dqk * k) + dqk_trans = tl.dot(v, tl.trans(do)) + dqk_trans = dqk_trans * sig_trans * (1 + qk_trans * (1 - sig_trans)) * (1.0 / MAX_SEQ_LEN) + if CAUSAL: + dqk_trans = tl.where(invalid_mask_trans, dqk_trans, 0) + dqk_trans = dqk_trans.to(k.dtype) + dq = tl.load( + dq_ptrs + start_m * stride_dqm, + mask=mask_m[:, None], + other=0.0, + ) + dq += tl.dot(tl.trans(dqk_trans), k) * alpha + tl.store( + dq_ptrs + start_m * stride_dqm, + dq, + mask=mask_m[:, None], + ) + # Note: the factor `alpha` is delayed until the end of the function to reduce the cost + dk += tl.dot(dqk_trans, q) + return dk, dv + + +@triton.jit +def _hstu_attn_bwd_one_col_block( # noqa C901 + start_n, + seq_len, + Q, + K, + V, + DOut, + DQ, + DK, + DV, + stride_qm, + stride_kn, + stride_vn, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + alpha, + MAX_SEQ_LEN, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + bias, +): + # Work on the subsequence dv[start_n, start_n + BLOCK_N, :] + if CAUSAL: + low = start_n + high = seq_len + else: + low = 0 + high = seq_len + + # initialize row/col offsets + offs_m = tl.arange(0, BLOCK_M) + offs_qk_d = tl.arange(0, BLOCK_D_Q) + offs_v_d = tl.arange(0, BLOCK_D_V) + offs_n = start_n + tl.arange(0, BLOCK_N) + + dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_qk_d[None, :]) + dk = tl.zeros([BLOCK_N, BLOCK_D_Q], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_V], dtype=tl.float32) + + mask_n = offs_n < seq_len + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_qk_d[None, :]) + do_ptrs = DOut + (offs_m[:, None] * stride_dom + offs_v_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_qk_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_v_d[None, :]) + k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) + v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) + + max_ids = seq_len + pos_offs_n = offs_n + # loop over rows + for start_m in tl.range(low, high, BLOCK_M): + bias_block_ptr = None + if HAS_BIAS: + bias_block_ptr = tl.make_block_ptr( + base=bias, + shape=(MAX_SEQ_LEN, MAX_SEQ_LEN), + strides=(MAX_SEQ_LEN, 1), + offsets=(start_m, start_n), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + start_m = tl.multiple_of(start_m, BLOCK_M) + dk, dv = _hstu_attn_bwd_one_block( + start_m=start_m, + offs_n=offs_n, + offs_m=offs_m, + q_ptrs=q_ptrs, + dq_ptrs=dq_ptrs, + mask_n=mask_n, + do_ptrs=do_ptrs, + dk=dk, + dv=dv, + k=k, + v=v, + pos_offs_n=pos_offs_n, + seq_len=seq_len, + max_ids=max_ids, + stride_qm=stride_qm, + stride_dom=stride_dom, + stride_dqm=stride_dqm, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + CAUSAL=CAUSAL, + HAS_BIAS=HAS_BIAS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + bias_block_ptr=bias_block_ptr, + ) + # write-back + dk = dk * alpha + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_v_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_qk_d[None, :]) + tl.store(dv_ptrs, dv.to(k.dtype), mask=mask_n[:, None]) + tl.store(dk_ptrs, dk.to(k.dtype), mask=mask_n[:, None]) + + +@triton.jit +def _hstu_attn_bwd( # noqa C901 + Q, K, V, Grad, DQ, DK, DV, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vh: tl.constexpr, + stride_dom: tl.constexpr, + stride_doh: tl.constexpr, + seq_offsets, + alpha: tl.constexpr, + batch: tl.constexpr, + head_num: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + head_dim: tl.constexpr, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + bias, +): + off = tl.program_id(0) + off_batch = off // head_num + off_head = off % head_num + off_head = off_head.to(tl.int64) + seq_start = tl.load(seq_offsets + off_batch).to(tl.int64) + seq_end = tl.load(seq_offsets + off_batch + 1) + seq_len = (seq_end - seq_start).to(tl.int32) + # offset pointers for batch/head + q_offset = seq_start * stride_qm + off_head * stride_qh + k_offset = seq_start * stride_kn + off_head * stride_kh + v_offset = seq_start * stride_vn + off_head * stride_vh + grad_offset = seq_start * stride_dom + off_head * stride_doh + bias_offset = off_batch * head_num * MAX_SEQ_LEN * MAX_SEQ_LEN + off_head * MAX_SEQ_LEN * MAX_SEQ_LEN + for start_n in range(0, seq_len, BLOCK_N): + _hstu_attn_bwd_one_col_block( + start_n=start_n, + seq_len=seq_len, + Q=Q + q_offset, + K=K + k_offset, + V=V + v_offset, + DOut=Grad + grad_offset, + DQ=DQ + q_offset, + DK=DK + k_offset, + DV=DV + v_offset, + stride_qm=stride_qm, + stride_kn=stride_kn, + stride_vn=stride_vn, + stride_dom=stride_dom, + stride_dqm=stride_qm, + stride_dkn=stride_kn, + stride_dvn=stride_vn, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + CAUSAL=CAUSAL, + HAS_BIAS=HAS_BIAS, + BLOCK_D_Q=head_dim, + BLOCK_D_V=head_dim, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + bias=bias + bias_offset if HAS_BIAS else bias, + ) + + +def triton_hstu_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + max_seq_len: int, + alpha: float, + causal: bool, + mask: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch = seq_offsets.numel() - 1 + total_seq, head_num, head_dim = q.shape + out = torch.empty_like(v) + BLOCK_M = BLOCK_FWD + BLOCK_N = BLOCK_FWD + if total_seq == 0: + print("error") + return out + has_bias = bias is not None + core_num = get_npu_properties('num_aicore') + tasks = total_seq * head_num // BLOCK_M // 2 + grid = (core_num, 1, 1) + _hstu_attn_fwd[grid](q, k, v, seq_offsets, out, q.stride(0), q.stride(1), k.stride(0), k.stride(1), + v.stride(0), v.stride(1), out.stride(0), out.stride(1), alpha, batch, head_num, max_seq_len, head_dim, + causal, has_bias, core_num, tasks, BLOCK_M, BLOCK_N, mask, bias, + ) + return out + + +def triton_hstu_attention_bwd( + grad: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + max_seq_len: int, + alpha: float, + causal: bool, + bias: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + if grad.shape[0] == 0: + return dq, dk, dv + batch = seq_offsets.numel() - 1 + _, head_num, head_dim = q.shape + has_bias = bias is not None + grid = (batch * head_num, 1,) + _hstu_attn_bwd[grid](q, k, v, grad, dq, dk, dv, + q.stride(0), q.stride(1), k.stride(0), k.stride(1), v.stride(0), v.stride(1), + grad.stride(0), grad.stride(1), seq_offsets, alpha, batch, head_num, max_seq_len, head_dim, + causal, has_bias, BLOCK_BWD, BLOCK_BWD, bias, + ) + return dq, dk, dv + + +def jagged_data_gen(batch_size, max_seq_len, num_heads, attention_dim, dataType) -> JaggedData: + seq_array = np.arange(256, max_seq_len + 1, 256) + seq_lens = np.random.choice(seq_array, size=batch_size) + if not np.isin(max_seq_len, seq_lens): + seq_lens[np.random.randint(0, batch_size)] = max_seq_len + seq_offset = torch.concat((torch.zeros((1,), dtype=torch.int64), + torch.cumsum(torch.from_numpy(seq_lens), axis=0))).to(torch.int64).numpy() + max_seq_len = np.max(seq_lens) + total_seqs = np.sum(seq_lens) + grad = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + q = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + k = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + v = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + + bias = torch.empty(batch_size, num_heads, max_seq_len, max_seq_len, dtype=dataType).uniform_(-1, 1) + mask = 1 - torch.triu(torch.ones(batch_size, num_heads, max_seq_len, max_seq_len), diagonal=1).to(torch.float32) + return JaggedData( + grad=grad, + q=q, + k=k, + v=v, + bias=bias, + mask=mask, + max_seq_len=max_seq_len, + seq_offset=seq_offset, + ) + + +def dense_to_jagged(q, dense_tensor, seq_lens): + tensor = torch.zeros_like(q) + offset = 0 + for batch_id, seq_len in enumerate(seq_lens): + tensor[offset: offset + seq_len, :, :] = dense_tensor[batch_id, 0: seq_len, :, :] + offset = offset + seq_len + return tensor + + +def jagged_to_dense(jagged_tensor, seq_lens, head_nums, atten_dim): + need_pad_seq = [] + offset = 0 + for _, seq_len in enumerate(seq_lens): + src_tensor = jagged_tensor[offset: offset + seq_len, :, :].reshape(seq_len, head_nums, atten_dim) + need_pad_seq.append(src_tensor) + offset = offset + seq_len + + dense_tensor = torch.nn.utils.rnn.pad_sequence(need_pad_seq, batch_first=True) + return dense_tensor + + +def gloden_fwd(q, k, v, mask, alpha, seq_offset, attnBias, max_seq_len, enable_mask, enableBias, dataType): + head_nums = q.shape[1] + head_dim = q.shape[2] + batch_size = attnBias.shape[0] + seq_lens = np.zeros((batch_size, )).astype(np.int64) + for batch_id in range(batch_size): + seq_lens[batch_id] = seq_offset[batch_id + 1] - seq_offset[batch_id] + q_dens = jagged_to_dense(q, seq_lens, head_nums, head_dim).to(dataType) + k_dens = jagged_to_dense(k, seq_lens, head_nums, head_dim).to(dataType) + v_dens = jagged_to_dense(v, seq_lens, head_nums, head_dim).to(dataType) + q_dens = q_dens.permute(0, 2, 1, 3) + k_dens = k_dens.permute(0, 2, 3, 1) + v_dens = v_dens.permute(0, 2, 1, 3) + + qk_attn = torch.matmul(q_dens, k_dens) * alpha + qk_attn = qk_attn.to(torch.float32) + attnBias = attnBias.to(torch.float32) + mask = mask.to(torch.float32) + if enableBias: + qk_attn = qk_attn + attnBias + silu = F.silu(qk_attn) * (1 / max_seq_len) + if enable_mask: + silu = silu * mask + silu = silu.to(dataType) + atten_output = torch.matmul(silu, v_dens) + + atten_output = atten_output.permute(0, 2, 1, 3) + atten_output = dense_to_jagged(q, atten_output, seq_lens) + return atten_output.to(dataType) + + +def run_fwd_case(batch_size, max_seq_len, num_heads, attention_dim, data_type): + alpha = 1 + jagged_data = jagged_data_gen(batch_size, max_seq_len, num_heads, attention_dim, data_type) + # golden 输出 + golden_output = gloden_fwd( + jagged_data.q, + jagged_data.k, + jagged_data.v, + jagged_data.mask, + alpha, + jagged_data.seq_offset, + jagged_data.bias, + jagged_data.max_seq_len, + True, + False, + data_type, + ) + # triton 输出 + seq_offsets = torch.tensor(jagged_data.seq_offset, dtype=torch.int64, device=DEVICE) + triton_output = triton_hstu_attention_fwd( + q=jagged_data.q.npu(), + k=jagged_data.k.npu(), + v=jagged_data.v.npu(), + seq_offsets=seq_offsets, + max_seq_len=int(jagged_data.max_seq_len), + alpha=alpha, + causal=True, + mask=jagged_data.mask.npu(), + ) + loss = 1e-4 + if data_type == torch.float16: + loss = 1e-3 + elif data_type == torch.bfloat16: + loss = 1e-2 + torch.testing.assert_close(triton_output.cpu(), golden_output, atol=loss, rtol=loss) + + +def golden_bwd(grad, q, k, v, bias, mask, max_seq_len, seq_offset, enable_mask, silu_scale, enable_bias, data_type): + def jagged_to_dense_bwd(jagged_tensor, seq_lens, max_seq_len, head_num, head_dim): + batch_size = len(seq_lens) + dense_tensor = torch.zeros(batch_size, max_seq_len, head_num, head_dim, dtype=jagged_tensor.dtype) + + offset = 0 + for batch_id, seq_len in enumerate(seq_lens): + dense_tensor[batch_id, :seq_len, :, :] = jagged_tensor[offset: offset + seq_len, :, :] + offset = offset + seq_len + + return dense_tensor + + def dense_to_jagged_bwd(jagged_tensor, dense_tensor, seq_lens): + tensor = torch.zeros_like(jagged_tensor) + + offset = 0 + for batch_id, seq_len in enumerate(seq_lens): + tensor[offset: offset + seq_len, :, :] = dense_tensor[batch_id, 0: seq_len, :, :] + offset = offset + seq_len + + return tensor + + q = q.cpu() + k = k.cpu() + v = v.cpu() + grad = grad.cpu() + head_nums = grad.shape[1] + head_dim = grad.shape[2] + batch_size = bias.shape[0] + seq_lens = np.zeros((batch_size,)).astype(np.int64) + for batch_id in range(batch_size): + seq_lens[batch_id] = seq_offset[batch_id + 1] - seq_offset[batch_id] + grad_dens = jagged_to_dense_bwd(grad, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + q_dens = jagged_to_dense_bwd(q, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + k_dens = jagged_to_dense_bwd(k, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + v_dens = jagged_to_dense_bwd(v, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + actual_seq_lens = torch.from_numpy(seq_lens).reshape(batch_size, 1, 1, 1).to(data_type) + actual_seq_lens = torch.broadcast_to(actual_seq_lens, bias.shape) + qk = torch.matmul(q_dens.permute(0, 2, 1, 3), k_dens.permute(0, 2, 3, 1)) + gv = torch.matmul(grad_dens.permute(0, 2, 1, 3), v_dens.permute(0, 2, 3, 1)) + qk = qk.float() + gv = gv.float() + bias = bias.float() + if enable_mask: + mask = mask.to(data_type) + mask = mask.float() + if enable_bias: + bias = bias.to(data_type) + bias = bias.float() + qkb = qk + bias + else: + qkb = qk + real_silu_scale = 1 / max_seq_len if silu_scale == 0.0 else silu_scale + + if enable_mask: + score = F.silu(qkb) * real_silu_scale * mask + else: + score = F.silu(qkb) * real_silu_scale + score = score.to(data_type) + v_grad_dens = torch.matmul(score.permute(0, 1, 3, 2), grad_dens.permute(0, 2, 1, 3)).permute(0, 2, 1, 3) + if enable_mask: + bias_grad = gv * real_silu_scale * mask * F.sigmoid(qkb) * (1 + qkb * (1 - F.sigmoid(qkb))) + else: + bias_grad = gv * real_silu_scale * F.sigmoid(qkb) * (1 + qkb * (1 - F.sigmoid(qkb))) + bias_grad = bias_grad.to(data_type) + k_grad_dens = torch.matmul(bias_grad.permute(0, 1, 3, 2), q_dens.permute(0, 2, 1, 3)).permute(0, 2, 1, 3) + q_grad_dens = torch.matmul(bias_grad, k_dens.permute(0, 2, 1, 3)).permute(0, 2, 1, 3) + bias_grad = bias_grad.cpu() + q_grad_dens = q_grad_dens.cpu() + q_grad = dense_to_jagged_bwd(q, q_grad_dens, seq_lens) + k_grad_dens = k_grad_dens.cpu() + k_grad = dense_to_jagged_bwd(k, k_grad_dens, seq_lens) + v_grad_dens = v_grad_dens.cpu() + v_grad = dense_to_jagged_bwd(v, v_grad_dens, seq_lens) + torch.npu.synchronize() + return q_grad, k_grad, v_grad, bias_grad + + +def run_bwd_case(batch_size, max_seq_len, num_heads, attention_dim, data_type): + alpha = 1 + jagged_data = jagged_data_gen(batch_size, max_seq_len, num_heads, attention_dim, data_type) + # golden 输出 + q_grad_golden, k_grad_golden, v_grad_golden, _ = golden_bwd( + jagged_data.grad, + jagged_data.q, + jagged_data.k, + jagged_data.v, + jagged_data.bias, + jagged_data.mask, + jagged_data.max_seq_len, + jagged_data.seq_offset, + True, + 0, + False, + data_type, + ) + + # triton 输出 + seq_offsets = torch.tensor(jagged_data.seq_offset, dtype=torch.int64, device=DEVICE) + dq, dk, dv = triton_hstu_attention_bwd( + grad=jagged_data.grad.npu(), + q=jagged_data.q.npu(), + k=jagged_data.k.npu(), + v=jagged_data.v.npu(), + seq_offsets=seq_offsets, + max_seq_len=int(jagged_data.max_seq_len), + alpha=alpha, + causal=True, + ) + loss = 1e-4 + if data_type == torch.float16: + loss = 1e-3 + elif data_type == torch.bfloat16: + loss = 1e-2 + torch.testing.assert_close(dq.cpu(), q_grad_golden.cpu(), atol=loss, rtol=loss) + torch.testing.assert_close(dk.cpu(), k_grad_golden.cpu(), atol=loss, rtol=loss) + torch.testing.assert_close(dv.cpu(), v_grad_golden.cpu(), atol=loss, rtol=loss) + + +@pytest.mark.parametrize( + "batch_size, max_seq_len, num_heads, attention_dim, data_type", + [ + (2, 1024, 2, 32, torch.float32), + ], +) +def test_hstu_attention_fwd(batch_size, max_seq_len, num_heads, attention_dim, data_type): + np.random.seed(0) + torch.manual_seed(0) + run_fwd_case(batch_size, max_seq_len, num_heads, attention_dim, data_type) + + +@pytest.mark.parametrize( + "batch_size, max_seq_len, num_heads, attention_dim, data_type", + [ + (2, 1024, 2, 32, torch.float32), + ], +) +def test_hstu_attention_bwd(batch_size, max_seq_len, num_heads, attention_dim, data_type): + np.random.seed(0) + torch.manual_seed(0) + run_bwd_case(batch_size, max_seq_len, num_heads, attention_dim, data_type) diff --git a/third_party/ascend/unittest/pytest_ut/test_13_matrix_multiplication_optimized.py b/third_party/ascend/unittest/pytest_ut/test_13_matrix_multiplication_optimized.py new file mode 100644 index 0000000000..b71c10d682 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_13_matrix_multiplication_optimized.py @@ -0,0 +1,239 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver +import triton.language.extra.cann.extension as extension + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), + ], + key=["M", "N", "K"] +) +@triton.jit +def matmul_kernel( + mat_a, mat_b, mat_c, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + num_cores: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_TRESHHOLD: tl.constexpr, +): + pid = tl.program_id(axis=0) + task_m_idx = 0 + task_n_idx = 0 + + ''' + 水平分核方式每个任务块编号如下 + [0, 1, 2, 3, 4, 5, 6, 7] + [8, 9, 10, 11, 12, 13, 14, 15] + [16, 17, 18, 19, 20, 21, 22, 23] + [24, 25, 26, 27, 28, 29, 30, 31] + [32, 33, 34, 35, 36, 37, 38, 39] + [40, 41, 42, 43, 44, 45, 46, 47] + [48, 49, 50, 51, 52, 53, 54, 55] + [56, 57, 58, 59, 60, 61, 62, 63] + 0核处理 0 20 40 60 4块任务 + 1核处理 1 21 41 61 4块任务 + 2核处理 2 22 42 62 4块任务 + ... + 19核处理 19 39 59 3块任务 + + 大shape下如果使用传统水平分核方式,会有如下问题 + 1:同一时间大量核心需要访问同一块左矩阵内存,产生Bank冲突,导致硬件访问效率降低 + 2:当完成一整行mat_c运算时,已经将所有右矩阵数据全部使用上,右矩阵较大时会超过L2Cache的容量上限, + 从而导致L2Cache的搬入及换出,此后每行运算都会或多或少产生CacheMiss,导致L2Cche命中率较低,影响 + 算子执行效率 + 此处使用8 * 8对角线分核方式可以按8 * 8的方块沿对角线方向分核计算,可以很大程度优化上面两点。 + + 此处以8*8对角线分核为例,实际以BLOCK_TRESHHOLD为tune参数选择最优的阈值 + 8 * 8 对角线分核方式中,每8 * 8分格内任务块编号如下 + [0, 8, 16, 24, 32, 40, 48, 56] + [57, 1, 9, 17, 25, 33, 41, 49] + [50, 58, 2, 10, 18, 26, 34, 42] + [43, 51, 59, 3, 11, 19, 27, 35] + [36, 44, 52, 60, 4, 12, 20, 28] + [29, 37, 45, 53, 61, 5, 13, 21] + [22, 30, 38, 46, 54, 62, 6, 14] + [15, 23, 31, 39, 47, 55, 63, 7] + + M轴方向超过8个基本块时,使用对角线分核可以明显减小Bank冲突 + 当右矩阵大小超过L2Cache大小时,采取对角线分核可以提升L2Cache利用率 + 所以当矩阵在M和N方向均超过8块时使能对角线分核即可有优化,当右矩阵大小超过L2Cache大小时优化效果尤为明显 + ''' + NUM_BLOCKS_M = triton.cdiv(M, BLOCK_M) + NUM_BLOCKS_N = triton.cdiv(N, BLOCK_N) + NUM_BLOCKS = NUM_BLOCKS_M * NUM_BLOCKS_N + # 当任务量较多时,可以使能对角线分核策略进行优化 + if NUM_BLOCKS_M >= BLOCK_TRESHHOLD and NUM_BLOCKS_N >= BLOCK_TRESHHOLD: + for block_idx in range( + pid, NUM_BLOCKS, num_cores + ): + # 8 * 8 对角线分核代码实现 + curThresholdM = BLOCK_TRESHHOLD if block_idx < (NUM_BLOCKS_M // BLOCK_TRESHHOLD * BLOCK_TRESHHOLD) * NUM_BLOCKS_N else NUM_BLOCKS_M % BLOCK_TRESHHOLD + curThresholdM_thresholdN = curThresholdM * BLOCK_TRESHHOLD + curThresholdN = BLOCK_TRESHHOLD if block_idx % (NUM_BLOCKS_N * BLOCK_TRESHHOLD) < (curThresholdM * NUM_BLOCKS_N) // curThresholdM_thresholdN * curThresholdM_thresholdN else NUM_BLOCKS_N % BLOCK_TRESHHOLD + localRelativeBlock = block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) % (BLOCK_TRESHHOLD * curThresholdM) + task_m_idx = localRelativeBlock % curThresholdM + block_idx // (BLOCK_TRESHHOLD * NUM_BLOCKS_N) * BLOCK_TRESHHOLD + # 求最小公倍数,方便求基本块的坐标 + x, y = curThresholdM, curThresholdN if curThresholdM > curThresholdN else curThresholdN, curThresholdM + while y != 0: + x, y = y, x % y + lcm = curThresholdM * curThresholdN // x + task_n_idx = (localRelativeBlock + (localRelativeBlock // lcm)) % curThresholdN + block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) // curThresholdM_thresholdN * BLOCK_TRESHHOLD + + m_start = task_m_idx * BLOCK_M + n_start = task_n_idx * BLOCK_N + + mat_c_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k_start in range(0, K, BLOCK_K): + mat_a_offset = ((m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + ( + k_start + tl.arange(0, BLOCK_K) + )[None, :] + mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (k_start + tl.arange(0, BLOCK_K)) < K + )[None, :] + mat_a_block = tl.load(mat_a + mat_a_offset, mask=mat_a_mask, other=0.0) + extension.compile_hint(mat_a_block, "dot_pad_only_k") + mat_b_offset = ((k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + mat_b_block = tl.load(mat_b + mat_b_offset, mask=mat_b_mask, other=0.0) + extension.compile_hint(mat_b_block, "dot_pad_only_k") + mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask=mat_c_mask) + else: + # 传统顺序分核 + for block_idx in range( + pid, NUM_BLOCKS, num_cores + ): + task_m_idx = block_idx // NUM_BLOCKS_N + task_n_idx = block_idx % NUM_BLOCKS_N + m_start = task_m_idx * BLOCK_M + n_start = task_n_idx * BLOCK_N + + mat_c_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k_start in range(0, K, BLOCK_K): + mat_a_offset = ((m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + ( + k_start + tl.arange(0, BLOCK_K) + )[None, :] + mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (k_start + tl.arange(0, BLOCK_K)) < K + )[None, :] + mat_a_block = tl.load(mat_a + mat_a_offset, mask=mat_a_mask, other=0.0) + extension.compile_hint(mat_a_block, "dot_pad_only_k") + mat_b_offset = ((k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + mat_b_block = tl.load(mat_b + mat_b_offset, mask=mat_b_mask, other=0.0) + extension.compile_hint(mat_b_block, "dot_pad_only_k") + mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask=mat_c_mask) + + +def triton_matmul( + mat_a, + mat_b, +): + m = mat_a.shape[0] + k = mat_a.shape[1] + n = mat_b.shape[1] + mat_c = torch.empty(m, n, dtype=mat_a.dtype, device=mat_a.device) + + ''' + NPU芯片更加亲和512B对齐场景,如下分块通用性能较好,可以使用autotune选取最优 + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 256 + ''' + + num_cores = get_npu_properties()["num_aicore"] + + matmul_kernel[(num_cores,)]( + mat_a, + mat_b, + mat_c, + m, + n, + k, + num_cores + ) + return mat_c + + +# ==================== Pytest Test ==================== +def test_matmul_extension(): + M = 2048 + K = 7168 + N = 16384 + + mat_a = torch.randn([M, K], dtype=torch.bfloat16, device="npu") + mat_b = torch.randn([K, N], dtype=torch.bfloat16, device="npu") + + result = triton_matmul(mat_a, mat_b) + golden = torch.matmul(mat_a, mat_b) + + mask = golden.abs() < 1.0 + tmpatol = tmprtol = 2 ** -6 + + torch.testing.assert_close(result[mask], golden[mask], atol=tmpatol, rtol=0) + torch.testing.assert_close(result[~mask], golden[~mask], atol=0, rtol=tmprtol) diff --git a/third_party/ascend/unittest/pytest_ut/test_14_accuracy_comparison.py b/third_party/ascend/unittest/pytest_ut/test_14_accuracy_comparison.py new file mode 100644 index 0000000000..f8618dd7bc --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_14_accuracy_comparison.py @@ -0,0 +1,149 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + + +def run_add(x0, x1): + """ + 测试 Triton 实现的向量加法与 PyTorch 的结果,精度比对是否一致。 + + 步骤: + 1. 使用 PyTorch 计算参考结果(torch_ref) + 2. 使用 Triton 编写 kernel 并计算结果(triton_cal) + 3. 调用 accuracy_comparison 进行精度比对 + """ + + # 1. 使用 PyTorch 作为参考实现(golden truth) + def torch_func(x0, x1): + res = x0 + x1 + return res + + # 2. 定义 Triton kernel(在 NPU/GPU 上执行) + @triton.jit + def triton_kernel_add( + out_ptr0, # 输出指针:结果存储位置 + in_ptr0, # 输入指针0:x0 的起始地址 + in_ptr1, # 输入指针1:x1 的起始地址 + XS: tl.constexpr # constexpr 参数:向量长度,在编译时确定 + ): + # 生成 [0, 1, 2, ..., XS-1] 的索引数组 + idx = tl.arange(0, XS) + # 从 in_ptr0 + idx 处加载 x0 的值 + tmp0 = tl.load(in_ptr0 + idx) + # 从 in_ptr1 + idx 处加载 x1 的值 + tmp1 = tl.load(in_ptr1 + idx) + # 执行加法 + tmp2 = tmp0 + tmp1 + # 将结果写入 out_ptr0 + idx + tl.store(out_ptr0 + idx, tmp2) + + # 3. Triton 封装函数:调用 kernel 并返回结果 + def triton_func(x0, x1): + y0 = torch.empty_like(x0) # 创建与输入形状、dtype 相同的输出张量 + # 启动 kernel:grid = [1, 1, 1] 表示仅使用一个 block + # 注意:XS 必须作为参数传入,因为它是 tl.constexpr 类型 + triton_kernel_add[1, 1, 1](y0, x0, x1, XS=x0.numel()) + return y0 + + # 4. 获取参考结果和 Triton 计算结果 + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1) + + # 5. 精度比对 + accuracy_comparison(triton_cal, torch_ref) + + # 6. 打印成功信息 + print( + f"== dtype:{triton_cal.dtype} == The accuracy comparison between triton_result and torch_result was successful.") + + +def accuracy_comparison(y_cal, y_ref): + """ + 精度比对函数:根据数据类型选择合适的比对策略。 + + 不同数据类型的处理策略: + - 浮点类型(float16/32, bfloat16):使用 torch.testing.assert_close,设置相对/绝对误差容限 + - 整数类型(int8/16/32/64):要求完全相等(torch.equal) + - 布尔类型(bool):CPU 上严格比较(避免设备差异) + """ + # 检查输出数据类型是否一致 + assert y_cal.dtype == y_ref.dtype, f"dtype mismatch: {y_cal.dtype} vs {y_ref.dtype}" + tensor_dtype = y_cal.dtype + + # 将张量移动到 NPU(假设测试在 NPU 上进行) + y_cal = y_cal.npu() + y_ref = y_ref.npu() + + # 根据数据类型选择不同的比对方式 + if tensor_dtype == torch.float16: + # float16 精度较低,允许稍大误差 + torch.testing.assert_close(y_ref, y_cal, rtol=1e-3, atol=1e-3, equal_nan=True) + elif tensor_dtype == torch.bfloat16: + # bfloat16 精度更低,建议转为 float32 再比较 + torch.testing.assert_close( + y_ref.to(torch.float32), + y_cal.to(torch.float32), + rtol=1e-3, + atol=1e-3, + equal_nan=True + ) + elif tensor_dtype == torch.float32: + # float32 精度较高,使用更严格的容差 + torch.testing.assert_close(y_ref, y_cal, rtol=1e-4, atol=1e-4, equal_nan=True) + elif tensor_dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint32]: + # 整数类型应完全相等 + assert torch.equal(y_cal, y_ref), f"Integer tensors are not equal for dtype {tensor_dtype}" + elif tensor_dtype == torch.bool: + # 布尔类型建议在 CPU 上比较,避免设备间布尔表示差异 + assert torch.equal(y_cal.cpu(), y_ref.cpu()), "Boolean tensors are not equal" + else: + raise ValueError(f'Invalid or unsupported tensor dtype: {tensor_dtype}') + + +# ==================== Pytest Test ==================== +@pytest.mark.parametrize("dtype_name, dtype, low, high", [ + ("fp32", torch.float32, 0, 1), + ("fp16", torch.float16, 0, 1), + ("bf16", torch.bfloat16, 0, 1), + ("i64", torch.int64, 1, 100), + ("i32", torch.int32, 1, 100), + ("i16", torch.int16, 1, 100), + ("i8", torch.int8, 1, 100), + ("i1", torch.bool, 0, 2), +]) +def test_all_dtypes(dtype_name, dtype, low, high): + N = 1024 + if dtype == torch.bool: + x0 = torch.randint(low=low, high=high, size=(N,)).bool().npu() + x1 = torch.randint(low=low, high=high, size=(N,)).bool().npu() + elif dtype.is_floating_point: + x0 = torch.rand((N,), dtype=dtype).npu() + x1 = torch.rand((N,), dtype=dtype).npu() + else: + x0 = torch.randint(low=low, high=high, size=(N,), dtype=dtype).npu() + x1 = torch.randint(low=low, high=high, size=(N,), dtype=dtype).npu() + + print(f"Running test for {dtype_name}...") + run_add(x0, x1) diff --git a/third_party/ascend/unittest/pytest_ut/test_15_demo_autotune.py b/third_party/ascend/unittest/pytest_ut/test_15_demo_autotune.py new file mode 100644 index 0000000000..8845a370b0 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_15_demo_autotune.py @@ -0,0 +1,90 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Autotune +============= +""" +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + + +# Return a set of different kernel configurations for autotune +def get_autotune_config(): + return [ + triton.Config({'XS': 1 * 128, 'multibuffer': True}), + triton.Config({'XS': 12 * 1024, 'multibuffer': True}), + triton.Config({'XS': 12 * 1024, 'multibuffer': False}), + triton.Config({'XS': 8 * 1024, 'multibuffer': True}), + ] + + +# Use @autotune decorator to automatically select the best kernel configuration +@triton.autotune( + configs=get_autotune_config(), + key=["numel"], +) +@triton.jit +def triton_calc_kernel( + out_ptr0, in_ptr0, in_ptr1, numel, + XS: tl.constexpr # Block size controlling how many elements each thread block processes +): + pid = tl.program_id(0) + idx = pid * XS + tl.arange(0, XS) + msk = idx < numel + for i in range(10000): + tmp0 = tl.load(in_ptr0 + idx, mask=msk, other=0.0) + tmp1 = tl.load(in_ptr1 + idx, mask=msk, other=0.0) + tmp2 = tl.math.exp(tmp0) + tmp1 + i + tl.store(out_ptr0 + idx, tmp2, mask=msk) + + +# Function to call the Triton kernel with autotuned configuration +def triton_calc_func(x0, x1): + n = x0.numel() + y0 = torch.empty_like(x0) + + def grid(meta): + return (triton.cdiv(n, meta["XS"]), 1, 1) + + triton_calc_kernel[grid](y0, x0, x1, n) + return y0 + + +# Reference implementation using PyTorch for correctness check +def torch_calc_func(x0, x1): + return torch.exp(x0) + x1 + 10000 - 1 + + +# ==================== Pytest Test ==================== +def test_triton_autotune(): + DEV = "npu" + DTYPE = torch.float32 + N = 192 * 1024 + x0 = torch.randn((N,), dtype=DTYPE, device=DEV) + x1 = torch.randn((N,), dtype=DTYPE, device=DEV) + + torch_ref = torch_calc_func(x0, x1) + triton_cal = triton_calc_func(x0, x1) + + torch.testing.assert_close(triton_cal, torch_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_16_profiler.py b/third_party/ascend/unittest/pytest_ut/test_16_profiler.py new file mode 100644 index 0000000000..d6f7d12ab4 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_16_profiler.py @@ -0,0 +1,125 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + + +def profiler_wrapper(fn, *args): + result_path = "./result_profiling" + skip_first = 10 + wait = 0 + warmup = 3 + active = 30 + repeat = 1 + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, + skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), + record_shapes=True, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for _ in range(skip_first + (wait + warmup + active) * repeat): + fn(*args) + prof.step() + stream.synchronize() + + +@triton.jit +def triton_kernel_add(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr): + idx = tl.arange(0, XS) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.load(in_ptr1 + idx) + tmp2 = tmp0 + tmp1 + tl.store(out_ptr0 + idx, tmp2) + + +@triton.jit +def triton_kernel_or(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr): + idx = tl.arange(0, XS) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.load(in_ptr1 + idx) + tmp2 = tmp0 | tmp1 + tl.store(out_ptr0 + idx, tmp2) + + +def triton_add_func(x0, x1, N): + y0 = torch.empty_like(x0) + triton_kernel_add[1, 1, 1](y0, x0, x1, N) + return y0 + + +def triton_or_func(x0, x1, N): + y0 = torch.empty_like(x0) + triton_kernel_or[1, 1, 1](y0, x0, x1, N) + return y0 + + +# ==================== Pytest Test ==================== +@pytest.mark.parametrize("dtype, low, high", [ + (torch.float32, 0, 1), + (torch.float16, 0, 1), + (torch.bfloat16, 0, 1), + (torch.int64, 1, 100), + (torch.int32, 1, 100), + (torch.int16, 1, 100), + (torch.int8, 1, 100), + (torch.bool, 0, 2), +]) +def test_elementwise_ops(dtype, low, high): + N = 1024 + test_case_is_inductor = False + + if dtype == torch.bool: + x0 = torch.randint(low=low, high=high, size=(N,)).bool().npu() + x1 = torch.randint(low=low, high=high, size=(N,)).bool().npu() + triton_cal = triton_or_func(x0, x1, N) + ref = x0 | x1 + else: + if dtype.is_floating_point: + x0 = torch.rand((N,), dtype=dtype).npu() + x1 = torch.rand((N,), dtype=dtype).npu() + else: + x0 = torch.randint(low=low, high=high, size=(N,), dtype=dtype).npu() + x1 = torch.randint(low=low, high=high, size=(N,), dtype=dtype).npu() + + triton_cal = triton_add_func(x0, x1, N) + ref = x0 + x1 + + torch.testing.assert_close(triton_cal, ref) + + def wrapper(): + _ = triton_add_func(x0, x1, N) if dtype != torch.bool else triton_or_func(x0, x1, N) + profiler_wrapper(wrapper) diff --git a/third_party/ascend/unittest/pytest_ut/test_17_demo_libentry.py b/third_party/ascend/unittest/pytest_ut/test_17_demo_libentry.py new file mode 100644 index 0000000000..a6172552f4 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_17_demo_libentry.py @@ -0,0 +1,131 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import time + +import pytest +import torch +import torch_npu + +import triton +import triton.language as tl +from triton.runtime.libentry import libentry + +DEV = "npu" +DTYPE = torch.float32 +SEQ_LEN = 2 * 1024 +device = torch.npu.current_device() +stream = torch.npu.current_stream(device) + + +def benchmark(func): + warmup = 10 + repeat = 100 + + def wrapper(*args, **kwargs): + # + for _ in range(warmup): + result = func(*args, **kwargs) + stream.synchronize() + # + start_time = time.perf_counter_ns() + for _ in range(repeat): + result = func(*args, **kwargs) + stream.synchronize() + end_time = time.perf_counter_ns() + # + start_time = start_time * 1e-3 + end_time = end_time * 1e-3 + elapsed_time = (end_time - start_time) / repeat + return (result, elapsed_time) + + return wrapper + + +@libentry() +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + XBLOCK: tl.constexpr, + XBLOCK_SUB: tl.constexpr, + RBLOCK: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * XBLOCK + rblk_idx = tl.arange(0, XBLOCK_SUB) + col_idx = tl.arange(0, RBLOCK) + for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB): + row_offsets = row_start + row_idx + rblk_idx[:, None] + col_offsets = col_idx[None, :] + xmask = row_offsets < n_rows + ymask = col_offsets < n_cols + mask = xmask & ymask + input_idx = row_offsets * input_row_stride + col_offsets + input_ptrs = input_ptr + input_idx + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=1).reshape(XBLOCK_SUB, 1) + softmax_output = numerator / denominator + output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) + tl.store(output_ptrs, softmax_output, mask=mask) + + +@benchmark +def torch_func(x0: torch.Tensor): + m = torch.nn.Softmax(dim=1) + return m(x0) + + +@benchmark +def triton_func(y0: torch.Tensor, x0: torch.Tensor): + n_rows, n_cols = x0.shape + ncore = 40 + xs = (n_rows + ncore - 1) // ncore + xss = min(xs, 5) + softmax_kernel[(ncore, 1, 1)]( + y0, + x0, + x0.stride(0), + y0.stride(0), + n_rows, + n_cols, + XBLOCK=xs, + XBLOCK_SUB=xss, + RBLOCK=n_cols, + ) + return y0 + + +@pytest.mark.parametrize("batch", [1000 * x for x in range(1, 16 + 1)]) +def test_demo_libentry_softmax(batch): + torch.manual_seed(0) + x = torch.rand((batch, SEQ_LEN), dtype=DTYPE, device=DEV) + y = torch.empty_like(x) + + torch_out, _ = torch_func(x) + triton_out, _ = triton_func(y, x) + + torch.testing.assert_close(triton_out, torch_out) diff --git a/third_party/ascend/unittest/pytest_ut/test_18_gather.py b/third_party/ascend/unittest/pytest_ut/test_18_gather.py new file mode 100644 index 0000000000..5b2041ffcb --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_18_gather.py @@ -0,0 +1,134 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +Gather +=============== +This is an example only for npu. +""" + +import pytest +import torch +import torch_npu +import triton +import triton.runtime.driver as driver +import triton.language as tl + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +# a torch-version gather benchmark +def torch_gather(embeddings, idxes, default_value=0.0): + # make the result tensor + res = torch.empty((idxes.shape[0], embeddings.shape[-1]), dtype=embeddings.dtype, device=embeddings.device) + + # scatter embeddings + res[idxes >= 0] = embeddings[idxes[idxes >= 0]] + # set default values + res[idxes < 0] = default_value + + return res + + +# triton-version gather's kernel +@triton.jit +def gather_kernel(embeddings_ptr, idxes_ptr, res_ptr, rows, cols, DEFAULT_VALUE: tl.constexpr, BIG_CORE_NUM: tl.constexpr, BIG_ROW_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE_SUB: tl.constexpr): + SMALL_ROW_BLOCK_SIZE = BIG_ROW_BLOCK_SIZE - 1 + + embedding_dtype = embeddings_ptr.type.element_ty + default_value = tl.cast(DEFAULT_VALUE, dtype=embedding_dtype) + default_embedding = tl.full((COL_BLOCK_SIZE_SUB, ), default_value, dtype=embedding_dtype) + + core_idx = tl.program_id(0) + # compute the the size and start index of block + row_block_size = BIG_ROW_BLOCK_SIZE if (core_idx < BIG_CORE_NUM) else SMALL_ROW_BLOCK_SIZE + row_start_idx = (core_idx * BIG_ROW_BLOCK_SIZE) if (core_idx < BIG_CORE_NUM) else (BIG_CORE_NUM * BIG_ROW_BLOCK_SIZE + (core_idx - BIG_CORE_NUM) * SMALL_ROW_BLOCK_SIZE) + + # process blocks witn shape (row_block_size, COL_BLOCK_SIZE_SUB) one by one + for col_idx in tl.range(0, COL_BLOCK_SIZE, COL_BLOCK_SIZE_SUB): + emb_col_offsets = col_idx + tl.arange(0, COL_BLOCK_SIZE_SUB) + emb_col_mask = emb_col_offsets < cols + + for row_idx in tl.range(row_start_idx, min(row_start_idx + row_block_size, rows)): + idx_val = tl.load(idxes_ptr + row_idx) + + write_row_offset = row_idx * cols + write_emb_mask = emb_col_mask + + if idx_val >= 0: + read_row_offset = idx_val * cols + read_emb_mask = emb_col_mask + # read embedding + embedding = tl.load(embeddings_ptr + read_row_offset + emb_col_offsets, mask=read_emb_mask) + tl.store(res_ptr + write_row_offset + emb_col_offsets, embedding, write_emb_mask) + else: + # set default values + tl.store(res_ptr + write_row_offset + emb_col_offsets, default_embedding, write_emb_mask) + + +# triton-version gather's host +def triton_gather(embeddings: torch.Tensor, indices: torch.Tensor, default_value=0.0): + # constant settings for npu + USE_SIZE = 96 * 1024 + CORE_NUM = get_npu_properties()["num_vectorcore"] + + n_rows = indices.shape[0] + n_cols = embeddings.shape[1] + # make the result tensor + output = torch.empty(n_rows, n_cols, dtype=embeddings.dtype, device=embeddings.device) + + # when writing an npu kernel using triton, + # you should note that the difference between BLOCK_SIZE and BLOCK_SIZE_SUB + # BLOCK_SIZE specifies the size of data that are processed in one program + col_size_aligned = triton.cdiv(embeddings.shape[-1] * embeddings.element_size(), 32) * 32 // embeddings.element_size() + # the data are scattered to multiple programs, which can not be even + # some process more data, some process less + big_row_block_size = triton.cdiv(n_rows, CORE_NUM) + big_core_num = CORE_NUM - ((big_row_block_size * CORE_NUM) - n_rows) + col_block_size = col_size_aligned + + # BLOCK_SIZE_SUB specifies the size of data that are processed in one loop of a program + max_col_block_size_sub = USE_SIZE // embeddings.element_size() // 2 + col_block_size_sub = min(col_size_aligned, max_col_block_size_sub) + + grid = (min(n_rows, CORE_NUM), triton.cdiv(n_cols, col_block_size)) + # launch the kernel + gather_kernel[grid](embeddings, indices, output, n_rows, n_cols, default_value, BIG_CORE_NUM=big_core_num, BIG_ROW_BLOCK_SIZE=big_row_block_size, COL_BLOCK_SIZE=col_block_size, COL_BLOCK_SIZE_SUB=col_block_size_sub) + + return output + + +# ==================== Pytest Test ==================== +@pytest.mark.parametrize("n_rows", [500, 1000]) +@pytest.mark.parametrize("n_cols", [16, 17, 31, 32, 63, 64, 128, 256, 819, 512, 1024, 8192, 1001, 2003, 17000]) +@pytest.mark.parametrize("index_num", [19, 123, 4321, 54321, 100, 200, 819, 500, 700, 1000]) +def test_gather(n_rows, n_cols, index_num): + indices = torch.randint(0, n_rows, (index_num, ), dtype=torch.int32).npu() + embeddings = torch.randn(n_rows, n_cols, dtype=torch.float).npu() + + expect = torch_gather(embeddings, indices).cpu() + actual = triton_gather(embeddings, indices).cpu() + torch.npu.synchronize() + + torch.testing.assert_close(actual, expect) diff --git a/third_party/ascend/unittest/pytest_ut/test_add.py b/third_party/ascend/unittest/pytest_ut/test_add.py index f6e9dd5a42..3dd8402761 100644 --- a/third_party/ascend/unittest/pytest_ut/test_add.py +++ b/third_party/ascend/unittest/pytest_ut/test_add.py @@ -74,3 +74,20 @@ def test_all_blocks_parallel(param_list, monkeypatch): triton_add[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) test_common.validate_cmp(dtype, y_cal, y_ref) monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ] + ) +def test_auto_blockify(param_list, monkeypatch): + monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_add[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub, auto_blockify_size=ncore) + test_common.validate_cmp(dtype, y_cal, y_ref) + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") \ No newline at end of file diff --git a/third_party/ascend/unittest/pytest_ut/test_address_check.py b/third_party/ascend/unittest/pytest_ut/test_address_check.py new file mode 100644 index 0000000000..ac93fe35bf --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_address_check.py @@ -0,0 +1,69 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +import torch +import torch_npu +import triton +import triton.language as tl +import pytest + + +@triton.jit +def simple_kernel(x_ptr, y_ptr, output_ptr, n_elements): + pid = tl.program_id(axis=0) + offsets = pid * 1024 + tl.arange(0, 1024) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + ret = x + y + tl.store(output_ptr + offsets, ret, mask=mask) + + +def test_npu_tensor_should_success(): + print("Test the NPU tensor. The NPU tensor should be passed and executed properly.") + + size = 1024 + x_npu = torch.rand(size, device='npu') + y_npu = torch.rand(size, device='npu') + output = torch.empty(size, device='npu') + + simple_kernel[(1,)](x_npu, y_npu, output, size) + + expected = x_npu + y_npu + actual = output + + torch.testing.assert_close(expected, actual, rtol=1e-03, atol=1e-03) + + +def test_cpu_tensor_should_fail(): + print("Test the CPU tensor. An address check exception should be raised.") + + size = 1024 + x_cpu = torch.rand(size, device='cpu') + y_cpu = torch.rand(size, device='cpu') + output = torch.empty(size, device='npu') + + with pytest.raises(ValueError) as exc_info: + simple_kernel[(1,)](x_cpu, y_cpu, output, size) + + error_msg = str(exc_info.value) + assert "cannot be accessed from Triton (cpu tensor?)" in error_msg, \ + f"Expected error message to contain CPU tensor rejection hint, but got: {error_msg}" diff --git a/third_party/ascend/unittest/pytest_ut/test_advance_ptr.py b/third_party/ascend/unittest/pytest_ut/test_advance_ptr.py new file mode 100644 index 0000000000..76d7e8e1d1 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_advance_ptr.py @@ -0,0 +1,65 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest + + +@triton.jit +def fn_npu_3d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(XB, YB, ZB), + strides=(YB * ZB, ZB, 1), + offsets=(0, 0, 0), + block_shape=(XB, YB, 2), + order=(2, 1, 0) + ) + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(XB, YB, ZB), + strides=(YB * ZB, ZB, 1), + offsets=(0, 0, 0), + block_shape=(XB, YB, 2), + order=(2, 1, 0) + ) + pid = tl.program_id(axis=0) # pid=0,1 BLOCK_SIZE_N=8 + for _ in range(ZB // 2): + X = tl.load(block_ptr_in, boundary_check=(0, 1, 2)) + tl.store(block_ptr_out, X, boundary_check=(0, 1, 2)) + block_ptr_in = tl.advance(block_ptr_in, (0, 0, 2)) + block_ptr_out = tl.advance(block_ptr_out, (0, 0, 2)) + + +@pytest.mark.parametrize('dtype', ["int32", "float32", "int16"]) +@pytest.mark.parametrize('shape', [(33, 9, 6), (8, 8, 4)]) +def test_advance_with_boundary_check(dtype, shape): + x = torch.randint(low=-128, high=128, size=shape, dtype=eval('torch.' + dtype)).npu() + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + expected = x + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + fn_npu_3d[1, 1, 1](output, x, XB=shape[0], YB=shape[1], ZB=shape[2]) + torch.testing.assert_close(output, expected) \ No newline at end of file diff --git a/third_party/ascend/unittest/pytest_ut/test_affine_map_binding.py b/third_party/ascend/unittest/pytest_ut/test_affine_map_binding.py new file mode 100644 index 0000000000..ba3054b655 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_affine_map_binding.py @@ -0,0 +1,114 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +import triton.language.extra.cann.extension as al +import pytest + + +def test_extension_reexports_affine_bindings(): + assert al.affine_map is ascend_ir.affine_map + assert al.affine_expr is ascend_ir.affine_expr + assert al.affine_constant_expr is ascend_ir.affine_constant_expr + assert al.affine_dim_expr is ascend_ir.affine_dim_expr + assert al.affine_symbol_expr is ascend_ir.affine_symbol_expr + assert al.affine_binary_op_expr is ascend_ir.affine_binary_op_expr + assert al.AffineMap is al.affine_map + assert al.AffineExpr is al.affine_expr + assert al.AffineConstantExpr is al.affine_constant_expr + assert al.AffineDimExpr is al.affine_dim_expr + assert al.AffineSymbolExpr is al.affine_symbol_expr + assert al.AffineBinaryOpExpr is al.affine_binary_op_expr + + +def test_make_affine_map(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + + d0 = ascend_ir.affine_expr.get_dim(0) + d1 = ascend_ir.affine_expr.get_dim(1) + c2 = ascend_ir.affine_expr.get_constant(2) + + expr = (d0 + c2) * d1 + assert "d0" in str(expr) and "d1" in str(expr) + assert not expr.is_pure_affine() + assert hash(expr) == hash(expr) + assert d0 == ascend_ir.affine_expr.get_dim(0) + assert c2 == ascend_ir.affine_expr.get_constant(2) + assert isinstance(c2, ascend_ir.affine_expr) + assert isinstance(d0, ascend_ir.affine_expr) + + identity_map = ascend_ir.affine_map.get_identity(2) + transpose_map = ascend_ir.affine_map.get(2, 0, [1, 0]) + transpose_map_by_expr = ascend_ir.affine_map.get(2, 0, [d1, d0]) + sum_map = ascend_ir.affine_map.get(2, 0, [d0 + d1, d1]) + const_map = ascend_ir.affine_map.get_constant(7) + minor_identity_map = ascend_ir.affine_map.get_minor_identity(3, 2) + + assert identity_map.is_identity() + assert identity_map.is_permutation() + assert identity_map.get_num_dims() == 2 + assert identity_map.get_num_symbols() == 0 + assert identity_map.get_num_results() == 2 + assert str(identity_map) == "(d0, d1) -> (d0, d1)" + + assert not transpose_map.is_identity() + assert transpose_map.is_permutation() + assert str(transpose_map) == "(d0, d1) -> (d1, d0)" + assert str(transpose_map_by_expr) == "(d0, d1) -> (d1, d0)" + assert str(sum_map) == "(d0, d1) -> (d0 + d1, d1)" + assert transpose_map.to_dict() == { + "num_dims": 2, + "num_symbols": 0, + "results": [1, 0], + } + assert str(sum_map.get_sub_map([1])) == "(d0, d1) -> (d1)" + assert str(sum_map.compose(transpose_map)) == "(d0, d1) -> (d1 + d0, d0)" + assert str(transpose_map.inverse_permutation()) == "(d0, d1) -> (d1, d0)" + assert transpose_map == transpose_map_by_expr + assert hash(transpose_map) == hash(transpose_map) + assert [str(r) for r in sum_map.get_results()] == ["d0 + d1", "d1"] + assert const_map.is_single_constant() + assert const_map.get_constant_result() == 7 + assert str(minor_identity_map) == "(d0, d1, d2) -> (d1, d2)" + + +def test_build_buffer_type_with_affine_map(): + with ir.context() as ctx: + ir.load_dialects(ctx) + ascend_ir.load_dialects(ctx) + builder = ascend_ir.ascendnpu_ir_builder(ctx) + + transpose_map = ascend_ir.affine_map.get(2, 0, [1, 0]) + ub_attr = builder.get_target_attribute(ascend_ir.AddressSpace.UB) + + buffer_ty = builder.get_buffer_ty_with_affine_map([8, 16], builder.get_float_ty(), transpose_map, ub_attr) + + assert "memref<8x16xf32" in str(buffer_ty) + assert "affine_map<(d0, d1) -> (d1, d0)>" in str(buffer_ty) + assert "ub" in str(buffer_ty) + + +if __name__ == '__main__': + test_build_buffer_type_with_affine_map() + test_extension_reexports_affine_bindings() + test_make_affine_map() diff --git a/third_party/ascend/unittest/pytest_ut/test_alloc.py b/third_party/ascend/unittest/pytest_ut/test_alloc.py index a1b2b4360c..53d1994153 100644 --- a/third_party/ascend/unittest/pytest_ut/test_alloc.py +++ b/third_party/ascend/unittest/pytest_ut/test_alloc.py @@ -30,6 +30,7 @@ from triton._C.libtriton import ir, buffer_ir from triton._C.libtriton.ascend import ir as ascend_ir + os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" @@ -49,8 +50,7 @@ def compile_kernel(kernel, signature, constants): ir.load_dialects(context) buffer_ir.load_dialects(context) ascend_ir.load_dialects(context) - module = ast_to_ttir(kernel, src, context, Options(), {"create_address_space": al.semantic.create_address_space}, - {}) + module = ast_to_ttir(kernel, src, context, Options(), {"create_address_space": al.semantic.create_address_space}, {}) return str(module) @@ -66,6 +66,9 @@ def allocate_local_buffer(XBLOCK: tl.constexpr): bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0A) bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0B) bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0C) + bl.alloc( + tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.UB, is_mem_unique=True + ) # ============== Main for manual testing ============== @@ -74,6 +77,8 @@ def allocate_local_buffer(XBLOCK: tl.constexpr): print("=" * 60) print("Test 1: Nested Scopes") print("=" * 60) - mlir = compile_kernel(allocate_local_buffer, {}, {"XBLOCK": 256}) + mlir = compile_kernel( + allocate_local_buffer, {}, {"XBLOCK": 256} + ) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) diff --git a/third_party/ascend/unittest/pytest_ut/test_arch.py b/third_party/ascend/unittest/pytest_ut/test_arch.py new file mode 100644 index 0000000000..7ca66dc29a --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_arch.py @@ -0,0 +1,93 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os +import pytest +import triton +import triton.language as tl +import triton.extension.buffer.language as bl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir, buffer_ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + arch = "Ascend950" + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + buffer_ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + + +@triton.jit +def copy( + A_ptr, + A1_ptr, + M: tl.constexpr, + N: tl.constexpr, +): + offs_a = tl.arange(0, M)[:, None] + offs_b = tl.arange(0, N)[None, :] + + offs_c = (offs_a) * M + (offs_b) + a_ptr = A_ptr + offs_c + a_val = tl.load(a_ptr) + a1_ptr = A1_ptr + offs_c + a1_val = tl.load(a1_ptr) + + add = tl.add(a_val, a1_val) + + add_ub = bl.to_buffer(add, al.ascend_address_space.UB) + A_l1 = bl.alloc(tl.float32, [M, N], al.ascend_address_space.L1) + al.copy(add_ub, A_l1) + + +def test_arch(): + print("=" * 60) + print("Test 1: copy ") + print("=" * 60) + mlir = compile_kernel( + copy, + {"A_ptr": "*fp32", "A1_ptr": "*fp32"}, + {"M": 16, "N": 16}, + ) + print(f"✅ Generated MLIR ({len(mlir)} chars):\n") + print(mlir) + +# ============== Main for manual testing ============== +if __name__ == "__main__": + test_arch() diff --git a/third_party/ascend/unittest/pytest_ut/test_argmax.py b/third_party/ascend/unittest/pytest_ut/test_argmax.py new file mode 100644 index 0000000000..16bfa28d78 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_argmax.py @@ -0,0 +1,65 @@ +import logging +import math +import pytest +import torch +import torch_npu +import numpy as np +import triton +import triton.language as tl + +import test_common + + +def torch_argmax(x0, dim, keepdim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + return torch.argmax(x0, dim=dim, keepdim=keepdim).npu() + + +@triton.jit +def triton_argmax_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + xoffset, None) + tmp4 = tl.argmax(tmp0, 0) + tl.store(out_ptr1, tmp4, None) + + +@pytest.mark.parametrize('shape', [(128,), (256,), (37,), (741,)]) +@pytest.mark.parametrize('dtype', ['int32', 'float32', 'uint8', 'int8']) +def test_argmax_1d(dtype, shape): + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty(1, dtype=torch.int32).npu() + numel = shape[0] + triton_argmax_1d[(1,)](x0, triton_res, numel, numel) + torch_res = torch_argmax(x0, dim=0, keepdim=True) + test_common.validate_cmp("int32", triton_res, torch_res) + + +@triton.jit +def triton_argmax_2d(in_ptr0, out_ptr0, + dim: tl.constexpr, + M: tl.constexpr, N: tl.constexpr, + MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, MNUMEL) + nblk_idx = tl.arange(0, NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * N + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=float('-inf')) + tmp4 = tl.argmax(x, dim) + if dim == 0: + tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) + else: + tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) + + +@pytest.mark.parametrize('shape', [(37, 125), (29, 4), (7, 31)]) +@pytest.mark.parametrize('dtype', ['int32', 'float32', 'uint8', 'int8']) +@pytest.mark.parametrize('dim', [0, 1]) +def test_argmax_2d(dtype, shape, dim): + shapex, shapey = shape + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[1 - dim], ], dtype=torch.int32).npu() + triton_argmax_2d[(1, 1)](x0, triton_res, dim, shapex, shapey, shapex, shapey) + torch_res = torch_argmax(x0, dim=dim, keepdim=False) + test_common.validate_cmp("int32", triton_res, torch_res) diff --git a/third_party/ascend/unittest/pytest_ut/test_argmin.py b/third_party/ascend/unittest/pytest_ut/test_argmin.py new file mode 100644 index 0000000000..98018baa55 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_argmin.py @@ -0,0 +1,65 @@ +import logging +import math +import pytest +import torch +import torch_npu +import numpy as np +import triton +import triton.language as tl + +import test_common + + +def torch_argmin(x0, dim, keepdim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + return torch.argmin(x0, dim=dim, keepdim=keepdim).npu() + + +@triton.jit +def triton_argmin_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + xoffset, None) + tmp4 = tl.argmin(tmp0, 0) + tl.store(out_ptr1, tmp4, None) + + +@pytest.mark.parametrize('shape', [(128,), (256,), (37,), (741,)]) +@pytest.mark.parametrize('dtype', ['int32', 'float32', 'uint8', 'int8']) +def test_argmin_1d(dtype, shape): + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty(1, dtype=torch.int32).npu() + numel = shape[0] + triton_argmin_1d[(1,)](x0, triton_res, numel, numel) + torch_res = torch_argmin(x0, dim=0, keepdim=True) + test_common.validate_cmp("int32", triton_res, torch_res) + + +@triton.jit +def triton_argmin_2d(in_ptr0, out_ptr0, + dim: tl.constexpr, + M: tl.constexpr, N: tl.constexpr, + MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, MNUMEL) + nblk_idx = tl.arange(0, NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * N + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=float('inf')) + tmp4 = tl.argmin(x, dim) + if dim == 0: + tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) + else: + tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) + + +@pytest.mark.parametrize('shape', [(37, 125), (29, 4), (7, 31)]) +@pytest.mark.parametrize('dtype', ['int32', 'float32', 'uint8', 'int8']) +@pytest.mark.parametrize('dim', [0, 1]) +def test_argmin_2d(dtype, shape, dim): + shapex, shapey = shape + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[1 - dim], ], dtype=torch.int32).npu() + triton_argmin_2d[(1, 1)](x0, triton_res, dim, shapex, shapey, shapex, shapey) + torch_res = torch_argmin(x0, dim=dim, keepdim=False) + test_common.validate_cmp("int32", triton_res, torch_res) diff --git a/third_party/ascend/unittest/pytest_ut/test_asm.py b/third_party/ascend/unittest/pytest_ut/test_asm.py index 02e69bddd6..db668e99f2 100644 --- a/third_party/ascend/unittest/pytest_ut/test_asm.py +++ b/third_party/ascend/unittest/pytest_ut/test_asm.py @@ -1,52 +1,101 @@ -import triton -import triton.language as tl -import numpy as np -import torch -import pytest -import test_common - - -def torch_add(x, y): - res = x + y - return res - - -@triton.jit -def triton_asm_add( - x_ptr, - y_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = tl.inline_asm_elementwise( - asm=""" - ADD.s64 $0, $1, $2 - """, - constraints=("=l,l,l"), - args=[x, y], - dtype=tl.int64, - is_pure=True, - pack=1, - ) - tl.store(output_ptr + offsets, output, mask=mask) - - -@pytest.mark.parametrize('param_list', [ - ['int64', 4096, 1024], -]) -def test_case(param_list): - dtype, length, block_size = param_list - ncore = length // block_size - x = test_common.generate_tensor((length, ), dtype).npu() - y = test_common.generate_tensor((length, ), dtype).npu() - res_ref = torch_add(x, y) - res_cal = torch.zeros((length, ), dtype=eval('torch.' + dtype)).npu() - triton_asm_add[(ncore, )](x, y, res_cal, length, BLOCK_SIZE=block_size) - test_common.validate_cmp(dtype, res_cal, res_ref) +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + +def torch_add(x, y): + res = x + y + return res + +@triton.jit +def triton_asm_add(x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = tl.inline_asm_elementwise( + asm=""" + ADD.s64 $0, $1, $2 + """, + constraints=( + "=l,l,l" + ), + args=[x, y], + dtype=tl.int64, + is_pure=True, + pack=1, + ) + tl.store(output_ptr + offsets, output, mask=mask) + + +@pytest.mark.parametrize('param_list', + [ + ['int64', 4096, 1024], + ] + ) + +def test_case(param_list): + dtype, length, block_size = param_list + ncore = length // block_size + x = test_common.generate_tensor((length,), dtype).npu() + y = test_common.generate_tensor((length,), dtype).npu() + res_ref = torch_add(x, y) + res_cal = torch.zeros((length,), dtype = eval('torch.' + dtype)).npu() + triton_asm_add[(ncore,)](x, y, res_cal, length, BLOCK_SIZE=block_size) + test_common.validate_cmp(dtype, res_cal, res_ref) + + +@triton.jit +def triton_asm_add_2d(x_ptr, + y_ptr, + output_ptr, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(axis=0) + row_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M) + col_offsets = tl.arange(0, BLOCK_N) + offsets = row_offsets[:, None] * N + col_offsets[None, :] + mask = (row_offsets[:, None] < M) & (col_offsets[None, :] < N) + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = tl.inline_asm_elementwise( + asm=""" + ADD.s64 $0, $1, $2 + """, + constraints=( + "=l,l,l" + ), + args=[x, y], + dtype=tl.int64, + is_pure=True, + pack=1, + ) + tl.store(output_ptr + offsets, output, mask=mask) + + +@pytest.mark.parametrize('param_list', + [ + ['int64', 64, 32, 16, 32], + ] + ) +def test_case_2d(param_list): + dtype, M, N, block_m, block_n = param_list + ncore = M // block_m + x = test_common.generate_tensor((M, N), dtype).npu() + y = test_common.generate_tensor((M, N), dtype).npu() + res_ref = torch_add(x, y) + res_cal = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() + triton_asm_add_2d[(ncore,)](x, y, res_cal, M, N, BLOCK_M=block_m, BLOCK_N=block_n) + test_common.validate_cmp(dtype, res_cal, res_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_asm_scalar.py b/third_party/ascend/unittest/pytest_ut/test_asm_scalar.py new file mode 100644 index 0000000000..3d86c516aa --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_asm_scalar.py @@ -0,0 +1,38 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest + + +@triton.jit +def triton_asm_time( + output_ptr, +): + y = tl.inline_asm_elementwise( + asm=""" + MOV $0, SYS_CNT + """, + constraints="=l", + args=[], + dtype=(tl.int64), + is_pure=False, + pack=1, + ) + tl.store(output_ptr, y) + + +@pytest.mark.parametrize( + "param_list", + [ + [ + "int64", + ] + ], +) +def test_case(param_list): + (dtype,) = param_list + res_cal = torch.zeros((1,), dtype=eval("torch." + dtype)).npu() + triton_asm_time[(1,)]( + res_cal, + ) diff --git a/third_party/ascend/unittest/pytest_ut/test_assume1.py b/third_party/ascend/unittest/pytest_ut/test_assume1.py new file mode 100644 index 0000000000..b4c8ec7e3f --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_assume1.py @@ -0,0 +1,34 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + +from triton._internal_testing import ( + is_interpreter +) + + +@triton.jit +def assume(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): + current_size = N - tl.program_id(0) * BLOCK_N + tl.assume(current_size >= BLOCK_N) + if current_size >= BLOCK_N: + tl.store(out_ptr + tl.program_id(0), current_size) + else: + tl.store(out_ptr + tl.program_id(0), current_size + 101024) + + +@pytest.mark.parametrize('dtype', ["float32"]) +def test_assume(dtype): + NBLOCKS = 1024 // 128 + BLOCK_N = 128 + N = 1024 + output = torch.zeros(NBLOCKS, device='npu') + pgm = assume[(NBLOCKS, )](output, N=N, BLOCK_N=BLOCK_N) + + if is_interpreter(): + return + + assert 'llvm.intr.assume' in pgm.asm['ttadapter'] \ No newline at end of file diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_add.py b/third_party/ascend/unittest/pytest_ut/test_atomic_add.py index 10aa305204..957bd03e08 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_add.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_add.py @@ -51,22 +51,48 @@ def atomic_add_supply(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr): tmp1 = tl.atomic_add(out_ptr0 + (x1), tmp0, xmask) -@pytest.mark.parametrize('param_list', [ - ['int16', (32, 32), 2], - ['int8', (32, 32), 2], - ['float32', (32, 32), 2], - ['float16', (64, 64), 4], - ['float32', (128, 128), 8], - ['float16', (128, 128), 16], - ['float32', (32768, 16), 32], -]) +@triton.jit +def atomic_add_for_load_offset( + index_ptr, in_ptr0, out_ptr0 +): + index = tl.atomic_add(index_ptr, 1) + val = tl.load(in_ptr0 + index) + tl.store(out_ptr0, val) + + +@triton.jit +def atomic_add_for_store_offset( + index_ptr, out_ptr0 +): + index = tl.atomic_add(index_ptr, 1) + tl.store(out_ptr0 + index, 1) + + +@pytest.mark.parametrize('param_list', + [ + ['int64', (256, 32), 2], + ['int32', (32, 32), 2], + ['int16', (32, 32), 2], + ['int8', (32, 32), 2], + ['uint8', (32, 32), 2], + ['float32', (32, 32), 2], + ['float16', (64, 64), 4], + ['bfloat16', (64, 64), 4], + ['float32', (128, 128), 8], + ['float16', (128, 128), 16], + ['float32', (32768, 16), 32], + ] + ) def test_atomic_add(param_list): dtype, shape, ncore = param_list block_size = shape[0] * shape[1] / ncore split_size = shape[0] // ncore x0_value = 3 x0 = torch.full(shape, x0_value, dtype=eval(f'torch.{dtype}')).npu() - x1 = torch.full((split_size, shape[1]), 2, dtype=eval(f'torch.{dtype}')).npu() + if dtype == 'int64': + x1 = torch.randint(-10**15, 10**15, (split_size, shape[1]), dtype=eval(f'torch.{dtype}')).npu() + else: + x1 = torch.full((split_size, shape[1]), 2, dtype=eval(f'torch.{dtype}')).npu() y = torch.full((split_size, shape[1]), -10, dtype=eval(f'torch.{dtype}')).npu() y_ref = x1 + 0 @@ -155,6 +181,33 @@ def test_atomic_add_2d_supply(dtype, shape): test_common.validate_cmp(dtype, x1, x1_ref) +def test_atomic_add_for_load_offset(): + index = torch.tensor([1]).npu() + input_tensor = torch.zeros(5).npu() + output = torch.tensor([1]).npu() + index_ref = index.clone() + index_ref += 1 + output_ref = output.clone() + output_ref = input_tensor[index] + + atomic_add_for_load_offset[(1, )](index, input_tensor, output) + torch.equal(index, index_ref) + torch.equal(output, output_ref) + + +def test_atomic_add_for_store_offset(): + index = torch.tensor([1]).npu() + output = torch.zeros(5).npu() + index_ref = index.clone() + index_ref += 1 + output_ref = output.clone() + output_ref[index] = 1 + + atomic_add_for_store_offset[(1, )](index, output) + torch.equal(index, index_ref) + torch.equal(output, output_ref) + + if __name__ == "__main__": param_list = ['float32', (32, 32), 2] test_atomic_add_2d(param_list) diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_and.py b/third_party/ascend/unittest/pytest_ut/test_atomic_and.py index 19a3bb6958..be1250d887 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_and.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_and.py @@ -39,12 +39,15 @@ def atomic_and(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr tl.store(out_ptr1 + (x1), tmp1, xmask) -@pytest.mark.parametrize('param_list', [ - ['int64', (32, 32), 2], - ['int32', (32, 32), 2], - ['int16', (32, 32), 2], - ['int8', (16, 16), 4], -]) +@pytest.mark.parametrize('param_list', + [ + ['int64', (32, 32), 2], + ['int32', (32, 32), 2], + ['int16', (32, 32), 2], + ['int8', (16, 16), 4], + ['uint8', (16, 16), 4], + ] + ) def test_atomic_and(param_list): dtype, shape, ncore = param_list block_size = shape[0] * shape[1] // ncore diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py b/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py index 3e3b3a6fc6..88bc7089b6 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py @@ -26,6 +26,15 @@ import torch_npu +types_all = [ + (torch.float32, 'float32'), +] + + +def ceil_div(a, b): + return (a + b - 1) // b + + @triton.jit def atomic_cas(in_ptr0, in_ptr1, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): xoffset = tl.program_id(0) * BLOCK_SIZE @@ -40,15 +49,57 @@ def atomic_cas(in_ptr0, in_ptr1, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl. tl.store(out_ptr1 + (x1), tmp1, xmask) -@pytest.mark.parametrize('param_list', [ - ['int16', (8, 8), 2], - ['int32', (32, 32), 6], - ['int64', (32, 32), 2], - ['float32', (32, 32), 2], - ['float16', (64, 64), 4], - ['float32', (128, 128), 8], - ['float16', (128, 128), 16], -]) + +@triton.jit +def atomic_cas_with_full( + ptr, + out, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + x = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = x < n_elements + + cmp = tl.full((BLOCK_SIZE,), 2.0, tl.float32) + val = tl.full((BLOCK_SIZE,), 1.0, tl.float32) + + old = tl.atomic_cas(ptr + x, cmp, val) # in_ptr(origin 2) -> ref: 1 X + tl.store(out + x, old, mask=mask) # out(origin 1) -> ref: old in_ptr(2) √ + + +@triton.jit +def atomic_cas_without_full( + ptr, + cmp_ptr, + val_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + x = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = x < n_elements + + cmp = tl.load(cmp_ptr + x, mask) # 2 + val = tl.load(val_ptr + x, mask) # 1 + + old = tl.atomic_cas(ptr + x, cmp, val) # old : 2 + tl.store(out_ptr + x, old, mask=mask) + + + +@pytest.mark.parametrize('param_list', + [ + ['int16', (8, 8), 2], + ['int32', (32, 32), 6], + ['int64', (32, 32), 2], + ['float32', (32, 32), 2], + ['float16', (64, 64), 4], + ['float32', (128, 128), 8], + ['float16', (128, 128), 16], + ] + ) def test_atomic_cas(param_list): dtype, shape, ncore = param_list block_size = shape[0] * shape[1] // ncore @@ -110,3 +161,43 @@ def test_atomic_cas_return_value(param_list): atomic_cas[ncore, 1, 1](val, cmp, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) test_common.validate_cmp(dtype, pointer, pointer_ref) test_common.validate_cmp(dtype, pointer_old, pointer_old_ref) + + + +@pytest.mark.parametrize('dtype,sigtype', types_all) +@pytest.mark.parametrize('n_elements, BLOCK_SIZE', [(4096, 256)]) +@pytest.mark.skip(reason="full tensor has problem, skipped") +def test_atomic_cas_with_full(n_elements, BLOCK_SIZE, dtype, sigtype): + in_ptr = torch.full((n_elements,), 2, dtype=dtype).npu() + out_ptr = torch.empty_like(in_ptr) + + grid = (ceil_div(n_elements, BLOCK_SIZE), 1, 1) + atomic_cas_with_full[grid]( + in_ptr, out_ptr, n_elements, + BLOCK_SIZE=BLOCK_SIZE + ) + + # old should be all 2 (for in-range) + torch.testing.assert_close(out_ptr, torch.full_like(out_ptr, 2.0)) + + # final ptr should be all 1 + torch.testing.assert_close(in_ptr, torch.ones_like(in_ptr)) + + +@pytest.mark.parametrize('dtype,sigtype', types_all) +@pytest.mark.parametrize('n_elements, BLOCK_SIZE', [(4096, 256)]) +def test_atomic_cas_without_full(n_elements, BLOCK_SIZE, dtype, sigtype): + in_ptr = torch.full((n_elements,), 2, dtype=dtype).npu() + cmp_ptr = torch.full((n_elements,), 2, dtype=dtype).npu() + val_ptr = torch.full((n_elements,), 1, dtype=dtype).npu() + out_ptr = torch.full((n_elements,), 1, dtype=dtype).npu() # ref: in_ptr + + grid = (ceil_div(n_elements, BLOCK_SIZE), 1, 1) + atomic_cas_without_full[grid]( + in_ptr, cmp_ptr, val_ptr, out_ptr, n_elements, + BLOCK_SIZE=BLOCK_SIZE + ) + + torch.testing.assert_close(in_ptr, torch.full_like(in_ptr, 1.0)) + torch.testing.assert_close(out_ptr, torch.full_like(out_ptr, 2.0)) + diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_max.py b/third_party/ascend/unittest/pytest_ut/test_atomic_max.py index 942f9429fd..77b64eee4f 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_max.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_max.py @@ -53,15 +53,22 @@ def triton_test_fn_atomic_max_dma_supply(in_ptr0, out_ptr0, n_elements: tl.const # torch.max do not support int -@pytest.mark.parametrize('param_list', [ - ['int16', (32, 32), 2], - ['float16', (32, 32), 2], - ['float32', (128, 128), 8], - ['float32', (32768, 16), 32], - ['int32', (32, 32), 2], - ['int32', (128, 128), 8], - ['int32', (32768, 16), 32], -]) +@pytest.mark.parametrize('param_list', + [ + ['uint8', (32, 32), 2], + ['int16', (32, 32), 2], + ['bfloat16', (32, 32), 2], + ['float16', (32, 32), 2], + ['float32', (128, 128), 8], + ['float32', (32768, 16), 32], + ['int32', (32, 32), 2], + ['int32', (128, 128), 8], + ['int32', (32768, 16), 32], + ['int64', (32, 32), 2], + ['int64', (128, 128), 8], + ['int64', (8192, 16), 32], + ] + ) def test_atomic_max(param_list): dtype, shape, ncore = param_list block_size = shape[0] * shape[1] / ncore diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_min.py b/third_party/ascend/unittest/pytest_ut/test_atomic_min.py index 213740548d..08dd777f76 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_min.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_min.py @@ -50,14 +50,18 @@ def triton_test_fn_atomic_min_dma_supply(in_ptr0, out_ptr0, n_elements: tl.const tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.atomic_min(out_ptr0 + (x1), tmp0, xmask) - -@pytest.mark.parametrize('param_list', [ - ['int8', (32, 32), 2], - ['int16', (32, 32), 2], - ['int32', (32, 32), 2], - ['float16', (64, 64), 4], - ['float32', (32, 32), 2], -]) +@pytest.mark.parametrize('param_list', + [ + ['uint8', (32, 32), 2], + ['int8', (32, 32), 2], + ['int16', (32, 32), 2], + ['int32', (32, 32), 2], + ['int64', (32, 32), 2], + ['bfloat16', (64, 64), 4], + ['float16', (64, 64), 4], + ['float32', (32, 32), 2], + ] + ) def test_atomic_min(param_list): dtype, shape, ncore = param_list block_size = shape[0] * shape[1] / ncore diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_rmw_useanalysis.py b/third_party/ascend/unittest/pytest_ut/test_atomic_rmw_useanalysis.py new file mode 100644 index 0000000000..b3d91384e9 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_rmw_useanalysis.py @@ -0,0 +1,70 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def atomic_rmw_useanalysis_kernel( + input_ptr, + output_ptr, + m_ptr, + d_ptr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + base_idx = pid * 8 + + term1 = 15.0 * 15.0 + term2 = 8.0 * (7.0 - base_idx) + + delta = term1 + term2 + sqrt_delta = tl.sqrt(delta) + + task_idx = tl.ceil((15.0 - sqrt_delta) / 2.0) + task_idx_i32 = task_idx.to(tl.int32) + + block_start = task_idx_i32 * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + data = tl.load(input_ptr + offsets, mask=mask, other=0.0) + m_val = tl.load(m_ptr + offsets, mask=mask, other=0.0) + d_val = tl.load(d_ptr + offsets, mask=mask, other=0.0) + + scaled = data - m_val + p = tl.exp(scaled) + + result = p * (data * 2.0 - d_val) + + output_offsets = offsets + tl.atomic_add(output_ptr + output_offsets, result, mask=mask) + + +def test_atomic_rmw_useanalysis(): + DEVICE = "npu" + N = 1024 + BLOCK_SIZE = 128 + + torch.manual_seed(42) + input_data = torch.randn(N, dtype=torch.float32, device=DEVICE) + m_data = torch.randn(N, dtype=torch.float32, device=DEVICE) + d_data = torch.randn(N, dtype=torch.float32, device=DEVICE) + output_data = torch.zeros(N, dtype=torch.float32, device=DEVICE) + + grid = (8,) + + atomic_rmw_useanalysis_kernel[grid]( + input_data, + output_data, + m_data, + d_data, + N=N, + BLOCK_SIZE=BLOCK_SIZE, + ) + output_sum = output_data.abs().sum().item() + + if output_sum == 0: + raise AssertionError("UseAnalysis bug detected: atomic_rmw dependencies were erased") + else: + print(" AtomicRMW UseAnalysis is working correctly.") diff --git a/third_party/ascend/unittest/pytest_ut/test_block_ptr.py b/third_party/ascend/unittest/pytest_ut/test_block_ptr.py index 5719e7a373..80f85b1e3a 100644 --- a/third_party/ascend/unittest/pytest_ut/test_block_ptr.py +++ b/third_party/ascend/unittest/pytest_ut/test_block_ptr.py @@ -83,4 +83,109 @@ def test_npu(para_type, data_type, XB, YB, ZB): print(a) fn_npu_[1, 1, 1](output, x, y, z, output1, XB=XB, YB=YB, ZB=ZB, debug=True) print(output) - torch.testing.assert_close(output, a) + torch.testing.assert_close(output,a) + + +@triton.jit +def dma_block_ptr( + input_ptr, + output_ptr, + scale_ptr, + batch_size, + cu_seqlens_ptr, + stride_i_m, + stride_i_n, + stride_o_m, + stride_o_n, + stride_s_b, + HEAD_DIM, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + n_progs = tl.num_programs(0) + pid = tl.program_id(0) + + cu_num_blocks = 0 + for bid in range(batch_size): + start_loc = tl.load(cu_seqlens_ptr + bid) + end_loc = tl.load(cu_seqlens_ptr + bid + 1) + scale = tl.load(scale_ptr + bid * stride_s_b) + + len_loc = end_loc - start_loc + prev_num_blocks = cu_num_blocks + new_num_blocks = tl.cdiv(len_loc, BLOCK_SIZE_M).to(tl.int32) + i_block_ptr_bbase = tl.make_block_ptr( + input_ptr + start_loc * stride_i_m, + shape=(len_loc, HEAD_DIM), + strides=(stride_i_m, stride_i_n), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), + order=(1, 0), + ) + o_block_ptr_bbase = tl.make_block_ptr( + output_ptr + start_loc * stride_o_m, + shape=(len_loc, HEAD_DIM), + strides=(stride_o_m, stride_o_n), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), + order=(1, 0), + ) + cu_num_blocks += new_num_blocks + for m_id in range((prev_num_blocks + pid) % n_progs, new_num_blocks, n_progs): + i_block_ptr = tl.advance(i_block_ptr_bbase, (m_id * BLOCK_SIZE_M, 0)) + o_block_ptr = tl.advance(o_block_ptr_bbase, (m_id * BLOCK_SIZE_M, 0)) + i_tile = tl.load(i_block_ptr, boundary_check=[0, 1], padding_option="zero") + o_tile = i_tile.to(tl.float32) * scale + tl.store(o_block_ptr, o_tile.to(i_tile.dtype), boundary_check=[0, 1]) + + +def ref_func(inputs, scale, cu_lens): + outputs = torch.zeros_like(inputs) + bsz = cu_lens.size(0) - 1 + for bid in range(bsz): + tmp = inputs[cu_lens[bid]: cu_lens[bid + 1]].to(torch.float32) * scale[bid] + outputs[cu_lens[bid]: cu_lens[bid + 1]] = tmp.to(outputs.dtype) + return outputs + + +def tt_func(inputs, scale, cu_lens): + bsz = cu_lens.size(0) - 1 + outputs = torch.zeros_like(inputs) + head_dim = inputs.size(-1) + assert head_dim <= 1024 + BLOCK_SIZE_N = 1024 + BLOCK_SIZE_M = 4 + dma_block_ptr[20, ]( + inputs, outputs, scale, bsz, cu_lens, + inputs.stride(0), + inputs.stride(1), + outputs.stride(0), + outputs.stride(1), + scale.stride(0), + head_dim, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return outputs + + +@pytest.mark.parametrize('param_list', + [ + [8, 1024, 1024, True], + [8, 1024, 1024, False], + ] + ) +def test_func(param_list): + bsz, max_len, max_n, test_align = param_list + lens = torch.randint(max_len // 2, max_len, (bsz,), dtype=torch.int32, device="npu") + n = torch.randint(max_n // 2, max_n, (1,), dtype=torch.int32, device="npu")[0].item() + if test_align: + lens = (lens + 1023) // 1024 * 1024 + n = (n + 1023) // 1024 * 1024 + cu_lens = torch.cumsum(lens, dim=0) + cu_lens = torch.cat([torch.zeros(1, dtype=torch.int32, device="npu"), cu_lens], dim=0) + inputs = torch.randn(cu_lens[-1], n, dtype=torch.float16, device="npu") + scale = torch.randn(bsz, dtype=torch.float32, device="npu") + ref_output = ref_func(inputs, scale, cu_lens) + tt_output = tt_func(inputs, scale, cu_lens) + torch.testing.assert_close(ref_output, tt_output) diff --git a/third_party/ascend/unittest/pytest_ut/test_boundary_check.py b/third_party/ascend/unittest/pytest_ut/test_boundary_check.py new file mode 100644 index 0000000000..8773a0b038 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_boundary_check.py @@ -0,0 +1,317 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import triton +import triton.language as tl +import pytest + + +# ========== Test 1: Static base address + boundary_check ========== +@triton.jit +def static_base_boundary_check_kernel( + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, +): + ptr = tl.make_block_ptr( + base=in_ptr, + shape=(BLOCK_SIZE * 2,), + strides=(1,), + offsets=(0,), + block_shape=(BLOCK_SIZE,), + order=(0,) + ) + data = tl.load(ptr, boundary_check=(0,), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def ref_static_base(in_tensor, BLOCK_SIZE): + return in_tensor[:BLOCK_SIZE].sum().item() + + +def test_static_base(): + BLOCK_SIZE = 64 + in_tensor = torch.randn(BLOCK_SIZE * 2, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + static_base_boundary_check_kernel[(1,)]( + out_ptr=out_tensor, + in_ptr=in_tensor, + BLOCK_SIZE=BLOCK_SIZE, + ) + expected = ref_static_base(in_tensor.cpu(), BLOCK_SIZE) + assert torch.allclose(out_tensor.cpu(), torch.tensor(expected, device='cpu'), atol=1e-4) + + +# ========== Test 2: Simple dynamic base address + boundary_check ========== +@triton.jit +def simple_dynamic_base_boundary_check_kernel( + out_ptr, + in_ptr, + offset: tl.int32, + BLOCK_SIZE: tl.constexpr, +): + base = in_ptr + offset + ptr = tl.make_block_ptr( + base=base, + shape=(BLOCK_SIZE * 2,), + strides=(1,), + offsets=(0,), + block_shape=(BLOCK_SIZE,), + order=(0,) + ) + data = tl.load(ptr, boundary_check=(0,), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def test_simple_dynamic_base(): + BLOCK_SIZE = 64 + offset = 32 + in_tensor = torch.randn(BLOCK_SIZE * 4, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + simple_dynamic_base_boundary_check_kernel[(1,)]( + out_ptr=out_tensor, + in_ptr=in_tensor, + offset=offset, + BLOCK_SIZE=BLOCK_SIZE, + ) + expected = in_tensor.cpu()[offset:offset + BLOCK_SIZE].sum().item() + assert torch.allclose(out_tensor.cpu(), torch.tensor(expected, device='cpu'), atol=1e-4) + + +# ========== Test 3: Nested loop + dynamic base address + advance + boundary_check ========== +@triton.jit +def nested_dynamic_advance_boundary_kernel( + out_ptr, + in_ptr, + stride_in: tl.int32, + OUTER_LOOP: tl.constexpr, + INNER_LOOP: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Smallest reproducible code: The dynamic base address is in the outer loop, + and tl.advance is in the inner loop, where there is a boundary_check. + """ + for i in range(OUTER_LOOP): + base = in_ptr + i * stride_in + ptr = tl.make_block_ptr( + base=base, + shape=(INNER_LOOP * BLOCK_SIZE,), + strides=(1,), + offsets=(0,), + block_shape=(BLOCK_SIZE,), + order=(0,) + ) + for j in range(INNER_LOOP): + cur_ptr = tl.advance(ptr, (j * BLOCK_SIZE,)) + data = tl.load(cur_ptr, boundary_check=(0,), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr + i * INNER_LOOP + j, result) + + +def ref_nested_dynamic(in_tensor, OUTER_LOOP, INNER_LOOP, BLOCK_SIZE): + """ + PyTorch equivalent implementation: + - Treat in_tensor as a tensor of shape [OUTER_LOOP, INNER_LOOP * BLOCK_SIZE] + - For each (i, j) block: take the BLOCK_SIZE elements starting from j*BLOCK_SIZE in the i-th row and sum them up. + - Note: There is boundary_check + zero padding, but there is no out-of-bound access in this case, so no special handling is needed. + """ + reshaped = in_tensor[:OUTER_LOOP * INNER_LOOP * BLOCK_SIZE].view(OUTER_LOOP, INNER_LOOP * BLOCK_SIZE) + blocks = reshaped.unfold(1, BLOCK_SIZE, BLOCK_SIZE) + return blocks.sum(dim=-1).flatten() + + +def test_nested_dynamic(): + BLOCK_SIZE = 8 + OUTER_LOOP = 2 + INNER_LOOP = 2 + in_tensor = torch.randn(OUTER_LOOP * INNER_LOOP * BLOCK_SIZE * 2, dtype=torch.float32).npu() + out_tensor = torch.zeros(OUTER_LOOP * INNER_LOOP, dtype=torch.float32).npu() + nested_dynamic_advance_boundary_kernel[(1,)]( + out_ptr=out_tensor, + in_ptr=in_tensor, + stride_in=INNER_LOOP * BLOCK_SIZE, + OUTER_LOOP=OUTER_LOOP, + INNER_LOOP=INNER_LOOP, + BLOCK_SIZE=BLOCK_SIZE, + ) + ref = ref_nested_dynamic(in_tensor.cpu(), OUTER_LOOP, INNER_LOOP, BLOCK_SIZE) + assert torch.allclose(out_tensor.cpu(), ref, atol=1e-4) + + +# ========== Test 4: Explicit out-of-bounds access + zero padding + boundary_check ========== +@triton.jit +def out_of_bound_zero_padding_kernel( + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, +): + ptr = tl.make_block_ptr( + base=in_ptr, + shape=(BLOCK_SIZE,), + strides=(1,), + offsets=(0,), + block_shape=(BLOCK_SIZE * 2,), + order=(0,) + ) + data = tl.load(ptr, boundary_check=(0,), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def test_out_of_bound(): + BLOCK_SIZE = 64 + in_tensor = torch.randn(BLOCK_SIZE, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + out_of_bound_zero_padding_kernel[(1,)]( + out_ptr=out_tensor, + in_ptr=in_tensor, + BLOCK_SIZE=BLOCK_SIZE, + ) + expected = in_tensor.cpu().sum().item() + assert torch.allclose(out_tensor.cpu(), torch.tensor(expected, device='cpu'), atol=1e-4) + + +# ========== Test 5:padding_option = NAN + boundary_check========== +@triton.jit +def nan_padding_kernel( + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, +): + ptr = tl.make_block_ptr( + base=in_ptr, + shape=(BLOCK_SIZE,), + strides=(1,), + offsets=(0,), + block_shape=(BLOCK_SIZE * 2,), + order=(0,) + ) + data = tl.load(ptr, boundary_check=(0,), padding_option="nan") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def test_nan_padding(): + BLOCK_SIZE = 64 + in_tensor = torch.randn(BLOCK_SIZE, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + try: + nan_padding_kernel[(1,)]( + out_ptr=out_tensor, + in_ptr=in_tensor, + BLOCK_SIZE=BLOCK_SIZE, + ) + assert torch.isnan(out_tensor.cpu()).any() + except Exception as e: + print(f"Warning: NAN padding test may not be supported: {e}") + + +# ========== Test 6:Multi-layer advance + boundary_check ========== +@triton.jit +def multi_advance_kernel( + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, +): + base = in_ptr + ptr0 = tl.make_block_ptr( + base=base, + shape=(BLOCK_SIZE * 4,), + strides=(1,), + offsets=(0,), + block_shape=(BLOCK_SIZE,), + order=(0,) + ) + ptr1 = tl.advance(ptr0, (BLOCK_SIZE,)) + ptr2 = tl.advance(ptr1, (BLOCK_SIZE,)) + data = tl.load(ptr2, boundary_check=(0,), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def test_multi_advance(): + BLOCK_SIZE = 64 + in_tensor = torch.randn(BLOCK_SIZE * 4, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + multi_advance_kernel[(1,)]( + out_ptr=out_tensor, + in_ptr=in_tensor, + BLOCK_SIZE=BLOCK_SIZE, + ) + expected = in_tensor.cpu()[2 * BLOCK_SIZE:3 * BLOCK_SIZE].sum().item() + assert torch.allclose(out_tensor.cpu(), torch.tensor(expected, device='cpu'), atol=1e-4) + + +# ========== Test 7:Complex base address calculation + boundary_check ========== +@triton.jit +def complex_base_calculation_kernel( + out_ptr, + in_ptr, + offset1: tl.int32, + offset2: tl.int32, + scale: tl.int32, + BLOCK_SIZE: tl.constexpr, +): + base = in_ptr + offset1 * scale + offset2 + ptr = tl.make_block_ptr( + base=base, + shape=(BLOCK_SIZE * 2,), + strides=(1,), + offsets=(0,), + block_shape=(BLOCK_SIZE,), + order=(0,) + ) + data = tl.load(ptr, boundary_check=(0,), padding_option="zero") + result = tl.sum(data) + tl.store(out_ptr, result) + + +def test_complex_base(): + BLOCK_SIZE = 64 + offset1, offset2, scale = 2, 16, 32 + total_offset = offset1 * scale + offset2 + in_tensor = torch.randn(total_offset + BLOCK_SIZE * 2, dtype=torch.float32).npu() + out_tensor = torch.zeros(1, dtype=torch.float32).npu() + complex_base_calculation_kernel[(1,)]( + out_ptr=out_tensor, + in_ptr=in_tensor, + offset1=offset1, + offset2=offset2, + scale=scale, + BLOCK_SIZE=BLOCK_SIZE, + ) + expected = in_tensor.cpu()[total_offset:total_offset + BLOCK_SIZE].sum().item() + assert torch.allclose(out_tensor.cpu(), torch.tensor(expected, device='cpu'), atol=1e-4) + + +if __name__ == "__main__": + print("Running all boundary_check tests...") + test_static_base() + test_simple_dynamic_base() + test_nested_dynamic() + test_out_of_bound() + test_nan_padding() + test_multi_advance() + test_complex_base() + print("All tests completed successfully!") diff --git a/third_party/ascend/unittest/pytest_ut/test_cat_help_func.py b/third_party/ascend/unittest/pytest_ut/test_cat_help_func.py new file mode 100644 index 0000000000..3fc6771f3a --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_cat_help_func.py @@ -0,0 +1,720 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import logging +import random +import pytest + +import triton +import triton.language as tl +import torch +import test_common +import numpy as np +import triton.language.extra.cann.extension as extension + + +def gen_1d_cat_shapes(min_val=1, max_val=4096): + shape1 = random.randint(min_val, max_val) + shape2 = random.randint(min_val, max_val) + return (shape1,), (shape2,), 0 + + +def gen_2d_cat_shapes(dim=0, min_val=1, max_val=4096): + if dim == 0: + common_col = random.randint(min_val, max_val) + row1 = random.randint(min_val, max_val) + row2 = random.randint(min_val, max_val) + shape1 = (row1, common_col) + shape2 = (row2, common_col) + elif dim == 1: + common_row = random.randint(min_val, max_val) + col1 = random.randint(min_val, max_val) + col2 = random.randint(min_val, max_val) + shape1 = (common_row, col1) + shape2 = (common_row, col2) + else: + raise ValueError("2d shape only support dim=0 or dim=1") + return shape1, shape2, dim + + +def gen_3d_cat_shapes(dim=0, min_val=1, max_val=4096): + if dim not in [0, 1, 2]: + raise ValueError("3d shape only support dim=0/1/2") + + if dim == 0: + common_d1 = random.randint(min_val, max_val) + common_d2 = random.randint(min_val, max_val) + d0_1 = random.randint(min_val, max_val) + d0_2 = random.randint(min_val, max_val) + shape1 = (d0_1, common_d1, common_d2) + shape2 = (d0_2, common_d1, common_d2) + + elif dim == 1: + common_d0 = random.randint(min_val, max_val) + common_d2 = random.randint(min_val, max_val) + d1_1 = random.randint(min_val, max_val) + d1_2 = random.randint(min_val, max_val) + shape1 = (common_d0, d1_1, common_d2) + shape2 = (common_d0, d1_2, common_d2) + + else: # dim == 2 + common_d0 = random.randint(min_val, max_val) + common_d1 = random.randint(min_val, max_val) + d2_1 = random.randint(min_val, max_val) + d2_2 = random.randint(min_val, max_val) + shape1 = (common_d0, common_d1, d2_1) + shape2 = (common_d0, common_d1, d2_2) + + return shape1, shape2, dim + + +def gen_100_cat_shapes( + num_groups=100, + mix_ratio=(0.3, 0.3, 0.4), + min_val=1, + max_val=4096 +): + + shape_list = [] + num_1d = int(num_groups * mix_ratio[0]) + num_2d = int(num_groups * mix_ratio[1]) + num_3d = num_groups - num_1d - num_2d + + for _ in range(num_1d): + shape_list.append(gen_1d_cat_shapes(min_val, max_val)) + + for _ in range(num_2d): + dim = random.choice([0, 1]) + shape_list.append(gen_2d_cat_shapes(dim, min_val, max_val)) + + for _ in range(num_3d): + dim = random.choice([0, 1, 2]) + shape_list.append(gen_3d_cat_shapes(dim, min_val, max_val)) + + random.shuffle(shape_list) + return shape_list + +full_shape = gen_100_cat_shapes( + num_groups=100, + mix_ratio=(0.3, 0.4, 0.3), + min_val=1, + max_val=4096 +) + + +@triton.jit +def _cat_helper_func_2D_1( + in_ptr0, + in_ptr1, + out_ptr0, + in0_x: tl.constexpr, + in1_x: tl.constexpr, + y0_numel, + x1_numel, + Y0BLOCK: tl.constexpr, + Y0BLOCK_SUB: tl.constexpr, + +): + y0_offset = tl.program_id(0) * Y0BLOCK_SUB + base_y0 = tl.arange(0, Y0BLOCK_SUB) + loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB + base_input0_x1 = tl.arange(0, in0_x)[None, :] + base_input1_x1 = tl.arange(0, in1_x)[None, :] + x1 = tl.arange(0, in0_x + in1_x)[None, :] + + for loop in range(loops_y0): + y0 = y0_offset + (loop * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + x1_mask = x1 < x1_numel + tmp0 = tl.load(in_ptr0 + (base_input0_x1 + in0_x * y0), y0_mask) + tmp1 = tl.load(in_ptr1 + (base_input1_x1 + in1_x * y0), y0_mask) + tmp2 = tl.zeros((Y0BLOCK_SUB, in0_x + in1_x), dtype=tmp0.dtype) + tmp3 = extension.insert_slice(tmp2, tmp0, [0, 0], [Y0BLOCK_SUB, in0_x], [1, 1]) + tmp4 = extension.insert_slice(tmp3, tmp1, [0, in0_x], [Y0BLOCK_SUB, in1_x], [1, 1]) + tl.store(out_ptr0 + (x1 + (in0_x + in1_x) * y0), tmp4, x1_mask & y0_mask) + + +@triton.jit +def triton_unk_fused_cat_dim0_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): + y0_offset = tl.program_id(0) * Y0BLOCK + base_y0 = tl.arange(0, Y0BLOCK_SUB) + loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB + base_x1 = tl.arange(0, X1BLOCK_SUB) + loops_x1 = (x1_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < x1_numel + + tmp0 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp10 = tl.zeros((2 * Y0BLOCK_SUB, X1BLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [Y0BLOCK_SUB, 0], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (2, Y0BLOCK_SUB, X1BLOCK_SUB)) + + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[None, :, None] + new_z0 = tl.arange(0, 2)[:, None, None] + new_x2_mask = new_x2 < x1_numel + new_y1_mask = new_y1 < y0_numel + tl.store(output_ptr + (new_x2 + x1_numel * (new_y1 + y0_numel * new_z0)), tmp13, new_x2_mask & new_y1_mask) + + +@triton.jit +def triton_unk_fused_cat_dim0_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, y1_numel, x1_numel, YBLOCK: tl.constexpr, + YBLOCK_2: tl.constexpr, YBLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): + y0_offset = tl.program_id(0) * YBLOCK + base_y0 = tl.arange(0, YBLOCK_SUB) + loops_y0 = (YBLOCK + YBLOCK_SUB - 1) // YBLOCK_SUB + base_x1 = tl.arange(0, X1BLOCK_SUB) + loops_x1 = (x1_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB + min_numel = 0 + max_numel = 0 + clone_numel = 0 + if y0_numel < y1_numel: + min_numel = y0_numel + max_numel = y1_numel + clone_numel = y1_numel - y0_numel + else: + min_numel = y1_numel + max_numel = y0_numel + clone_numel = y0_numel - y1_numel + + for loops_y in range(loops_y0): + y0 = y0_offset + (loops_y * YBLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(YBLOCK + y0_offset, min_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < x1_numel + + tmp0 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp10 = tl.zeros((2 * YBLOCK_SUB, X1BLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [YBLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [YBLOCK_SUB, 0], [YBLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (2, YBLOCK_SUB, X1BLOCK_SUB)) + + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :] + new_base_y1 = tl.arange(0, YBLOCK_SUB) + new_y1 = y0_offset + (loops_y * YBLOCK_SUB) + new_base_y1[None, :, None] + new_z0 = tl.arange(0, 2)[:, None, None] + new_x2_mask = new_x2 < x1_numel + new_y1_mask = new_y1 < min_numel + tl.store(output_ptr + (new_x2 + x1_numel * new_y1 + x1_numel * y0_numel * new_z0), tmp13, new_x2_mask & new_y1_mask) + + loops_y1 = (YBLOCK_2 + YBLOCK_SUB - 1) // YBLOCK_SUB + y2_offset = tl.program_id(0) * YBLOCK_2 + min_numel + if y0_numel < y1_numel: + for loops_y1 in range(loops_y1): + y0 = y2_offset + (loops_y1 * YBLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(YBLOCK_2 + y2_offset, y1_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < x1_numel + + tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, YBLOCK_SUB) + new_y1 = y2_offset + y0_numel + (loops_y1 * YBLOCK_SUB) + new_base_y1[:, None] + sum_numel = y0_numel + y1_numel + new_x2_mask = new_x2 < x1_numel + new_y1_mask = new_y1 < sum_numel + tl.store(output_ptr + (new_x2 + x1_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + else: + for loops_y1 in range(loops_y1): + y0 = y2_offset + (loops_y1 * YBLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(YBLOCK_2 + y2_offset, y0_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < x1_numel + + tmp8 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, YBLOCK_SUB) + new_y1 = y2_offset + (loops_y1 * YBLOCK_SUB) + new_base_y1[:, None] + new_x2_mask = new_x2 < x1_numel + new_y1_mask = new_y1 < y0_numel + tl.store(output_ptr + (new_x2 + x1_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + + +@triton.jit +def triton_unk_fused_cat_dim1_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): + y0_offset = tl.program_id(0) * Y0BLOCK + base_y0 = tl.arange(0, Y0BLOCK_SUB) + loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB + base_x1 = tl.arange(0, X1BLOCK_SUB) + loops_x1 = (x1_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < x1_numel + + tmp0 = tl.load(x_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp8 = tl.load(y_ptr + (x1 + x1_numel * y0), x1_mask & y0_mask) + tmp10 = tl.zeros((Y0BLOCK_SUB, 2 * X1BLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [0, X1BLOCK_SUB], [Y0BLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (Y0BLOCK_SUB, 2, X1BLOCK_SUB)) + + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None, None] + new_z0 = tl.arange(0, 2)[None, :, None] + new_x2_mask = new_x2 < x1_numel + new_y1_mask = new_y1 < y0_numel + tl.store(output_ptr + (new_x2 + 2 * x1_numel * new_y1 + x1_numel * new_z0), tmp13, new_x2_mask & new_y1_mask) + + +@triton.jit +def triton_unk_fused_cat_dim1_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, x0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): + y0_offset = tl.program_id(0) * Y0BLOCK + base_y0 = tl.arange(0, Y0BLOCK_SUB) + loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB + base_x = tl.arange(0, XBLOCK_SUB) + min_numel = 0 + max_numel = 0 + clone_numel = 0 + if x0_numel < x1_numel: + min_numel = x0_numel + max_numel = x1_numel + clone_numel = x1_numel - x0_numel + else: + min_numel = x1_numel + max_numel = x0_numel + clone_numel = x0_numel - x1_numel + loops_x = (min_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + loops_x2 = (clone_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + for loop_x in range(loops_x): + x = (loop_x * XBLOCK_SUB) + base_x[None, :] + x_mask = x < min_numel + + tmp0 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask) + tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask) + tmp10 = tl.zeros((Y0BLOCK_SUB, 2 * XBLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [0, XBLOCK_SUB], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (Y0BLOCK_SUB, 2, XBLOCK_SUB)) + + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = (loop_x * XBLOCK_SUB) + new_base_x2[None, None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None, None] + new_z0 = tl.arange(0, 2)[None, :, None] + new_x2_mask = new_x2 < min_numel + new_y1_mask = new_y1 < y0_numel + sum_numel = x0_numel + x1_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1 + x0_numel * new_z0), tmp13, new_x2_mask & new_y1_mask) + + if x0_numel < x1_numel: + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < x1_numel + + tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = x0_numel + min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None] + sum_numel = x0_numel + x1_numel + new_x2_mask = new_x2 < sum_numel + new_y1_mask = new_y1 < y0_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + else: + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, y0_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < x0_numel + + tmp8 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None] + sum_numel = x0_numel + x1_numel + new_x2_mask = new_x2 < x0_numel + new_y1_mask = new_y1 < y0_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + + +@triton.jit +def triton_unk_fused_cat_3d_dim0(output_ptr, x_ptr, y_ptr, z0_numel, z1_numel, y1_numel, x1_numel, ZBLOCK: tl.constexpr, ZBLOCK_2: tl.constexpr, ZBLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): + z0_offset = tl.program_id(0) * ZBLOCK + base_z0 = tl.arange(0, ZBLOCK_SUB) + loops_z0 = (ZBLOCK + ZBLOCK_SUB - 1) // ZBLOCK_SUB + xy_numel = x1_numel * y1_numel + base_x1 = tl.arange(0, X1BLOCK_SUB) + loops_x1 = (xy_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB + min_numel = 0 + max_numel = 0 + clone_numel = 0 + if z0_numel < z1_numel: + min_numel = z0_numel + max_numel = z1_numel + clone_numel = z1_numel - z0_numel + else: + min_numel = z1_numel + max_numel = z0_numel + clone_numel = z0_numel - z1_numel + + for loops_z in range(loops_z0): + z0 = z0_offset + (loops_z * ZBLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(ZBLOCK + z0_offset, min_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < xy_numel + + tmp0 = tl.load(x_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask) + tmp8 = tl.load(y_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask) + tmp10 = tl.zeros((2 * ZBLOCK_SUB, X1BLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [ZBLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [ZBLOCK_SUB, 0], [ZBLOCK_SUB, X1BLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (2, ZBLOCK_SUB, X1BLOCK_SUB)) + + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, None, :] + new_base_z1 = tl.arange(0, ZBLOCK_SUB) + new_z1 = z0_offset + (loops_z * ZBLOCK_SUB) + new_base_z1[None, :, None] + new_z0 = tl.arange(0, 2)[:, None, None] + new_x2_mask = new_x2 < xy_numel + new_z1_mask = new_z1 < min_numel + tl.store(output_ptr + (new_x2 + xy_numel * new_z1 + xy_numel * z0_numel * new_z0), tmp13, new_x2_mask & new_z1_mask) + + loops_z1 = (ZBLOCK_2 + ZBLOCK_SUB - 1) // ZBLOCK_SUB + z2_offset = tl.program_id(0) * ZBLOCK_2 + min_numel + if z0_numel < z1_numel: + for loops_z1 in range(loops_z1): + z0 = z2_offset + (loops_z1 * ZBLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(ZBLOCK_2 + z2_offset, z1_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < xy_numel + + tmp8 = tl.load(y_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask) + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :] + new_base_z1 = tl.arange(0, ZBLOCK_SUB) + new_z1 = z2_offset + z0_numel + (loops_z1 * ZBLOCK_SUB) + new_base_z1[:, None] + sum_numel = z0_numel + z1_numel + new_x2_mask = new_x2 < xy_numel + new_z1_mask = new_z1 < sum_numel + tl.store(output_ptr + (new_x2 + xy_numel * new_z1), tmp8, new_x2_mask & new_z1_mask) + else: + for loops_z1 in range(loops_z1): + z0 = z2_offset + (loops_z1 * ZBLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(ZBLOCK_2 + z2_offset, z0_numel) + for loop_x1 in range(loops_x1): + x1 = (loop_x1 * X1BLOCK_SUB) + base_x1[None, :] + x1_mask = x1 < xy_numel + + tmp8 = tl.load(x_ptr + (x1 + xy_numel * z0), x1_mask & z0_mask) + new_base_x2 = tl.arange(0, X1BLOCK_SUB) + new_x2 = (loop_x1 * X1BLOCK_SUB) + new_base_x2[None, :] + new_base_z1 = tl.arange(0, ZBLOCK_SUB) + new_z1 = z2_offset + (loops_z1 * ZBLOCK_SUB) + new_base_z1[:, None] + new_x2_mask = new_x2 < xy_numel + new_z1_mask = new_z1 < z0_numel + tl.store(output_ptr + (new_x2 + xy_numel * new_z1), tmp8, new_x2_mask & new_z1_mask) + + +@triton.jit +def triton_unk_fused_cat_3d_dim1(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, y1_numel, x0_numel, Z0BLOCK: tl.constexpr, Z0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): + z0_offset = tl.program_id(0) * Z0BLOCK + base_z0 = tl.arange(0, Z0BLOCK_SUB) + loops_z0 = (Z0BLOCK + Z0BLOCK_SUB - 1) // Z0BLOCK_SUB + base_x = tl.arange(0, XBLOCK_SUB) + min_numel = 0 + max_numel = 0 + clone_numel = 0 + if y0_numel < y1_numel: + min_numel = y0_numel * x0_numel + max_numel = y1_numel * x0_numel + clone_numel = (y1_numel - y0_numel) * x0_numel + else: + min_numel = y1_numel * x0_numel + max_numel = y0_numel * x0_numel + clone_numel = (y0_numel - y1_numel) * x0_numel + loops_x = (min_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + loops_x2 = (clone_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop_z0 in range(loops_z0): + z0 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(Z0BLOCK + z0_offset, z0_numel) + for loop_x in range(loops_x): + x = (loop_x * XBLOCK_SUB) + base_x[None, :] + x_mask = x < min_numel + + tmp0 = tl.load(x_ptr + (x + x0_numel * y0_numel * z0), x_mask & z0_mask) + tmp8 = tl.load(y_ptr + (x + x0_numel * y1_numel * z0), x_mask & z0_mask) + tmp10 = tl.zeros((Z0BLOCK_SUB, 2 * XBLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Z0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [0, XBLOCK_SUB], [Z0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (Z0BLOCK_SUB, 2, XBLOCK_SUB)) + + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = (loop_x * XBLOCK_SUB) + new_base_x2[None, None, :] + new_base_z1 = tl.arange(0, Z0BLOCK_SUB) + new_z1 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + new_base_z1[:, None, None] + new_z0 = tl.arange(0, 2)[None, :, None] + new_x2_mask = new_x2 < min_numel + new_z1_mask = new_z1 < z0_numel + sum_numel = min_numel + max_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_z1 + x0_numel * y0_numel * new_z0), tmp13, new_x2_mask & new_z1_mask) + + if y0_numel == y1_numel: + return + + if y0_numel < y1_numel: + for loop_z0 in range(loops_z0): + z0 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(Z0BLOCK + z0_offset, z0_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < y1_numel * x0_numel + + tmp8 = tl.load(y_ptr + (x + x0_numel * y1_numel * z0), x_mask & z0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = x0_numel * y0_numel + min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_z1 = tl.arange(0, Z0BLOCK_SUB) + new_z1 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + new_base_z1[:, None] + sum_numel = min_numel + max_numel + new_x2_mask = new_x2 < sum_numel + new_z1_mask = new_z1 < z0_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_z1), tmp8, new_x2_mask & new_z1_mask) + else: + for loop_z0 in range(loops_z0): + z0 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + base_z0[:, None] + z0_mask = z0 < min(Z0BLOCK + z0_offset, z0_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < x0_numel * y0_numel + + tmp8 = tl.load(x_ptr + (x + x0_numel * y0_numel * z0), x_mask & z0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_z1 = tl.arange(0, Z0BLOCK_SUB) + new_z1 = z0_offset + (loop_z0 * Z0BLOCK_SUB) + new_base_z1[:, None] + sum_numel = min_numel + max_numel + new_x2_mask = new_x2 < x0_numel * y0_numel + new_z1_mask = new_z1 < z0_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_z1), tmp8, new_x2_mask & new_z1_mask) + + +@triton.jit +def triton_unk_fused_cat_3d_dim2(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, x0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): + y0_offset = tl.program_id(0) * Y0BLOCK + base_y0 = tl.arange(0, Y0BLOCK_SUB) + loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB + base_x = tl.arange(0, XBLOCK_SUB) + min_numel = 0 + max_numel = 0 + clone_numel = 0 + zy_numel = z0_numel * y0_numel + if x0_numel < x1_numel: + min_numel = x0_numel + max_numel = x1_numel + clone_numel = x1_numel - x0_numel + else: + min_numel = x1_numel + max_numel = x0_numel + clone_numel = x0_numel - x1_numel + loops_x = (min_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + loops_x2 = (clone_numel + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, zy_numel) + for loop_x in range(loops_x): + x = (loop_x * XBLOCK_SUB) + base_x[None, :] + x_mask = x < min_numel + + tmp0 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask) + tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask) + tmp10 = tl.zeros((Y0BLOCK_SUB, 2 * XBLOCK_SUB), dtype=tmp0.dtype) + tmp11 = extension.insert_slice(tmp10, tmp0, [0, 0], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp12 = extension.insert_slice(tmp11, tmp8, [0, XBLOCK_SUB], [Y0BLOCK_SUB, XBLOCK_SUB], [1, 1]) + tmp13 = tl.reshape(tmp12, (Y0BLOCK_SUB, 2, XBLOCK_SUB)) + + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = (loop_x * XBLOCK_SUB) + new_base_x2[None, None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None, None] + new_z0 = tl.arange(0, 2)[None, :, None] + new_x2_mask = new_x2 < min_numel + new_y1_mask = new_y1 < zy_numel + sum_numel = x0_numel + x1_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1 + x0_numel * new_z0), tmp13, new_x2_mask & new_y1_mask) + + if x0_numel == x1_numel: + return + + if x0_numel < x1_numel: + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, zy_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < x1_numel + + tmp8 = tl.load(y_ptr + (x + x1_numel * y0), x_mask & y0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = x0_numel + min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None] + sum_numel = x0_numel + x1_numel + new_x2_mask = new_x2 < sum_numel + new_y1_mask = new_y1 < zy_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + else: + for loop_y0 in range(loops_y0): + y0 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + base_y0[:, None] + y0_mask = y0 < min(Y0BLOCK + y0_offset, zy_numel) + for loop_x2 in range(loops_x2): + x = (loop_x2 * XBLOCK_SUB) + base_x[None, :] + min_numel + x_mask = x < x0_numel + + tmp8 = tl.load(x_ptr + (x + x0_numel * y0), x_mask & y0_mask) + new_base_x2 = tl.arange(0, XBLOCK_SUB) + new_x2 = min_numel + (loop_x2 * XBLOCK_SUB) + new_base_x2[None, :] + new_base_y1 = tl.arange(0, Y0BLOCK_SUB) + new_y1 = y0_offset + (loop_y0 * Y0BLOCK_SUB) + new_base_y1[:, None] + sum_numel = x0_numel + x1_numel + new_x2_mask = new_x2 < x0_numel + new_y1_mask = new_y1 < zy_numel + tl.store(output_ptr + (new_x2 + sum_numel * new_y1), tmp8, new_x2_mask & new_y1_mask) + + +testlist = [ + # ===================== 1D场景(15组,dim=0) ===================== + ((3,), (3,), 0), + ((7,), (9,), 0), + ((13,), (11,), 0), + ((2047,), (2047,), 0), + ((2701,), (3003,), 0), + ((4093,), (3095,), 0), + + # ===================== 2D场景(20组,dim0/dim1) ===================== + # dim0(行拼接,列维度一致) + ((3, 5), (3, 5), 0), + ((1005, 300), (2007, 300), 0), + ((1307, 400), (309, 400), 0), + ((303, 500), (303, 500), 0), + # dim1(列拼接,行维度一致) + ((7, 9), (7, 9), 1), + ((100, 1001), (100, 2003), 1), + ((200, 2005), (200, 207), 1), + ((300, 707), (300, 707), 1), + + # ===================== 3D场景(15组,dim0/dim1/dim2) ===================== + # dim0(第0维拼接,d1/d2一致) + ((378, 200, 300), (101, 200, 300), 0), + ((378, 70, 50), (601, 70, 50), 0), + # dim1(第1维拼接,d0/d2一致) + ((100, 452, 300), (100, 201, 300), 1), + ((65, 1735, 57), (65, 2001, 57), 1), + # dim2(第2维拼接,d0/d1一致) + ((87, 200, 387), (87, 200, 501), 2), + ((20, 337, 543), (20, 337, 401), 2), +] + + +@pytest.mark.parametrize('testlists', testlist) +@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_cat_bigshape(testlists, dtype): + torch_dtype = eval('torch.' + dtype) + np_x0 = test_common.generate_numpy(testlists[0], dtype) + np_x1 = test_common.generate_numpy(testlists[1], dtype) + cat_dim = testlists[2] + + x0 = torch.from_numpy(np_x0).to(torch_dtype).npu() + x1 = torch.from_numpy(np_x1).to(torch_dtype).npu() + + if len(x0.shape) > 3: + pytest.skip("dim > 3 for 3D+ tensor, skipping.") + + torch_res = torch.cat([x0, x1], dim=cat_dim) + triton_res = torch.zeros_like(torch_res) + num_core = 32 + if len(x0.shape) == 3: + if cat_dim == 0: + ZBLOCK = (min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core + ZBLOCK_2 = (max(x0.shape[0], x1.shape[0]) - min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core + triton_unk_fused_cat_3d_dim0[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x1.shape[0], x0.shape[1], x0.shape[2], ZBLOCK, ZBLOCK_2, 1, 256) + elif cat_dim == 1: + Z0BLOCK = (x0.shape[0] + num_core - 1) // num_core + triton_unk_fused_cat_3d_dim1[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x1.shape[1], x1.shape[2], Z0BLOCK, 1, 256) + else: + Y0BLOCK = (x0.shape[0] * x0.shape[1] + num_core - 1) // num_core + triton_unk_fused_cat_3d_dim2[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x0.shape[2], x1.shape[2], Y0BLOCK, 1, 256) + test_common.validate_cmp(dtype, torch_res, triton_res) + return + numel_large = torch_res.numel() > 512 and len(x0.shape) < 3 + if numel_large or (cat_dim == 0 and len(x0.shape) == 2): + squeeze_flag = False + if len(x0.shape) == 1: + squeeze_flag = True + x0 = torch.unsqueeze(x0, dim=0) + x1 = torch.unsqueeze(x1, dim=0) + triton_res = torch.unsqueeze(triton_res, dim=0) + cat_dim = 1 + if cat_dim == 1: + Y0BLOCK = (x0.shape[0] + num_core - 1) // num_core + if x0.shape[1] == x1.shape[1]: + triton_unk_fused_cat_dim1_sameshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], Y0BLOCK, 1, 256) + else: + triton_unk_fused_cat_dim1_diffshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x1.shape[1], Y0BLOCK, 1, 256) + else: + if x0.shape[0] == x1.shape[0]: + Y0BLOCK = (x0.shape[0] + num_core - 1) // num_core + triton_unk_fused_cat_dim0_sameshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], Y0BLOCK, 1, 256) + else: + YBLOCK = (min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core + YBLOCK_2 = (max(x0.shape[0], x1.shape[0]) - min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core + triton_unk_fused_cat_dim0_diffshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x1.shape[0], x1.shape[1], YBLOCK, YBLOCK_2, 1, 256) + if squeeze_flag: + triton_res = triton_res.squeeze() + else: + squeeze_flag = False + if len(x0.shape) == 1: + squeeze_flag = True + x0 = torch.unsqueeze(x0, dim=0) + x1 = torch.unsqueeze(x1, dim=0) + triton_res = torch.unsqueeze(triton_res, dim=0) + _cat_helper_func_2D_1[num_core, 1, 1](x0, x1, triton_res, x0.shape[1], x1.shape[1], x0.shape[0], x0.shape[1] + x1.shape[1], 256, 16) + if squeeze_flag: + triton_res = triton_res.squeeze() + + test_common.validate_cmp(dtype, torch_res, triton_res) diff --git a/third_party/ascend/unittest/pytest_ut/test_celoss_indices.py b/third_party/ascend/unittest/pytest_ut/test_celoss_indices.py new file mode 100644 index 0000000000..c9cca12573 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_celoss_indices.py @@ -0,0 +1,108 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +import torch +import triton +import triton.language as tl + + +@triton.jit +def celoss_indices_kernel( + inp_ptr, + tgt_ptr, + w_ptr, + out_ptr, + w_tgt_ptr, + ignore_index, + C, + D, + BLOCK_C: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_d = tl.program_id(0).to(tl.int64) + pid_n = tl.program_id(1).to(tl.int64) + offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D).to(tl.int64) + + tgt_ptrs = tgt_ptr + pid_n * D + offset_d + tgt_mask = offset_d < D + tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) + + ignore_mask = not (tgt == ignore_index) and tgt_mask + + tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) + tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) + + for off in range(0, C, BLOCK_C): + offset_c = off + tl.arange(0, BLOCK_C) + inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] + inp_mask = offset_c[:, None] < C and offset_d[None, :] < D + inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) + cur_max = tl.maximum(tmp_max, inp) + cur_exp = tl.exp(inp - cur_max) + tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp + tmp_max = cur_max + + final_max = tl.max(tmp_max, axis=0) + tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :]) + final_sum = tl.log(tl.sum(tmp_sum, axis=0)) + inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d + inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32) + + out = final_sum + final_max - inp_tgt + w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d + + if w_ptr is None: + w_tgt = ignore_mask + else: + w_tgt = tl.load(w_ptr + tgt, mask=tgt_mask, other=0).to(tl.float32) + w_tgt = tl.where(ignore_mask, w_tgt, 0) + + tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) + out *= w_tgt + out_ptrs = out_ptr + pid_n * D + offset_d + tl.store(out_ptrs, out, mask=tgt_mask) + + +def test_celoss_indices_kernel(shape=(1, 2)): + device = "npu" + dtype = torch.float16 + ignore_index = -100 + BLOCK_C = 256 + BLOCK_D = 1 + + N, C = shape + D = 1 + + inp = torch.randn(shape, dtype=dtype, device=device) + tgt = torch.randint(0, C, (N,), dtype=torch.int64, device=device) + wgt = torch.randn(C, dtype=dtype, device=device) + + out_triton = torch.empty((N * D,), dtype=torch.float32, device=device) + w_tgt_triton = torch.empty((N * D,), dtype=torch.float32, device=device) + + grid = (triton.cdiv(D, BLOCK_D), N) + celoss_indices_kernel[grid]( + inp, tgt, wgt, out_triton, w_tgt_triton, + ignore_index, + C, D, + BLOCK_C=BLOCK_C, + BLOCK_D=BLOCK_D, + ) diff --git a/third_party/ascend/unittest/pytest_ut/test_compile_hint.py b/third_party/ascend/unittest/pytest_ut/test_compile_hint.py index 87b7fb3463..7ab3173013 100644 --- a/third_party/ascend/unittest/pytest_ut/test_compile_hint.py +++ b/third_party/ascend/unittest/pytest_ut/test_compile_hint.py @@ -45,9 +45,12 @@ def triton_compile_hint(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_ tl.store(out_ptr0 + (xindex), tmp2, xmask) -@pytest.mark.parametrize('param_list', [ - ['float32', (2, 4096, 8), 2, 32768, 1024], -]) +@pytest.mark.skip(reason="not supported after the NPUIR is updated in April, and will be fixed later") +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ] + ) def test_compile_hint(param_list): dtype, shape, ncore, xblock, xblock_sub = param_list x0 = test_common.generate_tensor(shape, dtype).npu() diff --git a/third_party/ascend/unittest/pytest_ut/test_complex_mask.py b/third_party/ascend/unittest/pytest_ut/test_complex_mask.py new file mode 100644 index 0000000000..7b26b447d6 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_complex_mask.py @@ -0,0 +1,69 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + + +def copy(x): + return x.clone() + + +@triton.jit +def copy_kernel(in_ptr, out_ptr, N: tl.constexpr, NUMEL): + idx_block = tl.arange(0, N) + is_valid = N <= NUMEL + x = tl.load(in_ptr + idx_block, mask=idx_block < N) + mask_i1 = is_valid[:, None] & (idx_block < N)[None, :] + tl.store(out_ptr + idx_block[None, :], x[None, :], mask=mask_i1) + + +@triton.jit +def permute_copy_kernel(in_ptr, out_ptr, M: tl.constexpr, N: tl.constexpr, NUMEL): + idx_block_n = tl.arange(0, N) + idx_block_m = tl.arange(0, M) + idx_block = idx_block_m[:, None] + idx_block_n[None, :] * M + is_valid = N <= NUMEL + x = tl.load(in_ptr + idx_block, mask=(idx_block_m[:, None] < M) & (idx_block_n[None, :] < N)) + mask_i1 = (is_valid[:, None, None]) & (idx_block_m[None, :, None] < M) & (idx_block_n[None, None, :] < N) + tl.store(out_ptr + idx_block[None, :], x[None, :], mask=mask_i1) + + +def test_complex_mask_copy(): + N = 1024 + x = torch.randn(N, dtype=torch.float32).npu() + y = torch.empty_like(x).npu() + copy_kernel[(1,)](x, y, N=N, NUMEL=N) + torch.testing.assert_close(x, y) + + +def test_complex_mask_permute_copy(): + M = 4 + N = 32 + x = torch.randn(M * N, dtype=torch.float32).npu() + y = torch.empty_like(x).npu() + permute_copy_kernel[(1,)](x, y, M=M, N=N, NUMEL=M * N) + torch.testing.assert_close(x, y) diff --git a/third_party/ascend/unittest/pytest_ut/test_copy.py b/third_party/ascend/unittest/pytest_ut/test_copy.py index f0a2a778be..99af972025 100644 --- a/third_party/ascend/unittest/pytest_ut/test_copy.py +++ b/third_party/ascend/unittest/pytest_ut/test_copy.py @@ -70,15 +70,16 @@ def copy( a1_val = tl.load(a1_ptr) add = tl.add(a_val, a1_val) - add_ub = bl.to_buffer(add, al.ascend_address_space.UB) + A_l1 = bl.alloc(tl.float32, [M, N], al.ascend_address_space.L1) al.copy_from_ub_to_l1(add_ub, A_l1) + A_ub = bl.alloc(tl.float32, [M, N], al.ascend_address_space.UB) + al.copy(add_ub, A_ub) -# ============== Main for manual testing ============== -if __name__ == "__main__": +def test_copy(): print("=" * 60) print("Test 1: copy ") print("=" * 60) @@ -89,3 +90,7 @@ def copy( ) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) + +# ============== Main for manual testing ============== +if __name__ == "__main__": + test_copy() diff --git a/third_party/ascend/unittest/pytest_ut/test_cumprod.py b/third_party/ascend/unittest/pytest_ut/test_cumprod.py index c4c6833574..feed59b7c4 100644 --- a/third_party/ascend/unittest/pytest_ut/test_cumprod.py +++ b/third_party/ascend/unittest/pytest_ut/test_cumprod.py @@ -89,7 +89,7 @@ def cumprod_generate_tensor(shape, dtype): @pytest.mark.parametrize("dtype", support_dtypes) @pytest.mark.parametrize("shape", [(7, 23)]) @pytest.mark.parametrize("dim", [0, 1]) -@pytest.mark.parametrize("reverse", [False]) +@pytest.mark.parametrize("reverse", [False, True]) def test_cumprod(dtype, shape, dim, reverse): x0 = cumprod_generate_tensor(shape=shape, dtype=dtype).npu() triton_cal = triton_func(x0, dim, reverse) diff --git a/third_party/ascend/unittest/pytest_ut/test_cumsum.py b/third_party/ascend/unittest/pytest_ut/test_cumsum.py index edc4553a0a..c9196b6bf7 100644 --- a/third_party/ascend/unittest/pytest_ut/test_cumsum.py +++ b/third_party/ascend/unittest/pytest_ut/test_cumsum.py @@ -73,7 +73,7 @@ def triton_func(x, dim, reverse): @pytest.mark.parametrize("dtype", support_dtypes) @pytest.mark.parametrize("shape", [(7, 23)]) @pytest.mark.parametrize("dim", [0, 1]) -@pytest.mark.parametrize("reverse", [False]) +@pytest.mark.parametrize("reverse", [False, True]) def test_cumsum(dtype, shape, dim, reverse): x0 = generate_tensor(shape=shape, dtype=dtype).npu() triton_cal = triton_func(x0, dim, reverse) diff --git a/third_party/ascend/unittest/pytest_ut/test_custom.py b/third_party/ascend/unittest/pytest_ut/test_custom.py index 9c2d35c0b8..0c1b6488a5 100755 --- a/third_party/ascend/unittest/pytest_ut/test_custom.py +++ b/third_party/ascend/unittest/pytest_ut/test_custom.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import subprocess +import os import triton import triton.language as tl import triton.language.extra.cann.extension as al @@ -39,9 +40,31 @@ class my_custom_op: core = al.CORE.VECTOR pipe = al.PIPE.PIPE_V mode = al.MODE.SIMT + symbol = "my_custom_func" + # fake path, this test only check Triton successfully lowered to MLIR + bitcode = os.path.abspath(__file__) + iterator_types = [ + al.IteratorType.Parallel, + al.IteratorType.Broadcast, + al.IteratorType.Transpose, + al.IteratorType.Reduction, + al.IteratorType.Interleave, + al.IteratorType.Deinterleave, + al.IteratorType.Inverse, + al.IteratorType.Pad, + al.IteratorType.Concat, + al.IteratorType.Gather, + al.IteratorType.Cumulative, + al.IteratorType.Opaque, + ] def __init__(self, x, ptr1, ptr2, offset: tl.int64, other, out=None): - pass + # Add optional custom-op attribute: ArrayAttr. + self.indexing_map = [al.affine_map.get_identity(1)] + + # Tag ptr2 as an argument that should be aligned at dimension 1. + # Tag 2nd argument that should be aligned at dimension 0. + self.align_dim = {"ptr2": 1, 1 : 0} @triton.jit @@ -55,6 +78,90 @@ def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): tl.store(out_ptr + i, result, mask=i < n) +@al.register_custom_op +class my_custom_op_extra_buf: + """Custom op declaring extra_buffers with several scalar Triton dtypes.""" + + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "my_extra_buf_func" + bitcode = os.path.abspath(__file__) + + def __init__(self, x, out=None): + self.indexing_map = [al.affine_map.get_identity(1)] + self.extra_buffers = [ + (tl.bfloat16, 256), + (tl.float64, 424242), + (tl.int8, 11), + (tl.float16, 22), + (tl.int32, 33), + ] + + +@al.register_custom_op +class my_custom_op_extra_buf_single_buf: + """Custom op declaring extra_buffers with single buf.""" + + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "my_extra_buf_func_single_buf" + bitcode = os.path.abspath(__file__) + + def __init__(self, x, out=None): + self.indexing_map = [al.affine_map.get_identity(1)] + self.extra_buffers = (tl.bfloat16, 256) + + +@triton.jit +def kernel_extra_buf(x_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(out_ptr + i, mask=i < n) + r = al.custom("my_custom_op_extra_buf", x, out=y) + tl.store(out_ptr + i, r, mask=i < n) + + +@triton.jit +def kernel_extra_buf_single_buf(x_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(out_ptr + i, mask=i < n) + r = al.custom("my_custom_op_extra_buf_single_buf", x, out=y) + tl.store(out_ptr + i, r, mask=i < n) + + +@al.register_custom_op +class my_custom_op_extra_buf_wide: + """Cover more integer widths and unsigned dtypes in extra_buffers_types.""" + + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + symbol = "my_extra_buf_wide_func" + bitcode = os.path.abspath(__file__) + + def __init__(self, x, out=None): + self.indexing_map = [al.affine_map.get_identity(1)] + self.extra_buffers = [ + (tl.int16, 1001), + (tl.uint16, 1002), + (tl.int64, 1003), + (tl.uint32, 1004), + (tl.uint8, 1005), + ] + + +@triton.jit +def kernel_extra_buf_wide(x_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(out_ptr + i, mask=i < n) + r = al.custom("my_custom_op_extra_buf_wide", x, out=y) + tl.store(out_ptr + i, r, mask=i < n) + + # ============== Pytest tests ============== @@ -75,15 +182,120 @@ def test_custom_op(): assert "hivm.pipe = #hivm.pipe" in line assert "hivm.tcore_type = #hivm.tcore_type" in line assert "hivm.vf_mode = #hivm.vf_mode" in line + # Optional indexing map attribute should be attached. + assert "indexing_map = [" in line + # Tagged argument alignment info is attached as integer operand attr. + assert "align_dim = 1" in line + assert "align_dim = 0" in line # All offset converted to int64. assert 'i64, ' in line assert 'i32, ' not in line + assert "iterator_types" in line + for iterator_name in ( + "parallel", + "broadcast", + "transpose", + "reduction", + "interleave", + "deinterleave", + "inverse", + "pad", + "concat", + "gather", + "cumulative", + "opaque", + ): + assert iterator_name in line + + +def _custom_lines(mlir: str, op_name: str): + # Match the MLIR string attribute exactly (avoid `my_custom_op` matching + # `my_custom_op_extra_buf`). + quoted = f'"{op_name}"' + return [ + line for line in mlir.splitlines() + if "hivm.hir.custom" in line and quoted in line + ] + + +def test_custom_op_extra_buffers_mixed_scalar_types(): + """extra_buffers_types must preserve bf16/f64/i8/f16/i32 (not all lowered to f32).""" + mlir = compile_kernel( + kernel_extra_buf, + {"x_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + assert mlir and len(mlir) > 0 + lines = _custom_lines(mlir, "my_custom_op_extra_buf") + assert lines, "expected at least one hivm.hir.custom line for my_custom_op_extra_buf" + line = lines[0] + assert "extra_buffers_types" in line + assert "extra_buffers_sizes" in line + assert "bf16" in line + assert "f64" in line + assert "i8" in line + assert "f16" in line + assert "i32" in line + assert "424242" in line + + +def test_custom_op_extra_buffers_single_buffer(): + mlir = compile_kernel( + kernel_extra_buf_single_buf, + {"x_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + assert mlir and len(mlir) > 0 + lines = _custom_lines(mlir, "my_custom_op_extra_buf_single_buf") + assert lines, "expected at least one hivm.hir.custom line for my_custom_op_extra_buf_single_buf" + line = lines[0] + assert "extra_buffers_types" in line + assert "extra_buffers_sizes" in line + assert "f32" in line + + +def test_custom_op_extra_buffers_integer_variants(): + """extra_buffers accept int16/uint16/int64/uint32/uint8 (IR uses i* storage types).""" + mlir = compile_kernel( + kernel_extra_buf_wide, + {"x_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + assert mlir and len(mlir) > 0 + lines = _custom_lines(mlir, "my_custom_op_extra_buf_wide") + assert lines + line = lines[0] + assert "extra_buffers_types" in line + assert "extra_buffers_sizes" in line + assert "i16" in line + assert "i64" in line + assert "i32" in line + assert "i8" in line + assert "1001" in line and "1005" in line + + +def test_custom_op_without_extra_buffers_has_no_extra_buffer_attrs(): + """Ops that do not set extra_buffers should not emit extra_buffers_* attributes.""" + mlir = compile_kernel( + my_kernel, + {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + assert mlir + for line in _custom_lines(mlir, "my_custom_op"): + assert "extra_buffers_types" not in line + assert "extra_buffers_sizes" not in line # ============== Main for manual testing ============== if __name__ == "__main__": - mlir = compile_kernel(my_kernel, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, - {"BLOCK": 256}) + test_custom_op() + test_custom_op_without_extra_buffers_has_no_extra_buffer_attrs() + test_custom_op_extra_buffers_integer_variants() + test_custom_op_extra_buffers_mixed_scalar_types() + test_custom_op_extra_buffers_single_buffer() + mlir = compile_kernel(my_kernel, + {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, {"BLOCK": 256}) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_atomic.py b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_atomic.py new file mode 100644 index 0000000000..b50ecdf7d5 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_atomic.py @@ -0,0 +1,48 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import triton +import triton.language as tl +import torch_npu +import pytest + + +@triton.jit +def single_disc_mask_atomic_add_kernel( + in_ptr, + BLOCK_N: tl.constexpr, +): + col_offs = tl.arange(0, BLOCK_N) + disc_mask = (col_offs * 2) < BLOCK_N + ptr_in = in_ptr + col_offs + tl.atomic_add(ptr_in, 1, mask=disc_mask) + + +@pytest.mark.parametrize("BLOCK_N", [8]) +def test_single_discrete_mask_atomic_add(BLOCK_N): + in_tensor = torch.arange(BLOCK_N, dtype=torch.float16, device='npu') + expected = in_tensor.clone() + single_disc_mask_atomic_add_kernel[(1,)](in_tensor, BLOCK_N=BLOCK_N) + + half = BLOCK_N // 2 + expected[:half] += 1 + assert torch.allclose(in_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{in_tensor.cpu()}" \ No newline at end of file diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py index a13322f361..52b4eccbad 100644 --- a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py @@ -18,6 +18,21 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +# ============================================================================= +# Discrete mask access conversion test suite +# +# Test matrix (mask type x operation type): +# +# | mask type | load only | store only | load + store | +# |---------------------------------|-----------|------------|--------------| +# | single discrete mask | (A) | (B) | - | +# | single continuous mask | (C) | (D) | - | +# | continuous & discrete 2-way | (E) | (F) | (G) | +# | continuous & discrete 4-way | - | - | (H) | +# | broadcast(cont & disc) 2-D AND | (I) | - | (J) | +# +# ============================================================================= + import torch import triton import triton.language as tl @@ -25,8 +40,207 @@ import pytest +# ============================================================================= +# (A) Single discrete mask -- load only +# ============================================================================= +@triton.jit +def single_disc_mask_load_kernel( + in_ptr, + out_ptr, + BLOCK_N: tl.constexpr, +): + col_offs = tl.arange(0, BLOCK_N) + disc_mask = (col_offs * 2) < BLOCK_N + ptr_in = in_ptr + col_offs + ptr_out = out_ptr + col_offs + data = tl.load(ptr_in, mask=disc_mask, other=0.0) + tl.store(ptr_out, data) + + +@pytest.mark.parametrize("BLOCK_N", [8]) +def test_single_discrete_mask_load(BLOCK_N): + in_tensor = torch.arange(BLOCK_N, dtype=torch.float16, device='npu') + out_tensor = torch.empty(BLOCK_N, dtype=torch.float16, device='npu') + + single_disc_mask_load_kernel[(1,)](in_tensor, out_tensor, BLOCK_N=BLOCK_N) + + half = BLOCK_N // 2 + expected = torch.zeros(BLOCK_N, dtype=torch.float16, device='npu') + expected[:half] = in_tensor[:half] + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (B) Single discrete mask -- store only +# ============================================================================= +@triton.jit +def single_disc_mask_store_kernel( + in_ptr, + out_ptr, + BLOCK_N: tl.constexpr, +): + col_offs = tl.arange(0, BLOCK_N) + disc_mask = (col_offs * 2) < BLOCK_N + ptr_in = in_ptr + col_offs + ptr_out = out_ptr + col_offs + data = tl.load(ptr_in) + tl.store(ptr_out, data, mask=disc_mask) + + +@pytest.mark.parametrize("BLOCK_N", [8]) +def test_single_discrete_mask_store(BLOCK_N): + in_tensor = torch.arange(BLOCK_N, dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_N,), -1.0, dtype=torch.float16, device='npu') + + single_disc_mask_store_kernel[(1,)](in_tensor, out_tensor, BLOCK_N=BLOCK_N) + + half = BLOCK_N // 2 + expected = torch.full((BLOCK_N,), -1.0, dtype=torch.float16, device='npu') + expected[:half] = in_tensor[:half] + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (C) Single continuous mask -- load only +# ============================================================================= +@triton.jit +def single_cont_mask_load_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + cont_mask = row_offs < M # Continuous mask + ptr_in = in_ptr + row_offs + ptr_out = out_ptr + row_offs + data = tl.load(ptr_in, mask=cont_mask, other=0.0) + tl.store(ptr_out, data, mask=cont_mask) + + +@pytest.mark.parametrize("M,BLOCK_M", [(6, 8)]) +def test_single_continuous_mask_load(M, BLOCK_M): + in_tensor = torch.arange(BLOCK_M, dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_M,), -1.0, dtype=torch.float16, device='npu') + + single_cont_mask_load_kernel[(1,)](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M) + + expected = torch.full((BLOCK_M,), -1.0, dtype=torch.float16, device='npu') + expected[:M] = in_tensor[:M] + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (D) Single continuous mask -- store only +# ============================================================================= +@triton.jit +def single_cont_mask_store_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + cont_mask = row_offs < M + ptr_in = in_ptr + row_offs + ptr_out = out_ptr + row_offs + data = tl.load(ptr_in) + tl.store(ptr_out, data, mask=cont_mask) + + +@pytest.mark.parametrize("M,BLOCK_M", [(6, 8)]) +def test_single_continuous_mask_store(M, BLOCK_M): + in_tensor = torch.arange(BLOCK_M, dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_M,), -1.0, dtype=torch.float16, device='npu') + single_cont_mask_store_kernel[(1,)](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M) + expected = torch.full((BLOCK_M,), -1.0, dtype=torch.float16, device='npu') + expected[:M] = in_tensor[:M] + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (E) Continuous & discrete 2-way AND -- load only +# ============================================================================= +@triton.jit +def cont_disc_combined_mask_load_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + # Continuous mask + row_boundary = row_offs < M + # Discrete mask + col_stride = (col_offs * 2) < BLOCK_N + mask = row_boundary[:, None] & col_stride[None, :] + ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + data = tl.load(ptr_in, mask=mask, other=0.0) + tl.store(ptr_out, data) + + +@pytest.mark.parametrize("M,BLOCK_M,BLOCK_N", [(6, 8, 8)]) +def test_cont_disc_combined_mask_load(M, BLOCK_M, BLOCK_N): + in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + out_tensor = torch.empty((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + + cont_disc_combined_mask_load_kernel[(1,)]( + in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + + half_n = BLOCK_N // 2 + expected = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + expected[:M, :half_n] = 1.0 + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (F) Continuous & discrete 2-way AND -- store only +# ============================================================================= +@triton.jit +def cont_disc_combined_mask_store_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + row_boundary = row_offs < M # continuous -> contLeaf + col_stride = (col_offs * 2) < BLOCK_N # discrete -> discLeaf + mask = row_boundary[:, None] & col_stride[None, :] + ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + data = tl.load(ptr_in) + tl.store(ptr_out, data, mask=mask) + + +@pytest.mark.parametrize("M,BLOCK_M,BLOCK_N", [(6, 8, 8)]) +def test_cont_disc_combined_mask_store(M, BLOCK_M, BLOCK_N): + in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') + cont_disc_combined_mask_store_kernel[(1,)]( + in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + half_n = BLOCK_N // 2 + expected = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') + expected[:M, :half_n] = 1.0 + assert torch.allclose(out_tensor, expected), \ + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + + +# ============================================================================= +# (G) Continuous & discrete 2-way AND -- load + store (complex interleave, original) +# ============================================================================= @triton.jit -def simple_discrete_mask_load_kernel( +def interleave_cont_disc_mask_kernel( in_ptr, out_ptr, M: tl.constexpr, @@ -35,9 +249,9 @@ def simple_discrete_mask_load_kernel( pid = tl.program_id(0) col_offs = tl.arange(0, N) even_col_offs = tl.arange(0, N // 2) * 2 - even_col_mask = even_col_offs < N + even_col_mask = even_col_offs < N # discrete: cmpi(muli(range,2), N) row_offs = tl.arange(0, M) - row_mask = row_offs < M + row_mask = row_offs < M # continuous: cmpi(range_M, M) in_even_ptr = in_ptr + row_offs[:, None] * N + even_col_offs[None, :] in_odd_ptr = in_ptr + row_offs[:, None] * N + even_col_offs[None, :] + 1 even_data = tl.load(in_even_ptr, mask=row_mask[:, None] & even_col_mask[None, :], other=0.0) @@ -47,21 +261,162 @@ def simple_discrete_mask_load_kernel( tl.store(out_ptr, rotated_data) -@pytest.mark.parametrize("M", [(4)]) -@pytest.mark.parametrize("N", [(8)]) +@pytest.mark.skip(reason="not supported after the NPUIR is updated in April, and will be fixed later") +@pytest.mark.parametrize("M", [4]) +@pytest.mark.parametrize("N", [8]) def test_discrete_mask_load_store(M, N): + """Regression test: mask=row_mask & even_col_mask (continuous & discrete 2-way)""" input_tensor = torch.arange(M * N, dtype=torch.float16, device='npu').reshape(M, N) output_tensor = torch.empty_like(input_tensor) - grid = (1, ) - simple_discrete_mask_load_kernel[grid]( - input_tensor, - output_tensor, - M=M, - N=N, - ) + interleave_cont_disc_mask_kernel[(1,)](input_tensor, output_tensor, M=M, N=N) even_cols = input_tensor[:, 0::2] odd_cols = input_tensor[:, 1::2] ref_output = torch.empty_like(input_tensor) ref_output[:, 0::2] = -odd_cols ref_output[:, 1::2] = even_cols assert torch.allclose(output_tensor.float(), ref_output.float()) + + +# ============================================================================= +# (H) Continuous & discrete 4-way AND -- load + store +# ============================================================================= +@triton.jit +def multi_cont_disc_mask_load_store_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + + row_boundary = row_offs < M # continuous mask + col_boundary = col_offs < N # continuous mask + row_stride = (row_offs * 2) < BLOCK_M # discrete mask + col_stride = (col_offs * 2) < BLOCK_N # discrete mask + + mask = (row_boundary[:, None] + & col_boundary[None, :] + & row_stride[:, None] + & col_stride[None, :]) + + ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + + data = tl.load(ptr_in, mask=mask, other=0.0) + result = data + 1.0 + tl.store(ptr_out, result, mask=mask) + + +@pytest.mark.parametrize("M,N,BLOCK_M,BLOCK_N", [ + (6, 6, 8, 8), +]) +def test_multi_cont_disc_mask_load_store(M, N, BLOCK_M, BLOCK_N): + in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + out_tensor = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + + multi_cont_disc_mask_load_store_kernel[(1,)]( + in_tensor, out_tensor, M=M, N=N, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + + half_m = BLOCK_M // 2 # = 4 + half_n = BLOCK_N // 2 # = 4 + expected = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + expected[:half_m, :half_n] = 2.0 + + assert torch.allclose(out_tensor, expected), ( + f"BLOCK=({BLOCK_M},{BLOCK_N}), valid=({M},{N})\n" + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + ) + + +# ============================================================================= +# (I) broadcast(continuous & discrete) 2-D AND -- load only +# ============================================================================= +@triton.jit +def broadcast_cont_disc_2d_load_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + + row_boundary = row_offs < M + row_disc = (row_offs * 2) < BLOCK_M + mask = row_boundary[:, None] & row_disc[:, None] & (col_offs < BLOCK_N)[None, :] + + ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + + data = tl.load(ptr_in, mask=mask, other=0.0) + tl.store(ptr_out, data) + + +@pytest.mark.parametrize("M,BLOCK_M,BLOCK_N", [(3, 4, 8)]) +def test_broadcast_cont_disc_2d_load(M, BLOCK_M, BLOCK_N): + in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + out_tensor = torch.empty((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + + broadcast_cont_disc_2d_load_kernel[(1,)]( + in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + + disc_true_rows = BLOCK_M // 2 + both_true_rows = min(M, disc_true_rows) + + expected = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + expected[:both_true_rows, :] = 1.0 + + assert torch.allclose(out_tensor, expected), ( + f"M={M}, BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}\n" + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + ) + + +# ============================================================================= +# (J) broadcast(continuous & discrete) 2-D AND -- load + store +# ============================================================================= +@triton.jit +def broadcast_cont_disc_2d_load_store_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + + row_boundary = row_offs < M + row_disc = (row_offs * 2) < BLOCK_M + + combined = row_boundary[:, None] & row_disc[:, None] & (col_offs < BLOCK_N)[None, :] + + ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] + + data = tl.load(ptr_in, mask=combined, other=0.0) + tl.store(ptr_out, data, mask=combined) + + +@pytest.mark.parametrize("M,BLOCK_M,BLOCK_N", [(3, 4, 8)]) +def test_broadcast_cont_disc_2d_load_store(M, BLOCK_M, BLOCK_N): + in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') + + broadcast_cont_disc_2d_load_store_kernel[(1,)]( + in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + + disc_true_rows = BLOCK_M // 2 + both_true_rows = min(M, disc_true_rows) + + expected = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') + expected[:both_true_rows, :] = 1.0 + + assert torch.allclose(out_tensor, expected), ( + f"M={M}, BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}\n" + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" + ) diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_tail_block_mte_oob.py b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_tail_block_mte_oob.py new file mode 100644 index 0000000000..1bb5925123 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_tail_block_mte_oob.py @@ -0,0 +1,249 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +# ============================================================================= +# MTE (Memory Tag Extension) OOB regression test for DiscreteMaskAccessConversionPass +# +# This file verifies that DiscreteMaskAccessConversionPass correctly bounds +# global-memory accesses when the load/store mask is a combined discrete mask +# +# Test strategy +# ------------- +# The test engineers this condition in four steps: +# Step 1 — probe: Trigger a fresh 2 MB NPU segment; measure its size. +# Step 2 — pre_fill: Fill the segment with small tensors until the remaining +# free space is in [IN_BYTES, TARGET_FREE]. +# Step 3 — in_tensor: Allocate the test tensor; it lands at the segment tail +# with only ~7680 bytes gap to the boundary. +# Step 4 — kernel: Run the kernel + synchronize. Before the fix the +# OOB read (24576 bytes) crosses the boundary → MTE. +# After the fix the copy is bounded to IN_BYTES → no MTE. +# +# Memory layout at the time of the kernel call (before fix): +# +# ┌──────────────────────────── 2 MB segment ────────────────────────────────┐ +# │ probe(512 B) │←────── pre_fill (~2025 KB) ──────→│ in_tensor(8192 B) │gap│ +# └──────────────────────────────────────────────────────────────────────────┘ +# ↑ segment end +# ├──── OOB_BYTES (24576 B) ─────→ +# crosses boundary → MTE ✓ +# +# ============================================================================= + +import math +import torch +import triton +import triton.language as tl +import torch_npu +import pytest + + +@triton.jit +def cont_disc_oob_inplace_2d_kernel( + ptr, + M, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + row_offs = tl.arange(0, BLOCK_M) + col_offs = tl.arange(0, BLOCK_N) + + # Continuous bound (contMask): marks the M valid rows in this tile. + row_boundary = row_offs < (M - pid_m * BLOCK_M) + row_disc = (row_offs * 2) < BLOCK_M + combined = row_boundary[:, None] & row_disc[:, None] & (col_offs < BLOCK_N)[None, :] + + row_start = pid_m * BLOCK_M + ptr_2d = ptr + (row_start + row_offs[:, None]) * BLOCK_N + col_offs[None, :] + + # load triggers DiscreteMaskAccessConversionPass. + # Before fix: copy size = BLOCK_M × BLOCK_N × 2 bytes = 32768 bytes (OOB). + # After fix: copy size = M × BLOCK_N × 2 bytes = 8192 bytes (safe). + data = tl.load(ptr_2d, mask=combined, other=0.0) + tl.store(ptr_2d, data, mask=row_boundary[:, None]) + + +# ============================================================================= +# Memory setup helper +# ============================================================================= +def _fill_segment_to_boundary(dtype, device, in_bytes, target_free, chunk_max_bytes): + """Allocate a fresh NPU segment and fill it so that only ~target_free bytes remain. + + Returns + ------- + pre_fillers : list of torch.Tensor + All tensors allocated (probe + fill chunks). The caller is responsible + for deleting them in `finally`. + pool_free_after_fill : int + Segment free space after filling, in bytes. + seg_size : int + Total size of the triggered segment, in bytes. + """ + elem_size = torch.finfo(dtype).bits // 8 + + # --- Step 1: probe — trigger a fresh 2 MB small-alloc segment ---------- + pool0 = torch.npu.memory_reserved(0) + alloc0 = torch.npu.memory_allocated(0) + + probe = torch.empty(1, dtype=dtype, device=device) + + pool1 = torch.npu.memory_reserved(0) + alloc1 = torch.npu.memory_allocated(0) + + seg_size = pool1 - pool0 # should be 2 MB = 2097152 bytes + probe_actual = alloc1 - alloc0 # NPU 512-byte aligned → 512 bytes + + print(f"\n[mte] Step 1: probe") + print(f"[mte] segment_size = {seg_size} bytes ({seg_size // 1024} KB)") + print(f"[mte] probe_actual = {probe_actual} bytes") + print(f"[mte] pool_free = {seg_size - probe_actual} bytes") + + # --- Step 2: pre_fill — leave only [in_bytes, target_free] bytes free --- + # Chunks are kept ≤ chunk_max_bytes to stay in the small-alloc pool and + # avoid opening a new segment via the large-alloc path. + pre_fillers = [probe] + + for chunk in [chunk_max_bytes, + chunk_max_bytes // 2, + chunk_max_bytes // 4, + chunk_max_bytes // 8, + 32 * 1024, 16 * 1024, 8 * 1024, + 4 * 1024, 2 * 1024, 1024, 512]: + while True: + free = torch.npu.memory_reserved(0) - torch.npu.memory_allocated(0) + if free <= in_bytes: + break # not enough room even for in_tensor; stop + if free <= target_free: + break # already in target range; try smaller chunk + if free <= target_free + chunk: + break # this chunk would overshoot; try smaller chunk + try: + t = torch.empty(chunk // elem_size, dtype=dtype, device=device) + pre_fillers.append(t) + except RuntimeError: + break # segment exhausted; try smaller chunk + + pool_free_after_fill = torch.npu.memory_reserved(0) - torch.npu.memory_allocated(0) + pre_bytes = sum(t.numel() * elem_size for t in pre_fillers) + print(f"\n[mte] Step 2: pre_fill") + print(f"[mte] tensors = {len(pre_fillers)}, total = {pre_bytes} bytes ({pre_bytes // 1024} KB)") + print(f"[mte] pool_free = {pool_free_after_fill} bytes (target [{in_bytes}, {target_free}] bytes)") + + return pre_fillers, pool_free_after_fill, seg_size + + +# ============================================================================= +# Test: MTE OOB via segment-boundary placement +# ============================================================================= +@pytest.mark.parametrize("BLOCK_M,BLOCK_N,M", [ + (4, 4096, 1), +]) +def test_mte_segment_boundary_oob(BLOCK_M, BLOCK_N, M): + """Regression: combined discrete mask load causes OOB on tail blocks. + + Verifies that DiscreteMaskAccessConversionPass correctly bounds + the memory copy to M rows (the contiguous range), not BLOCK_M rows (the full tile). + + Test outcome: + - Before fix: RuntimeError (MTE OOB) — the test would fail. + - After fix: no exception — the test passes. + """ + dtype = torch.float16 + device = 'npu' + elem_size = 2 # float16 + + in_bytes = M * BLOCK_N * elem_size # 8192 bytes + oob_bytes = (BLOCK_M - M) * BLOCK_N * elem_size # 24576 bytes + # TARGET_FREE: midpoint between in_bytes and oob_bytes. + # Ensures in_tensor fits AND gap < oob_bytes so OOB crosses segment boundary. + target_free = (in_bytes + oob_bytes) // 2 # 16384 bytes + chunk_max_bytes = 512 * 1024 # 512 KB + + print(f"\n[mte] BLOCK_M={BLOCK_M} BLOCK_N={BLOCK_N} M={M}") + print(f"[mte] in_bytes = {in_bytes} bytes (in_tensor: {M}×{BLOCK_N}×{elem_size})") + print(f"[mte] oob_bytes = {oob_bytes} bytes (unfixed copy: {BLOCK_M}×{BLOCK_N}×{elem_size} - in_bytes)") + print(f"[mte] target_free = {target_free} bytes (must satisfy in_bytes < target_free < oob_bytes)") + + torch.npu.empty_cache() + + pre_fillers = [] + in_tensor = None + + try: + pre_fillers, pool_free_after_fill, _ = _fill_segment_to_boundary( + dtype, device, in_bytes, target_free, chunk_max_bytes + ) + except Exception as exc: + torch.npu.empty_cache() + pytest.skip(f"Memory layout setup failed (allocator behaviour may differ): {exc}") + + # Verify pre_fill achieved the required free-space window. + if not (in_bytes <= pool_free_after_fill <= target_free): + for t in reversed(pre_fillers): + del t + torch.npu.empty_cache() + pytest.skip( + f"pre_fill did not reach target range [{in_bytes}, {target_free}] bytes; " + f"got {pool_free_after_fill} bytes. " + f"Skipping MTE check (NPU allocator behaviour may differ)." + ) + + try: + # Step 3: allocate in_tensor — lands at the very end of the segment. + # NPU 512-byte alignment means the allocator consumes + # in_bytes + 512 = 8704 bytes, leaving gap ≈ target_free - 8704 = 7680 bytes. + in_tensor = torch.ones(M * BLOCK_N, dtype=dtype, device=device).view(M, BLOCK_N) + + gap = torch.npu.memory_reserved(0) - torch.npu.memory_allocated(0) + print(f"\n[mte] Step 3: in_tensor") + print(f"[mte] address = [{in_tensor.data_ptr():#x}, {in_tensor.data_ptr() + in_bytes:#x})") + print(f"[mte] gap = {gap} bytes (in_tensor end → segment end)") + + if oob_bytes <= gap: + pytest.skip( + f"gap ({gap} bytes) >= oob_bytes ({oob_bytes} bytes): " + f"OOB would not cross the segment boundary. " + f"Skipping MTE check." + ) + print(f"[mte] oob_bytes({oob_bytes} B) > gap({gap} B) → MTE expected if unfixed ✓") + + # Step 4: run kernel + num_pids_m = math.ceil(M / BLOCK_M) + print(f"\n[mte] Step 4: kernel (grid=({num_pids_m},))") + cont_disc_oob_inplace_2d_kernel[(num_pids_m,)]( + in_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + torch.npu.synchronize() + print("[mte] PASSED: fix is effective, no OOB.") + + except RuntimeError as exc: + pytest.fail( + f"MTE OOB triggered — DiscreteMaskAccessConversionPass fix " + f"may not be applied or is incomplete.\nError: {exc}" + ) + + finally: + if in_tensor is not None: + del in_tensor + for t in reversed(pre_fillers): + del t + torch.npu.empty_cache() + print("[mte] Memory released.") diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_overlap_mask.py b/third_party/ascend/unittest/pytest_ut/test_discrete_overlap_mask.py new file mode 100644 index 0000000000..6f44b1de5c --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_overlap_mask.py @@ -0,0 +1,232 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + +# --------------------------------------------------------------------------- +# Fixed Constants +# --------------------------------------------------------------------------- +_M_ROWS = 16 # Rows per program +_OFFS = 8 # Write window offset for two programs +_HALF = 12 # Mask threshold +_NUM_C = 24 # Rows of matrix C (= OFFS + M_ROWS, ensure pid=1 write window not out of bounds) + +assert _OFFS < _HALF < _M_ROWS, "OFFS < HALF < M_ROWS to ensure True/False on both sides" +assert _NUM_C >= _OFFS + _M_ROWS, "NUM_C must accommodate upper bound of pid=1 write window" + + +# --------------------------------------------------------------------------- +# Triton Kernel +# --------------------------------------------------------------------------- +@triton.jit +def _copy_matrix_kernel( + A_ptr, idx_ptr, + C_ptr, + idx_stride, + A_row_stride, + A_col_stride, + C_row_stride, + C_col_stride, + BLOCK_N: tl.constexpr, + HALF: tl.constexpr, +): + """ + Discrete memory access + overlapping write window + runtime mask. + + pid=0 write window: rows [0, 15], mask=True when idx < HALF + pid=1 write window: rows [8, 23], mask=True when idx >= HALF + Overlap region : rows [8, 15] -> triggers load-select-store RMW + """ + program_id = tl.program_id(axis=0).to(tl.int64) + N_id = tl.program_id(axis=1).to(tl.int64) + + OFFS: tl.constexpr = 8 + M_ROWS: tl.constexpr = 16 + + N_BLOCK = N_id * BLOCK_N + tl.arange(0, BLOCK_N) # shape: (BLOCK_N,) + M_BLOCK = tl.arange(0, M_ROWS) # shape: (M_ROWS,) + + # Discrete row indices (loaded at runtime -> mask cannot be statically analyzed) + idx = tl.load(idx_ptr + program_id * idx_stride + M_BLOCK) + + # Runtime mask (generates scf.if -> compiler converts to load-select-store) + if program_id == 0: + mask = idx < HALF + else: + mask = idx >= HALF + + val = tl.load( + A_ptr + idx[:, None] * A_row_stride + tl.arange(0, BLOCK_N)[None, :] * A_col_stride, + mask=mask[:, None], + ) + + # Write to C (mask=False rows rely on load-select-store to preserve original values) + tl.store( + C_ptr + (OFFS * program_id + M_BLOCK[:, None]) * C_row_stride + + N_BLOCK[None, :] * C_col_stride, + val, + mask=mask[:, None], + ) + # C (24 × N) Program 0 Program 1 + # Row 0~7 ──────────── write value (mask=True) ── Not involved + # Row 8~11 ──────────── write value (mask=True) ── Not written (mask=False, overwritten by P0 to 0) + # Row 12~15 ──────────── Not written (mask=False) ── write value (mask=True) + # Row 16~23 ──────────── Not involved ── write value (mask=True) + + +# --------------------------------------------------------------------------- +# Helper: Construct discrete index vector +# --------------------------------------------------------------------------- +def _make_idx(device: str) -> torch.Tensor: + """ + Construct 2x16 index matrix that meets mask distribution requirements. + + pid=0 row (idx0): + First HALF=12 values ∈ [0, HALF) -> mask=True + Last 4 values ∈ [HALF, M_ROWS) -> mask=False + pid=1 row (idx1): + First 4 values ∈ [OFFS, OFFS+4) -> mask=False + Last HALF=12 values ∈ [HALF, HALF*2) -> mask=True + """ + def shuffle_quads(lst: list) -> list: + """Reverse each group of 4 elements (ignore if less than 4).""" + out = lst[:] + for i in range(0, len(out) - 3, 4): + out[i], out[i + 1], out[i + 2], out[i + 3] = \ + out[i + 3], out[i + 2], out[i + 1], out[i] + return out + + num_false = _M_ROWS - _HALF # = 4 + + seg0_true = shuffle_quads(list(range(0, _HALF))) # 12 values, < 12 + seg0_false = shuffle_quads(list(range(_HALF, _HALF + num_false))) # 4 values, >= 12 + idx0 = seg0_true + seg0_false # Total length 16 + + seg1_false = shuffle_quads(list(range(_OFFS, _OFFS + num_false))) # 4 values, < 12 + seg1_true = shuffle_quads(list(range(_HALF, _HALF + _HALF))) # 12 values, >= 12 + idx1 = seg1_false + seg1_true # Total length 16 + + assert len(idx0) == _M_ROWS, f"idx0 length error: {len(idx0)}" + assert len(idx1) == _M_ROWS, f"idx1 length error: {len(idx1)}" + assert all(v < _HALF for v in seg0_true), "pid=0 True segment should all be < HALF" + assert all(v >= _HALF for v in seg0_false), "pid=0 False segment should all be >= HALF" + assert all(v < _HALF for v in seg1_false), "pid=1 False segment should all be < HALF" + assert all(v >= _HALF for v in seg1_true), "pid=1 True segment should all be >= HALF" + + return torch.tensor([idx0, idx1], dtype=torch.int32, device=device) + + +# --------------------------------------------------------------------------- +# Dtype Mapping +# --------------------------------------------------------------------------- +_DTYPE_MAP = { + 'int32': torch.int32, + 'float32': torch.float32, + 'float16': torch.float16, + 'int16': torch.int16, +} + + +# --------------------------------------------------------------------------- +# Single Execution + Verification +# --------------------------------------------------------------------------- +def _run_once(BLOCK_N: int, dtype_str: str) -> None: + """ + Execute kernel once and verify results. + + Expectations: + C[0:HALF, :] all 0 -- pid=0 writes rows [0,HALF) of A (all 0) + C[HALF:NUM_C, :] all 1 -- pid=1 writes rows [HALF, NUM_C) of A (all 1) + """ + dev = 'npu' + td = _DTYPE_MAP[dtype_str] + zero_val = 0.0 if dtype_str.startswith('float') else 0 + one_val = 1.0 if dtype_str.startswith('float') else 1 + + # A: First HALF rows all 0, last HALF rows all 1 + A = torch.zeros((_NUM_C, BLOCK_N), dtype=td, device=dev) + A[_HALF:, :] = one_val + + idx = _make_idx(dev) + + # C: Fill all with 2 + C = torch.full((_NUM_C, BLOCK_N), 2, dtype=td, device=dev) + + grid = (2, 1) + _copy_matrix_kernel[grid]( + A_ptr=A, idx_ptr=idx, C_ptr=C, + idx_stride=idx.stride(0), + A_row_stride=A.stride(0), A_col_stride=A.stride(1), + C_row_stride=C.stride(0), C_col_stride=C.stride(1), + BLOCK_N=BLOCK_N, + HALF=_HALF, + enable_sync_block_lock=True, + ) + + # Verification + assert torch.all(C[:_HALF] == zero_val), ( + f"[dtype={dtype_str}, BLOCK_N={BLOCK_N}] " + f"C[:HALF] should all be {zero_val}, actual unique values: {C[:_HALF].unique().tolist()}" + ) + assert torch.all(C[_HALF:] == one_val), ( + f"[dtype={dtype_str}, BLOCK_N={BLOCK_N}] " + f"C[HALF:] should all be {one_val}, actual unique values: {C[_HALF:].unique().tolist()}" + ) + + +@pytest.mark.parametrize("param_list", [ + # --- int32 --- + (16, 'int32'), + (32, 'int32'), + (64, 'int32'), + # --- float32 --- + (16, 'float32'), + (32, 'float32'), + (64, 'float32'), +]) +def test_discrete_overlap_mask(param_list): + """ + Verify no precision issues in discrete access + overlapping write window + runtime mask scenario. + + Race condition errors are probabilistic. Each parameter combination is executed 10 times + to fully cover concurrent timing scenarios. + If sync_block_lock fix is effective, all 10 runs pass; if race condition exists, assertion failure + occurs with high probability. + """ + BLOCK_N, dtype_str = param_list + for _ in range(10): + _run_once(BLOCK_N, dtype_str) + + +if __name__ == "__main__": + configs = [ + (32, 'int32'), + (32, 'float32'), + ] + for BLOCK_N, dtype_str in configs: + print(f"Testing BLOCK_N={BLOCK_N}, dtype={dtype_str} ...", end=" ", flush=True) + for _ in range(10): + _run_once(BLOCK_N, dtype_str) + print("PASS (10 rounds)") + print("All tests passed.") diff --git a/third_party/ascend/unittest/pytest_ut/test_dot.py b/third_party/ascend/unittest/pytest_ut/test_dot.py index a4837d0449..1e6cf3c3a5 100644 --- a/third_party/ascend/unittest/pytest_ut/test_dot.py +++ b/third_party/ascend/unittest/pytest_ut/test_dot.py @@ -26,6 +26,16 @@ import test_common +@pytest.fixture(scope="function") +def restore_npu_hf32_setting(): + original_allow_hf32 = torch_npu.npu.matmul.allow_hf32 + try: + torch_npu.npu.matmul.allow_hf32 = True + yield + finally: + torch_npu.npu.matmul.allow_hf32 = original_allow_hf32 + + def torch_dot_None(x0, x1): res = torch.matmul(x0, x1) return res @@ -49,6 +59,59 @@ def triton_dot_2_None(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.constexpr tl.store(output_ptr + oidx, ret, mask=out_mask) +@triton.jit +def triton_dot_2_allow_tf32(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr): + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + + x_mask = (bidx[:, None] < B) & (cidx[None, :] < C) + y_mask = (cidx[:, None] < C) & (didx[None, :] < D) + out_mask = (bidx[:, None] < B) & (didx[None, :] < D) + Xidx = bidx[:, None] * C + cidx[None, :] + Yidx = cidx[:, None] * D + didx[None, :] + X = tl.load(x_ptr + Xidx, mask=x_mask, other=0.0) + Y = tl.load(y_ptr + Yidx, mask=y_mask, other=0.0) + ret = tl.dot(X, Y, allow_tf32=True) + oidx = bidx[:, None] * D + didx[None, :] + tl.store(output_ptr + oidx, ret, mask=out_mask) + + +@triton.jit +def triton_dot_2_input_tf32(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr): + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + + x_mask = (bidx[:, None] < B) & (cidx[None, :] < C) + y_mask = (cidx[:, None] < C) & (didx[None, :] < D) + out_mask = (bidx[:, None] < B) & (didx[None, :] < D) + Xidx = bidx[:, None] * C + cidx[None, :] + Yidx = cidx[:, None] * D + didx[None, :] + X = tl.load(x_ptr + Xidx, mask=x_mask, other=0.0) + Y = tl.load(y_ptr + Yidx, mask=y_mask, other=0.0) + ret = tl.dot(X, Y, input_precision="tf32") + oidx = bidx[:, None] * D + didx[None, :] + tl.store(output_ptr + oidx, ret, mask=out_mask) + + +@triton.jit +def triton_dot_2_ignore_tf32(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr): + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + + x_mask = (bidx[:, None] < B) & (cidx[None, :] < C) + y_mask = (cidx[:, None] < C) & (didx[None, :] < D) + out_mask = (bidx[:, None] < B) & (didx[None, :] < D) + Xidx = bidx[:, None] * C + cidx[None, :] + Yidx = cidx[:, None] * D + didx[None, :] + X = tl.load(x_ptr + Xidx, mask=x_mask, other=0.0) + Y = tl.load(y_ptr + Yidx, mask=y_mask, other=0.0) + ret = tl.dot(X, Y, input_precision="hf32") + oidx = bidx[:, None] * D + didx[None, :] + tl.store(output_ptr + oidx, ret, mask=out_mask) + testlist1 = [ (10, 13, 35, 39), ] @@ -60,12 +123,59 @@ def triton_dot_2_None(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.constexpr ] +@pytest.mark.skip(reason="not supported after the NPUIR is updated in April, and will be fixed later") @pytest.mark.parametrize("B, C, D", testlist2) @pytest.mark.parametrize("sigtype", typelist) -def test_dot_2(sigtype, B, C, D): +def test_dot_2(restore_npu_hf32_setting, sigtype, B, C, D): x = test_common.generate_tensor((B, C), sigtype).npu() y = test_common.generate_tensor((C, D), sigtype).npu() z_ref = torch_dot_None(x, y).to(torch.float32) z = torch.zeros((B, D), dtype=torch.float32).npu() triton_dot_2_None[1, 1, 1](z, x, y, B, C, D) test_common.validate_cmp(sigtype, z, z_ref) + + +@pytest.mark.xfail( + reason="Temporarily disabled: TA backend does not support allow_tf32 yet. Will be fixed in follow-up." +) +@pytest.mark.parametrize("B, C, D", testlist2) +@pytest.mark.parametrize("sigtype", typelist) +def test_dot_2_allow_tf32(restore_npu_hf32_setting, sigtype, B, C, D): + x = test_common.generate_tensor((B, C), sigtype).npu() + y = test_common.generate_tensor((C, D), sigtype).npu() + z_ref = torch_dot_None(x, y).to(torch.float32) + z = torch.zeros((B, D), dtype=torch.float32).npu() + triton_dot_2_allow_tf32[1, 1, 1](z, x, y, B, C, D) + test_common.validate_cmp(sigtype, z, z_ref) + + +@pytest.mark.skip(reason="not supported after the NPUIR is updated in April, and will be fixed later") +@pytest.mark.parametrize("B, C, D", testlist2) +@pytest.mark.parametrize("sigtype", typelist) +def test_dot_2_input_tf32(restore_npu_hf32_setting, sigtype, B, C, D): + x = test_common.generate_tensor((B, C), sigtype).npu() + y = test_common.generate_tensor((C, D), sigtype).npu() + z_ref = torch_dot_None(x, y).to(torch.float32) + z = torch.zeros((B, D), dtype=torch.float32).npu() + triton_dot_2_input_tf32[1, 1, 1](z, x, y, B, C, D) + test_common.validate_cmp(sigtype, z, z_ref) + + +@pytest.mark.parametrize("B, C, D", testlist2) +@pytest.mark.parametrize("sigtype", typelist) +def test_dot_2_ignore_tf32(sigtype, B, C, D): + input_type = "bfloat16" + x = test_common.generate_tensor((B, C), input_type).npu() + y = test_common.generate_tensor((C, D), input_type).npu() + z = torch.zeros((B, D), dtype=torch.float32).npu() + + original_allow_hf32 = torch_npu.npu.matmul.allow_hf32 + try: + torch_npu.npu.matmul.allow_hf32 = False + z_ref = torch_dot_None(x.to(torch.float32), y.to(torch.float32)).to(torch.float32) + + finally: + torch_npu.npu.matmul.allow_hf32 = original_allow_hf32 + + triton_dot_2_ignore_tf32[1, 1, 1](z, x, y, B, C, D) + test_common.validate_cmp(sigtype, z, z_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_erfinv.py b/third_party/ascend/unittest/pytest_ut/test_erfinv.py index 9e45a5c553..a41521417a 100644 --- a/third_party/ascend/unittest/pytest_ut/test_erfinv.py +++ b/third_party/ascend/unittest/pytest_ut/test_erfinv.py @@ -82,3 +82,29 @@ def test_all_blocks_parallel(param_list, monkeypatch): triton_erfinv[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub) test_common.validate_cmp(dtype, y_cal, y_ref) monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ] + ) +def test_auto_blockify(param_list, monkeypatch): + monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") + dtype, shape, ncore, xblock, xblock_sub = param_list + x = test_common.generate_tensor(shape, dtype).npu() + x[0][0][0] = 1 # erfinv(1) -> ∞ + x[0][0][1] = -1 # erfinv(-1) -> -∞ + + # Avoid numerical instability near ±1 + # Move values in (threshold, 1) to threshold and (-1, -threshold) to -threshold + threshold = 1 - 1.1e-4 + too_close_pos = (x > threshold) & (x < 1) + too_close_neg = (x < -threshold) & (x > -1) + x[too_close_pos] = threshold + x[too_close_neg] = -threshold + y_ref = torch.erfinv(x).npu() + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_erfinv[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub, auto_blockify_size=ncore) + test_common.validate_cmp(dtype, y_cal, y_ref) + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") diff --git a/third_party/ascend/unittest/pytest_ut/test_expm1.py b/third_party/ascend/unittest/pytest_ut/test_expm1.py index 90665b030f..d1eb5d2cda 100644 --- a/third_party/ascend/unittest/pytest_ut/test_expm1.py +++ b/third_party/ascend/unittest/pytest_ut/test_expm1.py @@ -45,9 +45,11 @@ def triton_expm1(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constex tl.store(out_ptr0 + (x0), tmp1, None) -@pytest.mark.parametrize('param_list', [ - ['float32', (2, 4096, 8), 2, 32768, 1024], -]) +@pytest.mark.skip(reason="expm1 failed sometimes, wait for fix") +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ]) def test_expm1(param_list): dtype, shape, ncore, xblock, xblock_sub = param_list x0_ref = test_common.generate_tensor(shape, dtype) diff --git a/third_party/ascend/unittest/generalization_cases/test_tan.py b/third_party/ascend/unittest/pytest_ut/test_fast_dividef.py similarity index 66% rename from third_party/ascend/unittest/generalization_cases/test_tan.py rename to third_party/ascend/unittest/pytest_ut/test_fast_dividef.py index 4d6b6454cb..7d0d9ebdb6 100644 --- a/third_party/ascend/unittest/generalization_cases/test_tan.py +++ b/third_party/ascend/unittest/pytest_ut/test_fast_dividef.py @@ -1,62 +1,62 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math -import triton.language.extra.ascend.libdevice as libdevice - - -def torch_pointwise(x0): - res = torch.tan(x0) - return res - - -@triton.jit -def triton_tan(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base1 = tl.arange(0, XBLOCK_SUB) - loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop1 in range(loops1): - x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 - x0 = offset + (loop1 * XBLOCK_SUB) + base1 - tmp0 = tl.load(in_ptr0 + (x0), None) - tmp2 = libdevice.tan(tmp0) - tl.store(out_ptr0 + (x0), tmp2, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['float32', 'float16']) -def test_case(dtype, shape): - x0 = test_common.generate_tensor(shape, dtype).npu() - - numel = x0.numel() - ncore = 1 if numel <= 32 else 32 - xblock = math.ceil(numel / ncore) - xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) - - y_ref = torch_pointwise(x0) - y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_tan[ncore, 1, 1](x0, y_cal, xblock, xblock_sub) - test_common.validate_cmp(dtype, y_cal, y_ref) +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common +import triton.language.extra.cann.libdevice as libdevice + + +def torch_pointwise(x0, x1): + res = x0 / x1 + return res + + +@triton.jit +def triton_fast_dividef(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = libdevice.fast_dividef(tmp0, tmp1) + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ] + ) + +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_fast_dividef[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/generalization_cases/test_log1p.py b/third_party/ascend/unittest/pytest_ut/test_fast_expf.py similarity index 66% rename from third_party/ascend/unittest/generalization_cases/test_log1p.py rename to third_party/ascend/unittest/pytest_ut/test_fast_expf.py index fa37cbd298..7d0d9ebdb6 100644 --- a/third_party/ascend/unittest/generalization_cases/test_log1p.py +++ b/third_party/ascend/unittest/pytest_ut/test_fast_expf.py @@ -1,62 +1,62 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import triton -import triton.language as tl -import torch -import pytest -import test_common -from test_common import TestUtils -import math -import triton.language.extra.ascend.libdevice as libdevice - - -def torch_pointwise(x0): - res = torch.log1p(x0) - return res - - -@triton.jit -def triton_log1p(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base1 = tl.arange(0, XBLOCK_SUB) - loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop1 in range(loops1): - x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 - x0 = offset + (loop1 * XBLOCK_SUB) + base1 - tmp0 = tl.load(in_ptr0 + (x0), None) - tmp2 = libdevice.log1p(tmp0) - tl.store(out_ptr0 + (x0), tmp2, None) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -@pytest.mark.parametrize('dtype', ['float32', 'float16']) -def test_case(dtype, shape): - x0 = test_common.generate_tensor(shape, dtype).npu() - - numel = x0.numel() - ncore = 1 if numel <= 32 else 32 - xblock = math.ceil(numel / ncore) - xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) - - y_ref = torch_pointwise(x0) - y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_log1p[ncore, 1, 1](x0, y_cal, xblock, xblock_sub) - test_common.validate_cmp(dtype, y_cal, y_ref) +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common +import triton.language.extra.cann.libdevice as libdevice + + +def torch_pointwise(x0, x1): + res = x0 / x1 + return res + + +@triton.jit +def triton_fast_dividef(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = libdevice.fast_dividef(tmp0, tmp1) + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ] + ) + +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_fast_dividef[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_gamma.py b/third_party/ascend/unittest/pytest_ut/test_gamma.py index 388ed36af2..19a33dc688 100644 --- a/third_party/ascend/unittest/pytest_ut/test_gamma.py +++ b/third_party/ascend/unittest/pytest_ut/test_gamma.py @@ -68,3 +68,21 @@ def test_all_blocks_parallel(param_list, monkeypatch): triton_gamma[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub) test_common.validate_cmp(dtype, y_cal, y_ref) monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 2048, 8), 2, 32768, 512], + ] + ) +def test_auto_blockify(param_list, monkeypatch): + monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") + dtype, shape, ncore, xblock, xblock_sub = param_list + x = torch.abs(test_common.generate_tensor(shape, dtype)) + x_np = x.cpu().numpy() + x = x.npu() + y_ref = torch.from_numpy(gamma(x_np)).to(x.device).to(x.dtype).npu() + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_gamma[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub, auto_blockify_size=ncore) + test_common.validate_cmp(dtype, y_cal, y_ref) + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") diff --git a/third_party/ascend/unittest/pytest_ut/test_if_advance.py b/third_party/ascend/unittest/pytest_ut/test_if_advance.py new file mode 100644 index 0000000000..beb8c670c7 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_if_advance.py @@ -0,0 +1,55 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al + +@triton.jit +def triton_if_advance_kernel(in_ptr0, in_ptr1, out_ptr, + xnumel, ynumel, k_loops, + XBLOCK: tl.constexpr, YBLOCK: tl.constexpr): + + K_block_ptr = tl.make_block_ptr( + base = in_ptr0, + shape = (xnumel, ynumel), + strides = (ynumel, 1), + offsets = (0, 0), + block_shape = (XBLOCK, YBLOCK), + order = (1, 0) + ) + V_block_ptr = tl.make_block_ptr( + base = in_ptr1, + shape = (ynumel, xnumel), + strides = (xnumel, 1), + offsets = (0, 0), + block_shape = (YBLOCK, XBLOCK), + order = (1, 0) + ) + O_block_ptr = tl.make_block_ptr( + base = out_ptr, + shape = (xnumel, xnumel), + strides = (xnumel, 1), + offsets = (0, 0), + block_shape = (XBLOCK, XBLOCK), + order = (1, 0) + ) + res = tl.zeros([XBLOCK, XBLOCK], tl.float32) + for i in range(0, k_loops): + if i > 0: + K_block_ptr = tl.advance(K_block_ptr, (0, YBLOCK)) + V_block_ptr = tl.advance(V_block_ptr, (YBLOCK, 0)) + a = tl.load(K_block_ptr) + b = tl.load(V_block_ptr) + res = tl.dot(a, b, acc = res) + tl.store(O_block_ptr, res) + +def test_if_advance(): + x = torch.randn((64, 256), dtype=torch.float32, device="npu") + y = torch.randn((256, 64), dtype=torch.float32, device="npu") + out_tri = torch.empty((64, 64), dtype=torch.float32, device="npu") + out_std = torch.empty((64, 64), dtype=torch.float32, device="npu") + torch.matmul(x, y, out = out_std) + triton_if_advance_kernel[1,1,1](x, y, out_tri, 64, 256, 4, 64, 64) + torch.testing.assert_close(out_std, out_tri, atol = 1e-2, rtol = 1e-2) + +test_if_advance() \ No newline at end of file diff --git a/third_party/ascend/unittest/pytest_ut/test_if_load.py b/third_party/ascend/unittest/pytest_ut/test_if_load.py new file mode 100644 index 0000000000..0c2e0c8e5a --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_if_load.py @@ -0,0 +1,84 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + +import torch +import torch_npu + + +@triton.jit +def triton_if_load(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): + base1 = tl.arange(0, XBLOCK) + index = base1 + if tl.program_id(0) == 0: + base1 = base1 * 1 + else: + base1 = base1 * 2 + tmp0 = tl.load(in_ptr0 + base1, base1 < XBLOCK, other=0.0) + tl.store(out_ptr0 + index, tmp0, None) + + +@triton.jit +def triton_for_if_load(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + base1 = tl.arange(0, XBLOCK_SUB) + index = base1 + loops = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for i in range(loops): + base1 = base1 + i * XBLOCK_SUB + index = index + i * XBLOCK_SUB + if tl.program_id(0) != 0: + base1 = base1 + 1 + + tmp0 = tl.load(in_ptr0 + base1, base1 < XBLOCK, other=0.0) + tl.store(out_ptr0 + index, tmp0, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (32,), 32], + ]) +def test_if_load(param_list): + dtype, shape, xblock = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = x0.clone() + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_if_load[(1,)](x0, y_cal, xblock) + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (32,), 32, 16], + ]) +def test_if_load(param_list): + dtype, shape, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = x0.clone() + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_for_if_load[(1,)](x0, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_implicit_atomic.py b/third_party/ascend/unittest/pytest_ut/test_implicit_atomic.py new file mode 100644 index 0000000000..67e5c729c8 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_implicit_atomic.py @@ -0,0 +1,140 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import pytest +import triton +import test_common +import triton.language as tl + + +types_all = [ + (torch.float32, 'float32'), +] + + +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def addptr_implicit_perm_atomic_add_2d( + ptr, + out, + ynumel, + xnumel, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] # [1, YB] + x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] # [XB, 1] + + val = 1.0 + (x.to(tl.float32) * 0.01) + (y.to(tl.float32) * 0.001) # [XB, YB] + xmask = x < xnumel + ymask = y < ynumel + old = tl.atomic_add(ptr + (x + 4 * y), val, xmask & ymask) + + tl.store(out + (x + 4 * y), old) + + +@triton.jit +def addptr_implicit_perm_atomic_cas_2d( + ptr, + out, + cmp_ptr, + val_ptr, + ynumel, + xnumel, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] + x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] + + xmask = x < xnumel + ymask = y < ynumel + mask = xmask & ymask + + offset = x + 4 * y + + cmp = tl.load(cmp_ptr + offset, mask=mask, other=0.0).to(tl.float32) + val = tl.load(val_ptr + offset, mask=mask, other=0.0).to(tl.float32) + + old = tl.atomic_cas(ptr + offset, cmp, val) + + tl.store(out + offset, old, mask=mask) + + +@pytest.mark.parametrize('dtype,sigtype', types_all) +@pytest.mark.parametrize('xnumel, ynumel, XBLOCK, YBLOCK', [(4, 512, 4, 64)]) +def test_addptr_implicit_perm_atomic_add_2d( + dtype, sigtype, + xnumel, ynumel, + XBLOCK, YBLOCK, +): + in_ptr = torch.zeros((ynumel * 4,), dtype=dtype).npu() + out_ptr = torch.ones_like(in_ptr) + + grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) + addptr_implicit_perm_atomic_add_2d[grid]( + in_ptr, out_ptr, ynumel, xnumel, + YBLOCK=YBLOCK, XBLOCK=XBLOCK + ) + + y_idx = torch.arange(ynumel).unsqueeze(1).npu() + x_idx = torch.arange(xnumel).unsqueeze(0).npu() + idx = (x_idx + 4 * y_idx).reshape(-1) + torch.testing.assert_close(out_ptr[idx], torch.zeros_like(out_ptr[idx])) + + val_ref = (1.0 + 0.01 * x_idx.to(torch.float32) + 0.001 * y_idx.to(torch.float32)).reshape(-1) + torch.testing.assert_close(in_ptr[idx], val_ref, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize('dtype,sigtype', types_all) +@pytest.mark.parametrize('xnumel, ynumel, XBLOCK, YBLOCK', [(4, 512, 4, 64)]) +def test_addptr_implicit_perm_atomic_cas_2d( + dtype, sigtype, + xnumel, ynumel, + XBLOCK, YBLOCK, +): + in_ptr = torch.full((ynumel * 4,), 2, dtype=dtype).npu() + out_ptr = torch.full((ynumel * 4,), 1, dtype=dtype).npu() + cmp_ptr = torch.full((ynumel * 4,), 2, dtype=dtype).npu() + val_ptr = torch.full((ynumel * 4,), 1, dtype=dtype).npu() + + grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) + addptr_implicit_perm_atomic_cas_2d[grid]( + in_ptr, out_ptr, cmp_ptr, val_ptr, ynumel, xnumel, + YBLOCK=YBLOCK, XBLOCK=XBLOCK + ) + + y_idx = torch.arange(ynumel).unsqueeze(1).npu() + x_idx = torch.arange(xnumel).unsqueeze(0).npu() + idx = (x_idx + 4 * y_idx).reshape(-1) + + torch.testing.assert_close(out_ptr[idx], torch.full_like(out_ptr[idx], 2.0)) + + torch.testing.assert_close(in_ptr[idx], torch.ones_like(in_ptr[idx])) + + +if __name__ == '__main__': + case_2d = (4, 512, 4, 64) + test_addptr_implicit_perm_atomic_add_2d(*types_all[0], *case_2d) + test_addptr_implicit_perm_atomic_cas_2d(*types_all[0], *case_2d) diff --git a/third_party/ascend/unittest/pytest_ut/test_implicit_permute.py b/third_party/ascend/unittest/pytest_ut/test_implicit_permute.py new file mode 100644 index 0000000000..f5c82bfafd --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_implicit_permute.py @@ -0,0 +1,1081 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import pytest +import triton +import triton.language as tl +import test_common + + +types_all = [ + (torch.float32, 'float32'), +] + +case_2d = [ + # X, Y, XBLOCK, YBLOCK + (512, 32, 4, 64), +] + +case_3d = [ + # X, Y, Z, XBLOCK, YBLOCK, ZBLOCK + (100, 40, 32, 10, 4, 4), +] + +case_4d = [ + # X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK + (100, 80, 20, 16, 20, 4, 4, 4), +] + + +# ---------------------------------------------------------- +# Triton kernel +# ---------------------------------------------------------- +@triton.jit +def addptr_implicit_perm_load_store_2d_static_stride( + ptr, + out, + ynumel, + xnumel, + stride_y: tl.constexpr, + stride_x: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr +): + # logical indices (A^T view) + x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] + y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] + mask = (x < xnumel) & (y < ynumel) + + # IMPORTANT: + # ptr is a row-major A, but we interpret it as A^T via stride + offset = x * stride_x + y * stride_y + + val = tl.load(ptr + offset, mask) + tl.store(out + offset, val, mask) + + +@triton.jit +def addptr_implicit_perm_load_store_2d( + ptr, + out, + ynumel, + xnumel, + stride_y, + stride_x, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + # logical indices (A^T view) + x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] + y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] + + mask = (x < xnumel) & (y < ynumel) + + # IMPORTANT: + # ptr is a row-major A, but we interpret it as A^T via stride + offset = x * stride_x + y * stride_y + + val = tl.load(ptr + offset, mask=mask) + tl.store(out + offset, val, mask=mask) + + +@triton.jit +def addptr_implicit_perm_load_store_3d_static_stride( + ptr, + out, + znumel, + ynumel, + xnumel, + stride_z: tl.constexpr, + stride_y: tl.constexpr, + stride_x: tl.constexpr, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x = pid_x * XBLOCK + tl.arange(0, XBLOCK)[:, None, None] + y = pid_y * YBLOCK + tl.arange(0, YBLOCK)[None, :, None] + z = pid_z * ZBLOCK + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + offset = x * stride_x + y * stride_y + z * stride_z + val = tl.load(ptr + offset, mask=mask) + tl.store(out + offset, val, mask=mask) + + +@triton.jit +def addptr_implicit_perm_load_store_3d( + ptr, + out, + znumel, + ynumel, + xnumel, + stride_z, + stride_y, + stride_x, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x = pid_x * XBLOCK + tl.arange(0, XBLOCK)[:, None, None] + y = pid_y * YBLOCK + tl.arange(0, YBLOCK)[None, :, None] + z = pid_z * ZBLOCK + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + offset = x * stride_x + y * stride_y + z * stride_z + val = tl.load(ptr + offset, mask=mask) + tl.store(out + offset, val, mask=mask) + + +@triton.jit +def addptr_implicit_perm_load_store_4d_static_stride( + ptr, + out, + wnumel, + znumel, + ynumel, + xnumel, + stride_w: tl.constexpr, + stride_z: tl.constexpr, + stride_y: tl.constexpr, + stride_x: tl.constexpr, + WBLOCK: tl.constexpr, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid0 = tl.program_id(0) # covers (w, x) + pid1 = tl.program_id(1) # y + pid2 = tl.program_id(2) # z + + xblocks_per_w = (xnumel + XBLOCK - 1) // XBLOCK + + w_pid = pid0 // xblocks_per_w + x_pid = pid0 - w_pid * xblocks_per_w + + x0 = x_pid * XBLOCK + y0 = pid1 * YBLOCK + z0 = pid2 * ZBLOCK + w0 = w_pid * WBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :, None] + w = w0 + tl.arange(0, WBLOCK)[None, None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) & (w < wnumel) + + offset = x * stride_x + y * stride_y + z * stride_z + w * stride_w + val = tl.load(ptr + offset, mask=mask, other=0.0) + tl.store(out + offset, val, mask=mask) + + +@triton.jit +def addptr_implicit_perm_load_store_4d( + ptr, + out, + wnumel, + znumel, + ynumel, + xnumel, + stride_w, + stride_z, + stride_y, + stride_x, + WBLOCK: tl.constexpr, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid0 = tl.program_id(0) # covers (w, x) + pid1 = tl.program_id(1) # y + pid2 = tl.program_id(2) # z + + xblocks_per_w = (xnumel + XBLOCK - 1) // XBLOCK + + w_pid = pid0 // xblocks_per_w + x_pid = pid0 - w_pid * xblocks_per_w + + x0 = x_pid * XBLOCK + y0 = pid1 * YBLOCK + z0 = pid2 * ZBLOCK + w0 = w_pid * WBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :, None] + w = w0 + tl.arange(0, WBLOCK)[None, None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) & (w < wnumel) + + offset = x * stride_x + y * stride_y + z * stride_z + w * stride_w + val = tl.load(ptr + offset, mask=mask, other=0.0) + tl.store(out + offset, val, mask=mask) + + +@triton.jit +def make_tensor_ptr_implicit_perm_load_store_2d_static_stride( + ptr, + out, + ynumel, + xnumel, + STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr +): + y0 = tl.program_id(1) * YBLOCK + x0 = tl.program_id(0) * XBLOCK + y = y0 + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] + x = x0 + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] + xmask = x < xnumel + ymask = y < ynumel + mask = xmask & ymask + + tptr = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel), + strides=(STRIDE_X, STRIDE_Y), + offsets=(x0, y0), + block_shape=(XBLOCK, YBLOCK), + order=(0, 1), + ) + + val = tl.load(tptr) + tl.store(out + (x * STRIDE_X + STRIDE_Y * y), val, mask=mask) + + +@triton.jit +def make_tensor_ptr_implicit_perm_load_store_3d_static_stride( + ptr, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z: tl.constexpr, + STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x0 = pid_x * XBLOCK + y0 = pid_y * YBLOCK + z0 = pid_z * ZBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + tptr = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel, znumel), + strides=(STRIDE_X, STRIDE_Y, STRIDE_Z), + offsets=(x0, y0, z0), + block_shape=(XBLOCK, YBLOCK, ZBLOCK), + order=(0, 1, 2), + ) + + val = tl.load(tptr) + tl.store(out + (x * STRIDE_X + y * STRIDE_Y + z * STRIDE_Z), val, mask=mask) + + +@triton.jit +def make_tensor_ptr_implicit_perm_load_store_3d( + ptr, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z, + STRIDE_Y, + STRIDE_X, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x0 = pid_x * XBLOCK + y0 = pid_y * YBLOCK + z0 = pid_z * ZBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + tptr = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel, znumel), + strides=(STRIDE_X, STRIDE_Y, STRIDE_Z), + offsets=(x0, y0, z0), + block_shape=(XBLOCK, YBLOCK, ZBLOCK), + order=(0, 1, 2), + ) + + val = tl.load(tptr) + tl.store(out + (x * STRIDE_X + y * STRIDE_Y + z * STRIDE_Z), val, mask=mask) + + +@triton.jit +def make_tensor_ptr_implicit_perm_load_3d_static_stride( + ptr, + out, + znumel, # logical z (== X) + ynumel, # logical y (== Y) + xnumel, # logical x (== Z) + STRIDE_Z: tl.constexpr, + STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, + # out is row-major with shape (xnumel, ynumel, znumel) + OUT_STRIDE0: tl.constexpr, # = ynumel*znumel + OUT_STRIDE1: tl.constexpr, # = znumel + OUT_STRIDE2: tl.constexpr, # = 1 + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x0 = pid_x * XBLOCK + y0 = pid_y * YBLOCK + z0 = pid_z * ZBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + # load: implicit permute view + tptr = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel, znumel), + strides=(STRIDE_X, STRIDE_Y, STRIDE_Z), + offsets=(x0, y0, z0), + block_shape=(XBLOCK, YBLOCK, ZBLOCK), + order=(0, 1, 2), + ) + val = tl.load(tptr) + + # store: row-major output (no implicit permute) + out_offset = x * OUT_STRIDE0 + y * OUT_STRIDE1 + z * OUT_STRIDE2 + tl.store(out + out_offset, val, mask=mask) + + + +@triton.jit +def advance_implicit_perm_load_store_2d_static_stride( + ptr, + out, + ynumel, + xnumel, + STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr +): + y0 = tl.program_id(1) * YBLOCK + x0 = tl.program_id(0) * XBLOCK + y = y0 + tl.arange(0, YBLOCK)[None, :] + x = x0 + tl.arange(0, XBLOCK)[:, None] + mask = (x < xnumel) & (y < ynumel) + + tptr = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel), + strides=(STRIDE_X, STRIDE_Y), + offsets=(0, 0), + block_shape=(XBLOCK, YBLOCK), + order=(0, 1), + ) + tptr2 = tl.advance(tptr, (x0, y0)) + val = tl.load(tptr2) + tl.store(out + (x * STRIDE_X + y * STRIDE_Y), val, mask=mask) + + +@triton.jit +def advance_implicit_perm_load_store_3d_static_stride( + ptr, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z: tl.constexpr, + STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, + ZBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + x0 = pid_x * XBLOCK + y0 = pid_y * YBLOCK + z0 = pid_z * ZBLOCK + + x = x0 + tl.arange(0, XBLOCK)[:, None, None] + y = y0 + tl.arange(0, YBLOCK)[None, :, None] + z = z0 + tl.arange(0, ZBLOCK)[None, None, :] + + mask = (x < xnumel) & (y < ynumel) & (z < znumel) + + tptr0 = tl.make_block_ptr( + base=ptr, + shape=(xnumel, ynumel, znumel), + strides=(STRIDE_X, STRIDE_Y, STRIDE_Z), + offsets=(0, 0, 0), + block_shape=(XBLOCK, YBLOCK, ZBLOCK), + order=(0, 1, 2), + ) + tptr = tl.advance(tptr0, (x0, y0, z0)) + val = tl.load(tptr) + tl.store(out + (x * STRIDE_X + y * STRIDE_Y + z * STRIDE_Z), val, mask=mask) + + +# ---------------------------------------------------------- +# pytest case +# ---------------------------------------------------------- +def ceil_div(a, b): + return (a + b - 1) // b + + +def _assert_row_major_2d(A, X, Y): + assert tuple(A.shape) == (X, Y) + assert A.is_contiguous() + assert A.stride() == (Y, 1) + + +def _assert_row_major_3d(A, X, Y, Z): + # [X, Y, Z] contiguous -> stride = (Y*Z, Z, 1) + assert tuple(A.shape) == (X, Y, Z) + assert A.is_contiguous() + assert A.stride() == (Y * Z, Z, 1) + + +def _assert_row_major_4d(A, X, Y, Z, W): + # [X, Y, Z, W] contiguous -> stride = (Y*Z*W, Z*W, W, 1) + assert tuple(A.shape) == (X, Y, Z, W) + assert A.is_contiguous() + assert A.stride() == (Y * Z * W, Z * W, W, 1) + + +# ---------------------------------------------------------- +# pytest case: addptr kernels +# ---------------------------------------------------------- +@pytest.mark.parametrize("X, Y, XBLOCK, YBLOCK", case_2d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_2d_static_stride( + X, + Y, + XBLOCK, + YBLOCK, + dtype, + sigtype, +): + """ + Test goal: + - Real memory layout: A[X, Y], row-major (stride = (Y, 1)) + - Kernel view: A^T[Y, X], stride = (1, Y) + - Kernel does load+store with identical offsets + - Result must satisfy: out == in + """ + A = test_common.generate_tensor( + shape=(X, Y), + dtype=sigtype, + ).npu() + + _assert_row_major_2d(A, X, Y) + + out = torch.zeros_like(A) + + # A^T logical shape + xnumel = Y # cols of A + ynumel = X # rows of A + + # A^T logical stride + stride_x = 1 + stride_y = Y + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + 1, + ) + + addptr_implicit_perm_load_store_2d_static_stride[grid]( + A, + out, + ynumel, + xnumel, + stride_y, + stride_x, + XBLOCK=XBLOCK, + YBLOCK=YBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, XBLOCK, YBLOCK", case_2d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_2d( + X, + Y, + XBLOCK, + YBLOCK, + dtype, + sigtype, +): + """ + Same as static-stride version, but stride passed as runtime values. + """ + A = test_common.generate_tensor( + shape=(X, Y), + dtype=sigtype, + ).npu() + + _assert_row_major_2d(A, X, Y) + + out = torch.zeros_like(A) + + # A^T logical shape + xnumel = Y # cols of A + ynumel = X # rows of A + + # A^T logical stride + stride_x = 1 + stride_y = Y + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + 1, + ) + + addptr_implicit_perm_load_store_2d[grid]( + A, + out, + ynumel, + xnumel, + stride_y, + stride_x, + XBLOCK=XBLOCK, + YBLOCK=YBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_3d_static_stride( + X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype +): + """ + Test goal: + - Real memory layout: A[X, Y, Z], row-major (stride = (Y*Z, Z, 1)) + - Kernel view: treat as permuted logical coords via stride: + offset = x*1 + y*Z + z*(Y*Z) + i.e. (x,y,z) mapped to base index (z, y, x) + - Kernel does load+store with identical offsets + - Result must satisfy: out == in + """ + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + out = torch.zeros_like(A) + + # Logical shape for "A^(perm)" (x fastest) + xnumel = Z + ynumel = Y + znumel = X + + # Logical strides (in elements): (1, Z, Y*Z) + stride_x = 1 + stride_y = Z + stride_z = Y * Z + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + addptr_implicit_perm_load_store_3d_static_stride[grid]( + A, + out, + znumel, + ynumel, + xnumel, + stride_z, + stride_y, + stride_x, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_3d( + X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype +): + """ + Same as static-stride version, but stride passed as runtime values. + """ + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + out = torch.zeros_like(A) + + xnumel = Z + ynumel = Y + znumel = X + + stride_x = 1 + stride_y = Z + stride_z = Y * Z + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + addptr_implicit_perm_load_store_3d[grid]( + A, + out, + znumel, + ynumel, + xnumel, + stride_z, + stride_y, + stride_x, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK", case_4d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_4d_static_stride( + X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK, dtype, sigtype +): + """ + Test goal: + - Real memory layout: A[X, Y, Z, W], row-major (stride = (Y*Z*W, Z*W, W, 1)) + - Kernel view: treat as permuted logical coords via stride: + offset = x*1 + y*W + z*(Z*W) + w*(Y*Z*W) + i.e. (x,y,z,w) mapped to base index (w, z, y, x) + - Kernel does load+store with identical offsets + - Result must satisfy: out == in + """ + A = test_common.generate_tensor(shape=(X, Y, Z, W), dtype=sigtype).npu() + _assert_row_major_4d(A, X, Y, Z, W) + out = torch.zeros_like(A) + + # Logical shape (x fastest) + xnumel = W + ynumel = Z + znumel = Y + wnumel = X + + # Logical strides (in elements): (1, W, Z*W, Y*Z*W) + stride_x = 1 + stride_y = W + stride_z = Z * W + stride_w = Y * Z * W + + # Kernel maps pid0 over (w, x). It uses xblocks_per_w computed from xnumel. + xblocks_per_w = ceil_div(xnumel, XBLOCK) + grid0 = wnumel * xblocks_per_w + grid = ( + grid0, + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + addptr_implicit_perm_load_store_4d_static_stride[grid]( + A, + out, + wnumel, + znumel, + ynumel, + xnumel, + stride_w, + stride_z, + stride_y, + stride_x, + WBLOCK=WBLOCK, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK", case_4d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_addptr_implicit_perm_load_store_4d( + X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK, dtype, sigtype +): + """ + Same as static-stride version, but stride passed as runtime values. + """ + A = test_common.generate_tensor(shape=(X, Y, Z, W), dtype=sigtype).npu() + _assert_row_major_4d(A, X, Y, Z, W) + out = torch.zeros_like(A) + + xnumel = W + ynumel = Z + znumel = Y + wnumel = X + + stride_x = 1 + stride_y = W + stride_z = Z * W + stride_w = Y * Z * W + + xblocks_per_w = ceil_div(xnumel, XBLOCK) + grid0 = wnumel * xblocks_per_w + grid = ( + grid0, + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + addptr_implicit_perm_load_store_4d[grid]( + A, + out, + wnumel, + znumel, + ynumel, + xnumel, + stride_w, + stride_z, + stride_y, + stride_x, + WBLOCK=WBLOCK, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +# ---------------------------------------------------------- +# pytest case: make_tensor_ptr kernels +# ---------------------------------------------------------- +@pytest.mark.parametrize("X, Y, XBLOCK, YBLOCK", case_2d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_make_tensor_ptr_implicit_perm_load_store_2d_static_stride( + X, Y, XBLOCK, YBLOCK, dtype, sigtype +): + """ + Test goal matches addptr_2d_static_stride, but uses tl.make_block_ptr + tl.load(tptr). + Real layout: A[X,Y] row-major stride=(Y,1) + Kernel view: A^T[Y,X] stride=(1,Y) + Store is by explicit linear offset with same logical stride. + """ + A = test_common.generate_tensor(shape=(X, Y), dtype=sigtype).npu() + _assert_row_major_2d(A, X, Y) + out = torch.zeros_like(A) + + xnumel = Y + ynumel = X + STRIDE_X = 1 + STRIDE_Y = Y + + grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) + make_tensor_ptr_implicit_perm_load_store_2d_static_stride[grid]( + A, + out, + ynumel, + xnumel, + STRIDE_Y=STRIDE_Y, + STRIDE_X=STRIDE_X, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_make_tensor_ptr_implicit_perm_load_store_3d_static_stride( + X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype +): + """ + Real layout: A[X,Y,Z] row-major stride=(Y*Z, Z, 1) + Kernel view (logical): shape=(Z,Y,X), strides=(1, Z, Y*Z) + """ + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + out = torch.zeros_like(A) + + xnumel = Z + ynumel = Y + znumel = X + STRIDE_X = 1 + STRIDE_Y = Z + STRIDE_Z = Y * Z + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + make_tensor_ptr_implicit_perm_load_store_3d_static_stride[grid]( + A, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z=STRIDE_Z, + STRIDE_Y=STRIDE_Y, + STRIDE_X=STRIDE_X, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_make_tensor_ptr_implicit_perm_load_store_3d( + X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype +): + """ + Same as static stride but STRIDE_* passed at runtime. + """ + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + out = torch.zeros_like(A) + + xnumel = Z + ynumel = Y + znumel = X + STRIDE_X = 1 + STRIDE_Y = Z + STRIDE_Z = Y * Z + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + make_tensor_ptr_implicit_perm_load_store_3d[grid]( + A, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z, + STRIDE_Y, + STRIDE_X, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_make_tensor_ptr_implicit_perm_load_3d_static_stride( + X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype +): + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + + # logical/permuted shape + xnumel = Z + ynumel = Y + znumel = X + + # implicit-permute strides (elements) + STRIDE_X = 1 + STRIDE_Y = Z + STRIDE_Z = Y * Z + + # output is row-major of shape (Z, Y, X) + out = torch.empty((xnumel, ynumel, znumel), device="npu", dtype=A.dtype) + assert out.is_contiguous() + OUT_STRIDE0 = ynumel * znumel # Y*X + OUT_STRIDE1 = znumel # X + OUT_STRIDE2 = 1 + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + make_tensor_ptr_implicit_perm_load_3d_static_stride[grid]( + A, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z=STRIDE_Z, + STRIDE_Y=STRIDE_Y, + STRIDE_X=STRIDE_X, + OUT_STRIDE0=OUT_STRIDE0, + OUT_STRIDE1=OUT_STRIDE1, + OUT_STRIDE2=OUT_STRIDE2, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + # expected: out[x,y,z] == A[z,y,x] => out == A.permute(2,1,0) + ref = A.permute(2, 1, 0).contiguous() + torch.testing.assert_close(out, ref) + + +# ---------------------------------------------------------- +# pytest case: advance kernels +# ---------------------------------------------------------- +@pytest.mark.parametrize("X, Y, XBLOCK, YBLOCK", case_2d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_advance_implicit_perm_load_store_2d_static_stride( + X, Y, XBLOCK, YBLOCK, dtype, sigtype +): + """ + Same goal as addptr_2d_static_stride, but uses tl.make_block_ptr + tl.advance. + """ + A = test_common.generate_tensor(shape=(X, Y), dtype=sigtype).npu() + _assert_row_major_2d(A, X, Y) + out = torch.zeros_like(A) + + xnumel = Y + ynumel = X + STRIDE_X = 1 + STRIDE_Y = Y + + grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) + advance_implicit_perm_load_store_2d_static_stride[grid]( + A, + out, + ynumel, + xnumel, + STRIDE_Y=STRIDE_Y, + STRIDE_X=STRIDE_X, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +@pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) +@pytest.mark.parametrize("dtype, sigtype", types_all) +def test_advance_implicit_perm_load_store_3d_static_stride( + X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype +): + """ + Real layout: A[X,Y,Z] row-major stride=(Y*Z, Z, 1) + Kernel view (logical): shape=(Z,Y,X), strides=(1, Z, Y*Z) + """ + A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() + _assert_row_major_3d(A, X, Y, Z) + out = torch.zeros_like(A) + + xnumel = Z + ynumel = Y + znumel = X + STRIDE_X = 1 + STRIDE_Y = Z + STRIDE_Z = Y * Z + + grid = ( + ceil_div(xnumel, XBLOCK), + ceil_div(ynumel, YBLOCK), + ceil_div(znumel, ZBLOCK), + ) + + advance_implicit_perm_load_store_3d_static_stride[grid]( + A, + out, + znumel, + ynumel, + xnumel, + STRIDE_Z=STRIDE_Z, + STRIDE_Y=STRIDE_Y, + STRIDE_X=STRIDE_X, + ZBLOCK=ZBLOCK, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) + + torch.testing.assert_close(out, A) + + +if __name__ == "__main__": + test_addptr_implicit_perm_load_store_2d_static_stride(*case_2d[0], *types_all[0]) + test_addptr_implicit_perm_load_store_2d(*case_2d[0], *types_all[0]) + test_addptr_implicit_perm_load_store_3d_static_stride(*case_3d[0], *types_all[0]) + test_addptr_implicit_perm_load_store_3d(*case_3d[0], *types_all[0]) + test_addptr_implicit_perm_load_store_4d_static_stride(*case_4d[0], *types_all[0]) + test_addptr_implicit_perm_load_store_4d(*case_4d[0], *types_all[0]) + test_make_tensor_ptr_implicit_perm_load_store_2d_static_stride(*case_2d[0], *types_all[0]) + test_make_tensor_ptr_implicit_perm_load_store_3d_static_stride(*case_3d[0], *types_all[0]) + test_make_tensor_ptr_implicit_perm_load_store_3d(*case_3d[0], *types_all[0]) + test_make_tensor_ptr_implicit_perm_load_3d_static_stride(*case_3d[0], *types_all[0]) + test_advance_implicit_perm_load_store_2d_static_stride(*case_2d[0], *types_all[0]) + test_advance_implicit_perm_load_store_3d_static_stride(*case_3d[0], *types_all[0]) diff --git a/third_party/ascend/unittest/pytest_ut/test_indirect_scalar_load_offset.py b/third_party/ascend/unittest/pytest_ut/test_indirect_scalar_load_offset.py new file mode 100644 index 0000000000..c7879a5392 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_indirect_scalar_load_offset.py @@ -0,0 +1,92 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + + +@triton.jit +def gather_after_reduce_kernel( + logits_ptr, + topk_ids_ptr, + output_ptr, + logits_stride, + vocab_size, + BLOCK: tl.constexpr, +): + req_idx = tl.program_id(0) + + max_val = -float('inf') + for start in range(0, vocab_size, BLOCK): + offsets = start + tl.arange(0, BLOCK) + mask = offsets < vocab_size + vals = tl.load( + logits_ptr + req_idx * logits_stride + offsets, + mask=mask, other=-float('inf'), + ) + block_max = tl.max(vals) + max_val = tl.maximum(max_val, block_max) + + topk_id = tl.load(topk_ids_ptr + req_idx + tl.arange(0, 1)) + val = tl.load(logits_ptr + req_idx * logits_stride + topk_id) + tl.store(output_ptr + req_idx + tl.arange(0, 1), val - max_val) + + +def torch_reference(logits, topk_ids): + num_rows = logits.shape[0] + output = torch.empty(num_rows, dtype=logits.dtype) + for i in range(num_rows): + max_val = logits[i].max() + output[i] = logits[i, topk_ids[i]] - max_val + return output + + +shapes = [ + (4, 128), + (8, 256), + (16, 1024), +] + + +@pytest.mark.parametrize('num_rows,vocab_size', shapes) +def test_gather_after_reduce(num_rows, vocab_size): + BLOCK = 128 + + logits_ref = test_common.generate_tensor(shape=(num_rows, vocab_size), dtype='float32') + logits = logits_ref.npu() + logits_flat = logits.reshape(-1) + + topk_ids_ref = torch.randint(0, vocab_size, (num_rows,), dtype=torch.int64) + topk_ids = topk_ids_ref.npu() + + output = torch.empty(num_rows, dtype=torch.float32).npu() + + gather_after_reduce_kernel[(num_rows,)]( + logits_flat, topk_ids, output, vocab_size, vocab_size, BLOCK=BLOCK, + ) + + output_ref = torch_reference(logits_ref, topk_ids_ref) + test_common.validate_cmp('float32', output, output_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_interleave_optimizaiton.py b/third_party/ascend/unittest/pytest_ut/test_interleave_optimizaiton.py new file mode 100644 index 0000000000..e50bcdbf6a --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_interleave_optimizaiton.py @@ -0,0 +1,130 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common + + +def torch_interleave_load(q, k, head_dim_half, bias): + d_indices = torch.arange(0, head_dim_half) + k[d_indices * 2 + bias] = q[d_indices * 2 + bias] + k[d_indices * 2 + 1 + bias] = -q[d_indices * 2 + 1 + bias] + return k + + +def torch_interleave_load_with_mask(q, k, head_dim_half, bias, numel): + d_indices = torch.arange(0, min(head_dim_half, numel)) + k[d_indices * 2 + bias] = q[d_indices * 2 + bias] + k[d_indices * 2 + 1 + bias] = -q[d_indices * 2 + 1 + bias] + return k + + +def torch_interleave_loadstore_with_mask(q, head_dim_half, bias, numel): + d_indices = torch.arange(0, min(head_dim_half, numel)) + # it's unneccessary since we store it back without edit: q[d_indices * 2 + bias] = q[d_indices * 2 + bias] + q[d_indices * 2 + 1 + bias] = -q[d_indices * 2 + 1 + bias] + return q + + +@triton.jit +def triton_interleave_load(q_ptr, k_ptr, head_dim_half: tl.constexpr, bias: tl.constexpr): + d_indices = tl.program_id(0) + tl.arange(0, head_dim_half) + q_real = tl.load(q_ptr + d_indices * 2 + bias) + q_imag = tl.load(q_ptr + d_indices * 2 + 1 + bias) + new_q_real = q_real + new_q_imag = -q_imag + tl.store(k_ptr + d_indices * 2 + bias, new_q_real) + tl.store(k_ptr + d_indices * 2 + 1 + bias, new_q_imag) + + +@triton.jit +def triton_interleave_load_with_mask(q_ptr, k_ptr, head_dim_half: tl.constexpr, bias: tl.constexpr, numel: tl.constexpr): + d_indices = tl.program_id(0) + tl.arange(0, head_dim_half) + mask = d_indices < numel + q_real = tl.load(q_ptr + d_indices * 2 + bias, mask) + q_imag = tl.load(q_ptr + d_indices * 2 + 1 + bias, mask) + new_q_real = q_real + new_q_imag = -q_imag + tl.store(k_ptr + d_indices * 2 + bias, new_q_real, mask) + tl.store(k_ptr + d_indices * 2 + 1 + bias, new_q_imag, mask) + + +# when load and store are on the same pointer, sometimes we can only optimize the store with mask +@triton.jit +def triton_interleave_loadstore_with_mask(q_ptr, head_dim_half: tl.constexpr, bias: tl.constexpr, numel: tl.constexpr): + d_indices = tl.arange(0, head_dim_half) + mask = d_indices < numel + q_real = tl.load(q_ptr + d_indices * 2 + bias, mask) + q_imag = tl.load(q_ptr + d_indices * 2 + 1 + bias, mask) + new_q_real = q_real + new_q_imag = -q_imag + tl.store(q_ptr + d_indices * 2 + bias, new_q_real, mask) + tl.store(q_ptr + d_indices * 2 + 1 + bias, new_q_imag, mask) + + +@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias', + [ + ['float32', torch.float32, 16, 4], + ] + ) +def test_interleave(para_type, data_type, head_dim_half, bias): + length = bias + head_dim_half * 2 + q = torch.randn((length,), dtype=data_type).npu() + k = torch.zeros_like(q, dtype=data_type).npu() + k_ref = torch.zeros_like(q, dtype=data_type).npu() + + triton_interleave_load[(1,)](q, k, head_dim_half, bias) + k_ref = torch_interleave_load(q, k_ref, head_dim_half, bias) + assert torch.allclose(k, k_ref) + + +@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias,numel', + [ + ['float32', torch.float32, 16, 0, 8], + ] + ) +def test_interleave_with_mask(para_type, data_type, head_dim_half, bias, numel): + length = bias + head_dim_half * 2 + q = torch.randn((length,), dtype=data_type).npu() + k = torch.zeros_like(q, dtype=data_type).npu() + k_ref = torch.zeros_like(q, dtype=data_type).npu() + + triton_interleave_load_with_mask[(1,)](q, k, head_dim_half, bias, numel) + k_ref = torch_interleave_load_with_mask(q, k_ref, head_dim_half, bias, numel) + assert torch.allclose(k, k_ref) + + +@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias,numel', + [ + ['float32', torch.float32, 16, 0, 8], + ] + ) +def test_interleave_loadstore_with_mask(para_type, data_type, head_dim_half, bias, numel): + length = bias + head_dim_half * 2 + q = torch.randn((length,), dtype=data_type).npu() + q_ref = q.clone() + + triton_interleave_loadstore_with_mask[(1,)](q, head_dim_half, bias, numel) + q_ref = torch_interleave_loadstore_with_mask(q_ref, head_dim_half, bias, numel) + assert torch.allclose(q, q_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_lgamma.py b/third_party/ascend/unittest/pytest_ut/test_lgamma.py index bc5db118ac..0633922df0 100644 --- a/third_party/ascend/unittest/pytest_ut/test_lgamma.py +++ b/third_party/ascend/unittest/pytest_ut/test_lgamma.py @@ -86,3 +86,31 @@ def test_all_blocks_parallel(param_list, monkeypatch): triton_lgamma[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub) test_common.validate_cmp(dtype, y_cal, y_ref) monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 2048, 8), 2, 32768, 512], + ] + ) +def test_auto_blockify(param_list, monkeypatch): + monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") + dtype, shape, ncore, xblock, xblock_sub = param_list + x = test_common.generate_tensor(shape, dtype).npu() + + # Avoid numerical instability near negative integer + nearest_int = torch.round(x) + neg_mask = nearest_int <= -1 + threshold = torch.zeros_like(x) + if neg_mask.any(): + neg_ints = nearest_int[neg_mask] + threshold[neg_mask] = 5.75e-5 * (2.42 ** (-1 - neg_ints)) + mask = (torch.abs(x - nearest_int) < threshold) & (nearest_int <= -1) + if mask.any(): + x = torch.where(mask, nearest_int, x) + + y_ref = torch.lgamma(x).npu() + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_lgamma[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub, auto_blockify_size=ncore) + test_common.validate_cmp(dtype, y_cal, y_ref) + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") diff --git a/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py b/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py index 64790e1935..4382b701c5 100644 --- a/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py +++ b/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py @@ -97,10 +97,56 @@ def triton_linearize_mask_broadcast(in_tensor, BLOCK_SIZE): N = in_tensor.shape[1] triton_output = torch.zeros_like(in_tensor) - grid = (ceil_div(2 * M * N, BLOCK_SIZE), ) + grid = (ceil_div(2 * M * N, BLOCK_SIZE),) + + linearize_mask_broadcast_kernel[grid]( + in_tensor, + triton_output, + N=N, + M=M, + BLOCK_SIZE_N=BLOCK_SIZE, + optimize_dynamic_offset=True + ) - linearize_mask_broadcast_kernel[grid](in_tensor, triton_output, N=N, M=M, BLOCK_SIZE_N=BLOCK_SIZE, - optimize_dynamic_offset=True) + +@triton.jit +def rem_kernel(in_ptr0, in_ptr1, out_ptr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + x = tl.arange(0, BLOCK_SIZE) + + base_offset = pid * BLOCK_SIZE + x + + rem_result = base_offset % 128 + mask = rem_result < 64 + + tmp0 = tl.load(in_ptr0 + base_offset, mask=mask, other=0.0) + tmp1 = tl.load(in_ptr1 + base_offset, mask=mask, other=0.0) + tmp2 = tmp0 + tmp1 + + tl.store(out_ptr + base_offset, tmp2, mask=mask) + + +def test_linearize_mask_rem(): + N = 1024 + BLOCK_SIZE = 256 + dtype = 'float32' + shape = (N,) + + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.zeros(shape).npu() + + grid = (ceil_div(N, BLOCK_SIZE),) + rem_kernel[grid](x0, x1, triton_res, N, BLOCK_SIZE=BLOCK_SIZE) + + base_offsets = torch.arange(N).npu() + rem_results = base_offsets % 128 + mask_bool = rem_results < 64 + + torch_res = torch.zeros((N,)).npu() + torch_res[mask_bool] = x0[mask_bool] + x1[mask_bool] + + test_common.validate_cmp(dtype, triton_res, torch_res) def profile_performance_test(M, N, dtype, BLOCK_SIZE): diff --git a/third_party/ascend/unittest/pytest_ut/test_makeblockptr_negative_padding.py b/third_party/ascend/unittest/pytest_ut/test_makeblockptr_negative_padding.py new file mode 100644 index 0000000000..abd13ba327 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_makeblockptr_negative_padding.py @@ -0,0 +1,140 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is herey_size granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import test_common + + +@triton.jit +def negative_padding_with_load_kernel( + input_ptr, + output_ptr, + x_offset: tl.constexpr, + y_offset: tl.constexpr, + x_size: tl.constexpr, + y_size: tl.constexpr, +): + in_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(x_size, y_size), + strides=(y_size, 1), + offsets=(x_offset, y_offset), + block_shape=(x_size, y_size), + order=(1, 0), + ) + out_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(x_size, y_size), + strides=(y_size, 1), + offsets=(0, 0), + block_shape=(x_size, y_size), + order=(1, 0), + ) + in_val = tl.load(in_ptr, boundary_check=(0, 1), padding_option="zero") + tl.store(out_ptr, in_val) + + +@triton.jit +def negative_padding_with_store_kernel( + input_ptr, + output_ptr, + x_offset: tl.constexpr, + y_offset: tl.constexpr, + x_size: tl.constexpr, + y_size: tl.constexpr, +): + in_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(x_size, y_size), + strides=(y_size, 1), + offsets=(0, 0), + block_shape=(x_size, y_size), + order=(1, 0), + ) + out_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(x_size, y_size), + strides=(y_size, 1), + offsets=(x_offset, y_offset), + block_shape=(x_size, y_size), + order=(1, 0), + ) + in_val = tl.load(in_ptr) + tl.store(out_ptr, in_val, boundary_check=(0, 1)) + + +@pytest.mark.parametrize('param_list', [ + (8, 8), (16, 16), (32, 32), (64, 64) +]) +def test_makeblockptr_load_with_negative_padding(param_list): + shape = param_list + torch.manual_seed(1) + x_offset = torch.randint(shape[0], size=()).item() + # y_offset = torch.randint(shape[1], size=()).item() + y_offset = 0 + input_tensor = torch.arange(start=1, end=shape[0] * shape[1] + 1, dtype=torch.int32).view(shape).npu() + output = torch.zeros(shape, dtype=torch.int32).npu() + negative_padding_with_load_kernel[(1, )]( + input_tensor, + output, + -x_offset, + -y_offset, + shape[0], + shape[1], + ) + output_ref = torch.zeros((shape[0] + x_offset, shape[1] + y_offset), dtype=torch.int32).cpu() + output_subview = torch.narrow(output_ref, 0, x_offset, shape[0]) + output_subview = torch.narrow(output_subview, 1, y_offset, shape[1]) + output_subview.copy_(input_tensor) + output_ref = torch.narrow(output_ref, 0, 0, shape[0]) + output_ref = torch.narrow(output_ref, 1, 0, shape[1]) + test_common.validate_cmp("int32", output, output_ref) + + +@pytest.mark.parametrize('param_list', [ + (8, 8), (16, 16), (32, 32), (64, 64) +]) +def test_makeblockptr_store_with_negative_padding(param_list): + shape = param_list + torch.manual_seed(1) + x_offset = torch.randint(shape[0], size=()).item() + # y_offset = torch.randint(shape[1], size=()).item() + y_offset = 0 + input_tensor = torch.arange(start=1, end=shape[0] * shape[1] + 1, dtype=torch.int32).view(shape).npu() + output = torch.zeros(shape, dtype=torch.int32).npu() + negative_padding_with_store_kernel[(1, )]( + input_tensor, + output, + -x_offset, + -y_offset, + shape[0], + shape[1], + ) + output_ref = torch.zeros(shape, dtype=torch.int32).cpu() + input_subview = torch.narrow(input_tensor, 0, x_offset, shape[0] - x_offset) + input_subview = torch.narrow(input_subview, 1, y_offset, shape[1] - y_offset) + output_subview = torch.narrow(output_ref, 0, 0, shape[0] - x_offset) + output_subview = torch.narrow(output_subview, 1, 0, shape[1] - y_offset) + output_subview.copy_(input_subview) + test_common.validate_cmp("int32", output, output_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_mod.py b/third_party/ascend/unittest/pytest_ut/test_mod.py index 5886403225..73690260f9 100644 --- a/third_party/ascend/unittest/pytest_ut/test_mod.py +++ b/third_party/ascend/unittest/pytest_ut/test_mod.py @@ -26,8 +26,15 @@ import test_common -def torch_pointwise(x0, x1): - res = x0 % x1 +def torch_pointwise(x0, x1, dtype): + if dtype == 'float16': + x0 = x0.to(torch.float32) + x1 = x1.to(torch.float32) + elif dtype == 'float32': + x0 = x0.to(torch.float64) + x1 = x1.to(torch.float64) + res = torch.div(x0, x1, rounding_mode="trunc") + res = x0 - x1 * res return res @@ -58,9 +65,10 @@ def test_case(param_list): else: x0 = test_common.generate_tensor(shape, dtype).npu() x1 = test_common.generate_tensor(shape, dtype).npu() - y_ref = torch_pointwise(x0.cpu(), x1.cpu()) - y_ref = y_ref.npu() - y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + y_ref = torch_pointwise(x0, x1, dtype) + if dtype == "float16": + y_ref = y_ref.to(torch.float16) + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() triton_mod[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) #test_common.validate_cmp(dtype, y_cal, y_ref.npu()) if dtype == 'int8': diff --git a/third_party/ascend/unittest/pytest_ut/test_mul_reduce.py b/third_party/ascend/unittest/pytest_ut/test_mul_reduce.py new file mode 100644 index 0000000000..c69b5aef64 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_mul_reduce.py @@ -0,0 +1,62 @@ +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import numpy as np + + +@triton.jit +def minimum(a, b): + ret = tl.minimum(a, b, tl.PropagateNan.ALL) + if a.dtype == tl.bfloat16: + ret = ret.to(tl.bfloat16) + return ret + + +@triton.jit +def triton_pw_rdc5d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret0 = x0 * x1 + ret = tl.reduce(ret0, 4, minimum, keep_dims=True) + zblk_idx = tl.arange(0, 1) + odx = (lblk_idx[:, None, None, None, None] * K * N * M + + mblk_idx[None, :, None, None, None] * K * N + + nblk_idx[None, None, :, None, None] * K + + kblk_idx[None, None, None, :, None] + + zblk_idx[None, None, None, None, :]) + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("shape", [(16, 1, 1, 1, 1)]) # L=16, others=1 +def test_pw_rdc5d(dtype, shape): + L, M, N, K, Z = shape + a = torch.randn(*shape, dtype=dtype, device='npu') + b = torch.randn(*shape, dtype=dtype, device='npu') + out = torch.empty(*shape, dtype=dtype, device='npu') + + expected = (a * b).to(dtype) + + triton_pw_rdc5d[(1,)]( + a, b, out, + L=L, M=M, N=N, K=K, Z=Z + ) + + torch.testing.assert_close(out.cpu(), expected.cpu(), rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/third_party/ascend/unittest/pytest_ut/test_multibuffer.py b/third_party/ascend/unittest/pytest_ut/test_multibuffer.py new file mode 100644 index 0000000000..4f3e5c0370 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_multibuffer.py @@ -0,0 +1,75 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os + +import pytest +import triton +import triton.language as tl +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +import triton.extension.buffer.language as bl +import triton.language.extra.cann.extension as al +from triton._C.libtriton import ir, buffer_ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + buffer_ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + + +@triton.jit +def multibuffer(XBLOCK: tl.constexpr): + buf = bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.UB) + al.multibuffer(buf, 2) + + +def test_multibuffer(): + print("=" * 60) + print("Test 1: test_alloc_ub_multibuffer") + print("=" * 60) + mlir = compile_kernel( + multibuffer, {}, {"XBLOCK": 256} + ) + print(f"Generated MLIR ({len(mlir)} chars):\n") + print(mlir) + + +# ============== Main for manual testing ============== +if __name__ == "__main__": + test_multibuffer() \ No newline at end of file diff --git a/third_party/ascend/unittest/generalization_cases/test_general_arange.py b/third_party/ascend/unittest/pytest_ut/test_negative_mask_dim.py similarity index 52% rename from third_party/ascend/unittest/generalization_cases/test_general_arange.py rename to third_party/ascend/unittest/pytest_ut/test_negative_mask_dim.py index 4e7f2df67b..02f9e79d68 100644 --- a/third_party/ascend/unittest/generalization_cases/test_general_arange.py +++ b/third_party/ascend/unittest/pytest_ut/test_negative_mask_dim.py @@ -18,49 +18,33 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -import math import pytest -import torch -import torch_npu + import triton import triton.language as tl -import test_common -from test_common import TestUtils - - -def torch_pointwise(length): - res = (torch.arange(0, length) / 2.7) * torch.arange(0, length) - return res - -def torch_arange(start, end): - TRITON_MAX_TENSOR_NUMEL = 1048576 - if end < start: - raise ValueError("arange's end argument must be greater than the start argument") - if end - start > TRITON_MAX_TENSOR_NUMEL: - raise ValueError( - f"end - start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}") - return torch.arange(start, end) +import torch +import torch_npu @triton.jit -def triton_arange(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): - off = tl.arange(0, BLOCK) - val = tl.arange(START, END) - tl.store(z + off, val) - - -@pytest.mark.parametrize('shape', TestUtils.test_shape1d) -def test_case(shape): - start = 0 - end = shape[0] - shape = [end - start] - block = end - start - dtype = 'int32' - - y_ref = torch_arange(start, end) - y_cal = torch.zeros(shape, dtype=torch.int32).npu() - - triton_arange[(1, )](y_cal, START=start, END=end, BLOCK=block) +def triton_negative_mask_dim(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): + index = tl.arange(0, XBLOCK) + mask = (index < 1) & (index + 1 >= XBLOCK) + tmp0 = tl.load(in_ptr0 + index, mask, other=0.0) + tl.store(out_ptr0 + index, tmp0, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (32,), 32], + ]) +def test_negative_mask_dim(param_list): + dtype, shape, xblock = param_list + x0 = torch.ones(shape, dtype=eval('torch.' + dtype)).npu() + y_ref = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + + y_cal = torch.ones(shape, dtype=eval('torch.' + dtype)).npu() + triton_negative_mask_dim[(1,)](x0, y_cal, xblock) + assert torch.allclose(y_cal, y_ref) - assert torch.equal(y_cal.cpu(), y_ref.cpu()) diff --git a/third_party/ascend/unittest/pytest_ut/test_nextafter.py b/third_party/ascend/unittest/pytest_ut/test_nextafter.py index 4df371dd11..45abbef2bf 100644 --- a/third_party/ascend/unittest/pytest_ut/test_nextafter.py +++ b/third_party/ascend/unittest/pytest_ut/test_nextafter.py @@ -47,11 +47,11 @@ def triton_nextafter(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SU tl.store(out_ptr0 + (x0), tmp2, None) -@pytest.mark.parametrize('param_list', [ - ['float32', (2, 4096, 8), 2, 32768, 1024], - ['float16', (2, 4096, 8), 2, 32768, 1024], - ['bfloat16', (2, 4096, 8), 2, 32768, 1024], -]) +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ]) def test_nextafter(param_list): dtype, shape, ncore, xblock, xblock_sub = param_list x0_ref = test_common.generate_tensor(shape, dtype) diff --git a/third_party/ascend/unittest/pytest_ut/test_paged_kvcache_krope.py b/third_party/ascend/unittest/pytest_ut/test_paged_kvcache_krope.py new file mode 100644 index 0000000000..0b5df57c72 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_paged_kvcache_krope.py @@ -0,0 +1,80 @@ +import torch +import triton +import triton.language as tl +import pytest + + +@triton.jit +def rope_like_load_kernel( + Kv_cache, + Req_to_tokens, + output_ptr, + stride_kv: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + ROPE_DIM: tl.constexpr, +): + + offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM) + offs_n = tl.arange(0, BLOCK_N) + + kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + + offs_k_pe = kv_loc[None, :] * stride_kv + offs_d_kpe[:, None] + + k_pe = tl.load(Kv_cache + offs_k_pe) + + offs_out = offs_n[:, None] * ROPE_DIM + tl.arange(0, ROPE_DIM)[None, :] + tl.store(output_ptr + offs_out, tl.trans(k_pe)) + + +def test_bubbleup_extract_nonzero_offset(): + device = "npu" + + PAGE_SIZE = 2 + BLOCK_N = 4 + head_dim = 32 + head_dim_v = 24 + rope_dim = head_dim - head_dim_v + num_pages = BLOCK_N // PAGE_SIZE + + req_to_tokens = torch.arange(num_pages, dtype=torch.int32, device=device) + total_tokens = num_pages * PAGE_SIZE + kv_cache = torch.zeros(total_tokens, head_dim, dtype=torch.float32, device=device) + for token_id in range(total_tokens): + kv_cache[token_id, :head_dim_v] = ( + torch.arange(head_dim_v, dtype=torch.float32) + token_id * 100 + ) + kv_cache[token_id, head_dim_v:] = ( + torch.arange(head_dim_v, head_dim, dtype=torch.float32) + token_id * 1000 + ) + output = torch.zeros(BLOCK_N, rope_dim, dtype=torch.float32, device=device) + + rope_like_load_kernel[(1,)]( + kv_cache.flatten(), + req_to_tokens, + output.flatten(), + stride_kv=head_dim, + HEAD_DIM_V=head_dim_v, + HEAD_DIM=head_dim, + PAGE_SIZE=PAGE_SIZE, + BLOCK_N=BLOCK_N, + ROPE_DIM=rope_dim, + ) + + expected = torch.zeros(BLOCK_N, rope_dim, dtype=torch.float32, device=device) + for token_id in range(BLOCK_N): + expected[token_id] = ( + torch.arange(head_dim_v, head_dim, dtype=torch.float32) + token_id * 1000 + ) + + buggy = torch.zeros(BLOCK_N, rope_dim, dtype=torch.float32, device=device) + for token_id in range(BLOCK_N): + buggy[token_id] = ( + torch.arange(rope_dim, dtype=torch.float32) + token_id * 100 + ) + + assert torch.allclose(output, expected, atol=1e-5) \ No newline at end of file diff --git a/third_party/ascend/unittest/pytest_ut/test_parallel.py b/third_party/ascend/unittest/pytest_ut/test_parallel.py index ff41149de0..f81c060811 100644 --- a/third_party/ascend/unittest/pytest_ut/test_parallel.py +++ b/third_party/ascend/unittest/pytest_ut/test_parallel.py @@ -81,6 +81,7 @@ def get_torch_typename(dtype): typelist = ['int8', 'int16', 'int32', 'int64'] +@pytest.mark.skip(reason="not supported after the NPUIR is updated in April, and will be fixed later") @pytest.mark.parametrize('L, M, N', testlist) @pytest.mark.parametrize('sigtype', typelist) def test_add_bind_false(sigtype, L, M, N): diff --git a/third_party/ascend/unittest/pytest_ut/test_permuted_boundary_check.py b/third_party/ascend/unittest/pytest_ut/test_permuted_boundary_check.py new file mode 100644 index 0000000000..7535c9ce40 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_permuted_boundary_check.py @@ -0,0 +1,60 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest + +@triton.jit +def zj_fa_fwd_pattern( + in_ptr0, in_ptr1, out_ptr, + M, K, N, + MBLOCK: tl.constexpr, + NBLOCK: tl.constexpr, + KBLOCK: tl.constexpr +): + a_ptr = tl.make_block_ptr( + base = in_ptr0, + shape = (M, K), # 8, 3 + strides = (K, 1), + offsets = (0, 0), + block_shape = (MBLOCK, KBLOCK), + order = (1, 0) + ) + + b_ptr = tl.make_block_ptr( + base = in_ptr1, + shape = (K, N), # 3, 8 + strides = (1, K), + offsets = (0, 0), + block_shape = (KBLOCK, NBLOCK), + order = (0, 1) + ) + + c_ptr = tl.make_block_ptr( + base = out_ptr, + shape = (M, N), + strides = (1, M), + offsets = (0, 0), + block_shape = (MBLOCK, NBLOCK), + order = (0, 1) + ) + + a = tl.load(a_ptr, boundary_check = (0,), padding_option="zero") + b = tl.load(b_ptr, boundary_check = (0,), padding_option="zero") + c = tl.dot(a, b) + tl.store(c_ptr, c, boundary_check = (0, 1)) + + +def test_permute_boundary_check(): + M = 8 + K = 3 + N = 8 + MBLOCK = 8 + NBLOCK = 8 + KBLOCK = 4 + a = torch.randn((M, K), device="npu") # 8, 3 + b = torch.randn((N, K), device="npu") # 8, 3 + c = torch.empty((N, M), device="npu") + zj_fa_fwd_pattern[(1,1,1)](a, b, c, M, K, N, MBLOCK, NBLOCK, KBLOCK) + std = a @ b.T + torch.testing.assert_close(std, c.T, atol = 1e-2, rtol = 1e-2) \ No newline at end of file diff --git a/third_party/ascend/unittest/pytest_ut/test_reduce_maximum.py b/third_party/ascend/unittest/pytest_ut/test_reduce_maximum.py new file mode 100644 index 0000000000..e1fb74e5b9 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_reduce_maximum.py @@ -0,0 +1,332 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +@triton.jit +def maximum(a, b): + ret = tl.maximum(a, b, tl.PropagateNan.ALL) + # 经过测试发现,tl.maximum仅在输入类型为bfloat16时,输出的结果会转变为float32,从而导致编译报错。在GPU上测试发现,和NPU上错误的现象一致。 + # 因此此处针对输入类型为bfloat16的情况,对输出进行了类型转换来规避该错误引起的编译报错。 + if a.dtype == tl.bfloat16: + ret = ret.to(tl.bfloat16) + return ret + + +@triton.jit +def triton_max_5d_dim024(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + idx = lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + \ + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[ + None, None, None, None, :] + odx = mblk_idx[:, None] * K + kblk_idx[None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 4, maximum) + ret1 = tl.reduce(ret, 2, maximum) + ret2 = tl.reduce(ret1, 0, maximum) + tl.store(out_ptr0 + odx, ret2) + + +@triton.jit +def triton_max_5d_dim13(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None] * N * Z + + nblk_idx[None, :, None] * Z + + zblk_idx[None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret_k = tl.reduce(x, 3, maximum) # [L, M, N, Z] + ret_m = tl.reduce(ret_k, 1, maximum) # [L, N, Z] + tl.store(out_ptr0 + odx, ret_m) + + +@triton.jit +def triton_max_5d_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (mblk_idx[:, None, None, None] * N * K * Z + + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 0, maximum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_max_5d_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * N * K * Z + + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 1, maximum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_max_5d_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * K * Z + + mblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 2, maximum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_max_5d_dim3(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * N * Z + + mblk_idx[None, :, None, None] * N * Z + + nblk_idx[None, None, :, None] * Z + + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 3, maximum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_max_5d_dim4(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * N * K + + mblk_idx[None, :, None, None] * N * K + + nblk_idx[None, None, :, None] * K + + kblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 4, maximum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_max_5d_all(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret1 = tl.reduce(x, 4, maximum) + ret2 = tl.reduce(ret1, 3, maximum) + ret3 = tl.reduce(ret2, 2, maximum) + ret4 = tl.reduce(ret3, 1, maximum) + ret5 = tl.reduce(ret4, 0, maximum) + tl.store(out_ptr0, ret5) + + +testlist = [ + (triton_max_5d_dim024, (1, 1, 1, 1, 1), "dim024"), + (triton_max_5d_dim024, (2, 2, 2, 2, 2), "dim024"), + (triton_max_5d_dim024, (3, 11, 1, 3, 42), "dim024"), + + (triton_max_5d_dim13, (1, 1, 1, 1, 1024), "dim13"), + + (triton_max_5d_dim0, (2, 2, 2, 2, 2), "dim0"), + (triton_max_5d_dim1, (2, 2, 2, 2, 2), "dim1"), + (triton_max_5d_dim2, (2, 2, 2, 2, 2), "dim2"), + (triton_max_5d_dim3, (2, 2, 2, 2, 2), "dim3"), + (triton_max_5d_dim4, (2, 2, 2, 2, 2), "dim4"), + + (triton_max_5d_all, (3, 11, 1, 3, 42), "all"), +] + +typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] + +ids = ["{}-{}-{}".format(testfunc.__name__, "-".join(map(str, shape)), dim_name) + for testfunc, shape, dim_name in testlist + ] + + +@pytest.mark.parametrize('testfunc, shape, dim_name', testlist, ids=ids) +@pytest.mark.parametrize('dtype', typelist) +def test_max(testfunc, dtype, shape, dim_name): + x0 = test_common.generate_tensor(shape=shape, dtype=dtype).npu() + + if dim_name == "dim024": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 4) + ans, _ = torch.max(ans.to(torch.int64), 2) + ans, _ = torch.max(ans.to(torch.int64), 0) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 4) + ans, _ = torch.max(ans, 2) + ans, _ = torch.max(ans, 0) + output = torch.zeros((shape[1],) + (shape[3],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim13": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 3) + ans, _ = torch.max(ans.to(torch.int64), 1) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 3) + ans, _ = torch.max(ans, 1) + output = torch.zeros((shape[0],) + (shape[2],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim0": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 0) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 0) + output = torch.zeros((shape[1],) + (shape[2],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim1": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 1) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 1) + output = torch.zeros((shape[0],) + (shape[2],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim2": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 2) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 2) + output = torch.zeros((shape[0],) + (shape[1],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim3": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 3) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 3) + output = torch.zeros((shape[0],) + (shape[1],) + (shape[2],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim4": + if 'int' in dtype: + ans, _ = torch.max(x0.to(torch.int64), 4) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.max(x0, 4) + output = torch.zeros((shape[0],) + (shape[1],) + (shape[2],) + (shape[3],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "all": + if 'int' in dtype: + ans = torch.max(x0.to(torch.int64)) + ans = torch.tensor([ans], dtype=eval('torch.' + dtype)) + else: + ans = torch.tensor([torch.max(x0)], dtype=eval('torch.' + dtype)) + output = torch.zeros((1,), dtype=eval('torch.' + dtype)).npu() + + testfunc[(1,)](x0, output, *shape) + + test_common.validate_cmp(dtype, output, ans) diff --git a/third_party/ascend/unittest/pytest_ut/test_reduce_min_4_keepdim_True_with_index_op.py b/third_party/ascend/unittest/pytest_ut/test_reduce_min_4_keepdim_True_with_index_op.py new file mode 100644 index 0000000000..adf3104f52 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_reduce_min_4_keepdim_True_with_index_op.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import time + +import pytest +import torch +import torch_npu + +import triton +import triton.language as tl +import test_common + + +@triton.jit +def promote_to_tensor(x): + # Addition promotes to tensor for us + return x + tl.zeros((1,), tl.int1) + + +@triton.jit +def minimum_with_index(a_value, a_index, b_value, b_index): + mask = a_value < b_value + equal = a_value == b_value + if promote_to_tensor(a_value).dtype.is_floating(): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def triton_min_5d_dim4_keepdim(in_ptr0, in_ptr1, out_ptr0, out_ptr1, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + idx = lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + \ + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :] + x = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret, ret1 = tl.reduce((x, x1), 4, minimum_with_index, keep_dims=True) + zblk_idx = tl.arange(0, 1) + odx = lblk_idx[:, None, None, None, None] * K * N * M + mblk_idx[None, :, None, None, None] * K * N + \ + nblk_idx[None, None, :, None, None] * K + kblk_idx[None, None, None, :, None] \ + + zblk_idx[None, None, None, None, :] + tl.store(out_ptr0 + odx, ret) + tl.store(out_ptr1 + odx, ret1) + + +testlist = [ + # 5D + (triton_min_5d_dim4_keepdim, (1, 1, 1, 1, 1)), + (triton_min_5d_dim4_keepdim, (2, 2, 2, 2, 2)), + (triton_min_5d_dim4_keepdim, (9, 3, 2, 4, 17)), + (triton_min_5d_dim4_keepdim, (3, 11, 1, 3, 42)), + (triton_min_5d_dim4_keepdim, (2, 51, 3, 13, 1)), + (triton_min_5d_dim4_keepdim, (129, 1, 5, 1, 4)), + (triton_min_5d_dim4_keepdim, (203, 1, 2, 2, 3)), + (triton_min_5d_dim4_keepdim, (512, 1, 1, 1, 1)), + (triton_min_5d_dim4_keepdim, (3, 1, 1, 2, 600)), + (triton_min_5d_dim4_keepdim, (1, 1, 1, 1, 1024)), + (triton_min_5d_dim4_keepdim, (15, 2, 2, 2, 54)), + (triton_min_5d_dim4_keepdim, (2, 91, 4, 2, 4)), + (triton_min_5d_dim4_keepdim, (1, 1, 3, 2, 600)), + (triton_min_5d_dim4_keepdim, (5, 2, 4, 1, 26)), + (triton_min_5d_dim4_keepdim, (2, 2, 2, 4, 8)), +] + +typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] + +ids = ["{}-{}".format(testfunc.__name__, "-".join(map(str, shape))) + for testfunc, shape in testlist +] + + +@pytest.mark.parametrize('testfunc, shape', testlist, ids=ids) +@pytest.mark.parametrize('sigtype', typelist) +def test_min_dim4_keepdim(testfunc, sigtype, shape): + dtype = eval('torch.' + sigtype) + x0 = torch.randn(shape).to(dtype).npu() + + x1 = torch.arange(x0.numel()).view(x0.shape).npu().to(torch.int32) + if 'int' in sigtype: + ans, ans1 = torch.min(x0.to(torch.int64), 4) + ans = ans.to(dtype) + else: + ans, ans1 = torch.min(x0, 4) + output = torch.zeros(shape[0:4], dtype=dtype).npu() + output1 = torch.zeros(shape[0:4], dtype=torch.int32).npu() + testfunc[(1,)](x0, x1, output, output1, *shape, debug=True) + test_common.validate_cmp(sigtype, output, ans) + test_common.validate_cmp('int32', output1, ans1) \ No newline at end of file diff --git a/third_party/ascend/unittest/pytest_ut/test_reduce_minimum.py b/third_party/ascend/unittest/pytest_ut/test_reduce_minimum.py new file mode 100644 index 0000000000..c3d3831b90 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_reduce_minimum.py @@ -0,0 +1,332 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +@triton.jit +def minimum(a, b): + ret = tl.minimum(a, b, tl.PropagateNan.ALL) + # 经过测试发现,tl.minimum仅在输入类型为bfloat16时,输出的结果会转变为float32,从而导致编译报错。在GPU上测试发现,和NPU上错误的现象一致。 + # 因此此处针对输入类型为bfloat16的情况,对输出进行了类型转换来规避该错误引起的编译报错。 + if a.dtype == tl.bfloat16: + ret = ret.to(tl.bfloat16) + return ret + + +@triton.jit +def triton_min_5d_dim024(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + idx = lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + \ + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[ + None, None, None, None, :] + odx = mblk_idx[:, None] * K + kblk_idx[None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 4, minimum) + ret1 = tl.reduce(ret, 2, minimum) + ret2 = tl.reduce(ret1, 0, minimum) + tl.store(out_ptr0 + odx, ret2) + + +@triton.jit +def triton_min_5d_dim13(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None] * N * Z + + nblk_idx[None, :, None] * Z + + zblk_idx[None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret_k = tl.reduce(x, 3, minimum) # [L, M, N, Z] + ret_m = tl.reduce(ret_k, 1, minimum) # [L, N, Z] + tl.store(out_ptr0 + odx, ret_m) + + +@triton.jit +def triton_min_5d_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (mblk_idx[:, None, None, None] * N * K * Z + + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 0, minimum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_min_5d_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * N * K * Z + + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 1, minimum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_min_5d_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * K * Z + + mblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 2, minimum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_min_5d_dim3(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * N * Z + + mblk_idx[None, :, None, None] * N * Z + + nblk_idx[None, None, :, None] * Z + + zblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 3, minimum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_min_5d_dim4(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + odx = (lblk_idx[:, None, None, None] * M * N * K + + mblk_idx[None, :, None, None] * N * K + + nblk_idx[None, None, :, None] * K + + kblk_idx[None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret = tl.reduce(x, 4, minimum) + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_min_5d_all(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + Z: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + kblk_idx = tl.arange(0, K) + zblk_idx = tl.arange(0, Z) + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + + kblk_idx[None, None, None, :, None] * Z + + zblk_idx[None, None, None, None, :]) + + x = tl.load(in_ptr0 + idx) + ret1 = tl.reduce(x, 4, minimum) + ret2 = tl.reduce(ret1, 3, minimum) + ret3 = tl.reduce(ret2, 2, minimum) + ret4 = tl.reduce(ret3, 1, minimum) + ret5 = tl.reduce(ret4, 0, minimum) + tl.store(out_ptr0, ret5) + + +testlist = [ + (triton_min_5d_dim024, (1, 1, 1, 1, 1), "dim024"), + (triton_min_5d_dim024, (2, 2, 2, 2, 2), "dim024"), + (triton_min_5d_dim024, (3, 11, 1, 3, 42), "dim024"), + + (triton_min_5d_dim13, (1, 1, 1, 1, 1024), "dim13"), + + (triton_min_5d_dim0, (2, 2, 2, 2, 2), "dim0"), + (triton_min_5d_dim1, (2, 2, 2, 2, 2), "dim1"), + (triton_min_5d_dim2, (2, 2, 2, 2, 2), "dim2"), + (triton_min_5d_dim3, (2, 2, 2, 2, 2), "dim3"), + (triton_min_5d_dim4, (2, 2, 2, 2, 2), "dim4"), + + (triton_min_5d_all, (3, 11, 1, 3, 42), "all"), +] + +typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] + +ids = ["{}-{}-{}".format(testfunc.__name__, "-".join(map(str, shape)), dim_name) + for testfunc, shape, dim_name in testlist + ] + + +@pytest.mark.parametrize('testfunc, shape, dim_name', testlist, ids=ids) +@pytest.mark.parametrize('dtype', typelist) +def test_min(testfunc, dtype, shape, dim_name): + x0 = test_common.generate_tensor(shape=shape, dtype=dtype).npu() + + if dim_name == "dim024": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 4) + ans, _ = torch.min(ans.to(torch.int64), 2) + ans, _ = torch.min(ans.to(torch.int64), 0) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 4) + ans, _ = torch.min(ans, 2) + ans, _ = torch.min(ans, 0) + output = torch.zeros((shape[1],) + (shape[3],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim13": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 3) + ans, _ = torch.min(ans.to(torch.int64), 1) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 3) + ans, _ = torch.min(ans, 1) + output = torch.zeros((shape[0],) + (shape[2],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim0": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 0) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 0) + output = torch.zeros((shape[1],) + (shape[2],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim1": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 1) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 1) + output = torch.zeros((shape[0],) + (shape[2],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim2": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 2) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 2) + output = torch.zeros((shape[0],) + (shape[1],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim3": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 3) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 3) + output = torch.zeros((shape[0],) + (shape[1],) + (shape[2],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "dim4": + if 'int' in dtype: + ans, _ = torch.min(x0.to(torch.int64), 4) + ans = ans.to(dtype=eval('torch.' + dtype)) + else: + ans, _ = torch.min(x0, 4) + output = torch.zeros((shape[0],) + (shape[1],) + (shape[2],) + (shape[3],), dtype=eval('torch.' + dtype)).npu() + + elif dim_name == "all": + if 'int' in dtype: + ans = torch.min(x0.to(torch.int64)) + ans = torch.tensor([ans], dtype=eval('torch.' + dtype)) + else: + ans = torch.tensor([torch.min(x0)], dtype=eval('torch.' + dtype)) + output = torch.zeros((1,), dtype=eval('torch.' + dtype)).npu() + + testfunc[(1,)](x0, output, *shape) + + test_common.validate_cmp(dtype, output, ans) diff --git a/third_party/ascend/unittest/pytest_ut/test_runtime_utils.py b/third_party/ascend/unittest/pytest_ut/test_runtime_utils.py new file mode 100644 index 0000000000..6a70866cfe --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_runtime_utils.py @@ -0,0 +1,33 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import logging +import os +from triton.backends.ascend import utils + + +def test_get_logger(): + logger = utils.get_logger("test_utils", "INFO") + assert logger.level == logging.INFO + + +def test_get_ascend_arch_from_env(): + os.environ["TRITON_ASCEND_ARCH"] = "Ascend910_9599" + result = utils.get_ascend_arch_from_env() + assert result == "Ascend910_9599" diff --git a/third_party/ascend/unittest/pytest_ut/test_scalar_calc.py b/third_party/ascend/unittest/pytest_ut/test_scalar_calc.py index 84ecdbeb19..53517eb4eb 100644 --- a/third_party/ascend/unittest/pytest_ut/test_scalar_calc.py +++ b/third_party/ascend/unittest/pytest_ut/test_scalar_calc.py @@ -137,7 +137,7 @@ def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): def torch_func(x0): y = x0[0] - y = y % 2.0 + y = y - 2.0 * torch.div(y, 2.0, rounding_mode="trunc") return torch.tensor(y) dtype, N = param_list diff --git a/third_party/ascend/unittest/pytest_ut/test_select_analysis.py b/third_party/ascend/unittest/pytest_ut/test_select_analysis.py new file mode 100644 index 0000000000..012ddc5fcd --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_select_analysis.py @@ -0,0 +1,142 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +@triton.jit +def kernel_cal_select_mask_bool( + Output_ptr, + Indices_ptr, + numel: tl.constexpr, + BLOCK: tl.constexpr, +): + offs = tl.arange(0, BLOCK) + indice = tl.load(Indices_ptr) + + true_tensor = tl.arange(0, BLOCK) < numel + false_tensor = tl.arange(0, BLOCK) >= numel + mask = offs < indice + res = tl.where(mask, true_tensor, false_tensor) + tl.store( + Output_ptr + offs, + res + ) + + +@triton.jit +def kernel_cal_select_mask( + QK_ptr, + Other_ptr, + Output_ptr, + Indices_ptr, + stride_qk: tl.constexpr, + numel: tl.constexpr, + BLOCK: tl.constexpr, +): + rows = tl.arange(0, BLOCK) * stride_qk + cols = tl.arange(0, BLOCK) + offs = rows[:, None] + cols[None, :] + row_indices = tl.load(Indices_ptr) + col_indices = tl.load(Indices_ptr + 1) + + qk_ub = tl.load( + QK_ptr + offs + ) + other = tl.load( + Other_ptr + offs + ) + mask_rows = rows < row_indices * stride_qk + mask_cols = cols < col_indices + + res = tl.where(mask_rows[:, None] & mask_cols[None, :], qk_ub, other) + tl.store( + Output_ptr + offs, + res + ) + + +def torch_cal_select_mask_bool( + Indice: torch.Tensor, + numel, + BLOCK, +): + offs = torch.arange(0, BLOCK) + true_tensor = torch.arange(0, BLOCK) < numel + false_tensor = torch.arange(0, BLOCK) >= numel + mask = offs < Indice + + res = torch.where(mask, true_tensor, false_tensor) + return res + + +def torch_cal_select_mask( + QK: torch.Tensor, + Other: torch.Tensor, + Indices: torch.Tensor, +): + row_limit_idx = Indices[0].item() + col_limit_idx = Indices[1].item() + Output = Other.clone() + Output[:row_limit_idx, :col_limit_idx] = QK[:row_limit_idx, :col_limit_idx] + return Output + + +@pytest.mark.parametrize('param_list', + [ + ['bool', 64, 63] + ] + ) +def test_select_analysis_bool(param_list): + dtype, SEQ_LEN, indice = param_list + assert dtype == 'bool' + qk_cal = torch.empty(SEQ_LEN).npu() + indices = torch.tensor([indice]).npu() + qk_ref = torch_cal_select_mask_bool(indice, SEQ_LEN, SEQ_LEN) + kernel_cal_select_mask_bool[(1,)]( + qk_cal, indices, SEQ_LEN, SEQ_LEN + ) + test_common.validate_cmp(dtype, qk_cal, qk_ref) + + +@pytest.mark.parametrize('param_list', + [ + ['float16', 64, 63, 62], + ['float32', 64, 63, 62], + ] + ) +def test_select_analysis(param_list): + dtype, SEQ_LEN, indice_x, indice_y = param_list + assert dtype != 'bool' + qk = torch.rand([SEQ_LEN, SEQ_LEN], dtype=eval('torch.' + dtype), device='npu') + qk_cal = torch.empty_like(qk).npu() + other = torch.zeros_like(qk).npu() + indices_tensor = torch.tensor([indice_x, indice_y]).npu() + qk_ref = torch_cal_select_mask(qk, other, indices_tensor) + kernel_cal_select_mask[(1,)]( + qk, other, qk_cal, indices_tensor, + qk.stride(0), SEQ_LEN * SEQ_LEN, SEQ_LEN + ) + test_common.validate_cmp(dtype, qk_cal, qk_ref) \ No newline at end of file diff --git a/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_invert.py b/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_invert.py new file mode 100644 index 0000000000..0610f0b21a --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_invert.py @@ -0,0 +1,170 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +@triton.jit +def cal_atten_mask_kernel( + QK_ptr, + Indices_ptr, + stride_qk_m, + stride_qk_n, + stride_ik, + SEQ_LEN: tl.constexpr, + sparse_block_size: tl.constexpr, + BLOCK_SBS: tl.constexpr, + TOPK_BASE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + idx_sub_sbs = pid_n + cur_s1 = pid_m * BLOCK_SBS + cur_s2 = cur_s1 + BLOCK_SBS + + if cur_s1 >= SEQ_LEN: + return + + beg_sbs = idx_sub_sbs * BLOCK_SBS // sparse_block_size + end_sbs = ((idx_sub_sbs + 1) * BLOCK_SBS) // sparse_block_size + + valid_col_end = cur_s1 + (cur_s2 - cur_s1) + + offs_m = cur_s1 + tl.arange(0, BLOCK_SBS) + offs_n_base = idx_sub_sbs * BLOCK_SBS + offs_n = offs_n_base + tl.arange(0, BLOCK_SBS) + + mask_m = offs_m < SEQ_LEN + mask_n = offs_n < SEQ_LEN + mask_load = mask_m[:, None] & mask_n[None, :] + + qk_ub = tl.load( + QK_ptr + offs_m[:, None] * stride_qk_m + offs_n[None, :] * stride_qk_n, + mask=mask_load, + other=0.0 + ) + + for idx_k in range(beg_sbs, end_sbs): + idx_s2 = tl.load(Indices_ptr + TOPK_BASE + idx_k * stride_ik) + if idx_s2 != -1 and idx_s2 * sparse_block_size > valid_col_end: + idx_lower_sbs = idx_k * sparse_block_size - \ + idx_sub_sbs * BLOCK_SBS + idx_higher_sbs = (idx_k + 1) * sparse_block_size - \ + idx_sub_sbs * BLOCK_SBS + mask_lower_sbs = tl.arange(0, BLOCK_SBS) >= idx_lower_sbs + mask_higher_sbs = tl.arange(0, BLOCK_SBS) < idx_higher_sbs + qk_ub = tl.where((mask_lower_sbs & mask_higher_sbs)[None, :], float("-inf"), qk_ub) + + tl.store( + QK_ptr + offs_m[:, None] * stride_qk_m + offs_n[None, :] * stride_qk_n, + qk_ub, + mask=mask_load + ) + + +def launch_cal_atten_mask( + qk_tensor, + indices_tensor, + sparse_block_size=64, + block_sbs=128 +): + """ + qk_tensor: (SEQ_LEN, SEQ_LEN) + indices_tensor: (K,) / (BATCH, K, ...) + """ + assert qk_tensor.is_contiguous() + M, N = qk_tensor.shape + + stride_qk_m = qk_tensor.stride(0) + stride_qk_n = qk_tensor.stride(1) + + stride_ik = 1 + topk_base = 0 + + grid = (triton.cdiv(M, block_sbs), triton.cdiv(N, block_sbs)) + cal_atten_mask_kernel[grid]( + qk_tensor, + indices_tensor, + stride_qk_m, + stride_qk_n, + stride_ik, + SEQ_LEN=M, + sparse_block_size=sparse_block_size, + BLOCK_SBS=block_sbs, + TOPK_BASE=topk_base, + ) + return qk_tensor + + +def torch_cal_atten_mask( + qk, + indices, + sparse_block_size, + block_sbs, + topk_base=0, +): + device = qk.device + dtype = qk.dtype + M, N = qk.shape + + row_ids = torch.arange(M, device=device).unsqueeze(1) + col_ids = torch.arange(N, device=device).unsqueeze(0) + + k_idx_global = col_ids // sparse_block_size + lookup_idx = k_idx_global + topk_base + max_valid_idx = indices.numel() - 1 + + valid_lookup = (lookup_idx >= 0) & (lookup_idx <= max_valid_idx) + safe_lookup_idx = lookup_idx.clamp(0, max_valid_idx) + idx_s2_map = indices.gather(0, safe_lookup_idx.squeeze(0)).unsqueeze(0) + idx_s2_map = torch.where(valid_lookup, idx_s2_map, torch.tensor(-1, device=device)) + + row_block_ends = ((row_ids // block_sbs) + 1) * block_sbs + row_block_ends = torch.min(row_block_ends, torch.tensor(N, device=device)) + + start_pos_k_map = idx_s2_map * sparse_block_size + cond_valid = (idx_s2_map != -1) + cond_exceed = (start_pos_k_map > row_block_ends) + final_mask = cond_valid & cond_exceed + + qk_out = torch.where(final_mask, torch.tensor(float("-inf"), dtype=dtype, device=device), qk) + return qk_out + + +@pytest.mark.parametrize('param_list', + [ + ['float32', 1024, 128, 64] + ] + ) +def test_divsiop_select_analysis1(param_list): + dtype, SEQ_LEN, BLOCK_SBS, SPARSE_BLOCK = param_list + qk = torch.zeros((SEQ_LEN, SEQ_LEN), dtype=eval('torch.' + dtype), device='npu') + K_SIZE = 20 + indices = torch.full((K_SIZE,), -1, dtype=torch.int32, device='npu') + indices[10] = 20 + qk_ref = torch_cal_atten_mask(qk.clone(), indices, sparse_block_size=SPARSE_BLOCK, block_sbs=BLOCK_SBS) + qk_cal = launch_cal_atten_mask(qk, indices, sparse_block_size=SPARSE_BLOCK, block_sbs=BLOCK_SBS) + test_common.validate_cmp(dtype, qk_cal, qk_ref) \ No newline at end of file diff --git a/third_party/ascend/unittest/pytest_ut/test_signbit.py b/third_party/ascend/unittest/pytest_ut/test_signbit.py index 693b601c23..b576e98a9e 100644 --- a/third_party/ascend/unittest/pytest_ut/test_signbit.py +++ b/third_party/ascend/unittest/pytest_ut/test_signbit.py @@ -64,3 +64,20 @@ def test_all_blocks_parallel(param_list, monkeypatch): triton_signbit[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub) test_common.validate_cmp('bool', y_cal, y_ref) monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") + + +@pytest.mark.parametrize('param_list', + [ + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['float32', (2, 4096, 8), 2, 32768, 1024], + ] + ) +def test_auto_blockify(param_list, monkeypatch): + monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") + dtype, shape, ncore, xblock, xblock_sub = param_list + x = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch.signbit(x).npu() + y_cal = torch.zeros(shape).bool().npu() + triton_signbit[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub, auto_blockify_size=ncore) + test_common.validate_cmp('bool', y_cal, y_ref) + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") diff --git a/third_party/ascend/unittest/pytest_ut/test_simplify_iterargs.py b/third_party/ascend/unittest/pytest_ut/test_simplify_iterargs.py new file mode 100644 index 0000000000..44f95114ab --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_simplify_iterargs.py @@ -0,0 +1,80 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + +import torch +import torch_npu + + +@triton.jit +def triton_fn_expanddims(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr, YBLOCK_SUB: tl.constexpr): + base1 = tl.arange(0, XBLOCK)[:, None] + base2 = tl.arange(0, YBLOCK_SUB)[None, :] + loops1: tl.constexpr = YBLOCK // YBLOCK_SUB # assume it's divisible + for _ in range(loops1): + x0 = base1 * YBLOCK + base2 + base2 = base2 + YBLOCK_SUB + tmp0 = tl.load(in_ptr0 + (x0), None) + tl.store(out_ptr0 + (x0), tmp0, None) + + +@triton.jit +def triton_fn_broadcast(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr, YBLOCK_SUB: tl.constexpr): + base1 = tl.arange(0, XBLOCK)[:, None] + base2 = tl.arange(0, YBLOCK_SUB)[None, :] + base2 = base2.broadcast_to((XBLOCK, YBLOCK_SUB)) + loops1: tl.constexpr = YBLOCK // YBLOCK_SUB # assume it's divisible + for _ in range(loops1): + x0 = base1 * YBLOCK + base2 + base2 = base2 + YBLOCK_SUB + tmp0 = tl.load(in_ptr0 + (x0), None) + tl.store(out_ptr0 + (x0), tmp0, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (128, 128), 128, 128, 32], + ]) +def test_expanddims(param_list): + dtype, shape, xblock, yblock, yblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_fn_expanddims[(1,)](x0, y_cal, xblock, yblock, yblock_sub) + test_common.validate_cmp(dtype, y_cal, x0) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (128, 128), 128, 128, 32], + ]) +def test_broadcast(param_list): + dtype, shape, xblock, yblock, yblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_fn_broadcast[(1,)](x0, y_cal, xblock, yblock, yblock_sub) + test_common.validate_cmp(dtype, y_cal, x0) diff --git a/third_party/ascend/unittest/pytest_ut/test_sink_broadcast.py b/third_party/ascend/unittest/pytest_ut/test_sink_broadcast.py new file mode 100644 index 0000000000..cf902cc00a --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_sink_broadcast.py @@ -0,0 +1,91 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + +import torch +import torch_npu + + +@triton.jit +def triton_sink_broadcast1(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr): + base1 = tl.arange(0, XBLOCK)[:, None] * YBLOCK + base2 = tl.arange(0, YBLOCK)[None, :] + base1 = base1.broadcast_to((XBLOCK, YBLOCK)) + tmp0 = tl.load(in_ptr0 + base1, None) + index = base1 + base2 + tl.store(out_ptr0 + index, tmp0, None) + + +@triton.jit +def triton_sink_broadcast2(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr): + base1 = tl.arange(0, XBLOCK)[:, None] * YBLOCK + base2 = tl.arange(0, YBLOCK)[None, :] + base1 = base1.broadcast_to((XBLOCK, YBLOCK)) + tmp0 = tl.load(in_ptr0 + base1, base1 < XBLOCK * YBLOCK, other=0.0) + index = base1 + base2 + tl.store(out_ptr0 + index, tmp0, index < XBLOCK * YBLOCK) + + +@triton.jit +def triton_sink_broadcast3(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr): + base1 = (tl.arange(0, XBLOCK) * YBLOCK)[:, None] + base2 = tl.arange(0, YBLOCK)[None, :] + base1 = base1.broadcast_to((XBLOCK, YBLOCK)) + tmp0 = tl.load(in_ptr0 + base1 + base2, (base1 + base2) < XBLOCK * YBLOCK, other=0.0) + index = base1 + base2 + tl.store(out_ptr0 + index, tmp0, index < XBLOCK * YBLOCK) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (32, 32), 32, 32], + ]) +def test_sink_broadcast(param_list): + dtype, shape, xblock, yblock = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = x0.clone() + y_ref = y_ref[:, 0].unsqueeze(1).expand(-1, x0.size(1)) + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + y_cal2 = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_sink_broadcast1[(1,)](x0, y_cal, xblock, yblock) + triton_sink_broadcast2[(1,)](x0, y_cal2, xblock, yblock) + test_common.validate_cmp(dtype, y_cal, y_ref) + test_common.validate_cmp(dtype, y_cal2, y_ref) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (32, 32), 32, 32], + ]) +def test_sink_broadcast3(param_list): + dtype, shape, xblock, yblock = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = x0.clone() + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_sink_broadcast3[(1,)](x0, y_cal, xblock, yblock) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_sync_block.py b/third_party/ascend/unittest/pytest_ut/test_sync_block.py index 0c84b3166a..8d1df6d243 100644 --- a/third_party/ascend/unittest/pytest_ut/test_sync_block.py +++ b/third_party/ascend/unittest/pytest_ut/test_sync_block.py @@ -17,6 +17,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import torch import triton import triton.language as tl @@ -91,7 +92,7 @@ def test_matmul_exp(dtype, ashape, bshape): C_ref = (A @ B).exp() # compare - test_common.validate_cmp(dtype, C, C_ref) + torch.testing.assert_close(C_ref, C, rtol=3e-2, atol=3e-2, equal_nan=True) if __name__ == "__main__": diff --git a/third_party/ascend/unittest/pytest_ut/test_use_analysis.py b/third_party/ascend/unittest/pytest_ut/test_use_analysis.py new file mode 100644 index 0000000000..7a9ce7ba4d --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_use_analysis.py @@ -0,0 +1,73 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + +import torch +import torch_npu + + +@triton.jit +def triton_reduce_deadcode(v_ptr, in_ptr0, in_ptr1, out_ptr0, VBLOCK: tl.constexpr, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr): + v_idx = tl.arange(0, VBLOCK) + v = tl.load(v_ptr + v_idx) + v_ret = tl.argmax(v, 0) + if v_ret < v_ret + 1: + for _ in range(v_ret, v_ret + 1): + cube_idx = tl.arange(0, XBLOCK)[:, None] * YBLOCK + tl.arange(0, YBLOCK)[None, :] + c0 = tl.load(in_ptr0 + cube_idx) + c1 = tl.load(in_ptr1 + cube_idx) + ret = tl.dot(c0, c1) + 1 + tl.store(out_ptr0 + cube_idx, ret) + else: + for _ in range(v_ret - 1, v_ret): + cube_idx = tl.arange(0, XBLOCK)[:, None] * YBLOCK + tl.arange(0, YBLOCK)[None, :] + c0 = tl.load(in_ptr0 + cube_idx) + c1 = tl.load(in_ptr1 + cube_idx) + ret = tl.dot(c0, c1) + 1 + tl.store(out_ptr0 + cube_idx, ret) + + +def torch_reduce_deadcode(in0, in1, v): + v_ret = torch.argmax(v) + if v_ret < v_ret + 1: + ret = torch.matmul(in0, in1) + 1 + else: + ret = torch.matmul(in0, in1) + 1 + return ret + + +def test_reduce_deadcode(): + VBLOCK, XBLOCK, YBLOCK = 16, 16, 16 + sigtype = 'float32' + dtype = torch.float32 + in0 = torch.randn((XBLOCK, YBLOCK), dtype=dtype, device='npu') + in1 = torch.randn((XBLOCK, YBLOCK), dtype=dtype, device='npu') + v = torch.randn((VBLOCK,), dtype=dtype, device='npu') + out = torch.zeros((XBLOCK, YBLOCK), dtype=dtype, device='npu') + + triton_reduce_deadcode[(1,)](v, in0, in1, out, VBLOCK=VBLOCK, XBLOCK=XBLOCK, YBLOCK=YBLOCK) + expected = torch_reduce_deadcode(in0, in1, v) + test_common.validate_cmp(sigtype, out, expected) diff --git a/third_party/ascend/unittest/pytest_ut/test_zeros.py b/third_party/ascend/unittest/pytest_ut/test_zeros.py index cc7fd067f1..3dab18af39 100644 --- a/third_party/ascend/unittest/pytest_ut/test_zeros.py +++ b/third_party/ascend/unittest/pytest_ut/test_zeros.py @@ -27,15 +27,11 @@ @triton.jit -def fn_npu_f32(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): +def fn_npu_f32(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): xidx = tl.arange(0, XB) yidx = tl.arange(0, YB) zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - ret = tl.zeros((XB, YB, ZB), dtype=tl.float32) oidx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] @@ -44,15 +40,11 @@ def fn_npu_f32(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.con @triton.jit -def fn_npu_f16(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): +def fn_npu_f16(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): xidx = tl.arange(0, XB) yidx = tl.arange(0, YB) zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - ret = tl.zeros((XB, YB, ZB), dtype=tl.float16) oidx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] @@ -61,15 +53,11 @@ def fn_npu_f16(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.con @triton.jit -def fn_npu_i8(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): +def fn_npu_i8(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): xidx = tl.arange(0, XB) yidx = tl.arange(0, YB) zidx = tl.arange(0, ZB) - idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] - - X = tl.load(x_ptr + idx) - ret = tl.zeros((XB, YB, ZB), dtype=tl.int8) oidx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] @@ -87,17 +75,14 @@ def fn_npu_i8(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.cons ]) def test_case(param_list): dtype, shape, ncore, XB, YB, ZB = param_list - x0 = test_common.generate_tensor(shape, dtype) y_ref = torch.full((XB, YB, ZB), 0, dtype=eval('torch.' + dtype)).npu() - print(f"y_ref = {y_ref[0, 0, 0:4]}") y_cal = torch.randint(1, (XB, YB, ZB), dtype=eval('torch.' + dtype)).npu() if dtype == "float32": - fn_npu_f32[ncore, 1, 1](y_cal, x0, XB, YB, ZB) + fn_npu_f32[ncore, 1, 1](y_cal, XB, YB, ZB) elif dtype == "float16": - fn_npu_f16[ncore, 1, 1](y_cal, x0, XB, YB, ZB) + fn_npu_f16[ncore, 1, 1](y_cal, XB, YB, ZB) else: - fn_npu_i8[ncore, 1, 1](y_cal, x0, XB, YB, ZB) - print(f"y_cal = {y_cal[0, 0, 0:4]}") + fn_npu_i8[ncore, 1, 1](y_cal, XB, YB, ZB) test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_zeroslike.py b/third_party/ascend/unittest/pytest_ut/test_zeroslike.py index 76ddf08f7e..6d6b4a822f 100644 --- a/third_party/ascend/unittest/pytest_ut/test_zeroslike.py +++ b/third_party/ascend/unittest/pytest_ut/test_zeroslike.py @@ -53,7 +53,7 @@ def fn_npu_(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.conste ]) def test_case(param_list): dtype, shape, ncore, XB, YB, ZB = param_list - x0 = test_common.generate_tensor(shape, dtype) + x0 = test_common.generate_tensor(shape, dtype).npu() y_ref = torch.zeros_like(x0, dtype=eval('torch.' + dtype)).npu() print(f"y_ref = {y_ref[0, 0, 0:4]}") y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() From 825eabf327bed2e6a4ab17a506630110a925b372 Mon Sep 17 00:00:00 2001 From: Stardep <1486216685@qq.com> Date: Mon, 11 May 2026 02:49:23 +0800 Subject: [PATCH 2/5] build: fix ascend cmake subdirectory paths --- third_party/ascend/include/CMakeLists.txt | 7 ------- third_party/ascend/lib/CMakeLists.txt | 13 +++---------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/third_party/ascend/include/CMakeLists.txt b/third_party/ascend/include/CMakeLists.txt index 2f77a4a5c4..e1eae53c09 100644 --- a/third_party/ascend/include/CMakeLists.txt +++ b/third_party/ascend/include/CMakeLists.txt @@ -1,12 +1,5 @@ -add_subdirectory(Dialect) -add_subdirectory(TritonToAnnotation) add_subdirectory(TritonToHFusion) add_subdirectory(TritonToHIVM) -add_subdirectory(TritonToLinalg) -add_subdirectory(Utils) -add_subdirectory(DiscreteMaskAccessConversion) -add_subdirectory(TritonToUnstructure) add_subdirectory(TritonToLLVM) -add_subdirectory(TritonToStructured) add_subdirectory(AutoBlockify) add_subdirectory(TritonAffinityOpt) diff --git a/third_party/ascend/lib/CMakeLists.txt b/third_party/ascend/lib/CMakeLists.txt index c5fa61143a..b04c8d981d 100644 --- a/third_party/ascend/lib/CMakeLists.txt +++ b/third_party/ascend/lib/CMakeLists.txt @@ -1,14 +1,7 @@ add_subdirectory(AutoBlockify) -add_subdirectory(Dialect) -add_subdirectory(TritonToAnnotation) -add_subdirectory(TritonToHFusion) -add_subdirectory(TritonToHIVM) -add_subdirectory(TritonToLinalg) -add_subdirectory(Utils) -add_subdirectory(DiscreteMaskAccessConversion) -add_subdirectory(TritonToUnstructure) -add_subdirectory(TritonToLLVM) -add_subdirectory(TritonToStructured) +add_subdirectory(Conversion/TritonToHFusion) +add_subdirectory(Conversion/TritonToHIVM) +add_subdirectory(Conversion/TritonToLLVM) add_subdirectory(TritonAffinityOpt) if(TRITON_ENABLE_COVERAGE_HITEST) From 61b6d71ed55ca46b952bc8e1ed0f99c02db39593 Mon Sep 17 00:00:00 2001 From: Stardep <1486216685@qq.com> Date: Mon, 11 May 2026 02:58:51 +0800 Subject: [PATCH 3/5] build: avoid duplicate AscendNPU-IR cmake inclusion --- third_party/ascend/CMakeLists.txt | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/third_party/ascend/CMakeLists.txt b/third_party/ascend/CMakeLists.txt index 5672ac1e5c..742063d2ce 100644 --- a/third_party/ascend/CMakeLists.txt +++ b/third_party/ascend/CMakeLists.txt @@ -12,20 +12,9 @@ include_directories(${CMAKE_BINARY_DIR}/third_party/flir/include) # set(BISHENGIR_ENABLE_A5_UNPUBLISHED_FEATURES ON) # set(BISHENGIR_BUILD_STANDALONE_IR_ONLY ON) -# Temporarily save and clear the RULE_LAUNCH_* attributes in the current directory to ensure AscendNPU-IR is not affected. -get_property(_saved_launch_compile DIRECTORY PROPERTY RULE_LAUNCH_COMPILE) -get_property(_saved_launch_link DIRECTORY PROPERTY RULE_LAUNCH_LINK) -set_property(DIRECTORY PROPERTY RULE_LAUNCH_COMPILE "") -set_property(DIRECTORY PROPERTY RULE_LAUNCH_LINK "") - -add_subdirectory(${ASCENDNPU_IR_SRC_DIR} ${ASCENDNPU_IR_BINARY_DIR}) - -# restore properties -set_property(DIRECTORY PROPERTY RULE_LAUNCH_COMPILE ${_saved_launch_compile}) -set_property(DIRECTORY PROPERTY RULE_LAUNCH_LINK ${_saved_launch_link}) - -include_directories(${ASCENDNPU_IR_SRC_DIR}/bishengir/include) -include_directories(${ASCENDNPU_IR_BINARY_DIR}/bishengir/include) # Tablegen'd files +# AscendNPU-IR is already added from the top-level CMakeLists when +# FLAGTREE_BACKEND=ascend. Do not add it again here, otherwise CMake will +# fail with "binary directory is already used to build a source directory". add_subdirectory(backend/spec/lib) From 0490c7bc58744bb8ccf411c5615eda4d85d54638 Mon Sep 17 00:00:00 2001 From: Stardep <1486216685@qq.com> Date: Mon, 11 May 2026 03:36:21 +0800 Subject: [PATCH 4/5] build: fix ascend source integration for CI --- bin/RegisterTritonDialects.h | 18 ++----- third_party/ascend/ascend_ir.cc | 8 ++-- .../ascend/lib/AutoBlockify/AutoBlockify.cpp | 3 +- .../lib/AutoBlockify/RewriteOperation.cpp | 1 - third_party/ascend/lib/AutoBlockify/Utils.cpp | 1 - .../Conversion/TritonToLLVM/TritonToLLVM.cpp | 10 ---- .../lib/TritonAffinityOpt/DAGSSBuffer.cpp | 1 - .../ascend/lib/TritonAffinityOpt/DAGScope.cpp | 1 - .../ascend/lib/TritonAffinityOpt/DAGSync.cpp | 1 - third_party/ascend/triton_ascend.cc | 47 +++---------------- 10 files changed, 16 insertions(+), 75 deletions(-) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index d4c6294b06..47f26957d6 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -1,13 +1,8 @@ #pragma once -#include "ascend/include/TritonToLinalg/Passes.h" -#include "ascend/include/DiscreteMaskAccessConversion/Passes.h" -#include "ascend/include/TritonToStructured/Passes.h" -#include "ascend/include/TritonToAnnotation/Passes.h" -#include "ascend/include/TritonToUnstructure/Passes.h" -#include "ascend/include/TritonToHIVM/Passes.h" -#include "ascend/include/TritonToHFusion/Passes.h" -#include "ascend/include/TritonToLLVM/Passes.h" -#include "ascend/include/AutoBlockify/Passes.h" +#include "TritonToHIVM/Passes.h" +#include "TritonToHFusion/Passes.h" +#include "TritonToLLVM/Passes.h" +#include "AutoBlockify/Passes.h" // #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" // #include "amd/include/TritonAMDGPUTransforms/Passes.h" // #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" @@ -73,11 +68,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { // mlir::triton::registerConvertTritonGPUToLLVMPass(); // mlir::triton::registerConvertNVGPUToLLVMPass(); // mlir::triton::registerDecomposeUnsupportedNVIDIAConversions(); - mlir::triton::registerTritonToLinalgPasses(); - mlir::triton::registerDiscreteMaskAccessConversion(); - mlir::triton::registerTritonToStructuredPasses(); - mlir::triton::registerTritonToAnnotationPasses(); - mlir::triton::registerTritonToUnstructurePasses(); mlir::triton::registerTritonToHIVMPasses(); mlir::triton::registerTritonToHFusionPasses(); mlir::triton::registerTritonToLLVMPasses(); diff --git a/third_party/ascend/ascend_ir.cc b/third_party/ascend/ascend_ir.cc index 25f5597a2d..3f9b58615b 100644 --- a/third_party/ascend/ascend_ir.cc +++ b/third_party/ascend/ascend_ir.cc @@ -522,9 +522,11 @@ void init_ascend_ir(py::module &&m) { }) .def("get_iterator_types_attr", [](AscendNPUIROpBuilder &self, const std::vector& array) { - auto attrs = llvm::to_vector(llvm::map_range(array, [&self](hivm::IteratorType type) { - return cast(self.getBuilder().getAttr(type)); - })); + llvm::SmallVector attrs; + attrs.reserve(array.size()); + for (auto type : array) { + attrs.push_back(self.getBuilder().getI32IntegerAttr(static_cast(type))); + } return self.getBuilder().getArrayAttr(attrs); }) .def("get_t_core_type_attr_name", diff --git a/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp b/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp index ec621adf79..8d9129cd5e 100644 --- a/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp +++ b/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp @@ -22,8 +22,7 @@ #include "AutoBlockify/AutoBlockify.h" #include "AutoBlockify/Utils.h" -#include "Dialect/TritonAscend/IR/TritonAscendDialect.h" -#include "Utils/Utils.h" +#include "npu/Dialect/TritonAscend/IR/TritonAscendDialect.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Pass/PassManager.h" diff --git a/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp b/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp index d51c44997a..f3f71ea7b6 100644 --- a/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp +++ b/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp @@ -22,7 +22,6 @@ #include "AutoBlockify/AutoBlockify.h" #include "AutoBlockify/Utils.h" -#include "Utils/Utils.h" #include "llvm/Support/Debug.h" diff --git a/third_party/ascend/lib/AutoBlockify/Utils.cpp b/third_party/ascend/lib/AutoBlockify/Utils.cpp index bdb09e792e..6e49e6b0f2 100644 --- a/third_party/ascend/lib/AutoBlockify/Utils.cpp +++ b/third_party/ascend/lib/AutoBlockify/Utils.cpp @@ -21,7 +21,6 @@ */ #include "AutoBlockify/Utils.h" -#include "Utils/Utils.h" #include "llvm/Support/Debug.h" diff --git a/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp b/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp index 75ea897ef4..a1fc685e17 100644 --- a/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp +++ b/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp @@ -248,16 +248,6 @@ struct ElementwiseInlineAsmOpConversion : OpRewritePattern outs; - for (int i = 0; i < unpackedResults.size(); i++) { - outs.push_back(rewriter.create( - loc, op->getResult(i).getType(), unpackedResults[i])); - } - rewriter.replaceOp(op, outs); - - return success(); - } }; void TritonToLLVMPass::runOnOperation() { diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp index 211adc311d..baa85cc2f9 100644 --- a/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp @@ -44,7 +44,6 @@ #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "Utils/Utils.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallPtrSet.h" #include diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp index 0b17b913fe..0f19131232 100644 --- a/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp @@ -40,7 +40,6 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "Utils/Utils.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallPtrSet.h" #include diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp index 17568cc9f0..c0fc70b255 100644 --- a/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp @@ -26,7 +26,6 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/Support/Casting.h" -#include "Utils/Utils.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/FormatVariadic.h" diff --git a/third_party/ascend/triton_ascend.cc b/third_party/ascend/triton_ascend.cc index ec40899e18..c13d394fd6 100644 --- a/third_party/ascend/triton_ascend.cc +++ b/third_party/ascend/triton_ascend.cc @@ -9,17 +9,12 @@ #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Pass/PassManager.h" -#include "ascend/include/AutoBlockify/Passes.h" -#include "ascend/include/TritonToStructured/Passes.h" -#include "ascend/include/TritonToAnnotation/Passes.h" -#include "ascend/include/TritonToLinalg/Passes.h" -#include "ascend/include/Dialect/TritonAscend/IR/TritonAscendDialect.h" -#include "ascend/include/DiscreteMaskAccessConversion/Passes.h" -#include "ascend/include/TritonToUnstructure/Passes.h" -#include "ascend/include/TritonToHIVM/Passes.h" -#include "ascend/include/TritonToHFusion/Passes.h" -#include "ascend/include/TritonToLLVM/Passes.h" - #include "ascend/include/TritonAffinityOpt/Passes.h" +#include "AutoBlockify/Passes.h" +#include "TritonToHIVM/Passes.h" +#include "TritonToHFusion/Passes.h" +#include "TritonToLLVM/Passes.h" +#include "TritonAffinityOpt/Passes.h" +#include "npu/Dialect/TritonAscend/IR/TritonAscendDialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "ir.h" // TritonOpBuilder @@ -303,39 +298,9 @@ void init_triton_ascend_passes_ttir(py::module &&m) { opts.autoBlockifySize = autoBlockifySize; pm.addPass(mlir::triton::createAutoBlockifyPass(opts));}); - m.def("add_triton_to_structure", [](mlir::PassManager &pm, - bool enableMaskFallbackConversion, bool optimizeDynamicOffset) { - pm.addPass(mlir::triton::createTritonToStructuredPass( - enableMaskFallbackConversion, optimizeDynamicOffset)); }); - - m.def("add_triton_to_annotation", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createTritonToAnnotationPass());}); - - m.def("add_triton_to_linalg", [](mlir::PassManager &pm, bool globalKernel, - bool namedOps, bool enableNd2nzOnVector, bool enableSelectAnalysis, - bool compileOn91095) { - pm.addPass(mlir::triton::createTritonToLinalgPass( - globalKernel, namedOps, enableNd2nzOnVector, - enableSelectAnalysis, compileOn91095)); }); - - m.def("add_triton_to_unstructure", [](mlir::PassManager &pm, - bool compileOn91095, bool forceSimtTemplate) { - TritonToUnstructureOptions opts; - opts.compileOn91095 = compileOn91095; - opts.forceSimtTemplate = forceSimtTemplate; - pm.addPass(mlir::triton::createTritonToUnstructurePass(opts));}); - m.def("add_triton_to_hfusion", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::createTritonToHFusionPass());}); - m.def("add_discrete_mask_access_conversion", [](mlir::PassManager &pm, - bool compileOn91095, bool forceSimtTemplate, bool enableSyncBlockLock) { - DiscreteMaskAccessConversionOptions opts; - opts.compileOn91095 = compileOn91095; - opts.forceSimtTemplate = forceSimtTemplate; - opts.enableSyncBlockLock = enableSyncBlockLock; - pm.addPass(mlir::triton::createDiscreteMaskAccessConversionPass(opts));}); - m.def("add_triton_to_hivm", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::createTritonToHIVMPass());}); From 20eedb93464662c2595a8ed66deede6723205817 Mon Sep 17 00:00:00 2001 From: flagtree-bot Date: Sun, 10 May 2026 19:59:16 +0000 Subject: [PATCH 5/5] Apply code-format changes --- bin/RegisterTritonDialects.h | 4 +- python/src/llvm.cc | 2 +- python/triton/extension/__init__.py | 2 +- python/triton/extension/buffer/__init__.py | 2 +- .../extension/buffer/language/builder.py | 2 - .../triton/extension/buffer/language/core.py | 65 +- .../extension/buffer/language/semantic.py | 64 +- .../triton/extension/buffer/src/buffer_ir.cc | 62 +- python/triton/tools/get_ascend_devices.py | 24 +- third_party/ascend/ascend_ir.cc | 336 +- third_party/ascend/backend/compiler.py | 107 +- third_party/ascend/backend/driver.py | 10 +- third_party/ascend/backend/npu_utils.cpp | 233 +- .../ascend/backend/runtime/autoparser.py | 67 +- .../ascend/backend/runtime/autotuner.py | 73 +- .../ascend/backend/runtime/tile_generator.py | 53 +- .../backend/spec/triton/language/semantic.py | 17 +- .../spec/triton/runtime/ascend_interpreter.py | 167 +- .../backend/spec/triton/runtime/autotuner.py | 143 +- .../spec/triton/runtime/interpreter.py | 12 +- third_party/ascend/backend/utils.py | 3 +- .../include/AutoBlockify/AutoBlockify.h | 7 +- .../include/AutoBlockify/CMakeLists.txt | 2 +- .../ascend/include/AutoBlockify/Passes.td | 2 +- .../ascend/include/AutoBlockify/Utils.h | 7 +- .../include/TritonAffinityOpt/CMakeLists.txt | 2 +- .../ascend/include/TritonAffinityOpt/DAG.h | 176 +- .../ascend/include/TritonAffinityOpt/Passes.h | 2 +- .../include/TritonAffinityOpt/Passes.td | 2 +- .../include/TritonAffinityOpt/Utils.hpp | 18 +- .../language/cann/extension/__init__.py | 1 - .../ascend/language/cann/extension/aux_ops.py | 54 +- .../ascend/language/cann/extension/core.py | 69 +- .../language/cann/extension/custom_op.py | 7 +- .../ascend/language/cann/extension/mem_ops.py | 214 +- .../language/cann/extension/semantic.py | 32 +- .../ascend/language/cann/extension/vec_ops.py | 134 +- third_party/ascend/language/cann/libdevice.py | 438 +- .../ascend/lib/AutoBlockify/AutoBlockify.cpp | 10 +- .../ascend/lib/AutoBlockify/CMakeLists.txt | 2 +- .../lib/AutoBlockify/RewriteOperation.cpp | 2 +- third_party/ascend/lib/AutoBlockify/Utils.cpp | 2 +- .../Conversion/TritonToLLVM/TritonToLLVM.cpp | 67 +- .../lib/TritonAffinityOpt/CMakeLists.txt | 2 +- .../ascend/lib/TritonAffinityOpt/DAG.cpp | 304 +- .../lib/TritonAffinityOpt/DAGSSBuffer.cpp | 3930 ++++++++--------- .../ascend/lib/TritonAffinityOpt/DAGScope.cpp | 1155 +++-- .../ascend/lib/TritonAffinityOpt/DAGSync.cpp | 2097 ++++----- third_party/ascend/python/src/ir.cc | 124 +- third_party/ascend/triton_ascend.cc | 559 +-- .../tutorials/03-matrix-multiplication.py | 60 +- .../ascend/tutorials/04-low-memory-dropout.py | 4 +- third_party/ascend/tutorials/05-layer-norm.py | 1 - .../ascend/tutorials/06-fused-attention.py | 75 +- .../ascend/tutorials/07-extern-functions.py | 6 +- .../ascend/tutorials/08-grouped-gemm.py | 7 +- .../ascend/tutorials/09-persistent-matmul.py | 7 +- .../TritonToLinalg/copy_use_analysis.mlir | 1 - .../General/AutoBlockify/auto_blockify.mlir | 2 +- .../General/DiscreteMaskAccess/atomic.mlir | 22 +- .../simplify_for_loop.mlir | 3 - .../General/TritonToHFusion/fp_to_fp_rtz.mlir | 2 +- .../General/TritonToHFusion/mod.mlir | 2 +- .../sync_block_op_conversion.mlir | 2 +- .../TritonToLinalg/atomic_rmw_block.mlir | 72 +- .../General/TritonToStructured/parseCmp.mlir | 6 +- .../bubbleupoperation.mlir | 2 +- .../TritonToUnstructure/nested_loop.mlir | 2 +- .../General/TritonToUnstructure/splat.mlir | 2 +- .../TritonToUnstructure/unstructure_mix.mlir | 4 +- .../unittest/autotune_ut/01-vector-add.py | 23 +- .../unittest/autotune_ut/02-fused-softmax.py | 16 +- .../unittest/autotune_ut/03-layer-norm.py | 19 +- .../unittest/autotune_ut/04-libentry.py | 22 +- .../autotune_ut/test_autotune_param_valid.py | 72 +- .../unittest/autotune_ut/test_common.py | 5 +- .../autotune_ut/test_customized_config.py | 44 +- .../autotune_ut/test_low_dim_axes_parse.py | 11 +- .../unittest/autotune_ut/test_mask_parse.py | 32 +- .../autotune_ut/test_no_tiling_axis_parse.py | 34 +- .../autotune_ut/test_reduction_axes_parse.py | 58 +- .../autotune_ut/test_split_axis_parse.py | 40 +- .../autotune_ut/test_tiling_axis_parse.py | 42 +- .../unittest/custom_op/builtin_ops_demo.py | 31 +- .../unittest/custom_op/custom_op_demo.py | 3 +- .../custom_op/custom_op_extra_buffer_demo.py | 6 +- .../unittest/custom_op/test_gather_load.py | 12 +- .../unittest/custom_op/test_index_select.py | 29 +- .../unittest/pytest_ut/test_01_vector_add.py | 1 - .../pytest_ut/test_02_fused_softmax.py | 14 +- .../test_03_matrix_multiplication.py | 61 +- .../pytest_ut/test_04_low_memory_dropout.py | 12 +- .../unittest/pytest_ut/test_05_layer_norm.py | 3 +- .../pytest_ut/test_06_fused_attention.py | 54 +- .../pytest_ut/test_07_extern_functions.py | 1 - .../pytest_ut/test_08_grouped_gemm.py | 5 +- .../pytest_ut/test_09_persistent_matmul.py | 5 +- .../pytest_ut/test_10_gather_sorted.py | 20 +- .../unittest/pytest_ut/test_11_rab_time.py | 114 +- .../pytest_ut/test_12_hstu_attention.py | 168 +- ...test_13_matrix_multiplication_optimized.py | 93 +- .../pytest_ut/test_14_accuracy_comparison.py | 35 +- .../pytest_ut/test_15_demo_autotune.py | 12 +- .../unittest/pytest_ut/test_16_profiler.py | 31 +- .../unittest/pytest_ut/test_18_gather.py | 15 +- .../ascend/unittest/pytest_ut/test_add.py | 10 +- .../unittest/pytest_ut/test_address_check.py | 5 +- .../unittest/pytest_ut/test_advance_ptr.py | 24 +- .../ascend/unittest/pytest_ut/test_alloc.py | 12 +- .../ascend/unittest/pytest_ut/test_arch.py | 1 + .../ascend/unittest/pytest_ut/test_argmax.py | 14 +- .../ascend/unittest/pytest_ut/test_argmin.py | 14 +- .../ascend/unittest/pytest_ut/test_asm.py | 71 +- .../unittest/pytest_ut/test_asm_scalar.py | 20 +- .../ascend/unittest/pytest_ut/test_assume1.py | 6 +- .../unittest/pytest_ut/test_atomic_add.py | 40 +- .../unittest/pytest_ut/test_atomic_and.py | 16 +- .../unittest/pytest_ut/test_atomic_cas.py | 59 +- .../unittest/pytest_ut/test_atomic_max.py | 30 +- .../unittest/pytest_ut/test_atomic_min.py | 23 +- .../pytest_ut/test_atomic_rmw_useanalysis.py | 10 +- .../unittest/pytest_ut/test_block_ptr.py | 30 +- .../unittest/pytest_ut/test_boundary_check.py | 160 +- .../unittest/pytest_ut/test_cat_help_func.py | 86 +- .../unittest/pytest_ut/test_celoss_indices.py | 16 +- .../unittest/pytest_ut/test_compile_hint.py | 8 +- .../unittest/pytest_ut/test_complex_mask.py | 4 +- .../ascend/unittest/pytest_ut/test_copy.py | 1 + .../ascend/unittest/pytest_ut/test_custom.py | 35 +- .../pytest_ut/test_discrete_mask_atomic.py | 4 +- .../pytest_ut/test_discrete_mask_loadstore.py | 82 +- .../test_discrete_mask_tail_block_mte_oob.py | 63 +- .../pytest_ut/test_discrete_overlap_mask.py | 49 +- .../ascend/unittest/pytest_ut/test_dot.py | 4 +- .../ascend/unittest/pytest_ut/test_erfinv.py | 10 +- .../ascend/unittest/pytest_ut/test_expm1.py | 7 +- .../unittest/pytest_ut/test_fast_dividef.py | 11 +- .../unittest/pytest_ut/test_fast_expf.py | 11 +- .../ascend/unittest/pytest_ut/test_gamma.py | 8 +- .../unittest/pytest_ut/test_if_advance.py | 48 +- .../ascend/unittest/pytest_ut/test_if_load.py | 18 +- .../pytest_ut/test_implicit_atomic.py | 42 +- .../pytest_ut/test_implicit_permute.py | 121 +- .../test_indirect_scalar_load_offset.py | 14 +- .../pytest_ut/test_interleave_optimizaiton.py | 39 +- .../ascend/unittest/pytest_ut/test_lgamma.py | 14 +- .../unittest/pytest_ut/test_linearize_mask.py | 36 +- .../test_makeblockptr_negative_padding.py | 8 +- .../ascend/unittest/pytest_ut/test_mod.py | 2 +- .../unittest/pytest_ut/test_mul_reduce.py | 17 +- .../unittest/pytest_ut/test_multibuffer.py | 6 +- .../pytest_ut/test_negative_mask_dim.py | 10 +- .../unittest/pytest_ut/test_nextafter.py | 9 +- .../pytest_ut/test_paged_kvcache_krope.py | 20 +- .../pytest_ut/test_permuted_boundary_check.py | 54 +- .../unittest/pytest_ut/test_reduce_maximum.py | 108 +- ...reduce_min_4_keepdim_True_with_index_op.py | 14 +- .../unittest/pytest_ut/test_reduce_minimum.py | 108 +- .../pytest_ut/test_select_analysis.py | 47 +- .../test_select_analysis_for_invert.py | 31 +- .../ascend/unittest/pytest_ut/test_signbit.py | 10 +- .../pytest_ut/test_simplify_iterargs.py | 18 +- .../unittest/pytest_ut/test_sink_broadcast.py | 20 +- .../unittest/pytest_ut/test_use_analysis.py | 7 +- 164 files changed, 6531 insertions(+), 7528 deletions(-) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 47f26957d6..6be9c978c6 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -1,8 +1,8 @@ #pragma once -#include "TritonToHIVM/Passes.h" +#include "AutoBlockify/Passes.h" #include "TritonToHFusion/Passes.h" +#include "TritonToHIVM/Passes.h" #include "TritonToLLVM/Passes.h" -#include "AutoBlockify/Passes.h" // #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" // #include "amd/include/TritonAMDGPUTransforms/Passes.h" // #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" diff --git a/python/src/llvm.cc b/python/src/llvm.cc index ee4b222eaa..4aa24986c7 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -390,7 +390,7 @@ void init_triton_llvm(py::module &&m) { py::arg("arch") = "", py::arg("features") = "", py::arg("flags") = std::vector{}, py::arg("enable_fp_fusion") = false, - py::call_guard()); + py::call_guard()); m.def( "translate_to_asm", diff --git a/python/triton/extension/__init__.py b/python/triton/extension/__init__.py index 006c0ba6ab..6cbe0ecf88 100644 --- a/python/triton/extension/__init__.py +++ b/python/triton/extension/__init__.py @@ -17,4 +17,4 @@ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. \ No newline at end of file +# THE SOFTWARE. diff --git a/python/triton/extension/buffer/__init__.py b/python/triton/extension/buffer/__init__.py index 006c0ba6ab..6cbe0ecf88 100644 --- a/python/triton/extension/buffer/__init__.py +++ b/python/triton/extension/buffer/__init__.py @@ -17,4 +17,4 @@ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. \ No newline at end of file +# THE SOFTWARE. diff --git a/python/triton/extension/buffer/language/builder.py b/python/triton/extension/buffer/language/builder.py index f94ed54f15..bb519df070 100644 --- a/python/triton/extension/buffer/language/builder.py +++ b/python/triton/extension/buffer/language/builder.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ buffer-specific builder utilities for code generation. """ @@ -72,4 +71,3 @@ def setup_unified_builder_with_buffer_builder(main_builder, buffer_builder): 'subview', ] attach_builder_methods_with_buffer_builder(main_builder, buffer_builder, buffer_methods) - diff --git a/python/triton/extension/buffer/language/core.py b/python/triton/extension/buffer/language/core.py index eb9d4397c4..d32a787d04 100644 --- a/python/triton/extension/buffer/language/core.py +++ b/python/triton/extension/buffer/language/core.py @@ -37,7 +37,6 @@ import triton.language.core as tl from triton.language import semantic as real_semantic - T = TypeVar("T") TRITON_BUILTIN = "__triton_builtin__" @@ -74,9 +73,7 @@ class address_space: """ def to_ir(self, builder: ir.builder) -> ir.type: - raise NotImplementedError( - "Abstract address_space cannot be converted to ir" - ) + raise NotImplementedError("Abstract address_space cannot be converted to ir") class buffer_type(tl.dtype): @@ -115,10 +112,8 @@ def __repr__(self): def __eq__(self, other) -> bool: if not isinstance(other, buffer_type): return False - return (self.element_ty == other.element_ty and - self.shape == other.shape and - self.space == other.space and - self.strides == other.strides) + return (self.element_ty == other.element_ty and self.shape == other.shape and self.space == other.space + and self.strides == other.strides) def __ne__(self, other) -> bool: return not self.__eq__(other) @@ -162,20 +157,14 @@ def __init__(self, handle, buffer_ty: buffer_type): def __str__(self) -> str: # ex. "<16x32xfloat32, address_space>" - res = '<' + 'x'.join(str(s) - for s in self.shape) + 'x' + str(self.dtype) + res = '<' + 'x'.join(str(s) for s in self.shape) + 'x' + str(self.dtype) if self.space: res += ', ' + str(self.space) return res + '>' @builtin - def subview( - self, - offsets: List[tl.constexpr], - sizes: List[tl.constexpr], - strides: List[tl.constexpr], - _builder=None - ) -> 'buffer': + def subview(self, offsets: List[tl.constexpr], sizes: List[tl.constexpr], strides: List[tl.constexpr], + _builder=None) -> 'buffer': return subview(self, offsets, sizes, strides, _builder=_builder) @builtin @@ -188,13 +177,8 @@ def to_tensor(self, writable=True, target_shape=None, _builder=None): @builtin -def alloc( - etype: tl.dtype, - shape: List[tl.constexpr], - _address_space: address_space = None, - is_mem_unique: bool = False, - _builder=None -) -> buffer: +def alloc(etype: tl.dtype, shape: List[tl.constexpr], _address_space: address_space = None, is_mem_unique: bool = False, + _builder=None) -> buffer: """ Allocates a region of local memory with the specified shape and type. @@ -209,12 +193,7 @@ def alloc( @builtin -def to_buffer( - tensor: tl.tensor, - space: address_space = None, - bind_buffer: buffer = None, - _builder=None -) -> buffer: +def to_buffer(tensor: tl.tensor, space: address_space = None, bind_buffer: buffer = None, _builder=None) -> buffer: """ Convert a tensor to a buffer. @@ -223,18 +202,11 @@ def to_buffer( :param space: the address space for the buffer (optional). :type space: address_space """ - return semantic.to_buffer( - tensor, space, bind_buffer, _builder - ) + return semantic.to_buffer(tensor, space, bind_buffer, _builder) @builtin -def to_tensor( - memref: buffer, - writable: bool = True, - target_shape=None, - _builder=None -) -> tl.tensor: +def to_tensor(memref: buffer, writable: bool = True, target_shape=None, _builder=None) -> tl.tensor: """ Create a tl.tensor from a bl.buffer. @@ -245,7 +217,7 @@ def to_tensor( """ return semantic.to_tensor(memref, writable, _builder, target_shape=target_shape) - + def check_subview(src, offsets, sizes, strides): """ Check data of subview methods which the data length and the offset value must be 32-byte aligned. @@ -277,7 +249,7 @@ def check_subview(src, offsets, sizes, strides): src_strides = [1] * length if length == 1: if offset[0] % base_byte != 0: - raise TypeError(f"all strides should be 1 and the offset value should be 32-bytes aligned.") + raise TypeError("all strides should be 1 and the offset value should be 32-bytes aligned.") return for i in range(length - 2, -1, -1): src_strides[i] = src_strides[i + 1] * src.shape[i + 1] @@ -293,17 +265,12 @@ def check_subview(src, offsets, sizes, strides): stride_1 = all(s == 1 for s in strides) is_unaligned = result_offset % base_byte != 0 or is_unaligned or not stride_1 if is_unaligned: - raise TypeError(f"all strides should be 1 and the offset value should be 32-bytes aligned.") + raise TypeError("all strides should be 1 and the offset value should be 32-bytes aligned.") @builtin -def subview( - src: buffer, - offsets: List[tl.constexpr], - sizes: List[tl.constexpr], - strides: List[tl.constexpr], - _builder=None -) -> buffer: +def subview(src: buffer, offsets: List[tl.constexpr], sizes: List[tl.constexpr], strides: List[tl.constexpr], + _builder=None) -> buffer: ''' Creates a subview of the source buffer with the specified offsets, sizes, and strides. diff --git a/python/triton/extension/buffer/language/semantic.py b/python/triton/extension/buffer/language/semantic.py index 5dd526366a..b694e3ab29 100644 --- a/python/triton/extension/buffer/language/semantic.py +++ b/python/triton/extension/buffer/language/semantic.py @@ -19,26 +19,18 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -from typing import ( - TypeVar, List -) +from typing import (TypeVar, List) from triton._C.libtriton import ir import triton.language.core as tl from . import core as bl - T = TypeVar('T') -def alloc( - etype: tl.dtype, - shape: List[tl.constexpr], - address_space: bl.address_space, - is_mem_unique, - builder: ir.builder -) -> bl.buffer: +def alloc(etype: tl.dtype, shape: List[tl.constexpr], address_space: bl.address_space, is_mem_unique, + builder: ir.builder) -> bl.buffer: shape = tl._unwrap_shape(shape) if etype == tl.int1: raise TypeError("Unsupported alloc int1 type") @@ -47,16 +39,12 @@ def alloc( etype = tl._constexpr_to_value(etype) address_space = tl._constexpr_to_value(address_space) element_ty_ir = etype.to_ir(builder) - addr_space_attr = ( - address_space.to_ir(builder) if address_space else builder.get_null_attr() - ) + addr_space_attr = (address_space.to_ir(builder) if address_space else builder.get_null_attr()) memref_ty = builder.get_buffer_ty(shape, element_ty_ir, addr_space_attr) handle = builder.alloc(memref_ty) if is_mem_unique: builder.create_annotation_mark(handle, "mem_unique", builder.get_unit_attr()) - builder.create_annotation_mark( - handle, "effects", builder.get_str_array_attr(["write", "read"]) - ) + builder.create_annotation_mark(handle, "effects", builder.get_str_array_attr(["write", "read"])) buffer_ty = bl.buffer_type(element_ty=etype, shape=shape, space=address_space) return bl.buffer(handle, buffer_ty) @@ -73,26 +61,19 @@ def to_buffer( if isinstance(bind_buffer, bl.buffer): builder.create_bind_buffer(tensor.handle, bind_buffer.handle) return bind_buffer - if not (bind_buffer is None): + if bind_buffer is not None: raise ValueError("bind_buffer must be a buffer or None") address_space = tl._constexpr_to_value(address_space) - addr_space_attr = ( - address_space.to_ir(builder) if address_space else builder.get_null_attr() - ) + addr_space_attr = (address_space.to_ir(builder) if address_space else builder.get_null_attr()) handle = builder.to_buffer(tensor.handle, addr_space_attr) buffer_ty = bl.buffer_type(element_ty=tensor.dtype, shape=tensor.shape, space=address_space) return bl.buffer(handle, buffer_ty) - -def to_tensor( - memref: bl.buffer, - writable: bool, - builder: ir.builder, - target_shape=None -) -> tl.tensor: + +def to_tensor(memref: bl.buffer, writable: bool, builder: ir.builder, target_shape=None) -> tl.tensor: if not isinstance(memref, bl.buffer): raise TypeError("memref must be bl.buffer") - + need_convert_layout = False shape = memref.shape if target_shape: @@ -110,26 +91,20 @@ def to_tensor( shape=shape, space=memref.space, ) - memref_value = builder.create_convert_layout( - memref_value, buffer_ty.to_ir(builder)) + memref_value = builder.create_convert_layout(memref_value, buffer_ty.to_ir(builder)) return tl.tensor(builder.to_tensor(memref_value, writable), tensor_type) -def subview( - src: bl.buffer, - offsets: List[tl.tensor], - sizes: List[tl.constexpr], - strides: List[tl.constexpr], - builder: ir.builder -) -> bl.buffer: +def subview(src: bl.buffer, offsets: List[tl.tensor], sizes: List[tl.constexpr], strides: List[tl.constexpr], + builder: ir.builder) -> bl.buffer: new_offsets = [offset.handle for offset in offsets] sizes_int = tl._unwrap_shape(sizes) strides_int = tl._unwrap_shape(strides) result_handle = builder.subview(src.handle, new_offsets, sizes_int, strides_int) - + # calculate the memory layout strides of the source buffer if src.strides: # use the strides of the source buffer @@ -143,16 +118,11 @@ def subview( raise ValueError("Cannot compute strides for buffer with dynamic dimensions") src_memory_strides.insert(0, stride) stride *= dim_size - + result_memory_strides = [] for src_stride, subview_stride in zip(src_memory_strides, strides_int): result_memory_strides.append(src_stride * subview_stride) - + # create buffer_type with strides - buffer_ty = bl.buffer_type( - element_ty=src.dtype, - shape=sizes_int, - space=src.space, - strides=result_memory_strides - ) + buffer_ty = bl.buffer_type(element_ty=src.dtype, shape=sizes_int, space=src.space, strides=result_memory_strides) return bl.buffer(result_handle, buffer_ty) diff --git a/python/triton/extension/buffer/src/buffer_ir.cc b/python/triton/extension/buffer/src/buffer_ir.cc index f1f07dda52..bd1bc917b6 100644 --- a/python/triton/extension/buffer/src/buffer_ir.cc +++ b/python/triton/extension/buffer/src/buffer_ir.cc @@ -39,8 +39,7 @@ constexpr unsigned kIntegerAttrBitWidth = 64; struct BufferOpBuilder : public TritonOpBuilder {}; -void init_buffer_ir(py::module &&m) -{ +void init_buffer_ir(py::module &&m) { m.def("load_dialects", [](MLIRContext &context) { DialectRegistry registry; registry.insert(); @@ -54,9 +53,12 @@ void init_buffer_ir(py::module &&m) .def(py::init()) .def("get_null_attr", [](BufferOpBuilder &self) { return Attribute(); }) .def("get_str_array_attr", - [](BufferOpBuilder &self, const std::vector &array) -> ArrayAttr { - auto strRefVec = to_vector(llvm::map_range(array, [](const auto &s) { return llvm::StringRef(s); })); - return self.getBuilder().getStrArrayAttr(llvm::ArrayRef {strRefVec}); + [](BufferOpBuilder &self, + const std::vector &array) -> ArrayAttr { + auto strRefVec = to_vector(llvm::map_range( + array, [](const auto &s) { return llvm::StringRef(s); })); + return self.getBuilder().getStrArrayAttr( + llvm::ArrayRef{strRefVec}); }) .def("alloc", [](BufferOpBuilder &self, Type memrefType) -> Value { @@ -64,22 +66,27 @@ void init_buffer_ir(py::module &&m) mlir::cast(memrefType)); }) .def("to_buffer", - [](BufferOpBuilder &self, Value &src, const Attribute &addressSpace) -> Value { - auto tensorType = dyn_cast(src.getType()); - if (!tensorType) { - llvm::report_fatal_error("to_buffer: src must be tensor type"); - } - auto memrefType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType(), MemRefLayoutAttrInterface {}); - // TODO: We need to add a pass before OneShotBufferize to generate MemorySpaceCastOp - Operation *memref = self.create(memrefType, src); - if (addressSpace) { - memref = self.create( - MemRefType::get(memrefType.getShape(), memrefType.getElementType(), memrefType.getLayout(), - addressSpace), - memref->getResult(0)); - } - return memref->getResult(0); + [](BufferOpBuilder &self, Value &src, + const Attribute &addressSpace) -> Value { + auto tensorType = dyn_cast(src.getType()); + if (!tensorType) { + llvm::report_fatal_error("to_buffer: src must be tensor type"); + } + auto memrefType = MemRefType::get(tensorType.getShape(), + tensorType.getElementType(), + MemRefLayoutAttrInterface{}); + // TODO: We need to add a pass before OneShotBufferize to generate + // MemorySpaceCastOp + Operation *memref = + self.create(memrefType, src); + if (addressSpace) { + memref = self.create( + MemRefType::get(memrefType.getShape(), + memrefType.getElementType(), + memrefType.getLayout(), addressSpace), + memref->getResult(0)); + } + return memref->getResult(0); }) .def("to_tensor", [](BufferOpBuilder &self, Value &src, bool writable) -> Value { @@ -142,7 +149,8 @@ void init_buffer_ir(py::module &&m) throw std::runtime_error("Expected strides to be positive"); } - // getDimSize() returns -1 (ShapedType::kDynamic) for dynamic dimensions + // getDimSize() returns -1 (ShapedType::kDynamic) for dynamic + // dimensions if (!ShapedType::isDynamic(srcDim)) { // verify the subview size does not exceed the source dimension if (size > srcDim) { @@ -157,13 +165,13 @@ void init_buffer_ir(py::module &&m) } } - mixedSizes.push_back( - IntegerAttr::get(IntegerType::get(context, kIntegerAttrBitWidth), size)); - mixedStrides.push_back( - IntegerAttr::get(IntegerType::get(context, kIntegerAttrBitWidth), stride)); + mixedSizes.push_back(IntegerAttr::get( + IntegerType::get(context, kIntegerAttrBitWidth), size)); + mixedStrides.push_back(IntegerAttr::get( + IntegerType::get(context, kIntegerAttrBitWidth), stride)); } return self.create(source, mixedOffsets, mixedSizes, mixedStrides); }); -} \ No newline at end of file +} diff --git a/python/triton/tools/get_ascend_devices.py b/python/triton/tools/get_ascend_devices.py index 13c28cda81..f6ba1c9c01 100644 --- a/python/triton/tools/get_ascend_devices.py +++ b/python/triton/tools/get_ascend_devices.py @@ -9,16 +9,16 @@ def get_ascend_devices(): devices = [] pci_path = '/sys/bus/pci/devices/*' - + for dev in glob.glob(pci_path): try: vendor_path = os.path.join(dev, 'vendor') device_path = os.path.join(dev, 'device') - + if os.path.exists(vendor_path): with open(vendor_path, 'r') as f: vendor = f.read().strip() - + if vendor == "0x19e5" and os.path.exists(device_path): with open(device_path, 'r') as f: device = f.read().strip() @@ -26,30 +26,24 @@ def get_ascend_devices(): except (IOError, OSError) as e: logger.warning(f"can not fetch device {dev}: {e}") continue - + return devices def check_npu_smi_device(): try: - result = subprocess.run( - ["npu-smi", "info"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - shell=False, - timeout=100 - ) + result = subprocess.run(["npu-smi", "info"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, + shell=False, timeout=100) if result.returncode == 0: output = result.stdout.lower() return "ascend910_95" in output or "ascend950" in output or "910_958b" in output return False - except Exception as e: - logger.warning(f"can not use command: npu-smi info") + except Exception: + logger.warning("can not use command: npu-smi info") return False ascend_devices = get_ascend_devices() pci_condition = any("0xd806" in dev for dev in ascend_devices) npu_smi_condition = check_npu_smi_device() -is_compile_on_910_95 = pci_condition or npu_smi_condition \ No newline at end of file +is_compile_on_910_95 = pci_condition or npu_smi_condition diff --git a/third_party/ascend/ascend_ir.cc b/third_party/ascend/ascend_ir.cc index 3f9b58615b..97ee3f0c2b 100644 --- a/third_party/ascend/ascend_ir.cc +++ b/third_party/ascend/ascend_ir.cc @@ -52,17 +52,16 @@ struct AscendNPUIROpBuilder : public TritonOpBuilder { explicit AscendNPUIROpBuilder(MLIRContext *context, std::string target = "") : TritonOpBuilder(context), target(target) {} - bool is_910_95() const - { + bool is_910_95() const { // TODO: Use enum instead of strings after enabling HACC in satandalone // build constexpr size_t kLen910 = sizeof(kTarget910_95) - 1; bool match_910 = target.size() >= kLen910 && - target.compare(0, kLen910, kTarget910_95) == 0; + target.compare(0, kLen910, kTarget910_95) == 0; constexpr size_t kLen950 = sizeof(kTarget950) - 1; - bool match_950 = target.size() >= kLen950 && - target.compare(0, kLen950, kTarget950) == 0; + bool match_950 = + target.size() >= kLen950 && target.compare(0, kLen950, kTarget950) == 0; return match_910 || match_950; } @@ -71,8 +70,7 @@ struct AscendNPUIROpBuilder : public TritonOpBuilder { namespace { MLIRContext *gDefaultAscendContext = nullptr; -MLIRContext *resolveContext(const py::object &contextObj) -{ +MLIRContext *resolveContext(const py::object &contextObj) { if (!contextObj.is_none()) { return &py::cast(contextObj); } @@ -172,58 +170,63 @@ void init_ascend_ir(py::module &&m) { auto affineExprClass = py::class_(m, "affine_expr", py::module_local()); affineExprClass - .def("__str__", [](AffineExpr self) { - std::string str; - llvm::raw_string_ostream os(str); - self.print(os); - return os.str(); - }) - .def("__repr__", [](AffineExpr self) { - std::string str; - llvm::raw_string_ostream os(str); - self.print(os); - return ""; - }) + .def("__str__", + [](AffineExpr self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }) + .def("__repr__", + [](AffineExpr self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return ""; + }) .def("is_symbolic_or_constant", &AffineExpr::isSymbolicOrConstant) .def("is_pure_affine", &AffineExpr::isPureAffine) .def("is_function_of_dim", &AffineExpr::isFunctionOfDim) .def("compose", [](AffineExpr self, AffineMap map) { return self.compose(map); }) .def("get_largest_known_divisor", &AffineExpr::getLargestKnownDivisor) - .def("floordiv", - [](AffineExpr self, AffineExpr other) { return self.floorDiv(other); }) - .def("ceildiv", - [](AffineExpr self, AffineExpr other) { return self.ceilDiv(other); }) + .def("floordiv", [](AffineExpr self, + AffineExpr other) { return self.floorDiv(other); }) + .def("ceildiv", [](AffineExpr self, + AffineExpr other) { return self.ceilDiv(other); }) .def("mod", [](AffineExpr self, AffineExpr other) { return self % other; }) - .def("__hash__", [](AffineExpr self) { - return py::int_(static_cast(mlir::hash_value(self))); - }) - .def("__eq__", - [](AffineExpr lhs, AffineExpr rhs) { return lhs == rhs; }) + .def("__hash__", + [](AffineExpr self) { + return py::int_(static_cast(mlir::hash_value(self))); + }) + .def("__eq__", [](AffineExpr lhs, AffineExpr rhs) { return lhs == rhs; }) .def(py::self + py::self) .def(py::self - py::self) .def(py::self * py::self) .def(py::self % py::self); affineExprClass - .def_static("get_constant", - [](int64_t val, py::object contextObj) { - auto *context = resolveContext(contextObj); - return getAffineConstantExpr(val, context); - }, - py::arg("value"), py::arg("context") = py::none()) - .def_static("get_dim", - [](uint32_t pos, py::object contextObj) { - auto *context = resolveContext(contextObj); - return getAffineDimExpr(pos, context); - }, - py::arg("pos"), py::arg("context") = py::none()) - .def_static("get_symbol", - [](uint32_t pos, py::object contextObj) { - auto *context = resolveContext(contextObj); - return getAffineSymbolExpr(pos, context); - }, - py::arg("pos"), py::arg("context") = py::none()); + .def_static( + "get_constant", + [](int64_t val, py::object contextObj) { + auto *context = resolveContext(contextObj); + return getAffineConstantExpr(val, context); + }, + py::arg("value"), py::arg("context") = py::none()) + .def_static( + "get_dim", + [](uint32_t pos, py::object contextObj) { + auto *context = resolveContext(contextObj); + return getAffineDimExpr(pos, context); + }, + py::arg("pos"), py::arg("context") = py::none()) + .def_static( + "get_symbol", + [](uint32_t pos, py::object contextObj) { + auto *context = resolveContext(contextObj); + return getAffineSymbolExpr(pos, context); + }, + py::arg("pos"), py::arg("context") = py::none()); py::class_(m, "affine_constant_expr", py::module_local()) @@ -239,20 +242,23 @@ void init_ascend_ir(py::module &&m) { .def("get_lhs", &AffineBinaryOpExpr::getLHS) .def("get_rhs", &AffineBinaryOpExpr::getRHS); - auto affineMapClass = py::class_(m, "affine_map", py::module_local()); + auto affineMapClass = + py::class_(m, "affine_map", py::module_local()); affineMapClass - .def("__str__", [](AffineMap &self) { - std::string str; - llvm::raw_string_ostream os(str); - self.print(os); - return os.str(); - }) - .def("__repr__", [](AffineMap &self) { - std::string str; - llvm::raw_string_ostream os(str); - self.print(os); - return ""; - }) + .def("__str__", + [](AffineMap &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }) + .def("__repr__", + [](AffineMap &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return ""; + }) .def("is_identity", &AffineMap::isIdentity) .def("is_permutation", &AffineMap::isPermutation) .def("get_num_dims", &AffineMap::getNumDims) @@ -261,12 +267,14 @@ void init_ascend_ir(py::module &&m) { .def("is_empty", &AffineMap::isEmpty) .def("is_single_constant", &AffineMap::isSingleConstant) .def("is_constant", &AffineMap::isConstant) - .def("get_constant_result", [](AffineMap &self) -> int64_t { - if (!self.isSingleConstant()) { - throw std::runtime_error("affine map is not a single constant map"); - } - return self.getSingleConstantResult(); - }) + .def("get_constant_result", + [](AffineMap &self) -> int64_t { + if (!self.isSingleConstant()) { + throw std::runtime_error( + "affine map is not a single constant map"); + } + return self.getSingleConstantResult(); + }) .def("get_result", [](AffineMap &self, uint32_t pos) { if (pos >= self.getNumResults()) { @@ -284,33 +292,36 @@ void init_ascend_ir(py::module &&m) { return self.replace(expr, replacement, numResultDims, numResultSymbols); }) - .def("compose", [](AffineMap &self, AffineMap map) { - return self.compose(map); - }) - .def("get_results", [](AffineMap &self) -> std::vector { - auto results = self.getResults(); - return std::vector(results.begin(), results.end()); - }) - .def("__hash__", [](AffineMap &self) { - return py::int_(static_cast(mlir::hash_value(self))); - }) + .def("compose", + [](AffineMap &self, AffineMap map) { return self.compose(map); }) + .def("get_results", + [](AffineMap &self) -> std::vector { + auto results = self.getResults(); + return std::vector(results.begin(), results.end()); + }) + .def("__hash__", + [](AffineMap &self) { + return py::int_(static_cast(mlir::hash_value(self))); + }) .def("__eq__", [](AffineMap &lhs, AffineMap &rhs) { return lhs == rhs; }) - .def("inverse_permutation", [](AffineMap &self) -> py::object { - // Validate it's a permutation first - if (!self.isPermutation()) { - throw py::value_error("AffineMap must be a valid permutation to compute inverse"); - } + .def("inverse_permutation", + [](AffineMap &self) -> py::object { + // Validate it's a permutation first + if (!self.isPermutation()) { + throw py::value_error( + "AffineMap must be a valid permutation to compute inverse"); + } - // Returns AffineMap directly, not a pointer - AffineMap inverse = mlir::inversePermutation(self); + // Returns AffineMap directly, not a pointer + AffineMap inverse = mlir::inversePermutation(self); - // Check if result is valid (null AffineMap) - if (!inverse) { - throw py::value_error("Failed to compute inverse permutation"); - } + // Check if result is valid (null AffineMap) + if (!inverse) { + throw py::value_error("Failed to compute inverse permutation"); + } - return py::cast(inverse); - }) + return py::cast(inverse); + }) .def("to_dict", [](AffineMap &self) -> py::dict { py::list results; for (AffineExpr result : self.getResults()) { @@ -331,7 +342,8 @@ void init_ascend_ir(py::module &&m) { return ret; }); affineMapClass - .def_static("get", + .def_static( + "get", [](int64_t numDims, int64_t numSymbols, const py::iterable &resultsIn, py::object contextObj) -> AffineMap { MLIRContext *context = nullptr; @@ -371,44 +383,47 @@ void init_ascend_ir(py::module &&m) { }, py::arg("num_dims"), py::arg("num_symbols"), py::arg("result_dims"), py::arg("context") = py::none()) - .def_static("get_identity", - [](int64_t numDims, py::object contextObj) -> AffineMap { - auto *context = resolveContext(contextObj); - if (numDims < 0) { - throw std::invalid_argument( - "num_dims must be non-negative"); - } - return AffineMap::getMultiDimIdentityMap(numDims, context); - }, - py::arg("num_dims"), py::arg("context") = py::none()) - .def_static("get_minor_identity", - [](int64_t dims, int64_t results, py::object contextObj) { - auto *context = resolveContext(contextObj); - if (dims < 0 || results < 0) { - throw std::invalid_argument( - "dims/results must be non-negative"); - } - return AffineMap::getMinorIdentityMap(dims, results, context); - }, - py::arg("dims"), py::arg("results"), - py::arg("context") = py::none()) - .def_static("get_empty", [](py::object contextObj) { - auto *context = resolveContext(contextObj); - return AffineMap::get(0, 0, {}, context); - }, py::arg("context") = py::none()) - .def_static("get_permutation", - [](const std::vector &permutation, - py::object contextObj) { - auto *context = resolveContext(contextObj); - return AffineMap::getPermutationMap(permutation, context); - }, - py::arg("permutation"), py::arg("context") = py::none()) - .def_static("get_constant", - [](int64_t value, py::object contextObj) { - auto *context = resolveContext(contextObj); - return AffineMap::getConstantMap(value, context); - }, - py::arg("value"), py::arg("context") = py::none()); + .def_static( + "get_identity", + [](int64_t numDims, py::object contextObj) -> AffineMap { + auto *context = resolveContext(contextObj); + if (numDims < 0) { + throw std::invalid_argument("num_dims must be non-negative"); + } + return AffineMap::getMultiDimIdentityMap(numDims, context); + }, + py::arg("num_dims"), py::arg("context") = py::none()) + .def_static( + "get_minor_identity", + [](int64_t dims, int64_t results, py::object contextObj) { + auto *context = resolveContext(contextObj); + if (dims < 0 || results < 0) { + throw std::invalid_argument("dims/results must be non-negative"); + } + return AffineMap::getMinorIdentityMap(dims, results, context); + }, + py::arg("dims"), py::arg("results"), py::arg("context") = py::none()) + .def_static( + "get_empty", + [](py::object contextObj) { + auto *context = resolveContext(contextObj); + return AffineMap::get(0, 0, {}, context); + }, + py::arg("context") = py::none()) + .def_static( + "get_permutation", + [](const std::vector &permutation, py::object contextObj) { + auto *context = resolveContext(contextObj); + return AffineMap::getPermutationMap(permutation, context); + }, + py::arg("permutation"), py::arg("context") = py::none()) + .def_static( + "get_constant", + [](int64_t value, py::object contextObj) { + auto *context = resolveContext(contextObj); + return AffineMap::getConstantMap(value, context); + }, + py::arg("value"), py::arg("context") = py::none()); py::enum_(m, "AddressSpace", py::module_local()) .value("L1", hivm::AddressSpace::L1) @@ -443,19 +458,19 @@ void init_ascend_ir(py::module &&m) { .export_values(); py::enum_(m, "IteratorType", py::module_local()) - .value("Parallel", hivm::IteratorType::kParallel) - .value("Broadcast", hivm::IteratorType::kBroadcast) - .value("Transpose", hivm::IteratorType::kTranspose) - .value("Reduction", hivm::IteratorType::kReduction) - .value("Interleave", hivm::IteratorType::kInterleave) - .value("Deinterleave", hivm::IteratorType::kDeinterleave) - .value("Inverse", hivm::IteratorType::kInverse) - .value("Pad", hivm::IteratorType::kPad) - .value("Concat", hivm::IteratorType::kConcat) - .value("Gather", hivm::IteratorType::kGather) - .value("Cumulative", hivm::IteratorType::kCumulative) - .value("Opaque", hivm::IteratorType::kOpaque) - .export_values(); + .value("Parallel", hivm::IteratorType::kParallel) + .value("Broadcast", hivm::IteratorType::kBroadcast) + .value("Transpose", hivm::IteratorType::kTranspose) + .value("Reduction", hivm::IteratorType::kReduction) + .value("Interleave", hivm::IteratorType::kInterleave) + .value("Deinterleave", hivm::IteratorType::kDeinterleave) + .value("Inverse", hivm::IteratorType::kInverse) + .value("Pad", hivm::IteratorType::kPad) + .value("Concat", hivm::IteratorType::kConcat) + .value("Gather", hivm::IteratorType::kGather) + .value("Cumulative", hivm::IteratorType::kCumulative) + .value("Opaque", hivm::IteratorType::kOpaque) + .export_values(); py::enum_(m, "FixpipeDMAMode", py::module_local()) .value("NZ2DN", hivm::FixpipeDMAMode::NZ2DN) @@ -521,14 +536,16 @@ void init_ascend_ir(py::module &&m) { return self.getBuilder().getAttr(mode); }) .def("get_iterator_types_attr", - [](AscendNPUIROpBuilder &self, const std::vector& array) { - llvm::SmallVector attrs; - attrs.reserve(array.size()); - for (auto type : array) { - attrs.push_back(self.getBuilder().getI32IntegerAttr(static_cast(type))); - } - return self.getBuilder().getArrayAttr(attrs); - }) + [](AscendNPUIROpBuilder &self, + const std::vector &array) { + llvm::SmallVector attrs; + attrs.reserve(array.size()); + for (auto type : array) { + attrs.push_back(self.getBuilder().getI32IntegerAttr( + static_cast(type))); + } + return self.getBuilder().getArrayAttr(attrs); + }) .def("get_t_core_type_attr_name", [](AscendNPUIROpBuilder &self) -> std::string { return hivm::TCoreTypeAttr::name.str(); @@ -546,7 +563,8 @@ void init_ascend_ir(py::module &&m) { .def("parse_attr", [](TritonOpBuilder &self, std::string value) -> Attribute { auto *ctx = self.getBuilder().getContext(); - // Enable parsing of HACC attributes by allowing unregistered dialects. + // Enable parsing of HACC attributes by allowing unregistered + // dialects. ctx->allowUnregisteredDialects(); return mlir::parseAttribute(value, ctx); }) @@ -620,17 +638,16 @@ void init_ascend_ir(py::module &&m) { attrVal); }) .def("create_custom_op", - [](AscendNPUIROpBuilder &self, - const std::string &name, - const py::dict &attrs, - const std::vector &ins, - const std::vector &outs, - const std::vector &arg_attrs) -> std::vector { + [](AscendNPUIROpBuilder &self, const std::string &name, + const py::dict &attrs, const std::vector &ins, + const std::vector &outs, + const std::vector &arg_attrs) -> std::vector { ValueRange inputs{ins}; ValueRange outputs{outs}; ValueRange temp_buffers{}; TypeRange res_types{outputs}; - auto op = self.create(res_types, name, inputs, outputs, temp_buffers); + auto op = self.create(res_types, name, inputs, + outputs, temp_buffers); for (auto &attr : attrs) { std::string attr_name = py::cast(attr.first); Attribute attr_value = py::cast(attr.second); @@ -652,14 +669,15 @@ void init_ascend_ir(py::module &&m) { for (const auto &attr : attrs) { std::string attr_name = py::cast(attr.first); Attribute attr_value = py::cast(attr.second); - namedAttrs.push_back( - NamedAttribute(self.getBuilder().getStringAttr(attr_name), attr_value)); + namedAttrs.push_back(NamedAttribute( + self.getBuilder().getStringAttr(attr_name), attr_value)); } dictAttrs[idx] = self.getBuilder().getDictionaryAttr(namedAttrs); } - ArrayAttr arg_attrs_array = self.getBuilder().getArrayAttr(dictAttrs); + ArrayAttr arg_attrs_array = + self.getBuilder().getArrayAttr(dictAttrs); op->setAttr("arg_attrs", arg_attrs_array); auto results = op->getResults(); diff --git a/third_party/ascend/backend/compiler.py b/third_party/ascend/backend/compiler.py index 1f04b4ec9f..dc7a069737 100644 --- a/third_party/ascend/backend/compiler.py +++ b/third_party/ascend/backend/compiler.py @@ -63,7 +63,6 @@ from triton.tools.get_ascend_devices import is_compile_on_910_95 - # TODO: materialize the concrete min shape def min_dot_size(target: GPUTarget): return lambda lhsType, rhsType: (1, 1, 1) @@ -114,10 +113,7 @@ def ttir_to_linalg(mod, metadata, opt, *, named_ops=False): auto_blockify_size = 1 pm = ir.pass_manager(mod.context) pm.enable_debug() - ascend.passes.ttir.add_auto_blockify( - pm, - auto_blockify_size - ) + ascend.passes.ttir.add_auto_blockify(pm, auto_blockify_size) if (metadata["add_auto_scheduling"]): ascend.passes.ttir.add_dag_sync(pm) ascend.passes.ttir.add_dag_scope(pm) @@ -127,40 +123,18 @@ def ttir_to_linalg(mod, metadata, opt, *, named_ops=False): passes.common.add_cse(pm) passes.common.add_canonicalizer(pm) - ascend.passes.ttir.add_triton_to_structure( - pm, - enable_mask_fallback_conversion, - optimize_dynamic_offset - ) - ascend.passes.ttir.add_discrete_mask_access_conversion( - pm, - compile_on_910_95, - force_simt_template, - enable_sync_block_lock - ) + ascend.passes.ttir.add_triton_to_structure(pm, enable_mask_fallback_conversion, optimize_dynamic_offset) + ascend.passes.ttir.add_discrete_mask_access_conversion(pm, compile_on_910_95, force_simt_template, + enable_sync_block_lock) ascend.passes.ttir.add_triton_to_annotation(pm) - ascend.passes.ttir.add_triton_to_unstructure( - pm, - compile_on_910_95, - force_simt_template - ) + ascend.passes.ttir.add_triton_to_unstructure(pm, compile_on_910_95, force_simt_template) ascend.passes.ttir.add_triton_to_hivm(pm) ascend.passes.ttir.add_triton_to_hfusion(pm) ascend.passes.ttir.add_triton_to_llvm(pm) ascend.passes.ttir.add_bubble_up_operation(pm) - ascend.passes.ttir.add_triton_to_structure( - pm, - enable_mask_fallback_conversion, - optimize_dynamic_offset - ) - ascend.passes.ttir.add_triton_to_linalg( - pm, - False, - named_ops, - enable_nd2nz_on_vector, - enable_select_analysis, - compile_on_910_95 - ) + ascend.passes.ttir.add_triton_to_structure(pm, enable_mask_fallback_conversion, optimize_dynamic_offset) + ascend.passes.ttir.add_triton_to_linalg(pm, False, named_ops, enable_nd2nz_on_vector, enable_select_analysis, + compile_on_910_95) pm.run(mod) enable_nd2nz_on_vector = metadata["enable_nd2nz_on_vector"] @@ -294,11 +268,8 @@ def get_auto_bind_sub_block_option(metadata): # auto_tile_and_bind_subblock is read from the module. # enable_auto_bind_sub_block is set by the user and has a higher priority. enable_auto_bind_sub_block = metadata["enable_auto_bind_sub_block"] - return ( - metadata["auto_tile_and_bind_subblock"] - if enable_auto_bind_sub_block is None - else enable_auto_bind_sub_block - ) + return (metadata["auto_tile_and_bind_subblock"] + if enable_auto_bind_sub_block is None else enable_auto_bind_sub_block) def _save_npuir_debug_output(stdout_bytes: bytes, stderr_bytes: bytes, tmpdir: str, metadata_hash: str): @@ -312,11 +283,7 @@ def _save_npuir_debug_output(stdout_bytes: bytes, stderr_bytes: bytes, tmpdir: s f.write(combined) dump_manager = get_dump_manager(metadata_hash) - dump_manager.put( - Path(output_path).read_text(encoding='utf-8'), - "kernel.npuir.mlir", - binary=False - ) + dump_manager.put(Path(output_path).read_text(encoding='utf-8'), "kernel.npuir.mlir", binary=False) def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): @@ -467,11 +434,7 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): if opt.debug: _compile_option_list += ["--bishengir-print-ir-after=hivm-graph-sync-solver"] - cmd_list = ( - [npu_compiler_path, ttadapter_path] - + _compile_option_list - + ["-o", bin_file] - ) + cmd_list = ([npu_compiler_path, ttadapter_path] + _compile_option_list + ["-o", bin_file]) vf_merge_level = metadata["vf_merge_level"] if vf_merge_level is not None: cmd_list += [f"--enable-vf-merge-level={vf_merge_level}"] @@ -484,13 +447,7 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): print(f"[DEBUG] cmd_list: {' '.join(cmd_list)}") try: - ret = subprocess.run( - cmd_list, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=True - ) + ret = subprocess.run(cmd_list, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) except subprocess.CalledProcessError as e: if opt.debug: _save_npuir_debug_output(e.stdout, e.stderr, tmpdir, metadata["hash"]) @@ -673,22 +630,12 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): if opt.debug: _compile_option_list += ["--bishengir-print-ir-after=hivm-graph-sync-solver"] - cmd_list = ( - [npu_compiler_path, ttadapter_path] - + _compile_option_list - + ["-o", bin_file] - ) + cmd_list = ([npu_compiler_path, ttadapter_path] + _compile_option_list + ["-o", bin_file]) if opt.debug: print(f"[DEBUG] cmd_list: {' '.join(cmd_list)}") try: - ret = subprocess.run( - cmd_list, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=True - ) + ret = subprocess.run(cmd_list, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) except subprocess.CalledProcessError as e: if opt.debug: _save_npuir_debug_output(e.stdout, e.stderr, tmpdir, metadata["hash"]) @@ -879,17 +826,11 @@ def ttir_to_npubin(mod, metadata, opt): if (enable_libdevice_simt): bisheng_options = metadata["bisheng_options"] if bisheng_options is not None: - _compile_option_list += [ - f"--append-bisheng-options={bisheng_options}" - ] + _compile_option_list += [f"--append-bisheng-options={bisheng_options}"] npu_compiler_path, env = _get_npucompiler_path() - cmd_list = ( - [npu_compiler_path, src_path] - + _compile_option_list - + ["-o", bin_file] - ) - ret = subprocess.run(cmd_list, env = env, capture_output = True, check = True) + cmd_list = ([npu_compiler_path, src_path] + _compile_option_list + ["-o", bin_file]) + ret = subprocess.run(cmd_list, env=env, capture_output=True, check=True) if not Path(bin_path).exists(): error_msg = ret.stderr.decode('utf-8') print(f"[DEBUG] {bin_path} is not found") @@ -916,10 +857,8 @@ def parse_options(self, opts) -> Any: args.setdefault("arch", self.target.arch) options = NPUOptions(**args) else: - raise NotImplementedError( - f"Backend '{self.target.backend}' is not supported. " - "Please ensure the target backend is set to 'npu'." - ) + raise NotImplementedError(f"Backend '{self.target.backend}' is not supported. " + "Please ensure the target backend is set to 'npu'.") return options def pack_metadata(self, metadata): @@ -970,10 +909,8 @@ def add_stages(self, stages, options): stages["npubin"] = ( lambda src, metadata: linalg_to_bin_enable_npu_compile_A2_A3(src, metadata, options)) else: - raise NotImplementedError( - f"Backend '{self.target.backend}' is not supported. " - "Please ensure the target backend is set to 'npu'." - ) + raise NotImplementedError(f"Backend '{self.target.backend}' is not supported. " + "Please ensure the target backend is set to 'npu'.") @functools.lru_cache() def hash(self): diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py index 803a66595a..b70608ec57 100644 --- a/third_party/ascend/backend/driver.py +++ b/third_party/ascend/backend/driver.py @@ -218,7 +218,7 @@ def _precompile_npu_ext_with_lock(header_src, enable_precompile): cache = get_cache_manager(precompile_hash) gch_path = cache.get_file("precompiled.h.gch") header_path = cache.get_file("precompiled.h") - if enable_precompile: + if enable_precompile: if header_path is not None and gch_path is not None: return header_path else: @@ -246,7 +246,7 @@ def _precompile_npu_ext_with_lock(header_src, enable_precompile): return header_path finally: fcntl.flock(f, fcntl.LOCK_UN) - + def make_npu_launcher_stub(header_src, wrapper_src, debug=False): """ @@ -256,7 +256,7 @@ def make_npu_launcher_stub(header_src, wrapper_src, debug=False): # if precompile header file and its gch file not exist, do precompile header_path = _precompile_npu_ext_with_lock(header_src, enable_precompile) assert header_path is not None, "the precompiled.h path is empty." - + # try to get cached file so_cache_key = hashlib.sha256(wrapper_src.encode("utf-8")).hexdigest() so_cache_manager = get_cache_manager(so_cache_key) @@ -281,12 +281,12 @@ def make_npu_launcher_stub(header_src, wrapper_src, debug=False): kernel_launcher_type = "torch" - with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, f"{name}.cxx") with open(src_path, "w") as f: f.write(wrapper_src) - so_path = _build_npu_ext(name, header_path, src_path, kernel_launcher=kernel_launcher_type, precompile=enable_precompile) + so_path = _build_npu_ext(name, header_path, src_path, kernel_launcher=kernel_launcher_type, + precompile=enable_precompile) if debug: with open(so_path, "rb") as f: dump_manager.put(f.read(), so_name, binary=True) diff --git a/third_party/ascend/backend/npu_utils.cpp b/third_party/ascend/backend/npu_utils.cpp index ec41487b84..1665ff1b4a 100644 --- a/third_party/ascend/backend/npu_utils.cpp +++ b/third_party/ascend/backend/npu_utils.cpp @@ -37,9 +37,10 @@ static std::unordered_map registered_names; static std::unordered_map> func_stubs; -static std::tuple -registerKernel(const char *name, const void *data, size_t data_size, - int device, const char *kernel_mode_str) { +static std::tuple registerKernel(const char *name, + const void *data, + size_t data_size, int device, + const char *kernel_mode_str) { rtError_t rtRet; rtDevBinary_t devbin; @@ -141,19 +142,19 @@ static PyObject *createStream(PyObject *self, PyObject *args) { rtError_t rtRet = rtStreamCreate(&stream, 0); - if (rtRet != RT_ERROR_NONE) { - printf("rtStreamCreate failed, 0x%x", rtRet); - return nullptr; - } - if (PyErr_Occurred()) { - return nullptr; - } - uint64_t stream_uint64 = reinterpret_cast(stream); - PyObject* result = Py_BuildValue("K", stream_uint64); + if (rtRet != RT_ERROR_NONE) { + printf("rtStreamCreate failed, 0x%x", rtRet); + return nullptr; + } + if (PyErr_Occurred()) { + return nullptr; + } + uint64_t stream_uint64 = reinterpret_cast(stream); + PyObject *result = Py_BuildValue("K", stream_uint64); - if (result == nullptr) { - rtStreamDestroy(stream); - } + if (result == nullptr) { + rtStreamDestroy(stream); + } return result; } @@ -193,21 +194,21 @@ std::vector readDataFromBinaryFile(const std::string &filename) { } static PyObject *readDataFromBinaryFileWrapper(PyObject *self, PyObject *args) { - const char *filename; - uint64_t arr_ptr; - if (!PyArg_ParseTuple(args, "sK", &filename, &arr_ptr)) { - return nullptr; - } - - try { - std::vector data = readDataFromBinaryFile(filename); - char *arr = reinterpret_cast(arr_ptr); - std::copy(data.begin(), data.end(), arr); - return Py_None; - } catch (const std::exception& e) { - PyErr_SetString(PyExc_RuntimeError, e.what()); - return nullptr; - } + const char *filename; + uint64_t arr_ptr; + if (!PyArg_ParseTuple(args, "sK", &filename, &arr_ptr)) { + return nullptr; + } + + try { + std::vector data = readDataFromBinaryFile(filename); + char *arr = reinterpret_cast(arr_ptr); + std::copy(data.begin(), data.end(), arr); + return Py_None; + } catch (const std::exception &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + } } void writeDataToBinaryFile(const std::string &filename, const char *data, @@ -225,99 +226,104 @@ void writeDataToBinaryFile(const std::string &filename, const char *data, } static PyObject *writeDataToBinaryFileWrapper(PyObject *self, PyObject *args) { - const char *filename; - uint64_t arr_ptr; - size_t num_bytes; - - if (!PyArg_ParseTuple(args, "sKn", &filename, &arr_ptr, &num_bytes)) { - return nullptr; - } - - try { - const char* data = reinterpret_cast(arr_ptr); - writeDataToBinaryFile(filename, data, num_bytes); - return Py_None; - } catch (const std::exception& e) { - PyErr_SetString(PyExc_RuntimeError, e.what()); - return nullptr; - } + const char *filename; + uint64_t arr_ptr; + size_t num_bytes; + + if (!PyArg_ParseTuple(args, "sKn", &filename, &arr_ptr, &num_bytes)) { + return nullptr; + } + + try { + const char *data = reinterpret_cast(arr_ptr); + writeDataToBinaryFile(filename, data, num_bytes); + return Py_None; + } catch (const std::exception &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + } } -static PyObject* allocateHostMemory(PyObject* self, PyObject* args) { - uint64_t num_bytes; - if (!PyArg_ParseTuple(args, "K", &num_bytes)) { - return nullptr; - } +static PyObject *allocateHostMemory(PyObject *self, PyObject *args) { + uint64_t num_bytes; + if (!PyArg_ParseTuple(args, "K", &num_bytes)) { + return nullptr; + } - void* host_ptr = nullptr; - rtError_t error = rtMallocHost(&host_ptr, num_bytes, RT_MEMORY_HOST); - if (error != RT_ERROR_NONE) { - PyErr_Format(PyExc_RuntimeError, "rtMallocHost failed with error code: 0x%x", error); - return nullptr; - } + void *host_ptr = nullptr; + rtError_t error = rtMallocHost(&host_ptr, num_bytes, RT_MEMORY_HOST); + if (error != RT_ERROR_NONE) { + PyErr_Format(PyExc_RuntimeError, + "rtMallocHost failed with error code: 0x%x", error); + return nullptr; + } PyObject *result = Py_BuildValue("K", (uint64_t)host_ptr); - if (result == nullptr) { - rtFreeHost(host_ptr); - } + if (result == nullptr) { + rtFreeHost(host_ptr); + } return result; } -static PyObject* allocateDeviceMemory(PyObject* self, PyObject* args) { - uint64_t num_bytes; - if (!PyArg_ParseTuple(args, "K", &num_bytes)) { - return nullptr; - } +static PyObject *allocateDeviceMemory(PyObject *self, PyObject *args) { + uint64_t num_bytes; + if (!PyArg_ParseTuple(args, "K", &num_bytes)) { + return nullptr; + } - void* device_ptr = nullptr; - rtError_t error = rtMalloc(&device_ptr, num_bytes, RT_MEMORY_HBM, 0); - if (error != RT_ERROR_NONE) { - PyErr_Format(PyExc_RuntimeError, "rtMalloc failed with error code: 0x%x", error); - return nullptr; - } + void *device_ptr = nullptr; + rtError_t error = rtMalloc(&device_ptr, num_bytes, RT_MEMORY_HBM, 0); + if (error != RT_ERROR_NONE) { + PyErr_Format(PyExc_RuntimeError, "rtMalloc failed with error code: 0x%x", + error); + return nullptr; + } PyObject *result = Py_BuildValue("K", (uint64_t)device_ptr); - if (result == nullptr) { - rtFree(device_ptr); - } + if (result == nullptr) { + rtFree(device_ptr); + } return result; } -static PyObject* copyMemory(PyObject* self, PyObject* args) { - uint64_t dst_ptr; - uint64_t src_ptr; - size_t count; - const char* direction_str; - rtMemcpyKind_t copy_direction; - - if (!PyArg_ParseTuple(args, "KKns", &dst_ptr, &src_ptr, &count, &direction_str)) { - return nullptr; - } - - if (strcmp(direction_str, "H2D") == 0) { - copy_direction = RT_MEMCPY_HOST_TO_DEVICE; - } else if (strcmp(direction_str, "D2H") == 0) { - copy_direction = RT_MEMCPY_DEVICE_TO_HOST; - } else { - PyErr_SetString(PyExc_ValueError, "Invalid copy direction. Must be 'H2D' or 'D2H'."); - return nullptr; - } - - void *dst = (void*)dst_ptr; - void *src = (void*)src_ptr; - - rtError_t error = rtMemcpy(dst, count, src, count, copy_direction); - if (error != RT_ERROR_NONE) { - PyErr_Format(PyExc_RuntimeError, "rtMemcpy failed with error code: 0x%x", error); - return nullptr; - } - - Py_INCREF(Py_None); - return Py_None; +static PyObject *copyMemory(PyObject *self, PyObject *args) { + uint64_t dst_ptr; + uint64_t src_ptr; + size_t count; + const char *direction_str; + rtMemcpyKind_t copy_direction; + + if (!PyArg_ParseTuple(args, "KKns", &dst_ptr, &src_ptr, &count, + &direction_str)) { + return nullptr; + } + + if (strcmp(direction_str, "H2D") == 0) { + copy_direction = RT_MEMCPY_HOST_TO_DEVICE; + } else if (strcmp(direction_str, "D2H") == 0) { + copy_direction = RT_MEMCPY_DEVICE_TO_HOST; + } else { + PyErr_SetString(PyExc_ValueError, + "Invalid copy direction. Must be 'H2D' or 'D2H'."); + return nullptr; + } + + void *dst = (void *)dst_ptr; + void *src = (void *)src_ptr; + + rtError_t error = rtMemcpy(dst, count, src, count, copy_direction); + if (error != RT_ERROR_NONE) { + PyErr_Format(PyExc_RuntimeError, "rtMemcpy failed with error code: 0x%x", + error); + return nullptr; + } + + Py_INCREF(Py_None); + return Py_None; } static PyMethodDef NpuUtilsMethods[] = { @@ -326,12 +332,17 @@ static PyMethodDef NpuUtilsMethods[] = { {"get_arch", getArch, METH_VARARGS, "Get soc version of NPU"}, // sentinel {"get_aicore_num", getAiCoreNum, METH_VARARGS, "Get the number of AI core"}, - {"create_stream", createStream, METH_VARARGS, "Create a stream"}, - {"read_data_from_file", readDataFromBinaryFileWrapper, METH_VARARGS, "Read binary file into the array already allocated"}, - {"write_data_to_file", writeDataToBinaryFileWrapper, METH_VARARGS, "Write an array to a binary file"}, - {"allocate_device_memory", allocateDeviceMemory, METH_VARARGS, "Allocate device memory"}, - {"allocate_host_memory", allocateHostMemory, METH_VARARGS, "Allocate host memory"}, - {"copy_memory", copyMemory, METH_VARARGS, "Copy data between host and device"}, + {"create_stream", createStream, METH_VARARGS, "Create a stream"}, + {"read_data_from_file", readDataFromBinaryFileWrapper, METH_VARARGS, + "Read binary file into the array already allocated"}, + {"write_data_to_file", writeDataToBinaryFileWrapper, METH_VARARGS, + "Write an array to a binary file"}, + {"allocate_device_memory", allocateDeviceMemory, METH_VARARGS, + "Allocate device memory"}, + {"allocate_host_memory", allocateHostMemory, METH_VARARGS, + "Allocate host memory"}, + {"copy_memory", copyMemory, METH_VARARGS, + "Copy data between host and device"}, {nullptr, nullptr, 0, nullptr}}; static PyModuleDef ModuleDef = { diff --git a/third_party/ascend/backend/runtime/autoparser.py b/third_party/ascend/backend/runtime/autoparser.py index 642ffc780e..4ff29ed9e4 100644 --- a/third_party/ascend/backend/runtime/autoparser.py +++ b/third_party/ascend/backend/runtime/autoparser.py @@ -100,7 +100,7 @@ def get_axis(self, var: str, node=None): elif isinstance(child_node, ast.BinOp) and \ isinstance(child_node.op, ast.BitAnd): - + axis = self.handle_lt_node(var, child_node.left) if axis is None: axis = self.handle_lt_node(var, child_node.right) @@ -202,11 +202,8 @@ def parse(self) -> Dict[str, str]: def visit_Assign(self, node): pid_dim = self._get_program_id_dim(node.value) if pid_dim is not None: - if ( - len(node.targets) == 1 - and isinstance(node.targets[0], ast.Name) - and node.targets[0].id not in self.program_id_vars - ): + if (len(node.targets) == 1 and isinstance(node.targets[0], ast.Name) + and node.targets[0].id not in self.program_id_vars): self.program_id_vars.append(node.targets[0].id) self.program_id_var_dims[node.targets[0].id] = pid_dim num_programs_dim = self._get_num_programs_dim(node.value) @@ -239,7 +236,7 @@ def visit_BinOp(self, node): if isinstance(node.left, ast.Name): split_axes_val = node.left.id split_axis_pid_dim = self._get_program_id_dim(node.right) - + if split_axes_val in self.candidates_params and \ split_axes_val not in self.split_axes.values(): split_axes_key = self.get_axis(split_axes_val) @@ -256,12 +253,8 @@ def visit_For(self, node): iter_fn = node.iter.func is_range = isinstance(iter_fn, ast.Name) and iter_fn.id == "range" - is_tl_range = ( - isinstance(iter_fn, ast.Attribute) - and isinstance(iter_fn.value, ast.Name) - and iter_fn.value.id == "tl" - and iter_fn.attr == "range" - ) + is_tl_range = (isinstance(iter_fn, ast.Attribute) and isinstance(iter_fn.value, ast.Name) + and iter_fn.value.id == "tl" and iter_fn.attr == "range") if not (is_range or is_tl_range): self.generic_visit(node) return @@ -285,13 +278,8 @@ def visit_For(self, node): self.generic_visit(node) def _get_program_id_dim(self, node): - if not ( - isinstance(node, ast.Call) - and isinstance(node.func, ast.Attribute) - and isinstance(node.func.value, ast.Name) - and node.func.value.id == "tl" - and node.func.attr == "program_id" - ): + if not (isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and isinstance( + node.func.value, ast.Name) and node.func.value.id == "tl" and node.func.attr == "program_id"): return None axis_dim = 0 @@ -311,13 +299,8 @@ def _get_program_id_dim(self, node): return axis_dim def _get_num_programs_dim(self, node): - if not ( - isinstance(node, ast.Call) - and isinstance(node.func, ast.Attribute) - and isinstance(node.func.value, ast.Name) - and node.func.value.id == "tl" - and node.func.attr == "num_programs" - ): + if not (isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and isinstance( + node.func.value, ast.Name) and node.func.value.id == "tl" and node.func.attr == "num_programs"): return None axis_dim = 0 @@ -370,11 +353,7 @@ def _contains_num_programs_dim(self, node, pid_dim): return False def _is_candidate_name(self, node, candidate_name): - return ( - isinstance(node, ast.Name) - and node.id == candidate_name - and candidate_name in self.candidates_params - ) + return (isinstance(node, ast.Name) and node.id == candidate_name and candidate_name in self.candidates_params) def _extract_pid_multiplied_candidate(self, node, pid_dim): if node is None: @@ -486,10 +465,7 @@ def visit_For(self, node): if isinstance(node.iter, ast.Call) and len(node.iter.args) == 3: step_expr = node.iter.args[2] for_loop_param = self._extract_unique_candidate(step_expr) - if ( - for_loop_param is not None - and for_loop_param not in self.candidates_params_for_loop - ): + if (for_loop_param is not None and for_loop_param not in self.candidates_params_for_loop): self.candidates_params_for_loop.append(for_loop_param) self.generic_visit(node) @@ -541,10 +517,7 @@ def _extract_unique_candidate(self, expr): """ if expr is None: return None - candidates = [ - param for param in self.candidates_params - if self.contains_target_var(expr, param) - ] + candidates = [param for param in self.candidates_params if self.contains_target_var(expr, param)] if len(candidates) == 1: return candidates[0] return None @@ -580,7 +553,7 @@ def __init__(self, func_ast: ast.AST, keys: Dict[str, str]): """ super().__init__(func_ast, keys) self.reduction_axes = list() - self.reduction_func = ('sum', 'xor_sum', 'max', 'min', 'argmax', 'argmin') # tl.xxx + self.reduction_func = ('sum', 'xor_sum', 'max', 'min', 'argmax', 'argmin') # tl.xxx self.ndim = 1 def parse(self) -> List[str]: @@ -590,16 +563,16 @@ def parse(self) -> List[str]: def visit_Assign(self, node): self._scan_subscripts(node.value) self.generic_visit(node) - + def _scan_subscripts(self, node): if isinstance(node, ast.Subscript): ndim = self._get_subscripts_ndim(node) if ndim > self.ndim: self.ndim = ndim - + for child in ast.iter_child_nodes(node): self._scan_subscripts(child) - + def _get_subscripts_ndim(self, subscript_node): slice_node = subscript_node.slice @@ -622,7 +595,7 @@ def visit_Call(self, node): return if func.attr not in self.reduction_func: return - + axis_dim = None args = node.args if len(args) == 1: @@ -635,7 +608,7 @@ def visit_Call(self, node): elif len(args) == 2: # Axis passed as positional argument. Check the second param axis_dim = self.get_axis_dim(args[1]) - + else: raise ValueError("Reduction funtions args error") @@ -656,7 +629,7 @@ def get_axis_dim(self, node): raise ValueError(f"Reduction function axis error, got: {ast.dump(node)}") if not isinstance(axis_dim, int): - raise ValueError("Reduction function axis must be an integer, " + raise ValueError("Reduction function axis must be an integer, " f"got {type(node.value).__name__}: {node.value}") return axis_dim diff --git a/third_party/ascend/backend/runtime/autotuner.py b/third_party/ascend/backend/runtime/autotuner.py index 5af6fe2aea..0e72d51acb 100644 --- a/third_party/ascend/backend/runtime/autotuner.py +++ b/third_party/ascend/backend/runtime/autotuner.py @@ -36,8 +36,7 @@ import triton from triton.runtime.autotuner import Autotuner, Config -from .autoparser import (LowDimsAxesParser, PtrNumsParser, ReductionAxesParser, - SplitAxesParser, TilingAxesParser) +from .autoparser import (LowDimsAxesParser, PtrNumsParser, ReductionAxesParser, SplitAxesParser, TilingAxesParser) from .utils import get_byte_per_numel, is_valid_axis_name, valid_axis_names @@ -115,10 +114,8 @@ def __init__( self.print_autotuning = os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" # Compile kernels in parallel by default for triton.runtime.JITFunction, # but not for others, e.g., LibEntry, since it's not compatible with AsyncCompileMode - self.compile_parallel = ( - isinstance(self.fn, triton.runtime.JITFunction) - and os.getenv("TRITON_AUTOTUNE_PARALLEL_COMPILE", "1") == "1" - ) + self.compile_parallel = (isinstance(self.fn, triton.runtime.JITFunction) + and os.getenv("TRITON_AUTOTUNE_PARALLEL_COMPILE", "1") == "1") def _init_axis_params(self, key, split_params, tiling_params, low_dim_axes, reduction_axes): if isinstance(key, list): @@ -192,9 +189,7 @@ def _autoparse_axis_params(self, all_args): self.persistent_reduction = True if not self.split_params: - all_split_params = self._autoparse_split_params( - self._get_constexpr_candidates() - ) + all_split_params = self._autoparse_split_params(self._get_constexpr_candidates()) self.all_split_params = dict(all_split_params) self.fixed_split_params = {} self.fixed_grid_dim_values = self._get_fixed_grid_dim_values( @@ -203,10 +198,7 @@ def _autoparse_axis_params(self, all_args): ) self.fixed_grid_dims = set(self.fixed_grid_dim_values.keys()) - fixed_grid_axes = { - axis for axis, pid_dim in self.axis_pid_dims.items() - if pid_dim in self.fixed_grid_dims - } + fixed_grid_axes = {axis for axis, pid_dim in self.axis_pid_dims.items() if pid_dim in self.fixed_grid_dims} # Only missing constexpr params are tunable, and fixed-grid axes # should not be tuned on split. @@ -233,10 +225,7 @@ def _autoparse_axis_params(self, all_args): # When split axes are provided by hints, parse axis->program_id mapping # independently for fixed-grid semantics and diagnostics. self._autoparse_axis_pid_dims() - miss_params = [ - arg for arg in miss_params - if arg not in self.split_params.values() - ] + miss_params = [arg for arg in miss_params if arg not in self.split_params.values()] if not self.tiling_params: self.tiling_params = self._autoparse_tiling_params(miss_params) miss_params = [arg for arg in miss_params if arg not in self.tiling_params.values()] @@ -397,8 +386,8 @@ def _batch_bench(self, *args, configs, **kwargs): future_kernels = [] try: with ( - ThreadPoolExecutor(max_workers=max_workers) as executor, - triton.AsyncCompileMode(executor), + ThreadPoolExecutor(max_workers=max_workers) as executor, + triton.AsyncCompileMode(executor), ): for config, fn in kernels_call.items(): future_kernels.append((config, fn(warmup=True))) @@ -485,22 +474,14 @@ def warmup(self, *args, **kwargs): max_workers = min(psutil.cpu_count(logical=False) // 2, len(pruned_configs)) with ( - ThreadPoolExecutor(max_workers=max_workers) as executor, - triton.AsyncCompileMode(executor), + ThreadPoolExecutor(max_workers=max_workers) as executor, + triton.AsyncCompileMode(executor), ): for config in pruned_configs: - ret.append(self.fn.warmup( - *args, - **kwargs, - **config.all_kwargs() - )) + ret.append(self.fn.warmup(*args, **kwargs, **config.all_kwargs())) else: for config in pruned_configs: - ret.append(self.fn.warmup( - *args, - **kwargs, - **config.all_kwargs() - )) + ret.append(self.fn.warmup(*args, **kwargs, **config.all_kwargs())) self.nargs = None return ret @@ -509,9 +490,7 @@ def _profile(self, *args, config, **meta): kernel_call = self._make_kernel_call(*args, config=config, **meta) fn = functools.partial(kernel_call, warmup=False) - do_bench_npu( - fn, prof_dir=self.auto_profile_dir, keep_res=True - ) + do_bench_npu(fn, prof_dir=self.auto_profile_dir, keep_res=True) def _autoparse_split_params(self, candidates_params: List[str]) -> Dict[str, str]: """ @@ -523,11 +502,9 @@ def _autoparse_split_params(self, candidates_params: List[str]) -> Dict[str, str self.split_axis_pid_dims = dict(getattr(parser, "split_axis_pid_dims", {})) self.axis_pid_dims = dict(getattr(parser, "axis_pid_dims", {})) if self.print_autotuning: - print( - f"Ascend autotuning parse split axes: {split_axes}, " - f"split axis pid dims: {self.split_axis_pid_dims}, " - f"axis pid dims: {self.axis_pid_dims}" - ) + print(f"Ascend autotuning parse split axes: {split_axes}, " + f"split axis pid dims: {self.split_axis_pid_dims}, " + f"axis pid dims: {self.axis_pid_dims}") return split_axes def _autoparse_axis_pid_dims(self) -> Dict[str, int]: @@ -545,10 +522,8 @@ def _autoparse_axis_pid_dims(self) -> Dict[str, int]: self.axis_pid_dims = dict(getattr(parser, "axis_pid_dims", {})) self.split_axis_pid_dims = dict(getattr(parser, "split_axis_pid_dims", {})) if self.print_autotuning: - print( - "Ascend autotuning parse axis pid dims (independent): " - f"{self.axis_pid_dims}" - ) + print("Ascend autotuning parse axis pid dims (independent): " + f"{self.axis_pid_dims}") return self.axis_pid_dims def _get_constexpr_candidates(self) -> List[str]: @@ -566,12 +541,8 @@ def _get_constexpr_candidates(self) -> List[str]: if not isinstance(arg, ast.arg): continue ann = arg.annotation - if ( - isinstance(ann, ast.Attribute) - and isinstance(ann.value, ast.Name) - and ann.value.id == "tl" - and ann.attr == "constexpr" - ): + if (isinstance(ann, ast.Attribute) and isinstance(ann.value, ast.Name) and ann.value.id == "tl" + and ann.attr == "constexpr"): constexpr_names.append(arg.arg) break return constexpr_names @@ -590,7 +561,7 @@ def _get_fixed_grid_dim_values(self, grid, all_args: Dict[str, object] = None) - def _extract_fixed_grid_dims(self, grid) -> Dict[int, int]: if isinstance(grid, int): - grid = (grid,) + grid = (grid, ) if not isinstance(grid, (tuple, list)): return {} fixed_dims = {} @@ -601,7 +572,7 @@ def _extract_fixed_grid_dims(self, grid) -> Dict[int, int]: def _normalize_grid_tuple(self, grid_out): if isinstance(grid_out, int): - return (grid_out,) + return (grid_out, ) if isinstance(grid_out, (tuple, list)): return tuple(grid_out) return None diff --git a/third_party/ascend/backend/runtime/tile_generator.py b/third_party/ascend/backend/runtime/tile_generator.py index 4e30eb6df3..a4e2a9a83c 100644 --- a/third_party/ascend/backend/runtime/tile_generator.py +++ b/third_party/ascend/backend/runtime/tile_generator.py @@ -93,9 +93,7 @@ def __init__( :param dual_reduction: performing reduction on more than one axis. :param persistent_reduction: there is no splitting in reduction axis. """ - self._validate_axis( - axis_sizes, split_params, fixed_split_params, tiling_params, low_dims - ) + self._validate_axis(axis_sizes, split_params, fixed_split_params, tiling_params, low_dims) axis_dict = {} idx = 0 @@ -127,9 +125,7 @@ def __init__( self.axis_info = list(axis_dict.values()) self.split_axis = [x for x in axis_dict.values() if x.is_split_axis] - self.tunable_split_axis = [ - x for x in axis_dict.values() if x.is_tunable_split_axis - ] + self.tunable_split_axis = [x for x in axis_dict.values() if x.is_tunable_split_axis] self.tiling_axis = [x for x in axis_dict.values() if x.is_tiling_axis] self.low_dims_axis = [x for x in axis_dict.values() if x.name in low_dims] self.dtype = dtype @@ -194,16 +190,16 @@ def __init__(self, kernel_meta: KernelMeta): self.is_simt_mode = kernel_meta.is_simt_mode local_mem_size = (rf_size_in_kbytes if self.is_simt_mode else ub_size_in_kbytes) self.max_numel_threshold = local_mem_size * 1024 // self.dtype_bytes // self.num_buffers - self.max_total_numel = functools.reduce(lambda x, y: x * y, [x.block_size for x in self.blocks]) if self.blocks else 1 + self.max_total_numel = functools.reduce(lambda x, y: x * y, [x.block_size + for x in self.blocks]) if self.blocks else 1 self.small_kernel = self.max_total_numel < 128 * 1024 self.tiny_kernel = self.max_total_numel <= 32 * 1024 - self.stop_numel = min(1024 // self.dtype_bytes, self.max_total_numel // (num_vector_core * 2)) if self.small_kernel else 1024 // self.dtype_bytes + self.stop_numel = min(1024 // self.dtype_bytes, self.max_total_numel // + (num_vector_core * 2)) if self.small_kernel else 1024 // self.dtype_bytes self.max_programs_num = 65535 self.tiny_program_threshold = num_vector_core // 8 self.tiny_per_program_cap = 1 - self.tiny_low_program_hist = { - p: 0 for p in range(1, self.tiny_program_threshold + 1) - } + self.tiny_low_program_hist = {p: 0 for p in range(1, self.tiny_program_threshold + 1)} self.tiny_low_program_active = False self.tiny_low_program_tile_floor = 0 @@ -269,11 +265,8 @@ def fill_config(self, cfg, candi_block): cfg[block_info.block_name] = curr_numel if axis.is_tiling_axis: tiling_numel = self.aligned_numel(block_info.sub_block_size) - cfg[block_info.sub_block_name] = ( - tiling_numel - if self.is_simt_mode - else min(tiling_numel, candi_block[axis.index]) - ) + cfg[block_info.sub_block_name] = (tiling_numel if self.is_simt_mode else min( + tiling_numel, candi_block[axis.index])) def find_config(self, cfg): for config_var in self.configs: @@ -282,17 +275,12 @@ def find_config(self, cfg): return False def _try_add_tiny_low_program_config(self, total_programs): - if ( - not self.tiny_kernel - or total_programs < 1 - or total_programs > self.tiny_program_threshold - ): + if (not self.tiny_kernel or total_programs < 1 or total_programs > self.tiny_program_threshold): return if self.tiny_low_program_hist.get(total_programs, 0) >= self.tiny_per_program_cap: return - candi_block = tuple([x.block_size for x in self.blocks]) if self.add_to_configs(list(candi_block)): if candi_block not in self.candidate_blocks: @@ -300,19 +288,13 @@ def _try_add_tiny_low_program_config(self, total_programs): if not self.tiny_low_program_active: self.tiny_low_program_active = True self.tiny_low_program_tile_floor = self.calculate_tile_numel() - self.tiny_low_program_hist[total_programs] = ( - self.tiny_low_program_hist.get(total_programs, 0) + 1 - ) + self.tiny_low_program_hist[total_programs] = (self.tiny_low_program_hist.get(total_programs, 0) + 1) def _calc_total_programs(self, candi_block=None): grids = [] for axis in self.kernel_meta.split_axis: numel = self.numels[axis.index] - block_size = ( - self.blocks[axis.index].block_size - if candi_block is None - else candi_block[axis.index] - ) + block_size = (self.blocks[axis.index].block_size if candi_block is None else candi_block[axis.index]) programs = (numel + block_size - 1) // block_size grids.append(programs) @@ -328,15 +310,10 @@ def add_to_configs(self, candi_block): total_programs = self._calc_total_programs(candi_block) program_threshold = self.tiny_program_threshold if self.small_kernel else num_vector_core // 2 if total_programs <= program_threshold: - tiny_low_program_threshold = max( - self.stop_numel, self.tiny_low_program_tile_floor // 2 - ) + tiny_low_program_threshold = max(self.stop_numel, self.tiny_low_program_tile_floor // 2) stop_numel_threshold = max(stop_numel_threshold, tiny_low_program_threshold) - if ( - tile_numel <= self.max_numel_threshold - and tile_numel >= stop_numel_threshold - and not self.find_config(newcfg) - ): + if (tile_numel <= self.max_numel_threshold and tile_numel >= stop_numel_threshold + and not self.find_config(newcfg)): self.configs.append(Config(newcfg, num_warps=1, num_stages=1)) return True return False diff --git a/third_party/ascend/backend/spec/triton/language/semantic.py b/third_party/ascend/backend/spec/triton/language/semantic.py index 50cd6a8f80..af88701b48 100644 --- a/third_party/ascend/backend/spec/triton/language/semantic.py +++ b/third_party/ascend/backend/spec/triton/language/semantic.py @@ -1183,9 +1183,8 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ if mask is None: load_handle = builder.create_load(ptr.handle, cache, eviction, is_volatile) else: - load_handle = builder.create_masked_load( - ptr.handle, mask.handle, other.handle if other else None, cache, eviction, is_volatile - ) + load_handle = builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, + eviction, is_volatile) if is_bool: load_handle.set_attr("was_bool_to_int8", builder.get_bool_attr(True)) @@ -1664,8 +1663,10 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() if is_compile_on_910_95: - assert lhs.dtype in [tl.float16, tl.bfloat16, tl.uint8, tl.float8e5, tl.float8e4nv], f"lhs matrix dtype must be in [bf16, fp16, uint8, e5m2, e4m3]" - assert rhs.dtype in [tl.float16, tl.bfloat16, tl.uint8, tl.float8e5, tl.float8e4nv], f"rhs matrix dtype must be in [bf16, fp16, uint8, e5m2, e4m3]" + assert lhs.dtype in [tl.float16, tl.bfloat16, tl.uint8, tl.float8e5, + tl.float8e4nv], f"lhs matrix dtype must be in [bf16, fp16, uint8, e5m2, e4m3]" + assert rhs.dtype in [tl.float16, tl.bfloat16, tl.uint8, tl.float8e5, + tl.float8e4nv], f"rhs matrix dtype must be in [bf16, fp16, uint8, e5m2, e4m3]" else: assert lhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"lhs matrix dtype must be bf16 or fp16" assert rhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"rhs matrix dtype must be bf16 or fp16" @@ -1685,9 +1686,11 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) - assert isinstance(lhs_scale, tl.tensor) and (lhs_scale.dtype == tl.int8 or lhs_scale.dtype == tl.uint8), f"lhs_scale must be int8 or uint8 tensor" + assert isinstance(lhs_scale, tl.tensor) and (lhs_scale.dtype == tl.int8 or lhs_scale.dtype + == tl.uint8), f"lhs_scale must be int8 or uint8 tensor" if not rhs_scale_is_none: - assert isinstance(rhs_scale, tl.tensor) and (rhs_scale.dtype == tl.int8 or rhs_scale.dtype == tl.uint8), f"rhs_scale must be int8 or uint8 tensor" + assert isinstance(rhs_scale, tl.tensor) and (rhs_scale.dtype == tl.int8 or rhs_scale.dtype + == tl.uint8), f"rhs_scale must be int8 or uint8 tensor" lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) diff --git a/third_party/ascend/backend/spec/triton/runtime/ascend_interpreter.py b/third_party/ascend/backend/spec/triton/runtime/ascend_interpreter.py index 6cfb26e682..9482e57ca7 100644 --- a/third_party/ascend/backend/spec/triton/runtime/ascend_interpreter.py +++ b/third_party/ascend/backend/spec/triton/runtime/ascend_interpreter.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Ascend-specific interpreter builder extensions. @@ -41,6 +40,7 @@ class AscendReduceOps(ReduceOps): Ascend reduce operations that override only the apply_impl logic. All other methods (sum, min_max, generic_reduce, etc.) are inherited from ReduceOps. """ + def apply_impl(self, input_param): if self.combine_fn == tl.standard._argmin_combine_tie_break_left: return self.min_max(input_param[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) @@ -48,7 +48,7 @@ def apply_impl(self, input_param): return self.min_max(input_param[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) # Ta has modified the implemention of tl.max elif self.combine_fn == tl.standard._elementwise_max_default: - return self.min_max(input_param[0], val_reduce_op=np.nanmax, idx_reduce_op=None) + return self.min_max(input_param[0], val_reduce_op=np.nanmax, idx_reduce_op=None) elif self.combine_fn == tl.standard._elementwise_max_propagate_nan: return self.min_max(input_param[0], val_reduce_op=np.max, idx_reduce_op=None) elif self.combine_fn == tl.standard._elementwise_min: @@ -70,7 +70,7 @@ def _compute_strides(shape): class AscendInterpreterBuilder(InterpreterBuilder): """ Extended InterpreterBuilder with Ascend-specific extension operations. - + This class inherits from InterpreterBuilder and adds support for: - get_element (extract_scalar): Extract scalar from tensor using indices - insert_slice: Insert sub-tensor into full tensor @@ -78,11 +78,11 @@ class AscendInterpreterBuilder(InterpreterBuilder): - index_select_simd: SIMD gather operation - get_sub_vec_id: Get vector core ID for 1:2 ratio emulation - Synchronization operations: sync_block_set/wait/all - + All extension operations handle both TensorHandle and Python int types for interpreter mode compatibility. """ - + def __init__(self) -> None: super().__init__() # Sub-vector core ID for simulating 1:2 hardware ratio @@ -93,7 +93,7 @@ def __init__(self) -> None: def to_int_val(self, val): """ Convert a value (int or TensorHandle) to Python int. - + :param val: Value to convert (int, TensorHandle, or other) :return: Python integer """ @@ -123,18 +123,18 @@ def _dummpy_scope(*args, **kwargs): tl.extra.cann.extension.parallel = _new_range tl.reduce = _new_reduce tl.core.reduce = _new_reduce - + def get_additional_reserved_keywords(self): """ Return additional reserved keywords specific to Ascend backend. - + These keywords will be filtered out from kernel call arguments and are not supported by the interpreter. - + :return: List of additional reserved keyword strings """ return [ - "multibuffer", # Ascend-specific memory buffering + "multibuffer", # Ascend-specific memory buffering "debug", "optimize_dynamic_offset", "enable_mixed_cv", @@ -144,21 +144,21 @@ def get_additional_reserved_keywords(self): # "ascend_option1", # "ascend_option2", ] - + def patch_extensions(self, fn): """ Patch Ascend extension modules for the given function. - + This method handles all Ascend-specific extension module patching, including CANN extensions and any other extension modules found in the function's global namespace. - + :param fn: The kernel function to patch extensions for """ # Import _patch_builtin from parent module from .interpreter import _patch_builtin self._patch_lang_ascend(fn) - + # Patch all modules in fn's globals that might be extension modules for name, value in list(fn.__globals__.items()): if value is None: @@ -176,7 +176,7 @@ def patch_extensions(self, fn): pass except Exception: pass - + # Also try importing extension directly as fallback try: import triton.language.extra.cann.extension as extension @@ -184,15 +184,15 @@ def patch_extensions(self, fn): except (ImportError, AttributeError): # Extension module not available (e.g., non-Ascend backend) pass - + def execute_with_sub_vec_simulation(self, fn, args, grid): """ Execute function with optional 1:2 sub-vector core simulation. - + Sub-vector simulation is only activated when create_get_sub_vec_id() is actually called during execution. This avoids unnecessary double execution for code that doesn't use sub_vec_id functionality. - + :param fn: The kernel function to execute :param args: Function arguments :param grid: Grid dimensions (nx, ny, nz) @@ -200,14 +200,14 @@ def execute_with_sub_vec_simulation(self, fn, args, grid): # Reset simulation flag at the beginning of each execution self._sub_vec_simulation_enabled = False self.sub_vec_id = 0 - + # First, try a single execution to see if sub_vec_id is used for x in range(grid[0]): for y in range(grid[1]): for z in range(grid[2]): self.set_grid_idx(x, y, z) fn(**args) - + # If sub_vec_id was accessed during execution, run again with sub_vec_id=1 if self._sub_vec_simulation_enabled: self.sub_vec_id = 1 @@ -224,9 +224,9 @@ def execute_with_sub_vec_simulation(self, fn, args, grid): def create_extract_scalar(self, tensor_handle, indices): """ Extract a scalar from a tensor using indices (equivalent to get_element). - + Handles mixed types: Python int (from loops) and TensorHandle (from other ops). - + :param tensor_handle: The tensor to extract from (TensorHandle) :param indices: List of scalar indices (can be TensorHandle or Python int) :return: Scalar value as TensorHandle @@ -242,9 +242,10 @@ def create_extract_scalar(self, tensor_handle, indices): index_values.append(int(idx.data.item()) if hasattr(idx.data, 'item') else int(idx.data)) else: # Fallback: try to extract data - index_values.append(int(idx.data.item()) if hasattr(idx, 'data') and hasattr(idx.data, 'item') - else int(idx.data) if hasattr(idx, 'data') else int(idx)) - + index_values.append( + int(idx.data.item()) if hasattr(idx, 'data') and hasattr(idx.data, 'item') else + int(idx.data) if hasattr(idx, 'data') else int(idx)) + # Extract the scalar value scalar_data = tensor_handle.data[tuple(index_values)] return TensorHandle(np.array([scalar_data]), tensor_handle.dtype.scalar) @@ -252,9 +253,9 @@ def create_extract_scalar(self, tensor_handle, indices): def create_insert_slice(self, full_tensor, sub_tensor, offsets, sizes, strides): """ Insert a sub-tensor into a full tensor at specified offsets. - + Handles mixed types: Python int and TensorHandle for offsets. - + :param full_tensor: The full tensor (destination, TensorHandle) :param sub_tensor: The sub-tensor to insert (TensorHandle) :param offsets: List of offset TensorHandle objects or Python ints @@ -263,7 +264,7 @@ def create_insert_slice(self, full_tensor, sub_tensor, offsets, sizes, strides): :return: Modified tensor with sub_tensor inserted (TensorHandle) """ result = full_tensor.data.copy() - + # Convert offsets from TensorHandle or Python int to integers offset_values = [] for off in offsets: @@ -275,9 +276,10 @@ def create_insert_slice(self, full_tensor, sub_tensor, offsets, sizes, strides): offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) else: # Fallback - offset_values.append(int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') - else int(off.data) if hasattr(off, 'data') else int(off)) - + offset_values.append( + int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') else + int(off.data) if hasattr(off, 'data') else int(off)) + # Build slices for insertion slices = [] for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): @@ -286,18 +288,18 @@ def create_insert_slice(self, full_tensor, sub_tensor, offsets, sizes, strides): slices.append(slice(offset, end)) else: slices.append(slice(offset, end, stride)) - + # Insert the sub-tensor result[tuple(slices)] = sub_tensor.data - + return TensorHandle(result, full_tensor.dtype.scalar) def create_extract_slice(self, full_tensor, offsets, sizes, strides): """ Extract a slice from a full tensor. - + Handles mixed types: Python int and TensorHandle for offsets. - + :param full_tensor: The full tensor (TensorHandle) :param offsets: List of offset TensorHandle objects or Python ints :param sizes: List of size integers @@ -315,9 +317,10 @@ def create_extract_slice(self, full_tensor, offsets, sizes, strides): offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) else: # Fallback - offset_values.append(int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') - else int(off.data) if hasattr(off, 'data') else int(off)) - + offset_values.append( + int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') else + int(off.data) if hasattr(off, 'data') else int(off)) + # Build slices for extraction slices = [] for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): @@ -326,19 +329,19 @@ def create_extract_slice(self, full_tensor, offsets, sizes, strides): slices.append(slice(offset, end)) else: slices.append(slice(offset, end, stride)) - + # Extract the slice extracted = full_tensor.data[tuple(slices)] - + return TensorHandle(extracted, full_tensor.dtype.scalar) def create_index_select_simd(self, src_ptr, index_tensor, dim, src_shape, src_offset, read_shape, result_shape): """ SIMD index_select operation (gather with indices along a dimension). - + This is a hardware-accelerated gather operation that selects elements from a tensor using a set of indices along a specified dimension. - + :param src_ptr: Source tensor pointer (TensorHandle), just ptr address, not value :param index_tensor: 1D tensor of indices (TensorHandle or array) :param dim: Dimension to select from (int) @@ -353,13 +356,13 @@ def create_index_select_simd(self, src_ptr, index_tensor, dim, src_shape, src_of src_offset_vals = [self.to_int_val(s) if s != -1 else -1 for s in src_offset] read_shape_vals = [self.to_int_val(r) if r != -1 else -1 for r in read_shape] result_shape_vals = [self.to_int_val(r) for r in result_shape] - + # Get index values - handle both array and TensorHandle if isinstance(index_tensor, TensorHandle): indices = index_tensor.data.flatten() else: indices = np.asarray(index_tensor).flatten() - + # Ensure indices are integers if indices.dtype not in [np.int32, np.int64]: indices = indices.astype(np.int32) @@ -369,10 +372,10 @@ def create_index_select_simd(self, src_ptr, index_tensor, dim, src_shape, src_of dtype_np = _get_np_dtype(dtype_tt) src_strides = _compute_strides(src_shape_vals) base_addr = int(src_ptr.data.item()) - + # Create result tensor result = np.empty(result_shape_vals, dtype=dtype_np) - + # Perform index_select: for each index, read the specified data for out_idx, in_idx in enumerate(indices): in_idx = int(in_idx) @@ -423,25 +426,25 @@ def create_index_select_simd(self, src_ptr, index_tensor, dim, src_shape, src_of else: result_slices.append(slice(None)) result[tuple(result_slices)] = tile_data - + return TensorHandle(result, dtype_tt) def create_get_sub_vec_id(self): """ Get the Vector Core index on the AI Core. - + In Interpreter mode, simulate multiple vector cores by maintaining a sub_vec_id counter. This is used for 1:2 hardware ratio emulation where different vector cores process different partitions of the data. - + The first call to this method enables sub_vec_simulation, causing the kernel to be executed twice (once for each sub_vec_id value). - + :return: Vector Core ID as TensorHandle (int64, scalar) """ # Enable sub_vec_id simulation when this method is called self._sub_vec_simulation_enabled = True - + # Return the current sub_vec_id vec_id = np.int64(self.sub_vec_id) return TensorHandle(np.array([vec_id], dtype=np.int64), tl.int64) @@ -449,10 +452,10 @@ def create_get_sub_vec_id(self): def sync_block_set(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): """ Set synchronization event between compute and vector units. - + In Interpreter mode, this is a no-op since we execute single-threaded. Synchronization is not needed in CPU emulation. - + :param sender: Source unit ("cube" or "vector") :param receiver: Destination unit ("cube" or "vector") :param event_id: Event ID (TensorHandle) @@ -465,10 +468,10 @@ def sync_block_set(self, sender, receiver, event_id, sender_pipe_value, receiver def sync_block_wait(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): """ Wait for synchronization event between compute and vector units. - + In Interpreter mode, this is a no-op since we execute single-threaded. Synchronization is not needed in CPU emulation. - + :param sender: Source unit ("cube" or "vector") :param receiver: Destination unit ("cube" or "vector") :param event_id: Event ID (TensorHandle) @@ -481,10 +484,10 @@ def sync_block_wait(self, sender, receiver, event_id, sender_pipe_value, receive def sync_block_all(self, mode, event_id): """ Synchronize all compute or vector units globally. - + In Interpreter mode, this is a no-op since we execute single-threaded. Synchronization is not needed in CPU emulation. - + :param mode: Sync mode ("all_cube", "all_vector", "all", "all_sub_vector") :param event_id: Event ID (int, constexpr, or TensorHandle) """ @@ -510,8 +513,7 @@ def create_sort(self, ptr_data, dim: int, descending: bool): ndim = ptr_data.data.ndim norm_dim = dim if dim >= 0 else dim + ndim if not (0 <= norm_dim < ndim): - raise IndexError( - f"Dimension out of range(expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})") + raise IndexError(f"Dimension out of range(expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})") if descending: sorted_asc = np.sort(ptr_data.data, axis=norm_dim) @@ -524,11 +526,11 @@ def create_flip(self, ptr_data, dim): ndim = ptr_data.data.ndim norm_dim = dim if dim >= 0 else dim + ndim if not (0 <= norm_dim < ndim): - raise IndexError( - f"Dimension out of range(expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})") + raise IndexError(f"Dimension out of range(expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})") return TensorHandle(np.flip(ptr_data.data, axis=norm_dim), ptr_data.dtype.scalar) - def create_gather_out_to_ub(self, src_ptr, index_tensor, index_boundary, dim, src_stride, end_offset, start_offset, other=None): + def create_gather_out_to_ub(self, src_ptr, index_tensor, index_boundary, dim, src_stride, end_offset, start_offset, + other=None): # Convert src_stride, start_offset, end_offset to integers src_stride_vals = [self.to_int_val(s) for s in src_stride] start_offset_vals = [self.to_int_val(s) for s in start_offset] @@ -548,7 +550,7 @@ def create_gather_out_to_ub(self, src_ptr, index_tensor, index_boundary, dim, sr for idx in range(total_elements): coord = np.unravel_index(idx, index_shape) all_coords.append(coord) - + # Compute the source tensor coordinates for each position in all_coords src_coords = [] for coord in all_coords: @@ -578,7 +580,7 @@ def create_gather_out_to_ub(self, src_ptr, index_tensor, index_boundary, dim, sr address = base_addr + offset * element_size addresses.append(address) valid_mask.append(True) - + addr_array = np.array(addresses, dtype=np.uint64) mask_array = np.array(valid_mask, dtype=bool) @@ -597,16 +599,17 @@ def create_gather_out_to_ub(self, src_ptr, index_tensor, index_boundary, dim, sr result = flat_result.reshape(index_shape) return TensorHandle(result, dtype_tt) - def create_scatter_ub_to_out(self, dst_ptr, value_tensor, index_tensor, index_boundary, dim, dst_stride, end_offset, start_offset): + def create_scatter_ub_to_out(self, dst_ptr, value_tensor, index_tensor, index_boundary, dim, dst_stride, end_offset, + start_offset): # Convert dst_stride, start_offset, end_offset to integers dst_stride_vals = [self.to_int_val(s) for s in dst_stride] start_offset_vals = [self.to_int_val(s) for s in start_offset] - end_offset_vals = [self.to_int_val(s) for s in end_offset] + end_offset_vals = [self.to_int_val(s) for s in end_offset] # Element type dtype_tt = dst_ptr.get_element_ty() dtype_np = _get_np_dtype(dtype_tt) - element_size = np.dtype(dtype_np).itemsize + element_size = np.dtype(dtype_np).itemsize base_addr = int(dst_ptr.data.item()) index_shape = index_tensor.data.shape @@ -633,7 +636,7 @@ def create_scatter_ub_to_out(self, dst_ptr, value_tensor, index_tensor, index_bo dst_coord = [] for d in range(index_rank): if d == dim: - dst_coord.append(start_offset_vals[d] + index_value) + dst_coord.append(start_offset_vals[d] + index_value) else: dst_coord.append(start_offset_vals[d] + coord[d]) offset = 0 @@ -642,28 +645,29 @@ def create_scatter_ub_to_out(self, dst_ptr, value_tensor, index_tensor, index_bo address = base_addr + offset * element_size addresses.append(address) valid_mask.append(True) - + addr_array = np.array(addresses, dtype=np.uint64) mask_array = np.array(valid_mask, dtype=bool) _interpreter.store(addr_array, flat_values, mask_array) - def create_index_put(self, dst_ptr, index_tensor, value_tensor, dim, index_boundary, end_offset, start_offset, dst_stride): + def create_index_put(self, dst_ptr, index_tensor, value_tensor, dim, index_boundary, end_offset, start_offset, + dst_stride): # Convert dst_stride, start_offset, end_offset_ to integers dst_stride_vals = [self.to_int_val(s) for s in dst_stride] start_offset_vals = [self.to_int_val(s) for s in start_offset] - end_offset_vals = [self.to_int_val(s) for s in end_offset] + end_offset_vals = [self.to_int_val(s) for s in end_offset] # Element type dtype_tt = dst_ptr.get_element_ty() dtype_np = _get_np_dtype(dtype_tt) - element_size = np.dtype(dtype_np).itemsize + element_size = np.dtype(dtype_np).itemsize base_addr = int(dst_ptr.data.item()) value_shape = value_tensor.data.shape value_rank = len(value_shape) - - flat_values = value_tensor.data.flatten() + + flat_values = value_tensor.data.flatten() total_elements = flat_values.size # Generate coordinates @@ -692,7 +696,7 @@ def create_index_put(self, dst_ptr, index_tensor, value_tensor, dim, index_bound dst_coord = [] for d in range(value_rank): if d == dim: - dst_coord.append(index_value) + dst_coord.append(index_value) else: dst_coord.append(start_offset_vals[d] + coord[d]) offset = 0 @@ -710,26 +714,23 @@ def create_index_put(self, dst_ptr, index_tensor, value_tensor, dim, index_bound _interpreter.store(addr_array, values_array, mask_array) def get_bool_attr(self, val): - return bool(val) + return bool(val) def get_unit_attr(self): return None # None valule in compile_hint return uint - def get_int32_attr(self, val): + def get_int32_attr(self, val): return int(val) def get_str_attr(self, val): return str(val) - def get_i64_array_attr(self, val): + def get_i64_array_attr(self, val): return [int(x) for x in val] def create_annotation_mark(self, ptr_data, hint_name: str, hint_val): if hint_name == "overflow_mode": raise ValueError(f"overflow_mode is not supported in interpreter mode, may have accuracy issues") else: - warnings.warn( - f"compile_hint '{hint_name}' is not supported in interpreter mode, just pass it", - UserWarning, - stacklevel=2 - ) \ No newline at end of file + warnings.warn(f"compile_hint '{hint_name}' is not supported in interpreter mode, just pass it", UserWarning, + stacklevel=2) diff --git a/third_party/ascend/backend/spec/triton/runtime/autotuner.py b/third_party/ascend/backend/spec/triton/runtime/autotuner.py index 5f4fd8f6b9..993aee56ad 100644 --- a/third_party/ascend/backend/spec/triton/runtime/autotuner.py +++ b/third_party/ascend/backend/spec/triton/runtime/autotuner.py @@ -410,45 +410,28 @@ def decorator(fn): _ALL_PARAMS = { - "BM_list", "BN_list", - "multibuffer", "unit_flag", - "limit_auto_multi_buffer_only_for_local_buffer", - "limit_auto_multi_buffer_of_local_buffer", - "set_workspace_multibuffer", - "enable_hivm_auto_cv_balance", - "tile_mix_vector_loop", - "tile_mix_cube_loop" + "BM_list", "BN_list", "multibuffer", "unit_flag", "limit_auto_multi_buffer_only_for_local_buffer", + "limit_auto_multi_buffer_of_local_buffer", "set_workspace_multibuffer", "enable_hivm_auto_cv_balance", + "tile_mix_vector_loop", "tile_mix_cube_loop" } _DEFAULTS = { - "BM_list": [16, 32, 64, 128], - "BN_list": [16, 32, 64, 128], - "multibuffer": [False], - "unit_flag": [False], - "limit_auto_multi_buffer_only_for_local_buffer": [True], - "limit_auto_multi_buffer_of_local_buffer": ["no-l0c"], - "set_workspace_multibuffer": [2, 4], - "enable_hivm_auto_cv_balance": [True], - "tile_mix_vector_loop": [2, 4], + "BM_list": [16, 32, 64, 128], "BN_list": [16, 32, 64, 128], "multibuffer": [False], "unit_flag": [False], + "limit_auto_multi_buffer_only_for_local_buffer": [True], "limit_auto_multi_buffer_of_local_buffer": ["no-l0c"], + "set_workspace_multibuffer": [2, 4], "enable_hivm_auto_cv_balance": [True], "tile_mix_vector_loop": [2, 4], "tile_mix_cube_loop": [2, 4] } _VALID_VALUES = { - "limit_auto_multi_buffer_of_local_buffer": ["no-limit", "no-l0c"], - "set_workspace_multibuffer": [2, 4], - "tile_mix_vector_loop": [2, 4, 8], - "tile_mix_cube_loop": [2, 4, 8] + "limit_auto_multi_buffer_of_local_buffer": ["no-limit", "no-l0c"], "set_workspace_multibuffer": [2, 4], + "tile_mix_vector_loop": [2, 4, 8], "tile_mix_cube_loop": [2, 4, 8] } _CUBE_PARAMS = {"multibuffer", "unit_flag", "limit_auto_multi_buffer_of_local_buffer"} _MIXCV_PARAMS = { - "multibuffer", "unit_flag", - "limit_auto_multi_buffer_only_for_local_buffer", - "limit_auto_multi_buffer_of_local_buffer", - "set_workspace_multibuffer", - "enable_hivm_auto_cv_balance", - "tile_mix_vector_loop", - "tile_mix_cube_loop" + "multibuffer", "unit_flag", "limit_auto_multi_buffer_only_for_local_buffer", + "limit_auto_multi_buffer_of_local_buffer", "set_workspace_multibuffer", "enable_hivm_auto_cv_balance", + "tile_mix_vector_loop", "tile_mix_cube_loop" } _VECTOR_PARAMS = {"multibuffer"} @@ -466,37 +449,28 @@ def _check_int_in_set(val: List[Any], valid_set: set, param_name: str) -> bool: _VALIDATION_RULES = { - "multibuffer": { - "desc": "must be non-empty list/tuple of boolean values", - "check": _check_boolean_list - }, - "unit_flag": { - "desc": "must be non-empty list/tuple of boolean values", - "check": _check_boolean_list - }, - "limit_auto_multi_buffer_only_for_local_buffer": { - "desc": "must be non-empty list/tuple of boolean values", - "check": _check_boolean_list - }, - "limit_auto_multi_buffer_of_local_buffer": { - "desc": f"must be one or more of: {_VALID_VALUES['limit_auto_multi_buffer_of_local_buffer']}", - "check": lambda val, param_name: _check_string_in_set(val, _VALID_VALUES['limit_auto_multi_buffer_of_local_buffer'], "limit_auto_multi_buffer_of_local_buffer") - }, - "set_workspace_multibuffer": { - "desc": f"must be one or more of: {_VALID_VALUES['set_workspace_multibuffer']}", - "check": lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['set_workspace_multibuffer'], "set_workspace_multibuffer") - }, - "enable_hivm_auto_cv_balance": { - "desc": "must be non-empty list/tuple of boolean values", - "check": _check_boolean_list - }, - "tile_mix_vector_loop": { - "desc": f"must be one or more of: {_VALID_VALUES['tile_mix_vector_loop']}", - "check": lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['tile_mix_vector_loop'], "tile_mix_vector_loop") - }, - "tile_mix_cube_loop": { - "desc": f"must be one or more of: {_VALID_VALUES['tile_mix_cube_loop']}", - "check": lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['tile_mix_cube_loop'], "tile_mix_cube_loop") + "multibuffer": {"desc": "must be non-empty list/tuple of boolean values", "check": _check_boolean_list}, + "unit_flag": {"desc": "must be non-empty list/tuple of boolean values", "check": + _check_boolean_list}, "limit_auto_multi_buffer_only_for_local_buffer": + {"desc": "must be non-empty list/tuple of boolean values", "check": + _check_boolean_list}, "limit_auto_multi_buffer_of_local_buffer": { + "desc": + f"must be one or more of: {_VALID_VALUES['limit_auto_multi_buffer_of_local_buffer']}", "check": + lambda val, param_name: _check_string_in_set(val, _VALID_VALUES['limit_auto_multi_buffer_of_local_buffer'], + "limit_auto_multi_buffer_of_local_buffer") + }, "set_workspace_multibuffer": { + "desc": + f"must be one or more of: {_VALID_VALUES['set_workspace_multibuffer']}", "check": + lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['set_workspace_multibuffer'], + "set_workspace_multibuffer") + }, "enable_hivm_auto_cv_balance": + {"desc": "must be non-empty list/tuple of boolean values", "check": _check_boolean_list}, "tile_mix_vector_loop": { + "desc": + f"must be one or more of: {_VALID_VALUES['tile_mix_vector_loop']}", "check": + lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['tile_mix_vector_loop'], "tile_mix_vector_loop") + }, "tile_mix_cube_loop": { + "desc": f"must be one or more of: {_VALID_VALUES['tile_mix_cube_loop']}", "check": + lambda val, param_name: _check_int_in_set(val, _VALID_VALUES['tile_mix_cube_loop'], "tile_mix_cube_loop") } } @@ -511,13 +485,8 @@ class BaseAutotuner: validation_rules: Validation rules for parameters (described in detail below). """ - def __init__( - self, - operator_name: str, - supported_params: set, - default_params: Dict[str, Any], - validation_rules: Dict[str, Dict[str, Any]] - ): + def __init__(self, operator_name: str, supported_params: set, default_params: Dict[str, Any], + validation_rules: Dict[str, Dict[str, Any]]): self.operator_name = operator_name self.supported_params = supported_params self.default_params = default_params @@ -563,10 +532,15 @@ def get_configs(self, **kwargs: Any) -> List[triton.Config]: valid_kwargs = {k: v for k, v in kwargs.items() if k in self.supported_params} - other_kwargs = {k: v for k, v in kwargs.items() if k not in self.supported_params and k not in self.SPECIAL_PARAMS_NO_WARNING} + other_kwargs = { + k: v + for k, v in kwargs.items() + if k not in self.supported_params and k not in self.SPECIAL_PARAMS_NO_WARNING + } if other_kwargs: print( - f"[WARNING] Parameter(s) {list(other_kwargs.keys())} do not belong to {self.operator_name} and have been ignored.") + f"[WARNING] Parameter(s) {list(other_kwargs.keys())} do not belong to {self.operator_name} and have been ignored." + ) configs = [] @@ -579,7 +553,10 @@ def get_configs(self, **kwargs: Any) -> List[triton.Config]: for param_name in sorted(self.supported_params): if param_name == "limit_auto_multi_buffer_only_for_local_buffer": dynamic_params[param_name] = valid_kwargs.get(param_name, _DEFAULTS[param_name]) - elif param_name in ["set_workspace_multibuffer", "enable_hivm_auto_cv_balance", "tile_mix_vector_loop", "tile_mix_cube_loop"]: + elif param_name in [ + "set_workspace_multibuffer", "enable_hivm_auto_cv_balance", "tile_mix_vector_loop", + "tile_mix_cube_loop" + ]: if not limit_flag: dynamic_params[param_name] = valid_kwargs.get(param_name, _DEFAULTS[param_name]) else: @@ -606,26 +583,14 @@ def get_configs(self, **kwargs: Any) -> List[triton.Config]: return configs -CubeAutotuner = BaseAutotuner( - operator_name="cube", - supported_params=_CUBE_PARAMS, - default_params=_DEFAULTS, - validation_rules=_VALIDATION_RULES -) - -MixcvAutotuner = BaseAutotuner( - operator_name="mixcv", - supported_params=_MIXCV_PARAMS, - default_params=_DEFAULTS, - validation_rules=_VALIDATION_RULES -) - -VectorAutotuner = BaseAutotuner( - operator_name="vector", - supported_params=_VECTOR_PARAMS, - default_params=_DEFAULTS, - validation_rules=_VALIDATION_RULES -) +CubeAutotuner = BaseAutotuner(operator_name="cube", supported_params=_CUBE_PARAMS, default_params=_DEFAULTS, + validation_rules=_VALIDATION_RULES) + +MixcvAutotuner = BaseAutotuner(operator_name="mixcv", supported_params=_MIXCV_PARAMS, default_params=_DEFAULTS, + validation_rules=_VALIDATION_RULES) + +VectorAutotuner = BaseAutotuner(operator_name="vector", supported_params=_VECTOR_PARAMS, default_params=_DEFAULTS, + validation_rules=_VALIDATION_RULES) def get_autotune_cube_config(**kwargs: Any) -> List[triton.Config]: diff --git a/third_party/ascend/backend/spec/triton/runtime/interpreter.py b/third_party/ascend/backend/spec/triton/runtime/interpreter.py index d15232ef90..0f7be77b03 100644 --- a/third_party/ascend/backend/spec/triton/runtime/interpreter.py +++ b/third_party/ascend/backend/spec/triton/runtime/interpreter.py @@ -18,6 +18,7 @@ _has_ascend_support = False AscendInterpreterBuilder = None + def _try_import_ascend(): global _has_ascend_support, AscendInterpreterBuilder try: @@ -183,9 +184,9 @@ def _convert_float(input, input_dtype, output_dtype, rounding_mode): output_max_exponent = (1 << output_exponent_width) - 1 exponent_output = np.maximum(0, np.minimum(exponent_unclamped, output_max_exponent)) exponent_output = exponent_output.astype(output_unint_dtype) - # mark overflow index + # mark overflow index overflow_index = exponent_unclamped > output_max_exponent - 1 - + sign_output = sign.astype(output_unint_dtype) if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( @@ -214,7 +215,7 @@ def _convert_float(input, input_dtype, output_dtype, rounding_mode): significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) # covert overflow value to inf - significand_output[overflow_index & ~input_nan_index] = 0 + significand_output[overflow_index & ~input_nan_index] = 0 output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( exponent_output << output_dtype.fp_mantissa_width) | significand_output return output.reshape(input.shape) @@ -637,7 +638,6 @@ def create_splat(self, arg, shape): else: # scalar return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) - def create_atomic_cas(self, ptr, cmp, val, sem, scope): if sem not in self.ir_sem_to_interpreter_sem: raise ValueError(f"unsupported semantic {sem}") @@ -1036,7 +1036,7 @@ def _patch_lang(fn): _patch_builtin(lang.math, interpreter_builder) _patch_lang_tensor(lang.tensor) _patch_lang_core(lang) - + # Patch Ascend extensions if using AscendInterpreterBuilder if hasattr(interpreter_builder, 'patch_extensions'): interpreter_builder.patch_extensions(fn) @@ -1137,7 +1137,7 @@ def __call__(self, *args_dev, **kwargs): assert len(grid) <= 3, "grid must have at most 3 dimensions" grid = grid + (1, ) * (3 - len(grid)) interpreter_builder.set_grid_dim(*grid) - + try: # Execute kernels - sub_vec_id simulation handled by AscendInterpreterBuilder if hasattr(interpreter_builder, 'execute_with_sub_vec_simulation'): diff --git a/third_party/ascend/backend/utils.py b/third_party/ascend/backend/utils.py index efd42780e1..cf9983b78f 100644 --- a/third_party/ascend/backend/utils.py +++ b/third_party/ascend/backend/utils.py @@ -357,7 +357,6 @@ def _precompile_npu_ext(header_path, gch_path): raise RuntimeError(f"Failed to compile {gch_path}, error: {result.stderr},cmd={cc_cmd}") - def _build_npu_ext(obj_name: str, header_path, src_path, *, kernel_launcher="torch", precompile=False) -> str: suffix = sysconfig.get_config_var("EXT_SUFFIX") src_dir = os.path.dirname(src_path) @@ -563,4 +562,4 @@ def triton_support_ffts(): def triton_enable_libdevice_simt(): enable_libdevice_simt = os.getenv("TRITON_ENABLE_LIBDEVICE_SIMT", False) - return enable_libdevice_simt \ No newline at end of file + return enable_libdevice_simt diff --git a/third_party/ascend/include/AutoBlockify/AutoBlockify.h b/third_party/ascend/include/AutoBlockify/AutoBlockify.h index b3633b76ab..3ce8159ef9 100644 --- a/third_party/ascend/include/AutoBlockify/AutoBlockify.h +++ b/third_party/ascend/include/AutoBlockify/AutoBlockify.h @@ -59,7 +59,8 @@ class PropagateUnrealizedCastDown PatternRewriter &rewriter) const override; private: - void handleBlockifyLoop(scf::ForOp blockifyLoop, Operation *op, PatternRewriter &rewriter) const; + void handleBlockifyLoop(scf::ForOp blockifyLoop, Operation *op, + PatternRewriter &rewriter) const; void rewriteSplat(UnrealizedConversionCastOp op, triton::SplatOp splatOp, PatternRewriter &rewriter) const; void rewriteExpandDims(UnrealizedConversionCastOp op, @@ -88,8 +89,8 @@ class PropagateUnrealizedCastDown PatternRewriter &rewriter) const; void rewriteLoop(UnrealizedConversionCastOp op, LoopLikeOpInterface loopOp, PatternRewriter &rewriter) const; - void rewriteIf(UnrealizedConversionCastOp &op, scf::IfOp ifOp, ArrayRef indices, - PatternRewriter &rewriter) const; + void rewriteIf(UnrealizedConversionCastOp &op, scf::IfOp ifOp, + ArrayRef indices, PatternRewriter &rewriter) const; void rewriteYield(UnrealizedConversionCastOp &op, scf::YieldOp yieldOp, PatternRewriter &rewriter) const; void rewriteCondition(UnrealizedConversionCastOp op, diff --git a/third_party/ascend/include/AutoBlockify/CMakeLists.txt b/third_party/ascend/include/AutoBlockify/CMakeLists.txt index abae4f6e9f..ca4cf9f552 100644 --- a/third_party/ascend/include/AutoBlockify/CMakeLists.txt +++ b/third_party/ascend/include/AutoBlockify/CMakeLists.txt @@ -1,3 +1,3 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls --name AutoBlockify) -add_public_tablegen_target(AutoBlockifyPassIncGen) \ No newline at end of file +add_public_tablegen_target(AutoBlockifyPassIncGen) diff --git a/third_party/ascend/include/AutoBlockify/Passes.td b/third_party/ascend/include/AutoBlockify/Passes.td index 56ab1587be..7d9f1a80a3 100644 --- a/third_party/ascend/include/AutoBlockify/Passes.td +++ b/third_party/ascend/include/AutoBlockify/Passes.td @@ -12,7 +12,7 @@ def AutoBlockify : Pass<"auto-blockify", "mlir::ModuleOp"> { "mlir::triton::TritonDialect" ]; let options = [ - Option<"autoBlockifySize", "auto-blockify-size", "int", "1", + Option<"autoBlockifySize", "auto-blockify-size", "int", "1", "Apply auto blockify v2 when TRITON_ALL_BLOCKS_PARALLEL is 1." "Expand highest dimension with blockify size"> ]; diff --git a/third_party/ascend/include/AutoBlockify/Utils.h b/third_party/ascend/include/AutoBlockify/Utils.h index 639922c8ee..385fa51a10 100644 --- a/third_party/ascend/include/AutoBlockify/Utils.h +++ b/third_party/ascend/include/AutoBlockify/Utils.h @@ -25,16 +25,15 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "triton/Dialect/Triton/IR/Dialect.h" using namespace mlir; using namespace triton; constexpr llvm::StringLiteral autoBlockifySizeAttr = "auto_blockify_size"; constexpr llvm::StringLiteral logicalBlockIdAttr = "logical_block_id"; -constexpr llvm::StringLiteral autoBlockifyLoopAttr = - "auto_blockify_loop"; +constexpr llvm::StringLiteral autoBlockifyLoopAttr = "auto_blockify_loop"; constexpr llvm::StringLiteral autoBlockifyRegionOpAttr = "auto_blockify_region_op"; @@ -63,4 +62,4 @@ Operation *createBlockifyLoop(Operation *targetOp, Value logicalBlockId, Value logicalBlockNum, int autoBlockifySize, RewriterBase &rewriter); -std::optional getBlockifyLoop(Operation *op); \ No newline at end of file +std::optional getBlockifyLoop(Operation *op); diff --git a/third_party/ascend/include/TritonAffinityOpt/CMakeLists.txt b/third_party/ascend/include/TritonAffinityOpt/CMakeLists.txt index 4a804f0784..c6193d1f5d 100644 --- a/third_party/ascend/include/TritonAffinityOpt/CMakeLists.txt +++ b/third_party/ascend/include/TritonAffinityOpt/CMakeLists.txt @@ -1,3 +1,3 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonAffinityOpt) -add_public_tablegen_target(TritonAffinityOptConversionPassIncGen) \ No newline at end of file +add_public_tablegen_target(TritonAffinityOptConversionPassIncGen) diff --git a/third_party/ascend/include/TritonAffinityOpt/DAG.h b/third_party/ascend/include/TritonAffinityOpt/DAG.h index ecebcf4e3e..364c20258c 100644 --- a/third_party/ascend/include/TritonAffinityOpt/DAG.h +++ b/third_party/ascend/include/TritonAffinityOpt/DAG.h @@ -22,7 +22,8 @@ #include #include -namespace mlir { namespace AffinityDAG { +namespace mlir { +namespace AffinityDAG { enum class OpAbility { PREFER_VECTOR = 1 << 0, @@ -43,11 +44,11 @@ inline constexpr CoreType toCoreType(OpAbility ct) { return static_cast(static_cast(ct)); } -constexpr inline CoreType operator| (CoreType lhs, CoreType rhs) { +constexpr inline CoreType operator|(CoreType lhs, CoreType rhs) { return enumOp(std::bit_or<>(), lhs, rhs); } -inline CoreType operator& (CoreType lhs, CoreType rhs) { +inline CoreType operator&(CoreType lhs, CoreType rhs) { return enumOp(std::bit_and<>(), lhs, rhs); } @@ -55,38 +56,36 @@ inline bool intersects(CoreType lhs, CoreType rhs) { return (lhs & rhs) != CoreType::UNDETERMINED; } -inline CoreType operator& (OpAbility lhs, CoreType rhs) { +inline CoreType operator&(OpAbility lhs, CoreType rhs) { return toCoreType(lhs) & rhs; } -inline CoreType operator!(CoreType ct) -{ - CoreType newCt = UNDETERMINED; - if ((ct & CoreType::CUBE_ONLY) == UNDETERMINED) { - newCt = newCt | CoreType::CUBE_ONLY; - } +inline CoreType operator!(CoreType ct) { + CoreType newCt = UNDETERMINED; + if ((ct & CoreType::CUBE_ONLY) == UNDETERMINED) { + newCt = newCt | CoreType::CUBE_ONLY; + } - if ((ct & CoreType::VECTOR_ONLY) == UNDETERMINED) { - newCt = newCt | CoreType::VECTOR_ONLY; - } + if ((ct & CoreType::VECTOR_ONLY) == UNDETERMINED) { + newCt = newCt | CoreType::VECTOR_ONLY; + } - return newCt; + return newCt; } -inline hivm::TCoreType toHivm(CoreType ct) -{ - switch (ct) { - case UNDETERMINED: - return hivm::TCoreType::CUBE_OR_VECTOR; - case CUBE_ONLY: - return hivm::TCoreType::CUBE; - case VECTOR_ONLY: - return hivm::TCoreType::VECTOR; - case CUBE_AND_VECTOR: - return hivm::TCoreType::CUBE_AND_VECTOR; - default: - llvm_unreachable("Invalid CoreType that cannot convert to hivm"); - } +inline hivm::TCoreType toHivm(CoreType ct) { + switch (ct) { + case UNDETERMINED: + return hivm::TCoreType::CUBE_OR_VECTOR; + case CUBE_ONLY: + return hivm::TCoreType::CUBE; + case VECTOR_ONLY: + return hivm::TCoreType::VECTOR; + case CUBE_AND_VECTOR: + return hivm::TCoreType::CUBE_AND_VECTOR; + default: + llvm_unreachable("Invalid CoreType that cannot convert to hivm"); + } } inline bool intersects(OpAbility lhs, CoreType rhs) { @@ -97,7 +96,7 @@ inline bool exactlyOneType(CoreType ct) { return (ct == CUBE_ONLY) || (ct == VECTOR_ONLY); } -const char* literalCoreType(CoreType ct); +const char *literalCoreType(CoreType ct); class MoveOnly { protected: @@ -119,37 +118,29 @@ ValueNode *getDataSource(OpNode *op); class Graph : MoveOnly { public: - using OpMapRaw = llvm::DenseMap>; + using OpMapRaw = llvm::DenseMap>; using ValueMapRaw = llvm::DenseMap>; using OpMap = std::shared_ptr; using ValueMap = std::shared_ptr; - Graph( - Block* block, - Graph* parent = nullptr, - OpMap opMap = nullptr, - ValueMap valueMap = nullptr, - bool inheritParent = true - ); + Graph(Block *block, Graph *parent = nullptr, OpMap opMap = nullptr, + ValueMap valueMap = nullptr, bool inheritParent = true); static std::unique_ptr fromMultiBlockFunc(triton::FuncOp funcOp); - OpMapRaw& getOpMap() const { - return *opMap; - } + OpMapRaw &getOpMap() const { return *opMap; } - ValueMapRaw& getValueMap() const { - return *valueMap; - } + ValueMapRaw &getValueMap() const { return *valueMap; } // [DEBUG] start - std::unique_ptr> legacyOpMap = nullptr; + std::unique_ptr> legacyOpMap = nullptr; std::unique_ptr> legacyValueTypes = nullptr; - inline llvm::DenseMap& getOpMapLegacy() { + inline llvm::DenseMap &getOpMapLegacy() { if (!legacyOpMap) { - legacyOpMap = std::move(std::make_unique>()); - for(auto& [key, val] : *opMap) { + legacyOpMap = + std::move(std::make_unique>()); + for (auto &[key, val] : *opMap) { (*legacyOpMap)[key] = val.get(); } } @@ -157,7 +148,7 @@ class Graph : MoveOnly { return *legacyOpMap; } - llvm::DenseMap& getValueTypes() ; + llvm::DenseMap &getValueTypes(); // [DEBUG] end @@ -166,11 +157,11 @@ class Graph : MoveOnly { friend class OpNode; OpMap opMap; ValueMap valueMap; - Block* block; - Graph* parent; - OpNode* terminator = nullptr; + Block *block; + Graph *parent; + OpNode *terminator = nullptr; size_t opCount = 0; - llvm::SmallVector blockArgs; + llvm::SmallVector blockArgs; }; class Node : MoveOnly { @@ -179,19 +170,14 @@ class Node : MoveOnly { friend class ValueNode; bool isUpstreamOfCubeMem = false; virtual CoreType absorbImpl() = 0; - llvm::SmallVector outputs; + llvm::SmallVector outputs; public: CoreType isOnPrivate = UNDETERMINED; - enum NodeKind { - NK_Op, - NK_Value - }; + enum NodeKind { NK_Op, NK_Value }; - inline CoreType isOn() const { - return isOnPrivate; - } + inline CoreType isOn() const { return isOnPrivate; } bool absorb() { auto newCoreType = absorbImpl(); @@ -201,12 +187,10 @@ class Node : MoveOnly { return changed; }; - virtual llvm::SmallVector getAffected() const = 0; - virtual OpNode* getSourceOpNode() = 0; + virtual llvm::SmallVector getAffected() const = 0; + virtual OpNode *getSourceOpNode() = 0; - ArrayRef getOutputs() const { - return outputs; - } + ArrayRef getOutputs() const { return outputs; } CoreType absorbCommon(); @@ -214,9 +198,7 @@ class Node : MoveOnly { const NodeKind kind; public: - NodeKind getKind() const { - return kind; - } + NodeKind getKind() const { return kind; } protected: Node(NodeKind kind) : kind(kind) {} @@ -225,60 +207,52 @@ class Node : MoveOnly { class OpNode : public Node { friend class Graph; friend class ValueNode; - llvm::SmallVector inputs; + llvm::SmallVector inputs; llvm::SmallVector subgraphs; virtual CoreType absorbImpl() override; public: - Operation* op; + Operation *op; - OpNode(Operation* op, Graph* graph); + OpNode(Operation *op, Graph *graph); OpAbility canRunOn() const; - inline ArrayRef getInputs() const { - return inputs; - } + inline ArrayRef getInputs() const { return inputs; } - static bool classof(const Node* node) { - return node->getKind() == NK_Op; - } + static bool classof(const Node *node) { return node->getKind() == NK_Op; } - virtual llvm::SmallVector getAffected() const override { - llvm::SmallVector result(inputs.begin(), inputs.end()); + virtual llvm::SmallVector getAffected() const override { + llvm::SmallVector result(inputs.begin(), inputs.end()); result.append(outputs.begin(), outputs.end()); return result; } - virtual OpNode* getSourceOpNode() override { - return this; - } + virtual OpNode *getSourceOpNode() override { return this; } }; class ValueNode : public Node { friend class Graph; friend class OpNode; virtual CoreType absorbImpl() override; -public: - Node* source = nullptr; +public: + Node *source = nullptr; Value value; // ValueNode(OpResult value); // ValueNode(BlockArgument value); - ValueNode(Value value) : Node(NK_Value), value(value) {}; - virtual OpNode* getSourceOpNode() override { + ValueNode(Value value) : Node(NK_Value), value(value){}; + virtual OpNode *getSourceOpNode() override { if (!source) { return nullptr; } return source->getSourceOpNode(); } - static bool classof(const Node* node) { - return node->getKind() == NK_Value; - } + static bool classof(const Node *node) { return node->getKind() == NK_Value; } - virtual llvm::SmallVector getAffected() const override { - llvm::SmallVector result(outputs.begin(), outputs.end()); + virtual llvm::SmallVector getAffected() const override { + llvm::SmallVector result(outputs.begin(), outputs.end()); if (source) result.push_back(source); @@ -296,27 +270,26 @@ class GraphManager { return instance; } - void registerGraph(llvm::StringRef funcName, std::shared_ptr graph) { + void registerGraph(llvm::StringRef funcName, + std::shared_ptr graph) { graphs[funcName] = graph; } - AffinityDAG::Graph* getGraph(llvm::StringRef funcName) { + AffinityDAG::Graph *getGraph(llvm::StringRef funcName) { auto it = graphs.find(funcName); return it != graphs.end() ? it->second.get() : nullptr; } - void removeGraph(llvm::StringRef funcName) { - graphs.erase(funcName); - } + void removeGraph(llvm::StringRef funcName) { graphs.erase(funcName); } }; - -inline llvm::DenseMap& Graph::getValueTypes() { +inline llvm::DenseMap &Graph::getValueTypes() { static std::mutex mtx; std::lock_guard lock(mtx); if (!legacyValueTypes) { - legacyValueTypes = std::move(std::make_unique>()); - for(auto& [key, val] : *valueMap) { + legacyValueTypes = + std::move(std::make_unique>()); + for (auto &[key, val] : *valueMap) { llvm::dbgs() << key << "\n"; llvm::dbgs().flush(); (*legacyValueTypes)[key] = val.get()->isOn(); @@ -326,5 +299,6 @@ inline llvm::DenseMap& Graph::getValueTypes() { return *legacyValueTypes; } -} } +} // namespace AffinityDAG +} // namespace mlir #endif diff --git a/third_party/ascend/include/TritonAffinityOpt/Passes.h b/third_party/ascend/include/TritonAffinityOpt/Passes.h index 5c9a63225f..f58c7563bc 100644 --- a/third_party/ascend/include/TritonAffinityOpt/Passes.h +++ b/third_party/ascend/include/TritonAffinityOpt/Passes.h @@ -44,4 +44,4 @@ std::unique_ptr> createDAGScopePass(); } // namespace triton } // namespace mlir -#endif // TRITON_ADAPTER_TRITON_AFFINITY_OPTIMIZATION_PASSES_H \ No newline at end of file +#endif // TRITON_ADAPTER_TRITON_AFFINITY_OPTIMIZATION_PASSES_H diff --git a/third_party/ascend/include/TritonAffinityOpt/Passes.td b/third_party/ascend/include/TritonAffinityOpt/Passes.td index b2a72f58db..f12de8444e 100644 --- a/third_party/ascend/include/TritonAffinityOpt/Passes.td +++ b/third_party/ascend/include/TritonAffinityOpt/Passes.td @@ -26,4 +26,4 @@ def DAGSync : Pass<"dag-sync", "mlir::ModuleOp"> { let dependentDialects = ["hivm::HIVMDialect", "bufferization::BufferizationDialect", "annotation::AnnotationDialect"]; } -#endif // TRITON_AFFINITY_OPTIMIZATION_PASSES \ No newline at end of file +#endif // TRITON_AFFINITY_OPTIMIZATION_PASSES diff --git a/third_party/ascend/include/TritonAffinityOpt/Utils.hpp b/third_party/ascend/include/TritonAffinityOpt/Utils.hpp index 42b8c64a3f..d3aa63be77 100644 --- a/third_party/ascend/include/TritonAffinityOpt/Utils.hpp +++ b/third_party/ascend/include/TritonAffinityOpt/Utils.hpp @@ -5,22 +5,16 @@ namespace mlir::AffinityDAG { -template -constexpr inline T enumOp(F&& func, T lhs, T rhs) { +template +constexpr inline T enumOp(F &&func, T lhs, T rhs) { static_assert(std::is_enum_v, "T must be an enum type"); using U = std::underlying_type_t; - return static_cast( - std::invoke( - std::forward(func), - static_cast(lhs), - static_cast(rhs) - ) - ); + return static_cast(std::invoke(std::forward(func), static_cast(lhs), + static_cast(rhs))); } -} // namespace TritonAffinity::Utils +} // namespace mlir::AffinityDAG - -#endif \ No newline at end of file +#endif diff --git a/third_party/ascend/language/cann/extension/__init__.py b/third_party/ascend/language/cann/extension/__init__.py index 1d8c4dc31b..efa09c71fd 100644 --- a/third_party/ascend/language/cann/extension/__init__.py +++ b/third_party/ascend/language/cann/extension/__init__.py @@ -148,4 +148,3 @@ "scatter_ub_to_out", "index_select_simd", ] - diff --git a/third_party/ascend/language/cann/extension/aux_ops.py b/third_party/ascend/language/cann/extension/aux_ops.py index f7a8d55e8a..fd134a787b 100644 --- a/third_party/ascend/language/cann/extension/aux_ops.py +++ b/third_party/ascend/language/cann/extension/aux_ops.py @@ -1,33 +1,23 @@ import triton.language as tl from triton.language import semantic, core, standard -from triton.language.core import ( - _constexpr_to_value, - _tensor_member_fn, - _unwrap_iterable, - builtin, - constexpr, - dtype, - tensor, - check_bit_width, - _unwrap_if_constexpr, - range -) +from triton.language.core import (_constexpr_to_value, _tensor_member_fn, _unwrap_iterable, builtin, constexpr, dtype, + tensor, check_bit_width, _unwrap_if_constexpr, range) from triton.language.semantic import ( - wrap_tensor, - _str_to_rounding_mode, - not_equal, + wrap_tensor, + _str_to_rounding_mode, + not_equal, _str_to_dot_input_precision, - binary_op_type_checking_impl, - integer_promote_impl, - broadcast_impl_shape, - _str_to_sem, - _str_to_scope, + binary_op_type_checking_impl, + integer_promote_impl, + broadcast_impl_shape, + _str_to_sem, + _str_to_scope, bitcast, bitwise_op_type_checking_impl, - to_tensor, - _str_to_load_cache_modifier, + to_tensor, + _str_to_load_cache_modifier, _str_to_eviction_policy, - _str_to_padding_option, + _str_to_padding_option, _canonicalize_boundary_check, ) @@ -67,8 +57,10 @@ def sync_block_set(sender, receiver, event_id, _builder=None): sender = _constexpr_to_value(sender) receiver = _constexpr_to_value(receiver) event_id = _constexpr_to_value(event_id) - assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" - assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(sender, str) and (sender == "cube" + or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver + == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" if sender == receiver: raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') @@ -88,8 +80,10 @@ def sync_block_wait(sender, receiver, event_id, _builder=None): sender = _constexpr_to_value(sender) receiver = _constexpr_to_value(receiver) event_id = _constexpr_to_value(event_id) - assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" - assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(sender, str) and (sender == "cube" + or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver + == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" if sender == receiver: raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') @@ -106,7 +100,9 @@ class parallel(range): This is used in the mixed cube-vector kernel on 910B. The number of vector cores is determined by the number of iteration in this loop. Currently on 910B, max 2 vector cores could be used. """ - def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, bind_sub_block: bool = False): + + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, + bind_sub_block: bool = False): super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) self.bind_sub_block = bind_sub_block @@ -132,6 +128,7 @@ def compile_hint_impl(ptr: tensor, hint_name: str, hint_val, builder: ir.builder raise ValueError(f"Unsupported hint value type: {type(hint_val)}") builder.create_annotation_mark(ptr.handle, hint_name, hint_val) + @builtin def compile_hint(ptr, hint_name, hint_val=None, _builder=None): # simt mode does not support hint annotations @@ -150,6 +147,7 @@ def _unwrap(val): hint_val = _unwrap_if_constexpr(hint_val) if hint_val else hint_val compile_hint_impl(ptr, hint_name, hint_val, _builder) + @builtin def multibuffer(src: tensor, size, _builder=None): """ diff --git a/third_party/ascend/language/cann/extension/core.py b/third_party/ascend/language/cann/extension/core.py index 333615c870..e9520c679c 100644 --- a/third_party/ascend/language/cann/extension/core.py +++ b/third_party/ascend/language/cann/extension/core.py @@ -21,27 +21,9 @@ # THE SOFTWARE. __all__ = [ - "ascend_address_space", - "builtin", - "CORE", - "copy_from_ub_to_l1", - "copy", - "debug_barrier", - "fixpipe", - "FixpipeDMAMode", - "FixpipeDualDstMode", - "FixpipePreQuantMode", - "FixpipePreReluMode", - "int64", - "is_builtin", - "MODE", - "PIPE", - "IteratorType", - "sub_vec_id", - "sub_vec_num", - "sync_block_all", - "sync_block_set", - "sync_block_wait", + "ascend_address_space", "builtin", "CORE", "copy_from_ub_to_l1", "copy", "debug_barrier", "fixpipe", + "FixpipeDMAMode", "FixpipeDualDstMode", "FixpipePreQuantMode", "FixpipePreReluMode", "int64", "is_builtin", "MODE", + "PIPE", "IteratorType", "sub_vec_id", "sub_vec_num", "sync_block_all", "sync_block_set", "sync_block_wait", "SYNC_IN_VF" ] @@ -58,8 +40,8 @@ from triton.backends.ascend.driver import NPUUtils from . import semantic as semantic -PIPE = semantic.PIPE +PIPE = semantic.PIPE T = TypeVar("T") @@ -95,6 +77,7 @@ class int64(int): For custom op, python int argument will be converted to int32 by default, if a device-side int64 is required, you can pass an al.int64(x) to it. """ + def __new__(cls, value): obj = int.__new__(cls, value) obj.type = tl.int64 @@ -141,6 +124,7 @@ class IteratorType(enum.Enum): class ascend_address_space_base(bl.address_space): + def __init__(self, address_space_value: ascend_ir.AddressSpace) -> None: super().__init__() self.real_address_space = address_space_value @@ -152,11 +136,9 @@ def to_ir(self, builder: ir.builder) -> ir.attribute: class ascend_address_space_group: def __init__(self): - for k, v in { - k: v - for k, v in ascend_ir.AddressSpace.__dict__.items() - if isinstance(v, ascend_ir.AddressSpace) - }.items(): + for k, v in {k: v + for k, v in ascend_ir.AddressSpace.__dict__.items() + if isinstance(v, ascend_ir.AddressSpace)}.items(): setattr(self, k, ascend_address_space_base(v)) @@ -199,13 +181,13 @@ def copy(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], _bu return semantic.copy(src, dst, _builder) -def create_sync_block(sender, receiver, event_id, is_set: bool, - sender_pipe=None, receiver_pipe=None, - _builder=None): +def create_sync_block(sender, receiver, event_id, is_set: bool, sender_pipe=None, receiver_pipe=None, _builder=None): sender = _constexpr_to_value(sender) receiver = _constexpr_to_value(receiver) - assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" - assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(sender, str) and (sender == "cube" + or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver + == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" if isinstance(event_id, int): assert (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" if sender == receiver: @@ -240,7 +222,8 @@ def sync_block_all(mode, event_id, _builder=None): event_id = _constexpr_to_value(event_id) assert isinstance(mode, str), f"mode: {mode} is not string" assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" - assert mode in ("all_cube", "all_vector", "all", "all_sub_vector"), f"ERROR: mode = {mode}, only supports all_cube/all_vector/all/all_sub_vector" + assert mode in ("all_cube", "all_vector", "all", + "all_sub_vector"), f"ERROR: mode = {mode}, only supports all_cube/all_vector/all/all_sub_vector" _builder.sync_block_all(mode, event_id) @@ -300,26 +283,19 @@ def fixpipe( if dst.space != ascend_address_space.UB: raise TypeError("dst must be located in the UB memory region") - if len(dst.shape) == 2 and ( - dst.type.element_ty == tl.float32 or dst.type.element_ty == tl.int32 - ): + if len(dst.shape) == 2 and (dst.type.element_ty == tl.float32 or dst.type.element_ty == tl.int32): N = dst.shape[1] if N % 8 != 0: raise ValueError("32b Fixpipe last dim must be aligned to 8") if (dma_mode != FixpipeDMAMode.NZ2ND) and (N % 16 != 0): raise ValueError("32b non-NZ2ND Fixpipe last dim must be aligned to 16") if (dual_dst_mode == FixpipeDualDstMode.COLUMN_SPLIT) and (N % 32 != 0): - raise ValueError( - "32b Column split dual Fixpipe last dim must be aligned to 32" - ) + raise ValueError("32b Column split dual Fixpipe last dim must be aligned to 32") M = dst.shape[0] if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 8 != 0): raise ValueError("32b NZ2DN Fixpipe first dim must be aligned to 8") - dst16bits = ( - dst.type.element_ty == tl.float16 - or dst.type.element_ty == tl.int16 - or dst.type.element_ty == tl.bfloat16 - ) + dst16bits = (dst.type.element_ty == tl.float16 or dst.type.element_ty == tl.int16 + or dst.type.element_ty == tl.bfloat16) if len(dst.shape) == 2 and dst16bits: N = dst.shape[1] if N % 16 != 0: @@ -328,9 +304,8 @@ def fixpipe( if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 16 != 0): raise ValueError("16b NZ2DN Fixpipe first dim must be aligned to 16") - return semantic.fixpipe( - src, dst, dma_mode, dual_dst_mode, FixpipePreQuantMode.NO_QUANT, FixpipePreReluMode.NO_RELU, _builder - ) + return semantic.fixpipe(src, dst, dma_mode, dual_dst_mode, FixpipePreQuantMode.NO_QUANT, FixpipePreReluMode.NO_RELU, + _builder) class SYNC_IN_VF(enum.Enum): diff --git a/third_party/ascend/language/cann/extension/custom_op.py b/third_party/ascend/language/cann/extension/custom_op.py index e1b149e2b9..8644a6d9e6 100644 --- a/third_party/ascend/language/cann/extension/custom_op.py +++ b/third_party/ascend/language/cann/extension/custom_op.py @@ -176,10 +176,10 @@ def _make_align_dim_attrs(op, builder, arg_attrs): for arg, align_val in op.align_dim.items(): if isinstance(arg, str) and arg in align_arg_indices: - arg_attrs[align_arg_indices[arg]] = { name : builder.get_int_attr(align_val) } + arg_attrs[align_arg_indices[arg]] = {name: builder.get_int_attr(align_val)} print(arg_attrs[align_arg_indices[arg]]) elif isinstance(arg, int): - arg_attrs[arg] = { name : builder.get_int_attr(align_val) } + arg_attrs[arg] = {name: builder.get_int_attr(align_val)} print(arg_attrs[arg]) else: assert False, f"{name}'s keys should be string or int" @@ -216,7 +216,7 @@ def _add_optional_extra_buffer_attr(op, builder, attrs): extra_buffers = getattr(op, name) if isinstance(extra_buffers, tuple): - extra_buffers = [ extra_buffers ] + extra_buffers = [extra_buffers] extra_buffer_types, extra_buffer_sizes = zip(*extra_buffers) attrs[name + "_types"] = builder.get_type_array_attr([ty.to_ir(builder) for ty in extra_buffer_types]) @@ -256,7 +256,6 @@ def _make_attrs(op, builder): # Add bit code path attribute, formalize to abosulte path. _add_bitcode_attr(op, builder, attrs) - _add_optional_indexing_map_attr(op, builder, attrs) _add_optional_iterator_types_attr(op, builder, attrs) diff --git a/third_party/ascend/language/cann/extension/mem_ops.py b/third_party/ascend/language/cann/extension/mem_ops.py index a59b71add9..72ba4a54ea 100644 --- a/third_party/ascend/language/cann/extension/mem_ops.py +++ b/third_party/ascend/language/cann/extension/mem_ops.py @@ -13,21 +13,21 @@ _unwrap_if_constexpr, ) from triton.language.semantic import ( - wrap_tensor, - _str_to_rounding_mode, - not_equal, + wrap_tensor, + _str_to_rounding_mode, + not_equal, _str_to_dot_input_precision, - binary_op_type_checking_impl, - integer_promote_impl, - broadcast_impl_shape, - _str_to_sem, - _str_to_scope, + binary_op_type_checking_impl, + integer_promote_impl, + broadcast_impl_shape, + _str_to_sem, + _str_to_scope, bitcast, bitwise_op_type_checking_impl, - to_tensor, - _str_to_load_cache_modifier, + to_tensor, + _str_to_load_cache_modifier, _str_to_eviction_policy, - _str_to_padding_option, + _str_to_padding_option, _canonicalize_boundary_check, ) @@ -39,17 +39,8 @@ @_tensor_member_fn @builtin -def index_put( - ptr: tensor, - index: tensor, - value: tensor, - dim: int, - index_boundary: int, - end_offset: tuple, - start_offset: tuple, - dst_stride: tuple, - _builder=None -): +def index_put(ptr: tensor, index: tensor, value: tensor, dim: int, index_boundary: int, end_offset: tuple, + start_offset: tuple, dst_stride: tuple, _builder=None): """ Index put values from a tensor into a destination tensor. @@ -59,10 +50,10 @@ def index_put( out[index[i]][start_offset[1]:end_offset[1]] = value[i][0:end_offset[1]-start_offset[1]] 2. 3D index scatter (0 <= dim < 2): 2.1 dim = 0 - out[index[i]][start_offset[1]:end_offset[1]][start_offset[2]:end_offset[2]] + out[index[i]][start_offset[1]:end_offset[1]][start_offset[2]:end_offset[2]] = value[i][0:end_offset[1]-start_offset[1]][0:end_offset[2]-start_offset[2]] 2.2 dim = 1 - out[start_offset[0]:end_offset[0]][index[j]][start_offset[2]:end_offset[2]] + out[start_offset[0]:end_offset[0]][index[j]][start_offset[2]:end_offset[2]] = value[0:end_offset[0]-start_offset[0]][j][0:end_offset[2]-start_offset[2]] @@ -121,17 +112,8 @@ def simple_index_put_kernel(value_ptr, index_ptr, dst_ptr): print("IndexPut result:", dst) # ref:[[3.,4.], [0.,0.], [1.,2.], [0.,0.]] """ - def index_put_impl( - ptr: tl.tensor, - index: tl.tensor, - value: tl.tensor, - dim: int, - index_boundary: int, - end_offset: Tuple, - start_offset: Tuple, - dst_stride: Tuple, - _builder: ir.builder - ): + def index_put_impl(ptr: tl.tensor, index: tl.tensor, value: tl.tensor, dim: int, index_boundary: int, + end_offset: Tuple, start_offset: Tuple, dst_stride: Tuple, _builder: ir.builder): assert index.dtype.is_int(), "index must be an integer tensor" if not ptr.dtype.element_ty.is_floating(): raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {ptr.dtype.element_ty}") @@ -144,18 +126,16 @@ def index_put_impl( raise ValueError(f"value rank must be in [2, 5], got value rank={v_rank}") if dim < 0 or dim >= v_rank - 1: raise ValueError(f"dim must satisfy 0<=dim 0 @@ -478,21 +404,13 @@ def _is_ranked_tensor(x): if not _is_ranked_tensor(value) or isinstance(value, constexpr): element_ty = ptr.type.scalar.element_ty value = real_semantic.full(index.shape, value, element_ty, _builder) - return scatter_ub_to_out_impl(ptr, value, index, index_boundary, dim, - dst_stride, end_offset, start_offset, _builder) + return scatter_ub_to_out_impl(ptr, value, index, index_boundary, dim, dst_stride, end_offset, start_offset, + _builder) @_tensor_member_fn @builtin -def index_select_simd( - src, - dim, - index, - src_shape, - src_offset, - read_shape, - _builder=None -) -> tensor: +def index_select_simd(src, dim, index, src_shape, src_offset, read_shape, _builder=None) -> tensor: """ Parallel index_select operation from Global Memory to Unified Buffer (SIMD version). @@ -560,15 +478,9 @@ def kernel(src_ptr, output_ptr, indices_ptr, M, N, D, ...): :rtype: tensor """ - def index_select_simd_impl( - src: tl.tensor, - dim: int, - index: tl.tensor, - src_shape: List[Union[int, tl.tensor]], - src_offset: List[Union[int, tl.tensor]], - read_shape: List[Union[int, tl.tensor]], - _builder: ir.builder - ) -> tl.tensor: + def index_select_simd_impl(src: tl.tensor, dim: int, index: tl.tensor, src_shape: List[Union[int, tl.tensor]], + src_offset: List[Union[int, tl.tensor]], read_shape: List[Union[int, tl.tensor]], + _builder: ir.builder) -> tl.tensor: # Validate inputs ndim = len(src_shape) assert len(src_offset) == ndim, \ @@ -602,13 +514,11 @@ def index_select_simd_impl( newsrc_offset.append(s.handle if hasattr(s, 'handle') else s) # Create output type - return_shape = [ - index.shape[0] if i == dim else read_shape[i] - for i in range(ndim) - ] + return_shape = [index.shape[0] if i == dim else read_shape[i] for i in range(ndim)] element_ty = src.type.element_ty output_ty = tl.block_type(element_ty, return_shape) - out = _builder.create_index_select_simd(src.handle, index.handle, dim, newsrc_shape, newsrc_offset, read_shape, return_shape) + out = _builder.create_index_select_simd(src.handle, index.handle, dim, newsrc_shape, newsrc_offset, read_shape, + return_shape) return tl.tensor(out, output_ty) dim = _constexpr_to_value(dim) @@ -621,16 +531,8 @@ def process_param(val): else: return _constexpr_to_value(val) - newsrc_shape = [ - real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o - for o in src_shape - ] - newsrc_offset = [ - real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o - for o in src_offset - ] + newsrc_shape = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_shape] + newsrc_offset = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_offset] assert len(index.shape) == 1, "index must be a 1D tensor" - return index_select_simd_impl( - src, dim, index, newsrc_shape, newsrc_offset, read_shape, _builder - ) + return index_select_simd_impl(src, dim, index, newsrc_shape, newsrc_offset, read_shape, _builder) diff --git a/third_party/ascend/language/cann/extension/semantic.py b/third_party/ascend/language/cann/extension/semantic.py index 2f43733dca..e4a90ad9d5 100644 --- a/third_party/ascend/language/cann/extension/semantic.py +++ b/third_party/ascend/language/cann/extension/semantic.py @@ -26,9 +26,7 @@ ] import enum -from typing import ( - TypeVar, List, Union -) +from typing import (TypeVar, List, Union) from triton._C.libtriton import ir from triton._C.libtriton.ascend import ir as ascend_ir @@ -41,10 +39,8 @@ T = TypeVar('T') -def create_address_space( - address_space: ascend_ir.AddressSpace, - builder: ascend_ir.ascendnpu_ir_builder -) -> ir.attribute: +def create_address_space(address_space: ascend_ir.AddressSpace, + builder: ascend_ir.ascendnpu_ir_builder) -> ir.attribute: return builder.get_target_attribute(address_space) @@ -62,29 +58,27 @@ class PIPE(enum.Enum): def create_sync_block_set(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): if isinstance(event_id, int): _builder.sync_block_set(sender, receiver, - real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, - sender_pipe.value, receiver_pipe.value) + real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, sender_pipe.value, + receiver_pipe.value) elif isinstance(event_id, tl.constexpr): _builder.sync_block_set(sender, receiver, - real_semantic.to_tensor(event_id, _builder).handle, - sender_pipe.value, receiver_pipe.value) + real_semantic.to_tensor(event_id, _builder).handle, sender_pipe.value, + receiver_pipe.value) else: - _builder.sync_block_set(sender, receiver, - event_id.handle, sender_pipe.value, receiver_pipe.value) + _builder.sync_block_set(sender, receiver, event_id.handle, sender_pipe.value, receiver_pipe.value) def create_sync_block_wait(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): if isinstance(event_id, int): _builder.sync_block_wait(sender, receiver, - real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, - sender_pipe.value, receiver_pipe.value) + real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, sender_pipe.value, + receiver_pipe.value) elif isinstance(event_id, tl.constexpr): _builder.sync_block_wait(sender, receiver, - real_semantic.to_tensor(event_id, _builder).handle, - sender_pipe.value, receiver_pipe.value) + real_semantic.to_tensor(event_id, _builder).handle, sender_pipe.value, + receiver_pipe.value) else: - _builder.sync_block_wait(sender, receiver, - event_id.handle, sender_pipe.value, receiver_pipe.value) + _builder.sync_block_wait(sender, receiver, event_id.handle, sender_pipe.value, receiver_pipe.value) def sub_vec_id(builder: ascend_ir.ascendnpu_ir_builder) -> tl.tensor: diff --git a/third_party/ascend/language/cann/extension/vec_ops.py b/third_party/ascend/language/cann/extension/vec_ops.py index ea2a5f7c41..57e152f9a2 100644 --- a/third_party/ascend/language/cann/extension/vec_ops.py +++ b/third_party/ascend/language/cann/extension/vec_ops.py @@ -7,34 +7,24 @@ import triton.language as tl from triton.language import semantic, core, standard -from triton.language.core import ( - _constexpr_to_value, - _tensor_member_fn, - _unwrap_iterable, - builtin, - constexpr, - dtype, - tensor, - check_bit_width, - _unwrap_if_constexpr, - range -) +from triton.language.core import (_constexpr_to_value, _tensor_member_fn, _unwrap_iterable, builtin, constexpr, dtype, + tensor, check_bit_width, _unwrap_if_constexpr, range) from triton.language.semantic import ( - wrap_tensor, - _str_to_rounding_mode, - not_equal, + wrap_tensor, + _str_to_rounding_mode, + not_equal, _str_to_dot_input_precision, - binary_op_type_checking_impl, - integer_promote_impl, - broadcast_impl_shape, - _str_to_sem, - _str_to_scope, + binary_op_type_checking_impl, + integer_promote_impl, + broadcast_impl_shape, + _str_to_sem, + _str_to_scope, bitcast, bitwise_op_type_checking_impl, - to_tensor, - _str_to_load_cache_modifier, + to_tensor, + _str_to_load_cache_modifier, _str_to_eviction_policy, - _str_to_padding_option, + _str_to_padding_option, _canonicalize_boundary_check, ) @@ -44,6 +34,7 @@ from typing import Optional, Tuple, List, overload from triton._C.libtriton import ir + @_tensor_member_fn @builtin def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: @@ -62,12 +53,13 @@ def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=No :type strides: tuple of ints """ - def insert_slice_impl(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tensor: - assert(len(ful.shape) == len(offsets)) - assert(len(ful.shape) == len(sizes)) - assert(len(ful.shape) == len(strides)) - assert(all([s>=1 for s in sizes])) - assert(all([s>=0 for s in strides])) + def insert_slice_impl(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) # Handle both tensor and int offsets (for interpreter mode) new_offsets = [] for o in offsets: @@ -84,10 +76,7 @@ def insert_slice_impl(ful: tensor, sub: tensor, offsets: List[tensor], sizes: Li assert len(ful.shape) > 0 assert len(ful.shape) == len(sub.shape) - new_offsets = [ - semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o - for o in offsets - ] + new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] out = insert_slice_impl(ful, sub, new_offsets, sizes, strides, _builder) return out @@ -108,12 +97,13 @@ def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) :type strides: tuple of ints """ - def extract_slice_impl(ful: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tensor: - assert(len(ful.shape) == len(offsets)) - assert(len(ful.shape) == len(sizes)) - assert(len(ful.shape) == len(strides)) - assert(all([s>=1 for s in sizes])) - assert(all([s>=0 for s in strides])) + def extract_slice_impl(ful: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) # Handle both tensor and int offsets (for interpreter mode) new_offsets = [] for o in offsets: @@ -129,13 +119,11 @@ def extract_slice_impl(ful: tensor, offsets: List[tensor], sizes: List[int], str return tensor(out, ret_type) assert len(ful.shape) > 0 - new_offsets = [ - semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o - for o in offsets - ] + new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] sub = extract_slice_impl(ful, new_offsets, sizes, strides, _builder) return sub + @_tensor_member_fn @builtin def get_element(src, indice, _builder=None, _generator=None): @@ -165,17 +153,15 @@ def get_element_impl(src: tensor, indice: List[tensor], builder: ir.builder): else: # Try to use .handle attribute if available new_indice.append(i.handle if hasattr(i, 'handle') else i) - + result = builder.create_extract_scalar(src.handle, new_indice) return wrap_tensor(result, src.type.scalar, None) assert len(src.shape) > 0 - new_indice = [ - semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i - for i in indice - ] + new_indice = [semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i for i in indice] return get_element_impl(src, new_indice, _builder) + @builtin def flip(ptr, dim=-1, _builder=None, _generator=None): @@ -236,15 +222,11 @@ def flip_simd(ptr: tensor, dim: int, builder: ir.builder): raise ValueError("ascend.flip requires tensor rank >= 1") norm_dim = dim if dim >= 0 else dim + rank if not (0 <= norm_dim < rank): - raise ValueError( - f"ascend.flip got invalid dim={dim} for shape {tuple(shape)}" - ) + raise ValueError(f"ascend.flip got invalid dim={dim} for shape {tuple(shape)}") dim = norm_dim else: if dim < 0: - raise ValueError( - "ascend.flip with unknown rank requires non-negative dim" - ) + raise ValueError("ascend.flip with unknown rank requires non-negative dim") flipped_vals = builder.create_flip(ptr.handle, dim) flipped = tensor(flipped_vals, type=ptr.type) @@ -262,7 +244,10 @@ def flip_simd(ptr: tensor, dim: int, builder: ir.builder): return ptr # reshape the swap dimension to (2, 2, ..., 2) idtype = core.get_int_dtype(bitwidth=ptr.dtype.primitive_bitwidth, signed=True) - y = core.reshape(ptr.to(idtype, bitcast=True, _builder=builder), ptr.shape.__getitem__(slice(None, _dim)) + [2] * steps + ptr.shape.__getitem__(slice(_dim + 1, None)), _builder=builder) + y = core.reshape( + ptr.to(idtype, bitcast=True, _builder=builder), + ptr.shape.__getitem__(slice(None, _dim)) + [2] * steps + ptr.shape.__getitem__(slice(_dim + 1, None)), + _builder=builder) for i in static_range(steps): y = y.__xor__(standard.xor_sum(y, _dim + i, True, _builder=builder, _generator=generator), _builder=builder) ptr = core.reshape(y, ptr.shape, _builder=builder).to(ptr.dtype, bitcast=True, _builder=builder) @@ -282,6 +267,7 @@ class static_range: Iterator for non-JIT Python functions that need to iterate over constexpr values. This is used in functions like flip that are called during compilation. """ + def __init__(self, arg1, arg2=None, step=None): if step is None: self.step = core.constexpr(1) @@ -328,13 +314,14 @@ def sort(ptr, dim=-1, descending=False, _builder=None): """ def sort_impl(ptr: tensor, dim: int, descending, builder: ir.builder): - allowed_types = {tl.int8, tl.int16, tl.bfloat16, tl.float16, tl.float32, tl.int32, tl.int64, tl.float8e4nv, tl.float8e5} + allowed_types = { + tl.int8, tl.int16, tl.bfloat16, tl.float16, tl.float32, tl.int32, tl.int64, tl.float8e4nv, tl.float8e5 + } base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type if base_ty not in allowed_types: raise TypeError( f"ascend.sort only supports int8, int16, bfloat16, float16, float32, int32, int64, float8e4nv, float8e5" - f"but got {ptr.type}" - ) + f"but got {ptr.type}") shape = getattr(ptr, "shape", None) if shape is None or shape == (): @@ -353,17 +340,13 @@ def sort_impl(ptr: tensor, dim: int, descending, builder: ir.builder): last_dim = rank - 1 norm_dim = dim if dim >= 0 else dim + rank if norm_dim != last_dim: - raise ValueError( - f"ascend.sort only supports sorting along the last dimension " - f"(dim={last_dim} or -1) for shape {tuple(shape)}, but got dim={dim}" - ) + raise ValueError(f"ascend.sort only supports sorting along the last dimension " + f"(dim={last_dim} or -1) for shape {tuple(shape)}, but got dim={dim}") dim = last_dim else: if dim != -1: - raise ValueError( - "ascend.sort only supports the last dimension; when rank is unknown " - "you must pass dim=-1" - ) + raise ValueError("ascend.sort only supports the last dimension; when rank is unknown " + "you must pass dim=-1") if hasattr(descending, "value"): descending = bool(descending.value) @@ -397,8 +380,8 @@ def sort_impl(ptr: tensor, dim: int, descending, builder: ir.builder): return ret -def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, - fp_downcast_rounding: Optional[str] = None, overflow_mode: Optional[str] = None) -> tensor: +def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, fp_downcast_rounding: Optional[str] = None, + overflow_mode: Optional[str] = None) -> tensor: src_ty = input.type if isinstance(dst_ty, tl.constexpr): dst_ty = dst_ty.value @@ -408,7 +391,7 @@ def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) if src_ty == dst_ty: return input - + src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar if src_sca_ty == dst_sca_ty: @@ -427,9 +410,9 @@ def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) if not is_compile_on_910_95: - if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): + if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): raise ValueError("[fp8, fp64] is unsupported on Ascend for now." - "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): assert builder.codegen_fns.get( "convert_custom_types") is not None, "target doesn't provide conversion for this type." @@ -472,7 +455,7 @@ def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, if dst_sca_ty.is_bool(): ty = input.dtype.to_ir(builder) _0 = tensor(builder.get_null_value(ty), input.dtype) - return not_equal(input, _0, builder) + return not_equal(input, _0, builder) elif overflow_mode == "saturate" and \ (src_sca_ty.is_int_unsigned() or dst_sca_ty.is_int_unsigned()) and \ src_sca_ty.int_bitwidth >= dst_sca_ty.int_bitwidth: @@ -509,7 +492,8 @@ def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, if bitwidth == 64: return tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) if bitwidth == 1: - return not_equal(ascend_cast_impl(input, tl.int64, builder), tensor(builder.get_int64(0), tl.int64), builder) + return not_equal(ascend_cast_impl(input, tl.int64, builder), tensor(builder.get_int64(0), tl.int64), + builder) # Casting integer types to pointer types if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): @@ -521,9 +505,11 @@ def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, assert False, f'cannot cast {input} to {dst_ty}' + @_tensor_member_fn @builtin -def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, overflow_mode: Optional[str] = None, _builder=None): +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, + overflow_mode: Optional[str] = None, _builder=None): """ Casts a tensor to the given :code:`dtype`. diff --git a/third_party/ascend/language/cann/libdevice.py b/third_party/ascend/language/cann/libdevice.py index d27bdefb4f..07836b7142 100644 --- a/third_party/ascend/language/cann/libdevice.py +++ b/third_party/ascend/language/cann/libdevice.py @@ -84,10 +84,9 @@ def atan(arg0, _builder=None): @core.extern def tanh(arg0, _builder=None): if triton_enable_libdevice_simt() and is_compile_on_910_95: - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_tanh_fp32", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_tanh_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( "", "", [arg0], { @@ -117,10 +116,9 @@ def ldexp(arg0, arg1, _builder=None): @core.extern def pow(arg0, arg1, _builder=None): if triton_enable_libdevice_simt() and is_compile_on_910_95: - return core.extern_elementwise( - "", "", [arg0, arg1], { - (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_pow_fp32", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_pow_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( "", "", [arg0, arg1], { @@ -142,10 +140,10 @@ def isnan(arg0, _builder=None): @core.extern def div_rz(arg0, arg1, _builder=None): - return core.extern_elementwise( - "", "", [arg0, arg1], { - (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_div_rz_fp32", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_div_rz_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + @core.builtin def fast_dividef(arg0, arg1, _builder=None): @@ -154,6 +152,7 @@ def fast_dividef(arg0, arg1, _builder=None): ret = semantic.fdiv(arg0, arg1, False, _builder) return ret + @core.builtin def fast_expf(arg0, _builder=None): arg0 = semantic.to_tensor(arg0, _builder) @@ -163,19 +162,16 @@ def fast_expf(arg0, _builder=None): @core.extern def fmod(arg0, arg1, _builder=None): - return core.extern_elementwise( - "", "", [arg0, arg1], { - (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_fmod_fp32", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_fmod_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def float_as_int(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"),): ("__hmf_float_as_int_fp32", core.dtype("int32")), - }, is_pure=True, _builder=_builder) - + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_float_as_int_fp32", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern @@ -190,37 +186,34 @@ def atan2(arg0, arg1, _builder): }, is_pure=True, _builder=_builder) -@core.builtin -@math._check_dtype(dtypes=["fp32"]) +@core.builtin +@math._check_dtype(dtypes=["fp32"]) @math._add_math_1arg_docstr("trunc") def trunc(arg0, _builder=None): if triton_enable_libdevice_simt() and is_compile_on_910_95: return core.extern_elementwise( "", "", [arg0], { - (core.dtype("fp16"),): ("__hmf_trunc_fp16", core.dtype("fp16")), - (core.dtype("fp32"),): ("__hmf_trunc_fp32", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_trunc_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_trunc_fp32", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - arg0 = semantic.to_tensor(arg0, _builder) - - - zero = semantic.full(arg0.shape, 0.0, arg0.type.scalar, _builder) - condition = semantic.greater_equal(arg0, zero, _builder) - - - floor_result = core.tensor(_builder.create_floor(arg0.handle), arg0.type) - ceil_result = core.tensor(_builder.create_ceil(arg0.handle), arg0.type) - - + arg0 = semantic.to_tensor(arg0, _builder) + + zero = semantic.full(arg0.shape, 0.0, arg0.type.scalar, _builder) + condition = semantic.greater_equal(arg0, zero, _builder) + + floor_result = core.tensor(_builder.create_floor(arg0.handle), arg0.type) + ceil_result = core.tensor(_builder.create_ceil(arg0.handle), arg0.type) + return semantic.where(condition, floor_result, ceil_result, _builder) @core.extern def round(arg0, _builder=None): - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_roundf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_roundf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @@ -232,8 +225,8 @@ def acos(arg0: core.tensor, _builder: ir.builder): core.static_assert(False) return core.extern_elementwise( "", "", [arg0], { - (core.dtype("fp16"),): ("__hmf_acos_fp16", core.dtype("fp16")), - (core.dtype("fp32"),): ("__hmf_acos_fp32", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_acos_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_acos_fp32", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: pi = 3.1415926536 @@ -293,8 +286,8 @@ def sinh(arg0: core.tensor, _builder: ir.builder): core.static_assert(False) return core.extern_elementwise( "", "", [arg0], { - (core.dtype("fp16"),): ("__hmf_sinh_fp16", core.dtype("fp16")), - (core.dtype("fp32"),): ("__hmf_sinh_fp32", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_sinh_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_sinh_fp32", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: arg0 = semantic.to_tensor(arg0, _builder) @@ -315,8 +308,8 @@ def cosh(arg0: core.tensor, _builder: ir.builder): core.static_assert(False) return core.extern_elementwise( "", "", [arg0], { - (core.dtype("fp16"),): ("__hmf_cosh_fp16", core.dtype("fp16")), - (core.dtype("fp32"),): ("__hmf_cosh_fp32", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_cosh_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_cosh_fp32", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: arg0 = semantic.to_tensor(arg0, _builder) @@ -402,8 +395,8 @@ def expm1(arg0: core.tensor, _builder: ir.builder): core.static_assert(False) return core.extern_elementwise( "", "", [arg0], { - (core.dtype("fp16"),): ("__hmf_expm1_fp16", core.dtype("fp16")), - (core.dtype("fp32"),): ("__hmf_expm1_fp32", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_expm1_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_expm1_fp32", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: arg0 = semantic.to_tensor(arg0, _builder) @@ -424,21 +417,9 @@ def nextafter(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): else: x = semantic.to_tensor(arg0, _builder) y = semantic.to_tensor(arg1, _builder) - dtype_map = { - "bf16": core.int16, - "fp16": core.int16, - "fp32": core.int32 - } - min_pos_bit = { - "bf16": 0x0001, - "fp16": 0x0001, - "fp32": 0x00000001 - } - max_neg_bit = { - "bf16": 0x8001, - "fp16": 0x8001, - "fp32": 0x80000001 - } + dtype_map = {"bf16": core.int16, "fp16": core.int16, "fp32": core.int32} + min_pos_bit = {"bf16": 0x0001, "fp16": 0x0001, "fp32": 0x00000001} + max_neg_bit = {"bf16": 0x8001, "fp16": 0x8001, "fp32": 0x80000001} int_type = dtype_map[x.type.scalar.name] x_eq_y = semantic.equal(x, y, _builder) x_gt_0 = semantic.greater_than(x, 0, _builder) @@ -501,69 +482,68 @@ def cyl_bessel_i0(arg0: core.tensor, _builder: ir.builder): if arg0.dtype == core.dtype("fp16"): core.static_print("extern livdevice.cyl_bessel_i0 for dtype bf16 is unspported for now.") core.static_assert(False) - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_cyl_bessel_i0_fp32", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_cyl_bessel_i0_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) else: param1 = [ - -4.41534164647933937950e-18, - +3.33079451882223809783e-17, - -2.43127984654795469359e-16, - +1.71539128555513303061e-15, - -1.16853328779934516808e-14, - +7.67618549860493561688e-14, - -4.85644678311192946090e-13, - +2.95505266312963983461e-12, - -1.72682629144155570723e-11, - +9.67580903537323691224e-11, - -5.18979560163526290666e-10, - +2.65982372468238665035e-09, - -1.30002500998624804212e-08, - +6.04699502254191894932e-08, - -2.67079385394061173391e-07, - +1.11738753912010371815e-06, - -4.41673835845875056359e-06, - +1.64484480707288970893e-05, - -5.75419501008210370398e-05, - +1.88502885095841655729e-04, - -5.76375574538582365885e-04, - +1.63947561694133579842e-03, - -4.32430999505057594430e-03, - +1.05464603945949983183e-02, - -2.37374148058994688156e-02, - +4.93052842396707084878e-02, - -9.49010970480476444210e-02, - +1.71620901522208775349e-01, - -3.04682672343198398683e-01, - +6.76795274409476084995e-01, + -4.41534164647933937950e-18, + +3.33079451882223809783e-17, + -2.43127984654795469359e-16, + +1.71539128555513303061e-15, + -1.16853328779934516808e-14, + +7.67618549860493561688e-14, + -4.85644678311192946090e-13, + +2.95505266312963983461e-12, + -1.72682629144155570723e-11, + +9.67580903537323691224e-11, + -5.18979560163526290666e-10, + +2.65982372468238665035e-09, + -1.30002500998624804212e-08, + +6.04699502254191894932e-08, + -2.67079385394061173391e-07, + +1.11738753912010371815e-06, + -4.41673835845875056359e-06, + +1.64484480707288970893e-05, + -5.75419501008210370398e-05, + +1.88502885095841655729e-04, + -5.76375574538582365885e-04, + +1.63947561694133579842e-03, + -4.32430999505057594430e-03, + +1.05464603945949983183e-02, + -2.37374148058994688156e-02, + +4.93052842396707084878e-02, + -9.49010970480476444210e-02, + +1.71620901522208775349e-01, + -3.04682672343198398683e-01, + +6.76795274409476084995e-01, ] param2 = [ - -7.23318048787475395456e-18, - -4.83050448594418207126e-18, - +4.46562142029675999901e-17, - +3.46122286769746109310e-17, - -2.82762398051658348494e-16, - -3.42548561967721913462e-16, - +1.77256013305652638360e-15, - +3.81168066935262242075e-15, - -9.55484669882830764870e-15, - -4.15056934728722208663e-14, - +1.54008621752140982691e-14, - +3.85277838274214270114e-13, - +7.18012445138366623367e-13, - -1.79417853150680611778e-12, - -1.32158118404477131188e-11, - -3.14991652796324136454e-11, - +1.18891471078464383424e-11, - +4.94060238822496958910e-10, - +3.39623202570838634515e-09, - +2.26666899049817806459e-08, - +2.04891858946906374183e-07, - +2.89137052083475648297e-06, - +6.88975834691682398426e-05, - +3.36911647825569408990e-03, - +8.04490411014108831608e-01, + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + +4.46562142029675999901e-17, + +3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + +1.77256013305652638360e-15, + +3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + +1.54008621752140982691e-14, + +3.85277838274214270114e-13, + +7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + +1.18891471078464383424e-11, + +4.94060238822496958910e-10, + +3.39623202570838634515e-09, + +2.26666899049817806459e-08, + +2.04891858946906374183e-07, + +2.89137052083475648297e-06, + +6.88975834691682398426e-05, + +3.36911647825569408990e-03, + +8.04490411014108831608e-01, ] arg0 = semantic.to_tensor(arg0, _builder) abs_x = core.tensor(_builder.create_fabs(arg0.handle), arg0.type) @@ -603,14 +583,14 @@ def signbit(arg0, _builder=None): if triton_enable_libdevice_simt() and is_compile_on_910_95: return core.extern_elementwise( "", "", [arg0], { - (core.dtype("fp16"),): ("__hmf_signbit_fp16", core.dtype("int32")), - (core.dtype("fp32"),): ("__hmf_signbit_fp32", core.dtype("int32")), + (core.dtype("fp16"), ): ("__hmf_signbit_fp16", core.dtype("int32")), + (core.dtype("fp32"), ): ("__hmf_signbit_fp32", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: arg0_scalar_ty = arg0.type.scalar if arg0_scalar_ty == core.float32: int_ty = core.int32 - else: # arg0 type: float16 / bfloat16 + else: # arg0 type: float16 / bfloat16 int_ty = core.int16 arg0 = semantic.to_tensor(arg0, _builder) @@ -622,8 +602,7 @@ def signbit(arg0, _builder=None): shift = semantic.full(arg0.shape, shift, int_ty, _builder) sign_bit_tensor = semantic.lshr(int_tensor, shift, _builder) - sign_bit_tensor = semantic.and_( - sign_bit_tensor, semantic.full(arg0.shape, 1, int_ty, _builder), _builder) + sign_bit_tensor = semantic.and_(sign_bit_tensor, semantic.full(arg0.shape, 1, int_ty, _builder), _builder) return semantic.equal(sign_bit_tensor, 1, _builder) @@ -637,16 +616,14 @@ def signbit(arg0, _builder=None): @math._check_dtype(dtypes=["fp32"]) def erfinv(arg0, _builder=None): if triton_enable_libdevice_simt() and is_compile_on_910_95: - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_erfinv_fp32", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_erfinv_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) else: arg0_scalar_ty = arg0.type.scalar arg0 = semantic.to_tensor(arg0, _builder) - inv_sqrt_pi_times_2 = semantic.full( - arg0.shape, 1.128379167, arg0_scalar_ty, _builder).handle # 2 / sqrt(pi) + inv_sqrt_pi_times_2 = semantic.full(arg0.shape, 1.128379167, arg0_scalar_ty, _builder).handle # 2 / sqrt(pi) coeff_low_numerator = [-0.140543331, 0.914624893, -1.645349621, 0.886226899] coeff_low_denominator = [0.012229801, -0.329097515, 1.442710462, -2.118377725, 1.0] coeff_high_numerator = [1.641345311, 3.429567803, -1.624906493, -1.970840454] @@ -654,18 +631,17 @@ def erfinv(arg0, _builder=None): # low cal arg0_squared = _builder.create_fmul(arg0.handle, arg0.handle) - numerator_low_range = semantic.full( - arg0.shape, coeff_low_numerator[0], arg0_scalar_ty, _builder).handle + numerator_low_range = semantic.full(arg0.shape, coeff_low_numerator[0], arg0_scalar_ty, _builder).handle for i in range(1, len(coeff_low_numerator)): - numerator_low_range = _builder.create_fma(numerator_low_range, arg0_squared, + numerator_low_range = _builder.create_fma( + numerator_low_range, arg0_squared, semantic.full(arg0.shape, coeff_low_numerator[i], arg0_scalar_ty, _builder).handle) - denominator_low_range = semantic.full( - arg0.shape, coeff_low_denominator[0], arg0_scalar_ty, _builder).handle + denominator_low_range = semantic.full(arg0.shape, coeff_low_denominator[0], arg0_scalar_ty, _builder).handle for i in range(1, len(coeff_low_denominator)): denominator_low_range = _builder.create_fma( - denominator_low_range, arg0_squared, semantic.full( - arg0.shape, coeff_low_denominator[i], arg0_scalar_ty, _builder).handle) + denominator_low_range, arg0_squared, + semantic.full(arg0.shape, coeff_low_denominator[i], arg0_scalar_ty, _builder).handle) low_res = _builder.create_fmul(arg0.handle, _builder.create_fdiv(numerator_low_range, denominator_low_range)) @@ -677,94 +653,67 @@ def erfinv(arg0, _builder=None): _builder.create_fdiv( _builder.create_fsub( semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder).handle, - _builder.create_fabs(arg0.handle) - ), - semantic.full(arg0.shape, 2, arg0_scalar_ty, _builder).handle - ) - ) - ) - ) + _builder.create_fabs(arg0.handle)), + semantic.full(arg0.shape, 2, arg0_scalar_ty, _builder).handle)))) numerator_high_range = semantic.full(arg0.shape, coeff_high_numerator[0], arg0_scalar_ty, _builder).handle for i in range(1, len(coeff_high_numerator)): numerator_high_range = _builder.create_fma( - numerator_high_range, arg0_erf_trans, semantic.full( - arg0.shape, coeff_high_numerator[i], arg0_scalar_ty, _builder).handle) + numerator_high_range, arg0_erf_trans, + semantic.full(arg0.shape, coeff_high_numerator[i], arg0_scalar_ty, _builder).handle) denominator_high_range = semantic.full(arg0.shape, coeff_high_denominator[0], arg0_scalar_ty, _builder).handle for i in range(1, len(coeff_high_denominator)): denominator_high_range = _builder.create_fma( - denominator_high_range, arg0_erf_trans, semantic.full( - arg0.shape, coeff_high_denominator[i], arg0_scalar_ty, _builder).handle) + denominator_high_range, arg0_erf_trans, + semantic.full(arg0.shape, coeff_high_denominator[i], arg0_scalar_ty, _builder).handle) high_res = _builder.create_fdiv(numerator_high_range, denominator_high_range) high_res = semantic.mul( - semantic.where( - signbit(arg0, _builder=_builder), - semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder), - semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), - _builder), - core.tensor(high_res, arg0.type), True, _builder - ).handle + semantic.where(signbit(arg0, _builder=_builder), semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder), + semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder), + core.tensor(high_res, arg0.type), True, _builder).handle for _ in range(2): low_res = _builder.create_fsub( - low_res, _builder.create_fdiv( - _builder.create_fsub( - _builder.create_erf(low_res), arg0.handle - ), + low_res, + _builder.create_fdiv( + _builder.create_fsub(_builder.create_erf(low_res), arg0.handle), _builder.create_fmul( - inv_sqrt_pi_times_2, _builder.create_exp( + inv_sqrt_pi_times_2, + _builder.create_exp( _builder.create_fmul( semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, - _builder.create_fmul(low_res, low_res) - ) - ) - ) - ) - ) + _builder.create_fmul(low_res, low_res)))))) high_res = _builder.create_fsub( - high_res, _builder.create_fdiv( - _builder.create_fsub( - _builder.create_erf(high_res), arg0.handle - ), + high_res, + _builder.create_fdiv( + _builder.create_fsub(_builder.create_erf(high_res), arg0.handle), _builder.create_fmul( - inv_sqrt_pi_times_2, _builder.create_exp( + inv_sqrt_pi_times_2, + _builder.create_exp( _builder.create_fmul( semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle, - _builder.create_fmul(high_res, high_res) - ) - ) - ) - ) - ) + _builder.create_fmul(high_res, high_res)))))) arg0_abs = core.tensor(_builder.create_fabs(arg0.handle), arg0.type) # Check if |arg0| > 1 - arg0_over = semantic.greater_than( - arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder) + arg0_over = semantic.greater_than(arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder) nan_tensor = semantic.full(arg0.shape, float("nan"), arg0_scalar_ty, _builder) # Check if |arg0| = 1 - arg0_equal1 = semantic.equal( - arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder - ) + arg0_equal1 = semantic.equal(arg0_abs, semantic.full(arg0.shape, 1, arg0_scalar_ty, _builder), _builder) pos_inf_tensor = semantic.full(arg0.shape, float("inf"), arg0_scalar_ty, _builder) neg_inf_tensor = semantic.full(arg0.shape, float("-inf"), arg0_scalar_ty, _builder) - inf_res = semantic.where( - signbit(arg0, _builder=_builder), neg_inf_tensor, pos_inf_tensor, _builder - ) + inf_res = semantic.where(signbit(arg0, _builder=_builder), neg_inf_tensor, pos_inf_tensor, _builder) # Check if |arg0| >= 0.7 - arg0_high = semantic.greater_equal( - arg0_abs, semantic.full(arg0.shape, 0.7, arg0_scalar_ty, _builder), _builder - ) + arg0_high = semantic.greater_equal(arg0_abs, semantic.full(arg0.shape, 0.7, arg0_scalar_ty, _builder), _builder) return semantic.where( - arg0_equal1, inf_res, semantic.where( - arg0_over, nan_tensor, semantic.where( - arg0_high, core.tensor(high_res, arg0.type), core.tensor(low_res, arg0.type), _builder - ), _builder - ), _builder - ) + arg0_equal1, inf_res, + semantic.where( + arg0_over, nan_tensor, + semantic.where(arg0_high, core.tensor(high_res, arg0.type), core.tensor(low_res, arg0.type), _builder), + _builder), _builder) # Note: @@ -786,9 +735,7 @@ def gamma(arg0, _builder=None): -0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7 ] condition = semantic.less_than(arg0, 0.5, _builder) # 1 - x = x -> x = 0.5 - reflect_arg0 = semantic.where( - condition, semantic.sub(1, arg0, True, _builder), arg0, _builder - ) + reflect_arg0 = semantic.where(condition, semantic.sub(1, arg0, True, _builder), arg0, _builder) x = semantic.full(arg0.shape, 0.99999999999980993, arg0_scalar_ty, _builder) for i in range(0, len(lanczos_coeff)): @@ -799,42 +746,31 @@ def gamma(arg0, _builder=None): t = semantic.add(reflect_arg0, 6.5, True, _builder) gamma_res = _builder.create_fmul( + _builder.create_fmul(sqrt_2pi_tensor, + pow(t, semantic.sub(reflect_arg0, 0.5, True, _builder), _builder=_builder).handle), _builder.create_fmul(sqrt_2pi_tensor, pow(t, semantic.sub(reflect_arg0, 0.5, True, _builder), _builder=_builder).handle), _builder.create_fmul( - sqrt_2pi_tensor, pow( - t, semantic.sub(reflect_arg0, 0.5, True, _builder), _builder=_builder - ).handle - ), - _builder.create_fmul( - x.handle, _builder.create_exp( - _builder.create_fmul( - t.handle, semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle - ) - ) - ) - ) - - gamma_res_reflect = _builder.create_fdiv( - _builder.create_fdiv(pi_tensor, gamma_res), - _builder.create_sin(_builder.create_fmul(pi_tensor, arg0.handle)) - ) - - is_neg_int = semantic.logical_and( - semantic.equal(math.floor(arg0, _builder=_builder), arg0, _builder), - semantic.less_than(arg0, 0, _builder), _builder - ) + x.handle, + _builder.create_exp( + _builder.create_fmul(t.handle, + semantic.full(arg0.shape, -1, arg0_scalar_ty, _builder).handle)))) + + gamma_res_reflect = _builder.create_fdiv(_builder.create_fdiv(pi_tensor, gamma_res), + _builder.create_sin(_builder.create_fmul(pi_tensor, arg0.handle))) + + is_neg_int = semantic.logical_and(semantic.equal(math.floor(arg0, _builder=_builder), arg0, _builder), + semantic.less_than(arg0, 0, _builder), _builder) pos_inf_tensor = semantic.full(arg0.shape, float('inf'), arg0_scalar_ty, _builder) neg_inf_tensor = semantic.full(arg0.shape, float('-inf'), arg0_scalar_ty, _builder) - gamma_res_reflect = semantic.where( - is_neg_int, pos_inf_tensor, core.tensor(gamma_res_reflect, arg0.type), _builder) + gamma_res_reflect = semantic.where(is_neg_int, pos_inf_tensor, core.tensor(gamma_res_reflect, arg0.type), _builder) res = semantic.where(condition, gamma_res_reflect, core.tensor(gamma_res, arg0.type), _builder) is_pos_inf_input = semantic.equal(arg0, pos_inf_tensor, _builder) is_neg_inf_input = semantic.equal(arg0, neg_inf_tensor, _builder) - return semantic.where(is_pos_inf_input, pos_inf_tensor, semantic.where( - is_neg_inf_input, neg_inf_tensor, res, _builder), _builder) + return semantic.where(is_pos_inf_input, pos_inf_tensor, + semantic.where(is_neg_inf_input, neg_inf_tensor, res, _builder), _builder) # Note: @@ -849,18 +785,15 @@ def gamma(arg0, _builder=None): @math._check_dtype(dtypes=["fp32"]) def lgamma(arg0, _builder=None): if triton_enable_libdevice_simt() and is_compile_on_910_95: - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"), ): ("__hmf_lgamma_fp32", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_lgamma_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) else: arg0_scalar_ty = arg0.type.scalar arg0 = semantic.to_tensor(arg0, _builder) inf_tensor = semantic.full(arg0.shape, float('inf'), arg0_scalar_ty, _builder) - is_inf = semantic.equal( - core.tensor(_builder.create_fabs(arg0.handle), arg0.type), inf_tensor, _builder - ) + is_inf = semantic.equal(core.tensor(_builder.create_fabs(arg0.handle), arg0.type), inf_tensor, _builder) gamma_res = _builder.create_fabs(gamma(arg0, _builder=_builder).handle) lgamma_res = _builder.create_log(gamma_res) @@ -874,10 +807,9 @@ def lgamma(arg0, _builder=None): @math._add_math_1arg_docstr("nearbyint") def nearbyint(arg0: core.tensor, _builder: ir.builder): if triton_enable_libdevice_simt() and is_compile_on_910_95: - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"),): ("__hmf_nearbyint_fp32", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_nearbyint_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) else: """ Round argument x to an integer value in floating-point format. @@ -912,14 +844,11 @@ def nearbyint(arg0: core.tensor, _builder: ir.builder): is_even = semantic.equal(basic_round, double_half, _builder) - adjustment = semantic.where(is_positive, - semantic.full(arg0.shape, -1.0, arg0.type.scalar, _builder), - semantic.full(arg0.shape, 1.0, arg0.type.scalar, _builder), - _builder) + adjustment = semantic.where(is_positive, semantic.full(arg0.shape, -1.0, arg0.type.scalar, _builder), + semantic.full(arg0.shape, 1.0, arg0.type.scalar, _builder), _builder) - banker_result = semantic.where(is_even, basic_round, - semantic.add(basic_round, adjustment, True, _builder), - _builder) + banker_result = semantic.where(is_even, basic_round, semantic.add(basic_round, adjustment, True, _builder), + _builder) # Final result: Use banker's rounding for cases exactly at 0.5, otherwise use basic rounding. return semantic.where(is_half, banker_result, basic_round, _builder) @@ -934,8 +863,8 @@ def asin(arg0: core.tensor, _builder: ir.builder): if triton_enable_libdevice_simt() and is_compile_on_910_95: return core.extern_elementwise( "", "", [arg0], { - (core.dtype("fp16"),): ("__hmf_asin_fp16", core.dtype("fp16")), - (core.dtype("fp32"),): ("__hmf_asin_fp32", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_asin_fp16", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_asin_fp32", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: """ @@ -959,10 +888,9 @@ def asin(arg0: core.tensor, _builder: ir.builder): @math._add_math_1arg_docstr("base-10 logarithm") def log10(arg0: core.tensor, _builder: ir.builder): if triton_enable_libdevice_simt() and is_compile_on_910_95: - return core.extern_elementwise( - "", "", [arg0], { - (core.dtype("fp32"),): ("__hmf_log10_fp32", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_log10_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) else: """ Calculate the base 10 logarithm of the input argument x. @@ -985,10 +913,9 @@ def log10(arg0: core.tensor, _builder: ir.builder): @math._add_math_2arg_docstr("copysign") def copysign(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): if triton_enable_libdevice_simt() and is_compile_on_910_95: - return core.extern_elementwise( - "", "", [arg0, arg1], { - (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_copysign_fp32", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_copysign_fp32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) else: """ Create a floating-point value with the magnitude of x and the sign of y. @@ -1009,7 +936,8 @@ def copysign(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): is_negative_nonzero = semantic.less_than(y, zero, _builder) is_negative = semantic.or_(is_negative_zero, is_negative_nonzero, _builder) - neg_magnitude = semantic.mul(magnitude, semantic.full(magnitude.shape, -1.0, magnitude.type.scalar, _builder), True, _builder) + neg_magnitude = semantic.mul(magnitude, semantic.full(magnitude.shape, -1.0, magnitude.type.scalar, _builder), + True, _builder) return semantic.where(is_negative, neg_magnitude, magnitude, _builder) diff --git a/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp b/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp index 8d9129cd5e..fa82404dc8 100644 --- a/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp +++ b/third_party/ascend/lib/AutoBlockify/AutoBlockify.cpp @@ -103,9 +103,8 @@ PropagateUnrealizedCastDown::matchAndRewrite(UnrealizedConversionCastOp op, } else if (auto conditionOp = dyn_cast(user)) { rewriteCondition(op, conditionOp, rewriter); } else if (user->hasTrait() || - isa(user)) { + isa(user)) { rewriteGeneraleOp(op, user, rewriter); } else if (isa(user)) { auto *newOp = @@ -147,7 +146,8 @@ bool AutoBlockifyPass::checkBlockifiable(Value v) { auto &os = llvm::dbgs(); os << "User:\n" << *user << "\n"; }); - if (isa(user) || + if (isa( + user) || llvm::any_of(user->getOperandTypes(), isTensorPtrType)) return false; if (auto ifOp = dyn_cast(user)) { @@ -359,4 +359,4 @@ void AutoBlockifyPass::runOnOperation() { std::unique_ptr> triton::createAutoBlockifyPass(const AutoBlockifyOptions &options) { return std::make_unique(options); -} \ No newline at end of file +} diff --git a/third_party/ascend/lib/AutoBlockify/CMakeLists.txt b/third_party/ascend/lib/AutoBlockify/CMakeLists.txt index a0ccd59b2e..20ffce4753 100644 --- a/third_party/ascend/lib/AutoBlockify/CMakeLists.txt +++ b/third_party/ascend/lib/AutoBlockify/CMakeLists.txt @@ -19,4 +19,4 @@ add_triton_library(AutoBlockify MLIRTransforms MLIRSupport MLIRSCFTransforms -) \ No newline at end of file +) diff --git a/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp b/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp index f3f71ea7b6..0b610ebf0c 100644 --- a/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp +++ b/third_party/ascend/lib/AutoBlockify/RewriteOperation.cpp @@ -505,4 +505,4 @@ void PropagateUnrealizedCastDown::rewriteCondition( newUccOp = rewriter.create( arg.getLoc(), oldArgType, ValueRange({arg, mask})); rewriter.replaceAllUsesExcept(arg, newUccOp->getResult(0), newUccOp); -} \ No newline at end of file +} diff --git a/third_party/ascend/lib/AutoBlockify/Utils.cpp b/third_party/ascend/lib/AutoBlockify/Utils.cpp index 6e49e6b0f2..c6deed7c78 100644 --- a/third_party/ascend/lib/AutoBlockify/Utils.cpp +++ b/third_party/ascend/lib/AutoBlockify/Utils.cpp @@ -207,4 +207,4 @@ std::optional getBlockifyLoop(Operation *op) { op = forOp; } return std::nullopt; -} \ No newline at end of file +} diff --git a/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp b/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp index a1fc685e17..7184b0c402 100644 --- a/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp +++ b/third_party/ascend/lib/Conversion/TritonToLLVM/TritonToLLVM.cpp @@ -36,10 +36,9 @@ static Type getElementType(Value value) { return type; } -static int64_t getTensorNumElements(Value tensor) -{ - auto type = mlir::cast(tensor.getType()); - return type.getNumElements(); +static int64_t getTensorNumElements(Value tensor) { + auto type = mlir::cast(tensor.getType()); + return type.getNumElements(); } static Value getInt32Value(RewriterBase &rewriter, Location loc, int val) { @@ -77,25 +76,27 @@ SmallVector packOperands(mlir::triton::ElementwiseInlineAsmOp op, return packedOperands; } -static SmallVector unpackElements(Location loc, Value packedValues, RewriterBase &rewriter) -{ - auto type = mlir::cast(packedValues.getType()); - auto elementType = type.getElementType(); - auto shape = type.getShape(); - - int64_t numElements = type.getNumElements(); - - SmallVector result; - for (int64_t linearIdx = 0; linearIdx < numElements; linearIdx++) { - SmallVector indexes(shape.size()); - int64_t remaining = linearIdx; - for (int64_t dim = shape.size() - 1; dim >= 0; dim--) { - indexes[dim] = rewriter.create(loc, remaining % shape[dim]); - remaining /= shape[dim]; - } - Value extracted = rewriter.create(loc, elementType, packedValues, indexes); - result.push_back(extracted); +static SmallVector unpackElements(Location loc, Value packedValues, + RewriterBase &rewriter) { + auto type = mlir::cast(packedValues.getType()); + auto elementType = type.getElementType(); + auto shape = type.getShape(); + + int64_t numElements = type.getNumElements(); + + SmallVector result; + for (int64_t linearIdx = 0; linearIdx < numElements; linearIdx++) { + SmallVector indexes(shape.size()); + int64_t remaining = linearIdx; + for (int64_t dim = shape.size() - 1; dim >= 0; dim--) { + indexes[dim] = + rewriter.create(loc, remaining % shape[dim]); + remaining /= shape[dim]; } + Value extracted = rewriter.create(loc, elementType, + packedValues, indexes); + result.push_back(extracted); + } return result; } @@ -176,8 +177,7 @@ createDestOps(triton::ElementwiseInlineAsmOp op, RewriterBase &rewriter, } static LogicalResult processScalarInlineAsm(triton::ElementwiseInlineAsmOp op, - PatternRewriter &rewriter) -{ + PatternRewriter &rewriter) { Location loc = op.getLoc(); auto outsWrapped = createDestOps(op, rewriter, {}, loc); @@ -192,8 +192,7 @@ static LogicalResult processScalarInlineAsm(triton::ElementwiseInlineAsmOp op, } static LogicalResult processVectorInlineAsm(triton::ElementwiseInlineAsmOp op, - PatternRewriter &rewriter) -{ + PatternRewriter &rewriter) { Location loc = op.getLoc(); SmallVector> unpackedOperands; @@ -239,15 +238,15 @@ static LogicalResult processVectorInlineAsm(triton::ElementwiseInlineAsmOp op, } // namespace -struct ElementwiseInlineAsmOpConversion : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ElementwiseInlineAsmOpConversion + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(triton::ElementwiseInlineAsmOp op, - PatternRewriter &rewriter) const final - { - return op.getOperands().empty() ? processScalarInlineAsm(op, rewriter) - : processVectorInlineAsm(op, rewriter); - } + LogicalResult matchAndRewrite(triton::ElementwiseInlineAsmOp op, + PatternRewriter &rewriter) const final { + return op.getOperands().empty() ? processScalarInlineAsm(op, rewriter) + : processVectorInlineAsm(op, rewriter); + } }; void TritonToLLVMPass::runOnOperation() { diff --git a/third_party/ascend/lib/TritonAffinityOpt/CMakeLists.txt b/third_party/ascend/lib/TritonAffinityOpt/CMakeLists.txt index 2f49f3f0b4..925ca52d92 100644 --- a/third_party/ascend/lib/TritonAffinityOpt/CMakeLists.txt +++ b/third_party/ascend/lib/TritonAffinityOpt/CMakeLists.txt @@ -16,4 +16,4 @@ add_triton_library(TritonAffinityOpt MLIRSupport TritonIR MLIRSCFDialect -) \ No newline at end of file +) diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAG.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAG.cpp index b796f8eb96..d0222255e1 100644 --- a/third_party/ascend/lib/TritonAffinityOpt/DAG.cpp +++ b/third_party/ascend/lib/TritonAffinityOpt/DAG.cpp @@ -1,5 +1,7 @@ #include "TritonAffinityOpt/DAG.h" +#include "bishengir/Dialect/Annotation/IR/Annotation.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OperationSupport.h" @@ -11,7 +13,6 @@ #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "triton/Dialect/Triton/IR/Dialect.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "triton/Dialect/Triton/IR/Types.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -30,41 +31,36 @@ #include #include #include -#include "bishengir/Dialect/Annotation/IR/Annotation.h" -namespace mlir { namespace AffinityDAG { +namespace mlir { +namespace AffinityDAG { -const auto printFlags = OpPrintingFlags() - .enableDebugInfo(true, true) - .skipRegions(); +const auto printFlags = + OpPrintingFlags().enableDebugInfo(true, true).skipRegions(); -const char* literalCoreType(CoreType ct) { +const char *literalCoreType(CoreType ct) { switch (ct) { - case VECTOR_ONLY: - return "VECTOR_ONLY"; - case CUBE_ONLY: - return "CUBE_ONLY"; - case CUBE_AND_VECTOR: - return "CUBE_AND_VECTOR"; - case UNDETERMINED: - return "UNDETERMINED"; + case VECTOR_ONLY: + return "VECTOR_ONLY"; + case CUBE_ONLY: + return "CUBE_ONLY"; + case CUBE_AND_VECTOR: + return "CUBE_AND_VECTOR"; + case UNDETERMINED: + return "UNDETERMINED"; } return "Unknown"; } - -bool opIsScf(Operation* op) { +bool opIsScf(Operation *op) { if (!llvm::isa(op->getDialect())) return false; return true; } -Graph::Graph(Block* block, Graph* parent, OpMap opMap, ValueMap valueMap, bool inheritParent) : - block(block), - parent(parent), - opMap(opMap), - valueMap(valueMap) -{ +Graph::Graph(Block *block, Graph *parent, OpMap opMap, ValueMap valueMap, + bool inheritParent) + : block(block), parent(parent), opMap(opMap), valueMap(valueMap) { if (parent && inheritParent) { if (!this->opMap) { @@ -84,12 +80,12 @@ Graph::Graph(Block* block, Graph* parent, OpMap opMap, ValueMap valueMap, bool i this->valueMap = std::make_shared(); } - for(auto blockArg : block->getArguments()) { + for (auto blockArg : block->getArguments()) { (*this->valueMap)[blockArg] = std::make_unique(blockArg); blockArgs.push_back((*this->valueMap)[blockArg].get()); } - for(auto& opRef : block->getOperations()) { + for (auto &opRef : block->getOperations()) { opCount += 1; auto op = &opRef; auto opNodeUnique = std::make_unique(op, this); @@ -100,7 +96,7 @@ Graph::Graph(Block* block, Graph* parent, OpMap opMap, ValueMap valueMap, bool i terminator = opNode; } - for (auto& subgraph : opNode->subgraphs) { + for (auto &subgraph : opNode->subgraphs) { opCount += subgraph.opCount; } } @@ -140,63 +136,59 @@ OpAbility OpNode::canRunOn() const { if (opIsScf(op)) { return OpAbility::CUBE_AND_VECTOR; } - return llvm::TypeSwitch(op) - .Case([](auto) { - return OpAbility::CUBE_ONLY; - }) - .Case([](auto) { - return OpAbility::CUBE_AND_VECTOR; - }) - .Case([](arith::SelectOp op) { - // when cond is vector, selectOp should be vector, otherwise scalar - return ( - valueIsScalar(op.getCondition()) ? OpAbility::CUBE_AND_VECTOR : OpAbility::PREFER_VECTOR - ); - }) - .Default([](Operation* op) { - auto isVector = false; - for(auto operand : op->getOperands()) { - if (!valueIsScalar(operand)) { - // if (valueIsTensorOfPtr(operand)) { - // return SCALAR; - // } - isVector = true; + return llvm::TypeSwitch(op) + .Case([](auto) { return OpAbility::CUBE_ONLY; }) + .Case([](auto) { return OpAbility::CUBE_AND_VECTOR; }) + .Case([](arith::SelectOp op) { + // when cond is vector, selectOp should be vector, otherwise scalar + return (valueIsScalar(op.getCondition()) ? OpAbility::CUBE_AND_VECTOR + : OpAbility::PREFER_VECTOR); + }) + .Default([](Operation *op) { + auto isVector = false; + for (auto operand : op->getOperands()) { + if (!valueIsScalar(operand)) { + // if (valueIsTensorOfPtr(operand)) { + // return SCALAR; + // } + isVector = true; + } } - } - for(auto result : op->getResults()) { - if (!valueIsScalar(result)) { - // if (valueIsTensorOfPtr(result)) { - // return SCALAR; - // } - isVector = true; + for (auto result : op->getResults()) { + if (!valueIsScalar(result)) { + // if (valueIsTensorOfPtr(result)) { + // return SCALAR; + // } + isVector = true; + } } - } - if (isVector) { - return OpAbility::PREFER_VECTOR; - } + if (isVector) { + return OpAbility::PREFER_VECTOR; + } - return OpAbility::CUBE_AND_VECTOR; - }); + return OpAbility::CUBE_AND_VECTOR; + }); } -OpNode::OpNode(Operation* op, Graph* graph) : Node(Node::NK_Op), op(op) { +OpNode::OpNode(Operation *op, Graph *graph) : Node(Node::NK_Op), op(op) { if (op == nullptr) { return; } llvm::outs() << op << "\n"; - auto& valueMap = *graph->valueMap.get(); - auto& opMap = *graph->opMap.get(); - for(const auto operand : op->getOperands()) { + auto &valueMap = *graph->valueMap.get(); + auto &opMap = *graph->opMap.get(); + for (const auto operand : op->getOperands()) { auto valueNode = valueMap.at(operand).get(); valueNode->outputs.push_back(this); inputs.push_back(valueNode); } - for(const auto& result : op->getResults()) { + for (const auto &result : op->getResults()) { auto valueNodeUnique = std::make_unique(result); auto valueNode = valueNodeUnique.get(); valueMap[result] = std::move(valueNodeUnique); @@ -205,39 +197,39 @@ OpNode::OpNode(Operation* op, Graph* graph) : Node(Node::NK_Op), op(op) { } // if (!op->hasTrait()) { - // llvm::dbgs() << "Not building subgraph because op is not SingleBlock: " << op << '\n'; - // return; + // llvm::dbgs() << "Not building subgraph because op is not SingleBlock: " + // << op << '\n'; return; // } if (auto branchOp = llvm::dyn_cast(op)) { - OpNode* terminator = nullptr; - llvm::SmallVector, 2> validRegions; + OpNode *terminator = nullptr; + llvm::SmallVector, 2> validRegions; - for(auto& region : branchOp->getRegions()) { + for (auto ®ion : branchOp->getRegions()) { if (region.getBlocks().empty()) continue; subgraphs.emplace_back(®ion.getBlocks().front(), graph); validRegions.emplace_back(region, subgraphs.back()); } - for(auto [region, subgraph] : validRegions) { + for (auto [region, subgraph] : validRegions) { SmallVector succRegions; branchOp.getSuccessorRegions(region, succRegions); - if (auto currTerminator = dyn_cast(subgraph.terminator->op)) { - for(auto& succ : succRegions) { + if (auto currTerminator = dyn_cast( + subgraph.terminator->op)) { + for (auto &succ : succRegions) { auto forwardedVal = currTerminator.getSuccessorOperands(succ); if (succ.isParent()) { // Step1: first yield to parent -> results: double direction if (!terminator && subgraph.terminator) { terminator = subgraph.terminator; - for(auto [forwardedVal, resultNode] : llvm::zip_equal( - forwardedVal, - outputs - )) { + for (auto [forwardedVal, resultNode] : + llvm::zip_equal(forwardedVal, outputs)) { auto resultValueNode = llvm::dyn_cast(resultNode); - assert(resultValueNode && "Output of a OpNode should be ValueNode!"); + assert(resultValueNode && + "Output of a OpNode should be ValueNode!"); auto forwardedNode = valueMap[forwardedVal].get(); resultValueNode->source = forwardedNode; forwardedNode->outputs.push_back(resultNode); @@ -248,10 +240,8 @@ OpNode::OpNode(Operation* op, Graph* graph) : Node(Node::NK_Op), op(op) { // Step2: Region terminator -> Succ Operands auto succRegion = succ.getSuccessor(); - for(auto [operand, succInput] : llvm::zip_equal( - forwardedVal, - succ.getSuccessorInputs() - )) { + for (auto [operand, succInput] : + llvm::zip_equal(forwardedVal, succ.getSuccessorInputs())) { auto forwardedNode = valueMap[operand].get(); auto succNode = valueMap[succInput].get(); forwardedNode->outputs.push_back(succNode); @@ -263,13 +253,17 @@ OpNode::OpNode(Operation* op, Graph* graph) : Node(Node::NK_Op), op(op) { } if (auto loopOp = llvm::dyn_cast(op)) { - // Step3: inits->iter_args (single directional) (should be handled in step 2: ) last terminator -> iter_args (bidirectional) - for(auto [init, iterArgVal] : llvm::zip_equal(loopOp.getInits(), loopOp.getRegionIterArgs())) { - auto& initNode = valueMap[init]; - auto& iterArgNode = valueMap[iterArgVal]; + // Step3: inits->iter_args (single directional) (should be handled in step + // 2: ) last terminator -> iter_args (bidirectional) + for (auto [init, iterArgVal] : + llvm::zip_equal(loopOp.getInits(), loopOp.getRegionIterArgs())) { + auto &initNode = valueMap[init]; + auto &iterArgNode = valueMap[iterArgVal]; initNode->outputs.push_back(iterArgNode.get()); } - // for(auto [init, iterArgVal, yieldNode] : llvm::zip_equal(loopOp.getInits(), loopOp.getRegionIterArgs(), terminator->outputs)) { + // for(auto [init, iterArgVal, yieldNode] : + // llvm::zip_equal(loopOp.getInits(), loopOp.getRegionIterArgs(), + // terminator->outputs)) { // auto& initNode = valueMap[init]; // auto& iterArgNode = valueMap[iterArgVal]; // initNode->outputs.push_back(iterArgNode.get()); @@ -295,16 +289,17 @@ OpNode::OpNode(Operation* op, Graph* graph) : Node(Node::NK_Op), op(op) { // return 0; // }; -// std::stable_sort(result.begin(), result.end(), [&](ValueNode* a, ValueNode* b) { +// std::stable_sort(result.begin(), result.end(), [&](ValueNode* a, ValueNode* +// b) { // return getPriority(a) < getPriority(b); // }); // return result; // } -ValueNode* getWriteDataSource(OpNode* op) { +ValueNode *getWriteDataSource(OpNode *op) { auto inputRange = op->getInputs(); - for(auto node : inputRange.drop_front()) { + for (auto node : inputRange.drop_front()) { auto typ = getElementTypeOrSelf(node->value); if (!typ.isInteger(1)) { return node; @@ -314,11 +309,7 @@ ValueNode* getWriteDataSource(OpNode* op) { return nullptr; } -enum class MemPolicy { - NONE, - READ, - WRITE -}; +enum class MemPolicy { NONE, READ, WRITE }; CoreType Node::absorbCommon() { @@ -327,7 +318,7 @@ CoreType Node::absorbCommon() { if (!sourceNode || !op) { CoreType newCoreType = isOnPrivate; - for(auto output : outputs) { + for (auto output : outputs) { newCoreType = newCoreType | output->isOn(); isUpstreamOfCubeMem = isUpstreamOfCubeMem || output->isUpstreamOfCubeMem; } @@ -346,7 +337,8 @@ CoreType Node::absorbCommon() { auto memPolicy = MemPolicy::NONE; if (memIface) { - // Possible improvements: Determine the policy to use based on shapes, inputs and outputs, etc + // Possible improvements: Determine the policy to use based on shapes, + // inputs and outputs, etc if (memIface.hasEffect()) { memPolicy = MemPolicy::WRITE; } else if (memIface.hasEffect()) { @@ -369,29 +361,24 @@ CoreType Node::absorbCommon() { return VECTOR_ONLY; } - for(auto output : outputs) { + for (auto output : outputs) { switch (output->isOn()) { - case CUBE_AND_VECTOR: - newCoreType = newCoreType | VECTOR_ONLY; - // not breaking the switch because we need to handle cube - case CUBE_ONLY: - if ( - ability != OpAbility::PREFER_VECTOR || - output->isUpstreamOfCubeMem || - memPolicy == MemPolicy::READ - ) { - isUpstreamOfCubeMem = ( - isUpstreamOfCubeMem || - output->isUpstreamOfCubeMem || - memPolicy == MemPolicy::READ - ); - newCoreType = newCoreType | CUBE_ONLY; - } - break; - case VECTOR_ONLY: - newCoreType = newCoreType | VECTOR_ONLY; - default: // UNDETERMINED, skip - break; + case CUBE_AND_VECTOR: + newCoreType = newCoreType | VECTOR_ONLY; + // not breaking the switch because we need to handle cube + case CUBE_ONLY: + if (ability != OpAbility::PREFER_VECTOR || output->isUpstreamOfCubeMem || + memPolicy == MemPolicy::READ) { + isUpstreamOfCubeMem = + (isUpstreamOfCubeMem || output->isUpstreamOfCubeMem || + memPolicy == MemPolicy::READ); + newCoreType = newCoreType | CUBE_ONLY; + } + break; + case VECTOR_ONLY: + newCoreType = newCoreType | VECTOR_ONLY; + default: // UNDETERMINED, skip + break; }; } @@ -414,9 +401,7 @@ CoreType OpNode::absorbImpl() { return newCoreType; } -CoreType ValueNode::absorbImpl() { - return absorbCommon(); -} +CoreType ValueNode::absorbImpl() { return absorbCommon(); } std::unique_ptr Graph::fromMultiBlockFunc(triton::FuncOp funcOp) { @@ -425,37 +410,35 @@ std::unique_ptr Graph::fromMultiBlockFunc(triton::FuncOp funcOp) { auto dummyNode = std::make_unique(nullptr, dummyGraph.get()); size_t opCount = 0; - for (auto& block : funcOp.getBody()) { - auto& subgraph = dummyNode->subgraphs.emplace_back( - &block, - dummyGraph.get() - ); + for (auto &block : funcOp.getBody()) { + auto &subgraph = + dummyNode->subgraphs.emplace_back(&block, dummyGraph.get()); opCount += subgraph.opCount; } - auto& opMap = *dummyGraph->opMap.get(); - auto& valueMap = *dummyGraph->valueMap.get(); + auto &opMap = *dummyGraph->opMap.get(); + auto &valueMap = *dummyGraph->valueMap.get(); - llvm::SmallVector nodes; + llvm::SmallVector nodes; nodes.reserve(opMap.size() + valueMap.size()); - for(auto& [_, node] : opMap) { + for (auto &[_, node] : opMap) { if (node.get()) nodes.push_back(node.get()); } - for(auto& [_, node] : valueMap) { + for (auto &[_, node] : valueMap) { if (node.get()) nodes.push_back(node.get()); } auto diffuse = [&]() { // Not sure if determinism is required - llvm::SmallSetVector worklist(nodes.begin(), nodes.end()); + llvm::SmallSetVector worklist(nodes.begin(), nodes.end()); size_t threshold = worklist.size() * 5; - for(size_t i = 0; i< threshold; i++) { + for (size_t i = 0; i < threshold; i++) { if (worklist.empty()) { break; } @@ -471,7 +454,7 @@ std::unique_ptr Graph::fromMultiBlockFunc(triton::FuncOp funcOp) { diffuse(); - for(auto node : nodes) { + for (auto node : nodes) { if (node->isOn() == UNDETERMINED) { node->isOnPrivate = VECTOR_ONLY; } @@ -482,29 +465,28 @@ std::unique_ptr Graph::fromMultiBlockFunc(triton::FuncOp funcOp) { OpPrintingFlags flags; flags.skipRegions(); - for(auto [idx, node] : llvm::enumerate(nodes)) { - llvm::TypeSwitch(node) - .Case([&, idx=idx](OpNode* node) { - if (node->op) { - llvm::dbgs() << llvm::formatv("\n\n====== OpNode on: {1} @ {0} ======\n", - node->op, - literalCoreType(node->isOn()) - ); - node->op->print(llvm::dbgs(), flags); - llvm::dbgs() << "\nAbility: " << literalCoreType(toCoreType(node->canRunOn())); - llvm::dbgs() << llvm::formatv("\n====== {0} ======\n", node->op); - } - }) - .Case([&, idx=idx](ValueNode* node) { - if (node->value) { - llvm::dbgs() << llvm::formatv("\n\n====== ValueNode on {1} @ {0} ======\n", - node->value, - literalCoreType(node->isOn()) - ); - node->value.print(llvm::dbgs(), flags); - llvm::dbgs() << llvm::formatv("\n====== {0} ======\n", node->value); - } - }); + for (auto [idx, node] : llvm::enumerate(nodes)) { + llvm::TypeSwitch(node) + .Case([&, idx = idx](OpNode *node) { + if (node->op) { + llvm::dbgs() << llvm::formatv( + "\n\n====== OpNode on: {1} @ {0} ======\n", node->op, + literalCoreType(node->isOn())); + node->op->print(llvm::dbgs(), flags); + llvm::dbgs() << "\nAbility: " + << literalCoreType(toCoreType(node->canRunOn())); + llvm::dbgs() << llvm::formatv("\n====== {0} ======\n", node->op); + } + }) + .Case([&, idx = idx](ValueNode *node) { + if (node->value) { + llvm::dbgs() << llvm::formatv( + "\n\n====== ValueNode on {1} @ {0} ======\n", node->value, + literalCoreType(node->isOn())); + node->value.print(llvm::dbgs(), flags); + llvm::dbgs() << llvm::formatv("\n====== {0} ======\n", node->value); + } + }); // if (auto opNode = llvm::dyn_cast(node)) { // if (auto forOp = llvm::dyn_cast_if_present(opNode->op)) { // llvm::dbgs() << "\n==== ForOp ====\n"; @@ -522,7 +504,8 @@ std::unique_ptr Graph::fromMultiBlockFunc(triton::FuncOp funcOp) { // } // llvm::dbgs() << "\n---- Results ----\n"; // for(auto result : forOp.getResults()) { - // llvm::dbgs() << result.getResultNumber() << ' ' << literalCoreType(valueMap[result]->isOn()) << '\n'; + // llvm::dbgs() << result.getResultNumber() << ' ' << + // literalCoreType(valueMap[result]->isOn()) << '\n'; // } // } // } @@ -531,4 +514,5 @@ std::unique_ptr Graph::fromMultiBlockFunc(triton::FuncOp funcOp) { return dummyGraph; }; -} } +} // namespace AffinityDAG +} // namespace mlir diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp index baa85cc2f9..21b1af55a2 100644 --- a/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGSSBuffer.cpp @@ -22,30 +22,30 @@ #include "TritonAffinityOpt/Passes.h" -#include "bishengir/Dialect/Scope/IR/Scope.h" #include "bishengir/Dialect/HIVM/IR/HIVM.h" #include "bishengir/Dialect/HIVM/IR/HIVMImpl.h" -#include "bishengir/Dialect/HIVM/Transforms/Passes.h" #include "bishengir/Dialect/HIVM/IR/HIVMInterfaces.h" +#include "bishengir/Dialect/HIVM/Transforms/Passes.h" #include "bishengir/Dialect/HIVM/Utils/Utils.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "bishengir/Dialect/Scope/IR/Scope.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Block.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include // #include "mlir/Pass/Pass.h" @@ -66,8 +66,7 @@ using namespace hivm; namespace { struct DAGSSBufferPass - : public mlir::triton::impl::DAGSSBufferBase< - DAGSSBufferPass> { + : public mlir::triton::impl::DAGSSBufferBase { void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -77,1140 +76,1093 @@ struct DAGSSBufferPass } // namespace void ControlSsbufV2(ModuleOp module) { - mlir::OpBuilder builder(module.getContext()); - // 用于记录已经处理过的scope.scope操作 - llvm::DenseSet processedScopes; - - auto aiCAttr = hivm::TCoreTypeAttr::get( - builder.getContext(), - hivm::TCoreType::CUBE); - int cubeControlIndex = 15; - int vectorControlIndex = 14; - - llvm::DenseSet processedScopes2; - module->walk([&](SyncBlockWaitOp op) { - auto pipeS = hivm::PipeAttr::get(op->getContext(), hivm::PIPE::PIPE_S); - if (op.getTpipe() == pipeS || op.getPipe() == pipeS) { - return; - } + mlir::OpBuilder builder(module.getContext()); + // 用于记录已经处理过的scope.scope操作 + llvm::DenseSet processedScopes; - // 向上查找父scope.scope操作 - mlir::Operation* parentOp = op->getParentOp(); - mlir::Operation* scopeOp = nullptr; - mlir::Operation* forOp = nullptr; - - // 向上遍历查找scope.scope操作 - while (parentOp) { - if (dyn_cast(parentOp)) { - scopeOp = parentOp; - break; - } - parentOp = parentOp->getParentOp(); - } - parentOp = op->getParentOp(); - while (parentOp) { - if (dyn_cast(parentOp)) { - forOp = parentOp; - break; - } - parentOp = parentOp->getParentOp(); - } - // 如果没有找到scope.scope操作,则跳过 - if (!scopeOp) { - return; - } - if (!forOp) { - return; - } + auto aiCAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + int cubeControlIndex = 15; + int vectorControlIndex = 14; - // 如果该scope已经处理过,则跳过 - if (processedScopes2.count(forOp) > 0) return; - - // 标记该scope为已处理 - processedScopes2.insert(forOp); + llvm::DenseSet processedScopes2; + module->walk([&](SyncBlockWaitOp op) { + auto pipeS = hivm::PipeAttr::get(op->getContext(), hivm::PIPE::PIPE_S); + if (op.getTpipe() == pipeS || op.getPipe() == pipeS) { + return; + } - }); - bool firstSet = true; - bool firstWait = true; - for (auto forOp : processedScopes2) { - mlir::Operation* parentOp = forOp->getParentOp(); - mlir::Operation* scopeOp = nullptr; - - // 向上遍历查找scope.scope操作 - while (parentOp) { - if (dyn_cast(parentOp)) { - scopeOp = parentOp; - break; - } - parentOp = parentOp->getParentOp(); - } - bool isAIC = false; - // 1. 先检查操作是否有这个属性 - - if (scopeOp->hasAttr("hivm.tcore_type")) { - auto attr = scopeOp->getAttr("hivm.tcore_type"); - if (attr == aiCAttr) { - isAIC = true; - } - } + // 向上查找父scope.scope操作 + mlir::Operation *parentOp = op->getParentOp(); + mlir::Operation *scopeOp = nullptr; + mlir::Operation *forOp = nullptr; - if (isAIC) { - // 在for循环的开头插入代码 - builder.setInsertionPoint(scopeOp); - // %ssb_ready_addr = llvm.mlir.constant(0 : i64) : i64 - auto i64Type = builder.getIntegerType(64); - auto i32Type = builder.getIntegerType(32); - - builder.setInsertionPointToStart(&forOp->getRegion(0).front()); - // %ssb_ready_addr = llvm.mlir.constant(0 : i64) : i64 - // add sync_block_wait - auto coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::CUBE); - auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - auto flagId = builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); - builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); - - // 在循环末尾(yield之前)插入代码 - auto &loopBody = forOp->getRegion(0).front(); - // 找到循环体的terminator(应该是yield操作) - auto *terminator = loopBody.getTerminator(); - builder.setInsertionPoint(terminator); - - // add sync_block_set - coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::CUBE); - setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - flagId = builder.getIntegerAttr(builder.getI64Type(), cubeControlIndex); - builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); - - if (firstWait) { - auto &scopeBlock = scopeOp->getRegion(0).front(); - auto *scope_terminator = scopeBlock.getTerminator(); - builder.setInsertionPoint(scope_terminator); - // add sync_block_wait - coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::CUBE); - setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - flagId = builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); - builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); - firstWait = false; - } - } - else { - // 1. 在scopeop的开头插入代码 - // 假设scopeOp是一个具有区域的操作,我们获取其第一个块 - if (firstSet) { - auto &scopeBlock = scopeOp->getRegion(0).front(); - builder.setInsertionPointToStart(&scopeBlock); - - // add sync_block_wait - auto coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::VECTOR); - auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - auto flagId = builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); - builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); - firstSet = false; - } + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + parentOp = op->getParentOp(); + while (parentOp) { + if (dyn_cast(parentOp)) { + forOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } + // 如果没有找到scope.scope操作,则跳过 + if (!scopeOp) { + return; + } + if (!forOp) { + return; + } - auto i64Type = builder.getIntegerType(64); - auto i32Type = builder.getIntegerType(32); - - // 创建需要的常量 - auto c32ConstAttr = mlir::IntegerAttr::get(i64Type, 32); - auto c32ConstOp = builder.create( - scopeOp->getLoc(), i64Type, c32ConstAttr); - - auto c0i64ConstAttr = mlir::IntegerAttr::get(i64Type, 0); - auto c0i64ConstOp = builder.create( - scopeOp->getLoc(), i64Type, c0i64ConstAttr); - - auto c0i32ConstAttr = mlir::IntegerAttr::get(i32Type, 0); - auto c0i32ConstOp = builder.create( - scopeOp->getLoc(), i32Type, c0i32ConstAttr); - - auto c1i32ConstAttr = mlir::IntegerAttr::get(i32Type, 1); - auto c1i32ConstOp = builder.create( - scopeOp->getLoc(), i32Type, c1i32ConstAttr); - - // %sub_id = hivm.hir.get_sub_block_idx -> i64 - // 这里假设有一个getSubBlockIdxOp操作 - auto subIdOp = builder.create( - scopeOp->getLoc(), i64Type); - - // %ssb_addr_offset = arith.muli %sub_id, %c32_i64 : i64 - auto ssbAddrOffsetOp = builder.create( - scopeOp->getLoc(), - subIdOp.getResult(), - c32ConstOp.getResult()); - - // %ssb_addr = arith.addi %ssb_addr_offset, %c32_i64 : i64 - auto ssbAddrOp = builder.create( - scopeOp->getLoc(), - ssbAddrOffsetOp.getResult(), - c32ConstOp.getResult()); - - // %vec_id = arith.cmpi eq, %sub_id, %c0_i64 : i64 - auto vecIdOp = builder.create( - scopeOp->getLoc(), - mlir::arith::CmpIPredicate::eq, - subIdOp.getResult(), - c0i64ConstOp.getResult()); - - // 2. 在parentop的开头插入代码 - builder.setInsertionPointToStart(&forOp->getRegion(0).front()); - - // add sync_block_wait - auto coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::VECTOR); - auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - auto flagId = builder.getIntegerAttr(builder.getI64Type(), cubeControlIndex); - builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); - - // 在循环末尾(yield之前)插入代码 - auto &loopBody = forOp->getRegion(0).front(); - // 找到循环体的terminator(应该是yield操作) - auto *terminator = loopBody.getTerminator(); - builder.setInsertionPoint(terminator); - - // add sync_block_wait - coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::VECTOR); - setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - flagId = builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); - builder.create(forOp->getLoc(), coreAttr, setPipe, waitPipe, flagId); - } + // 如果该scope已经处理过,则跳过 + if (processedScopes2.count(forOp) > 0) + return; + + // 标记该scope为已处理 + processedScopes2.insert(forOp); + }); + bool firstSet = true; + bool firstWait = true; + for (auto forOp : processedScopes2) { + mlir::Operation *parentOp = forOp->getParentOp(); + mlir::Operation *scopeOp = nullptr; + + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); } - - auto i64Type = builder.getIntegerType(64); - auto i32Type = builder.getIntegerType(32); - auto initPtrType = mlir::LLVM::LLVMPointerType::get(builder.getContext(), 11); - SmallVector scopeOps; - module->walk([&](mlir::Operation* op) { - // 检查是否为目标操作 - if (auto scopeOp = dyn_cast(op)) { - scopeOps.push_back(scopeOp); - } - }); - if (!scopeOps.empty()) { - auto scopeOp = scopeOps[0]; + bool isAIC = false; + // 1. 先检查操作是否有这个属性 + + if (scopeOp->hasAttr("hivm.tcore_type")) { + auto attr = scopeOp->getAttr("hivm.tcore_type"); + if (attr == aiCAttr) { + isAIC = true; + } + } + + if (isAIC) { + // 在for循环的开头插入代码 builder.setInsertionPoint(scopeOp); + // %ssb_ready_addr = llvm.mlir.constant(0 : i64) : i64 + auto i64Type = builder.getIntegerType(64); + auto i32Type = builder.getIntegerType(32); + + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + // %ssb_ready_addr = llvm.mlir.constant(0 : i64) : i64 + // add sync_block_wait + auto coreAttr = + hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto flagId = + builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); + + // 在循环末尾(yield之前)插入代码 + auto &loopBody = forOp->getRegion(0).front(); + // 找到循环体的terminator(应该是yield操作) + auto *terminator = loopBody.getTerminator(); + builder.setInsertionPoint(terminator); + + // add sync_block_set + coreAttr = + hivm::TCoreTypeAttr::get(module.getContext(), hivm::TCoreType::CUBE); + setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + flagId = builder.getIntegerAttr(builder.getI64Type(), cubeControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); + + if (firstWait) { + auto &scopeBlock = scopeOp->getRegion(0).front(); + auto *scope_terminator = scopeBlock.getTerminator(); + builder.setInsertionPoint(scope_terminator); + // add sync_block_wait + coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), + hivm::TCoreType::CUBE); + setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + flagId = + builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); + firstWait = false; + } + } else { + // 1. 在scopeop的开头插入代码 + // 假设scopeOp是一个具有区域的操作,我们获取其第一个块 + if (firstSet) { + auto &scopeBlock = scopeOp->getRegion(0).front(); + builder.setInsertionPointToStart(&scopeBlock); + + // add sync_block_wait + auto coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), + hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto flagId = + builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); + firstSet = false; + } + + auto i64Type = builder.getIntegerType(64); + auto i32Type = builder.getIntegerType(32); + + // 创建需要的常量 + auto c32ConstAttr = mlir::IntegerAttr::get(i64Type, 32); + auto c32ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c32ConstAttr); + auto c0i64ConstAttr = mlir::IntegerAttr::get(i64Type, 0); auto c0i64ConstOp = builder.create( scopeOp->getLoc(), i64Type, c0i64ConstAttr); - auto c32i64ConstAttr = mlir::IntegerAttr::get(i64Type, 32); - auto c32i64ConstOp = builder.create( - scopeOp->getLoc(), i64Type, c32i64ConstAttr); - auto c64i64ConstAttr = mlir::IntegerAttr::get(i64Type, 64); - auto c64i64ConstOp = builder.create( - scopeOp->getLoc(), i64Type, c64i64ConstAttr); - auto c96i64ConstAttr = mlir::IntegerAttr::get(i64Type, 96); - auto c96i64ConstOp = builder.create( - scopeOp->getLoc(), i64Type, c96i64ConstAttr); + auto c0i32ConstAttr = mlir::IntegerAttr::get(i32Type, 0); auto c0i32ConstOp = builder.create( scopeOp->getLoc(), i32Type, c0i32ConstAttr); - - auto c0initInttoptrOp = builder.create( - scopeOp->getLoc(), initPtrType, c0i64ConstOp.getResult()); - auto c32initInttoptrOp = builder.create( - scopeOp->getLoc(), initPtrType, c32i64ConstOp.getResult()); - auto c64initInttoptrOp = builder.create( - scopeOp->getLoc(), initPtrType, c64i64ConstOp.getResult()); - auto c96initInttoptrOp = builder.create( - scopeOp->getLoc(), initPtrType, c96i64ConstOp.getResult()); - - builder.create( - scopeOp->getLoc(), - c0i32ConstOp, - c0initInttoptrOp - ); - builder.create( - scopeOp->getLoc(), - c0i32ConstOp, - c32initInttoptrOp - ); - builder.create( - scopeOp->getLoc(), - c0i32ConstOp, - c64initInttoptrOp - ); - builder.create( - scopeOp->getLoc(), - c0i32ConstOp, - c96initInttoptrOp - ); + + auto c1i32ConstAttr = mlir::IntegerAttr::get(i32Type, 1); + auto c1i32ConstOp = builder.create( + scopeOp->getLoc(), i32Type, c1i32ConstAttr); + + // %sub_id = hivm.hir.get_sub_block_idx -> i64 + // 这里假设有一个getSubBlockIdxOp操作 + auto subIdOp = + builder.create(scopeOp->getLoc(), i64Type); + + // %ssb_addr_offset = arith.muli %sub_id, %c32_i64 : i64 + auto ssbAddrOffsetOp = builder.create( + scopeOp->getLoc(), subIdOp.getResult(), c32ConstOp.getResult()); + + // %ssb_addr = arith.addi %ssb_addr_offset, %c32_i64 : i64 + auto ssbAddrOp = builder.create( + scopeOp->getLoc(), ssbAddrOffsetOp.getResult(), + c32ConstOp.getResult()); + + // %vec_id = arith.cmpi eq, %sub_id, %c0_i64 : i64 + auto vecIdOp = builder.create( + scopeOp->getLoc(), mlir::arith::CmpIPredicate::eq, + subIdOp.getResult(), c0i64ConstOp.getResult()); + + // 2. 在parentop的开头插入代码 + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + + // add sync_block_wait + auto coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), + hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto flagId = + builder.getIntegerAttr(builder.getI64Type(), cubeControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); + + // 在循环末尾(yield之前)插入代码 + auto &loopBody = forOp->getRegion(0).front(); + // 找到循环体的terminator(应该是yield操作) + auto *terminator = loopBody.getTerminator(); + builder.setInsertionPoint(terminator); + + // add sync_block_wait + coreAttr = hivm::TCoreTypeAttr::get(module.getContext(), + hivm::TCoreType::VECTOR); + setPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + waitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + flagId = builder.getIntegerAttr(builder.getI64Type(), vectorControlIndex); + builder.create(forOp->getLoc(), coreAttr, setPipe, + waitPipe, flagId); } + } + + auto i64Type = builder.getIntegerType(64); + auto i32Type = builder.getIntegerType(32); + auto initPtrType = mlir::LLVM::LLVMPointerType::get(builder.getContext(), 11); + SmallVector scopeOps; + module->walk([&](mlir::Operation *op) { + // 检查是否为目标操作 + if (auto scopeOp = dyn_cast(op)) { + scopeOps.push_back(scopeOp); + } + }); + if (!scopeOps.empty()) { + auto scopeOp = scopeOps[0]; + builder.setInsertionPoint(scopeOp); + auto c0i64ConstAttr = mlir::IntegerAttr::get(i64Type, 0); + auto c0i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c0i64ConstAttr); + auto c32i64ConstAttr = mlir::IntegerAttr::get(i64Type, 32); + auto c32i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c32i64ConstAttr); + auto c64i64ConstAttr = mlir::IntegerAttr::get(i64Type, 64); + auto c64i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c64i64ConstAttr); + auto c96i64ConstAttr = mlir::IntegerAttr::get(i64Type, 96); + auto c96i64ConstOp = builder.create( + scopeOp->getLoc(), i64Type, c96i64ConstAttr); + auto c0i32ConstAttr = mlir::IntegerAttr::get(i32Type, 0); + auto c0i32ConstOp = builder.create( + scopeOp->getLoc(), i32Type, c0i32ConstAttr); + + auto c0initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c0i64ConstOp.getResult()); + auto c32initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c32i64ConstOp.getResult()); + auto c64initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c64i64ConstOp.getResult()); + auto c96initInttoptrOp = builder.create( + scopeOp->getLoc(), initPtrType, c96i64ConstOp.getResult()); + + builder.create(scopeOp->getLoc(), c0i32ConstOp, + c0initInttoptrOp); + builder.create(scopeOp->getLoc(), c0i32ConstOp, + c32initInttoptrOp); + builder.create(scopeOp->getLoc(), c0i32ConstOp, + c64initInttoptrOp); + builder.create(scopeOp->getLoc(), c0i32ConstOp, + c96initInttoptrOp); + } } scf::ForOp transformLoop(scf::ForOp forOp, OpBuilder &builder) { - - // 1. 获取原始循环的信息 - Value originalLowerBound = forOp.getLowerBound(); - Value originalUpperBound = forOp.getUpperBound(); - Value originalStep = forOp.getStep(); - SmallVector iterArgs; - for (auto arg : forOp.getInitArgs()) { - iterArgs.push_back(arg); - } - auto yields = forOp.getBody()->getTerminator(); - - // 2. 检查循环体中是否有特定操作 - int hasTargetOps = 0; - forOp.walk([&](Operation* op) { - if (auto ifOp = dyn_cast(op)) { - if (ifOp->hasAttr("ssbuffer")) { - hasTargetOps++; - } - } - }); - // 3. 如果存在目标操作,在迭代参数中添加计数器 - Value counterInit = nullptr; - mlir::Operation* parentOp = forOp->getParentOp(); - mlir::Operation* scopeOp = nullptr; - // 向上遍历查找scope.scope操作 - while (parentOp) { - if (dyn_cast(parentOp)) { - scopeOp = parentOp; - break; - } - parentOp = parentOp->getParentOp(); + + // 1. 获取原始循环的信息 + Value originalLowerBound = forOp.getLowerBound(); + Value originalUpperBound = forOp.getUpperBound(); + Value originalStep = forOp.getStep(); + SmallVector iterArgs; + for (auto arg : forOp.getInitArgs()) { + iterArgs.push_back(arg); + } + auto yields = forOp.getBody()->getTerminator(); + + // 2. 检查循环体中是否有特定操作 + int hasTargetOps = 0; + forOp.walk([&](Operation *op) { + if (auto ifOp = dyn_cast(op)) { + if (ifOp->hasAttr("ssbuffer")) { + hasTargetOps++; + } } + }); + // 3. 如果存在目标操作,在迭代参数中添加计数器 + Value counterInit = nullptr; + mlir::Operation *parentOp = forOp->getParentOp(); + mlir::Operation *scopeOp = nullptr; + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); + } - builder.setInsertionPoint(scopeOp); - for (int i = 0; i < hasTargetOps; i++) { - Location loc = forOp.getLoc(); - auto argType = originalLowerBound.getType(); - - // 添加到迭代参数列表 - iterArgs.push_back(originalLowerBound); - } - // 2. 创建新的上界:originalUpperBound * 2 + builder.setInsertionPoint(scopeOp); + for (int i = 0; i < hasTargetOps; i++) { Location loc = forOp.getLoc(); - Type ubType = originalStep.getType(); - builder.setInsertionPoint(forOp); - - int count = 0; - for (auto &op : forOp.getBody()->getOperations()) { - if (auto ifOp = dyn_cast(op)) { - auto parentOp = ifOp->getParentOp(); - if (parentOp == forOp && ifOp->hasAttr("ssbuffer")) { - count++; - } + auto argType = originalLowerBound.getType(); + + // 添加到迭代参数列表 + iterArgs.push_back(originalLowerBound); + } + // 2. 创建新的上界:originalUpperBound * 2 + Location loc = forOp.getLoc(); + Type ubType = originalStep.getType(); + builder.setInsertionPoint(forOp); + + int count = 0; + for (auto &op : forOp.getBody()->getOperations()) { + if (auto ifOp = dyn_cast(op)) { + auto parentOp = ifOp->getParentOp(); + if (parentOp == forOp && ifOp->hasAttr("ssbuffer")) { + count++; } } + } - Value two; - if (ubType.isIndex()) { - two = builder.create(loc, count - 1); - } else if (auto intType = dyn_cast(ubType)) { - // 对于整数类型,创建相应类型的常数2 - two = builder.create(loc, count - 1, intType); - } else { - // 其他类型可能需要特殊处理 - llvm::errs() << "Warning: Unexpected type for upper bound: " << ubType << "\n"; - // 尝试创建索引类型的2然后转换 - auto indexTwo = builder.create(loc, count - 1); - two = builder.create(loc, ubType, indexTwo); - } - - auto steps = builder.create( - forOp.getLoc(), - originalStep, - two - ); + Value two; + if (ubType.isIndex()) { + two = builder.create(loc, count - 1); + } else if (auto intType = dyn_cast(ubType)) { + // 对于整数类型,创建相应类型的常数2 + two = builder.create(loc, count - 1, intType); + } else { + // 其他类型可能需要特殊处理 + llvm::errs() << "Warning: Unexpected type for upper bound: " << ubType + << "\n"; + // 尝试创建索引类型的2然后转换 + auto indexTwo = builder.create(loc, count - 1); + two = builder.create(loc, ubType, indexTwo); + } - auto nowUpperBound = builder.create( - forOp.getLoc(), - originalUpperBound, - steps - ); + auto steps = builder.create(forOp.getLoc(), originalStep, two); - // 3. Create a new for loop - auto newForOp = builder.create( - forOp.getLoc(), - originalLowerBound, - nowUpperBound, - originalStep, - iterArgs); - - // 4. 设置IR映射表,将旧循环的变量映射到新循环 - IRMapping mapper; - - // 映射迭代变量 - mapper.map(forOp.getInductionVar(), newForOp.getInductionVar()); - - // 映射迭代参数 - for (auto [oldArg, newArg] : - llvm::zip(forOp.getRegionIterArgs(), - newForOp.getRegionIterArgs())) { - mapper.map(oldArg, newArg); - } - - SmallVector newCounterArgs; - for (int i = forOp.getRegionIterArgs().size(); i < newForOp.getRegionIterArgs().size(); i++) { - newCounterArgs.push_back(newForOp.getRegionIterArgs()[i]); - } - // 5. 克隆循环体内容到新循环 - auto &newLoopBody = *newForOp.getBody(); - builder.setInsertionPointToStart(&newLoopBody); - - for (auto &op : forOp.getBody()->without_terminator()) { - builder.clone(op, mapper); - } - - // 6. 克隆yield操作 - if (auto yieldOp = dyn_cast(yields)) { - SmallVector newYieldOperands; - for (auto operand : yieldOp.getOperands()) { - newYieldOperands.push_back(mapper.lookupOrDefault(operand)); - } - if (hasTargetOps != 0) { - for (auto currentCounter : newCounterArgs) { - // 将更新后的计数器添加到yield操作数中 - newYieldOperands.push_back(currentCounter); - } - } - builder.create(yieldOp.getLoc(), newYieldOperands); + auto nowUpperBound = + builder.create(forOp.getLoc(), originalUpperBound, steps); + + // 3. Create a new for loop + auto newForOp = + builder.create(forOp.getLoc(), originalLowerBound, + nowUpperBound, originalStep, iterArgs); + + // 4. 设置IR映射表,将旧循环的变量映射到新循环 + IRMapping mapper; + + // 映射迭代变量 + mapper.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + // 映射迭代参数 + for (auto [oldArg, newArg] : + llvm::zip(forOp.getRegionIterArgs(), newForOp.getRegionIterArgs())) { + mapper.map(oldArg, newArg); + } + + SmallVector newCounterArgs; + for (int i = forOp.getRegionIterArgs().size(); + i < newForOp.getRegionIterArgs().size(); i++) { + newCounterArgs.push_back(newForOp.getRegionIterArgs()[i]); + } + // 5. 克隆循环体内容到新循环 + auto &newLoopBody = *newForOp.getBody(); + builder.setInsertionPointToStart(&newLoopBody); + + for (auto &op : forOp.getBody()->without_terminator()) { + builder.clone(op, mapper); + } + + // 6. 克隆yield操作 + if (auto yieldOp = dyn_cast(yields)) { + SmallVector newYieldOperands; + for (auto operand : yieldOp.getOperands()) { + newYieldOperands.push_back(mapper.lookupOrDefault(operand)); } - - // 7. 替换原循环的结果 if (hasTargetOps != 0) { - // 新循环有额外的计数器结果,但原循环没有对应结果 - // 我们可以选择只替换原循环对应的结果,或者忽略计数器结果 - unsigned numOriginalResults = forOp.getNumResults(); - SmallVector originalResults; - for (unsigned i = 0; i < numOriginalResults; i++) { - originalResults.push_back(newForOp.getResult(i)); - } - forOp.replaceAllUsesWith(originalResults); - } else { - forOp.replaceAllUsesWith(newForOp.getResults()); + for (auto currentCounter : newCounterArgs) { + // 将更新后的计数器添加到yield操作数中 + newYieldOperands.push_back(currentCounter); + } } - - // 8. 删除原循环 - forOp.erase(); - return newForOp; - + builder.create(yieldOp.getLoc(), newYieldOperands); + } + + // 7. 替换原循环的结果 + if (hasTargetOps != 0) { + // 新循环有额外的计数器结果,但原循环没有对应结果 + // 我们可以选择只替换原循环对应的结果,或者忽略计数器结果 + unsigned numOriginalResults = forOp.getNumResults(); + SmallVector originalResults; + for (unsigned i = 0; i < numOriginalResults; i++) { + originalResults.push_back(newForOp.getResult(i)); + } + forOp.replaceAllUsesWith(originalResults); + } else { + forOp.replaceAllUsesWith(newForOp.getResults()); + } + + // 8. 删除原循环 + forOp.erase(); + return newForOp; } -// Find the first occurrence of convert_layout or fixpipe operation after the specified operation -Value findFirstTargetOpAfterWait(SyncBlockWaitOp waitOp, SmallVector& excludedValues) -{ - bool startSearching = false; - - for (Operation &op : waitOp->getBlock()->getOperations()) { - Value res = nullptr; - if (&op == waitOp) { - startSearching = true; - continue; - } - - if (startSearching) { - if (isa(op)) { - res = op.getOperands()[0]; - } - if (isa(op)) { - res = op.getOperands()[1]; - } - if (isa(op)) { - res = op.getOperands()[1]; - } - if (isa(op)) { - res = op.getOperands()[0]; - } - } - if (res) { - if (llvm::is_contained(excludedValues, res)) { - continue; - } - excludedValues.push_back(res); - return res; - } +// Find the first occurrence of convert_layout or fixpipe operation after the +// specified operation +Value findFirstTargetOpAfterWait(SyncBlockWaitOp waitOp, + SmallVector &excludedValues) { + bool startSearching = false; + + for (Operation &op : waitOp->getBlock()->getOperations()) { + Value res = nullptr; + if (&op == waitOp) { + startSearching = true; + continue; } - - return nullptr; + + if (startSearching) { + if (isa(op)) { + res = op.getOperands()[0]; + } + if (isa(op)) { + res = op.getOperands()[1]; + } + if (isa(op)) { + res = op.getOperands()[1]; + } + if (isa(op)) { + res = op.getOperands()[0]; + } + } + if (res) { + if (llvm::is_contained(excludedValues, res)) { + continue; + } + excludedValues.push_back(res); + return res; + } + } + + return nullptr; } -void getWaitType(std::string CoreType, scf::ForOp forOp, SmallVector& waitTypes, SmallVector& allocTypes) -{ - auto scalarWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_S); - auto cubeWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_FIX); - auto vectorWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_MTE3); - SmallVector excludedValues; - forOp.walk([&](Operation* op) { - if (auto waitOp = dyn_cast(op)) { - auto parentOp = op->getParentOp(); - if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { - auto ifOp = dyn_cast(parentOp); - if (forOp == ifOp->getParentOp()) { - auto waitPipe = waitOp.getPipe(); - if ((waitPipe == cubeWaitPipe && CoreType == "cube") || (waitPipe == vectorWaitPipe && CoreType == "vector")) { - auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); - waitTypes.push_back(0); - allocTypes.push_back(allocOp); - } - else if (waitPipe != scalarWaitPipe) { - auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); - waitTypes.push_back(1); - allocTypes.push_back(allocOp); - } - } - } +void getWaitType(std::string CoreType, scf::ForOp forOp, + SmallVector &waitTypes, SmallVector &allocTypes) { + auto scalarWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_S); + auto cubeWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_FIX); + auto vectorWaitPipe = + PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_MTE3); + SmallVector excludedValues; + forOp.walk([&](Operation *op) { + if (auto waitOp = dyn_cast(op)) { + auto parentOp = op->getParentOp(); + if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { + auto ifOp = dyn_cast(parentOp); + if (forOp == ifOp->getParentOp()) { + auto waitPipe = waitOp.getPipe(); + if ((waitPipe == cubeWaitPipe && CoreType == "cube") || + (waitPipe == vectorWaitPipe && CoreType == "vector")) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + waitTypes.push_back(0); + allocTypes.push_back(allocOp); + } else if (waitPipe != scalarWaitPipe) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + waitTypes.push_back(1); + allocTypes.push_back(allocOp); + } } - }); + } + } + }); } DenseMap getCounterOffset(scf::ForOp forOp) { - int i = 0; - DenseMap bufferMap; - auto scalarWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_S); - forOp.walk([&](Operation* op) { - bufferMap[i] = 0; - auto ifOp = dyn_cast(op); - if (ifOp && ifOp->hasAttr("ssbuffer") && ifOp->getParentOp() == forOp) { - ifOp.walk([&](Operation* op) { - if (auto waitOp = dyn_cast(op)) { - if (auto waitIfOp = dyn_cast(op->getParentOp())) { - if (waitIfOp == ifOp) { - auto waitPipe = waitOp.getPipe(); - if ((waitPipe != scalarWaitPipe)) { - bufferMap[i]++; - } - } - } - } - }); - i ++; + int i = 0; + DenseMap bufferMap; + auto scalarWaitPipe = PipeAttr::get(forOp.getContext(), hivm::PIPE::PIPE_S); + forOp.walk([&](Operation *op) { + bufferMap[i] = 0; + auto ifOp = dyn_cast(op); + if (ifOp && ifOp->hasAttr("ssbuffer") && ifOp->getParentOp() == forOp) { + ifOp.walk([&](Operation *op) { + if (auto waitOp = dyn_cast(op)) { + if (auto waitIfOp = dyn_cast(op->getParentOp())) { + if (waitIfOp == ifOp) { + auto waitPipe = waitOp.getPipe(); + if ((waitPipe != scalarWaitPipe)) { + bufferMap[i]++; + } + } + } } - }); - return bufferMap; + }); + i++; + } + }); + return bufferMap; } -SmallVector addBufValLoop(scf::ForOp forOp, DenseMap VecBitMap, DenseMapCubeBitMap, OpBuilder &builder) -{ - auto aiCAttr = hivm::TCoreTypeAttr::get( - builder.getContext(), - hivm::TCoreType::CUBE); - bool isAIC = false; - // 向上查找父scope.scope操作 - mlir::Operation* parentOp = forOp->getParentOp(); - mlir::Operation* scopeOp = nullptr; - // 向上遍历查找scope.scope操作 - while (parentOp) { - if (dyn_cast(parentOp)) { - scopeOp = parentOp; - break; - } - parentOp = parentOp->getParentOp(); +SmallVector addBufValLoop(scf::ForOp forOp, + DenseMap VecBitMap, + DenseMap CubeBitMap, + OpBuilder &builder) { + auto aiCAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + bool isAIC = false; + // 向上查找父scope.scope操作 + mlir::Operation *parentOp = forOp->getParentOp(); + mlir::Operation *scopeOp = nullptr; + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; } - if (scopeOp->hasAttr("hivm.tcore_type")) { - auto attr = scopeOp->getAttr("hivm.tcore_type"); - if (attr == aiCAttr) { - isAIC = true; - } + parentOp = parentOp->getParentOp(); + } + if (scopeOp->hasAttr("hivm.tcore_type")) { + auto attr = scopeOp->getAttr("hivm.tcore_type"); + if (attr == aiCAttr) { + isAIC = true; } - auto bufferMap = getCounterOffset(forOp); - SmallVector buf_vals; - SmallVector if_conditions; - builder.setInsertionPointToStart(&scopeOp->getRegion(0).front()); - - // 1. 提取并处理end值 - Value startValue = forOp.getLowerBound(); - Value endValue = forOp.getUpperBound(); - // 2. 提取并处理step值 - Value stepValue = forOp.getStep(); - builder.setInsertionPoint(forOp); - Location loc = forOp.getLoc(); - int count = 0; - for (auto &op : forOp.getBody()->getOperations()) { - if (auto ifOp = dyn_cast(op)) { - auto parentOp = ifOp->getParentOp(); - if (parentOp == forOp && ifOp->hasAttr("ssbuffer")) { - count++; - } + } + auto bufferMap = getCounterOffset(forOp); + SmallVector buf_vals; + SmallVector if_conditions; + builder.setInsertionPointToStart(&scopeOp->getRegion(0).front()); + + // 1. 提取并处理end值 + Value startValue = forOp.getLowerBound(); + Value endValue = forOp.getUpperBound(); + // 2. 提取并处理step值 + Value stepValue = forOp.getStep(); + builder.setInsertionPoint(forOp); + Location loc = forOp.getLoc(); + int count = 0; + for (auto &op : forOp.getBody()->getOperations()) { + if (auto ifOp = dyn_cast(op)) { + auto parentOp = ifOp->getParentOp(); + if (parentOp == forOp && ifOp->hasAttr("ssbuffer")) { + count++; } } + } - Value two; - Type ubType = stepValue.getType(); - if (ubType.isIndex()) { - two = builder.create(loc, count - 1); - } else if (auto intType = dyn_cast(ubType)) { - // 对于整数类型,创建相应类型的常数2 - two = builder.create(loc, count - 1, intType); - } else { - // 其他类型可能需要特殊处理 - llvm::errs() << "Warning: Unexpected type for upper bound: " << ubType << "\n"; - // 尝试创建索引类型的2然后转换 - auto indexTwo = builder.create(loc, count - 1); - two = builder.create(loc, ubType, indexTwo); - } + Value two; + Type ubType = stepValue.getType(); + if (ubType.isIndex()) { + two = builder.create(loc, count - 1); + } else if (auto intType = dyn_cast(ubType)) { + // 对于整数类型,创建相应类型的常数2 + two = builder.create(loc, count - 1, intType); + } else { + // 其他类型可能需要特殊处理 + llvm::errs() << "Warning: Unexpected type for upper bound: " << ubType + << "\n"; + // 尝试创建索引类型的2然后转换 + auto indexTwo = builder.create(loc, count - 1); + two = builder.create(loc, ubType, indexTwo); + } - auto steps = builder.create( - forOp.getLoc(), - endValue.getType(), - stepValue, - two - ); + auto steps = builder.create(forOp.getLoc(), endValue.getType(), + stepValue, two); - auto subLoopValue = builder.create( - forOp.getLoc(), - endValue.getType(), - endValue, - steps - ); + auto subLoopValue = builder.create( + forOp.getLoc(), endValue.getType(), endValue, steps); - SmallVector WaitType; - SmallVector AllocType; - SmallVector bufferPtrs; - if (isAIC) { - builder.setInsertionPointToStart(&forOp->getRegion(0).front()); - // 创建常量32和64 - Value c0 = builder.create( - forOp.getLoc(), 0, 32 // 值32,64位 - ); - Value c32 = builder.create( - forOp.getLoc(), 32, 64 // 值32,64位 - ); - Value c64 = builder.create( - forOp.getLoc(), 64, 64 // 值64,64位 - ); - // 创建inttoptr操作 - Value ssb_vec0_ptr = builder.create( - forOp.getLoc(), - LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 - c32 - ); - Value ssb_vec1_ptr = builder.create( - forOp.getLoc(), - LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 - c64 - ); - bufferPtrs.push_back(ssb_vec0_ptr); - bufferPtrs.push_back(ssb_vec1_ptr); - // 创建load操作 - Value status_vec0 = builder.create( - forOp.getLoc(), builder.getI32Type(), ssb_vec0_ptr - ); - - Value status_vec1 = builder.create( - forOp.getLoc(), builder.getI32Type(), ssb_vec1_ptr + SmallVector WaitType; + SmallVector AllocType; + SmallVector bufferPtrs; + if (isAIC) { + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + // 创建常量32和64 + Value c0 = + builder.create(forOp.getLoc(), 0, 32 // 值32,64位 ); + Value c32 = builder.create(forOp.getLoc(), 32, + 64 // 值32,64位 + ); + Value c64 = builder.create(forOp.getLoc(), 64, + 64 // 值64,64位 + ); + // 创建inttoptr操作 + Value ssb_vec0_ptr = builder.create( + forOp.getLoc(), + LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 + c32); + Value ssb_vec1_ptr = builder.create( + forOp.getLoc(), + LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 + c64); + bufferPtrs.push_back(ssb_vec0_ptr); + bufferPtrs.push_back(ssb_vec1_ptr); + // 创建load操作 + Value status_vec0 = builder.create( + forOp.getLoc(), builder.getI32Type(), ssb_vec0_ptr); + + Value status_vec1 = builder.create( + forOp.getLoc(), builder.getI32Type(), ssb_vec1_ptr); + + getWaitType("cube", forOp, WaitType, AllocType); + + for (auto i = 0; i < WaitType.size(); i++) { + auto correnspondAlloc = CubeBitMap[AllocType[i]]; + auto i32ConstAttr = + mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + Value bufi_vec0_val = builder.create( + forOp.getLoc(), status_vec0, buf_constant_set); + Value bufi_vec1_val = builder.create( + forOp.getLoc(), status_vec1, buf_constant_set); + Value flag_bufi_vec0; + Value flag_bufi_vec1; + // 创建比较操作 + if (WaitType[i] == 0) { + flag_bufi_vec0 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec0_val, c0); + flag_bufi_vec1 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec1_val, c0); + } else { + flag_bufi_vec0 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec0_val, + buf_constant_set); + flag_bufi_vec1 = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec1_val, + buf_constant_set); + } + // 创建最终的and操作 + Value bufi_val = builder.create( + forOp.getLoc(), flag_bufi_vec0, flag_bufi_vec1); + buf_vals.push_back(bufi_val); + } - getWaitType("cube", forOp, WaitType, AllocType); - - for (auto i = 0; i < WaitType.size(); i++) { - auto correnspondAlloc = CubeBitMap[AllocType[i]]; - auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); - auto buf_constant_set = builder.create( - scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); - Value bufi_vec0_val = builder.create( - forOp.getLoc(), status_vec0, buf_constant_set - ); - Value bufi_vec1_val = builder.create( - forOp.getLoc(), status_vec1, buf_constant_set - ); - Value flag_bufi_vec0; - Value flag_bufi_vec1; - // 创建比较操作 - if (WaitType[i] == 0) { - flag_bufi_vec0 = builder.create( - forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec0_val, c0 - ); - flag_bufi_vec1 = builder.create( - forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec1_val, c0 - ); - } - else { - flag_bufi_vec0 = builder.create( - forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec0_val, buf_constant_set - ); - flag_bufi_vec1 = builder.create( - forOp.getLoc(), arith::CmpIPredicate::eq, bufi_vec1_val, buf_constant_set - ); - } - // 创建最终的and操作 - Value bufi_val = builder.create( - forOp.getLoc(), flag_bufi_vec0, flag_bufi_vec1 - ); - buf_vals.push_back(bufi_val); - } - - } else { - builder.setInsertionPointToStart(&scopeOp->getRegion(0).front()); - Value c0 = builder.create( - forOp.getLoc(), 0, 32 // 值32,64位 - ); - auto i64Type = builder.getIntegerType(64); - // %sub_id = hivm.hir.get_sub_block_idx -> i64 - // 这里假设有一个getSubBlockIdxOp操作 - auto subIdOp = builder.create( - scopeOp->getLoc(), i64Type); - auto i64ConstAttr = mlir::IntegerAttr::get(i64Type, 32); - auto cst_offset = builder.create( - scopeOp->getLoc(), i64Type, i64ConstAttr); - auto ssb_addr_offset = builder.create( - scopeOp->getLoc(), subIdOp, cst_offset); - auto ssb_addr = builder.create( - scopeOp->getLoc(), ssb_addr_offset, cst_offset); - builder.setInsertionPointToStart(&forOp->getRegion(0).front()); - // 创建inttoptr操作 - Value ssb_cube_ptr = builder.create( - forOp.getLoc(), - LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 - ssb_addr - ); - bufferPtrs.push_back(ssb_cube_ptr); - // 创建load操作 - Value status_cube = builder.create( - forOp.getLoc(), builder.getI32Type(), ssb_cube_ptr + } else { + builder.setInsertionPointToStart(&scopeOp->getRegion(0).front()); + Value c0 = + builder.create(forOp.getLoc(), 0, 32 // 值32,64位 ); + auto i64Type = builder.getIntegerType(64); + // %sub_id = hivm.hir.get_sub_block_idx -> i64 + // 这里假设有一个getSubBlockIdxOp操作 + auto subIdOp = builder.create(scopeOp->getLoc(), i64Type); + auto i64ConstAttr = mlir::IntegerAttr::get(i64Type, 32); + auto cst_offset = builder.create( + scopeOp->getLoc(), i64Type, i64ConstAttr); + auto ssb_addr_offset = + builder.create(scopeOp->getLoc(), subIdOp, cst_offset); + auto ssb_addr = builder.create(scopeOp->getLoc(), + ssb_addr_offset, cst_offset); + builder.setInsertionPointToStart(&forOp->getRegion(0).front()); + // 创建inttoptr操作 + Value ssb_cube_ptr = builder.create( + forOp.getLoc(), + LLVM::LLVMPointerType::get(builder.getContext(), 11), // 地址空间11 + ssb_addr); + bufferPtrs.push_back(ssb_cube_ptr); + // 创建load操作 + Value status_cube = builder.create( + forOp.getLoc(), builder.getI32Type(), ssb_cube_ptr); + + getWaitType("vector", forOp, WaitType, AllocType); + for (auto i = 0; i < WaitType.size(); i++) { + auto correnspondAlloc = VecBitMap[AllocType[i]]; + auto i32ConstAttr = + mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + Value bufi_cube_val = builder.create( + forOp.getLoc(), status_cube, buf_constant_set); + + Value flag_bufi_cube; + // 创建比较操作 + if (WaitType[i] == 0) { + flag_bufi_cube = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_cube_val, c0); + } else { + flag_bufi_cube = builder.create( + forOp.getLoc(), arith::CmpIPredicate::eq, bufi_cube_val, + buf_constant_set); + } + buf_vals.push_back(flag_bufi_cube); + } + } + int bufIdx = 0; + int groupIdx = 0; - getWaitType("vector", forOp, WaitType, AllocType); - for (auto i = 0; i < WaitType.size(); i++) { - auto correnspondAlloc = VecBitMap[AllocType[i]]; - auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); - auto buf_constant_set = builder.create( - scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); - Value bufi_cube_val = builder.create( - forOp.getLoc(), status_cube, buf_constant_set - ); - - Value flag_bufi_cube; - // 创建比较操作 - if (WaitType[i] == 0) { - flag_bufi_cube = builder.create( - forOp.getLoc(), arith::CmpIPredicate::eq, bufi_cube_val, c0 - ); - } - else { - flag_bufi_cube = builder.create( - forOp.getLoc(), arith::CmpIPredicate::eq, bufi_cube_val, buf_constant_set - ); - } - buf_vals.push_back(flag_bufi_cube); - } + for (const auto &pair : bufferMap) { + if (bufferMap[groupIdx] == 0) { + continue; } - int bufIdx = 0; - int groupIdx = 0; - for (const auto &pair : bufferMap) { - if (bufferMap[groupIdx] == 0) { - continue; - } - - // 获取对应的region迭代参数 - Value cnti = builder.create( - forOp.getLoc(), arith::CmpIPredicate::slt, - forOp.getRegionIterArgs()[forOp.getRegionIterArgs().size() - (bufferMap.size() - 1 - groupIdx)], - subLoopValue - ); - - // 计算该组中所有buffer值的AND - Value finalBufVal = buf_vals[bufIdx]; - for (int count = 1; count < bufferMap[groupIdx]; count++) { - finalBufVal = builder.create( - forOp.getLoc(), finalBufVal, buf_vals[bufIdx + count] - ); - } - - auto cond = builder.create( - forOp.getLoc(), finalBufVal, cnti - ); - if_conditions.push_back(cond); - - // 更新索引 - bufIdx += bufferMap[groupIdx]; - groupIdx++; - } - int ifIndex = 0; - int acc = 0; - int bufferBit = 0; - for (int i = 0; i < CubeBitMap.size(); i++) { - bufferBit += (1 << i); - } - forOp.getBody()->walk([&](Operation* op) { - auto ifOp = dyn_cast(op); - if (ifOp && ifOp->hasAttr("ssbuffer")) { - // 获取then区域 - Block* thenBlock = &ifOp.getThenRegion().front(); - - // 找到then区域中的yield操作 - Operation* yieldOp = nullptr; - for (auto& op : *thenBlock) { - if (isa(op)) { - yieldOp = &op; - break; - } + // 获取对应的region迭代参数 + Value cnti = builder.create( + forOp.getLoc(), arith::CmpIPredicate::slt, + forOp.getRegionIterArgs()[forOp.getRegionIterArgs().size() - + (bufferMap.size() - 1 - groupIdx)], + subLoopValue); + + // 计算该组中所有buffer值的AND + Value finalBufVal = buf_vals[bufIdx]; + for (int count = 1; count < bufferMap[groupIdx]; count++) { + finalBufVal = builder.create(forOp.getLoc(), finalBufVal, + buf_vals[bufIdx + count]); + } + + auto cond = + builder.create(forOp.getLoc(), finalBufVal, cnti); + if_conditions.push_back(cond); + + // 更新索引 + bufIdx += bufferMap[groupIdx]; + groupIdx++; + } + int ifIndex = 0; + int acc = 0; + int bufferBit = 0; + for (int i = 0; i < CubeBitMap.size(); i++) { + bufferBit += (1 << i); + } + forOp.getBody()->walk([&](Operation *op) { + auto ifOp = dyn_cast(op); + if (ifOp && ifOp->hasAttr("ssbuffer")) { + // 获取then区域 + Block *thenBlock = &ifOp.getThenRegion().front(); + + // 找到then区域中的yield操作 + Operation *yieldOp = nullptr; + for (auto &op : *thenBlock) { + if (isa(op)) { + yieldOp = &op; + break; } - if (yieldOp) { - builder.setInsertionPoint(yieldOp); - - if (isAIC) { - // 创建插入的语句 - // %status_v2 = llvm.load %ssb_ptr : !llvm.ptr<11> -> i32 - Value status_v2_0 = builder.create( - yieldOp->getLoc(), - builder.getIntegerType(32), // i32类型 - bufferPtrs[0] // 假设ssb_ptr已在作用域中定义 - ); - Value status_v2_1 = builder.create( - yieldOp->getLoc(), - builder.getIntegerType(32), // i32类型 - bufferPtrs[1] // 假设ssb_ptr已在作用域中定义 - ); - Value buf_val_new_0 = status_v2_0; - Value buf_val_new_1 = status_v2_1; - auto bufferNum = bufferMap[ifIndex]; - for (int i = 0; i < bufferNum; i++) { - if (WaitType[acc + i] == 0) { - auto correnspondAlloc = CubeBitMap[AllocType[acc + i]]; - auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); - auto buf_constant_set = builder.create( - scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); - buf_val_new_0 = builder.create( - yieldOp->getLoc(), - buf_val_new_0, - buf_constant_set // 假设buf3_clear已在作用域中定义 - ); - buf_val_new_1 = builder.create( - yieldOp->getLoc(), - buf_val_new_1, - buf_constant_set // 假设buf3_clear已在作用域中定义 - ); - } - else { - auto correnspondAlloc = CubeBitMap[AllocType[acc + i]]; - int bitPos = correnspondAlloc; - int basePattern = bufferBit; - int finalValue = basePattern ^ (1 << bitPos); - auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), finalValue); - auto buf_constant_set = builder.create( - scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); - buf_val_new_0 = builder.create( - yieldOp->getLoc(), - buf_val_new_0, - buf_constant_set // 假设buf3_clear已在作用域中定义 - ); - buf_val_new_1 = builder.create( - yieldOp->getLoc(), - buf_val_new_1, - buf_constant_set // 假设buf3_clear已在作用域中定义 - ); - } - } - acc += bufferNum; - builder.create( - yieldOp->getLoc(), - buf_val_new_0, - bufferPtrs[0] - ); - builder.create( - yieldOp->getLoc(), - buf_val_new_1, - bufferPtrs[1] - ); - + } + if (yieldOp) { + builder.setInsertionPoint(yieldOp); + + if (isAIC) { + // 创建插入的语句 + // %status_v2 = llvm.load %ssb_ptr : !llvm.ptr<11> -> i32 + Value status_v2_0 = builder.create( + yieldOp->getLoc(), + builder.getIntegerType(32), // i32类型 + bufferPtrs[0] // 假设ssb_ptr已在作用域中定义 + ); + Value status_v2_1 = builder.create( + yieldOp->getLoc(), + builder.getIntegerType(32), // i32类型 + bufferPtrs[1] // 假设ssb_ptr已在作用域中定义 + ); + Value buf_val_new_0 = status_v2_0; + Value buf_val_new_1 = status_v2_1; + auto bufferNum = bufferMap[ifIndex]; + for (int i = 0; i < bufferNum; i++) { + if (WaitType[acc + i] == 0) { + auto correnspondAlloc = CubeBitMap[AllocType[acc + i]]; + auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), + 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new_0 = builder.create( + yieldOp->getLoc(), buf_val_new_0, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + buf_val_new_1 = builder.create( + yieldOp->getLoc(), buf_val_new_1, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + } else { + auto correnspondAlloc = CubeBitMap[AllocType[acc + i]]; + int bitPos = correnspondAlloc; + int basePattern = bufferBit; + int finalValue = basePattern ^ (1 << bitPos); + auto i32ConstAttr = + mlir::IntegerAttr::get(builder.getI32Type(), finalValue); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new_0 = builder.create( + yieldOp->getLoc(), buf_val_new_0, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + buf_val_new_1 = builder.create( + yieldOp->getLoc(), buf_val_new_1, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); } - else { - // 创建插入的语句 - // %status_v2 = llvm.load %ssb_ptr : !llvm.ptr<11> -> i32 - Value status_v2 = builder.create( - yieldOp->getLoc(), - builder.getIntegerType(32), // i32类型 - bufferPtrs[0] // 假设ssb_ptr已在作用域中定义 - ); - Value buf_val_new = status_v2; - auto bufferNum = bufferMap[ifIndex]; - for (int i = 0; i < bufferNum; i++) { - if (WaitType[acc + i] == 0) { - auto correnspondAlloc = VecBitMap[AllocType[acc + i]]; - auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), 1 << correnspondAlloc); - auto buf_constant_set = builder.create( - scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); - buf_val_new = builder.create( - yieldOp->getLoc(), - buf_val_new, - buf_constant_set // 假设buf3_clear已在作用域中定义 - ); - } - else { - auto correnspondAlloc = VecBitMap[AllocType[acc + i]]; - int bitPos = correnspondAlloc; - int basePattern = bufferBit; - int finalValue = basePattern ^ (1 << bitPos); - auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), finalValue); - auto buf_constant_set = builder.create( - scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); - buf_val_new = builder.create( - yieldOp->getLoc(), - buf_val_new, - buf_constant_set // 假设buf3_clear已在作用域中定义 - ); - } - } - acc += bufferNum; - builder.create( - yieldOp->getLoc(), - buf_val_new, - bufferPtrs[0] - ); + } + acc += bufferNum; + builder.create(yieldOp->getLoc(), buf_val_new_0, + bufferPtrs[0]); + builder.create(yieldOp->getLoc(), buf_val_new_1, + bufferPtrs[1]); + + } else { + // 创建插入的语句 + // %status_v2 = llvm.load %ssb_ptr : !llvm.ptr<11> -> i32 + Value status_v2 = builder.create( + yieldOp->getLoc(), + builder.getIntegerType(32), // i32类型 + bufferPtrs[0] // 假设ssb_ptr已在作用域中定义 + ); + Value buf_val_new = status_v2; + auto bufferNum = bufferMap[ifIndex]; + for (int i = 0; i < bufferNum; i++) { + if (WaitType[acc + i] == 0) { + auto correnspondAlloc = VecBitMap[AllocType[acc + i]]; + auto i32ConstAttr = mlir::IntegerAttr::get(builder.getI32Type(), + 1 << correnspondAlloc); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new = builder.create( + yieldOp->getLoc(), buf_val_new, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); + } else { + auto correnspondAlloc = VecBitMap[AllocType[acc + i]]; + int bitPos = correnspondAlloc; + int basePattern = bufferBit; + int finalValue = basePattern ^ (1 << bitPos); + auto i32ConstAttr = + mlir::IntegerAttr::get(builder.getI32Type(), finalValue); + auto buf_constant_set = builder.create( + scopeOp->getLoc(), builder.getI32Type(), i32ConstAttr); + buf_val_new = builder.create( + yieldOp->getLoc(), buf_val_new, + buf_constant_set // 假设buf3_clear已在作用域中定义 + ); } - ifIndex ++; + } + acc += bufferNum; + builder.create(yieldOp->getLoc(), buf_val_new, + bufferPtrs[0]); } + ifIndex++; } - }); + } + }); - return if_conditions; + return if_conditions; } -void ReplaceIf(scf::ForOp forOp, SmallVector conditions, SmallVector& opsToErase, DenseMap& ifArgMap, OpBuilder &builder, ModuleOp moduleOp) -{ - SmallVector ifToProcess; - llvm::outs()<<"enter replaceif\n"; - Value step = forOp.getStep(); - auto aiCAttr = hivm::TCoreTypeAttr::get( - builder.getContext(), - hivm::TCoreType::CUBE); - forOp.getBody()->walk([&](Operation* op) { - auto ifOp = dyn_cast(op); - if (ifOp && ifOp->hasAttr("ssbuffer") && forOp == ifOp->getParentOp()) { - ifToProcess.push_back(ifOp); +void ReplaceIf(scf::ForOp forOp, SmallVector conditions, + SmallVector &opsToErase, + DenseMap &ifArgMap, OpBuilder &builder, + ModuleOp moduleOp) { + SmallVector ifToProcess; + llvm::outs() << "enter replaceif\n"; + Value step = forOp.getStep(); + auto aiCAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + forOp.getBody()->walk([&](Operation *op) { + auto ifOp = dyn_cast(op); + if (ifOp && ifOp->hasAttr("ssbuffer") && forOp == ifOp->getParentOp()) { + ifToProcess.push_back(ifOp); + } + }); + + IRMapping IRMap; + for (int i = 0; i < ifToProcess.size(); i++) { + auto ifOp = ifToProcess[i]; + auto parentOp = ifOp->getParentOp(); + auto loc = ifOp.getLoc(); + // 获取for循环的iterargs(迭代参数) + auto iterArgs = forOp.getRegionIterArgs(); + if (iterArgs.size() < conditions.size()) { + return; + } + auto thenYieldOp = + dyn_cast(ifOp.getThenRegion().front().getTerminator()); + SmallVector thenResults; + if (thenYieldOp) { + // 如果已有返回值,保留它们 + for (auto result : thenYieldOp.getResults()) { + thenResults.push_back(result); } - }); + } + // 创建新的else区域,返回两个迭代参数 + SmallVector elseResults; + scf::YieldOp elseYieldOp = nullptr; + bool hasElse = false; + if (!ifOp.getElseRegion().empty()) { + elseYieldOp = + dyn_cast(ifOp.getElseRegion().front().getTerminator()); + hasElse = true; + } + if (elseYieldOp) { + for (auto result : elseYieldOp.getResults()) { + elseResults.push_back(result); + } + } + // 获取最后两个迭代参数 + Value iterArgMinus = iterArgs[iterArgs.size() - (conditions.size() - i)]; + // 创建新的then区域,返回两个迭代参数 + thenResults.push_back(iterArgMinus); + elseResults.push_back(iterArgMinus); - IRMapping IRMap; - for (int i = 0; i < ifToProcess.size(); i++) { - auto ifOp = ifToProcess[i]; - auto parentOp = ifOp->getParentOp(); - auto loc = ifOp.getLoc(); - // 获取for循环的iterargs(迭代参数) - auto iterArgs = forOp.getRegionIterArgs(); - if (iterArgs.size() < conditions.size()) { - return; - } - auto thenYieldOp = dyn_cast(ifOp.getThenRegion().front().getTerminator()); - SmallVector thenResults; - if (thenYieldOp) { - // 如果已有返回值,保留它们 - for (auto result : thenYieldOp.getResults()) { - thenResults.push_back(result); - } - } - // 创建新的else区域,返回两个迭代参数 - SmallVector elseResults; - scf::YieldOp elseYieldOp = nullptr; - bool hasElse = false; - if (!ifOp.getElseRegion().empty()) { - elseYieldOp = dyn_cast(ifOp.getElseRegion().front().getTerminator()); - hasElse = true; - } - if (elseYieldOp) { - for (auto result : elseYieldOp.getResults()) { - elseResults.push_back(result); - } + // 保存原有的操作,以便后续克隆 + SmallVector thenOps; + for (auto &op : ifOp.getThenRegion().front()) { + thenOps.push_back(&op); + } + + SmallVector elseOps; + if (!ifOp.getElseRegion().empty()) { + for (auto &op : ifOp.getElseRegion().front()) { + elseOps.push_back(&op); + } + } + SmallVector resultTypes; + for (auto val : thenResults) { + resultTypes.push_back(val.getType()); + } + // 创建新的scf.if操作 + builder.setInsertionPoint(ifOp); + auto newIfOp = builder.create(loc, resultTypes, conditions[i], + /*withElseRegion=*/true); + newIfOp->setAttr("ssbuffer", builder.getUnitAttr()); + // 处理then区域 + auto &newThenBlock = newIfOp.getThenRegion().front(); + builder.setInsertionPointToStart(&newThenBlock); + + // 克隆then区域的操作 + for (auto op : thenOps) { + if (auto yieldOp = dyn_cast(op)) { + // 处理yield的操作数映射 + SmallVector mappedOperands; + for (auto operand : yieldOp->getOperands()) { + mappedOperands.push_back(IRMap.lookupOrDefault(operand)); } // 获取最后两个迭代参数 - Value iterArgMinus = iterArgs[iterArgs.size() - (conditions.size() - i)]; - // 创建新的then区域,返回两个迭代参数 - thenResults.push_back(iterArgMinus); - elseResults.push_back(iterArgMinus); - - // 保存原有的操作,以便后续克隆 - SmallVector thenOps; - for (auto &op : ifOp.getThenRegion().front()) { - thenOps.push_back(&op); - } - - SmallVector elseOps; - if (!ifOp.getElseRegion().empty()) { - for (auto &op : ifOp.getElseRegion().front()) { - elseOps.push_back(&op); - } - } - SmallVector resultTypes; - for (auto val : thenResults) { - resultTypes.push_back(val.getType()); - } - // 创建新的scf.if操作 - builder.setInsertionPoint(ifOp); - auto newIfOp = builder.create( - loc, - resultTypes, - conditions[i], - /*withElseRegion=*/true); - newIfOp->setAttr("ssbuffer", builder.getUnitAttr()); - // 处理then区域 - auto &newThenBlock = newIfOp.getThenRegion().front(); - builder.setInsertionPointToStart(&newThenBlock); - - // 克隆then区域的操作 - for (auto op : thenOps) { - if (auto yieldOp = dyn_cast(op)) { - // 处理yield的操作数映射 - SmallVector mappedOperands; - for (auto operand : yieldOp->getOperands()) { - mappedOperands.push_back(IRMap.lookupOrDefault(operand)); - } - // 获取最后两个迭代参数 - Value iterArgMinus = iterArgs[iterArgs.size() - (conditions.size() - i)]; - - // %ssb_addr = arith.addi %ssb_addr_offset, %c32_i64 : i64 - auto AddIOp = builder.create( - forOp->getLoc(), - iterArgMinus, - step); - // 这里加个add1 - mappedOperands.push_back(AddIOp); - builder.create(loc, mappedOperands); - } else { - auto newOp = builder.clone(*op, IRMap); - IRMap.map(op->getResults(), newOp->getResults()); - } - } - - // 处理else区域 - auto &newElseBlock = newIfOp.getElseRegion().front(); - builder.setInsertionPointToStart(&newElseBlock); - // 克隆else区域的操作 - if (hasElse) { - for (auto op : elseOps) { - if (auto yieldOp = dyn_cast(op)) { - // 处理yield的操作数映射 - SmallVector mappedOperands; - for (auto operand : yieldOp->getOperands()) { - mappedOperands.push_back(IRMap.lookupOrDefault(operand)); - } - Value iterArgMinus = iterArgs[iterArgs.size() - (conditions.size() - i)]; - mappedOperands.push_back(iterArgMinus); - builder.create(loc, mappedOperands); - } else { - auto newOp = builder.clone(*op, IRMap); - IRMap.map(op->getResults(), newOp->getResults()); - } + Value iterArgMinus = + iterArgs[iterArgs.size() - (conditions.size() - i)]; + + // %ssb_addr = arith.addi %ssb_addr_offset, %c32_i64 : i64 + auto AddIOp = builder.create(forOp->getLoc(), + iterArgMinus, step); + // 这里加个add1 + mappedOperands.push_back(AddIOp); + builder.create(loc, mappedOperands); + } else { + auto newOp = builder.clone(*op, IRMap); + IRMap.map(op->getResults(), newOp->getResults()); + } + } + + // 处理else区域 + auto &newElseBlock = newIfOp.getElseRegion().front(); + builder.setInsertionPointToStart(&newElseBlock); + // 克隆else区域的操作 + if (hasElse) { + for (auto op : elseOps) { + if (auto yieldOp = dyn_cast(op)) { + // 处理yield的操作数映射 + SmallVector mappedOperands; + for (auto operand : yieldOp->getOperands()) { + mappedOperands.push_back(IRMap.lookupOrDefault(operand)); } + Value iterArgMinus = + iterArgs[iterArgs.size() - (conditions.size() - i)]; + mappedOperands.push_back(iterArgMinus); + builder.create(loc, mappedOperands); } else { - SmallVector cntOperands; - cntOperands.push_back(iterArgMinus); - builder.create(loc, cntOperands); - } - - // 替换原有if操作的使用 - // 首先,将原if操作的结果替换为新if操作的对应结果 - for (unsigned j = 0; j < ifOp.getNumResults(); ++j) { - ifOp.getResult(j).replaceAllUsesWith(newIfOp.getResult(j)); + auto newOp = builder.clone(*op, IRMap); + IRMap.map(op->getResults(), newOp->getResults()); } - // 获取新if操作所在的块 - Block* newIfBlock = ifOp->getBlock(); - // 在for循环体内替换迭代参数的使用 - forOp.getBody()->walk([&](Operation* op) { - // 检查操作是否与新ifOp在同一个块中 - Block* opBlock = op->getBlock(); - if (opBlock != newIfBlock) { - // 不在同一个块中,跳过 - return; - } - if (op->isBeforeInBlock(newIfOp)) { - return; // 只处理if操作之后的use - } - for (unsigned j = 0; j < op->getNumOperands(); ++j) { - for (auto argIndex = 0; argIndex < conditions.size(); argIndex ++) { - // 获取最后两个迭代参数 - Value iterArgMinus = iterArgs[iterArgs.size() - (conditions.size() - i)]; - if (op->getOperand(j) == iterArgMinus) { - op->setOperand(j, newIfOp.getResults()[newIfOp.getNumResults() - 1]); - } - } - } - }); - - // // 删除原有的if操作 - opsToErase.push_back(ifOp); - if (ifArgMap.find(newIfOp) == ifArgMap.end()) { - ifArgMap[newIfOp] = iterArgMinus; - } + } + } else { + SmallVector cntOperands; + cntOperands.push_back(iterArgMinus); + builder.create(loc, cntOperands); + } + + // 替换原有if操作的使用 + // 首先,将原if操作的结果替换为新if操作的对应结果 + for (unsigned j = 0; j < ifOp.getNumResults(); ++j) { + ifOp.getResult(j).replaceAllUsesWith(newIfOp.getResult(j)); + } + // 获取新if操作所在的块 + Block *newIfBlock = ifOp->getBlock(); + // 在for循环体内替换迭代参数的使用 + forOp.getBody()->walk([&](Operation *op) { + // 检查操作是否与新ifOp在同一个块中 + Block *opBlock = op->getBlock(); + if (opBlock != newIfBlock) { + // 不在同一个块中,跳过 + return; + } + if (op->isBeforeInBlock(newIfOp)) { + return; // 只处理if操作之后的use + } + for (unsigned j = 0; j < op->getNumOperands(); ++j) { + for (auto argIndex = 0; argIndex < conditions.size(); argIndex++) { + // 获取最后两个迭代参数 + Value iterArgMinus = + iterArgs[iterArgs.size() - (conditions.size() - i)]; + if (op->getOperand(j) == iterArgMinus) { + op->setOperand(j, + newIfOp.getResults()[newIfOp.getNumResults() - 1]); + } } + } + }); + + // // 删除原有的if操作 + opsToErase.push_back(ifOp); + if (ifArgMap.find(newIfOp) == ifArgMap.end()) { + ifArgMap[newIfOp] = iterArgMinus; + } + } } int getNestingDepth(scf::ForOp forOp) { - int depth = 0; - Operation* op = forOp.getOperation(); - while (op) { - if (op->getDialect() && op->getDialect()->getNamespace() == "scf") { - ++depth; - } - op = op->getParentOp(); + int depth = 0; + Operation *op = forOp.getOperation(); + while (op) { + if (op->getDialect() && op->getDialect()->getNamespace() == "scf") { + ++depth; } - return depth; + op = op->getParentOp(); + } + return depth; } -void printDenseMap(const mlir::DenseMap& Map) -{ - for (const auto& pair : Map) { - mlir::Value val = pair.first; - int bitValue = pair.second; - llvm::outs()< &Map) { + for (const auto &pair : Map) { + mlir::Value val = pair.first; + int bitValue = pair.second; + llvm::outs() << val << " " << bitValue << " allocmap\n\n\n"; + llvm::outs().flush(); + } + llvm::outs() << "------------------------------\n\n\n"; } -void getAllocBit(ModuleOp module, DenseMap& VecBitMap, DenseMap& CubeBitMap, OpBuilder builder) -{ - auto aiCAttr = hivm::TCoreTypeAttr::get( - builder.getContext(), - hivm::TCoreType::CUBE); - auto scalarWaitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); - auto cubeWaitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_FIX); - auto vectorWaitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_MTE3); - - int cubeAcc = 0; - int vecAcc = 0; - SmallVector scopeOpToEdit; - module.walk([&](scope::ScopeOp scopeOp) { - scopeOpToEdit.push_back(scopeOp); - }); - for (auto scopeOp : scopeOpToEdit) { - SmallVector excludedValues; - if (scopeOp->hasAttr("hivm.tcore_type")) { - auto attr = scopeOp->getAttr("hivm.tcore_type"); - if (attr == aiCAttr) { - scopeOp.walk([&](SyncBlockWaitOp waitOp) { - auto parentOp = waitOp->getParentOp(); - if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { - auto waitPipe = waitOp.getPipe(); - if (waitPipe != scalarWaitPipe) { - auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); - if (VecBitMap.find(allocOp) != VecBitMap.end()) { - CubeBitMap[allocOp] = VecBitMap[allocOp]; - } else { - CubeBitMap[allocOp] = cubeAcc; - cubeAcc++; - } - } - } - }); - } else { - scopeOp.walk([&](SyncBlockWaitOp waitOp) { - auto parentOp = waitOp->getParentOp(); - if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { - auto waitPipe = waitOp.getPipe(); - if (waitPipe != scalarWaitPipe) { - auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); - if (VecBitMap.find(allocOp) == VecBitMap.end()) { - VecBitMap[allocOp] = vecAcc; - vecAcc++; - } - } - } - }); +void getAllocBit(ModuleOp module, DenseMap &VecBitMap, + DenseMap &CubeBitMap, OpBuilder builder) { + auto aiCAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto scalarWaitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_S); + auto cubeWaitPipe = PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_FIX); + auto vectorWaitPipe = + PipeAttr::get(module.getContext(), hivm::PIPE::PIPE_MTE3); + + int cubeAcc = 0; + int vecAcc = 0; + SmallVector scopeOpToEdit; + module.walk( + [&](scope::ScopeOp scopeOp) { scopeOpToEdit.push_back(scopeOp); }); + for (auto scopeOp : scopeOpToEdit) { + SmallVector excludedValues; + if (scopeOp->hasAttr("hivm.tcore_type")) { + auto attr = scopeOp->getAttr("hivm.tcore_type"); + if (attr == aiCAttr) { + scopeOp.walk([&](SyncBlockWaitOp waitOp) { + auto parentOp = waitOp->getParentOp(); + if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { + auto waitPipe = waitOp.getPipe(); + if (waitPipe != scalarWaitPipe) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + if (VecBitMap.find(allocOp) != VecBitMap.end()) { + CubeBitMap[allocOp] = VecBitMap[allocOp]; + } else { + CubeBitMap[allocOp] = cubeAcc; + cubeAcc++; + } + } + } + }); + } else { + scopeOp.walk([&](SyncBlockWaitOp waitOp) { + auto parentOp = waitOp->getParentOp(); + if (isa(parentOp) && parentOp->hasAttr("ssbuffer")) { + auto waitPipe = waitOp.getPipe(); + if (waitPipe != scalarWaitPipe) { + auto allocOp = findFirstTargetOpAfterWait(waitOp, excludedValues); + if (VecBitMap.find(allocOp) == VecBitMap.end()) { + VecBitMap[allocOp] = vecAcc; + vecAcc++; + } + } } + }); } } + } } -void modifyForIterargDeps(scf::ForOp forOp, DenseMap ifCounters) -{ +void modifyForIterargDeps(scf::ForOp forOp, + DenseMap ifCounters) { Value iterArg = forOp.getInductionVar(); for (Operation &op : forOp.getBody()->without_terminator()) { @@ -1218,10 +1170,10 @@ void modifyForIterargDeps(scf::ForOp forOp, DenseMap ifCounter if (ifCounters.find(ifOp) != ifCounters.end()) { Value counter = ifCounters[ifOp]; - ifOp.walk([&](Operation* opInIf) { + ifOp.walk([&](Operation *opInIf) { for (auto [i, operand] : llvm::enumerate(opInIf->getOperands())) { if (operand == iterArg) { - opInIf->setOperand(i, counter); + opInIf->setOperand(i, counter); } } }); @@ -1231,91 +1183,88 @@ void modifyForIterargDeps(scf::ForOp forOp, DenseMap ifCounter } void FlowSssbuf(ModuleOp module) { - mlir::OpBuilder builder(module.getContext()); - // 收集所有需要转换的循环 - SmallVector targetLoops; - llvm::outs()<<"enter flowsssbuf\n\n"; - module.walk([&](Operation* op) { - if (auto forOp = dyn_cast(op)) { - // 检查循环是否包含特定的 sync_block_set 操作 - bool hasSyncBlockSet = false; - forOp.walk([&](Operation *op) { - if (isa(op)) { - if (auto ifOp = dyn_cast(op->getParentOp())) { - if (forOp == ifOp->getParentOp() && ifOp->hasAttr("ssbuffer")) { - hasSyncBlockSet = true; - } - } - } - }); - - if (hasSyncBlockSet) { - if (llvm::find(targetLoops, forOp) == targetLoops.end()) { - targetLoops.push_back(forOp); - } + mlir::OpBuilder builder(module.getContext()); + // 收集所有需要转换的循环 + SmallVector targetLoops; + llvm::outs() << "enter flowsssbuf\n\n"; + module.walk([&](Operation *op) { + if (auto forOp = dyn_cast(op)) { + // 检查循环是否包含特定的 sync_block_set 操作 + bool hasSyncBlockSet = false; + forOp.walk([&](Operation *op) { + if (isa(op)) { + if (auto ifOp = dyn_cast(op->getParentOp())) { + if (forOp == ifOp->getParentOp() && ifOp->hasAttr("ssbuffer")) { + hasSyncBlockSet = true; } + } } - }); - llvm::outs()<<"enter flowsssbuf\n\n"; - - SmallVector transformLoops; - // 转换每个目标循环 - for (scf::ForOp forOp : targetLoops) { - auto newforOp = transformLoop(forOp, builder); - } - - module.walk([&](Operation* op) { - if (auto forOp = dyn_cast(op)) { - // 检查循环是否包含特定的 sync_block_set 操作 - bool hasSyncBlockSet = false; - forOp.walk([&](Operation *op) { - if (isa(op)) { - if (auto ifOp = dyn_cast(op->getParentOp())) { - if (forOp == ifOp->getParentOp() && ifOp->hasAttr("ssbuffer")) { - hasSyncBlockSet = true; - } - } - } - }); - - if (hasSyncBlockSet) { - if (llvm::find(transformLoops, forOp) == transformLoops.end()) { - transformLoops.push_back(forOp); - } - } + }); + + if (hasSyncBlockSet) { + if (llvm::find(targetLoops, forOp) == targetLoops.end()) { + targetLoops.push_back(forOp); } - - }); + } + } + }); + llvm::outs() << "enter flowsssbuf\n\n"; - llvm::sort(transformLoops, [](scf::ForOp a, scf::ForOp b) { - return getNestingDepth(a) > getNestingDepth(b); - }); - DenseMap VecBitMap; - DenseMap CubeBitMap; - getAllocBit(module, VecBitMap, CubeBitMap, builder); - printDenseMap(CubeBitMap); - printDenseMap(VecBitMap); - SmallVector opsToErase; - for (scf::ForOp forOp : transformLoops) { - DenseMap ifArgMap; - llvm::outs()<<"before replaceif\n"; - auto bufvals = addBufValLoop(forOp, VecBitMap, CubeBitMap, builder); - ReplaceIf(forOp, bufvals, opsToErase, ifArgMap, builder, module); - llvm::outs()<<"after replaceif\n"; - for (const auto& pair : ifArgMap) { - auto val = pair.first; - auto bitValue = pair.second; - llvm::outs()< transformLoops; + // 转换每个目标循环 + for (scf::ForOp forOp : targetLoops) { + auto newforOp = transformLoop(forOp, builder); + } + + module.walk([&](Operation *op) { + if (auto forOp = dyn_cast(op)) { + // 检查循环是否包含特定的 sync_block_set 操作 + bool hasSyncBlockSet = false; + forOp.walk([&](Operation *op) { + if (isa(op)) { + if (auto ifOp = dyn_cast(op->getParentOp())) { + if (forOp == ifOp->getParentOp() && ifOp->hasAttr("ssbuffer")) { + hasSyncBlockSet = true; + } + } } + }); - modifyForIterargDeps(forOp, ifArgMap); - } - for (auto op : opsToErase) { - op->erase(); + if (hasSyncBlockSet) { + if (llvm::find(transformLoops, forOp) == transformLoops.end()) { + transformLoops.push_back(forOp); + } + } } + }); - + llvm::sort(transformLoops, [](scf::ForOp a, scf::ForOp b) { + return getNestingDepth(a) > getNestingDepth(b); + }); + DenseMap VecBitMap; + DenseMap CubeBitMap; + getAllocBit(module, VecBitMap, CubeBitMap, builder); + printDenseMap(CubeBitMap); + printDenseMap(VecBitMap); + SmallVector opsToErase; + for (scf::ForOp forOp : transformLoops) { + DenseMap ifArgMap; + llvm::outs() << "before replaceif\n"; + auto bufvals = addBufValLoop(forOp, VecBitMap, CubeBitMap, builder); + ReplaceIf(forOp, bufvals, opsToErase, ifArgMap, builder, module); + llvm::outs() << "after replaceif\n"; + for (const auto &pair : ifArgMap) { + auto val = pair.first; + auto bitValue = pair.second; + llvm::outs() << val << " " << bitValue << " ifargmrp\n\n\n"; + llvm::outs().flush(); + } + + modifyForIterargDeps(forOp, ifArgMap); + } + for (auto op : opsToErase) { + op->erase(); + } } bool isTransOp(mlir::Operation *op) { @@ -1328,17 +1277,21 @@ bool isTransOp(mlir::Operation *op) { return false; else { - Value copySrc = copyOp.getODSOperands(0).front(); - MemRefType copySrcTy = dyn_cast(copySrc.getType()); - auto SrcAddrSpace = dyn_cast_or_null(copySrcTy.getMemorySpace()); - bool isSrcUbSpace = SrcAddrSpace.getAddressSpace() == hivm::AddressSpace::UB; - - Value copyDst = copyOp.getODSOperands(1).front(); - MemRefType copyDstTy = dyn_cast(copyDst.getType()); - auto DstAddrSpace = dyn_cast_or_null(copyDstTy.getMemorySpace()); - bool isDstCbufSpace = DstAddrSpace.getAddressSpace() == hivm::AddressSpace::L1; - - return isSrcUbSpace && isDstCbufSpace; + Value copySrc = copyOp.getODSOperands(0).front(); + MemRefType copySrcTy = dyn_cast(copySrc.getType()); + auto SrcAddrSpace = + dyn_cast_or_null(copySrcTy.getMemorySpace()); + bool isSrcUbSpace = + SrcAddrSpace.getAddressSpace() == hivm::AddressSpace::UB; + + Value copyDst = copyOp.getODSOperands(1).front(); + MemRefType copyDstTy = dyn_cast(copyDst.getType()); + auto DstAddrSpace = + dyn_cast_or_null(copyDstTy.getMemorySpace()); + bool isDstCbufSpace = + DstAddrSpace.getAddressSpace() == hivm::AddressSpace::L1; + + return isSrcUbSpace && isDstCbufSpace; } } @@ -1351,7 +1304,6 @@ void FindAndMarkBuffer(ModuleOp module) { IntegerAttr idxAttr = builder.getI32IntegerAttr(BufferIdx); module.walk([&](mlir::Operation *op) { - if (isTransOp(op)) { llvm::outs() << "Buffer idx" << BufferIdx << "\n"; llvm::outs() << "Trans Op" << *op << "\n"; @@ -1375,12 +1327,13 @@ void FindAndMarkBuffer(ModuleOp module) { op->setAttr("Set Flag", builder.getI32IntegerAttr(1)); for (Operation *consumerOp : SharedBuffer.getUsers()) { - if (consumerOp == op) + if (consumerOp == op) + continue; + if (!consumerOp) continue; - if (!consumerOp) continue; - + llvm::outs() << "consumerOp: " << *consumerOp << "\n"; - + consumerOp->setAttr("Buffer idx", builder.getI32IntegerAttr(BufferIdx)); consumerOp->setAttr("Wait Flag", builder.getI32IntegerAttr(0)); } @@ -1404,9 +1357,8 @@ struct MergedRegion { SmallVector resultTypes; }; -void MoveIterArgUsersIntoIf( - scf::ForOp forOp, - SmallVector &mergedRegions) { +void MoveIterArgUsersIntoIf(scf::ForOp forOp, + SmallVector &mergedRegions) { Block &body = forOp.getRegion().front(); @@ -1469,14 +1421,13 @@ void MoveIterArgUsersIntoIf( } } -void ComputeYieldForMergedRegion( - MergedRegion &mr, Block &body) { +void ComputeYieldForMergedRegion(MergedRegion &mr, Block &body) { mr.yieldValues.clear(); mr.resultTypes.clear(); - SmallPtrSet inRegion( - mr.opsToMove.begin(), mr.opsToMove.end()); + SmallPtrSet inRegion(mr.opsToMove.begin(), + mr.opsToMove.end()); for (Operation *op : mr.opsToMove) { for (Value res : op->getResults()) { @@ -1504,15 +1455,14 @@ void ComputeYieldForMergedRegion( } } -static void ComputeYieldForMergedRegionV2( - MergedRegion &mr, Block &body) { +static void ComputeYieldForMergedRegionV2(MergedRegion &mr, Block &body) { mr.yieldValues.clear(); mr.resultTypes.clear(); // 当前 region 内的 ops - SmallPtrSet inRegion( - mr.opsToMove.begin(), mr.opsToMove.end()); + SmallPtrSet inRegion(mr.opsToMove.begin(), + mr.opsToMove.end()); for (Operation *op : mr.opsToMove) { for (Value res : op->getResults()) { @@ -1540,93 +1490,92 @@ static void ComputeYieldForMergedRegionV2( } static void ComputeYieldForMergedRegionV3(MergedRegion &mr) { - mr.yieldValues.clear(); - mr.resultTypes.clear(); - - // 用 DenseSet 暂存当前 region 的所有 ops - DenseSet regionOps(mr.opsToMove.begin(), mr.opsToMove.end()); + mr.yieldValues.clear(); + mr.resultTypes.clear(); - for (Operation *op : mr.opsToMove) { - for (Value res : op->getResults()) { + // 用 DenseSet 暂存当前 region 的所有 ops + DenseSet regionOps(mr.opsToMove.begin(), mr.opsToMove.end()); - bool needsYield = false; + for (Operation *op : mr.opsToMove) { + for (Value res : op->getResults()) { - for (OpOperand &use : res.getUses()) { - Operation *user = use.getOwner(); + bool needsYield = false; - // 如果 user 不在当前 region,则需要 yield - if (!regionOps.contains(user)) { - needsYield = true; - break; - } - } + for (OpOperand &use : res.getUses()) { + Operation *user = use.getOwner(); - if (needsYield) { - mr.yieldValues.push_back(res); - mr.resultTypes.push_back(res.getType()); - } + // 如果 user 不在当前 region,则需要 yield + if (!regionOps.contains(user)) { + needsYield = true; + break; } + } + + if (needsYield) { + mr.yieldValues.push_back(res); + mr.resultTypes.push_back(res.getType()); + } } + } } // 递归收集 op 和它所有 region 内的 ops -static void CollectAllNestedOps(Operation *op, DenseSet ®ionOps) { - if (!op) - return; - - if (regionOps.contains(op)) - return; // 已经收集过 +static void CollectAllNestedOps(Operation *op, + DenseSet ®ionOps) { + if (!op) + return; - regionOps.insert(op); + if (regionOps.contains(op)) + return; // 已经收集过 - // 遍历所有 region,递归收集 - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - for (Operation &nestedOp : block) { - CollectAllNestedOps(&nestedOp, regionOps); - } - } + regionOps.insert(op); + + // 遍历所有 region,递归收集 + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : block) { + CollectAllNestedOps(&nestedOp, regionOps); + } } + } } static void ComputeYieldForMergedRegionV4(MergedRegion &mr) { - mr.yieldValues.clear(); - mr.resultTypes.clear(); - - // 用 DenseSet 暂存当前 region 的所有 ops - // 初始 DenseSet: 顶层 opsToMove - DenseSet regionOps; - for (Operation *op : mr.opsToMove) { - CollectAllNestedOps(op, regionOps); // 完整展开嵌套 - } + mr.yieldValues.clear(); + mr.resultTypes.clear(); - for (Operation *op : mr.opsToMove) { - for (Value res : op->getResults()) { + // 用 DenseSet 暂存当前 region 的所有 ops + // 初始 DenseSet: 顶层 opsToMove + DenseSet regionOps; + for (Operation *op : mr.opsToMove) { + CollectAllNestedOps(op, regionOps); // 完整展开嵌套 + } - bool needsYield = false; + for (Operation *op : mr.opsToMove) { + for (Value res : op->getResults()) { - for (OpOperand &use : res.getUses()) { - Operation *user = use.getOwner(); + bool needsYield = false; - // 如果 user 不在当前 region,则需要 yield - if (!regionOps.contains(user)) { - needsYield = true; - break; - } - } + for (OpOperand &use : res.getUses()) { + Operation *user = use.getOwner(); - if (needsYield) { - mr.yieldValues.push_back(res); - mr.resultTypes.push_back(res.getType()); - } + // 如果 user 不在当前 region,则需要 yield + if (!regionOps.contains(user)) { + needsYield = true; + break; } + } + + if (needsYield) { + mr.yieldValues.push_back(res); + mr.resultTypes.push_back(res.getType()); + } } + } } -int findTargetRegion( - Operation *startOp, - Block &body, - DenseMap &opToRegion) { +int findTargetRegion(Operation *startOp, Block &body, + DenseMap &opToRegion) { SmallVector worklist{startOp}; SmallPtrSet visited; @@ -1653,20 +1602,16 @@ int findTargetRegion( return -1; } -void greedyAbsorbToRegion( - Operation *startOp, - int regionIdx, - int lowerBound, - Block &body, - DenseMap &opIndex, - DenseMap &opToRegion, - SmallVector &mergedRegions) { +void greedyAbsorbToRegion(Operation *startOp, int regionIdx, int lowerBound, + Block &body, DenseMap &opIndex, + DenseMap &opToRegion, + SmallVector &mergedRegions) { auto &mr = mergedRegions[regionIdx]; SmallVector worklist; - SmallPtrSet visited( - mr.opsToMove.begin(), mr.opsToMove.end()); + SmallPtrSet visited(mr.opsToMove.begin(), + mr.opsToMove.end()); // 先把 startOp 本身吸收(如果还没被吸收) if (!opToRegion.count(startOp)) { @@ -1697,8 +1642,7 @@ void greedyAbsorbToRegion( auto it = opToRegion.find(defOp); // 不能跨到其他 region - if (it != opToRegion.end() && - it->second != regionIdx) + if (it != opToRegion.end() && it->second != regionIdx) continue; // 去重 @@ -1713,9 +1657,10 @@ void greedyAbsorbToRegion( } } -SmallVector getOperationInput(Operation *op, SmallVector dependValues, - DenseMap>> &collectDepValueMap) -{ +SmallVector +getOperationInput(Operation *op, SmallVector dependValues, + DenseMap>> + &collectDepValueMap) { // Analyse each Op's input DenseSet opInput; if (isa(op) || isa(op)) { @@ -1729,12 +1674,13 @@ SmallVector getOperationInput(Operation *op, SmallVector dependVal } // recursively walk scf op - for (Block *curBlock: regionBlocks) { + for (Block *curBlock : regionBlocks) { for (auto &curOp : *curBlock) { - for (auto operand : getOperationInput(&curOp, dependValues, collectDepValueMap)) { + for (auto operand : + getOperationInput(&curOp, dependValues, collectDepValueMap)) { Operation *defOp; if (auto blockArg = dyn_cast(operand)) { - Block* ownerBlock = blockArg.getOwner(); + Block *ownerBlock = blockArg.getOwner(); defOp = ownerBlock->getParentOp(); } else { defOp = operand.getDefiningOp(); @@ -1767,20 +1713,22 @@ SmallVector getOperationInput(Operation *op, SmallVector dependVal } } -SmallVector collectDepValuesCalculation(DenseSet forRegionOps, - DenseSet regionOps, Operation *op, SmallVector dependValues, - DenseMap>> &collectDepValueMap) -{ +SmallVector collectDepValuesCalculation( + DenseSet forRegionOps, DenseSet regionOps, + Operation *op, SmallVector dependValues, + DenseMap>> + &collectDepValueMap) { DenseSet collectOps; std::deque opStack; bool flag = false; - + opStack.push_back(op); while (opStack.size()) { Operation *curOp = opStack.front(); opStack.pop_front(); - for (auto operand : getOperationInput(curOp, dependValues, collectDepValueMap)) { + for (auto operand : + getOperationInput(curOp, dependValues, collectDepValueMap)) { if (llvm::is_contained(dependValues, operand)) { flag = true; } @@ -1807,9 +1755,11 @@ SmallVector collectDepValuesCalculation(DenseSet forRe } } -void copyOpsToMergedRegion(scf::ForOp forOp, SmallVector collectOps, MergedRegion &mergedRegion, - DenseMap>> &collectDepValueMap) -{ +void copyOpsToMergedRegion( + scf::ForOp forOp, SmallVector collectOps, + MergedRegion &mergedRegion, + DenseMap>> + &collectDepValueMap) { Block *forBodyBlock = forOp.getBody(); OpBuilder builder(forOp); SmallVector clonedOps; @@ -1843,7 +1793,7 @@ void copyOpsToMergedRegion(scf::ForOp forOp, SmallVector collectOps DenseSet mergedRegionOps; for (Operation *op : mergedRegion.opsToMove) { - CollectAllNestedOps(op, mergedRegionOps); + CollectAllNestedOps(op, mergedRegionOps); } // replace the ifresult value by new cloned op's result @@ -1860,19 +1810,19 @@ void copyOpsToMergedRegion(scf::ForOp forOp, SmallVector collectOps mergedRegion.opsToMove = clonedOps; } -void copyLoadCalculation(scf::ForOp forOp, SmallVector dependValues, SmallVector &mergedRegions) -{ - mlir::Operation* parentOp = forOp->getParentOp(); - mlir::Operation* scopeOp = nullptr; +void copyLoadCalculation(scf::ForOp forOp, SmallVector dependValues, + SmallVector &mergedRegions) { + mlir::Operation *parentOp = forOp->getParentOp(); + mlir::Operation *scopeOp = nullptr; while (parentOp) { - if (dyn_cast(parentOp)) { - scopeOp = parentOp; - break; - } - parentOp = parentOp->getParentOp(); + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; + } + parentOp = parentOp->getParentOp(); } - auto coreTypeAttr = scopeOp->getAttrOfType( - hivm::TCoreTypeAttr::name); + auto coreTypeAttr = + scopeOp->getAttrOfType(hivm::TCoreTypeAttr::name); // only process the vector core if (coreTypeAttr.getTcoretype() == hivm::TCoreType::CUBE) { return; @@ -1881,21 +1831,23 @@ void copyLoadCalculation(scf::ForOp forOp, SmallVector dependValues, Smal // recursively collect all op in forOp DenseSet forRegionOps; for (Operation &op : forOp.getBody()->without_terminator()) { - CollectAllNestedOps(&op, forRegionOps); + CollectAllNestedOps(&op, forRegionOps); } - + for (MergedRegion &mr : mergedRegions) { DenseSet regionOps; for (Operation *op : mr.opsToMove) { - CollectAllNestedOps(op, regionOps); + CollectAllNestedOps(op, regionOps); } for (Operation *op : regionOps) { if (isa(op) || isa(op)) { - // recusively check that whether load/store op's operands originated from if results - DenseMap>> collectDepValueMap; - SmallVector collectOps = \ - collectDepValuesCalculation(forRegionOps, regionOps, op, dependValues, collectDepValueMap); + // recusively check that whether load/store op's operands originated + // from if results + DenseMap>> + collectDepValueMap; + SmallVector collectOps = collectDepValuesCalculation( + forRegionOps, regionOps, op, dependValues, collectDepValueMap); copyOpsToMergedRegion(forOp, collectOps, mr, collectDepValueMap); } } @@ -1904,9 +1856,8 @@ void copyLoadCalculation(scf::ForOp forOp, SmallVector dependValues, Smal // 以 forOp 的 yield value 为中心 // 决定它应该归属哪个 mergedRegion, 然后再向前吸 operand -void ExpandMergedRegionOpsForAIV( - scf::ForOp forOp, - SmallVector &mergedRegions) { +void ExpandMergedRegionOpsForAIV(scf::ForOp forOp, + SmallVector &mergedRegions) { Block &body = forOp.getRegion().front(); @@ -1923,8 +1874,7 @@ void ExpandMergedRegionOpsForAIV( opToRegion[op] = r; // 取 scf.yield - auto yieldOp = - cast(body.getTerminator()); + auto yieldOp = cast(body.getTerminator()); // 依次处理每个 yield value(按编号顺序) for (Value yv : yieldOp.getOperands()) { @@ -1941,8 +1891,7 @@ void ExpandMergedRegionOpsForAIV( targetRegion = it->second; } else { // 否则向前搜索确定归属 - targetRegion = - findTargetRegion(defOp, body, opToRegion); + targetRegion = findTargetRegion(defOp, body, opToRegion); } if (targetRegion == -1) @@ -1952,34 +1901,26 @@ void ExpandMergedRegionOpsForAIV( int lowerBound = 0; if (targetRegion > 0) { - Operation *prevLast = - mergedRegions[targetRegion - 1] - .opsToMove.back(); + Operation *prevLast = mergedRegions[targetRegion - 1].opsToMove.back(); lowerBound = opIndex[prevLast] + 1; } // 真正贪心吸收 - greedyAbsorbToRegion(defOp, - targetRegion, - lowerBound, - body, - opIndex, - opToRegion, - mergedRegions); + greedyAbsorbToRegion(defOp, targetRegion, lowerBound, body, opIndex, + opToRegion, mergedRegions); } // 每个 region 内按 block 顺序排序 for (auto &mr : mergedRegions) { - llvm::sort(mr.opsToMove, - [&](Operation *a, Operation *b) { - return opIndex[a] < opIndex[b]; - }); + llvm::sort(mr.opsToMove, [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); } } // 以 mergedRegion 为中心, 向前吸 operand void ExpandMergedRegionOpsForAIC(scf::ForOp forOp, - SmallVector &mergedRegions) { + SmallVector &mergedRegions) { Block &body = forOp.getRegion().front(); // 记录每个 mergedRegion 的起始 op index @@ -1998,19 +1939,17 @@ void ExpandMergedRegionOpsForAIC(scf::ForOp forOp, // 边界: 前一个 mergedRegion 的最后一个 op if (r > 0) { - Operation *prevLast = - mergedRegions[r - 1].opsToMove.back(); + Operation *prevLast = mergedRegions[r - 1].opsToMove.back(); lowerBound = opIndex[prevLast] + 1; } - SmallVector worklist(mr.opsToMove.begin(), - mr.opsToMove.end()); - SmallPtrSet visited( - mr.opsToMove.begin(), mr.opsToMove.end()); + SmallVector worklist(mr.opsToMove.begin(), mr.opsToMove.end()); + SmallPtrSet visited(mr.opsToMove.begin(), + mr.opsToMove.end()); while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); - + // 往前吸收operand for (Value operand : op->getOperands()) { // BlockArgument @@ -2042,18 +1981,15 @@ void ExpandMergedRegionOpsForAIC(scf::ForOp forOp, } // 最后按原 block 顺序排序 - llvm::sort(mr.opsToMove, - [&](Operation *a, Operation *b) { - return opIndex[a] < opIndex[b]; - }); + llvm::sort(mr.opsToMove, [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); } } -static void pullInRegionDependencies( - Operation *regionOp, - int regionId, - DenseMap &opToRegion, - Block &body) { +static void pullInRegionDependencies(Operation *regionOp, int regionId, + DenseMap &opToRegion, + Block &body) { SmallVector worklist; @@ -2100,10 +2036,9 @@ static void pullInRegionDependencies( } // BFS 查找某个 op 最早被哪个 region 使用 -static int findEarliestRegion( - Operation *startOp, - const DenseMap &seedRegionMap, - Block &body) { +static int findEarliestRegion(Operation *startOp, + const DenseMap &seedRegionMap, + Block &body) { SmallVector worklist{startOp}; SmallPtrSet visited; @@ -2137,9 +2072,8 @@ static int findEarliestRegion( return earliestRegion; } -void ExpandMergedRegionOpsForAll( - scf::ForOp forOp, - SmallVector &mergedRegions) { +void ExpandMergedRegionOpsForAll(scf::ForOp forOp, + SmallVector &mergedRegions) { Block &body = forOp.getRegion().front(); @@ -2207,16 +2141,14 @@ void ExpandMergedRegionOpsForAll( // ---------- Step4 排序 ---------- for (auto &mr : mergedRegions) { - llvm::sort(mr.opsToMove, - [&](Operation *a, Operation *b) { - return opIndex[a] < opIndex[b]; - }); + llvm::sort(mr.opsToMove, [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); } } -void ExpandMergedRegionOpsByInput( - scf::ForOp forOp, - SmallVector &mergedRegions) { +void ExpandMergedRegionOpsByInput(scf::ForOp forOp, + SmallVector &mergedRegions) { Block &body = forOp.getRegion().front(); @@ -2284,16 +2216,15 @@ void ExpandMergedRegionOpsByInput( // ---------- Step4 排序 ---------- for (auto &mr : mergedRegions) { - llvm::sort(mr.opsToMove, - [&](Operation *a, Operation *b) { - return opIndex[a] < opIndex[b]; - }); + llvm::sort(mr.opsToMove, [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); } } -static void ExpandMergedRegionOpsByOutput( - scf::ForOp forOp, - SmallVector &mergedRegions) { +static void +ExpandMergedRegionOpsByOutput(scf::ForOp forOp, + SmallVector &mergedRegions) { Block &body = forOp.getRegion().front(); @@ -2349,16 +2280,14 @@ static void ExpandMergedRegionOpsByOutput( } // 排序保持原 block 顺序 - llvm::sort(merged.opsToMove, - [&](Operation *a, Operation *b) { - return opOrder[a] < opOrder[b]; - }); + llvm::sort(merged.opsToMove, [&](Operation *a, Operation *b) { + return opOrder[a] < opOrder[b]; + }); } } -static void MoveIndependentOpsIntoIf( - scf::ForOp forOp, - SmallVector &mergedRegions) { +static void MoveIndependentOpsIntoIf(scf::ForOp forOp, + SmallVector &mergedRegions) { Block &body = forOp.getRegion().front(); @@ -2430,17 +2359,16 @@ static void MoveIndependentOpsIntoIf( // 排序保持 block 顺序 for (auto &mr : mergedRegions) { - llvm::sort(mr.opsToMove, - [&](Operation *a, Operation *b) { - return opIndex[a] < opIndex[b]; - }); + llvm::sort(mr.opsToMove, [&](Operation *a, Operation *b) { + return opIndex[a] < opIndex[b]; + }); } } // 暴力包裹 -static void ExpandMergedRegionOpsGreedyMaximum( - scf::ForOp forOp, - SmallVector &mergedRegions) { +static void +ExpandMergedRegionOpsGreedyMaximum(scf::ForOp forOp, + SmallVector &mergedRegions) { Block &body = forOp.getRegion().front(); @@ -2513,10 +2441,9 @@ static void ExpandMergedRegionOpsGreedyMaximum( // 最后保持 block 顺序 for (auto ®ion : mergedRegions) { - llvm::sort(region.opsToMove, - [](Operation *a, Operation *b) { - return a->isBeforeInBlock(b); - }); + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); region.opsToMove.erase( std::unique(region.opsToMove.begin(), region.opsToMove.end()), @@ -2524,10 +2451,9 @@ static void ExpandMergedRegionOpsGreedyMaximum( } } -static void CollectForYieldRelatedOps( - scf::ForOp forOp, - SmallVector &mergedRegions, - DenseSet &yieldRelatedOps) { +static void CollectForYieldRelatedOps(scf::ForOp forOp, + SmallVector &mergedRegions, + DenseSet &yieldRelatedOps) { Block &body = forOp.getRegion().front(); @@ -2575,10 +2501,10 @@ static void CollectForYieldRelatedOps( } // 贪心吸收region前后的op -static void ExpandMergedRegionOpsGreedy( - scf::ForOp forOp, - SmallVector &mergedRegions, - DenseSet &skipOps) { +static void +ExpandMergedRegionOpsGreedy(scf::ForOp forOp, + SmallVector &mergedRegions, + DenseSet &skipOps) { Block &body = forOp.getRegion().front(); @@ -2663,10 +2589,9 @@ static void ExpandMergedRegionOpsGreedy( // 最后保持 block 顺序 for (auto ®ion : mergedRegions) { - llvm::sort(region.opsToMove, - [](Operation *a, Operation *b) { - return a->isBeforeInBlock(b); - }); + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); region.opsToMove.erase( std::unique(region.opsToMove.begin(), region.opsToMove.end()), @@ -2675,10 +2600,10 @@ static void ExpandMergedRegionOpsGreedy( } // 贪心吸收region前面的op -static void ExpandMergedRegionOpsGreedyV2( - scf::ForOp forOp, - SmallVector &mergedRegions, - DenseSet &skipOps) { +static void +ExpandMergedRegionOpsGreedyV2(scf::ForOp forOp, + SmallVector &mergedRegions, + DenseSet &skipOps) { Block &body = forOp.getRegion().front(); @@ -2755,22 +2680,20 @@ static void ExpandMergedRegionOpsGreedyV2( // ---------- 保持 block 顺序 ---------- for (auto ®ion : mergedRegions) { - llvm::sort(region.opsToMove, - [](Operation *a, Operation *b) { - return a->isBeforeInBlock(b); - }); + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); region.opsToMove.erase( - std::unique(region.opsToMove.begin(), - region.opsToMove.end()), + std::unique(region.opsToMove.begin(), region.opsToMove.end()), region.opsToMove.end()); } } // 贪心吸收region前面的op -static void ExpandMergedRegionOpsGreedyV2ForAIC( - scf::ForOp forOp, - SmallVector &mergedRegions) { +static void +ExpandMergedRegionOpsGreedyV2ForAIC(scf::ForOp forOp, + SmallVector &mergedRegions) { Block &body = forOp.getRegion().front(); @@ -2843,22 +2766,19 @@ static void ExpandMergedRegionOpsGreedyV2ForAIC( // ---------- 保持 block 顺序 ---------- for (auto ®ion : mergedRegions) { - llvm::sort(region.opsToMove, - [](Operation *a, Operation *b) { - return a->isBeforeInBlock(b); - }); + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); region.opsToMove.erase( - std::unique(region.opsToMove.begin(), - region.opsToMove.end()), + std::unique(region.opsToMove.begin(), region.opsToMove.end()), region.opsToMove.end()); } } -static void MoveForYieldOpIntoRegion( - scf::ForOp forOp, - DenseSet &yieldRelatedOps, - SmallVector &mergedRegions) { +static void MoveForYieldOpIntoRegion(scf::ForOp forOp, + DenseSet &yieldRelatedOps, + SmallVector &mergedRegions) { DenseMap opToRegion; @@ -2915,10 +2835,9 @@ static void MoveForYieldOpIntoRegion( for (auto ®ion : mergedRegions) { - llvm::sort(region.opsToMove, - [](Operation *a, Operation *b) { - return a->isBeforeInBlock(b); - }); + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); region.opsToMove.erase( std::unique(region.opsToMove.begin(), region.opsToMove.end()), @@ -2926,10 +2845,10 @@ static void MoveForYieldOpIntoRegion( } } -static void MoveRemainingYieldOpsToPrevRegion( - scf::ForOp forOp, - DenseSet &yieldRelatedOps, - SmallVector &mergedRegions) { +static void +MoveRemainingYieldOpsToPrevRegion(scf::ForOp forOp, + DenseSet &yieldRelatedOps, + SmallVector &mergedRegions) { if (yieldRelatedOps.empty()) return; @@ -2979,10 +2898,9 @@ static void MoveRemainingYieldOpsToPrevRegion( // 排序 + 去重 for (auto ®ion : mergedRegions) { - llvm::sort(region.opsToMove, - [](Operation *a, Operation *b) { - return a->isBeforeInBlock(b); - }); + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); region.opsToMove.erase( std::unique(region.opsToMove.begin(), region.opsToMove.end()), @@ -2991,8 +2909,7 @@ static void MoveRemainingYieldOpsToPrevRegion( } static void MoveIndependentOpsIntoRegionBackwardV2( - scf::ForOp forOp, - SmallVector &mergedRegions) { + scf::ForOp forOp, SmallVector &mergedRegions) { Block &body = forOp.getRegion().front(); SmallVector ops; @@ -3009,25 +2926,29 @@ static void MoveIndependentOpsIntoRegionBackwardV2( for (int i = 0; i < mergedRegions.size(); i++) { MergedRegion ®ion = mergedRegions[i]; - if (region.opsToMove.empty()) continue; + if (region.opsToMove.empty()) + continue; Operation *firstOp = region.opsToMove.front(); - Operation *lastOp = region.opsToMove.back(); + Operation *lastOp = region.opsToMove.back(); auto itFirst = std::find(ops.begin(), ops.end(), firstOp); - auto itLast = std::find(ops.begin(), ops.end(), lastOp); - if (itFirst == ops.end() || itLast == ops.end()) continue; + auto itLast = std::find(ops.begin(), ops.end(), lastOp); + if (itFirst == ops.end() || itLast == ops.end()) + continue; int startIdx = std::distance(ops.begin(), itFirst); - int endIdx = std::distance(ops.begin(), itLast); + int endIdx = std::distance(ops.begin(), itLast); // ----------- 收集 wait-set 区间 ----------- - SmallVector> waitIntervals; + SmallVector> waitIntervals; bool inWait = false; int begin = -1; for (int j = startIdx; j <= endIdx; j++) { Operation *op = ops[j]; if (op->getName().getStringRef().contains("sync_block_wait")) { - inWait = true; begin = j + 1; continue; + inWait = true; + begin = j + 1; + continue; } if (op->getName().getStringRef().contains("sync_block_set") && inWait) { inWait = false; @@ -3036,42 +2957,46 @@ static void MoveIndependentOpsIntoRegionBackwardV2( } auto isInWaitSet = [&](int idx) { for (auto &p : waitIntervals) - if (idx >= p.first && idx <= p.second) return true; + if (idx >= p.first && idx <= p.second) + return true; return false; }; // ----------- 从后往前扫描 region 内的 op ----------- for (int j = endIdx; j >= startIdx; j--) { Operation *op = ops[j]; - if (isa(op) || isInWaitSet(j)) continue; + if (isa(op) || isInWaitSet(j)) + continue; // ---------- operand 是否依赖本 region ---------- bool dependCurrentRegion = false; for (Value operand : op->getOperands()) { Operation *def = operand.getDefiningOp(); - if (!def) continue; - if (std::find(region.opsToMove.begin(), - region.opsToMove.end(), - def) != region.opsToMove.end()) { - dependCurrentRegion = true; break; + if (!def) + continue; + if (std::find(region.opsToMove.begin(), region.opsToMove.end(), def) != + region.opsToMove.end()) { + dependCurrentRegion = true; + break; } } - if (dependCurrentRegion) continue; + if (dependCurrentRegion) + continue; // ---------- 当前 region 后续是否使用 ---------- bool usedLaterInSameRegion = false; for (Value result : op->getResults()) for (Operation *user : result.getUsers()) - if (std::find(region.opsToMove.begin(), - region.opsToMove.end(), + if (std::find(region.opsToMove.begin(), region.opsToMove.end(), user) != region.opsToMove.end() && - std::find(region.opsToMove.begin(), - region.opsToMove.end(), op) < - std::find(region.opsToMove.begin(), - region.opsToMove.end(), user)) { - usedLaterInSameRegion = true; break; + std::find(region.opsToMove.begin(), region.opsToMove.end(), op) < + std::find(region.opsToMove.begin(), region.opsToMove.end(), + user)) { + usedLaterInSameRegion = true; + break; } - if (usedLaterInSameRegion) continue; + if (usedLaterInSameRegion) + continue; // ---------- 找使用该 op 的后续 region ---------- int targetRegion = -1; @@ -3079,12 +3004,16 @@ static void MoveIndependentOpsIntoRegionBackwardV2( for (Operation *candidate : mergedRegions[k].opsToMove) for (Value operand : candidate->getOperands()) if (operand.getDefiningOp() == op) { - targetRegion = k; break; + targetRegion = k; + break; } - if (targetRegion != -1) break; - if (targetRegion != -1) break; + if (targetRegion != -1) + break; + if (targetRegion != -1) + break; } - if (targetRegion == -1) continue; + if (targetRegion == -1) + continue; movePlan[op] = targetRegion; // llvm::outs() << "MJ: plan move " << *op @@ -3092,16 +3021,16 @@ static void MoveIndependentOpsIntoRegionBackwardV2( } } - // ----------- 统一应用移动 ----------- + // ----------- 统一应用移动 ----------- for (auto &it : movePlan) { - Operation *op = it.first; - int targetRegionIdx = it.second; - MergedRegion &targetRegion = mergedRegions[targetRegionIdx]; - // 更新数据结构 - targetRegion.opsToMove.push_back(op); - - llvm::outs() << "MJ: move " << *op - << " -> region " << targetRegionIdx << "\n"; + Operation *op = it.first; + int targetRegionIdx = it.second; + MergedRegion &targetRegion = mergedRegions[targetRegionIdx]; + // 更新数据结构 + targetRegion.opsToMove.push_back(op); + + llvm::outs() << "MJ: move " << *op << " -> region " << targetRegionIdx + << "\n"; } // ----------- 更新原 region 的 opsToMove ----------- @@ -3109,29 +3038,28 @@ static void MoveIndependentOpsIntoRegionBackwardV2( MergedRegion ®ion = mergedRegions[i]; SmallVector newOps; for (Operation *op : region.opsToMove) { - auto it = movePlan.find(op); - if (it == movePlan.end() || it->second == i) { - // 没有移动计划,或者移动的目标就是自己,保留 - newOps.push_back(op); - } + auto it = movePlan.find(op); + if (it == movePlan.end() || it->second == i) { + // 没有移动计划,或者移动的目标就是自己,保留 + newOps.push_back(op); + } } region.opsToMove.swap(newOps); } // ----------- 排序 + 去重 ----------- for (auto ®ion : mergedRegions) { - llvm::sort(region.opsToMove, - [](Operation *a, Operation *b) { - return a->isBeforeInBlock(b); - }); + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); region.opsToMove.erase( - std::unique(region.opsToMove.begin(), - region.opsToMove.end()), + std::unique(region.opsToMove.begin(), region.opsToMove.end()), region.opsToMove.end()); } } -// // debug: 如果一个forop的第一个region的最后3条op是%27 = tt.expand_dims %25#1 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> +// // debug: 如果一个forop的第一个region的最后3条op是%27 = tt.expand_dims %25#1 +// {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> // %28 = tt.broadcast %27 : tensor<64x1xf32> -> tensor<64x128xf32> // %29 = arith.mulf %arg10, %28 : tensor<64x128xf32> // 直接放到第2个region里 @@ -3163,7 +3091,7 @@ static void TempChange(scf::ForOp forOp, llvm::outs() << "TempChange triggered\n"; - SmallVector opsToMove = {op1, op2, op3}; + SmallVector opsToMove = {op1, op2, op3}; // ---------- 移动到 region2 末尾 ---------- for (Operation *op : opsToMove) { @@ -3176,23 +3104,21 @@ static void TempChange(scf::ForOp forOp, // ---------- 排序 ---------- for (auto ®ion : mergedRegions) { - llvm::sort(region.opsToMove, - [](Operation *a, Operation *b) { - return a->isBeforeInBlock(b); - }); + llvm::sort(region.opsToMove, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); region.opsToMove.erase( - std::unique(region.opsToMove.begin(), - region.opsToMove.end()), + std::unique(region.opsToMove.begin(), region.opsToMove.end()), region.opsToMove.end()); } } -static void sortOperationsByDataFlow(llvm::SmallVector &ops) { - llvm::DenseSet visited; - llvm::SmallVector result; +static void sortOperationsByDataFlow(llvm::SmallVector &ops) { + llvm::DenseSet visited; + llvm::SmallVector result; - std::function dfs = [&](Operation *op) { + std::function dfs = [&](Operation *op) { if (!visited.insert(op).second) return; @@ -3299,7 +3225,7 @@ static void CopyOpsToAfterwardRegions( if (yieldDefOps.contains(op)) { cloneAndOriYieldMap[cloned] = op; } - + // 记录copy的for op if (auto forOp = dyn_cast(cloned)) { copiedForOps.push_back(forOp); @@ -3310,8 +3236,8 @@ static void CopyOpsToAfterwardRegions( } // 插入到当前 region 开头 - curRegion.opsToMove.insert(curRegion.opsToMove.begin(), - clonedOps.begin(), clonedOps.end()); + curRegion.opsToMove.insert(curRegion.opsToMove.begin(), clonedOps.begin(), + clonedOps.end()); // rebuild SSA for (Operation *op : curRegion.opsToMove) { @@ -3320,7 +3246,6 @@ static void CopyOpsToAfterwardRegions( // 排序保证拓扑顺序 sortOperationsByDataFlow(curRegion.opsToMove); - } } @@ -3338,7 +3263,8 @@ static void GetYieldMap(scf::ForOp forOp, // 获取生成 yieldVal 的原始 op Operation *defOp = yieldVal.getDefiningOp(); - // 对 block arg(可能是 iter_arg)没有 definingOp 的情况,可以跳过或直接记录 nullptr + // 对 block arg(可能是 iter_arg)没有 definingOp 的情况,可以跳过或直接记录 + // nullptr if (!defOp) continue; @@ -3365,10 +3291,10 @@ static Value findIterArgForAIC(Value v, scf::ForOp forOp) { } } -static Operation *findCloneOfYieldOp( - Operation *oriYieldOp, - DenseMap &cloneAndOriYieldMap, - MergedRegion ®ion) { +static Operation * +findCloneOfYieldOp(Operation *oriYieldOp, + DenseMap &cloneAndOriYieldMap, + MergedRegion ®ion) { for (Operation *op : region.opsToMove) { auto it = cloneAndOriYieldMap.find(op); @@ -3379,8 +3305,7 @@ static Operation *findCloneOfYieldOp( } static void RebuildForYielValuesForAIC( - scf::ForOp forOp, - SmallVector &mergedRegions, + scf::ForOp forOp, SmallVector &mergedRegions, DenseMap &yieldMap, DenseMap &cloneAndOriYieldMap) { @@ -3439,9 +3364,9 @@ void ExpandMergedRegionOps(scf::ForOp forOp, auto scopeOp = forOp->getParentOfType(); if (!scopeOp) return; - - auto coreTypeAttr = scopeOp->getAttrOfType( - hivm::TCoreTypeAttr::name); + + auto coreTypeAttr = + scopeOp->getAttrOfType(hivm::TCoreTypeAttr::name); if (coreTypeAttr.getTcoretype() == hivm::TCoreType::VECTOR) { isInAIV = true; @@ -3465,35 +3390,33 @@ void ExpandMergedRegionOps(scf::ForOp forOp, // 5 剩余 yield chain 放入前一个 region MoveRemainingYieldOpsToPrevRegion(forOp, yieldRelatedOps, mergedRegions); - } - else { // AIC单独处理, 避免出现CUBE内的tensor变量依赖 + } else { // AIC单独处理, 避免出现CUBE内的tensor变量依赖 // 用Map记录原始的for yield op的的映射 DenseMap yieldMap; GetYieldMap(forOp, yieldMap); - - llvm::outs()<<"YieldMap:\n"; - for(auto it: yieldMap) { - llvm::outs()<<*(it.second)<<"\n"; + + llvm::outs() << "YieldMap:\n"; + for (auto it : yieldMap) { + llvm::outs() << *(it.second) << "\n"; } // 2 greedy 扩展, yield value后续处理 ExpandMergedRegionOpsGreedyV2ForAIC(forOp, mergedRegions); - - // 复制当前region的除tt.dot、以及[wait - set]之间的op到后续的所有MergedRegion - // 倒序实现 + + // 复制当前region的除tt.dot、以及[wait - + // set]之间的op到后续的所有MergedRegion 倒序实现 // 记录clone和original的yield对应op的map DenseMap cloneAndOriYieldMap; - CopyOpsToAfterwardRegions(mergedRegions, yieldMap, cloneAndOriYieldMap, copiedForOps); + CopyOpsToAfterwardRegions(mergedRegions, yieldMap, cloneAndOriYieldMap, + copiedForOps); - // 4 先确定每个MergedRegion的tt.dot的operand的来源是for的哪个iter_arg(递归查找), 假设为%arg0, 依据yieldMap可以得到oriYield - // 遍历当前MergedRegion的所有op, 确定哪条op对应的cloneAndOriYieldMap的second是oriYield, 假设为%45 + // 4 + // 先确定每个MergedRegion的tt.dot的operand的来源是for的哪个iter_arg(递归查找), + // 假设为%arg0, 依据yieldMap可以得到oriYield 遍历当前MergedRegion的所有op, + // 确定哪条op对应的cloneAndOriYieldMap的second是oriYield, 假设为%45 // 最后替换for yield op对应位置的operand为%45 - RebuildForYielValuesForAIC( - forOp, - mergedRegions, - yieldMap, - cloneAndOriYieldMap); - + RebuildForYielValuesForAIC(forOp, mergedRegions, yieldMap, + cloneAndOriYieldMap); } } @@ -3505,8 +3428,7 @@ void MergeWaitSetRegions(SmallVector ®ions, mr.opsToMove.append(regions[i].opsToMove); int j = i; - while (!regions[j].hasCopyOrFixpipe && - j + 1 < regions.size()) { + while (!regions[j].hasCopyOrFixpipe && j + 1 < regions.size()) { j++; mr.regions.push_back(®ions[j]); mr.opsToMove.append(regions[j].opsToMove); @@ -3528,8 +3450,7 @@ void MergeWaitSetRegions(SmallVector ®ions, bool usedOutside = false; for (OpOperand &use : v.getUses()) { Operation *user = use.getOwner(); - if (!opSet.contains(user) && - user->getBlock() == op->getBlock()) { + if (!opSet.contains(user) && user->getBlock() == op->getBlock()) { usedOutside = true; break; } @@ -3549,14 +3470,14 @@ void GetBlockInfos(SmallVector ®ions, Block &body) { auto waitOp = dyn_cast(op); if (!waitOp) { - it++; - continue; + it++; + continue; } auto pipeS = hivm::PipeAttr::get(op->getContext(), hivm::PIPE::PIPE_S); if (auto syncWait = dyn_cast(op)) { if (syncWait.getTpipe() == pipeS || syncWait.getPipe() == pipeS) { - return; + return; } } Operation *lastSetOp = nullptr; @@ -3568,18 +3489,19 @@ void GetBlockInfos(SmallVector ®ions, Block &body) { SmallVector opsInRegion; for (; curIt != body.end(); ++curIt) { Operation *curOp = &*curIt; - if (isa(curOp) && setOpCount >= 1) break; + if (isa(curOp) && setOpCount >= 1) + break; if (isa(curOp)) { setOpCount++; - endIt = curIt; //setop的位置 - lastSetOp = curOp; // 最后一个 set + endIt = curIt; // setop的位置 + lastSetOp = curOp; // 最后一个 set } } if (!lastSetOp) { it = curIt; continue; - }// 没有 set, 不包 + } // 没有 set, 不包 // 收集 [wait, ..., lastSet] 之间的 ops bool hasCopyOrFixpipe = false; @@ -3590,67 +3512,72 @@ void GetBlockInfos(SmallVector ®ions, Block &body) { hasCopyOrFixpipe = true; } } - + it = endIt++; regions.push_back({waitOp, lastSetOp, opsInRegion, hasCopyOrFixpipe}); } } Value findIterArg(Value v, Type t) { - SmallVector worklist = {v}; - SmallPtrSet visited; + SmallVector worklist = {v}; + SmallPtrSet visited; - while (!worklist.empty()) { - Value cur = worklist.front(); - worklist.erase(worklist.begin()); - if (!visited.insert(cur).second) - continue; + while (!worklist.empty()) { + Value cur = worklist.front(); + worklist.erase(worklist.begin()); + if (!visited.insert(cur).second) + continue; - // 匹配scf.for原始迭代参数, 直接返回 - if (auto b = mlir::dyn_cast(cur)) { - auto forOp = mlir::dyn_cast(b.getOwner()->getParentOp()); - if (forOp && b.getType() == t) { - for (Value iterArg : forOp.getRegionIterArgs()) { - if (iterArg.getAsOpaquePointer() == b.getAsOpaquePointer()) { - return b; - } - } - } + // 匹配scf.for原始迭代参数, 直接返回 + if (auto b = mlir::dyn_cast(cur)) { + auto forOp = mlir::dyn_cast(b.getOwner()->getParentOp()); + if (forOp && b.getType() == t) { + for (Value iterArg : forOp.getRegionIterArgs()) { + if (iterArg.getAsOpaquePointer() == b.getAsOpaquePointer()) { + return b; + } } + } + } - Operation *defOp = cur.getDefiningOp(); - if (!defOp) continue; - - // 核心逻辑:如果当前值是scf.if的结果 - // 进入then块找源头 - if (auto ifOp = mlir::dyn_cast(defOp)) { - Block &thenBlock = ifOp.getThenRegion().front(); - // 找到then块最后一个op(scf.yield) - // 取其operands(即ifOp结果的源头值) - for (auto &innerOp : llvm::reverse(thenBlock)) { - if (auto yieldOp = mlir::dyn_cast(&innerOp)) { - // 按索引匹配: cur是ifOp的第n个结果, 取yieldOp的第n个operand - for (auto [idx, res] : llvm::enumerate(ifOp.getResults())) { - if (res.getAsOpaquePointer() == cur.getAsOpaquePointer()) { - Value srcVal = yieldOp.getOperand(idx); - if (!visited.count(srcVal)) worklist.push_back(srcVal); - break; - } - } - break; // 找到yield即退出, 无需遍历其他op - } - } - } else { - // 非if结果值 - // 正常往前追溯operands - for (Value operand : defOp->getOperands()) { - if (!visited.count(operand)) worklist.push_back(operand); + Operation *defOp = cur.getDefiningOp(); + if (!defOp) + continue; + + // 核心逻辑:如果当前值是scf.if的结果 + // 进入then块找源头 + if (auto ifOp = mlir::dyn_cast(defOp)) { + Block &thenBlock = ifOp.getThenRegion().front(); + // 找到then块最后一个op(scf.yield) + // 取其operands(即ifOp结果的源头值) + for (auto &innerOp : llvm::reverse(thenBlock)) { + if (auto yieldOp = mlir::dyn_cast(&innerOp)) { + // 按索引匹配: cur是ifOp的第n个结果, 取yieldOp的第n个operand + for (auto [idx, res] : llvm::enumerate(ifOp.getResults())) { + if (res.getAsOpaquePointer() == cur.getAsOpaquePointer()) { + Value srcVal = yieldOp.getOperand(idx); + if (!visited.count(srcVal)) + worklist.push_back(srcVal); + break; } + } + break; // 找到yield即退出, 无需遍历其他op } + } + } else { + // 非if结果值 + // 正常往前追溯operands + for (Value operand : defOp->getOperands()) { + if (!visited.count(operand)) + worklist.push_back(operand); + } } + } - llvm::outs() << "未找到迭代参数, 返回原值: "; v.print(llvm::outs()); llvm::outs() << "\n"; - return v; + llvm::outs() << "未找到迭代参数, 返回原值: "; + v.print(llvm::outs()); + llvm::outs() << "\n"; + return v; } // 如果 v 最终被 scf.for 的 yield 使用 @@ -3681,7 +3608,8 @@ Value findIterArgForAll(Value v, Type t) { return v; } -void FindDependValues (SmallVector &dependValues, SmallVector mergedRegions) { +void FindDependValues(SmallVector &dependValues, + SmallVector mergedRegions) { dependValues.clear(); for (auto &curMR : mergedRegions) { for (Value yieldValue : curMR.yieldValues) { @@ -3694,12 +3622,14 @@ void FindDependValues (SmallVector &dependValues, SmallVector &dependValues, SmallVector otherOps; for (Operation *op : otherMR.opsToMove) { - CollectAllNestedOps(op, otherOps); // 完整展开嵌套 + CollectAllNestedOps(op, otherOps); // 完整展开嵌套 } if (otherOps.contains(userOp)) { isUserInOtherRegion = true; break; } - } // 无重复的添加依赖变量 @@ -3734,7 +3663,8 @@ void FindDependValues (SmallVector &dependValues, SmallVector &mergedRegions, IRMapping &mapper) { +void UpdateMergedRegionsWithNewForOp(SmallVector &mergedRegions, + IRMapping &mapper) { for (auto &mr : mergedRegions) { // WaitSetRegion 后续已经不使用了,直接释放,否则会出现野指针 SmallVector newRegions; @@ -3771,13 +3701,15 @@ void UpdateMergedRegionsWithNewForOp(SmallVector &mergedRegions, I } } -void AddArgsForDependValues (scf::ForOp forOp, SmallVector &dependValues, SmallVector &mergedRegions, ModuleOp module) { +void AddArgsForDependValues(scf::ForOp forOp, SmallVector &dependValues, + SmallVector &mergedRegions, + ModuleOp module) { OpBuilder moduleBuilder(module.getContext()); SmallVector valueTypes; valueTypes.clear(); if (dependValues.empty()) { - return ; + return; } else { for (Value v : dependValues) { Type valueType = v.getType(); @@ -3794,39 +3726,51 @@ void AddArgsForDependValues (scf::ForOp forOp, SmallVector &dependValues, for (Type valueType : valueTypes) { auto tensorType = dyn_cast(valueType); triton::PointerType ptrType; - ptrType = (tensorType) ? dyn_cast(tensorType.getElementType()) : dyn_cast(valueType); + ptrType = + (tensorType) + ? dyn_cast(tensorType.getElementType()) + : dyn_cast(valueType); if (ptrType) { - // 如果依赖变量是一个ptr类型 - // 1. 创建 i64 0 - // 2. cast 成 !tt.ptr<...> - Value zero = moduleBuilder.create(constOp.getLoc(), 0, 64); - Value ptrValue = moduleBuilder.create(constOp.getLoc(), ptrType, zero); - if (tensorType) { - // 3. splat 成 tensor<...x!tt.ptr<...>> - Value ptrTensor = moduleBuilder.create(constOp.getLoc(), tensorType, ptrValue); - initTensors.push_back(ptrTensor); - } else { - initTensors.push_back(ptrValue); - } + // 如果依赖变量是一个ptr类型 + // 1. 创建 i64 0 + // 2. cast 成 !tt.ptr<...> + Value zero = moduleBuilder.create( + constOp.getLoc(), 0, 64); + Value ptrValue = moduleBuilder.create( + constOp.getLoc(), ptrType, zero); + if (tensorType) { + // 3. splat 成 tensor<...x!tt.ptr<...>> + Value ptrTensor = moduleBuilder.create( + constOp.getLoc(), tensorType, ptrValue); + initTensors.push_back(ptrTensor); + } else { + initTensors.push_back(ptrValue); + } } else if (auto memrefType = dyn_cast(valueType)) { // 如果中间变量是一个memref类型,为iterarg创建一个 alloc = memref - // 仅支持#hivm.address_space,对于#hivm.address_space,不存在 copy cbuf to cbuf 行为 - auto spaceAttr = cast(memrefType.getMemorySpace()); - if (spaceAttr && spaceAttr.getAddressSpace() == hivm::AddressSpace::L1) { - llvm::dbgs() << "AddArgsForDependValues: dependValue type is a memref hivm::AddressSpace::L1 type!!!\n"; + // 仅支持#hivm.address_space,对于#hivm.address_space,不存在 + // copy cbuf to cbuf 行为 + auto spaceAttr = + cast(memrefType.getMemorySpace()); + if (spaceAttr && + spaceAttr.getAddressSpace() == hivm::AddressSpace::L1) { + llvm::dbgs() << "AddArgsForDependValues: dependValue type is a " + "memref hivm::AddressSpace::L1 type!!!\n"; return mlir::WalkResult::interrupt(); } else { - mlir::Value alloc = moduleBuilder.create(constOp.getLoc(), memrefType); + mlir::Value alloc = moduleBuilder.create( + constOp.getLoc(), memrefType); initTensors.push_back(alloc); } } else { // 非 ptr 类型创建零值常量 auto zeroAttr = moduleBuilder.getZeroAttr(valueType); - Value zeroTensor = moduleBuilder.create(constOp.getLoc(), zeroAttr); + Value zeroTensor = moduleBuilder.create( + constOp.getLoc(), zeroAttr); initTensors.push_back(zeroTensor); } } - return mlir::WalkResult::interrupt(); + return mlir::WalkResult::interrupt(); } return mlir::WalkResult::advance(); }); @@ -3847,7 +3791,8 @@ void AddArgsForDependValues (scf::ForOp forOp, SmallVector &dependValues, // 创建新的 ForOp,插入点位于原操作之前 OpBuilder builder(forOp); - auto newForOp = builder.create(forOp.getLoc(), lb, ub, step, newInitArgs); + auto newForOp = + builder.create(forOp.getLoc(), lb, ub, step, newInitArgs); // 获取新循环的 region 块(已自动包含循环索引和迭代参数) Block &newBlock = newForOp.getRegion().front(); @@ -3856,7 +3801,7 @@ void AddArgsForDependValues (scf::ForOp forOp, SmallVector &dependValues, // 建立块参数的映射:原块参数 -> 新块参数 IRMapping mapper; for (unsigned i = 0; i < oldBlock.getNumArguments(); ++i) { - mapper.map(oldBlock.getArgument(i), newBlock.getArgument(i)); + mapper.map(oldBlock.getArgument(i), newBlock.getArgument(i)); } // 将原循环体中的操作(不包括终结符)克隆到新块中 // 同时按照顺序克隆新的 dependValues @@ -3864,17 +3809,17 @@ void AddArgsForDependValues (scf::ForOp forOp, SmallVector &dependValues, int cnt = 0; builder.setInsertionPointToStart(&newBlock); for (auto &op : oldBlock) { - auto newOp = builder.clone(op, mapper); - // dependValue 的定义OP 可能有多个 result - for (size_t i = 0; i < dependValues.size(); i++) { - Operation *defineOp = dependValues[i].getDefiningOp(); - if (defineOp == &op) { - unsigned int index = cast(dependValues[i]).getResultNumber(); - newDependValues[i] = newOp->getResult(index); - cnt++; - break; - } + auto newOp = builder.clone(op, mapper); + // dependValue 的定义OP 可能有多个 result + for (size_t i = 0; i < dependValues.size(); i++) { + Operation *defineOp = dependValues[i].getDefiningOp(); + if (defineOp == &op) { + unsigned int index = cast(dependValues[i]).getResultNumber(); + newDependValues[i] = newOp->getResult(index); + cnt++; + break; } + } } // 判断是否找到了所有的 dependValue if (newDependValues.size() != cnt) { @@ -3885,7 +3830,7 @@ void AddArgsForDependValues (scf::ForOp forOp, SmallVector &dependValues, // 更新 mergedRegions 中的 op 为新的for循环的 op UpdateMergedRegionsWithNewForOp(mergedRegions, mapper); - + // 创建新的循环 yield 操作:原操作数 + dependValues auto oldYield = cast(newBlock.getTerminator()); SmallVector newYieldOps(oldYield.getOperands()); @@ -3899,55 +3844,63 @@ void AddArgsForDependValues (scf::ForOp forOp, SmallVector &dependValues, // 将原 forOp 的所有使用替换为新 forOp int oldResultNum = forOp->getResults().size(); - for (auto it : llvm::zip(forOp->getResults(), newForOp->getResults().take_front(oldResultNum))) { - std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + for (auto it : llvm::zip(forOp->getResults(), + newForOp->getResults().take_front(oldResultNum))) { + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); } forOp.erase(); } -void ComputeElseYieldValues (MergedRegion mergedRegion, SmallVector &elseYieldValues, SmallVector dependValues) { +void ComputeElseYieldValues(MergedRegion mergedRegion, + SmallVector &elseYieldValues, + SmallVector dependValues) { int idx = 0; for (Value v : mergedRegion.yieldValues) { - Type yieldType = mergedRegion.resultTypes[idx]; - elseYieldValues.push_back(findIterArg(v, yieldType)); - idx++; + Type yieldType = mergedRegion.resultTypes[idx]; + elseYieldValues.push_back(findIterArg(v, yieldType)); + idx++; } } -void ComputeElseYieldValuesV2 (MergedRegion mergedRegion, SmallVector &elseYieldValues, SmallVector dependValues) { - // 对于yieldValues,其中的 yield value 一定是被 for op yield 所引用,或者被其他 region 所使用 - auto forOp = dyn_cast(mergedRegion.yieldValues[0].getDefiningOp()->getBlock()->getParentOp()); +void ComputeElseYieldValuesV2(MergedRegion mergedRegion, + SmallVector &elseYieldValues, + SmallVector dependValues) { + // 对于yieldValues,其中的 yield value 一定是被 for op yield + // 所引用,或者被其他 region 所使用 + auto forOp = dyn_cast( + mergedRegion.yieldValues[0].getDefiningOp()->getBlock()->getParentOp()); if (!forOp) { llvm::outs() << "define op's parent is not ForOp \n"; return; } auto iterArgs = forOp.getRegionIterArgs(); auto forYieldValues = forOp.getYieldedValues(); - - // 新增的与 dependvalue 相关的 initarg 是接在原本for循环args后面,数量与dependvalue数量相等 + + // 新增的与 dependvalue 相关的 initarg + // 是接在原本for循环args后面,数量与dependvalue数量相等 int baseDependIdx = iterArgs.size() - dependValues.size(); int idx = 0; for (Value v : mergedRegion.yieldValues) { - Type yieldType = mergedRegion.resultTypes[idx]; - // yieldValue 中是dependvalue 的情况下 - // else yield value 使用对应的新增 iterargs - if (llvm::is_contained(dependValues, v)) { - int dependIdx = 0; - for (; dependIdx < dependValues.size(); dependIdx++) { - if (v == dependValues[dependIdx]) { - break; - } + Type yieldType = mergedRegion.resultTypes[idx]; + // yieldValue 中是dependvalue 的情况下 + // else yield value 使用对应的新增 iterargs + if (llvm::is_contained(dependValues, v)) { + int dependIdx = 0; + for (; dependIdx < dependValues.size(); dependIdx++) { + if (v == dependValues[dependIdx]) { + break; } - // llvm::outs()<<"v2for:"< &mergedRegions, SmallVector dependValues) { +void CreateIfOps(SmallVector &mergedRegions, + SmallVector dependValues) { for (auto ®ion : mergedRegions) { // 去重yieldvalues @@ -4008,8 +3962,8 @@ void CreateIfOps (SmallVector &mergedRegions, SmallVector d Operation *insertPt = region.opsToMove.front(); OpBuilder builder(insertPt); Location loc = insertPt->getLoc(); - Value cond = builder.create( - loc, builder.getI1Type(), builder.getBoolAttr(true)); + Value cond = builder.create(loc, builder.getI1Type(), + builder.getBoolAttr(true)); bool needsYield = !region.yieldValues.empty(); scf::IfOp ifOp; @@ -4024,13 +3978,15 @@ void CreateIfOps (SmallVector &mergedRegions, SmallVector d // 获取if yield value 在 else块 返回值 SmallVector elseYieldValues; - llvm::outs()<<"before ComputeElseYieldValuesV2"<<"\n"; + llvm::outs() << "before ComputeElseYieldValuesV2" + << "\n"; if (needsYield) { - // ComputeElseYieldValues(region, elseYieldValues, dependValues); + // ComputeElseYieldValues(region, elseYieldValues, dependValues); ComputeElseYieldValuesV2(region, elseYieldValues, dependValues); } - llvm::outs()<<"after ComputeElseYieldValuesV2"<<"\n"; + llvm::outs() << "after ComputeElseYieldValuesV2" + << "\n"; // 将op移进then块 Block &thenBlock = ifOp.getThenRegion().front(); for (Operation *m : llvm::reverse(region.opsToMove)) { @@ -4043,7 +3999,7 @@ void CreateIfOps (SmallVector &mergedRegions, SmallVector d thenBuilder.setInsertionPointToEnd(&thenBlock); thenBuilder.create(loc, region.yieldValues); - // else block + // else block Block &elseBlock = ifOp.getElseRegion().front(); OpBuilder elseBuilder(&elseBlock, elseBlock.end()); elseBuilder.create(loc, elseYieldValues); @@ -4064,8 +4020,9 @@ void CreateIfOps (SmallVector &mergedRegions, SmallVector d // for (OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) { // Operation *user = use.getOwner(); - // // 同一个 block, user 必须在 ifOp 之后, 不能在 ifOp 内部(then / else) - // if (user->getBlock() != ifOp->getBlock() || !ifOp->isBeforeInBlock(user) || user->getParentOp() == ifOp) + // // 同一个 block, user 必须在 ifOp 之后, 不能在 ifOp 内部(then / + // else) if (user->getBlock() != ifOp->getBlock() || + // !ifOp->isBeforeInBlock(user) || user->getParentOp() == ifOp) // continue; // usesToReplace.push_back(&use); // } @@ -4073,14 +4030,13 @@ void CreateIfOps (SmallVector &mergedRegions, SmallVector d // for (OpOperand *use : usesToReplace) // use->set(newVal); // } - } - llvm::outs() <<"Create ifOp: "<< *ifOp << "\n"; + llvm::outs() << "Create ifOp: " << *ifOp << "\n"; } } -void CreateIfOpsOrigin (SmallVector &mergedRegions) { +void CreateIfOpsOrigin(SmallVector &mergedRegions) { for (auto ®ion : mergedRegions) { // 去重yieldvalues @@ -4089,9 +4045,9 @@ void CreateIfOpsOrigin (SmallVector &mergedRegions) { Operation *insertPt = region.opsToMove.front(); OpBuilder builder(insertPt); Location loc = insertPt->getLoc(); - Value cond = builder.create( - loc, builder.getI1Type(), builder.getBoolAttr(true)); - + Value cond = builder.create(loc, builder.getI1Type(), + builder.getBoolAttr(true)); + bool needsYield = !region.yieldValues.empty(); scf::IfOp ifOp; if (needsYield) @@ -4101,56 +4057,57 @@ void CreateIfOpsOrigin (SmallVector &mergedRegions) { // 加标记 ifOp->setAttr("ssbuffer", builder.getUnitAttr()); - + // 将op移进then块 Block &thenBlock = ifOp.getThenRegion().front(); for (Operation *m : llvm::reverse(region.opsToMove)) { m->moveBefore(&thenBlock, thenBlock.begin()); } - + // 创建 then/else yield if (needsYield) { OpBuilder thenBuilder(builder.getContext()); thenBuilder.setInsertionPointToEnd(&thenBlock); thenBuilder.create(loc, region.yieldValues); - + // else block SmallVector elseYieldValues; int idx = 0; for (Value v : region.yieldValues) { - Type yieldType = region.resultTypes[idx]; - elseYieldValues.push_back(findIterArgForAll(v, yieldType)); - idx++; + Type yieldType = region.resultTypes[idx]; + elseYieldValues.push_back(findIterArgForAll(v, yieldType)); + idx++; } Block &elseBlock = ifOp.getElseRegion().front(); OpBuilder elseBuilder(&elseBlock, elseBlock.end()); elseBuilder.create(loc, elseYieldValues); - + // 替换外部使用 Block *block = ifOp->getBlock(); auto ifIt = Block::iterator(ifOp); - + for (size_t i = 0; i < region.yieldValues.size(); ++i) { Value oldVal = region.yieldValues[i]; Value newVal = ifOp.getResult(i); - + SmallVector usesToReplace; - + for (OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) { Operation *user = use.getOwner(); - // 同一个 block, user 必须在 ifOp 之后, 不能在 ifOp 内部(then / else) - if (user->getBlock() != ifOp->getBlock() || !ifOp->isBeforeInBlock(user) || user->getParentOp() == ifOp) + // 同一个 block, user 必须在 ifOp 之后, 不能在 ifOp 内部(then / + // else) + if (user->getBlock() != ifOp->getBlock() || + !ifOp->isBeforeInBlock(user) || user->getParentOp() == ifOp) continue; usesToReplace.push_back(&use); } - + for (OpOperand *use : usesToReplace) use->set(newVal); } - } - - llvm::outs() <<"Create ifOp: "<< *ifOp << "\n"; + + llvm::outs() << "Create ifOp: " << *ifOp << "\n"; } } @@ -4176,7 +4133,7 @@ void AddIfCondition(ModuleOp module) { // 处理forop的末尾对于iter_arg的自增操作, 如tt.advance, 移进对应的if op MoveIterArgUsersIntoIf(forOp, mergedRegions); - + // 获取if yield的value, 并更新if内op的user为yield value for (MergedRegion &mr : mergedRegions) { // ComputeYieldForMergedRegion(mr, body); @@ -4191,9 +4148,9 @@ void AddIfCondition(ModuleOp module) { regionList.push_back(mergedRegions); }); - llvm::outs()<<"CopyForOp:\n"; - for(auto op : copiedForOps){ - llvm::outs()<<*op<<"\n"; + llvm::outs() << "CopyForOp:\n"; + for (auto op : copiedForOps) { + llvm::outs() << *op << "\n"; } SmallVector tmpOps; @@ -4240,7 +4197,7 @@ void AddIfCondition(ModuleOp module) { if (dependValues.size() != 0) { copyLoadCalculation(oldForOp, dependValues, newMergedRegions); - + // repeat previous operations for (MergedRegion &mr : newMergedRegions) { mr.yieldValues.clear(); @@ -4249,12 +4206,12 @@ void AddIfCondition(ModuleOp module) { } FindDependValues(dependValues, newMergedRegions); } - + // 如果存在VV或CC依赖,更新ForOp添加新的对应args if (dependValues.size() != 0) { AddArgsForDependValues(oldForOp, dependValues, newMergedRegions, module); } - + // 创建最终的if op llvm::outs() << "before create if ops" << '\n'; CreateIfOps(newMergedRegions, dependValues); @@ -4279,16 +4236,20 @@ void ChangeAdvanceOpForm(ModuleOp module) { break; } } - if (!advanceOp) continue; + if (!advanceOp) + continue; // base 必须是 for的iter_arg Value base = advanceOp.getPtr(); auto barg = dyn_cast(base); - if (!barg || barg.getOwner() != &body) continue; + if (!barg || barg.getOwner() != &body) + continue; // yield 去掉 advance 的返回值 - auto thenYield = cast(ifOp.getThenRegion().front().getTerminator()); - auto elseYield = cast(ifOp.getElseRegion().front().getTerminator()); + auto thenYield = + cast(ifOp.getThenRegion().front().getTerminator()); + auto elseYield = + cast(ifOp.getElseRegion().front().getTerminator()); int advanceIdx = -1; for (auto it : llvm::enumerate(thenYield.getOperands())) { @@ -4297,19 +4258,22 @@ void ChangeAdvanceOpForm(ModuleOp module) { break; } } - - if (advanceIdx == -1) continue; + + if (advanceIdx == -1) + continue; // 删除 advance - SmallVector thenOps(thenYield.getOperands().begin(), thenYield.getOperands().end()); - SmallVector elseOps(elseYield.getOperands().begin(), elseYield.getOperands().end()); + SmallVector thenOps(thenYield.getOperands().begin(), + thenYield.getOperands().end()); + SmallVector elseOps(elseYield.getOperands().begin(), + elseYield.getOperands().end()); thenOps.erase(thenOps.begin() + advanceIdx); elseOps.erase(elseOps.begin() + advanceIdx); thenYield->setOperands(thenOps); elseYield->setOperands(elseOps); - + // 重建 ifOp(去掉 advance 对应的 result) OpBuilder ifBuilder(ifOp); ifBuilder.setInsertionPoint(ifOp); @@ -4322,11 +4286,9 @@ void ChangeAdvanceOpForm(ModuleOp module) { } // 创建新的 if - auto newIf = ifBuilder.create( - ifOp.getLoc(), - newResultTypes, - ifOp.getCondition(), - /*withElseRegion=*/true); + auto newIf = ifBuilder.create(ifOp.getLoc(), newResultTypes, + ifOp.getCondition(), + /*withElseRegion=*/true); newIf->setAttr("ssbuffer", ifBuilder.getUnitAttr()); // 把已经修改过 yield 的 region 搬过去 newIf.getThenRegion().takeBody(ifOp.getThenRegion()); @@ -4348,15 +4310,15 @@ void ChangeAdvanceOpForm(ModuleOp module) { SmallVector newOffsets; for (Value off : advanceOp.getOffsets()) { auto intTy = cast(off.getType()); - auto zero = builder.create( - newIf.getLoc(), 0, intTy.getWidth()); - auto sel = builder.create( - newIf.getLoc(), flag, off, zero); + auto zero = builder.create(newIf.getLoc(), 0, + intTy.getWidth()); + auto sel = + builder.create(newIf.getLoc(), flag, off, zero); newOffsets.push_back(sel); } auto newAdvance = builder.create( - newIf.getLoc(), base.getType(), base, newOffsets); + newIf.getLoc(), base.getType(), base, newOffsets); // 原 if 的 advance result 的 users,接到 newAdvance ifOp.getResult(advanceIdx).replaceAllUsesWith(newAdvance.getResult()); @@ -4369,240 +4331,243 @@ void ChangeAdvanceOpForm(ModuleOp module) { } void processRedudantIf(ModuleOp module) { - SmallVector forOps; - llvm::outs()< forOps; + llvm::outs() << module << " wwwww\n\n\n"; + module.walk([&](scf::ForOp forOp) { + auto initArgs = forOp.getInitArgs(); + if (initArgs.size() == 5) { + forOps.push_back(forOp); + } + }); - for (auto forOp : forOps) { - auto initArgs = forOp.getInitArgs(); - Value newInit = initArgs[2]; + for (auto forOp : forOps) { + auto initArgs = forOp.getInitArgs(); + Value newInit = initArgs[2]; - // 构建新的初始化参数列表 - SmallVector newInitArgs(initArgs.begin(), initArgs.end()); - newInitArgs.push_back(newInit); + // 构建新的初始化参数列表 + SmallVector newInitArgs(initArgs.begin(), initArgs.end()); + newInitArgs.push_back(newInit); - // 获取原循环的边界和步长 - Value lb = forOp.getLowerBound(); - Value ub = forOp.getUpperBound(); - Value step = forOp.getStep(); + // 获取原循环的边界和步长 + Value lb = forOp.getLowerBound(); + Value ub = forOp.getUpperBound(); + Value step = forOp.getStep(); - // 创建新的 ForOp,插入点位于原操作之前 - OpBuilder builder(forOp); - auto newForOp = builder.create(forOp.getLoc(), lb, ub, step, newInitArgs); + // 创建新的 ForOp,插入点位于原操作之前 + OpBuilder builder(forOp); + auto newForOp = + builder.create(forOp.getLoc(), lb, ub, step, newInitArgs); - // 获取新循环的 region 块(已自动包含循环索引和迭代参数) - Block &newBlock = newForOp.getRegion().front(); - Block &oldBlock = forOp.getRegion().front(); + // 获取新循环的 region 块(已自动包含循环索引和迭代参数) + Block &newBlock = newForOp.getRegion().front(); + Block &oldBlock = forOp.getRegion().front(); - // 建立块参数的映射:原块参数 -> 新块参数(前6个对应) - IRMapping mapper; - for (unsigned i = 0; i < oldBlock.getNumArguments(); ++i) { - mapper.map(oldBlock.getArgument(i), newBlock.getArgument(i)); - } - // 将原循环体中的操作(不包括终结符)克隆到新块中 - builder.setInsertionPointToStart(&newBlock); - for (auto &op : oldBlock) { - auto newOp = builder.clone(op, mapper); - } + // 建立块参数的映射:原块参数 -> 新块参数(前6个对应) + IRMapping mapper; + for (unsigned i = 0; i < oldBlock.getNumArguments(); ++i) { + mapper.map(oldBlock.getArgument(i), newBlock.getArgument(i)); + } + // 将原循环体中的操作(不包括终结符)克隆到新块中 + builder.setInsertionPointToStart(&newBlock); + for (auto &op : oldBlock) { + auto newOp = builder.clone(op, mapper); + } - // 在新块中查找第一个 scf::IfOp(即原代码中的第一个 if) - scf::IfOp firstIfOp = nullptr; - for (auto &op : newBlock.getOperations()) { - if (auto ifOp = dyn_cast(&op)) { - firstIfOp = ifOp; - break; - } - } - assert(firstIfOp && "Expected at least one if op in the loop body"); - - // 修改第一个 if 的 else 分支的 yield 操作: - // 将其第二个操作数(索引1)从原来的 %arg9 改为新迭代参数(新块参数索引6) - Block &elseBlock = firstIfOp.getElseRegion().front(); - auto elseYield = cast(elseBlock.getTerminator()); - SmallVector newElseYieldOps(elseYield.getOperands()); - newElseYieldOps[1] = newBlock.getArgument(6); // 新迭代参数 - builder.setInsertionPoint(elseYield); - builder.create(elseYield.getLoc(), newElseYieldOps); - elseYield->erase(); - - // 创建新的循环 yield 操作:原5个操作数 + 第一个 if 的第二个结果 - auto oldYield = cast(newBlock.getTerminator()); - SmallVector newYieldOps(oldYield.getOperands()); - newYieldOps.push_back(firstIfOp.getResult(1)); // 第一个 if 的第二个结果 - builder.setInsertionPointToEnd(&newBlock); - builder.create(oldYield.getLoc(), newYieldOps); - oldYield.erase(); - - // 将原 forOp 的所有使用替换为新 forOp 的前5个结果 - for (auto it : llvm::zip(forOp->getResults(), newForOp->getResults().take_front(5))) { - std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); - } + // 在新块中查找第一个 scf::IfOp(即原代码中的第一个 if) + scf::IfOp firstIfOp = nullptr; + for (auto &op : newBlock.getOperations()) { + if (auto ifOp = dyn_cast(&op)) { + firstIfOp = ifOp; + break; + } } - for (auto forOp : forOps) { - forOp.erase(); + assert(firstIfOp && "Expected at least one if op in the loop body"); + + // 修改第一个 if 的 else 分支的 yield 操作: + // 将其第二个操作数(索引1)从原来的 %arg9 改为新迭代参数(新块参数索引6) + Block &elseBlock = firstIfOp.getElseRegion().front(); + auto elseYield = cast(elseBlock.getTerminator()); + SmallVector newElseYieldOps(elseYield.getOperands()); + newElseYieldOps[1] = newBlock.getArgument(6); // 新迭代参数 + builder.setInsertionPoint(elseYield); + builder.create(elseYield.getLoc(), newElseYieldOps); + elseYield->erase(); + + // 创建新的循环 yield 操作:原5个操作数 + 第一个 if 的第二个结果 + auto oldYield = cast(newBlock.getTerminator()); + SmallVector newYieldOps(oldYield.getOperands()); + newYieldOps.push_back(firstIfOp.getResult(1)); // 第一个 if 的第二个结果 + builder.setInsertionPointToEnd(&newBlock); + builder.create(oldYield.getLoc(), newYieldOps); + oldYield.erase(); + + // 将原 forOp 的所有使用替换为新 forOp 的前5个结果 + for (auto it : + llvm::zip(forOp->getResults(), newForOp->getResults().take_front(5))) { + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); } + } + for (auto forOp : forOps) { + forOp.erase(); + } } // 针对依赖变量,对原本的for op增加double buffer相关的迭代参数 -scf::ForOp addDoubleBuffForArgs(ModuleOp module, SmallVector uniqueDeps, int bufferNum) { - mlir::OpBuilder builder(module.getContext()); - SmallVector depValueForIdxs; - - // ========== 找到scf.if所在的scf::ForOp ========== - if (!isa(uniqueDeps[0].getDefiningOp()->getParentOp())) { - llvm::errs() << "Error: parent op of scf.if is not scf.for"; - } - scf::ForOp forOp = dyn_cast(uniqueDeps[0].getDefiningOp()->getParentOp()); - - for(Value dependencyValue : uniqueDeps){ - // ========== 步骤1:验证目标Value是scf.if的返回值,并找到对应的scf::IfOp ========== - Operation *ifOp = dependencyValue.getDefiningOp(); - if (!ifOp || !isa(ifOp)) { - llvm::errs() << "Error: 目标Value不是scf.if的返回值\n"; - return nullptr; - } - scf::IfOp targetIfOp = dyn_cast(ifOp); - - // 确认当前Value是scf.if的第几个返回值 - int64_t depValueIdx = -1; - for (auto [idx, result] : llvm::enumerate(targetIfOp.getResults())) { - if (result == dependencyValue) { - depValueIdx = idx; - break; - } - } +scf::ForOp addDoubleBuffForArgs(ModuleOp module, SmallVector uniqueDeps, + int bufferNum) { + mlir::OpBuilder builder(module.getContext()); + SmallVector depValueForIdxs; + + // ========== 找到scf.if所在的scf::ForOp ========== + if (!isa(uniqueDeps[0].getDefiningOp()->getParentOp())) { + llvm::errs() << "Error: parent op of scf.if is not scf.for"; + } + scf::ForOp forOp = + dyn_cast(uniqueDeps[0].getDefiningOp()->getParentOp()); + + for (Value dependencyValue : uniqueDeps) { + // ========== 步骤1:验证目标Value是scf.if的返回值,并找到对应的scf::IfOp + // ========== + Operation *ifOp = dependencyValue.getDefiningOp(); + if (!ifOp || !isa(ifOp)) { + llvm::errs() << "Error: 目标Value不是scf.if的返回值\n"; + return nullptr; + } + scf::IfOp targetIfOp = dyn_cast(ifOp); - // ========== 步骤2:找到%38#2关联的scf.for迭代参数以及索引 ========== - // %38#2对应scf.if else分支yield的第2个操作数 → 即%arg10 - Operation *elseYield = targetIfOp.elseYield(); - Value dependencyArg = elseYield->getOperand(depValueIdx); // depValueIdx=2,对应else yield的第2个参数 + // 确认当前Value是scf.if的第几个返回值 + int64_t depValueIdx = -1; + for (auto [idx, result] : llvm::enumerate(targetIfOp.getResults())) { + if (result == dependencyValue) { + depValueIdx = idx; + break; + } + } - int64_t depValueForIdx = -1; - for (auto [idx, result] : llvm::enumerate(forOp.getRegionIterArgs())) { - if (result == dependencyArg) { - depValueForIdx = idx; - break; - } - } - depValueForIdxs.push_back(depValueForIdx); - llvm::outs() << "depValueForIdx: " << depValueForIdx << '\n'; - } - - llvm::outs() << "oldFor: " << forOp << '\n'; - - // 获取原始循环的信息 - Value originalLowerBound = forOp.getLowerBound(); - Value originalUpperBound = forOp.getUpperBound(); - Value originalStep = forOp.getStep(); - SmallVector originalInitArgs = forOp.getInitArgs(); - SmallVector iterArgs; - for (auto arg : originalInitArgs) { - iterArgs.push_back(arg); - } - auto yields = forOp.getBody()->getTerminator(); - - // 创建计数器初始零值 - Value counterInit = nullptr; - mlir::Operation* parentOp = forOp->getParentOp(); - mlir::Operation* scopeOp = nullptr; - // 向上遍历查找scope.scope操作 - while (parentOp) { - if (dyn_cast(parentOp)) { - scopeOp = parentOp; - break; - } - parentOp = parentOp->getParentOp(); + // ========== 步骤2:找到%38#2关联的scf.for迭代参数以及索引 ========== + // %38#2对应scf.if else分支yield的第2个操作数 → 即%arg10 + Operation *elseYield = targetIfOp.elseYield(); + Value dependencyArg = elseYield->getOperand( + depValueIdx); // depValueIdx=2,对应else yield的第2个参数 + + int64_t depValueForIdx = -1; + for (auto [idx, result] : llvm::enumerate(forOp.getRegionIterArgs())) { + if (result == dependencyArg) { + depValueForIdx = idx; + break; + } } + depValueForIdxs.push_back(depValueForIdx); + llvm::outs() << "depValueForIdx: " << depValueForIdx << '\n'; + } - builder.setInsertionPoint(scopeOp); - Location loc = forOp.getLoc(); - auto boundType = originalLowerBound.getType(); - counterInit = builder.create(loc, 0, boundType); + llvm::outs() << "oldFor: " << forOp << '\n'; - // 添加和depValueForIdxs相同的迭代参数和计数器 - for (int64_t idx : depValueForIdxs) { - for (int i = 0; i < bufferNum - 1; i++) { - iterArgs.push_back(originalInitArgs[idx]); - } - - // 在迭代参数中添加计数器 - for (int i = 0; i < 2; i++) { - iterArgs.push_back(counterInit); - } + // 获取原始循环的信息 + Value originalLowerBound = forOp.getLowerBound(); + Value originalUpperBound = forOp.getUpperBound(); + Value originalStep = forOp.getStep(); + SmallVector originalInitArgs = forOp.getInitArgs(); + SmallVector iterArgs; + for (auto arg : originalInitArgs) { + iterArgs.push_back(arg); + } + auto yields = forOp.getBody()->getTerminator(); + + // 创建计数器初始零值 + Value counterInit = nullptr; + mlir::Operation *parentOp = forOp->getParentOp(); + mlir::Operation *scopeOp = nullptr; + // 向上遍历查找scope.scope操作 + while (parentOp) { + if (dyn_cast(parentOp)) { + scopeOp = parentOp; + break; } + parentOp = parentOp->getParentOp(); + } - builder.setInsertionPoint(forOp); - // 创建新的for循环 - auto newForOp = builder.create( - forOp.getLoc(), - originalLowerBound, - originalUpperBound, - originalStep, - iterArgs); - - // 设置IR映射表,将旧循环的变量映射到新循环 - IRMapping mapper; - - // 映射迭代变量 - mapper.map(forOp.getInductionVar(), newForOp.getInductionVar()); - - // 映射迭代参数 - for (auto [oldArg, newArg] : - llvm::zip(forOp.getRegionIterArgs(), - newForOp.getRegionIterArgs())) { - mapper.map(oldArg, newArg); - } - - SmallVector newArgs; - for (int i = forOp.getRegionIterArgs().size(); i < newForOp.getRegionIterArgs().size(); i++) { - newArgs.push_back(newForOp.getRegionIterArgs()[i]); - } - // 克隆循环体内容到新循环 - auto &newLoopBody = *newForOp.getBody(); - builder.setInsertionPointToStart(&newLoopBody); - - for (auto &op : forOp.getBody()->without_terminator()) { - builder.clone(op, mapper); - } - - // 克隆yield操作 - if (auto yieldOp = dyn_cast(yields)) { - SmallVector newYieldOperands; - for (auto operand : yieldOp.getOperands()) { - newYieldOperands.push_back(mapper.lookupOrDefault(operand)); - } - // 将新增的迭代参数添加到yield操作数中 - for (auto currentCounter : newArgs) { - newYieldOperands.push_back(currentCounter); - } - builder.create(yieldOp.getLoc(), newYieldOperands); + builder.setInsertionPoint(scopeOp); + Location loc = forOp.getLoc(); + auto boundType = originalLowerBound.getType(); + counterInit = builder.create(loc, 0, boundType); + + // 添加和depValueForIdxs相同的迭代参数和计数器 + for (int64_t idx : depValueForIdxs) { + for (int i = 0; i < bufferNum - 1; i++) { + iterArgs.push_back(originalInitArgs[idx]); } - - // 替换原循环的结果 - unsigned numOriginalResults = forOp.getNumResults(); - SmallVector originalResults; - for (unsigned i = 0; i < numOriginalResults; i++) { - originalResults.push_back(newForOp.getResult(i)); + + // 在迭代参数中添加计数器 + for (int i = 0; i < 2; i++) { + iterArgs.push_back(counterInit); } - forOp.replaceAllUsesWith(originalResults); - - // 8. 删除原循环 - forOp.erase(); + } + + builder.setInsertionPoint(forOp); + // 创建新的for循环 + auto newForOp = + builder.create(forOp.getLoc(), originalLowerBound, + originalUpperBound, originalStep, iterArgs); + + // 设置IR映射表,将旧循环的变量映射到新循环 + IRMapping mapper; + + // 映射迭代变量 + mapper.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + // 映射迭代参数 + for (auto [oldArg, newArg] : + llvm::zip(forOp.getRegionIterArgs(), newForOp.getRegionIterArgs())) { + mapper.map(oldArg, newArg); + } + + SmallVector newArgs; + for (int i = forOp.getRegionIterArgs().size(); + i < newForOp.getRegionIterArgs().size(); i++) { + newArgs.push_back(newForOp.getRegionIterArgs()[i]); + } + // 克隆循环体内容到新循环 + auto &newLoopBody = *newForOp.getBody(); + builder.setInsertionPointToStart(&newLoopBody); + + for (auto &op : forOp.getBody()->without_terminator()) { + builder.clone(op, mapper); + } + + // 克隆yield操作 + if (auto yieldOp = dyn_cast(yields)) { + SmallVector newYieldOperands; + for (auto operand : yieldOp.getOperands()) { + newYieldOperands.push_back(mapper.lookupOrDefault(operand)); + } + // 将新增的迭代参数添加到yield操作数中 + for (auto currentCounter : newArgs) { + newYieldOperands.push_back(currentCounter); + } + builder.create(yieldOp.getLoc(), newYieldOperands); + } + + // 替换原循环的结果 + unsigned numOriginalResults = forOp.getNumResults(); + SmallVector originalResults; + for (unsigned i = 0; i < numOriginalResults; i++) { + originalResults.push_back(newForOp.getResult(i)); + } + forOp.replaceAllUsesWith(originalResults); - llvm::outs() << "for op erased!\n"; - return newForOp; + // 8. 删除原循环 + forOp.erase(); + + llvm::outs() << "for op erased!\n"; + return newForOp; } SmallVector buildNBufferProducer(OpBuilder &builder, Location loc, Value frontCnt, Value newDepVal, ArrayRef buffs, ArrayRef constants) { - // N-buffer producer: determines which buffer is written to newDepVal based on frontCnt % N + // N-buffer producer: determines which buffer is written to newDepVal based on + // frontCnt % N const int N = buffs.size(); SmallVector results; @@ -4615,16 +4580,20 @@ SmallVector buildNBufferProducer(OpBuilder &builder, Location loc, bufferIndex, constants[0]); auto dstShapedType = mlir::dyn_cast(newDepVal.getType()); - auto maskType = RankedTensorType::get(dstShapedType.getShape(), isBuffer0.getType()); + auto maskType = + RankedTensorType::get(dstShapedType.getShape(), isBuffer0.getType()); Value mask = builder.create(loc, maskType, isBuffer0); - Value newBuff0 = builder.create(loc, mask, newDepVal, buffs[0]); + Value newBuff0 = + builder.create(loc, mask, newDepVal, buffs[0]); results.push_back(newBuff0); - // 2. Double-buffer specialization (when N == 2, a direct select is sufficient) + // 2. Double-buffer specialization (when N == 2, a direct select is + // sufficient) if (N == 2) { - Value newBuff1 = builder.create(loc, mask, buffs[1], newDepVal); + Value newBuff1 = + builder.create(loc, mask, buffs[1], newDepVal); auto nextCnt = builder.create(loc, frontCnt, constants[1]); @@ -4635,7 +4604,8 @@ SmallVector buildNBufferProducer(OpBuilder &builder, Location loc, } // 3. Build the root IF: when idx == 0, - // use the first buffer; otherwise enter the nestedIf chain to use other buffers + // use the first buffer; otherwise enter the nestedIf chain to use other + // buffers SmallVector resultTypes; for (int i = 1; i < N; ++i) resultTypes.push_back(buffs[i].getType()); @@ -4667,21 +4637,25 @@ SmallVector buildNBufferProducer(OpBuilder &builder, Location loc, // Update buffer[i] dstShapedType = mlir::dyn_cast(newDepVal.getType()); - maskType = RankedTensorType::get(dstShapedType.getShape(), isCurrent.getType()); + maskType = + RankedTensorType::get(dstShapedType.getShape(), isCurrent.getType()); mask = builder.create(loc, maskType, isCurrent); - Value updatedBuffer = builder.create(loc, mask, newDepVal, buffs[i]); + Value updatedBuffer = + builder.create(loc, mask, newDepVal, buffs[i]); // If this is the last level: directly yield both buffers if (i == N - 2) { - dstShapedType = mlir::dyn_cast(newDepVal.getType()); - maskType = RankedTensorType::get(dstShapedType.getShape(), isCurrent.getType()); - mask = builder.create(loc, maskType, isCurrent); - Value lastBuffer = builder.create(loc, mask, buffs[N - 1], newDepVal); + dstShapedType = mlir::dyn_cast(newDepVal.getType()); + maskType = + RankedTensorType::get(dstShapedType.getShape(), isCurrent.getType()); + mask = builder.create(loc, maskType, isCurrent); + Value lastBuffer = + builder.create(loc, mask, buffs[N - 1], newDepVal); - builder.create(loc, ValueRange {updatedBuffer, lastBuffer}); + builder.create(loc, ValueRange{updatedBuffer, lastBuffer}); - break; + break; } // Create the next nested if @@ -4740,14 +4714,16 @@ SmallVector buildNBufferConsumer(OpBuilder &builder, Location loc, builder.create(loc, postCnt, constants[bufferNum]); Value isBuffer0 = builder.create(loc, arith::CmpIPredicate::eq, - bufferIndex, constants[0]); + bufferIndex, constants[0]); auto dstShapedType = mlir::dyn_cast(oldBuffs[0].getType()); - auto maskType = RankedTensorType::get(dstShapedType.getShape(), isBuffer0.getType()); + auto maskType = + RankedTensorType::get(dstShapedType.getShape(), isBuffer0.getType()); auto mask = builder.create(loc, maskType, isBuffer0); // 1. Double-buffer specialization (avoid generating scf.if) if (bufferNum == 2) { - Value selected = builder.create(loc, mask, oldBuffs[0], oldBuffs[1]); + Value selected = + builder.create(loc, mask, oldBuffs[0], oldBuffs[1]); auto nextCnt = builder.create(loc, postCnt, constants[1]); results.push_back(selected); @@ -4757,7 +4733,8 @@ SmallVector buildNBufferConsumer(OpBuilder &builder, Location loc, } // 2. Build the root IF: - // when idx == 0, use the first buffer; otherwise enter the nestedIf chain to use other buffers + // when idx == 0, use the first buffer; otherwise enter the nestedIf chain to + // use other buffers SmallVector resultTypes{oldBuffs[0].getType()}; auto rootIf = builder.create(loc, resultTypes, isBuffer0, true); @@ -4810,7 +4787,8 @@ SmallVector buildNBufferConsumer(OpBuilder &builder, Location loc, maskType = RankedTensorType::get(dstShapedType.getShape(), isLast.getType()); mask = builder.create(loc, maskType, isLast); - Value finalSelect = builder.create(loc, mask, oldBuffs[last], oldBuffs[last + 1]); + Value finalSelect = builder.create(loc, mask, oldBuffs[last], + oldBuffs[last + 1]); builder.create(loc, finalSelect); @@ -4827,13 +4805,9 @@ SmallVector buildNBufferConsumer(OpBuilder &builder, Location loc, return results; } -void replaceDepsMap( - scf::IfOp oldIfOp, - scf::IfOp newIfOp, - SmallVector &newDeps, - bool isFront, - DenseMap> &newIfResultDeps) -{ +void replaceDepsMap(scf::IfOp oldIfOp, scf::IfOp newIfOp, + SmallVector &newDeps, bool isFront, + DenseMap> &newIfResultDeps) { mlir::IRMapping valueMap; // old result -> new result @@ -4860,15 +4834,12 @@ void replaceDepsMap( } } -scf::IfOp addResultsForFrontIfOp(scf::IfOp frontIfOp, OpBuilder builder, - int bufferNum, Value depValue, - SmallVector constants, - SmallVector buffs, Value frontCnt, - Value postCnt, - SmallVector &extraResultIndices, - SmallVector &newDeps, - DenseMap> &newIfResultDeps) -{ +scf::IfOp addResultsForFrontIfOp( + scf::IfOp frontIfOp, OpBuilder builder, int bufferNum, Value depValue, + SmallVector constants, SmallVector buffs, Value frontCnt, + Value postCnt, SmallVector &extraResultIndices, + SmallVector &newDeps, + DenseMap> &newIfResultDeps) { OpBuilder::InsertionGuard guard(builder); Location loc = frontIfOp.getLoc(); @@ -5005,16 +4976,14 @@ scf::IfOp addResultsForFrontIfOp(scf::IfOp frontIfOp, OpBuilder builder, return newIfOp; } -scf::IfOp addResultsForPostIfOp(scf::IfOp postIfOp, scf::IfOp newfrontIfOp, - OpBuilder builder, int bufferNum, - Value newDepValue, SmallVector constants, - SmallVector buffs, Value frontCnt, - Value postCnt, - SmallVector &extraResultIndices, - SmallVector &newDeps, - DenseMap> &newIfResultDeps) -{ - // 1. Parse the extra result indices produced by frontIf (added buffers and counters) +scf::IfOp addResultsForPostIfOp( + scf::IfOp postIfOp, scf::IfOp newfrontIfOp, OpBuilder builder, + int bufferNum, Value newDepValue, SmallVector constants, + SmallVector buffs, Value frontCnt, Value postCnt, + SmallVector &extraResultIndices, SmallVector &newDeps, + DenseMap> &newIfResultDeps) { + // 1. Parse the extra result indices produced by frontIf (added buffers and + // counters) SmallVector bufferIndices(extraResultIndices.begin(), extraResultIndices.end() - 1); int frontCntIndex = extraResultIndices[bufferNum]; @@ -5036,7 +5005,8 @@ scf::IfOp addResultsForPostIfOp(scf::IfOp postIfOp, scf::IfOp newfrontIfOp, mlir::IRMapping mapping; - // 3. THEN region: clone the original logic, insert the multibuffer consumer and update dependency buffers + // 3. THEN region: clone the original logic, insert the multibuffer consumer + // and update dependency buffers auto &newThenBlock = newIfOp.getThenRegion().front(); builder.setInsertionPointToStart(&newThenBlock); @@ -5045,8 +5015,9 @@ scf::IfOp addResultsForPostIfOp(scf::IfOp postIfOp, scf::IfOp newfrontIfOp, builder.clone(op, mapping); builder.setInsertionPointToStart(&newThenBlock); - // Find dependency uses that need to be replaced (located inside the current IfOp) - SmallVector replaceUses; + // Find dependency uses that need to be replaced (located inside the current + // IfOp) + SmallVector replaceUses; for (auto &use : newDepValue.getUses()) { if (newIfOp == dyn_cast(use.getOwner()->getParentOp())) { replaceUses.push_back(&use); @@ -5067,7 +5038,7 @@ scf::IfOp addResultsForPostIfOp(scf::IfOp postIfOp, scf::IfOp newfrontIfOp, // Replace dependent buffer for (auto *usePtr : replaceUses) { - usePtr->set(selectedBuffer); + usePtr->set(selectedBuffer); } // Create then yield @@ -5112,8 +5083,8 @@ scf::IfOp addResultsForPostIfOp(scf::IfOp postIfOp, scf::IfOp newfrontIfOp, } void addMultiBuffCaculate(ModuleOp module, SmallVector newUniqueDeps, - DenseMap> &ifResultDeps, scf::ForOp &newForOp, int bufferNum) -{ + DenseMap> &ifResultDeps, + scf::ForOp &newForOp, int bufferNum) { // ============================================================ // Overall Idea @@ -5129,10 +5100,8 @@ void addMultiBuffCaculate(ModuleOp module, SmallVector newUniqueDeps, int processedDepCount = 0; SmallVector postIfOps; - newForOp.walk([&](scf::IfOp postIfOp) { - postIfOps.push_back(postIfOp); - }); - for (auto postIfOp:postIfOps) { + newForOp.walk([&](scf::IfOp postIfOp) { postIfOps.push_back(postIfOp); }); + for (auto postIfOp : postIfOps) { if (!ifResultDeps.count(postIfOp)) { continue; } @@ -5174,7 +5143,8 @@ void addMultiBuffCaculate(ModuleOp module, SmallVector newUniqueDeps, // Other buffers come from for iter args for (int i = 1; i < bufferNum; ++i) { - buffers.push_back(newForOp.getRegionIterArgs()[extraArgBaseIdx + i - 1]); + buffers.push_back( + newForOp.getRegionIterArgs()[extraArgBaseIdx + i - 1]); } // Two counters @@ -5182,15 +5152,15 @@ void addMultiBuffCaculate(ModuleOp module, SmallVector newUniqueDeps, newForOp.getRegionIterArgs()[extraArgBaseIdx + bufferNum - 1]; Value postCnt = newForOp.getRegionIterArgs()[extraArgBaseIdx + bufferNum]; - // Step 3. Create constants (0 ~ bufferNum) for rem / cmp buffer selection logic + // Step 3. Create constants (0 ~ bufferNum) for rem / cmp buffer selection + // logic SmallVector constants; builder.setInsertionPoint(frontIfOp); auto dataType = frontCnt.getType(); for (int i = 0; i <= bufferNum; ++i) { constants.push_back(builder.create( - frontIfOp.getLoc(), dataType, - builder.getIntegerAttr(dataType, i))); + frontIfOp.getLoc(), dataType, builder.getIntegerAttr(dataType, i))); } // Record the positions of newly added results in the IfOp @@ -5204,7 +5174,7 @@ void addMultiBuffCaculate(ModuleOp module, SmallVector newUniqueDeps, // buffer result indices SmallVector bufferResultIndices(extraResultIndices.begin(), - extraResultIndices.end() - 1); + extraResultIndices.end() - 1); int frontCntResultIndex = extraResultIndices[bufferNum]; @@ -5214,7 +5184,8 @@ void addMultiBuffCaculate(ModuleOp module, SmallVector newUniqueDeps, scf::IfOp postIfOp = nullptr; for (auto &use : newDepValue.getUses()) { - if (auto candidate = dyn_cast(use.getOwner()->getParentOp())) { + if (auto candidate = + dyn_cast(use.getOwner()->getParentOp())) { postIfOp = candidate; break; } @@ -5229,7 +5200,8 @@ void addMultiBuffCaculate(ModuleOp module, SmallVector newUniqueDeps, scf::IfOp newPostIfOp = addResultsForPostIfOp( postIfOp, newFrontIfOp, builder, bufferNum, newDepValue, constants, - buffers, frontCnt, postCnt, extraResultIndices, newDeps, ifResultDeps); + buffers, frontCnt, postCnt, extraResultIndices, newDeps, + ifResultDeps); llvm::outs() << "after addResultsForPostIfOp.\n"; @@ -5246,7 +5218,8 @@ void addMultiBuffCaculate(ModuleOp module, SmallVector newUniqueDeps, if (yieldIdx < forYield->getNumOperands() && bufferResultIndices[i] < newFrontIfOp.getNumResults()) { - forYield->setOperand(yieldIdx, newFrontIfOp.getResult(bufferResultIndices[i])); + forYield->setOperand(yieldIdx, + newFrontIfOp.getResult(bufferResultIndices[i])); llvm::outs() << "Replaced yield operand " << yieldIdx << "\n"; } else { @@ -5282,13 +5255,11 @@ void addMultiBuffCaculate(ModuleOp module, SmallVector newUniqueDeps, } } - llvm::outs() << "multibuffer end!\n"; } // Compute the nesting level of an ifOp within the specified forOp -static int computeIfLevel(scf::IfOp ifOp, scf::ForOp rootForOp) -{ +static int computeIfLevel(scf::IfOp ifOp, scf::ForOp rootForOp) { int level = 1; Operation *parent = ifOp->getParentOp(); @@ -5303,8 +5274,7 @@ static int computeIfLevel(scf::IfOp ifOp, scf::ForOp rootForOp) return level; } -int assignIfOpLevels(scf::ForOp forOp) -{ +int assignIfOpLevels(scf::ForOp forOp) { SmallVector targetIfOps; int maxLevel = 0; // Collect all ifOp assigned with ssbuffer tag @@ -5319,14 +5289,12 @@ int assignIfOpLevels(scf::ForOp forOp) int level = computeIfLevel(ifOp, forOp); maxLevel = std::max(level, maxLevel); Builder builder(ifOp.getContext()); - ifOp->setAttr("ssbuffer.level", - builder.getI32IntegerAttr(level)); + ifOp->setAttr("ssbuffer.level", builder.getI32IntegerAttr(level)); } return maxLevel; } -static bool hasSSBufferIf(scf::ForOp forOp) -{ +static bool hasSSBufferIf(scf::ForOp forOp) { bool found = false; forOp.walk([&](scf::IfOp ifOp) { @@ -5340,8 +5308,7 @@ static bool hasSSBufferIf(scf::ForOp forOp) return found; } -static bool hasAncestorSSBufferFor(scf::ForOp forOp) -{ +static bool hasAncestorSSBufferFor(scf::ForOp forOp) { Operation *parent = forOp->getParentOp(); while (parent) { @@ -5355,8 +5322,7 @@ static bool hasAncestorSSBufferFor(scf::ForOp forOp) return false; } -static bool hasAncestorRootFor(scf::ForOp forOp) -{ +static bool hasAncestorRootFor(scf::ForOp forOp) { Operation *parent = forOp->getParentOp(); while (parent) { @@ -5369,11 +5335,9 @@ static bool hasAncestorRootFor(scf::ForOp forOp) return false; } -SmallVector collectIfInfo( - scf::ForOp &curForOp, - DenseMap> &ifDeps, - int level) -{ +SmallVector +collectIfInfo(scf::ForOp &curForOp, + DenseMap> &ifDeps, int level) { // Find all dependency variables based on the inputs and outputs of ifOp SmallVector allDeps; DenseSet producedValues; @@ -5386,7 +5350,7 @@ SmallVector collectIfInfo( // Levels match → check the direct parent if (auto parentFor = dyn_cast(ifOp->getParentOp())) { - newForOp = parentFor; // 更新 + newForOp = parentFor; // 更新 } // Stop walking regardless of whether the parent is a for-loop @@ -5406,20 +5370,22 @@ SmallVector collectIfInfo( ifOps.push_back(ifOp); return WalkResult::advance(); }); - llvm::outs()<<"ifOps:"<getOperands():"<getOperands().size()<<"\n"; + llvm::outs() << "ifOp->getOperands():" << ifOp->getOperands().size() + << "\n"; SmallVector deps; if (producedValues.empty()) { - llvm::outs()<<"producedValues为空!"<<"\n"; + llvm::outs() << "producedValues为空!" + << "\n"; } - + // inputs Region &thenRegion = ifOp.getThenRegion(); for (Operation &op : thenRegion.front()) { @@ -5473,7 +5439,6 @@ void WalkAIVNestedForAndProcess( SmallVector targetFors; scope.walk([&](scf::ForOp forOp) { - // Must contain an ssbuffer if if (!hasSSBufferIf(forOp)) return WalkResult::advance(); @@ -5495,15 +5460,16 @@ void WalkAIVNestedForAndProcess( maxLevels = assignIfOpLevels(currentFor); for (int level = 1; level <= maxLevels; level++) { auto uniqueDeps = collectIfInfo(currentFor, ifResultDeps, level); - llvm::outs()<<"maxLevels:"<> newIfResultDeps; auto uniqueList = collectIfInfo(newForOp, newIfResultDeps, level); - addMultiBuffCaculate(module, uniqueList, newIfResultDeps, newForOp, bufferNum); + addMultiBuffCaculate(module, uniqueList, newIfResultDeps, newForOp, + bufferNum); } } }); @@ -5526,8 +5492,6 @@ void DAGSSBufferPass::runOnOperation() { return; } -std::unique_ptr> -mlir::triton::createDAGSSBufferPass() { +std::unique_ptr> mlir::triton::createDAGSSBufferPass() { return std::make_unique(); } - diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp index 0f19131232..ed82b46084 100644 --- a/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGScope.cpp @@ -22,26 +22,26 @@ #include "TritonAffinityOpt/Passes.h" -#include "bishengir/Dialect/Scope/IR/Scope.h" #include "bishengir/Dialect/HIVM/IR/HIVM.h" #include "bishengir/Dialect/HIVM/IR/HIVMImpl.h" -#include "bishengir/Dialect/HIVM/Transforms/Passes.h" #include "bishengir/Dialect/HIVM/IR/HIVMInterfaces.h" +#include "bishengir/Dialect/HIVM/Transforms/Passes.h" #include "bishengir/Dialect/HIVM/Utils/Utils.h" +#include "bishengir/Dialect/Scope/IR/Scope.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Dialect/Triton/IR/Dialect.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Block.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include #include "TritonAffinityOpt/DAG.h" @@ -57,128 +57,131 @@ using namespace mlir; using namespace hivm; namespace { -struct DAGScopePass - : public mlir::triton::impl::DAGScopeBase< - DAGScopePass> { +struct DAGScopePass : public mlir::triton::impl::DAGScopeBase { void runOnOperation() override; }; } // namespace +static std::pair +encapsulateWithScope(triton::FuncOp funcOp) { + Block &entryBlock = funcOp.getBody().front(); + Block &lastBlock = funcOp.getBody().back(); + Operation *terminator = lastBlock.getTerminator(); -static std::pair encapsulateWithScope(triton::FuncOp funcOp) { - Block &entryBlock = funcOp.getBody().front(); - Block &lastBlock = funcOp.getBody().back(); - Operation *terminator = lastBlock.getTerminator(); - - - // 辅助函数:判断操作是否应该被跳过 - auto shouldSkipOp = [](Operation *op) -> bool { - return isa(op) || isa(op) || isa(op); - }; + // 辅助函数:判断操作是否应该被跳过 + auto shouldSkipOp = [](Operation *op) -> bool { + return isa(op) || isa(op) || + isa(op); + }; - // 第三步:准备要移动的操作列表(按顺序) - SmallVector opsToMove; - DenseMap opOrder; - int order = 0; + // 第三步:准备要移动的操作列表(按顺序) + SmallVector opsToMove; + DenseMap opOrder; + int order = 0; - // 记录原始顺序并收集需要移动的操作 - for (Operation &op : lastBlock.without_terminator()) { - opOrder[&op] = order++; - if (!shouldSkipOp(&op)) { - opsToMove.push_back(&op); - } + // 记录原始顺序并收集需要移动的操作 + for (Operation &op : lastBlock.without_terminator()) { + opOrder[&op] = order++; + if (!shouldSkipOp(&op)) { + opsToMove.push_back(&op); } + } - // 按原始顺序排序 - std::sort(opsToMove.begin(), opsToMove.end(), - [&](Operation *a, Operation *b) { - return opOrder[a] < opOrder[b]; - }); + // 按原始顺序排序 + std::sort( + opsToMove.begin(), opsToMove.end(), + [&](Operation *a, Operation *b) { return opOrder[a] < opOrder[b]; }); - if (opsToMove.empty()) { - return std::make_pair(nullptr, nullptr); - } + if (opsToMove.empty()) { + return std::make_pair(nullptr, nullptr); + } - // 第四步:创建scope操作并移动操作 - Operation *lastOpToMove = opsToMove.back(); - OpBuilder builder(&lastBlock, ++lastOpToMove->getIterator()); + // 第四步:创建scope操作并移动操作 + Operation *lastOpToMove = opsToMove.back(); + OpBuilder builder(&lastBlock, ++lastOpToMove->getIterator()); - // 创建第一个scope - auto scopeOp = builder.create(builder.getUnknownLoc(), llvm::ArrayRef{}); - scopeOp.getBodyRegion().emplaceBlock(); - Block *scopeBody = &scopeOp.getBodyRegion().front(); + // 创建第一个scope + auto scopeOp = builder.create(builder.getUnknownLoc(), + llvm::ArrayRef{}); + scopeOp.getBodyRegion().emplaceBlock(); + Block *scopeBody = &scopeOp.getBodyRegion().front(); - // 移动操作到scope中 - OpBuilder scopeBuilder(scopeBody, scopeBody->end()); - DenseMap valueMapping; + // 移动操作到scope中 + OpBuilder scopeBuilder(scopeBody, scopeBody->end()); + DenseMap valueMapping; - for (Operation *op : opsToMove) { - SmallVector originalResults = op->getResults(); - op->remove(); - scopeBuilder.insert(op); + for (Operation *op : opsToMove) { + SmallVector originalResults = op->getResults(); + op->remove(); + scopeBuilder.insert(op); - // 更新值的映射 - for (size_t i = 0; i < originalResults.size(); ++i) { - valueMapping[originalResults[i]] = op->getResult(i); - } + // 更新值的映射 + for (size_t i = 0; i < originalResults.size(); ++i) { + valueMapping[originalResults[i]] = op->getResult(i); } + } - // 添加return操作 - scopeBuilder.create(builder.getUnknownLoc()); + // 添加return操作 + scopeBuilder.create(builder.getUnknownLoc()); - // 创建第二个scope(如果需要) - scopeBuilder.setInsertionPointAfter(scopeOp); - auto newScopeOp = scopeBuilder.create(builder.getUnknownLoc(), llvm::ArrayRef{}); - newScopeOp.getRegion().emplaceBlock(); + // 创建第二个scope(如果需要) + scopeBuilder.setInsertionPointAfter(scopeOp); + auto newScopeOp = scopeBuilder.create( + builder.getUnknownLoc(), llvm::ArrayRef{}); + newScopeOp.getRegion().emplaceBlock(); - OpBuilder newScopeBuilder(&newScopeOp.getRegion().front(), - newScopeOp.getRegion().front().begin()); - newScopeBuilder.create(scopeOp->getLoc()); + OpBuilder newScopeBuilder(&newScopeOp.getRegion().front(), + newScopeOp.getRegion().front().begin()); + newScopeBuilder.create(scopeOp->getLoc()); - // 设置属性 - auto vecAttr = hivm::TCoreTypeAttr::get( - builder.getContext(), - hivm::TCoreType::VECTOR); - auto aicAttr = hivm::TCoreTypeAttr::get( - builder.getContext(), - hivm::TCoreType::CUBE); + // 设置属性 + auto vecAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + auto aicAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); - scopeOp->setAttr(hivm::TCoreTypeAttr::name, vecAttr); - newScopeOp->setAttr(hivm::TCoreTypeAttr::name, aicAttr); + scopeOp->setAttr(hivm::TCoreTypeAttr::name, vecAttr); + newScopeOp->setAttr(hivm::TCoreTypeAttr::name, aicAttr); - return std::make_pair(scopeOp, newScopeOp); + return std::make_pair(scopeOp, newScopeOp); } struct OpMoveInfo { - Operation* op; - Operation* targetParent; // 目标父操作(nullptr表示aicScope本身) - }; + Operation *op; + Operation *targetParent; // 目标父操作(nullptr表示aicScope本身) +}; // 递归遍历函数 - 优化版本 -void collectOpsToMove(Operation* op, AffinityDAG::Graph& graph, - Operation* parentFor, llvm::SmallVector& aivToMove, llvm::SmallVector& cubeToMove) { +void collectOpsToMove(Operation *op, AffinityDAG::Graph &graph, + Operation *parentFor, + llvm::SmallVector &aivToMove, + llvm::SmallVector &cubeToMove) { // 检查当前操作是否需要移动 bool needsMoveAiv = false; bool needsMoveCube = false; - auto& valueTypes = graph.getValueTypes(); + auto &valueTypes = graph.getValueTypes(); // 检查结果类型 int i = 0; for (auto res : op->getResults()) { i++; - if (AffinityDAG::intersects(valueTypes[res], AffinityDAG::CoreType::VECTOR_ONLY)) { + if (AffinityDAG::intersects(valueTypes[res], + AffinityDAG::CoreType::VECTOR_ONLY)) { needsMoveAiv = true; } - if (AffinityDAG::intersects(valueTypes[res], AffinityDAG::CoreType::CUBE_ONLY)) { + if (AffinityDAG::intersects(valueTypes[res], + AffinityDAG::CoreType::CUBE_ONLY)) { needsMoveCube = true; } } if (isa(op)) { auto res = op->getOperand(0); - if (AffinityDAG::intersects(valueTypes[res], AffinityDAG::CoreType::VECTOR_ONLY)) { + if (AffinityDAG::intersects(valueTypes[res], + AffinityDAG::CoreType::VECTOR_ONLY)) { needsMoveAiv = true; } - if (AffinityDAG::intersects(valueTypes[res], AffinityDAG::CoreType::CUBE_ONLY)) { + if (AffinityDAG::intersects(valueTypes[res], + AffinityDAG::CoreType::CUBE_ONLY)) { needsMoveCube = true; } } @@ -199,75 +202,75 @@ void collectOpsToMove(Operation* op, AffinityDAG::Graph& graph, } if (isa(op)) { - if (auto storeOp = dyn_cast(op)) { - // 获取所有操作数列表 - auto operands = storeOp.getOperands(); - bool typeMatched = false; - - // 按顺序检查第1个、第0个、第2个操作数 - std::vector checkOrder = {1, 0, 2}; - for (size_t idx : checkOrder) { - // 先判断操作数索引是否有效,避免越界访问 - if (idx >= operands.size()) { - continue; - } - auto operand = operands[idx]; - auto coreType = valueTypes[operand]; - - if (coreType == AffinityDAG::CoreType::VECTOR_ONLY) { - needsMoveAiv = true; - typeMatched = true; - } else if (coreType == AffinityDAG::CoreType::CUBE_ONLY) { - needsMoveCube = true; - typeMatched = true; - } + if (auto storeOp = dyn_cast(op)) { + // 获取所有操作数列表 + auto operands = storeOp.getOperands(); + bool typeMatched = false; + + // 按顺序检查第1个、第0个、第2个操作数 + std::vector checkOrder = {1, 0, 2}; + for (size_t idx : checkOrder) { + // 先判断操作数索引是否有效,避免越界访问 + if (idx >= operands.size()) { + continue; } - // 所有指定操作数都不匹配时,执行原else逻辑 - if (!typeMatched) { + auto operand = operands[idx]; + auto coreType = valueTypes[operand]; + + if (coreType == AffinityDAG::CoreType::VECTOR_ONLY) { needsMoveAiv = true; + typeMatched = true; + } else if (coreType == AffinityDAG::CoreType::CUBE_ONLY) { needsMoveCube = true; + typeMatched = true; } } + // 所有指定操作数都不匹配时,执行原else逻辑 + if (!typeMatched) { + needsMoveAiv = true; + needsMoveCube = true; + } + } } if (isa(op)) { - if (auto assertOp = dyn_cast(op)) { - // 获取所有操作数列表 - auto operand = assertOp.getCondition(); - - auto coreType = valueTypes[operand]; - if (coreType == AffinityDAG::CoreType::VECTOR_ONLY) { - needsMoveAiv = true; - } else if (coreType == AffinityDAG::CoreType::CUBE_ONLY) { - needsMoveCube = true; - } else { - needsMoveAiv = true; - needsMoveCube = true; - } + if (auto assertOp = dyn_cast(op)) { + // 获取所有操作数列表 + auto operand = assertOp.getCondition(); + + auto coreType = valueTypes[operand]; + if (coreType == AffinityDAG::CoreType::VECTOR_ONLY) { + needsMoveAiv = true; + } else if (coreType == AffinityDAG::CoreType::CUBE_ONLY) { + needsMoveCube = true; + } else { + needsMoveAiv = true; + needsMoveCube = true; } + } } // 检查 Sync 操作的 tcore_type 属性 if ((isa(op) || isa(op))) { mlir::OpBuilder builder(op); - auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); if (op->getAttr("tcore_type") == coreAttr) { needsMoveCube = true; - } - else { + } else { needsMoveAiv = true; } } // 如果不需要移动,直接返回 if (!needsMoveAiv && !needsMoveCube) { - llvm::outs()<<"Unsupport Op: "<< *op<<" \n"; + llvm::outs() << "Unsupport Op: " << *op << " \n"; } // 处理 for 循环 if (auto forOp = dyn_cast(op)) { // 确定父级 for 循环 - Operation* targetParent = parentFor != nullptr ? parentFor : nullptr; + Operation *targetParent = parentFor != nullptr ? parentFor : nullptr; aivToMove.push_back({op, targetParent}); cubeToMove.push_back({op, targetParent}); @@ -279,7 +282,7 @@ void collectOpsToMove(Operation* op, AffinityDAG::Graph& graph, } } else if (auto ifOp = dyn_cast(op)) { // 确定父级 for 循环 - Operation* targetParent = parentFor != nullptr ? parentFor : nullptr; + Operation *targetParent = parentFor != nullptr ? parentFor : nullptr; aivToMove.push_back({op, targetParent}); cubeToMove.push_back({op, targetParent}); @@ -304,16 +307,16 @@ void collectOpsToMove(Operation* op, AffinityDAG::Graph& graph, if (needsMoveCube) { cubeToMove.push_back({op, parentFor}); } - } } -mlir::Block* getBlockByIndex(mlir::Region& region, int blockIndex) { +mlir::Block *getBlockByIndex(mlir::Region ®ion, int blockIndex) { // 边界校验:索引非法时返回nullptr - if (blockIndex < 0) return nullptr; + if (blockIndex < 0) + return nullptr; int currentIdx = 0; - for (auto& block : region) { + for (auto &block : region) { if (currentIdx == blockIndex) { return █ // 找到对应索引的Block,直接返回 } @@ -323,29 +326,26 @@ mlir::Block* getBlockByIndex(mlir::Region& region, int blockIndex) { return nullptr; } -void processOperationToMove(const OpMoveInfo& info, - llvm::DenseMap& parentMap, - mlir::OpBuilder& builder, - mlir::IRMapping& mapper, - mlir::Block* aivBlock, - mlir::Operation* terminator, - AffinityDAG::Graph& graph, - int MoveType) { +void processOperationToMove( + const OpMoveInfo &info, + llvm::DenseMap &parentMap, + mlir::OpBuilder &builder, mlir::IRMapping &mapper, mlir::Block *aivBlock, + mlir::Operation *terminator, AffinityDAG::Graph &graph, int MoveType) { // llvm::outs()<<*info.op<<" ssss\n\n\n"; // llvm::outs().flush(); // 获取原始Block信息并计算索引 - mlir::Block* originalBlock = info.op->getBlock(); + mlir::Block *originalBlock = info.op->getBlock(); int originalRegionIndex = -1; int originalBlockIndex = -1; int blockCounter = 0; - auto& valueTypes = graph.getValueTypes(); + auto &valueTypes = graph.getValueTypes(); if (originalBlock) { - mlir::Operation* parentOp = info.op->getParentOp(); // 原始父操作 - if (parentOp) { // 确保父操作存在 + mlir::Operation *parentOp = info.op->getParentOp(); // 原始父操作 + if (parentOp) { // 确保父操作存在 // 老版本MLIR用 getParent() 替代 getParentRegion(),返回值就是Region* - mlir::Region* blockBelongsToRegion = originalBlock->getParent(); + mlir::Region *blockBelongsToRegion = originalBlock->getParent(); int regionCounter = 0; - for (auto& region : parentOp->getRegions()) { // 遍历父操作的所有region + for (auto ®ion : parentOp->getRegions()) { // 遍历父操作的所有region // 直接对比指针,判断当前region是否是block所属的region if (®ion == blockBelongsToRegion) { originalRegionIndex = regionCounter; @@ -357,7 +357,7 @@ void processOperationToMove(const OpMoveInfo& info, } if (originalBlock) { - for (auto& block : originalBlock->getParent()->getBlocks()) { + for (auto &block : originalBlock->getParent()->getBlocks()) { if (&block == originalBlock) { originalBlockIndex = blockCounter; break; @@ -396,12 +396,9 @@ void processOperationToMove(const OpMoveInfo& info, // 创建新的for循环 auto aivForOp = builder.create( - forOp.getLoc(), - getMapped(forOp.getLowerBound()), - getMapped(forOp.getUpperBound()), - getMapped(forOp.getStep()), - llvm::to_vector(llvm::map_range(aivInputs, getMapped)) - ); + forOp.getLoc(), getMapped(forOp.getLowerBound()), + getMapped(forOp.getUpperBound()), getMapped(forOp.getStep()), + llvm::to_vector(llvm::map_range(aivInputs, getMapped))); // 清空循环体 if (!aivForOp.getBody()->empty()) { @@ -410,7 +407,8 @@ void processOperationToMove(const OpMoveInfo& info, // 处理原始循环的yield操作 auto oldBody = forOp.getBody(); - auto oldYield = mlir::dyn_cast(oldBody->getTerminator()); + auto oldYield = + mlir::dyn_cast(oldBody->getTerminator()); assert(oldYield && "scf::ForOp must have a yield terminator"); llvm::SmallVector aivYieldOperands; @@ -428,13 +426,14 @@ void processOperationToMove(const OpMoveInfo& info, int oldInputIndex = it->first; int mappedNewIndex = it->second; mapper.map(oldBodyArgs[oldInputIndex], aivBodyArgs[mappedNewIndex]); - mapper.map((*info.op).getResults()[oldInputIndex - 1], aivForOp->getResults()[mappedNewIndex - 1]); + mapper.map((*info.op).getResults()[oldInputIndex - 1], + aivForOp->getResults()[mappedNewIndex - 1]); } mapper.map(oldBodyArgs[0], aivBodyArgs[0]); // 将新循环移动到目标位置 if (info.targetParent == nullptr) { - mlir::Block* targetBlock = aivBlock; + mlir::Block *targetBlock = aivBlock; if (terminator) { aivForOp->moveBefore(terminator); } else { @@ -443,20 +442,19 @@ void processOperationToMove(const OpMoveInfo& info, parentMap[forOp] = aivForOp; } else { auto targetParent = parentMap[info.targetParent]; - auto& region = targetParent->getRegion(originalRegionIndex); + auto ®ion = targetParent->getRegion(originalRegionIndex); if (region.empty()) { region.push_back(new mlir::Block()); } - mlir::Block* targetBlock = getBlockByIndex(region, originalBlockIndex); + mlir::Block *targetBlock = getBlockByIndex(region, originalBlockIndex); if (targetBlock) { aivForOp->moveBefore(targetBlock, targetBlock->end()); parentMap[forOp] = aivForOp; } else { - llvm::outs()<<"Can't find block by index\n"; + llvm::outs() << "Can't find block by index\n"; } - } } @@ -465,7 +463,8 @@ void processOperationToMove(const OpMoveInfo& info, auto yieldOp = mlir::cast(info.op); // 处理父节点为 scf::ForOp 的情况 - if (auto parentForOp = mlir::dyn_cast(info.targetParent)) { + if (auto parentForOp = + mlir::dyn_cast(info.targetParent)) { auto it = parentMap.find(parentForOp); if (it == parentMap.end()) { return; @@ -486,13 +485,15 @@ void processOperationToMove(const OpMoveInfo& info, } } - auto newYieldOp = builder.create(yieldOp.getLoc(), newYieldOperands); - auto& region = newForOp->getRegion(0); - mlir::Block* targetBlock = ®ion.front(); + auto newYieldOp = builder.create(yieldOp.getLoc(), + newYieldOperands); + auto ®ion = newForOp->getRegion(0); + mlir::Block *targetBlock = ®ion.front(); newYieldOp->moveBefore(targetBlock, targetBlock->end()); } // 处理父节点为 scf::IfOp 的情况 - else if (auto parentIfOp = mlir::dyn_cast(info.targetParent)) { + else if (auto parentIfOp = + mlir::dyn_cast(info.targetParent)) { auto it = parentMap.find(parentIfOp); if (it == parentMap.end()) { return; @@ -513,13 +514,14 @@ void processOperationToMove(const OpMoveInfo& info, } } - auto& region = newIfOp->getRegion(originalRegionIndex); - auto newYieldOp = builder.create(yieldOp.getLoc(), newYieldOperands); - mlir::Block* targetBlock = getBlockByIndex(region, originalBlockIndex); + auto ®ion = newIfOp->getRegion(originalRegionIndex); + auto newYieldOp = builder.create(yieldOp.getLoc(), + newYieldOperands); + mlir::Block *targetBlock = getBlockByIndex(region, originalBlockIndex); if (targetBlock) { newYieldOp->moveBefore(targetBlock, targetBlock->end()); } else { - llvm::outs()<<"Can't find block by index\n"; + llvm::outs() << "Can't find block by index\n"; } } } @@ -549,31 +551,28 @@ void processOperationToMove(const OpMoveInfo& info, // 创建新的if操作 auto aivIfOp = builder.create( - ifOp.getLoc(), - aivResultTypes, - getMapped(condition) - ); + ifOp.getLoc(), aivResultTypes, getMapped(condition)); // 映射if操作结果 - for (auto& [oldIdx, newIdx] : aivResultMap) { + for (auto &[oldIdx, newIdx] : aivResultMap) { mapper.map(ifOp.getResult(oldIdx), aivIfOp.getResult(newIdx)); } // 初始化then和else区域 - mlir::Region& thenRegion = aivIfOp.getThenRegion(); - mlir::Block* thenBlock = new mlir::Block(); + mlir::Region &thenRegion = aivIfOp.getThenRegion(); + mlir::Block *thenBlock = new mlir::Block(); thenRegion.push_back(thenBlock); - mlir::Region& elseRegion = ifOp.getElseRegion(); + mlir::Region &elseRegion = ifOp.getElseRegion(); if (!elseRegion.empty()) { - mlir::Region& elseRegion = aivIfOp.getElseRegion(); - mlir::Block* elseBlock = new mlir::Block(); - elseRegion.push_back(elseBlock); + mlir::Region &elseRegion = aivIfOp.getElseRegion(); + mlir::Block *elseBlock = new mlir::Block(); + elseRegion.push_back(elseBlock); } // 将新if操作移动到目标位置 if (info.targetParent == nullptr) { - mlir::Block* targetBlock = aivBlock; + mlir::Block *targetBlock = aivBlock; if (terminator) { aivIfOp->moveBefore(terminator); } else { @@ -581,17 +580,18 @@ void processOperationToMove(const OpMoveInfo& info, } parentMap[ifOp] = aivIfOp; } else { - auto& region = parentMap[info.targetParent]->getRegion(originalRegionIndex); + auto ®ion = + parentMap[info.targetParent]->getRegion(originalRegionIndex); if (region.empty()) { region.push_back(new mlir::Block()); } - mlir::Block* targetBlock = getBlockByIndex(region, originalBlockIndex); + mlir::Block *targetBlock = getBlockByIndex(region, originalBlockIndex); if (targetBlock) { aivIfOp->moveBefore(targetBlock, targetBlock->end()); parentMap[ifOp] = aivIfOp; } else { - llvm::outs()<<"Can't find block by index\n"; + llvm::outs() << "Can't find block by index\n"; } } } @@ -605,30 +605,31 @@ void processOperationToMove(const OpMoveInfo& info, } if (info.targetParent == nullptr) { - mlir::Block* targetBlock = aivBlock; + mlir::Block *targetBlock = aivBlock; clonedOp->moveBefore(terminator); parentMap[info.op] = clonedOp; } else { auto parentIt = parentMap.find(info.targetParent); auto mappedParentOp = parentIt->second; - auto& region = mappedParentOp->getRegion(originalRegionIndex); + auto ®ion = mappedParentOp->getRegion(originalRegionIndex); if (region.empty()) { region.push_back(new mlir::Block()); } - mlir::Block* targetBlock = getBlockByIndex(region, originalBlockIndex); + mlir::Block *targetBlock = getBlockByIndex(region, originalBlockIndex); if (targetBlock) { clonedOp->moveBefore(targetBlock, targetBlock->end()); } else { - llvm::outs()<<"Can't find block by index\n"; + llvm::outs() << "Can't find block by index\n"; } - } } } -static void SplitScope(triton::FuncOp funcOp, AffinityDAG::Graph& graph, Operation* aivScope, Operation* aicScope, ModuleOp module) { +static void SplitScope(triton::FuncOp funcOp, AffinityDAG::Graph &graph, + Operation *aivScope, Operation *aicScope, + ModuleOp module) { llvm::SmallVector aivToMove; llvm::SmallVector cubeToMove; for (auto &block : aivScope->getRegion(0)) { @@ -638,440 +639,405 @@ static void SplitScope(triton::FuncOp funcOp, AffinityDAG::Graph& graph, Operati } mlir::IRMapping aivmapper; mlir::OpBuilder builder(aivScope); - llvm::DenseMap aivparentMap; + llvm::DenseMap aivparentMap; // 第二遍:实际移动操作 // 先移动for循环 - mlir::Block* aivBlock = &aivScope->getRegion(0).front(); // 或者使用合适的block - SmallVector deleteOp; - auto* terminator = aivBlock->getTerminator(); - // 如果操作已被使用,直接跳过 - llvm::SmallVector aivUsedOp; // 改为函数内静态,保持原有逻辑 - for (const auto& info : aivToMove) { - if (std::find(aivUsedOp.begin(), aivUsedOp.end(), info.op) != aivUsedOp.end()) { + mlir::Block *aivBlock = + &aivScope->getRegion(0).front(); // 或者使用合适的block + SmallVector deleteOp; + auto *terminator = aivBlock->getTerminator(); + // 如果操作已被使用,直接跳过 + llvm::SmallVector + aivUsedOp; // 改为函数内静态,保持原有逻辑 + for (const auto &info : aivToMove) { + if (std::find(aivUsedOp.begin(), aivUsedOp.end(), info.op) != + aivUsedOp.end()) { return; } aivUsedOp.push_back(info.op); - processOperationToMove(info, aivparentMap, builder, aivmapper, aivBlock, terminator, graph, AffinityDAG::CoreType::CUBE_ONLY); + processOperationToMove(info, aivparentMap, builder, aivmapper, aivBlock, + terminator, graph, AffinityDAG::CoreType::CUBE_ONLY); } - llvm::DenseMap aicparentMap; + llvm::DenseMap aicparentMap; mlir::IRMapping aicmapper; - mlir::Block* aicBlock = &aicScope->getRegion(0).front(); // 或者使用合适的block + mlir::Block *aicBlock = + &aicScope->getRegion(0).front(); // 或者使用合适的block terminator = aicBlock->getTerminator(); - llvm::SmallVector aicUsedOp; // 改为函数内静态,保持原有逻辑 - for (const auto& info : cubeToMove) { - if (std::find(aicUsedOp.begin(), aicUsedOp.end(), info.op) != aicUsedOp.end()) { + llvm::SmallVector + aicUsedOp; // 改为函数内静态,保持原有逻辑 + for (const auto &info : cubeToMove) { + if (std::find(aicUsedOp.begin(), aicUsedOp.end(), info.op) != + aicUsedOp.end()) { return; } aicUsedOp.push_back(info.op); - processOperationToMove(info, aicparentMap, builder, aicmapper, aicBlock, terminator, graph, AffinityDAG::CoreType::VECTOR_ONLY); + processOperationToMove(info, aicparentMap, builder, aicmapper, aicBlock, + terminator, graph, + AffinityDAG::CoreType::VECTOR_ONLY); } - for (const auto& info : aivToMove) { - if (std::find(deleteOp.begin(), deleteOp.end(), info.op) == deleteOp.end()) { - deleteOp.push_back(info.op); + for (const auto &info : aivToMove) { + if (std::find(deleteOp.begin(), deleteOp.end(), info.op) == + deleteOp.end()) { + deleteOp.push_back(info.op); } } - for (const auto& info : cubeToMove) { - if (std::find(deleteOp.begin(), deleteOp.end(), info.op) == deleteOp.end()) { - deleteOp.push_back(info.op); + for (const auto &info : cubeToMove) { + if (std::find(deleteOp.begin(), deleteOp.end(), info.op) == + deleteOp.end()) { + deleteOp.push_back(info.op); } } // llvm::outs() << "\n" << module<<" ====== ddd ====== \n\n\n"; // llvm::outs().flush(); for (auto it = deleteOp.rbegin(); it != deleteOp.rend(); ++it) { - (*it)->erase(); // 解引用反向迭代器,调用 erase 方法 + (*it)->erase(); // 解引用反向迭代器,调用 erase 方法 } return; +} +/// 创建setop +static hivm::SyncBlockSetOp +createSyncBlockSetOp(OpBuilder &builder, Location loc, hivm::TCoreType coreType, + hivm::PIPE setPipeEnum, hivm::PIPE waitPipeEnum, + int64_t flag) { + MLIRContext *ctx = builder.getContext(); + auto coreAttr = hivm::TCoreTypeAttr::get(ctx, coreType); + auto setPipe = hivm::PipeAttr::get(ctx, setPipeEnum); + auto waitPipe = hivm::PipeAttr::get(ctx, waitPipeEnum); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + return builder.create(loc, coreAttr, setPipe, waitPipe, + flagId); } - /// 创建setop - static hivm::SyncBlockSetOp createSyncBlockSetOp( - OpBuilder &builder, - Location loc, - hivm::TCoreType coreType, - hivm::PIPE setPipeEnum, - hivm::PIPE waitPipeEnum, - int64_t flag) { - MLIRContext *ctx = builder.getContext(); - auto coreAttr = hivm::TCoreTypeAttr::get(ctx, coreType); - auto setPipe = hivm::PipeAttr::get(ctx, setPipeEnum); - auto waitPipe = hivm::PipeAttr::get(ctx, waitPipeEnum); - auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); - return builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - } - - /// 创建waitop - static hivm::SyncBlockWaitOp createSyncBlockWaitOp( - OpBuilder &builder, - Location loc, - hivm::TCoreType coreType, - hivm::PIPE setPipeEnum, - hivm::PIPE waitPipeEnum, - int64_t flag) { - MLIRContext *ctx = builder.getContext(); - auto coreAttr = hivm::TCoreTypeAttr::get(ctx, coreType); - auto setPipe = hivm::PipeAttr::get(ctx, setPipeEnum); - auto waitPipe = hivm::PipeAttr::get(ctx, waitPipeEnum); - auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); - return builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - } - - // 在scope return前插入wait - static void insertWaitBeforeFinalReturn(Region *region, OpBuilder &builder, int64_t flag, bool coretypebool) { - for (Block &block : *region) { - if (auto returnOp = dyn_cast_or_null(block.getTerminator())) { - builder.setInsertionPoint(returnOp); - if (coretypebool) { - createSyncBlockWaitOp( - builder, - returnOp->getLoc(), - hivm::TCoreType::CUBE, - hivm::PIPE::PIPE_V, - hivm::PIPE::PIPE_FIX, - flag - ); - return; - } - else { - createSyncBlockWaitOp( - builder, - returnOp->getLoc(), - hivm::TCoreType::VECTOR, - hivm::PIPE::PIPE_M, - hivm::PIPE::PIPE_MTE3, - flag - ); - return; - } - } - } - } - - /// 在scope内起始位置加上set - static void insertSetAtRegionStart(Region *region, OpBuilder &builder, int64_t flag, bool coretypebool) { - if (!region->empty()) { - Block &entry = region->front(); - Location loc = entry.empty() ? region->getParentOp()->getLoc() : entry.front().getLoc(); - builder.setInsertionPointToStart(&entry); - if (coretypebool) { - createSyncBlockSetOp( - builder, - loc, - hivm::TCoreType::VECTOR, - hivm::PIPE::PIPE_V, - hivm::PIPE::PIPE_FIX, - flag - ); - } - else { - createSyncBlockSetOp( - builder, - loc, - hivm::TCoreType::CUBE, - hivm::PIPE::PIPE_M, - hivm::PIPE::PIPE_MTE3, - flag - ); - } - } - } - - static Operation *findNextSyncBlockSetAfter(Operation *startOp) { - Block *block = startOp->getBlock(); - auto it = ++startOp->getIterator(); - for (; it != block->end(); ++it) { - if (isa(*it)) - return &*it; - } - return nullptr; - } - - static hivm::SyncBlockWaitOp findWaitOpInRegionWithFlag(Region *region, int64_t flag) { - hivm::SyncBlockWaitOp result; - region->walk([&](hivm::SyncBlockWaitOp op) { - auto flagAttr = op->getAttrOfType("static_flag_id"); - if (flagAttr && flagAttr.getInt() == flag) { - result = op; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return result; - } - - static Operation *findInsertionPointAfterWaitForAIV(Operation *waitOp) { - Block *block = waitOp->getBlock(); - auto it = ++waitOp->getIterator(); - - for (; it != block->end(); ++it) { - if (isa(*it) || isa(*it)) { - break; - } - } +/// 创建waitop +static hivm::SyncBlockWaitOp +createSyncBlockWaitOp(OpBuilder &builder, Location loc, + hivm::TCoreType coreType, hivm::PIPE setPipeEnum, + hivm::PIPE waitPipeEnum, int64_t flag) { + MLIRContext *ctx = builder.getContext(); + auto coreAttr = hivm::TCoreTypeAttr::get(ctx, coreType); + auto setPipe = hivm::PipeAttr::get(ctx, setPipeEnum); + auto waitPipe = hivm::PipeAttr::get(ctx, waitPipeEnum); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + return builder.create(loc, coreAttr, setPipe, waitPipe, + flagId); +} - while (it != block->begin()) { - auto prevIt = std::prev(it); - if (isa(*prevIt)) { - it = prevIt; - } else { - break; - } +// 在scope return前插入wait +static void insertWaitBeforeFinalReturn(Region *region, OpBuilder &builder, + int64_t flag, bool coretypebool) { + for (Block &block : *region) { + if (auto returnOp = + dyn_cast_or_null(block.getTerminator())) { + builder.setInsertionPoint(returnOp); + if (coretypebool) { + createSyncBlockWaitOp(builder, returnOp->getLoc(), + hivm::TCoreType::CUBE, hivm::PIPE::PIPE_V, + hivm::PIPE::PIPE_FIX, flag); + return; + } else { + createSyncBlockWaitOp(builder, returnOp->getLoc(), + hivm::TCoreType::VECTOR, hivm::PIPE::PIPE_M, + hivm::PIPE::PIPE_MTE3, flag); + return; } + } + } +} + +/// 在scope内起始位置加上set +static void insertSetAtRegionStart(Region *region, OpBuilder &builder, + int64_t flag, bool coretypebool) { + if (!region->empty()) { + Block &entry = region->front(); + Location loc = entry.empty() ? region->getParentOp()->getLoc() + : entry.front().getLoc(); + builder.setInsertionPointToStart(&entry); + if (coretypebool) { + createSyncBlockSetOp(builder, loc, hivm::TCoreType::VECTOR, + hivm::PIPE::PIPE_V, hivm::PIPE::PIPE_FIX, flag); + } else { + createSyncBlockSetOp(builder, loc, hivm::TCoreType::CUBE, + hivm::PIPE::PIPE_M, hivm::PIPE::PIPE_MTE3, flag); + } + } +} +static Operation *findNextSyncBlockSetAfter(Operation *startOp) { + Block *block = startOp->getBlock(); + auto it = ++startOp->getIterator(); + for (; it != block->end(); ++it) { + if (isa(*it)) return &*it; - } - - static Operation *findInsertionPointAfterWaitForAIC(Operation *waitOp) { - Block *block = waitOp->getBlock(); - auto it = ++waitOp->getIterator(); - for (; it != block->end(); ++it) { - if (auto fixpipe = dyn_cast(*it)) { - if (it != block->begin()) { - auto prev = std::prev(it); - if (isa(*prev)) - return &*prev; - } - return &*it; - } - if (isa(*it)) - return &*it; - } - return nullptr; - } - - // 查找 FixpipeOp 下一行的 sync_block_set 操作的 flag 值 - static int findFixPipeFlagSafe(hivm::FixpipeOp fixpipeOp) { - mlir::Operation *fixpipeOperation = fixpipeOp.getOperation(); - if (!fixpipeOperation || !fixpipeOperation->getBlock()) { - return -1; - } + } + return nullptr; +} - // 获取 FixpipeOp 的迭代器 - auto it = ++fixpipeOperation->getIterator(); - - // 遍历后续操作直到找到 sync_block_set - while (it != fixpipeOperation->getBlock()->end()) { - mlir::Operation &op = *it++; - - if (op.getName().getStringRef() == "hivm.hir.sync_block_set") { - auto staticFlagAttr = op.getAttrOfType("static_flag_id"); - return staticFlagAttr.getInt(); - break; - } - } +static hivm::SyncBlockWaitOp findWaitOpInRegionWithFlag(Region *region, + int64_t flag) { + hivm::SyncBlockWaitOp result; + region->walk([&](hivm::SyncBlockWaitOp op) { + auto flagAttr = op->getAttrOfType("static_flag_id"); + if (flagAttr && flagAttr.getInt() == flag) { + result = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} - return -1; +static Operation *findInsertionPointAfterWaitForAIV(Operation *waitOp) { + Block *block = waitOp->getBlock(); + auto it = ++waitOp->getIterator(); + for (; it != block->end(); ++it) { + if (isa(*it) || isa(*it)) { + break; } + } - /// cube处理逻辑 - static void processFixpipeOpsInAIC( - Region *aicRegion, - Region *aivRegion) { - - MLIRContext *ctx = aicRegion->getContext(); - OpBuilder builder(ctx); - SmallVector fixpipes; - aicRegion->walk([&](hivm::FixpipeOp op) { - fixpipes.push_back(op); - }); - - - for (auto fixpipeOp : fixpipes) { - - auto newflag = findFixPipeFlagSafe(fixpipeOp); - // 1. 在 FixpipeOp 前插 Wait - builder.setInsertionPoint(fixpipeOp); - createSyncBlockWaitOp( - builder, - fixpipeOp->getLoc(), - hivm::TCoreType::CUBE, - hivm::PIPE::PIPE_V, - hivm::PIPE::PIPE_FIX, - newflag); - bool coretypebool = true; - - // 2. 在 aicRegion 末尾 Return 前插 Wait - insertWaitBeforeFinalReturn(aicRegion, builder, newflag, coretypebool); - - // 3. 在 aivRegion 开头插 Set - insertSetAtRegionStart(aivRegion, builder, newflag, coretypebool); - - // 4. 在 aicRegion 向后找 SyncBlockSetOp - if (auto *nextSetOp = findNextSyncBlockSetAfter(fixpipeOp)) { - auto setFlagAttr = nextSetOp->getAttrOfType("static_flag_id"); - // 调试:打印set - // llvm::dbgs() << "aicnextSetOp:"; - // nextSetOp->dump(); - if (!setFlagAttr) { - llvm::dbgs() << "AIC can not find setop in aic\n"; - continue; - } - int64_t setflag = setFlagAttr.getInt(); - - // 5. 在 aivRegion 中找 flag=setflag 的 WaitOp - auto targetWait = findWaitOpInRegionWithFlag(aivRegion, setflag); - if (!targetWait) { - llvm::dbgs() << "AIC can not find waitop in aiv\n"; - continue; - } - - // 调试:打印wait - // llvm::dbgs() << "aictargetWait:"; - // llvm::dbgs() << targetWait << "\n"; - - // 6. 从该 Wait 向下找 ToMemrefOp 或 Yield,插 Set(newflag) - if (auto *insertPt = findInsertionPointAfterWaitForAIV(targetWait)) { - builder.setInsertionPoint(insertPt); - createSyncBlockSetOp( - builder, - fixpipeOp->getLoc(), - hivm::TCoreType::VECTOR, - hivm::PIPE::PIPE_V, - hivm::PIPE::PIPE_FIX, - newflag); - } - } + while (it != block->begin()) { + auto prevIt = std::prev(it); + if (isa(*prevIt)) { + it = prevIt; + } else { + break; + } + } + + return &*it; +} + +static Operation *findInsertionPointAfterWaitForAIC(Operation *waitOp) { + Block *block = waitOp->getBlock(); + auto it = ++waitOp->getIterator(); + for (; it != block->end(); ++it) { + if (auto fixpipe = dyn_cast(*it)) { + if (it != block->begin()) { + auto prev = std::prev(it); + if (isa(*prev)) + return &*prev; } + return &*it; } + if (isa(*it)) + return &*it; + } + return nullptr; +} + +// 查找 FixpipeOp 下一行的 sync_block_set 操作的 flag 值 +static int findFixPipeFlagSafe(hivm::FixpipeOp fixpipeOp) { + mlir::Operation *fixpipeOperation = fixpipeOp.getOperation(); + if (!fixpipeOperation || !fixpipeOperation->getBlock()) { + return -1; + } + + // 获取 FixpipeOp 的迭代器 + auto it = ++fixpipeOperation->getIterator(); + + // 遍历后续操作直到找到 sync_block_set + while (it != fixpipeOperation->getBlock()->end()) { + mlir::Operation &op = *it++; + + if (op.getName().getStringRef() == "hivm.hir.sync_block_set") { + auto staticFlagAttr = + op.getAttrOfType("static_flag_id"); + return staticFlagAttr.getInt(); + break; + } + } - // 查找 copyOp 下一行的 sync_block_set 操作的 flag 值 - static int findCopyFlagSafe(bufferization::ToMemrefOp toMemrefOp) { - mlir::Operation *toMemrefOperation = toMemrefOp.getOperation(); - if (!toMemrefOperation || !toMemrefOperation->getBlock()) { - return -1; + return -1; +} + +/// cube处理逻辑 +static void processFixpipeOpsInAIC(Region *aicRegion, Region *aivRegion) { + + MLIRContext *ctx = aicRegion->getContext(); + OpBuilder builder(ctx); + SmallVector fixpipes; + aicRegion->walk([&](hivm::FixpipeOp op) { fixpipes.push_back(op); }); + + for (auto fixpipeOp : fixpipes) { + + auto newflag = findFixPipeFlagSafe(fixpipeOp); + // 1. 在 FixpipeOp 前插 Wait + builder.setInsertionPoint(fixpipeOp); + createSyncBlockWaitOp(builder, fixpipeOp->getLoc(), hivm::TCoreType::CUBE, + hivm::PIPE::PIPE_V, hivm::PIPE::PIPE_FIX, newflag); + bool coretypebool = true; + + // 2. 在 aicRegion 末尾 Return 前插 Wait + insertWaitBeforeFinalReturn(aicRegion, builder, newflag, coretypebool); + + // 3. 在 aivRegion 开头插 Set + insertSetAtRegionStart(aivRegion, builder, newflag, coretypebool); + + // 4. 在 aicRegion 向后找 SyncBlockSetOp + if (auto *nextSetOp = findNextSyncBlockSetAfter(fixpipeOp)) { + auto setFlagAttr = + nextSetOp->getAttrOfType("static_flag_id"); + // 调试:打印set + // llvm::dbgs() << "aicnextSetOp:"; + // nextSetOp->dump(); + if (!setFlagAttr) { + llvm::dbgs() << "AIC can not find setop in aic\n"; + continue; } + int64_t setflag = setFlagAttr.getInt(); - // 获取 copyOp 的迭代器 - auto it = ++toMemrefOperation->getIterator(); - - // 遍历后续操作直到找到 sync_block_set - while (it != toMemrefOperation->getBlock()->end()) { - mlir::Operation &op = *it++; - - if (op.getName().getStringRef() == "hivm.hir.sync_block_set") { - auto staticFlagAttr = op.getAttrOfType("static_flag_id"); - return staticFlagAttr.getInt(); - break; - } + // 5. 在 aivRegion 中找 flag=setflag 的 WaitOp + auto targetWait = findWaitOpInRegionWithFlag(aivRegion, setflag); + if (!targetWait) { + llvm::dbgs() << "AIC can not find waitop in aiv\n"; + continue; } - return -1; + // 调试:打印wait + // llvm::dbgs() << "aictargetWait:"; + // llvm::dbgs() << targetWait << "\n"; - } - /// vector处理逻辑 - static void processToMemrefOpsInAIV( - Region *aivRegion, - Region *aicRegion) { - - MLIRContext *ctx = aivRegion->getContext(); - OpBuilder builder(ctx); - SmallVector toMemrefs; - aivRegion->walk([&](bufferization::ToMemrefOp op) { - toMemrefs.push_back(op); - }); - - for (auto toMemrefOp : toMemrefs) { - auto newflag = findCopyFlagSafe(toMemrefOp); - - // 1. 在 ToMemrefOp 前插 Wait - builder.setInsertionPoint(toMemrefOp); - createSyncBlockWaitOp( - builder, - toMemrefOp->getLoc(), - hivm::TCoreType::VECTOR, - hivm::PIPE::PIPE_M, - hivm::PIPE::PIPE_MTE3, - newflag); - bool coretypebool = false; - - // 2. 在 aivRegion 末尾 Return 前插 Wait - insertWaitBeforeFinalReturn(aivRegion, builder, newflag, coretypebool); - - // 3. 在 aicRegion 开头插 Set - insertSetAtRegionStart(aicRegion, builder, newflag, coretypebool); - - // 4. 在 aivRegion 向后找 SyncBlockSetOp - if (auto *nextSetOp = findNextSyncBlockSetAfter(toMemrefOp)) { - auto setFlagAttr = nextSetOp->getAttrOfType("static_flag_id"); - // 调试:打印set及其所有attribute - // llvm::dbgs() << "aivnextSetOp:"; - // nextSetOp->dump(); - // llvm::dbgs() << "Attributes:\n"; - // for (auto namedAttr : nextSetOp->getAttrs()) { - // llvm::dbgs() << " " << namedAttr.getName() << " = "; - // namedAttr.getValue().print(llvm::dbgs()); - // llvm::dbgs() << "\n"; - // } - if (!setFlagAttr) { - llvm::dbgs() << "AIV can not find setop in aiv\n"; - continue; - } - int64_t setflag = setFlagAttr.getInt(); - - // 5. 在 aicRegion 中找 flag=setflag 的 WaitOp - auto targetWait = findWaitOpInRegionWithFlag(aicRegion, setflag); - - if (!targetWait) { - llvm::dbgs() << "AIV can not find waitop in aic\n"; - continue; - } - - // 调试:打印wait - // llvm::dbgs() << "aivtargetWait:"; - // llvm::dbgs() << targetWait << "\n"; - - // 6. 从该 Wait 向下找 Fixpipe 前 Wait 或 Yield,插 Set(newflag) - if (auto *insertPt = findInsertionPointAfterWaitForAIC(targetWait)) { - builder.setInsertionPoint(insertPt); - createSyncBlockSetOp( - builder, - toMemrefOp->getLoc(), - hivm::TCoreType::CUBE, - hivm::PIPE::PIPE_M, - hivm::PIPE::PIPE_MTE3, - newflag); - } - } + // 6. 从该 Wait 向下找 ToMemrefOp 或 Yield,插 Set(newflag) + if (auto *insertPt = findInsertionPointAfterWaitForAIV(targetWait)) { + builder.setInsertionPoint(insertPt); + createSyncBlockSetOp(builder, fixpipeOp->getLoc(), + hivm::TCoreType::VECTOR, hivm::PIPE::PIPE_V, + hivm::PIPE::PIPE_FIX, newflag); } } + } +} - /// 同步点增强 - void addSyncOpsForBufferWait(ModuleOp module) { - for (auto funcOp : llvm::make_early_inc_range(module.getOps())) { - if (funcOp.getBody().empty()) { - continue; - } +// 查找 copyOp 下一行的 sync_block_set 操作的 flag 值 +static int findCopyFlagSafe(bufferization::ToMemrefOp toMemrefOp) { + mlir::Operation *toMemrefOperation = toMemrefOp.getOperation(); + if (!toMemrefOperation || !toMemrefOperation->getBlock()) { + return -1; + } - Region *aicRegion = nullptr; - Region *aivRegion = nullptr; + // 获取 copyOp 的迭代器 + auto it = ++toMemrefOperation->getIterator(); - funcOp.walk([&](scope::ScopeOp scopeOp) { - auto coreTypeAttr = scopeOp->getAttrOfType( - hivm::TCoreTypeAttr::name); - if (!coreTypeAttr) return; + // 遍历后续操作直到找到 sync_block_set + while (it != toMemrefOperation->getBlock()->end()) { + mlir::Operation &op = *it++; - if (coreTypeAttr.getTcoretype() == hivm::TCoreType::CUBE) { - aicRegion = &scopeOp.getRegion(); - } - if (coreTypeAttr.getTcoretype() == hivm::TCoreType::VECTOR) { - aivRegion = &scopeOp.getRegion(); - } - }); + if (op.getName().getStringRef() == "hivm.hir.sync_block_set") { + auto staticFlagAttr = + op.getAttrOfType("static_flag_id"); + return staticFlagAttr.getInt(); + break; + } + } - if (!aicRegion || !aivRegion) { - continue; - } + return -1; +} +/// vector处理逻辑 +static void processToMemrefOpsInAIV(Region *aivRegion, Region *aicRegion) { + + MLIRContext *ctx = aivRegion->getContext(); + OpBuilder builder(ctx); + SmallVector toMemrefs; + aivRegion->walk( + [&](bufferization::ToMemrefOp op) { toMemrefs.push_back(op); }); + + for (auto toMemrefOp : toMemrefs) { + auto newflag = findCopyFlagSafe(toMemrefOp); + + // 1. 在 ToMemrefOp 前插 Wait + builder.setInsertionPoint(toMemrefOp); + createSyncBlockWaitOp(builder, toMemrefOp->getLoc(), + hivm::TCoreType::VECTOR, hivm::PIPE::PIPE_M, + hivm::PIPE::PIPE_MTE3, newflag); + bool coretypebool = false; + + // 2. 在 aivRegion 末尾 Return 前插 Wait + insertWaitBeforeFinalReturn(aivRegion, builder, newflag, coretypebool); + + // 3. 在 aicRegion 开头插 Set + insertSetAtRegionStart(aicRegion, builder, newflag, coretypebool); + + // 4. 在 aivRegion 向后找 SyncBlockSetOp + if (auto *nextSetOp = findNextSyncBlockSetAfter(toMemrefOp)) { + auto setFlagAttr = + nextSetOp->getAttrOfType("static_flag_id"); + // 调试:打印set及其所有attribute + // llvm::dbgs() << "aivnextSetOp:"; + // nextSetOp->dump(); + // llvm::dbgs() << "Attributes:\n"; + // for (auto namedAttr : nextSetOp->getAttrs()) { + // llvm::dbgs() << " " << namedAttr.getName() << " = "; + // namedAttr.getValue().print(llvm::dbgs()); + // llvm::dbgs() << "\n"; + // } + if (!setFlagAttr) { + llvm::dbgs() << "AIV can not find setop in aiv\n"; + continue; + } + int64_t setflag = setFlagAttr.getInt(); + + // 5. 在 aicRegion 中找 flag=setflag 的 WaitOp + auto targetWait = findWaitOpInRegionWithFlag(aicRegion, setflag); + + if (!targetWait) { + llvm::dbgs() << "AIV can not find waitop in aic\n"; + continue; + } + + // 调试:打印wait + // llvm::dbgs() << "aivtargetWait:"; + // llvm::dbgs() << targetWait << "\n"; - processFixpipeOpsInAIC(aicRegion, aivRegion); - processToMemrefOpsInAIV(aivRegion, aicRegion); + // 6. 从该 Wait 向下找 Fixpipe 前 Wait 或 Yield,插 Set(newflag) + if (auto *insertPt = findInsertionPointAfterWaitForAIC(targetWait)) { + builder.setInsertionPoint(insertPt); + createSyncBlockSetOp(builder, toMemrefOp->getLoc(), + hivm::TCoreType::CUBE, hivm::PIPE::PIPE_M, + hivm::PIPE::PIPE_MTE3, newflag); } } + } +} +/// 同步点增强 +void addSyncOpsForBufferWait(ModuleOp module) { + for (auto funcOp : + llvm::make_early_inc_range(module.getOps())) { + if (funcOp.getBody().empty()) { + continue; + } + + Region *aicRegion = nullptr; + Region *aivRegion = nullptr; + + funcOp.walk([&](scope::ScopeOp scopeOp) { + auto coreTypeAttr = scopeOp->getAttrOfType( + hivm::TCoreTypeAttr::name); + if (!coreTypeAttr) + return; + + if (coreTypeAttr.getTcoretype() == hivm::TCoreType::CUBE) { + aicRegion = &scopeOp.getRegion(); + } + if (coreTypeAttr.getTcoretype() == hivm::TCoreType::VECTOR) { + aivRegion = &scopeOp.getRegion(); + } + }); + + if (!aicRegion || !aivRegion) { + continue; + } + + processFixpipeOpsInAIC(aicRegion, aivRegion); + processToMemrefOpsInAIV(aivRegion, aicRegion); + } +} void DAGScopePass::runOnOperation() { auto module = getOperation(); @@ -1079,60 +1045,59 @@ void DAGScopePass::runOnOperation() { mlir::OpBuilder builder(&getContext()); - for (auto funcOp : llvm::make_early_inc_range(module.getOps())) { + for (auto funcOp : + llvm::make_early_inc_range(module.getOps())) { // skip invalid function if (funcOp.getBody().empty()) { continue; } // 收集所有 memref.alloc 操作 - llvm::SmallVector allocOps; + llvm::SmallVector allocOps; // 遍历函数中的所有操作(包括嵌套区域中的操作) funcOp.walk([&](mlir::Operation *op) { - if (mlir::isa(op)) { - allocOps.push_back(op); - } + if (mlir::isa(op)) { + allocOps.push_back(op); + } }); - mlir::Block& entryBlock = funcOp.getBody().front(); + mlir::Block &entryBlock = funcOp.getBody().front(); mlir::Block::iterator insertPos = entryBlock.begin(); // 将 alloc 操作移动到函数的最前面 - for (mlir::Operation* allocOp : allocOps) { - // 如果 alloc 操作已经是最前面的操作,跳过 - if (allocOp->getBlock() == &entryBlock && - allocOp->isBeforeInBlock(&*insertPos)) { - continue; - } + for (mlir::Operation *allocOp : allocOps) { + // 如果 alloc 操作已经是最前面的操作,跳过 + if (allocOp->getBlock() == &entryBlock && + allocOp->isBeforeInBlock(&*insertPos)) { + continue; + } - // 将 alloc 操作移动到指定位置 - allocOp->moveBefore(&entryBlock, insertPos); + // 将 alloc 操作移动到指定位置 + allocOp->moveBefore(&entryBlock, insertPos); } auto funcName = funcOp.getName(); - auto* graph_ptr = AffinityDAG::GraphManager::getInstance().getGraph(funcName); + auto *graph_ptr = + AffinityDAG::GraphManager::getInstance().getGraph(funcName); if (!graph_ptr) { continue; } - auto& main_graph = *graph_ptr; - + auto &main_graph = *graph_ptr; auto ScopeList = encapsulateWithScope(funcOp); - auto aivScope = ScopeList.first; // 第一个元素 + auto aivScope = ScopeList.first; // 第一个元素 auto aicScope = ScopeList.second; // 第二个元素 SplitScope(funcOp, main_graph, aivScope, aicScope, module); } - addSyncOpsForBufferWait(module); - // llvm::outs()<> -mlir::triton::createDAGScopePass() { +std::unique_ptr> mlir::triton::createDAGScopePass() { return std::make_unique(); } - diff --git a/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp b/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp index c0fc70b255..2d01330aed 100644 --- a/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp +++ b/third_party/ascend/lib/TritonAffinityOpt/DAGSync.cpp @@ -1,7 +1,6 @@ #include "TritonAffinityOpt/Passes.h" #include "bishengir/Dialect/Annotation/IR/Annotation.h" -#include "bishengir/Dialect/Scope/IR/Scope.h" #include "bishengir/Dialect/HIVM/IR/HIVM.h" #include "bishengir/Dialect/HIVM/IR/HIVMImpl.h" #include "bishengir/Dialect/HIVM/IR/HIVMInterfaces.h" @@ -38,336 +37,360 @@ namespace mlir { namespace triton { #define GEN_PASS_DEF_DAGSYNC #include "ascend/include/TritonAffinityOpt/Passes.h.inc" -} // namespace triton -} // namespace mlir +} // namespace triton +} // namespace mlir // 使用 DAG 命名空间 using namespace mlir; using namespace hivm; using namespace AffinityDAG; -llvm::DenseMap* valueTypes; +llvm::DenseMap *valueTypes; // 修改类声明,将数据搬运逻辑集成到同步插入中 namespace { struct DAGSyncPass : public mlir::triton::impl::DAGSyncBase { - void runOnOperation() override; + void runOnOperation() override; private: - // 原有的辅助函数 - CoreType getNodeDeviceType(OpNode *node, llvm::DenseMap *valueTypes); - bool needVectorCubeSync(CoreType src, CoreType dst); - - // 修改后的同步插入函数,包含数据搬运 - void insertSyncAndMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, - CoreType srcType, CoreType dstType, - mlir::OpBuilder &builder, int flag, llvm::DenseMap* valueMap, Graph &mainGraph); - - // 新增:处理跨 block 的同步和数据搬运 - void insertSyncAndMovementForCrossBlock(mlir::Operation *srcOp, mlir::Operation *dstOp, - CoreType srcType, CoreType dstType, - mlir::OpBuilder &builder, int flag, - bool dstIsInnerBlock, llvm::DenseMap* valueMap, Graph &mainGraph); - - // 新增:处理 scf.for 循环迭代参数的同步 - void processScfForSync(mlir::scf::ForOp forOp, - Node* forNode, - llvm::DenseMap *valueTypes, - mlir::OpBuilder &builder, - int &flag); - - // 数据搬运相关的辅助函数 - void insertCubeToVectorDataMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, - mlir::Value srcResult, mlir::OpBuilder &builder, - mlir::Location loc, mlir::Value iterArgs); - - void insertVectorToCubeDataMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, Operation * posOp, - mlir::Value srcResult, mlir::OpBuilder &builder, - mlir::Location loc, llvm::DenseMap* valueMap); - - // 获取或创建合适的 memref.alloc - mlir::Value getOrCreateAllocation(mlir::Operation *op, mlir::Type tensorType, - hivm::AddressSpace addressSpace, - mlir::OpBuilder &builder, mlir::Location loc); - - // 获取 tensor 的形状和元素类型 - mlir::RankedTensorType getTensorType(mlir::Value tensorValue); - - // 替换 dstOp 中使用 srcResult 的操作数 - void replaceOperandWithNewValue(mlir::Operation *dstOp, mlir::Value oldValue, - mlir::Value newValue); - - // Find sync position - Operation* FindLastestPosition(Operation* srcOp, Graph &mainGraph, OpBuilder &builder); - Operation* FindEarliestPosition(Operation* dstOp, Graph &mainGraph, OpBuilder &builder); + // 原有的辅助函数 + CoreType getNodeDeviceType(OpNode *node, + llvm::DenseMap *valueTypes); + bool needVectorCubeSync(CoreType src, CoreType dst); + + // 修改后的同步插入函数,包含数据搬运 + void insertSyncAndMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, + CoreType srcType, CoreType dstType, + mlir::OpBuilder &builder, int flag, + llvm::DenseMap *valueMap, + Graph &mainGraph); + + // 新增:处理跨 block 的同步和数据搬运 + void insertSyncAndMovementForCrossBlock( + mlir::Operation *srcOp, mlir::Operation *dstOp, CoreType srcType, + CoreType dstType, mlir::OpBuilder &builder, int flag, + bool dstIsInnerBlock, llvm::DenseMap *valueMap, + Graph &mainGraph); + + // 新增:处理 scf.for 循环迭代参数的同步 + void processScfForSync(mlir::scf::ForOp forOp, Node *forNode, + llvm::DenseMap *valueTypes, + mlir::OpBuilder &builder, int &flag); + + // 数据搬运相关的辅助函数 + void insertCubeToVectorDataMovement(mlir::Operation *srcOp, + mlir::Operation *dstOp, + mlir::Value srcResult, + mlir::OpBuilder &builder, + mlir::Location loc, mlir::Value iterArgs); + + void + insertVectorToCubeDataMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, + Operation *posOp, mlir::Value srcResult, + mlir::OpBuilder &builder, mlir::Location loc, + llvm::DenseMap *valueMap); + + // 获取或创建合适的 memref.alloc + mlir::Value getOrCreateAllocation(mlir::Operation *op, mlir::Type tensorType, + hivm::AddressSpace addressSpace, + mlir::OpBuilder &builder, + mlir::Location loc); + + // 获取 tensor 的形状和元素类型 + mlir::RankedTensorType getTensorType(mlir::Value tensorValue); + + // 替换 dstOp 中使用 srcResult 的操作数 + void replaceOperandWithNewValue(mlir::Operation *dstOp, mlir::Value oldValue, + mlir::Value newValue); + + // Find sync position + Operation *FindLastestPosition(Operation *srcOp, Graph &mainGraph, + OpBuilder &builder); + Operation *FindEarliestPosition(Operation *dstOp, Graph &mainGraph, + OpBuilder &builder); }; -} // namespace - -void DAGSyncPass::processScfForSync(mlir::scf::ForOp forOp, - Node* forNode, - llvm::DenseMap *valueTypes, - mlir::OpBuilder &builder, - int &flag) { - - mlir::Block* loopBody = forOp.getBody(); - mlir::scf::YieldOp yieldOp = nullptr; - for (mlir::Operation &op : *loopBody) { - if (auto yield = mlir::dyn_cast(&op)) { - yieldOp = yield; - break; - } +} // namespace + +void DAGSyncPass::processScfForSync( + mlir::scf::ForOp forOp, Node *forNode, + llvm::DenseMap *valueTypes, mlir::OpBuilder &builder, + int &flag) { + + mlir::Block *loopBody = forOp.getBody(); + mlir::scf::YieldOp yieldOp = nullptr; + for (mlir::Operation &op : *loopBody) { + if (auto yield = mlir::dyn_cast(&op)) { + yieldOp = yield; + break; } - Location loc = forOp.getLoc(); - - for (int i = 0; i < forOp.getInitArgs().size(); i++) { - mlir::BlockArgument iterArg = loopBody->getArgument(i+1); - // 找到首次使用 - mlir::Operation* firstUser = nullptr; - - for (mlir::Operation &op : *loopBody) { - // 跳过 yield 操作 - if (mlir::isa(&op)) { - continue; - } + } + Location loc = forOp.getLoc(); - // 检查是否使用该迭代参数 - bool usesIterArg = false; - for (mlir::Value operand : op.getOperands()) { - if (operand == iterArg) { - usesIterArg = true; - break; - } - } + for (int i = 0; i < forOp.getInitArgs().size(); i++) { + mlir::BlockArgument iterArg = loopBody->getArgument(i + 1); + // 找到首次使用 + mlir::Operation *firstUser = nullptr; - if (usesIterArg) { - firstUser = &op; - break; - } - } - // map 内找到对应的iterType,iterType由首次在loop内使用到的op定义 - if (!firstUser) { - continue; - } - CoreType iterType = CoreType::CUBE_AND_VECTOR; - if (valueTypes->find(firstUser->getResult(0)) != valueTypes->end()) { - iterType = valueTypes->find(firstUser->getResult(0))->second; - } + for (mlir::Operation &op : *loopBody) { + // 跳过 yield 操作 + if (mlir::isa(&op)) { + continue; + } - // 获取对应yield - mlir::Value yieldOperand = yieldOp->getOperand(i); - CoreType yieldType = CoreType::CUBE_AND_VECTOR; - if (valueTypes->find(yieldOperand) != valueTypes->end()) { - yieldType = valueTypes->find(yieldOperand)->second; + // 检查是否使用该迭代参数 + bool usesIterArg = false; + for (mlir::Value operand : op.getOperands()) { + if (operand == iterArg) { + usesIterArg = true; + break; } - mlir::Operation* yieldDefiningOp = yieldOperand.getDefiningOp(); - - if (yieldType == CoreType::CUBE_ONLY && iterType == CoreType::VECTOR_ONLY) { - - // 2. 插入同步指令 - auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); - auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); - auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); - auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + } - // set 在 yieldDefiningOp 后 - builder.setInsertionPointAfter(yieldDefiningOp); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + if (usesIterArg) { + firstUser = &op; + break; + } + } + // map 内找到对应的iterType,iterType由首次在loop内使用到的op定义 + if (!firstUser) { + continue; + } + CoreType iterType = CoreType::CUBE_AND_VECTOR; + if (valueTypes->find(firstUser->getResult(0)) != valueTypes->end()) { + iterType = valueTypes->find(firstUser->getResult(0))->second; + } - mlir::Value srcResult = yieldDefiningOp->getResult(0); + // 获取对应yield + mlir::Value yieldOperand = yieldOp->getOperand(i); + CoreType yieldType = CoreType::CUBE_AND_VECTOR; + if (valueTypes->find(yieldOperand) != valueTypes->end()) { + yieldType = valueTypes->find(yieldOperand)->second; + } + mlir::Operation *yieldDefiningOp = yieldOperand.getDefiningOp(); - // // 1. 插入数据搬运 - insertCubeToVectorDataMovement(yieldDefiningOp, firstUser, srcResult, builder, loc, iterArg); + if (yieldType == CoreType::CUBE_ONLY && iterType == CoreType::VECTOR_ONLY) { - // wait 在 firstUser 前 - builder.setInsertionPoint(firstUser); - coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - // llvm::outs() << "yieldOp" << yieldDefiningOp << "iterargs" << firstUser << "\n"; - // llvm::outs() << "Inserted CUBE->VECTOR sync and data movement (flag=" << flag << ")\n"; + // 2. 插入同步指令 + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + + // set 在 yieldDefiningOp 后 + builder.setInsertionPointAfter(yieldDefiningOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + mlir::Value srcResult = yieldDefiningOp->getResult(0); + + // // 1. 插入数据搬运 + insertCubeToVectorDataMovement(yieldDefiningOp, firstUser, srcResult, + builder, loc, iterArg); + + // wait 在 firstUser 前 + builder.setInsertionPoint(firstUser); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + // llvm::outs() << "yieldOp" << yieldDefiningOp << "iterargs" << firstUser + // << "\n"; llvm::outs() << "Inserted CUBE->VECTOR sync and data movement + // (flag=" << flag << ")\n"; } // VECTOR -> CUBE - else if (yieldType == CoreType::VECTOR_ONLY && iterType == CoreType::CUBE_ONLY) { - - // 2. 插入同步指令 - auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); - auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); - auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); - auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); - - // set 在 yieldDefiningOp 后 - builder.setInsertionPointAfter(yieldDefiningOp); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - - // 1. 插入数据搬运 - // insertVectorToCubeDataMovement(yieldDefiningOp, firstUser, srcResult, builder, loc, iterArg); - - // wait 在 firstUser 前 - builder.setInsertionPoint(firstUser); - coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - // llvm::outs() << "yieldOp" << yieldDefiningOp << "iterargs" << firstUser << "\n"; - // llvm::outs() << "Inserted VECTOR->CUBE sync and data movement (flag=" << flag << ")\n"; - } + else if (yieldType == CoreType::VECTOR_ONLY && + iterType == CoreType::CUBE_ONLY) { + + // 2. 插入同步指令 + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto waitPipe = + PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + + // set 在 yieldDefiningOp 后 + builder.setInsertionPointAfter(yieldDefiningOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运 + // insertVectorToCubeDataMovement(yieldDefiningOp, firstUser, srcResult, + // builder, loc, iterArg); + + // wait 在 firstUser 前 + builder.setInsertionPoint(firstUser); + coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + // llvm::outs() << "yieldOp" << yieldDefiningOp << "iterargs" << firstUser + // << "\n"; llvm::outs() << "Inserted VECTOR->CUBE sync and data movement + // (flag=" << flag << ")\n"; } + } } // 获取节点的设备类型 -CoreType DAGSyncPass::getNodeDeviceType(OpNode *node, llvm::DenseMap *valueTypes) -{ - if (!node || !node->op) { - return CoreType::CUBE_AND_VECTOR; +CoreType DAGSyncPass::getNodeDeviceType( + OpNode *node, llvm::DenseMap *valueTypes) { + if (!node || !node->op) { + return CoreType::CUBE_AND_VECTOR; + } + + // 尝试从节点的结果中获取设备类型 + // 通常使用第一个结果来代表节点的设备类型 + if (node->op->getNumResults() > 0) { + mlir::Value result = node->op->getResult(0); + auto it = valueTypes->find(result); + if (it != valueTypes->end()) { + return it->second; } + } - // 尝试从节点的结果中获取设备类型 - // 通常使用第一个结果来代表节点的设备类型 - if (node->op->getNumResults() > 0) { - mlir::Value result = node->op->getResult(0); - auto it = valueTypes->find(result); - if (it != valueTypes->end()) { - return it->second; - } - } + // 如果没有找到,检查操作数 + // for (mlir::Value operand : node->op->getOperands()) { + // auto it = valueTypes->find(operand); + // if (it != valueTypes->end()) { + // return it->second; + // } + // } - // 如果没有找到,检查操作数 - // for (mlir::Value operand : node->op->getOperands()) { - // auto it = valueTypes->find(operand); - // if (it != valueTypes->end()) { - // return it->second; - // } - // } - - return CoreType::CUBE_AND_VECTOR; // 默认 + return CoreType::CUBE_AND_VECTOR; // 默认 } // 判断是否需要vector<->cube同步 -bool DAGSyncPass::needVectorCubeSync(CoreType src, CoreType dst) -{ - return (src == CoreType::VECTOR_ONLY && dst == CoreType::CUBE_ONLY) || - (src == CoreType::CUBE_ONLY && dst == CoreType::VECTOR_ONLY); +bool DAGSyncPass::needVectorCubeSync(CoreType src, CoreType dst) { + return (src == CoreType::VECTOR_ONLY && dst == CoreType::CUBE_ONLY) || + (src == CoreType::CUBE_ONLY && dst == CoreType::VECTOR_ONLY); } // 获取 tensor 类型 mlir::RankedTensorType DAGSyncPass::getTensorType(mlir::Value tensorValue) { - if (auto tensorType = dyn_cast(tensorValue.getType())) { - return tensorType; - } - return nullptr; + if (auto tensorType = + dyn_cast(tensorValue.getType())) { + return tensorType; + } + return nullptr; } // 替换操作数 -void DAGSyncPass::replaceOperandWithNewValue(mlir::Operation *dstOp, mlir::Value oldValue, - mlir::Value newValue) { - for (unsigned i = 0; i < dstOp->getNumOperands(); ++i) { - if (dstOp->getOperand(i) == oldValue) { - dstOp->setOperand(i, newValue); - // llvm::outs() << "Replaced operand " << i << " of " << dstOp->getName().getStringRef() - // << " with new value\n"; - } +void DAGSyncPass::replaceOperandWithNewValue(mlir::Operation *dstOp, + mlir::Value oldValue, + mlir::Value newValue) { + for (unsigned i = 0; i < dstOp->getNumOperands(); ++i) { + if (dstOp->getOperand(i) == oldValue) { + dstOp->setOperand(i, newValue); + // llvm::outs() << "Replaced operand " << i << " of " << + // dstOp->getName().getStringRef() + // << " with new value\n"; } + } } // 修改 getOrCreateAllocation 函数,将 alloc 提到函数最外层 -mlir::Value DAGSyncPass::getOrCreateAllocation(mlir::Operation *op, mlir::Type tensorType, +mlir::Value DAGSyncPass::getOrCreateAllocation(mlir::Operation *op, + mlir::Type tensorType, hivm::AddressSpace addressSpace, - mlir::OpBuilder &builder, mlir::Location loc) { - auto rankedTensorType = cast(tensorType); - auto elementType = rankedTensorType.getElementType(); - auto shape = rankedTensorType.getShape(); - - auto addressSpaceAttr = hivm::AddressSpaceAttr::get(builder.getContext(), addressSpace); - auto memrefType = mlir::MemRefType::get(shape, elementType, /*layout=*/nullptr, addressSpaceAttr); - - // 查找是否已经存在相同类型的 allocation(在函数的 entry block 中) - mlir::Operation* funcOp = op; - while (funcOp && !mlir::isa(funcOp)) { - funcOp = funcOp->getParentOp(); - } - - if (auto func = mlir::dyn_cast(funcOp)) { - // 在函数的 entry block 中查找现有的 allocation - mlir::Block& entryBlock = func.getBody().front(); - // for (auto& blockOp : entryBlock) { - // if (auto allocOp = mlir::dyn_cast(&blockOp)) { - // if (allocOp.getType() == memrefType) { - // // 找到匹配的 allocation,直接复用 - // llvm::outs() << "Reusing existing allocation: " << allocOp << "\n"; - // return allocOp.getResult(); - // } - // } - // } - - // 没有找到现有的 allocation,在函数开头创建新的 - builder.setInsertionPointToStart(&entryBlock); - return builder.create(loc, memrefType); - } + mlir::OpBuilder &builder, + mlir::Location loc) { + auto rankedTensorType = cast(tensorType); + auto elementType = rankedTensorType.getElementType(); + auto shape = rankedTensorType.getShape(); + + auto addressSpaceAttr = + hivm::AddressSpaceAttr::get(builder.getContext(), addressSpace); + auto memrefType = mlir::MemRefType::get(shape, elementType, + /*layout=*/nullptr, addressSpaceAttr); + + // 查找是否已经存在相同类型的 allocation(在函数的 entry block 中) + mlir::Operation *funcOp = op; + while (funcOp && !mlir::isa(funcOp)) { + funcOp = funcOp->getParentOp(); + } + + if (auto func = mlir::dyn_cast(funcOp)) { + // 在函数的 entry block 中查找现有的 allocation + mlir::Block &entryBlock = func.getBody().front(); + // for (auto& blockOp : entryBlock) { + // if (auto allocOp = mlir::dyn_cast(&blockOp)) { + // if (allocOp.getType() == memrefType) { + // // 找到匹配的 allocation,直接复用 + // llvm::outs() << "Reusing existing allocation: " << allocOp << + // "\n"; return allocOp.getResult(); + // } + // } + // } - // 如果没有找到函数,回退到原逻辑 - builder.setInsertionPoint(op); + // 没有找到现有的 allocation,在函数开头创建新的 + builder.setInsertionPointToStart(&entryBlock); return builder.create(loc, memrefType); + } + + // 如果没有找到函数,回退到原逻辑 + builder.setInsertionPoint(op); + return builder.create(loc, memrefType); } // 插入 CUBE -> VECTOR 数据搬运 -void DAGSyncPass::insertCubeToVectorDataMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, - mlir::Value srcResult, mlir::OpBuilder &builder, - mlir::Location loc, mlir::Value iterArgs) { - auto srcTensorType = getTensorType(srcResult); - if (!srcTensorType) { - return; - } - - // 1. 在 srcOp 之后创建 UB 空间的 memref.alloc - builder.setInsertionPointAfter(srcOp); - mlir::Value ubAlloc = getOrCreateAllocation(srcOp, srcTensorType, - hivm::AddressSpace::UB, builder, loc); - - // 2. 创建 fixpipe 指令 - builder.setInsertionPointAfter(srcOp); - FixpipeDMAModeAttr dmaModeAttr = FixpipeDMAModeAttr::get(builder.getContext(), FixpipeDMAMode::NZ2ND); - - auto fixpipeOp = builder.create( - loc, - mlir::TypeRange{}, // 没有返回值 - srcResult, // src - ubAlloc, // dst - /*unit_flag_cond=*/mlir::ValueRange{}, - /*dma_mode=*/dmaModeAttr, - /*dual_dst_mode=*/nullptr, - /*pre_quant=*/nullptr, - /*pre_relu=*/nullptr, - /*channel_split=*/nullptr, - /*unit_flag_mode=*/mlir::ArrayAttr{}); - - llvm::outs() << "Inserted fixpipe after " << srcOp->getName().getStringRef() - << " for CUBE->VECTOR data movement\n"; - - // 3. 在 dstOp 前创建 memory_space_cast 和 to_tensor - builder.setInsertionPoint(dstOp); - - // memory_space_cast(如果需要) - mlir::Value plainMemref = ubAlloc; - auto memrefType = cast(ubAlloc.getType()); - if (memrefType.getMemorySpace()) { - auto plainMemrefType = mlir::MemRefType::get(memrefType.getShape(), - memrefType.getElementType()); - plainMemref = builder.create(loc, plainMemrefType, ubAlloc); - (*valueTypes)[plainMemref] = CoreType::VECTOR_ONLY; - } - - // 4. 创建 to_tensor - auto toTensorOp = builder.create( - loc, - srcTensorType, // 原始的 tensor 类型 - plainMemref, - /*restrict=*/true, - /*writable=*/true - ); - (*valueTypes)[toTensorOp.getResult()] = CoreType::VECTOR_ONLY; - - // 5. 替换 dstOp 的操作数 - if (!iterArgs) { - replaceOperandWithNewValue(dstOp, srcResult, toTensorOp.getResult()); - } else { - replaceOperandWithNewValue(dstOp, iterArgs, toTensorOp.getResult()); - } +void DAGSyncPass::insertCubeToVectorDataMovement( + mlir::Operation *srcOp, mlir::Operation *dstOp, mlir::Value srcResult, + mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iterArgs) { + auto srcTensorType = getTensorType(srcResult); + if (!srcTensorType) { + return; + } + + // 1. 在 srcOp 之后创建 UB 空间的 memref.alloc + builder.setInsertionPointAfter(srcOp); + mlir::Value ubAlloc = getOrCreateAllocation( + srcOp, srcTensorType, hivm::AddressSpace::UB, builder, loc); + + // 2. 创建 fixpipe 指令 + builder.setInsertionPointAfter(srcOp); + FixpipeDMAModeAttr dmaModeAttr = + FixpipeDMAModeAttr::get(builder.getContext(), FixpipeDMAMode::NZ2ND); + + auto fixpipeOp = + builder.create(loc, mlir::TypeRange{}, // 没有返回值 + srcResult, // src + ubAlloc, // dst + /*unit_flag_cond=*/mlir::ValueRange{}, + /*dma_mode=*/dmaModeAttr, + /*dual_dst_mode=*/nullptr, + /*pre_quant=*/nullptr, + /*pre_relu=*/nullptr, + /*channel_split=*/nullptr, + /*unit_flag_mode=*/mlir::ArrayAttr{}); + + llvm::outs() << "Inserted fixpipe after " << srcOp->getName().getStringRef() + << " for CUBE->VECTOR data movement\n"; + + // 3. 在 dstOp 前创建 memory_space_cast 和 to_tensor + builder.setInsertionPoint(dstOp); + + // memory_space_cast(如果需要) + mlir::Value plainMemref = ubAlloc; + auto memrefType = cast(ubAlloc.getType()); + if (memrefType.getMemorySpace()) { + auto plainMemrefType = mlir::MemRefType::get(memrefType.getShape(), + memrefType.getElementType()); + plainMemref = builder.create( + loc, plainMemrefType, ubAlloc); + (*valueTypes)[plainMemref] = CoreType::VECTOR_ONLY; + } + + // 4. 创建 to_tensor + auto toTensorOp = builder.create( + loc, + srcTensorType, // 原始的 tensor 类型 + plainMemref, + /*restrict=*/true, + /*writable=*/true); + (*valueTypes)[toTensorOp.getResult()] = CoreType::VECTOR_ONLY; + + // 5. 替换 dstOp 的操作数 + if (!iterArgs) { + replaceOperandWithNewValue(dstOp, srcResult, toTensorOp.getResult()); + } else { + replaceOperandWithNewValue(dstOp, iterArgs, toTensorOp.getResult()); + } } static uint64_t getElemBytesForAlign(Type t) { @@ -394,7 +417,8 @@ static FailureOr getBlockElemsFor32BAlign(Type elemType) { return kAlignBytes / elemBytes; } -static std::optional> newCbubAllocShape(memref::AllocOp allocOp) { +static std::optional> +newCbubAllocShape(memref::AllocOp allocOp) { auto type = dyn_cast(allocOp.getType()); // 仅支持静态 2D MemRef if (!type || type.getRank() != 2) @@ -419,432 +443,460 @@ static std::optional> newCbubAllocShape(memref::AllocOp } // 修改 VECTOR->CUBE 数据搬运函数 -void DAGSyncPass::insertVectorToCubeDataMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, Operation* posOp, - mlir::Value srcResult, mlir::OpBuilder &builder, - mlir::Location loc, llvm::DenseMap* valueMap) { - auto srcTensorType = getTensorType(srcResult); - if (!srcTensorType) { - return; - } - if (isa(srcOp) && isa(dstOp)) { - return; - } - - // 1. 在 srcOp 之后创建 UB 空间的 memref.alloc(用于 to_memref) - builder.setInsertionPointAfter(srcOp); - - // 首先创建 UB 空间的 memref type - auto ubSpaceAttr = hivm::AddressSpaceAttr::get(builder.getContext(), hivm::AddressSpace::UB); - auto ubMemrefType = mlir::MemRefType::get(srcTensorType.getShape(), - srcTensorType.getElementType(), - /*layout=*/nullptr, - ubSpaceAttr); - - // 创建 bufferization.to_memref - if (srcOp->getBlock() == dstOp->getBlock()) { - builder.setInsertionPoint(posOp); - } - auto toMemrefOp = builder.create( - loc, - ubMemrefType, - srcResult - ); - - // 2. 创建 CBUF 空间的 memref.alloc(用于 copy 的目标) - mlir::Value cbufAllocOld = getOrCreateAllocation(srcOp, srcTensorType, - hivm::AddressSpace::L1, builder, loc); - auto cbufShape = *newCbubAllocShape(dyn_cast(cbufAllocOld.getDefiningOp())); - // 获取旧的memref类型并创建新的类型 - auto oldType = dyn_cast(cbufAllocOld.getType()); - - // 获取新的维度数量 - unsigned newRank = cbufShape.size(); - - // 方法1:创建新的恒等布局映射 - AffineMap identityMap = builder.getMultiDimIdentityMap(newRank); - MemRefLayoutAttrInterface layout = AffineMapAttr::get(identityMap); - - // 方法2:如果旧类型有布局,尝试调整它(更安全的选择) - // 先检查旧类型是否有布局 - if (auto oldLayout = oldType.getLayout()) { - if (auto affineMapAttr = dyn_cast(oldLayout)) { - // 如果旧布局是AffineMap,尝试创建新的恒等映射 - // 因为维度改变,旧的affine map可能不再有效 - layout = AffineMapAttr::get(identityMap); - } else { - // 对于其他类型的布局,可能需要特殊处理 - layout = oldLayout; - } +void DAGSyncPass::insertVectorToCubeDataMovement( + mlir::Operation *srcOp, mlir::Operation *dstOp, Operation *posOp, + mlir::Value srcResult, mlir::OpBuilder &builder, mlir::Location loc, + llvm::DenseMap *valueMap) { + auto srcTensorType = getTensorType(srcResult); + if (!srcTensorType) { + return; + } + if (isa(srcOp) && isa(dstOp)) { + return; + } + + // 1. 在 srcOp 之后创建 UB 空间的 memref.alloc(用于 to_memref) + builder.setInsertionPointAfter(srcOp); + + // 首先创建 UB 空间的 memref type + auto ubSpaceAttr = + hivm::AddressSpaceAttr::get(builder.getContext(), hivm::AddressSpace::UB); + auto ubMemrefType = mlir::MemRefType::get(srcTensorType.getShape(), + srcTensorType.getElementType(), + /*layout=*/nullptr, ubSpaceAttr); + + // 创建 bufferization.to_memref + if (srcOp->getBlock() == dstOp->getBlock()) { + builder.setInsertionPoint(posOp); + } + auto toMemrefOp = + builder.create(loc, ubMemrefType, srcResult); + + // 2. 创建 CBUF 空间的 memref.alloc(用于 copy 的目标) + mlir::Value cbufAllocOld = getOrCreateAllocation( + srcOp, srcTensorType, hivm::AddressSpace::L1, builder, loc); + auto cbufShape = *newCbubAllocShape( + dyn_cast(cbufAllocOld.getDefiningOp())); + // 获取旧的memref类型并创建新的类型 + auto oldType = dyn_cast(cbufAllocOld.getType()); + + // 获取新的维度数量 + unsigned newRank = cbufShape.size(); + + // 方法1:创建新的恒等布局映射 + AffineMap identityMap = builder.getMultiDimIdentityMap(newRank); + MemRefLayoutAttrInterface layout = AffineMapAttr::get(identityMap); + + // 方法2:如果旧类型有布局,尝试调整它(更安全的选择) + // 先检查旧类型是否有布局 + if (auto oldLayout = oldType.getLayout()) { + if (auto affineMapAttr = dyn_cast(oldLayout)) { + // 如果旧布局是AffineMap,尝试创建新的恒等映射 + // 因为维度改变,旧的affine map可能不再有效 + layout = AffineMapAttr::get(identityMap); + } else { + // 对于其他类型的布局,可能需要特殊处理 + layout = oldLayout; } - - // 创建新的alloc类型 - auto newAllocType = MemRefType::get( - cbufShape, - oldType.getElementType(), - layout, // 使用新创建的布局 - oldType.getMemorySpace() - ); - - builder.setInsertionPoint(cbufAllocOld.getDefiningOp()); - // 创建新的alloc操作 - auto cbufAlloc = builder.create( - cbufAllocOld.getDefiningOp()->getLoc(), - newAllocType - ); - - builder.setInsertionPointAfter(toMemrefOp); - // 3. 创建 copy 指令(src 是 ub memref,dst 是 cbuf memref) - auto copyOp = builder.create( - loc, - mlir::TypeRange{}, // 没有返回值 - toMemrefOp.getResult(), // src (memref in UB) - cbufAlloc // dst (memref in CBUF) - ); - - // llvm::outs() << "Inserted copy after " << srcOp->getName().getStringRef() - // << " for VECTOR->CUBE data movement\n"; - - // 4. 在 dstOp 前创建 convert_layout - builder.setInsertionPoint(dstOp); - auto ndLayout = hivm::DataLayoutAttr::get(builder.getContext(), hivm::DataLayout::ND); - // 创建 convert_layout - auto convertLayoutOp = builder.create( - loc, - cbufAllocOld.getType(), // 输出类型与输入相同 - cbufAlloc, - ndLayout, // srcLayout - ndLayout // dstLayout - ); - (*valueTypes)[convertLayoutOp.getResult()] = CoreType::CUBE_ONLY; - - // 5. 创建 memory_space_cast - auto cbufMemrefType = cast(convertLayoutOp.getType()); - auto plainMemrefType = mlir::MemRefType::get(cbufMemrefType.getShape(), - cbufMemrefType.getElementType()); - - auto memspaceCastOp = builder.create( - loc, - plainMemrefType, - convertLayoutOp.getResult() - ); - (*valueTypes)[memspaceCastOp.getResult()] = CoreType::CUBE_ONLY; - - // 6. 创建 to_tensor - auto toTensorOp = builder.create( - loc, - srcTensorType, // 原始的 tensor 类型 - memspaceCastOp.getResult(), - /*restrict=*/true, - /*writable=*/true - ); - (*valueTypes)[toTensorOp.getResult()] = CoreType::CUBE_ONLY; - - // 7. 替换 dstOp 的操作数 - replaceOperandWithNewValue(dstOp, srcResult, toTensorOp.getResult()); + } + + // 创建新的alloc类型 + auto newAllocType = MemRefType::get(cbufShape, oldType.getElementType(), + layout, // 使用新创建的布局 + oldType.getMemorySpace()); + + builder.setInsertionPoint(cbufAllocOld.getDefiningOp()); + // 创建新的alloc操作 + auto cbufAlloc = builder.create( + cbufAllocOld.getDefiningOp()->getLoc(), newAllocType); + + builder.setInsertionPointAfter(toMemrefOp); + // 3. 创建 copy 指令(src 是 ub memref,dst 是 cbuf memref) + auto copyOp = + builder.create(loc, mlir::TypeRange{}, // 没有返回值 + toMemrefOp.getResult(), // src (memref in UB) + cbufAlloc // dst (memref in CBUF) + ); + + // llvm::outs() << "Inserted copy after " << srcOp->getName().getStringRef() + // << " for VECTOR->CUBE data movement\n"; + + // 4. 在 dstOp 前创建 convert_layout + builder.setInsertionPoint(dstOp); + auto ndLayout = + hivm::DataLayoutAttr::get(builder.getContext(), hivm::DataLayout::ND); + // 创建 convert_layout + auto convertLayoutOp = builder.create( + loc, + cbufAllocOld.getType(), // 输出类型与输入相同 + cbufAlloc, + ndLayout, // srcLayout + ndLayout // dstLayout + ); + (*valueTypes)[convertLayoutOp.getResult()] = CoreType::CUBE_ONLY; + + // 5. 创建 memory_space_cast + auto cbufMemrefType = cast(convertLayoutOp.getType()); + auto plainMemrefType = mlir::MemRefType::get(cbufMemrefType.getShape(), + cbufMemrefType.getElementType()); + + auto memspaceCastOp = builder.create( + loc, plainMemrefType, convertLayoutOp.getResult()); + (*valueTypes)[memspaceCastOp.getResult()] = CoreType::CUBE_ONLY; + + // 6. 创建 to_tensor + auto toTensorOp = builder.create( + loc, + srcTensorType, // 原始的 tensor 类型 + memspaceCastOp.getResult(), + /*restrict=*/true, + /*writable=*/true); + (*valueTypes)[toTensorOp.getResult()] = CoreType::CUBE_ONLY; + + // 7. 替换 dstOp 的操作数 + replaceOperandWithNewValue(dstOp, srcResult, toTensorOp.getResult()); } -Operation* DAGSyncPass::FindLastestPosition(Operation* srcOp, Graph &mainGraph, OpBuilder &builder) { - Operation* insertPos = nullptr; - auto opMap = mainGraph.getOpMapLegacy(); - auto valueTypes = &mainGraph.getValueTypes(); - // Find the first cube-dependent vector core operation. - for(auto nextOp = srcOp->getNextNode();nextOp!=nullptr; nextOp=nextOp->getNextNode()) { - auto nextType = getNodeDeviceType(opMap[nextOp], valueTypes); - if(nextType == CoreType::CUBE_ONLY) continue; - // No memref ops in IR yet; directly tracing operands - for(auto operand: nextOp->getOperands()) { - auto defOp = operand.getDefiningOp(); - auto defType = getNodeDeviceType(opMap[defOp], valueTypes); - if(defType == CoreType::CUBE_ONLY) { - //To prevent UB overflow, we need to break the dependency at the point where the result shape is minimized - // — i.e., trace upward to find the first broadcast. - for(auto prevOp = nextOp->getPrevNode(); prevOp != nullptr && prevOp != srcOp; prevOp = prevOp->getPrevNode()) { - if(isa(prevOp)) { - if(prevOp->getPrevNode() && isa(prevOp->getPrevNode())) { - return prevOp->getPrevNode(); - } - return prevOp; - } - } - // Can't find the result shape is minimized - return nextOp; +Operation *DAGSyncPass::FindLastestPosition(Operation *srcOp, Graph &mainGraph, + OpBuilder &builder) { + Operation *insertPos = nullptr; + auto opMap = mainGraph.getOpMapLegacy(); + auto valueTypes = &mainGraph.getValueTypes(); + // Find the first cube-dependent vector core operation. + for (auto nextOp = srcOp->getNextNode(); nextOp != nullptr; + nextOp = nextOp->getNextNode()) { + auto nextType = getNodeDeviceType(opMap[nextOp], valueTypes); + if (nextType == CoreType::CUBE_ONLY) + continue; + // No memref ops in IR yet; directly tracing operands + for (auto operand : nextOp->getOperands()) { + auto defOp = operand.getDefiningOp(); + auto defType = getNodeDeviceType(opMap[defOp], valueTypes); + if (defType == CoreType::CUBE_ONLY) { + // To prevent UB overflow, we need to break the dependency at the point + // where the result shape is minimized + // — i.e., trace upward to find the first broadcast. + for (auto prevOp = nextOp->getPrevNode(); + prevOp != nullptr && prevOp != srcOp; + prevOp = prevOp->getPrevNode()) { + if (isa(prevOp)) { + if (prevOp->getPrevNode() && + isa(prevOp->getPrevNode())) { + return prevOp->getPrevNode(); } + return prevOp; + } } + // Can't find the result shape is minimized + return nextOp; + } + } - // Once meet SyncBlockWaitOp, return now! - if(auto waitOp = dyn_cast(nextOp)) { - if(waitOp.getTcoreType() == hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR)) { - return nextOp; - } - } - insertPos = nextOp; + // Once meet SyncBlockWaitOp, return now! + if (auto waitOp = dyn_cast(nextOp)) { + if (waitOp.getTcoreType() == + hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR)) { + return nextOp; + } } - return insertPos; + insertPos = nextOp; + } + return insertPos; } -Operation* DAGSyncPass::FindEarliestPosition(Operation* dstOp, Graph &mainGraph, OpBuilder &builder) -{ - auto insertPos = dstOp; - auto opMap = mainGraph.getOpMapLegacy(); - auto valueTypes = &mainGraph.getValueTypes(); - for (auto prevOp = dstOp->getPrevNode(); prevOp != nullptr; prevOp = prevOp->getPrevNode()) { - if (dstOp->getBlock() != prevOp->getBlock()) continue; - // Once meet SyncBlockSetOp, return now! - if (auto waitOp = dyn_cast(prevOp)) { - if (waitOp.getTcoreType() == hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR)) { - return insertPos; - } - } - insertPos = prevOp; +Operation *DAGSyncPass::FindEarliestPosition(Operation *dstOp, Graph &mainGraph, + OpBuilder &builder) { + auto insertPos = dstOp; + auto opMap = mainGraph.getOpMapLegacy(); + auto valueTypes = &mainGraph.getValueTypes(); + for (auto prevOp = dstOp->getPrevNode(); prevOp != nullptr; + prevOp = prevOp->getPrevNode()) { + if (dstOp->getBlock() != prevOp->getBlock()) + continue; + // Once meet SyncBlockSetOp, return now! + if (auto waitOp = dyn_cast(prevOp)) { + if (waitOp.getTcoreType() == + hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR)) { + return insertPos; + } } - return insertPos; + insertPos = prevOp; + } + return insertPos; } // 主要的同步和数据搬运插入函数 -void DAGSyncPass::insertSyncAndMovement(mlir::Operation *srcOp, mlir::Operation *dstOp, - CoreType srcType, CoreType dstType, - mlir::OpBuilder &builder, int flag, llvm::DenseMap* valueMap, Graph &mainGraph) { - mlir::Location loc = srcOp->getLoc(); - // 保存当前的插入点 - mlir::OpBuilder::InsertionGuard guard(builder); - - // 检查是否是跨 block - mlir::Block *srcBlock = srcOp->getBlock(); - mlir::Block *dstBlock = dstOp->getBlock(); - bool sameBlock = (srcBlock == dstBlock); - - if (!sameBlock) { - // 检查是否是外层到内层的依赖 - bool dstIsInnerBlock = false; - mlir::Operation *dstParentOp = dstBlock->getParentOp(); - while (dstParentOp) { - if (dstParentOp->getBlock() == srcBlock) { - dstIsInnerBlock = true; - break; - } - if (dstParentOp->getBlock()) { - dstParentOp = dstParentOp->getBlock()->getParentOp(); - } else { - break; - } - } - - if (dstIsInnerBlock) { - insertSyncAndMovementForCrossBlock(srcOp, dstOp, srcType, dstType, builder, flag, true, valueMap, mainGraph); - return; - } +void DAGSyncPass::insertSyncAndMovement( + mlir::Operation *srcOp, mlir::Operation *dstOp, CoreType srcType, + CoreType dstType, mlir::OpBuilder &builder, int flag, + llvm::DenseMap *valueMap, Graph &mainGraph) { + mlir::Location loc = srcOp->getLoc(); + // 保存当前的插入点 + mlir::OpBuilder::InsertionGuard guard(builder); + + // 检查是否是跨 block + mlir::Block *srcBlock = srcOp->getBlock(); + mlir::Block *dstBlock = dstOp->getBlock(); + bool sameBlock = (srcBlock == dstBlock); + + if (!sameBlock) { + // 检查是否是外层到内层的依赖 + bool dstIsInnerBlock = false; + mlir::Operation *dstParentOp = dstBlock->getParentOp(); + while (dstParentOp) { + if (dstParentOp->getBlock() == srcBlock) { + dstIsInnerBlock = true; + break; + } + if (dstParentOp->getBlock()) { + dstParentOp = dstParentOp->getBlock()->getParentOp(); + } else { + break; + } } - // 同一 block 内的处理 - // 获取 srcOp 的输出(假设第一个结果) - if (srcOp->getNumResults() == 0) { - return; + if (dstIsInnerBlock) { + insertSyncAndMovementForCrossBlock(srcOp, dstOp, srcType, dstType, + builder, flag, true, valueMap, + mainGraph); + return; } - mlir::Value srcResult = srcOp->getResult(0); - - // CUBE -> VECTOR - if (srcType == CoreType::CUBE_ONLY && dstType == CoreType::VECTOR_ONLY) { - - // 2. 插入同步指令 - auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); - auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); - auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); - auto lastSetPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); - auto lastWaitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); - auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); - auto flagAddId = builder.getIntegerAttr(builder.getI64Type(), flag * 2); - auto lastFlagAddId = builder.getIntegerAttr(builder.getI64Type(), (flag - 1) * 2); - - // set 在 srcOp 后 - builder.setInsertionPointAfter(srcOp); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - - // wait 在 dstOp 前 - - auto posOp = FindEarliestPosition(dstOp, mainGraph, builder); - builder.setInsertionPoint(posOp); - coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } - // 1. 插入数据搬运 - insertCubeToVectorDataMovement(srcOp, dstOp, srcResult, builder, loc, nullptr); + // 同一 block 内的处理 + // 获取 srcOp 的输出(假设第一个结果) + if (srcOp->getNumResults() == 0) { + return; + } + mlir::Value srcResult = srcOp->getResult(0); + + // CUBE -> VECTOR + if (srcType == CoreType::CUBE_ONLY && dstType == CoreType::VECTOR_ONLY) { + + // 2. 插入同步指令 + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto lastSetPipe = + PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto lastWaitPipe = + PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + auto flagAddId = builder.getIntegerAttr(builder.getI64Type(), flag * 2); + auto lastFlagAddId = + builder.getIntegerAttr(builder.getI64Type(), (flag - 1) * 2); - // llvm::outs() << "Inserted CUBE->VECTOR sync and data movement (flag=" << flag << ")\n"; + // set 在 srcOp 后 + builder.setInsertionPointAfter(srcOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // wait 在 dstOp 前 + + auto posOp = FindEarliestPosition(dstOp, mainGraph, builder); + builder.setInsertionPoint(posOp); + coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运 + insertCubeToVectorDataMovement(srcOp, dstOp, srcResult, builder, loc, + nullptr); + + // llvm::outs() << "Inserted CUBE->VECTOR sync and data movement (flag=" << + // flag << ")\n"; + } + // VECTOR -> CUBE + else if (srcType == CoreType::VECTOR_ONLY && dstType == CoreType::CUBE_ONLY) { + + // 2. 插入同步指令 + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto lastSetPipe = + PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto lastWaitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + auto flagAddId = builder.getIntegerAttr(builder.getI64Type(), flag * 2); + auto lastFlagAddId = + builder.getIntegerAttr(builder.getI64Type(), (flag - 1) * 2); + + // set 在 srcOp 后 + // builder.setInsertionPointAfter(srcOp); + auto posOp = FindLastestPosition(srcOp, mainGraph, builder); + if (posOp) { + builder.setInsertionPoint(posOp); + } else { + builder.setInsertionPointAfter(srcOp); } - // VECTOR -> CUBE - else if (srcType == CoreType::VECTOR_ONLY && dstType == CoreType::CUBE_ONLY) { - - // 2. 插入同步指令 - auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); - auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); - auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); - auto lastSetPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); - auto lastWaitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); - auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); - auto flagAddId = builder.getIntegerAttr(builder.getI64Type(), flag * 2); - auto lastFlagAddId = builder.getIntegerAttr(builder.getI64Type(), (flag - 1) * 2); - - // set 在 srcOp 后 - // builder.setInsertionPointAfter(srcOp); - auto posOp = FindLastestPosition(srcOp, mainGraph, builder); - if (posOp) { - builder.setInsertionPoint(posOp); - } else { - builder.setInsertionPointAfter(srcOp); - } - auto setOp = builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + auto setOp = builder.create(loc, coreAttr, setPipe, + waitPipe, flagId); - // wait 在 dstOp 前 - builder.setInsertionPoint(dstOp); - coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + // wait 在 dstOp 前 + builder.setInsertionPoint(dstOp); + coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - // 1. 插入数据搬运 - insertVectorToCubeDataMovement(srcOp, dstOp, setOp, srcResult, builder, loc, valueMap); + // 1. 插入数据搬运 + insertVectorToCubeDataMovement(srcOp, dstOp, setOp, srcResult, builder, loc, + valueMap); - // llvm::outs() << "Inserted VECTOR->CUBE sync and data movement (flag=" << flag << ")\n"; - } + // llvm::outs() << "Inserted VECTOR->CUBE sync and data movement (flag=" << + // flag << ")\n"; + } } // 跨 block 的同步和数据搬运 -void DAGSyncPass::insertSyncAndMovementForCrossBlock(mlir::Operation *srcOp, mlir::Operation *dstOp, - CoreType srcType, CoreType dstType, - mlir::OpBuilder &builder, int flag, - bool dstIsInnerBlock, llvm::DenseMap* valueMap, Graph &mainGraph) { - if (!dstIsInnerBlock) { - insertSyncAndMovement(srcOp, dstOp, srcType, dstType, builder, flag, valueMap, mainGraph); - return; - } - - mlir::Location loc = srcOp->getLoc(); - mlir::Block *dstBlock = dstOp->getBlock(); +void DAGSyncPass::insertSyncAndMovementForCrossBlock( + mlir::Operation *srcOp, mlir::Operation *dstOp, CoreType srcType, + CoreType dstType, mlir::OpBuilder &builder, int flag, bool dstIsInnerBlock, + llvm::DenseMap *valueMap, Graph &mainGraph) { + if (!dstIsInnerBlock) { + insertSyncAndMovement(srcOp, dstOp, srcType, dstType, builder, flag, + valueMap, mainGraph); + return; + } - // 获取 srcOp 的输出 - if (srcOp->getNumResults() == 0) { - return; - } - mlir::Value srcResult = srcOp->getResult(0); + mlir::Location loc = srcOp->getLoc(); + mlir::Block *dstBlock = dstOp->getBlock(); - // CUBE -> VECTOR - if (srcType == CoreType::CUBE_ONLY && dstType == CoreType::VECTOR_ONLY) { + // 获取 srcOp 的输出 + if (srcOp->getNumResults() == 0) { + return; + } + mlir::Value srcResult = srcOp->getResult(0); - // 2. 插入同步指令(跨 block 特殊处理) - auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); - auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); - auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); - auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + // CUBE -> VECTOR + if (srcType == CoreType::CUBE_ONLY && dstType == CoreType::VECTOR_ONLY) { - // set 在 srcOp 后(外层) - builder.setInsertionPointAfter(srcOp); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + // 2. 插入同步指令(跨 block 特殊处理) + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_FIX); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_V); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); - // 1. 插入数据搬运(同 block 内逻辑) - insertCubeToVectorDataMovement(srcOp, dstOp, srcResult, builder, loc, nullptr); + // set 在 srcOp 后(外层) + builder.setInsertionPointAfter(srcOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - // wait 在内层 block 入口前 - mlir::Operation *parentOp = dstBlock->getParentOp(); - if (parentOp) { - while (srcOp->getBlock() != parentOp->getBlock()) { - parentOp = parentOp->getBlock()->getParentOp(); - } - builder.setInsertionPoint(parentOp); - coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - } else { - builder.setInsertionPoint(dstOp); - coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - } + // 1. 插入数据搬运(同 block 内逻辑) + insertCubeToVectorDataMovement(srcOp, dstOp, srcResult, builder, loc, + nullptr); + // wait 在内层 block 入口前 + mlir::Operation *parentOp = dstBlock->getParentOp(); + if (parentOp) { + while (srcOp->getBlock() != parentOp->getBlock()) { + parentOp = parentOp->getBlock()->getParentOp(); + } + builder.setInsertionPoint(parentOp); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } else { + builder.setInsertionPoint(dstOp); + coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), + hivm::TCoreType::VECTOR); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); } - // VECTOR -> CUBE - else if (srcType == CoreType::VECTOR_ONLY && dstType == CoreType::CUBE_ONLY) { - - // 2. 插入同步指令(跨 block 特殊处理) - auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); - auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); - auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); - auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); - - // set 在 srcOp 后(外层) - builder.setInsertionPointAfter(srcOp); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - - // 1. 插入数据搬运(同 block 内逻辑) - insertVectorToCubeDataMovement(srcOp, dstOp, srcOp, srcResult, builder, loc, valueMap); - - // wait 在内层 block 入口前 - mlir::Operation *parentOp = dstBlock->getParentOp(); - if (parentOp) { - builder.setInsertionPoint(parentOp); - coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - } else { - builder.setInsertionPoint(dstOp); - coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); - builder.create(loc, coreAttr, setPipe, waitPipe, flagId); - } - // llvm::outs() << "Inserted cross-block VECTOR->CUBE sync and data movement (flag=" << flag << ")\n"; + } + // VECTOR -> CUBE + else if (srcType == CoreType::VECTOR_ONLY && dstType == CoreType::CUBE_ONLY) { + + // 2. 插入同步指令(跨 block 特殊处理) + auto coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::VECTOR); + auto setPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE3); + auto waitPipe = PipeAttr::get(builder.getContext(), hivm::PIPE::PIPE_MTE1); + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + + // set 在 srcOp 后(外层) + builder.setInsertionPointAfter(srcOp); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + + // 1. 插入数据搬运(同 block 内逻辑) + insertVectorToCubeDataMovement(srcOp, dstOp, srcOp, srcResult, builder, loc, + valueMap); + + // wait 在内层 block 入口前 + mlir::Operation *parentOp = dstBlock->getParentOp(); + if (parentOp) { + builder.setInsertionPoint(parentOp); + coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); + } else { + builder.setInsertionPoint(dstOp); + coreAttr = + hivm::TCoreTypeAttr::get(builder.getContext(), hivm::TCoreType::CUBE); + builder.create(loc, coreAttr, setPipe, waitPipe, flagId); } + + // llvm::outs() << "Inserted cross-block VECTOR->CUBE sync and data movement + // (flag=" << flag << ")\n"; + } } void LegalizeDot(triton::FuncOp funcOp) { - mlir::OpBuilder builder(funcOp); - funcOp.walk([&](triton::DotOp dotOp) { - // 获取dot操作的输入 - Value a = dotOp.getOperands()[0]; - Value b = dotOp.getOperands()[1]; - Value c = dotOp.getOperands()[2]; // 累加器参数 - - // 检查累加器是否为全零常量 - bool isZeroAccumulator = false; - - // 检查是否直接是arith.constant 0 - if (auto constantOp = c.getDefiningOp()) { - if (auto denseAttr = dyn_cast(constantOp.getValue())) { - if (denseAttr.isSplat() && denseAttr.getSplatValue().getValueAsDouble() == 0.0) { - isZeroAccumulator = true; - } + mlir::OpBuilder builder(funcOp); + funcOp.walk([&](triton::DotOp dotOp) { + // 获取dot操作的输入 + Value a = dotOp.getOperands()[0]; + Value b = dotOp.getOperands()[1]; + Value c = dotOp.getOperands()[2]; // 累加器参数 + + // 检查累加器是否为全零常量 + bool isZeroAccumulator = false; + + // 检查是否直接是arith.constant 0 + if (auto constantOp = c.getDefiningOp()) { + if (auto denseAttr = dyn_cast(constantOp.getValue())) { + if (denseAttr.isSplat() && + denseAttr.getSplatValue().getValueAsDouble() == 0.0) { + isZeroAccumulator = true; } } + } - if (!isZeroAccumulator) { - // 创建新的零累加器 - Location loc = dotOp.getLoc(); - auto resultType = dotOp.getResult().getType(); + if (!isZeroAccumulator) { + // 创建新的零累加器 + Location loc = dotOp.getLoc(); + auto resultType = dotOp.getResult().getType(); - Value originalResult = dotOp.getResult(); - builder.setInsertionPoint(dotOp); - // 创建全零张量 - auto zeroAttr = DenseElementsAttr::get( - dyn_cast(resultType), - APFloat(0.0f)); - auto zeroConstant = builder.create(loc, zeroAttr); + Value originalResult = dotOp.getResult(); + builder.setInsertionPoint(dotOp); + // 创建全零张量 + auto zeroAttr = DenseElementsAttr::get( + dyn_cast(resultType), APFloat(0.0f)); + auto zeroConstant = builder.create(loc, zeroAttr); - // 创建新的dot操作,使用零作为累加器 - auto newDot = builder.create( - loc, resultType, a, b, zeroConstant); + // 创建新的dot操作,使用零作为累加器 + auto newDot = + builder.create(loc, resultType, a, b, zeroConstant); - // 创建加法操作,将新的dot结果与原来的累加器c相加 - auto addOp = builder.create(loc, newDot, c); + // 创建加法操作,将新的dot结果与原来的累加器c相加 + auto addOp = builder.create(loc, newDot, c); - // 用addOp替换原来的dotOp - originalResult.replaceAllUsesWith(addOp.getResult()); + // 用addOp替换原来的dotOp + originalResult.replaceAllUsesWith(addOp.getResult()); - // 删除原dotOp(如果它没有其他用途) - if (dotOp.use_empty()) { - dotOp.erase(); - } + // 删除原dotOp(如果它没有其他用途) + if (dotOp.use_empty()) { + dotOp.erase(); } - - }); + } + }); } -static void rewriteCopyChainForCbub( - hivm::CopyOp copyOp, - ArrayRef newShape, - OpBuilder &builder) { +static void rewriteCopyChainForCbub(hivm::CopyOp copyOp, + ArrayRef newShape, + OpBuilder &builder) { // 获取 copy 的输入(ins),应为 to_memref 的结果 Value insVal = copyOp.getOperands()[0]; @@ -867,8 +919,10 @@ static void rewriteCopyChainForCbub( SmallVector intermediateShape3D = {M, N / blk, blk}; SmallVector intermediateShapetrans = {N / blk, M, blk}; auto elementType = inputTensorType.getElementType(); - auto interTensor3DType = RankedTensorType::get(intermediateShape3D, elementType); - auto interTensortransType = RankedTensorType::get(intermediateShapetrans, elementType); + auto interTensor3DType = + RankedTensorType::get(intermediateShape3D, elementType); + auto interTensortransType = + RankedTensorType::get(intermediateShapetrans, elementType); auto finalTensorType = RankedTensorType::get(newShape, elementType); @@ -879,56 +933,52 @@ static void rewriteCopyChainForCbub( builder.setInsertionPointAfter(tensorOp); // 插入 triton.reshape 将 2D tensor 展开为 3D - auto reshape3DOp = builder.create( - loc, interTensor3DType, inputTensor); + auto reshape3DOp = + builder.create(loc, interTensor3DType, inputTensor); (*valueTypes)[reshape3DOp.getResult()] = CoreType::VECTOR_ONLY; // nark tiling dim for reshapeop auto markOp3d = builder.create(loc, reshape3DOp); auto tilingDimAttr3d = builder.getDictionaryAttr(SmallVector{ - NamedAttribute(builder.getStringAttr("1"), builder.getIndexAttr(1))}); + NamedAttribute(builder.getStringAttr("1"), builder.getIndexAttr(1))}); markOp3d->setAttr("tiling_dim_mapping", tilingDimAttr3d); // 插入 triton.trans 调整维度顺序 Insert tt.trans {order = [1, 0, 2]} SmallVector order = {1, 0, 2}; - auto orderAttr = builder.getDenseI32ArrayAttr(order); // OpBuilder supports this + auto orderAttr = + builder.getDenseI32ArrayAttr(order); // OpBuilder supports this auto transOp = builder.create( loc, interTensortransType, reshape3DOp.getResult(), orderAttr); (*valueTypes)[transOp.getResult()] = CoreType::VECTOR_ONLY; // 插入 triton.reshape 将 3D tensor 展开为 4D - auto reshape4DOp = builder.create( - loc, finalTensorType, transOp.getResult()); + auto reshape4DOp = builder.create(loc, finalTensorType, + transOp.getResult()); (*valueTypes)[reshape4DOp.getResult()] = CoreType::VECTOR_ONLY; // nark tiling dim for reshapeop auto markOp4d = builder.create(loc, reshape4DOp); auto tilingDimAttr4d = builder.getDictionaryAttr(SmallVector{ - NamedAttribute(builder.getStringAttr("1"), builder.getIndexAttr(1))}); + NamedAttribute(builder.getStringAttr("1"), builder.getIndexAttr(1))}); markOp4d->setAttr("tiling_dim_mapping", tilingDimAttr4d); // Create new to_memref builder.setInsertionPoint(toMemRefOp); - auto newMemRefType = MemRefType::get( - newShape, - elementType, - mlir::AffineMap{}, - toMemRefOp.getType().getMemorySpace()); + auto newMemRefType = MemRefType::get(newShape, elementType, mlir::AffineMap{}, + toMemRefOp.getType().getMemorySpace()); auto newToMemRefOp = builder.create( - toMemRefOp.getLoc(), - newMemRefType, - reshape4DOp.getResult()); + toMemRefOp.getLoc(), newMemRefType, reshape4DOp.getResult()); (*valueTypes)[newToMemRefOp.getResult()] = CoreType::VECTOR_ONLY; // Create NEW copyOp (replacing the old one) builder.setInsertionPoint(copyOp); auto resultTypes = copyOp->getResultTypes(); - auto newCopyOp = builder.create( - copyOp.getLoc(), - resultTypes, // TypeRange - reshape4DOp.getResult(), // src (ins) - copyOp.getOperands()[1] // dst (outs) - ); + auto newCopyOp = + builder.create(copyOp.getLoc(), + resultTypes, // TypeRange + reshape4DOp.getResult(), // src (ins) + copyOp.getOperands()[1] // dst (outs) + ); // 替换 uses 并清理旧 op copyOp.replaceAllUsesWith(newCopyOp); @@ -939,394 +989,403 @@ static void rewriteCopyChainForCbub( } template -OpTy createBlockSync(OpBuilder builder, - hivm::TCoreType coreType, - hivm::PIPE srcPipe, - hivm::PIPE dstPipe, - int flag, - Operation *cause) -{ - auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); - auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), coreType); - auto setPipe = PipeAttr::get(builder.getContext(), srcPipe); - auto waitPipe = PipeAttr::get(builder.getContext(), dstPipe); - return builder.create(cause->getLoc(), coreAttr, setPipe, waitPipe, flagId); +OpTy createBlockSync(OpBuilder builder, hivm::TCoreType coreType, + hivm::PIPE srcPipe, hivm::PIPE dstPipe, int flag, + Operation *cause) { + auto flagId = builder.getIntegerAttr(builder.getI64Type(), flag); + auto coreAttr = hivm::TCoreTypeAttr::get(builder.getContext(), coreType); + auto setPipe = PipeAttr::get(builder.getContext(), srcPipe); + auto waitPipe = PipeAttr::get(builder.getContext(), dstPipe); + return builder.create(cause->getLoc(), coreAttr, setPipe, waitPipe, + flagId); } // since we do not have llvm::set_intersects in this version... -template bool intersects(S1Ty &s1, S2Ty &s2) -{ - if (s1.size() > s2.size()) { - return intersects(s2, s1); - } +template bool intersects(S1Ty &s1, S2Ty &s2) { + if (s1.size() > s2.size()) { + return intersects(s2, s1); + } - return llvm::any_of(s1, [&](auto e) { return s2.count(e); }); + return llvm::any_of(s1, [&](auto e) { return s2.count(e); }); } -bool mayAlias(DataFlowSolver &solver, Value ptrA, Value ptrB) -{ - if (ptrA == ptrB) { - return true; - } - const auto *stateA = solver.lookupState>(ptrA); - const auto *stateB = solver.lookupState>(ptrB); - if (!stateA || !stateB) { // not triton ptr type - return true; - } - auto infoA = stateA->getValue(); - auto infoB = stateB->getValue(); - - return intersects(infoA.getAllocs(), infoB.getAllocs()); +bool mayAlias(DataFlowSolver &solver, Value ptrA, Value ptrB) { + if (ptrA == ptrB) { + return true; + } + const auto *stateA = solver.lookupState>(ptrA); + const auto *stateB = solver.lookupState>(ptrB); + if (!stateA || !stateB) { // not triton ptr type + return true; + } + auto infoA = stateA->getValue(); + auto infoB = stateB->getValue(); + + return intersects(infoA.getAllocs(), infoB.getAllocs()); } const size_t MAX_EXPECTED_PARENTS_COUNT = 8; -std::optional> findAncestorCommonBlock(mlir::Operation *opA, mlir::Operation *opB) -{ - if (opA->getBlock() == opB->getBlock()) { - return std::make_pair(opA, opB); - } - - // record all ancestors of opA - llvm::SmallPtrSet ancestorsA; - mlir::Operation *curr = opA; - while (curr) { - ancestorsA.insert(curr); - curr = curr->getParentOp(); - } - - // find the last ancestor of opB which is also the ancestor of opA - mlir::Operation *commonAncOp = nullptr; - curr = opB; - while (curr) { - if (ancestorsA.count(curr)) { - commonAncOp = curr; - break; - } - curr = curr->getParentOp(); - } - - if (!commonAncOp) { - return std::nullopt; +std::optional> +findAncestorCommonBlock(mlir::Operation *opA, mlir::Operation *opB) { + if (opA->getBlock() == opB->getBlock()) { + return std::make_pair(opA, opB); + } + + // record all ancestors of opA + llvm::SmallPtrSet ancestorsA; + mlir::Operation *curr = opA; + while (curr) { + ancestorsA.insert(curr); + curr = curr->getParentOp(); + } + + // find the last ancestor of opB which is also the ancestor of opA + mlir::Operation *commonAncOp = nullptr; + curr = opB; + while (curr) { + if (ancestorsA.count(curr)) { + commonAncOp = curr; + break; } + curr = curr->getParentOp(); + } - // find the ancestors in the given region - for (mlir::Region ®ion : commonAncOp->getRegions()) { - for (mlir::Block &block : region) { - auto *ancA = block.findAncestorOpInBlock(*opA); - auto *ancB = block.findAncestorOpInBlock(*opB); - if (ancA && ancB) { - return std::make_pair(ancA, ancB); - } - } - } + if (!commonAncOp) { return std::nullopt; + } + + // find the ancestors in the given region + for (mlir::Region ®ion : commonAncOp->getRegions()) { + for (mlir::Block &block : region) { + auto *ancA = block.findAncestorOpInBlock(*opA); + auto *ancB = block.findAncestorOpInBlock(*opB); + if (ancA && ancB) { + return std::make_pair(ancA, ancB); + } + } + } + return std::nullopt; } struct SyncCandidate { - CoreType srcCoreType; - Operation *setCause; - Operation *setAfter; - Operation *waitCause; - Operation *waitBefore; + CoreType srcCoreType; + Operation *setCause; + Operation *setAfter; + Operation *waitCause; + Operation *waitBefore; }; // setOp, waitOp -void createBlockSyncBetween(OpBuilder builder, - hivm::PIPE srcPipe, - hivm::PIPE dstPipe, - SyncCandidate candidate, - int flag) -{ - auto srcCoreType = toHivm(candidate.srcCoreType); - auto dstCoreType = toHivm(!candidate.srcCoreType); - - builder.setInsertionPointAfter(candidate.setAfter); - auto setOp = createBlockSync(builder, srcCoreType, srcPipe, dstPipe, flag, candidate.setCause); - builder.setInsertionPoint(candidate.waitBefore); - auto waitOp = createBlockSync(builder, dstCoreType, srcPipe, dstPipe, flag, candidate.waitCause); +void createBlockSyncBetween(OpBuilder builder, hivm::PIPE srcPipe, + hivm::PIPE dstPipe, SyncCandidate candidate, + int flag) { + auto srcCoreType = toHivm(candidate.srcCoreType); + auto dstCoreType = toHivm(!candidate.srcCoreType); + + builder.setInsertionPointAfter(candidate.setAfter); + auto setOp = createBlockSync( + builder, srcCoreType, srcPipe, dstPipe, flag, candidate.setCause); + builder.setInsertionPoint(candidate.waitBefore); + auto waitOp = createBlockSync( + builder, dstCoreType, srcPipe, dstPipe, flag, candidate.waitCause); }; -void addMemEffectsSync(triton::FuncOp funcOp, Graph *graph, OpBuilder &builder, int &syncFlag) -{ - DominanceInfo domInfo(funcOp); - PostDominanceInfo postDomInfo(funcOp); - DataFlowSolver solver; - solver.load(); - solver.load(); - - if (failed(solver.initializeAndRun(funcOp))) { - funcOp->emitWarning("SharedMemoryAliasAnalysis failed! This could lead to potential memory related issues! \n"); +void addMemEffectsSync(triton::FuncOp funcOp, Graph *graph, OpBuilder &builder, + int &syncFlag) { + DominanceInfo domInfo(funcOp); + PostDominanceInfo postDomInfo(funcOp); + DataFlowSolver solver; + solver.load(); + solver.load(); + + if (failed(solver.initializeAndRun(funcOp))) { + funcOp->emitWarning("SharedMemoryAliasAnalysis failed! This could lead to " + "potential memory related issues! \n"); + } + + // [(node, EffectInstance, LinearisationPt)] + llvm::SmallVector> memOps; + + // [(setAfter, waitBefore, srcOP, dstOp)][CoreType] + llvm::SmallVector candidates; + + funcOp.walk([&](MemoryEffectOpInterface memIface) { + auto *op = memIface.getOperation(); + if (llvm::isa(op)) { + return; } - // [(node, EffectInstance, LinearisationPt)] - llvm::SmallVector> memOps; + auto *currNode = graph->getOpMap()[op].get(); + SmallVector effects; - // [(setAfter, waitBefore, srcOP, dstOp)][CoreType] - llvm::SmallVector candidates; + memIface.getEffects(effects); - funcOp.walk([&](MemoryEffectOpInterface memIface) { - auto *op = memIface.getOperation(); - if (llvm::isa(op)) { - return; - } - - auto *currNode = graph->getOpMap()[op].get(); - SmallVector effects; - - memIface.getEffects(effects); - - for (auto &effect : effects) { - if (!isa(effect.getEffect())) { - continue; - } - memOps.emplace_back(currNode, effect); - bool isWrite = isa(effect.getEffect()); - for (auto &[prevNode, prevEffect] : memOps) { - if ((isa(prevEffect.getEffect()) || isWrite) && - mayAlias(solver, prevEffect.getValue(), effect.getValue()) && - prevNode->isOn() != currNode->isOn() // write is forced on single core type, so we are safe to judge - // based on whether the core types are different - ) { - CoreType srcCoreType = isWrite ? !currNode->isOn() : prevNode->isOn(); - auto opPair = findAncestorCommonBlock(prevNode->op, currNode->op); - if (!opPair.has_value()) { - op->emitWarning( - llvm::formatv("Unable to find ancestors in common block with {0}\n", *prevNode->op)); - continue; - } - auto [setAfter, waitBefore] = opPair.value(); - if (setAfter == waitBefore) { - continue; - } - candidates.push_back(SyncCandidate {srcCoreType, prevNode->op, setAfter, op, waitBefore}); - } - } + for (auto &effect : effects) { + if (!isa(effect.getEffect())) { + continue; + } + memOps.emplace_back(currNode, effect); + bool isWrite = isa(effect.getEffect()); + for (auto &[prevNode, prevEffect] : memOps) { + if ((isa(prevEffect.getEffect()) || isWrite) && + mayAlias(solver, prevEffect.getValue(), effect.getValue()) && + prevNode->isOn() != + currNode->isOn() // write is forced on single core type, so we + // are safe to judge based on whether the core + // types are different + ) { + CoreType srcCoreType = isWrite ? !currNode->isOn() : prevNode->isOn(); + auto opPair = findAncestorCommonBlock(prevNode->op, currNode->op); + if (!opPair.has_value()) { + op->emitWarning(llvm::formatv( + "Unable to find ancestors in common block with {0}\n", + *prevNode->op)); + continue; + } + auto [setAfter, waitBefore] = opPair.value(); + if (setAfter == waitBefore) { + continue; + } + candidates.push_back(SyncCandidate{srcCoreType, prevNode->op, + setAfter, op, waitBefore}); } - }); - - auto addBlockSyncCommon = [&builder, &syncFlag](SyncCandidate cand) { - llvm::dbgs() << "\n\n=== Insert sync between ===\n" - << *cand.setAfter << "\n" - << *cand.waitBefore << "\n=== Insert Sync End ===\n\n"; - - auto srcPipe = cand.srcCoreType == CoreType::CUBE_ONLY ? hivm::PIPE::PIPE_FIX : hivm::PIPE::PIPE_MTE2; - auto dstPipe = hivm::PIPE::PIPE_S; - createBlockSyncBetween(builder, srcPipe, dstPipe, cand, syncFlag % 14); - syncFlag++; - }; - - if (candidates.empty()) { - return; + } } + }); + + auto addBlockSyncCommon = [&builder, &syncFlag](SyncCandidate cand) { + llvm::dbgs() << "\n\n=== Insert sync between ===\n" + << *cand.setAfter << "\n" + << *cand.waitBefore << "\n=== Insert Sync End ===\n\n"; + + auto srcPipe = cand.srcCoreType == CoreType::CUBE_ONLY + ? hivm::PIPE::PIPE_FIX + : hivm::PIPE::PIPE_MTE2; + auto dstPipe = hivm::PIPE::PIPE_S; + createBlockSyncBetween(builder, srcPipe, dstPipe, cand, syncFlag % 14); + syncFlag++; + }; + + if (candidates.empty()) { + return; + } - auto setAfterDominate = [&domInfo](Operation *a, Operation *b) { - if (domInfo.dominates(a, b)) { - return true; - } - if (domInfo.dominates(b, a)) { - return false; - } - if (a->isAncestor(b)) { - return false; - } - if (b->isAncestor(a)) { - return true; - } - return false; - }; - - auto waitBeforePostDominate = [&postDomInfo](Operation *a, Operation *b) { - if (postDomInfo.postDominates(a, b)) { - return true; - } - if (postDomInfo.postDominates(b, a)) { - return false; - } - if (a->isAncestor(b)) { - return true; - } - if (b->isAncestor(a)) { - return false; - } - return false; - }; + auto setAfterDominate = [&domInfo](Operation *a, Operation *b) { + if (domInfo.dominates(a, b)) { + return true; + } + if (domInfo.dominates(b, a)) { + return false; + } + if (a->isAncestor(b)) { + return false; + } + if (b->isAncestor(a)) { + return true; + } + return false; + }; - llvm::sort(candidates, [&](const SyncCandidate &a, const SyncCandidate &b) { - if (a.setAfter != b.setAfter) { - return setAfterDominate(a.setAfter, b.setAfter); - } + auto waitBeforePostDominate = [&postDomInfo](Operation *a, Operation *b) { + if (postDomInfo.postDominates(a, b)) { + return true; + } + if (postDomInfo.postDominates(b, a)) { + return false; + } + if (a->isAncestor(b)) { + return true; + } + if (b->isAncestor(a)) { + return false; + } + return false; + }; - if (a.waitBefore != b.waitBefore) { - return waitBeforePostDominate(a.waitBefore, b.waitBefore); - } + llvm::sort(candidates, [&](const SyncCandidate &a, const SyncCandidate &b) { + if (a.setAfter != b.setAfter) { + return setAfterDominate(a.setAfter, b.setAfter); + } - return false; - }); + if (a.waitBefore != b.waitBefore) { + return waitBeforePostDominate(a.waitBefore, b.waitBefore); + } - for (auto [i, cand] : llvm::enumerate(candidates)) { - bool shouldInsert = true; - for (auto otherCand : ArrayRef(candidates).drop_front(i + 1)) { - bool duplicated = (cand.waitBefore == otherCand.waitBefore && cand.setAfter == otherCand.setAfter && - cand.srcCoreType == otherCand.srcCoreType); - bool containsOther = - (cand.srcCoreType == otherCand.srcCoreType && setAfterDominate(cand.setAfter, otherCand.setAfter) && - waitBeforePostDominate(cand.waitBefore, otherCand.waitBefore)); - if (duplicated || containsOther) { - shouldInsert = false; - break; - } - } + return false; + }); + + for (auto [i, cand] : llvm::enumerate(candidates)) { + bool shouldInsert = true; + for (auto otherCand : ArrayRef(candidates).drop_front(i + 1)) { + bool duplicated = (cand.waitBefore == otherCand.waitBefore && + cand.setAfter == otherCand.setAfter && + cand.srcCoreType == otherCand.srcCoreType); + bool containsOther = + (cand.srcCoreType == otherCand.srcCoreType && + setAfterDominate(cand.setAfter, otherCand.setAfter) && + waitBeforePostDominate(cand.waitBefore, otherCand.waitBefore)); + if (duplicated || containsOther) { + shouldInsert = false; + break; + } + } - if (shouldInsert) { - addBlockSyncCommon(cand); - } + if (shouldInsert) { + addBlockSyncCommon(cand); } + } } -void DAGSyncPass::runOnOperation() -{ - auto module = getOperation(); - mlir::OpBuilder builder(&getContext()); +void DAGSyncPass::runOnOperation() { + auto module = getOperation(); + mlir::OpBuilder builder(&getContext()); + + // 遍历所有函数 + for (auto funcOp : + llvm::make_early_inc_range(module.getOps())) { + // 跳过无效函数 + LegalizeDot(funcOp); + if (funcOp.getBody().empty()) { + continue; + } - // 遍历所有函数 - for (auto funcOp : llvm::make_early_inc_range(module.getOps())) { - // 跳过无效函数 - LegalizeDot(funcOp); - if (funcOp.getBody().empty()) { - continue; - } + // llvm::outs() << "\n====================================\n"; + // llvm::outs() << "处理函数: " << funcOp.getName() << "\n"; + // llvm::outs() << "====================================\n"; - // llvm::outs() << "\n====================================\n"; - // llvm::outs() << "处理函数: " << funcOp.getName() << "\n"; - // llvm::outs() << "====================================\n"; + auto unique_graph = Graph::fromMultiBlockFunc(funcOp); + std::shared_ptr shared_graph = std::move(unique_graph); + auto &main_graph = *shared_graph; - auto unique_graph = Graph::fromMultiBlockFunc(funcOp); - std::shared_ptr shared_graph = std::move(unique_graph); - auto& main_graph = *shared_graph; + auto funcName = funcOp.getName(); - auto funcName = funcOp.getName(); + // 获取 DAG 图的映射 + auto opMapRaw = main_graph.getOpMapLegacy(); + valueTypes = &main_graph.getValueTypes(); + auto *opMap = &opMapRaw; - // 获取 DAG 图的映射 - auto opMapRaw = main_graph.getOpMapLegacy(); - valueTypes = &main_graph.getValueTypes(); - auto *opMap = &opMapRaw; + if (!opMap || !valueTypes) { + llvm::errs() << "Warning: Failed to create DAG graph for function " + << funcOp.getName() << "\n"; + continue; + } - if (!opMap || !valueTypes) { - llvm::errs() << "Warning: Failed to create DAG graph for function " << funcOp.getName() << "\n"; - continue; - } + // 用于避免重复插入同步 + llvm::DenseSet> + processedPairs; + int syncFlag = 1; + addMemEffectsSync(funcOp, shared_graph.get(), builder, syncFlag); + + // 3. 使用 walk 遍历函数中的所有操作 + funcOp.walk([&](mlir::Operation *op) { + // 查找当前操作对应的 Node + auto nodeIt = opMap->find(op); + if (nodeIt == opMap->end()) { + // 这个操作不在 entry block 的 DAG 图中 + // 可能是嵌套在控制流内部的操作 + return; + } - // 用于避免重复插入同步 - llvm::DenseSet> processedPairs; - int syncFlag = 1; - addMemEffectsSync(funcOp, shared_graph.get(), builder, syncFlag); - - // 3. 使用 walk 遍历函数中的所有操作 - funcOp.walk([&](mlir::Operation *op) { - // 查找当前操作对应的 Node - auto nodeIt = opMap->find(op); - if (nodeIt == opMap->end()) { - // 这个操作不在 entry block 的 DAG 图中 - // 可能是嵌套在控制流内部的操作 - return; - } + OpNode *currentNode = nodeIt->second; - OpNode *currentNode = nodeIt->second; + // 检查是否是 scf.for 操作 + if (auto forOp = mlir::dyn_cast(op)) { + // 处理 scf.for 循环的特殊同步逻辑 + int temp = syncFlag % 14; + processScfForSync(forOp, currentNode, valueTypes, builder, temp); + } - // 检查是否是 scf.for 操作 - if (auto forOp = mlir::dyn_cast(op)) { - // 处理 scf.for 循环的特殊同步逻辑 - int temp = syncFlag % 14; - processScfForSync(forOp, currentNode, valueTypes, builder, temp); - } + // 获取当前节点的设备类型 + CoreType currentType = getNodeDeviceType(currentNode, valueTypes); + + // 打印操作信息(可选) + // if (!llvm::isa(op->getDialect())) { + // llvm::outs() << "操作: " << *op + // << " 设备类型: " + // << (currentType == CoreType::VECTOR_ONLY ? "VECTOR" : + // currentType == CoreType::CUBE_ONLY ? "CUBE" : + // "SCALAR") + // << "\n"; + // } + + // 4. 遍历当前节点的所有输入节点 + for (ValueNode *inputValNode : currentNode->getInputs()) { + auto inputOp = inputValNode->value.getDefiningOp(); + if (!inputOp || !opMap->contains(inputOp)) { + continue; + } - // 获取当前节点的设备类型 - CoreType currentType = getNodeDeviceType(currentNode, valueTypes); - - // 打印操作信息(可选) - // if (!llvm::isa(op->getDialect())) { - // llvm::outs() << "操作: " << *op - // << " 设备类型: " - // << (currentType == CoreType::VECTOR_ONLY ? "VECTOR" : - // currentType == CoreType::CUBE_ONLY ? "CUBE" : "SCALAR") - // << "\n"; - // } - - // 4. 遍历当前节点的所有输入节点 - for (ValueNode *inputValNode : currentNode->getInputs()) { - auto inputOp = inputValNode->value.getDefiningOp(); - if (!inputOp || !opMap->contains(inputOp)) { - continue; + auto inputNode = (*opMap)[inputOp]; + + // 获取输入节点的设备类型 + CoreType inputType = getNodeDeviceType(inputNode, valueTypes); + + // 5. 判断是否需要插入同步和数据搬运 + if (needVectorCubeSync(inputType, currentType)) { + // 检查是否已经处理过这对操作 + auto opPair = std::make_pair(inputOp, op); + if (processedPairs.insert(opPair).second) { + // 插入同步和数据搬运指令 + // 检查是否是跨 block 的依赖 + mlir::Block *srcBlock = inputOp->getBlock(); + mlir::Block *dstBlock = op->getBlock(); + + if (srcBlock == dstBlock) { + // 同一 block 内 + insertSyncAndMovement(inputOp, op, inputType, currentType, + builder, syncFlag % 14, valueTypes, + main_graph); + syncFlag++; + } else { + // 跨 block,判断是否是外层到内层 + llvm::outs() << "#########\n"; + bool dstIsInnerBlock = false; + mlir::Operation *dstParentOp = dstBlock->getParentOp(); + + // 向上查找,看 dstBlock 是否在 srcBlock 的区域内 + while (dstParentOp) { + if (dstParentOp->getBlock() == srcBlock) { + dstIsInnerBlock = true; + break; } - - auto inputNode = (*opMap)[inputOp]; - - // 获取输入节点的设备类型 - CoreType inputType = getNodeDeviceType(inputNode, valueTypes); - - // 5. 判断是否需要插入同步和数据搬运 - if (needVectorCubeSync(inputType, currentType)) { - // 检查是否已经处理过这对操作 - auto opPair = std::make_pair(inputOp, op); - if (processedPairs.insert(opPair).second) { - // 插入同步和数据搬运指令 - // 检查是否是跨 block 的依赖 - mlir::Block *srcBlock = inputOp->getBlock(); - mlir::Block *dstBlock = op->getBlock(); - - if (srcBlock == dstBlock) { - // 同一 block 内 - insertSyncAndMovement(inputOp, op, inputType, currentType, builder, syncFlag % 14, valueTypes, main_graph); - syncFlag ++; - } else { - // 跨 block,判断是否是外层到内层 - llvm::outs() << "#########\n"; - bool dstIsInnerBlock = false; - mlir::Operation *dstParentOp = dstBlock->getParentOp(); - - // 向上查找,看 dstBlock 是否在 srcBlock 的区域内 - while (dstParentOp) { - if (dstParentOp->getBlock() == srcBlock) { - dstIsInnerBlock = true; - break; - } - if (dstParentOp->getBlock()) { - dstParentOp = dstParentOp->getBlock()->getParentOp(); - } else { - break; - } - } - if (dstIsInnerBlock) { - - insertSyncAndMovementForCrossBlock(inputOp, op, inputType, currentType, - builder, syncFlag % 14, dstIsInnerBlock, valueTypes, main_graph); - syncFlag ++; - } - } - } + if (dstParentOp->getBlock()) { + dstParentOp = dstParentOp->getBlock()->getParentOp(); + } else { + break; } + } + if (dstIsInnerBlock) { + + insertSyncAndMovementForCrossBlock( + inputOp, op, inputType, currentType, builder, syncFlag % 14, + dstIsInnerBlock, valueTypes, main_graph); + syncFlag++; + } } - }); - - // llvm::outs() << "\n函数 " << funcOp.getName() << " 统计:\n"; - // llvm::outs() << " - 插入的总同步操作数: " << syncFlag << "\n"; - funcOp.walk([&](hivm::CopyOp copyOp) { - llvm::outs()<(copyOp.getOperands()[1].getType()).getShape(), builder); - }); - GraphManager::getInstance().registerGraph(funcName, shared_graph); - } + } + } + } + }); - // llvm::outs()<(copyOp.getOperands()[1].getType()).getShape(), + builder); + }); + GraphManager::getInstance().registerGraph(funcName, shared_graph); + } + // llvm::outs()<> mlir::triton::createDAGSyncPass() -{ - return std::make_unique(); +std::unique_ptr> mlir::triton::createDAGSyncPass() { + return std::make_unique(); } diff --git a/third_party/ascend/python/src/ir.cc b/third_party/ascend/python/src/ir.cc index 400c376a66..f131669ef3 100644 --- a/third_party/ascend/python/src/ir.cc +++ b/third_party/ascend/python/src/ir.cc @@ -231,8 +231,9 @@ void init_triton_ir(py::module &&m) { py::class_(m, "context", py::module_local()) .def(py::init<>()) - .def("__enter__", [](MLIRContext &self) -> MLIRContext& { return self; }, - py::return_value_policy::reference) + .def( + "__enter__", [](MLIRContext &self) -> MLIRContext & { return self; }, + py::return_value_policy::reference) .def("__exit__", [](MLIRContext &, py::object, py::object, py::object) -> bool { // Keep context alive for the duration of the scope. @@ -666,13 +667,13 @@ void init_triton_ir(py::module &&m) { "get_unit_attr", [](TritonOpBuilder &self) { return self.getBuilder().getUnitAttr(); }) .def("get_i64_array_attr", - [](TritonOpBuilder &self, const std::vector& array) { - return self.getBuilder().getI64ArrayAttr(array); - }) + [](TritonOpBuilder &self, const std::vector &array) { + return self.getBuilder().getI64ArrayAttr(array); + }) .def("get_type_array_attr", - [](TritonOpBuilder &self, const std::vector& array) { - return self.getBuilder().getTypeArrayAttr(array); - }) + [](TritonOpBuilder &self, const std::vector &array) { + return self.getBuilder().getTypeArrayAttr(array); + }) // Use arith.ConstantOp to create constants // Constants .def("get_int1", @@ -1719,58 +1720,61 @@ void init_triton_ir(py::module &&m) { printingFlags); } }) - .def("run", [](PassManager &self, ModuleOp &mod) { - // TODO: maybe dump module to file and print error for better - // diagnostics - - auto *context = mod.getContext(); - if (::triton::tools::getBoolEnv("MLIR_DISABLE_MULTITHREADING")) - context->disableMultithreading(); - - auto reproducerPath = - triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); - if (!reproducerPath.empty()) { - auto anchorName = self.getOpAnchorName(); - auto passes = self.getPasses(); - Operation *op = mod.getOperation(); - makeReproducer(anchorName, passes, op, reproducerPath); - context->disableMultithreading(); - } - - if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { - ::llvm::DebugFlag = true; - } - - if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); - !debugOnly.empty()) { - llvm::SmallVector split; - llvm::SmallVector storage; - llvm::SmallVector debugTypes; - - StringRef(debugOnly.c_str()).split(split, ','); - llvm::transform(split, std::back_inserter(debugTypes), - [&storage](StringRef str) { - // StringRefs are not always null-terminated. - // The purpose for this storage pattern is to - // produce a collection of C-strings that are. - storage.push_back(str.str()); - return storage.back().c_str(); - }); - - ::llvm::DebugFlag = true; - using namespace llvm; - setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); - } - - bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); - if (haveTiming) { - self.enableTiming(); - } - - if (failed(self.run(mod.getOperation()))) - throw std::runtime_error("PassManager::run failed"); - }, - py::call_guard()); + .def( + "run", + [](PassManager &self, ModuleOp &mod) { + // TODO: maybe dump module to file and print error for better + // diagnostics + + auto *context = mod.getContext(); + if (::triton::tools::getBoolEnv("MLIR_DISABLE_MULTITHREADING")) + context->disableMultithreading(); + + auto reproducerPath = + triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); + if (!reproducerPath.empty()) { + auto anchorName = self.getOpAnchorName(); + auto passes = self.getPasses(); + Operation *op = mod.getOperation(); + makeReproducer(anchorName, passes, op, reproducerPath); + context->disableMultithreading(); + } + + if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { + ::llvm::DebugFlag = true; + } + + if (auto debugOnly = + triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); + !debugOnly.empty()) { + llvm::SmallVector split; + llvm::SmallVector storage; + llvm::SmallVector debugTypes; + + StringRef(debugOnly.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(debugTypes), + [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + + ::llvm::DebugFlag = true; + using namespace llvm; + setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + } + + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + + if (failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); + }, + py::call_guard()); } void init_triton_env_vars(py::module &m) { diff --git a/third_party/ascend/triton_ascend.cc b/third_party/ascend/triton_ascend.cc index c13d394fd6..7e3a6a2e88 100644 --- a/third_party/ascend/triton_ascend.cc +++ b/third_party/ascend/triton_ascend.cc @@ -6,18 +6,18 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "AutoBlockify/Passes.h" -#include "TritonToHIVM/Passes.h" +#include "TritonAffinityOpt/Passes.h" #include "TritonToHFusion/Passes.h" +#include "TritonToHIVM/Passes.h" #include "TritonToLLVM/Passes.h" -#include "TritonAffinityOpt/Passes.h" #include "npu/Dialect/TritonAscend/IR/TritonAscendDialect.h" -#include "triton/Dialect/Triton/IR/Dialect.h" #include "ir.h" // TritonOpBuilder +#include "triton/Dialect/Triton/IR/Dialect.h" #include @@ -28,296 +28,297 @@ using namespace mlir; void init_triton_ascend_ir(py::module &&m) { auto *builder_cls = ir::getBuilderClass(); builder_cls - ->def("create_extract_scalar", - [](TritonOpBuilder &self, Value &src, std::vector &indices) -> Value { - llvm::SmallVector arg_indices; - for (const auto &i : indices) { - auto iTy = i.getType(); - if (!iTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), i); - arg_indices.push_back(v); - } else { - arg_indices.push_back(i); - } - } - auto ret = self.create(src, arg_indices); - return ret; - }) - .def("create_extract_slice", - [](TritonOpBuilder &self, Value &ful, std::vector &offs_vec, - std::vector &sizs_vec, std::vector &strd_vec) -> Value { - llvm::SmallVector offsets; - for (const auto &o : offs_vec) { - auto oTy = o.getType(); - if (!oTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), o); - offsets.push_back(v); - } else { - offsets.push_back(o); - } - } - llvm::SmallVector sizes; - llvm::SmallVector retSizes; - for (const auto &s : sizs_vec) { - auto v = self.create(s); - sizes.push_back(v); - retSizes.push_back(s); - } - llvm::SmallVector strides; - for (const auto &s : strd_vec) { - auto v = self.create(s); - strides.push_back(v); - } - auto retTy = RankedTensorType::get(retSizes, - cast(ful.getType()).getElementType()); - - return self.create(retTy, ful, offsets, sizes, strides); - }) - .def("create_insert_slice", - [](TritonOpBuilder &self, Value &ful, Value &sub, - std::vector &offs_vec, std::vector &sizs_vec, - std::vector &strd_vec) -> Value { - llvm::SmallVector offsets; - for (const auto &o : offs_vec) { - auto oTy = o.getType(); - if (!oTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), o); - offsets.push_back(v); - } else { - offsets.push_back(o); - } - } - llvm::SmallVector sizes; - llvm::SmallVector retSizes; - for (const auto &s : sizs_vec) { - auto v = self.create(s); - sizes.push_back(v); - retSizes.push_back(s); - } - llvm::SmallVector strides; - for (const auto &s : strd_vec) { - auto v = self.create(s); - strides.push_back(v); - } - auto retTy = RankedTensorType::get( - retSizes, - cast(ful.getType()).getElementType()); - auto ret = self.create(sub, ful, offsets, - sizes, strides); - return ret; - }) - .def("create_custom_op_for_inter_core_sync", - [](TritonOpBuilder &self, std::string &op_name, - std::string &mode_or_sender, int id) -> void { - auto args = self.getBuilder().getArrayAttr( - {self.getBuilder().getStringAttr(mode_or_sender), - self.getBuilder().getI32IntegerAttr(id)} - ); - self.create(op_name, args, ValueRange()); - }) - .def("create_index_select_simd", - [](TritonOpBuilder &self, Value &src, Value &index, int32_t dim, - std::vector &srcShape, std::vector &srcOffset, - std::vector &readShape, std::vector &returnShape) -> Value { - auto &builder = self.getBuilder(); - auto loc = self.getLastLoc(); - - // Get element type from source pointer - Type elemType; - if (auto ptrTy = dyn_cast(src.getType())) { - elemType = ptrTy.getPointeeType(); - } else { - llvm::report_fatal_error("index_select_simd: src must be pointer type"); - } - - // Create return tensor type - llvm::SmallVector retShape; - for (const auto &s : returnShape) { - retShape.push_back(s); - } - auto retTensorType = RankedTensorType::get(retShape, elemType); - - // Convert srcShape and srcOffset values to index type if needed - llvm::SmallVector srcShapeIndex; - for (auto val : srcShape) { - if (!val.getType().isIndex()) { - val = self.create(builder.getIndexType(), val); - } - srcShapeIndex.push_back(val); - } - - llvm::SmallVector srcOffsetIndex; - for (auto val : srcOffset) { - if (!val.getType().isIndex()) { - val = self.create(builder.getIndexType(), val); - } - srcOffsetIndex.push_back(val); - } - - // Create attributes - auto dimAttr = builder.getI32IntegerAttr(dim); - auto readShapeAttr = builder.getDenseI32ArrayAttr(readShape); - - // Create the IndexSelectSimdOp - // Parameter order must match TritonOps.td definition: - // src, index, dim, src_shape, src_offset, read_shape - auto indexSelectSimdOp = builder.create( - loc, - retTensorType, // result type - src, // src pointer - index, // index tensor - dimAttr, // dim attribute - srcShapeIndex, // src_shape (variadic, index type) - srcOffsetIndex, // src_offset (variadic, index type) - readShapeAttr // read_shape attribute - ); - - return indexSelectSimdOp.getResult(); - }) - .def("create_index_put", - [](TritonOpBuilder &self, Value &ptr, Value &index, - Value &value, const int32_t dim, const int64_t indexBoundary, - std::vector &endOffset, std::vector &startOffset, - std::vector &dstStride) -> void { - // dim need to be i32 type - auto dimI32Ty = self.getBuilder().getI32Type(); - auto dim_val = self.create(dim, dimI32Ty); - // indexBoundary need to be i64 type - auto BoundI64Ty = self.getBuilder().getI64Type(); - auto bound_val = self.create(indexBoundary, BoundI64Ty); - - self.create( - ptr, - index, - value, - dim_val, - bound_val, - endOffset, - startOffset, - dstStride - ); - }) - .def("create_gather_out_to_ub", - [](TritonOpBuilder &self, Value &src, Value &index, const int64_t indexBoundary, - const int32_t dim, std::vector &srcStride, std::vector &endOffset, - std::vector &startOffset, std::optional &other) -> Value { - auto elemTy = cast(src.getType()).getPointeeType(); - auto idxTy = cast(index.getType()); - auto idxShape = idxTy.getShape(); - std::vector retShape(idxShape.begin(), idxShape.end()); - auto resType = RankedTensorType::get(retShape, elemTy); - - // indexBoundary need to be i64 type - auto BoundI64Ty = self.getBuilder().getI64Type(); - auto bound_val = self.create(indexBoundary, BoundI64Ty); - // dim need to be i32 type - auto dimI32Ty = self.getBuilder().getI32Type(); - auto dim_val = self.create(dim, dimI32Ty); - return self.create( - resType, - src, - index, - bound_val, - dim_val, - srcStride, - endOffset, - startOffset, - other.value_or(Value()) - ); - }) - .def("create_scatter_ub_to_out", - [](TritonOpBuilder &self, Value &ptr, Value &value, Value &index, - const int64_t indexBoundary, const int32_t dim, std::vector &dstStride, - std::vector &endOffset, std::vector &startOffset) -> void { - auto idxTy = cast(index.getType()); - - // indexBoundary need to be i64 type - auto BoundI64Ty = self.getBuilder().getI64Type(); - auto bound_val = self.create(indexBoundary, BoundI64Ty); - // dim need to be i32 type - auto dimI32Ty = self.getBuilder().getI32Type(); - auto dim_val = self.create(dim, dimI32Ty); - - self.create( - ptr, - value, - index, - bound_val, - dim_val, - dstStride, - endOffset, - startOffset - ); - }) - // Add sort - .def("create_sort", - [](TritonOpBuilder &self, Value src, int64_t dim, bool descending) -> Value { - auto &builder = self.getBuilder(); - auto loc = self.getLastLoc(); - - auto dimAttr = builder.getI64IntegerAttr(dim); - auto descendingAttr = builder.getBoolAttr(descending); - - auto op = builder.create(loc, src, dimAttr, descendingAttr); - - return op->getResult(0); - }) - // Add flip - .def("create_flip", - [](TritonOpBuilder &self, Value src, int64_t dim) -> Value { - auto &builder = self.getBuilder(); - auto loc = self.getLastLoc(); - - auto dimAttr = builder.getI64IntegerAttr(dim); - - auto op = builder.create(loc, src, dimAttr); - - return op->getResult(0); - }) - .def("create_tanh", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) - // Add an annotation - .def("create_annotation", - [](TritonOpBuilder &self, Value &ptr, const std::string &attrKey, - Attribute &attrVal) { - auto annotationOp = self.create(ptr); - annotationOp->setAttr(self.getBuilder().getStringAttr(attrKey), - attrVal); - }); + ->def("create_extract_scalar", + [](TritonOpBuilder &self, Value &src, + std::vector &indices) -> Value { + llvm::SmallVector arg_indices; + for (const auto &i : indices) { + auto iTy = i.getType(); + if (!iTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), i); + arg_indices.push_back(v); + } else { + arg_indices.push_back(i); + } + } + auto ret = self.create(src, arg_indices); + return ret; + }) + .def("create_extract_slice", + [](TritonOpBuilder &self, Value &ful, std::vector &offs_vec, + std::vector &sizs_vec, std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); + + return self.create(retTy, ful, offsets, + sizes, strides); + }) + .def("create_insert_slice", + [](TritonOpBuilder &self, Value &ful, Value &sub, + std::vector &offs_vec, std::vector &sizs_vec, + std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); + auto ret = self.create(sub, ful, offsets, + sizes, strides); + return ret; + }) + .def("create_custom_op_for_inter_core_sync", + [](TritonOpBuilder &self, std::string &op_name, + std::string &mode_or_sender, int id) -> void { + auto args = self.getBuilder().getArrayAttr( + {self.getBuilder().getStringAttr(mode_or_sender), + self.getBuilder().getI32IntegerAttr(id)}); + self.create(op_name, args, ValueRange()); + }) + .def("create_index_select_simd", + [](TritonOpBuilder &self, Value &src, Value &index, int32_t dim, + std::vector &srcShape, std::vector &srcOffset, + std::vector &readShape, + std::vector &returnShape) -> Value { + auto &builder = self.getBuilder(); + auto loc = self.getLastLoc(); + + // Get element type from source pointer + Type elemType; + if (auto ptrTy = dyn_cast(src.getType())) { + elemType = ptrTy.getPointeeType(); + } else { + llvm::report_fatal_error( + "index_select_simd: src must be pointer type"); + } + + // Create return tensor type + llvm::SmallVector retShape; + for (const auto &s : returnShape) { + retShape.push_back(s); + } + auto retTensorType = RankedTensorType::get(retShape, elemType); + + // Convert srcShape and srcOffset values to index type if needed + llvm::SmallVector srcShapeIndex; + for (auto val : srcShape) { + if (!val.getType().isIndex()) { + val = self.create(builder.getIndexType(), + val); + } + srcShapeIndex.push_back(val); + } + + llvm::SmallVector srcOffsetIndex; + for (auto val : srcOffset) { + if (!val.getType().isIndex()) { + val = self.create(builder.getIndexType(), + val); + } + srcOffsetIndex.push_back(val); + } + + // Create attributes + auto dimAttr = builder.getI32IntegerAttr(dim); + auto readShapeAttr = builder.getDenseI32ArrayAttr(readShape); + + // Create the IndexSelectSimdOp + // Parameter order must match TritonOps.td definition: + // src, index, dim, src_shape, src_offset, read_shape + auto indexSelectSimdOp = + builder.create( + loc, + retTensorType, // result type + src, // src pointer + index, // index tensor + dimAttr, // dim attribute + srcShapeIndex, // src_shape (variadic, index type) + srcOffsetIndex, // src_offset (variadic, index type) + readShapeAttr // read_shape attribute + ); + + return indexSelectSimdOp.getResult(); + }) + .def("create_index_put", + [](TritonOpBuilder &self, Value &ptr, Value &index, Value &value, + const int32_t dim, const int64_t indexBoundary, + std::vector &endOffset, std::vector &startOffset, + std::vector &dstStride) -> void { + // dim need to be i32 type + auto dimI32Ty = self.getBuilder().getI32Type(); + auto dim_val = self.create(dim, dimI32Ty); + // indexBoundary need to be i64 type + auto BoundI64Ty = self.getBuilder().getI64Type(); + auto bound_val = + self.create(indexBoundary, BoundI64Ty); + + self.create(ptr, index, value, dim_val, + bound_val, endOffset, + startOffset, dstStride); + }) + .def("create_gather_out_to_ub", + [](TritonOpBuilder &self, Value &src, Value &index, + const int64_t indexBoundary, const int32_t dim, + std::vector &srcStride, std::vector &endOffset, + std::vector &startOffset, + std::optional &other) -> Value { + auto elemTy = cast(src.getType()).getPointeeType(); + auto idxTy = cast(index.getType()); + auto idxShape = idxTy.getShape(); + std::vector retShape(idxShape.begin(), idxShape.end()); + auto resType = RankedTensorType::get(retShape, elemTy); + + // indexBoundary need to be i64 type + auto BoundI64Ty = self.getBuilder().getI64Type(); + auto bound_val = + self.create(indexBoundary, BoundI64Ty); + // dim need to be i32 type + auto dimI32Ty = self.getBuilder().getI32Type(); + auto dim_val = self.create(dim, dimI32Ty); + return self.create( + resType, src, index, bound_val, dim_val, srcStride, endOffset, + startOffset, other.value_or(Value())); + }) + .def("create_scatter_ub_to_out", + [](TritonOpBuilder &self, Value &ptr, Value &value, Value &index, + const int64_t indexBoundary, const int32_t dim, + std::vector &dstStride, std::vector &endOffset, + std::vector &startOffset) -> void { + auto idxTy = cast(index.getType()); + + // indexBoundary need to be i64 type + auto BoundI64Ty = self.getBuilder().getI64Type(); + auto bound_val = + self.create(indexBoundary, BoundI64Ty); + // dim need to be i32 type + auto dimI32Ty = self.getBuilder().getI32Type(); + auto dim_val = self.create(dim, dimI32Ty); + + self.create( + ptr, value, index, bound_val, dim_val, dstStride, endOffset, + startOffset); + }) + // Add sort + .def("create_sort", + [](TritonOpBuilder &self, Value src, int64_t dim, + bool descending) -> Value { + auto &builder = self.getBuilder(); + auto loc = self.getLastLoc(); + + auto dimAttr = builder.getI64IntegerAttr(dim); + auto descendingAttr = builder.getBoolAttr(descending); + + auto op = builder.create(loc, src, dimAttr, + descendingAttr); + + return op->getResult(0); + }) + // Add flip + .def("create_flip", + [](TritonOpBuilder &self, Value src, int64_t dim) -> Value { + auto &builder = self.getBuilder(); + auto loc = self.getLastLoc(); + + auto dimAttr = builder.getI64IntegerAttr(dim); + + auto op = + builder.create(loc, src, dimAttr); + + return op->getResult(0); + }) + .def("create_tanh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + // Add an annotation + .def("create_annotation", + [](TritonOpBuilder &self, Value &ptr, const std::string &attrKey, + Attribute &attrVal) { + auto annotationOp = self.create(ptr); + annotationOp->setAttr(self.getBuilder().getStringAttr(attrKey), + attrVal); + }); } void init_triton_ascend_passes_ttir(py::module &&m) { - m.def("add_auto_blockify", [](mlir::PassManager &pm, - int autoBlockifySize) { + m.def("add_auto_blockify", [](mlir::PassManager &pm, int autoBlockifySize) { AutoBlockifyOptions opts; opts.autoBlockifySize = autoBlockifySize; - pm.addPass(mlir::triton::createAutoBlockifyPass(opts));}); + pm.addPass(mlir::triton::createAutoBlockifyPass(opts)); + }); m.def("add_triton_to_hfusion", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createTritonToHFusionPass());}); + pm.addPass(mlir::triton::createTritonToHFusionPass()); + }); m.def("add_triton_to_hivm", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createTritonToHIVMPass());}); + pm.addPass(mlir::triton::createTritonToHIVMPass()); + }); m.def("add_triton_to_llvm", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createTritonToLLVMPass());}); - + pm.addPass(mlir::triton::createTritonToLLVMPass()); + }); + m.def("add_bubble_up_operation", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createBubbleUpOperationPass());}); + pm.addPass(mlir::triton::createBubbleUpOperationPass()); + }); m.def("add_dag_sync", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createDAGSyncPass());}); - + pm.addPass(mlir::triton::createDAGSyncPass()); + }); + m.def("add_dag_scope", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createDAGScopePass());}); - + pm.addPass(mlir::triton::createDAGScopePass()); + }); + m.def("add_dag_ssbuffer", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createDAGSSBufferPass());}); + pm.addPass(mlir::triton::createDAGSSBufferPass()); + }); } // Forward declaration for ascend_ir bindings (defined in ascend_ir.cc) @@ -335,7 +336,7 @@ void init_triton_ascend(py::module &&m) { init_triton_ascend_passes_ttir(passes.def_submodule("ttir")); init_triton_ascend_ir(m.def_submodule("ascend_ir")); - + // Initialize ascend IR bindings (ascendnpu_ir_builder, scope/hivm dialects) init_ascend_ir(m.def_submodule("ir")); } diff --git a/third_party/ascend/tutorials/03-matrix-multiplication.py b/third_party/ascend/tutorials/03-matrix-multiplication.py index beae7d6f97..d4cbb0d35f 100644 --- a/third_party/ascend/tutorials/03-matrix-multiplication.py +++ b/third_party/ascend/tutorials/03-matrix-multiplication.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Matrix Multiplication =============== @@ -45,28 +44,19 @@ def get_autotune_config(): ) @triton.jit def matmul_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - # Matrix dimensions - M, - N, - K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, - stride_ak, # - stride_bk, - stride_bn, # - stride_cm, - stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, # - ACTIVATION: tl.constexpr, # + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + ACTIVATION: tl.constexpr, # ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -131,9 +121,7 @@ def matmul_kernel( # Comment out the following lines to enable split the workload to two vector cores SUB_BLK_M: tl.constexpr = BLOCK_SIZE_M // 2 for s in extension.parallel(0, 2, bind_sub_block=True): - vec_sub_blk = extension.extract_slice( - accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1) - ) + vec_sub_blk = extension.extract_slice(accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1)) if ACTIVATION == "leaky_relu_custom": vec_sub_blk = leaky_relu_custom(vec_sub_blk) c_sub_blk = vec_sub_blk.to(tl.float16) @@ -172,24 +160,18 @@ def matmul(a, b, activation=""): K, N = b.shape # Allocates output. c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. def grid(META): return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) matmul_kernel[grid]( - a, - b, - c, # - M, - N, - K, # - a.stride(0), - a.stride(1), # - b.stride(0), - b.stride(1), # - c.stride(0), - c.stride(1), # + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # ACTIVATION=activation, # ) return c @@ -214,4 +196,4 @@ def test(): if __name__ == "__main__": - test() \ No newline at end of file + test() diff --git a/third_party/ascend/tutorials/04-low-memory-dropout.py b/third_party/ascend/tutorials/04-low-memory-dropout.py index 0947d85ec3..2c5570a0f4 100644 --- a/third_party/ascend/tutorials/04-low-memory-dropout.py +++ b/third_party/ascend/tutorials/04-low-memory-dropout.py @@ -19,7 +19,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Low-Memory Dropout ================== @@ -119,7 +118,6 @@ def test(): ["output"] + output.tolist(), ])) - x = torch.randn(size=(10, ), device=DEV) # Compare this to the baseline - dropout mask is never instantiated! output = seeded_dropout(x, p=0.5, seed=123) @@ -136,4 +134,4 @@ def test(): if __name__ == "__main__": - test() \ No newline at end of file + test() diff --git a/third_party/ascend/tutorials/05-layer-norm.py b/third_party/ascend/tutorials/05-layer-norm.py index b7361e9300..8af05fc81f 100644 --- a/third_party/ascend/tutorials/05-layer-norm.py +++ b/third_party/ascend/tutorials/05-layer-norm.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Layer Normalization ============= diff --git a/third_party/ascend/tutorials/06-fused-attention.py b/third_party/ascend/tutorials/06-fused-attention.py index dfc03e21b8..8f67a6b3fb 100644 --- a/third_party/ascend/tutorials/06-fused-attention.py +++ b/third_party/ascend/tutorials/06-fused-attention.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Fused Attention =============== @@ -40,7 +39,6 @@ import triton.language as tl import triton.language.extra.cann.extension as extension - DEVICE = "npu" @@ -49,8 +47,10 @@ def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, quer K_block_ptr, V_block_ptr, # Key and value block pointers for current stage start_m, qk_scale, # Starting position of current query block, qk scale factor BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # Block size constants - STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # Current stage flag, m and n offset indices - N_CTX: tl.constexpr, fp8_v: tl.constexpr): # Total context length, whether to enable FP8 for value precision + STAGE: tl.constexpr, offs_m: tl.constexpr, + offs_n: tl.constexpr, # Current stage flag, m and n offset indices + N_CTX: tl.constexpr, + fp8_v: tl.constexpr): # Total context length, whether to enable FP8 for value precision # Set the processing range [lo, hi) for the current stage (in column block units) # Causal attention, as the name implies, restricts the flow of information during computation, # only allowing the model to see the current and previous positions. @@ -145,18 +145,12 @@ def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, quer @triton.jit -def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, - stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, stride_qk: tl.constexpr, - stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, stride_kk: tl.constexpr, - stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, stride_vk: tl.constexpr, - stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, stride_on: tl.constexpr, - Z: tl.constexpr, H: tl.constexpr, - N_CTX: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - STAGE: tl.constexpr - ): +def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, + stride_qk: tl.constexpr, stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, + stride_kk: tl.constexpr, stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, + stride_vk: tl.constexpr, stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, + stride_on: tl.constexpr, Z: tl.constexpr, H: tl.constexpr, N_CTX: tl.constexpr, HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, STAGE: tl.constexpr): # Total number of blocks in sequence dimension (M) NUM_BLOCKS_M = N_CTX // BLOCK_M # Total tasks = number of sequence blocks × batch size (Z) × number of attention heads (H) @@ -215,11 +209,8 @@ def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, if HEAD_DIM < 256: acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) else: - acc_offset = ( - off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM + - off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM + - task_m_idx * BLOCK_M * HEAD_DIM - ) + acc_offset = (off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM + + off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM + task_m_idx * BLOCK_M * HEAD_DIM) acc_ptr = acc + acc_offset # load q: it will stay in SRAM throughout @@ -288,25 +279,17 @@ def forward(ctx, q, k, v, causal, sm_scale, BM, BN): out = torch.empty_like(q) stage = 3 if causal else 1 extra_kern_args = {} - # Number of NPU cores (adjust based on hardware) num_cores = 20 acc = torch.zeros((q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), dtype=torch.float32, device=q.device) M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - _attn_fwd[(num_cores,)]( - q, k, v, M, out, acc, sm_scale, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - out.stride(0), out.stride(1), out.stride(2), out.stride(3), - q.shape[0], q.shape[1], N_CTX=q.shape[2], - HEAD_DIM=HEAD_DIM_K, - BLOCK_M=BM, - BLOCK_N=BN, - STAGE=stage, - **extra_kern_args) + _attn_fwd[(num_cores, )](q, k, v, M, out, acc, sm_scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), + v.stride(2), v.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), + q.shape[0], q.shape[1], N_CTX=q.shape[2], HEAD_DIM=HEAD_DIM_K, BLOCK_M=BM, BLOCK_N=BN, + STAGE=stage, **extra_kern_args) ctx.save_for_backward(q, k, v, out, M) ctx.sm_scale = sm_scale @@ -314,6 +297,7 @@ def forward(ctx, q, k, v, causal, sm_scale, BM, BN): ctx.causal = causal return out + attention = _attention.apply @@ -340,16 +324,19 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN): tri_out = attention(q, k, v, causal, sm_scale, BM, BN) ref_out = torch_npu.npu_fusion_attention( - q, k, v, H, - padding_mask=None, - atten_mask=None, - scale=sm_scale, - keep_prob=1.0, - input_layout="BNSD", - pre_tockens=65535, - next_tockens=65535, - sparse_mode=0, - )[0] + q, + k, + v, + H, + padding_mask=None, + atten_mask=None, + scale=sm_scale, + keep_prob=1.0, + input_layout="BNSD", + pre_tockens=65535, + next_tockens=65535, + sparse_mode=0, + )[0] torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2, equal_nan=True) print(f"[PASSED] Attention shape:({Z}, {H}, {N_CTX}, {HEAD_DIM}), BM: {BM}, BN: {BN}, dtype: {dtype}") diff --git a/third_party/ascend/tutorials/07-extern-functions.py b/third_party/ascend/tutorials/07-extern-functions.py index f48953b0f3..e433640245 100644 --- a/third_party/ascend/tutorials/07-extern-functions.py +++ b/third_party/ascend/tutorials/07-extern-functions.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Libdevice (`tl.extra.libdevice`) function ============================== @@ -69,8 +68,7 @@ def grid(meta): print(output_torch) print(output_triton) print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}') - + f'{torch.max(torch.abs(output_torch - output_triton))}') current_file = inspect.getfile(inspect.currentframe()) current_dir = Path(os.path.dirname(os.path.abspath(current_file))) @@ -82,7 +80,7 @@ def grid(meta): print(output_torch) print(output_triton) print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}') + f'{torch.max(torch.abs(output_torch - output_triton))}') if __name__ == "__main__": diff --git a/third_party/ascend/tutorials/08-grouped-gemm.py b/third_party/ascend/tutorials/08-grouped-gemm.py index 4ba59bf8e0..96739be81d 100644 --- a/third_party/ascend/tutorials/08-grouped-gemm.py +++ b/third_party/ascend/tutorials/08-grouped-gemm.py @@ -18,7 +18,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Group GEMM ============================ @@ -163,7 +162,7 @@ def group_gemm_fn(group_A, group_B): d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) def grid(meta): - return (meta['NUM_SM'],) + return (meta['NUM_SM'], ) grouped_matmul_kernel[grid]( d_a_ptrs, @@ -205,7 +204,7 @@ def test(): def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): def grid(meta): - return (meta['NUM_SM'],) + return (meta['NUM_SM'], ) grouped_matmul_kernel[grid]( a_ptrs, @@ -279,4 +278,4 @@ def bench_triton(): if __name__ == "__main__": - test() \ No newline at end of file + test() diff --git a/third_party/ascend/tutorials/09-persistent-matmul.py b/third_party/ascend/tutorials/09-persistent-matmul.py index f80c4852a9..0e085dc624 100644 --- a/third_party/ascend/tutorials/09-persistent-matmul.py +++ b/third_party/ascend/tutorials/09-persistent-matmul.py @@ -19,7 +19,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Persistent Matmul ===================== @@ -311,10 +310,8 @@ def validate(M, N, K): persistent_vs_torch = "✅" if torch.allclose(persistent_result, torch_result, atol=1.0) else "❌" naive_vs_persistent = "✅" if torch.allclose(naive_result, persistent_result, atol=1.0) else "❌" - print( - f"M={M}, N={N}, K={K} verification naive vs torch: {naive_vs_torch} " - f"persistent vs torch: {persistent_vs_torch} naive vs persistent: {naive_vs_persistent}" - ) + print(f"M={M}, N={N}, K={K} verification naive vs torch: {naive_vs_torch} " + f"persistent vs torch: {persistent_vs_torch} naive vs persistent: {naive_vs_persistent}") if __name__ == "__main__": diff --git a/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/copy_use_analysis.mlir b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/copy_use_analysis.mlir index b8074c87ff..bb98ba7bee 100644 --- a/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/copy_use_analysis.mlir +++ b/third_party/ascend/unittest/Conversion/950PR/TritonToLinalg/copy_use_analysis.mlir @@ -260,4 +260,3 @@ module { tt.return } } - diff --git a/third_party/ascend/unittest/Conversion/General/AutoBlockify/auto_blockify.mlir b/third_party/ascend/unittest/Conversion/General/AutoBlockify/auto_blockify.mlir index 2fcb46e6e3..c9a1dcb409 100644 --- a/third_party/ascend/unittest/Conversion/General/AutoBlockify/auto_blockify.mlir +++ b/third_party/ascend/unittest/Conversion/General/AutoBlockify/auto_blockify.mlir @@ -131,4 +131,4 @@ tt.func @kernel2(%arg0: !tt.ptr) { scf.yield } tt.return -} \ No newline at end of file +} diff --git a/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/atomic.mlir b/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/atomic.mlir index 71d8fb3db1..c6978c29bc 100644 --- a/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/atomic.mlir +++ b/third_party/ascend/unittest/Conversion/General/DiscreteMaskAccess/atomic.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: tt.func @atomic_add_i32 // CHECK: %[[default:.*]] = arith.constant dense<0> : tensor<1024xi32> // CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] -// CHECK: %[[result:.*]] = tt.atomic_rmw add, acq_rel, gpu, %[[ptr:.*]], %[[value]] +// CHECK: %[[result:.*]] = tt.atomic_rmw add, acq_rel, gpu, %[[ptr:.*]], %[[value]] tt.func @atomic_add_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { %cst = arith.constant dense<200> : tensor<1024xi32> %cst_0 = arith.constant dense<400> : tensor<1024xi32> @@ -23,7 +23,7 @@ tt.func @atomic_add_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-LABEL: tt.func @atomic_fadd_f32 // CHECK: %[[default:.*]] = arith.constant dense<0.000000e+00> : tensor<1024xf32> // CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] -// CHECK: %[[result:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, %[[ptr:.*]], %[[value]] +// CHECK: %[[result:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, %[[ptr:.*]], %[[value]] tt.func @atomic_fadd_f32(%arg0: !tt.ptr, %arg1: !tt.ptr) { %cst = arith.constant dense<200> : tensor<1024xi32> %cst_0 = arith.constant dense<400> : tensor<1024xi32> @@ -43,7 +43,7 @@ tt.func @atomic_fadd_f32(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-LABEL: tt.func @atomic_max_i32 // CHECK: %[[default:.*]] = arith.constant dense<-2147483648> : tensor<1024xi32> // CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] -// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] +// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] tt.func @atomic_max_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { %cst = arith.constant dense<200> : tensor<1024xi32> %cst_0 = arith.constant dense<400> : tensor<1024xi32> @@ -63,7 +63,7 @@ tt.func @atomic_max_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-LABEL: tt.func @atomic_umax_i32 // CHECK: %[[default:.*]] = arith.constant dense<0> : tensor<1024xi32> // CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] -// CHECK: %[[result:.*]] = tt.atomic_rmw umax, acq_rel, gpu, %[[ptr:.*]], %[[value]] +// CHECK: %[[result:.*]] = tt.atomic_rmw umax, acq_rel, gpu, %[[ptr:.*]], %[[value]] tt.func @atomic_umax_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { %cst = arith.constant dense<200> : tensor<1024xi32> %cst_0 = arith.constant dense<400> : tensor<1024xi32> @@ -83,7 +83,7 @@ tt.func @atomic_umax_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-LABEL: tt.func @atomic_min_i32 // CHECK: %[[default:.*]] = arith.constant dense<2147483647> : tensor<1024xi32> // CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] -// CHECK: %[[result:.*]] = tt.atomic_rmw min, acq_rel, gpu, %[[ptr:.*]], %[[value]] +// CHECK: %[[result:.*]] = tt.atomic_rmw min, acq_rel, gpu, %[[ptr:.*]], %[[value]] tt.func @atomic_min_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { %cst = arith.constant dense<200> : tensor<1024xi32> %cst_0 = arith.constant dense<400> : tensor<1024xi32> @@ -103,7 +103,7 @@ tt.func @atomic_min_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-LABEL: tt.func @atomic_umin_i32 // CHECK: %[[default:.*]] = arith.constant dense<2147483647> : tensor<1024xi32> // CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] -// CHECK: %[[result:.*]] = tt.atomic_rmw umin, acq_rel, gpu, %[[ptr:.*]], %[[value]] +// CHECK: %[[result:.*]] = tt.atomic_rmw umin, acq_rel, gpu, %[[ptr:.*]], %[[value]] tt.func @atomic_umin_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { %cst = arith.constant dense<200> : tensor<1024xi32> %cst_0 = arith.constant dense<400> : tensor<1024xi32> @@ -123,7 +123,7 @@ tt.func @atomic_umin_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-LABEL: tt.func @atomic_and_i32 // CHECK: %[[default:.*]] = arith.constant dense<2147483647> : tensor<1024xi32> // CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] -// CHECK: %[[result:.*]] = tt.atomic_rmw and, acq_rel, gpu, %[[ptr:.*]], %[[value]] +// CHECK: %[[result:.*]] = tt.atomic_rmw and, acq_rel, gpu, %[[ptr:.*]], %[[value]] tt.func @atomic_and_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { %cst = arith.constant dense<200> : tensor<1024xi32> %cst_0 = arith.constant dense<400> : tensor<1024xi32> @@ -143,7 +143,7 @@ tt.func @atomic_and_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-LABEL: tt.func @atomic_or_i32 // CHECK: %[[default:.*]] = arith.constant dense<0> : tensor<1024xi32> // CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] -// CHECK: %[[result:.*]] = tt.atomic_rmw or, acq_rel, gpu, %[[ptr:.*]], %[[value]] +// CHECK: %[[result:.*]] = tt.atomic_rmw or, acq_rel, gpu, %[[ptr:.*]], %[[value]] tt.func @atomic_or_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { %cst = arith.constant dense<200> : tensor<1024xi32> %cst_0 = arith.constant dense<400> : tensor<1024xi32> @@ -163,7 +163,7 @@ tt.func @atomic_or_i32(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-LABEL: tt.func @atomic_max_i16 // CHECK: %[[default:.*]] = arith.constant dense<-32768> : tensor<1024xi16> // CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] -// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] +// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] tt.func @atomic_max_i16(%arg0: !tt.ptr, %arg1: !tt.ptr) { %cst = arith.constant dense<200> : tensor<1024xi32> %cst_0 = arith.constant dense<400> : tensor<1024xi32> @@ -183,7 +183,7 @@ tt.func @atomic_max_i16(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK-LABEL: tt.func @atomic_max_f16 // CHECK: %[[default:.*]] = arith.constant dense<0xFC00> : tensor<1024xf16> // CHECK: %[[value:.*]] = arith.select %[[mask:.*]], %[[origin:.*]], %[[default]] -// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] +// CHECK: %[[result:.*]] = tt.atomic_rmw max, acq_rel, gpu, %[[ptr:.*]], %[[value]] tt.func @atomic_max_f16(%arg0: !tt.ptr, %arg1: !tt.ptr) { %cst = arith.constant dense<200> : tensor<1024xi32> %cst_0 = arith.constant dense<400> : tensor<1024xi32> @@ -198,4 +198,4 @@ tt.func @atomic_max_f16(%arg0: !tt.ptr, %arg1: !tt.ptr) { %8 = tt.load %7 : tensor<1024x!tt.ptr> %9 = tt.atomic_rmw max, acq_rel, gpu, %5, %8, %3 : (tensor<1024x!tt.ptr>, tensor<1024xf16>, tensor<1024xi1>) -> tensor<1024xf16> tt.return -} \ No newline at end of file +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonAscendAllPass/simplify_for_loop.mlir b/third_party/ascend/unittest/Conversion/General/TritonAscendAllPass/simplify_for_loop.mlir index 30b3c375c2..d1e8eb8a2e 100644 --- a/third_party/ascend/unittest/Conversion/General/TritonAscendAllPass/simplify_for_loop.mlir +++ b/third_party/ascend/unittest/Conversion/General/TritonAscendAllPass/simplify_for_loop.mlir @@ -116,6 +116,3 @@ module { tt.return } } - - - diff --git a/third_party/ascend/unittest/Conversion/General/TritonToHFusion/fp_to_fp_rtz.mlir b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/fp_to_fp_rtz.mlir index cbabb03d89..fbff6cb341 100644 --- a/third_party/ascend/unittest/Conversion/General/TritonToHFusion/fp_to_fp_rtz.mlir +++ b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/fp_to_fp_rtz.mlir @@ -15,4 +15,4 @@ tt.func @test_fp32_to_fp16_rtz_fail(%arg0: tensor<1024xf32>) -> tensor<1024xf16> %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf32> -> tensor<1024xf16> // CHECK: %{{.*}} = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf32> -> tensor<1024xf16> tt.return %0 : tensor<1024xf16> -} \ No newline at end of file +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToHFusion/mod.mlir b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/mod.mlir index fd9b479006..d4c264913a 100644 --- a/third_party/ascend/unittest/Conversion/General/TritonToHFusion/mod.mlir +++ b/third_party/ascend/unittest/Conversion/General/TritonToHFusion/mod.mlir @@ -8,4 +8,4 @@ module { %0 = ascend.mod %arg0, %arg1 : tensor<1xf32> tensor<1xf32> -> tensor<1xf32> tt.return %0 : tensor<1xf32> } -} \ No newline at end of file +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToHIVM/sync_block_op_conversion.mlir b/third_party/ascend/unittest/Conversion/General/TritonToHIVM/sync_block_op_conversion.mlir index d004352a3e..186e7210fc 100644 --- a/third_party/ascend/unittest/Conversion/General/TritonToHIVM/sync_block_op_conversion.mlir +++ b/third_party/ascend/unittest/Conversion/General/TritonToHIVM/sync_block_op_conversion.mlir @@ -17,4 +17,4 @@ tt.func @triton_func() { // CHECK: hivm.hir.sync_block_wait[, , ] flag = 2 // CHECK: hivm.hir.sync_block[, 1 : i16] tcube_pipe = // CHECK: hivm.hir.sync_block[, 1 : i16] tvector_pipe = -// CHECK: hivm.hir.sync_block[, 1 : i16] tcube_pipe = tvector_pipe = \ No newline at end of file +// CHECK: hivm.hir.sync_block[, 1 : i16] tcube_pipe = tvector_pipe = diff --git a/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw_block.mlir b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw_block.mlir index c7ea71622c..a5ec468da5 100644 --- a/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw_block.mlir +++ b/third_party/ascend/unittest/Conversion/General/TritonToLinalg/atomic_rmw_block.mlir @@ -2,45 +2,45 @@ module attributes {hacc.target = #hacc.target<"Ascend910B2">} { tt.func public @moe_align_block_size_stage4(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} , %arg3: !tt.ptr {tt.divisibility = 16 : i32} , %arg4: !tt.ptr {tt.divisibility = 16 : i32} , %arg5: i32) attributes {noinline = false} { - %cst = arith.constant dense<1> : tensor<1xi32> - %cst_0 = arith.constant dense<0> : tensor<1xi32> - %c250_i32 = arith.constant 250 : i32 - %c16_i32 = arith.constant 16 : i32 - %c1_i32 = arith.constant 1 : i32 - %0 = tt.get_program_id x : i32 - %1 = tt.addptr %arg4, %0 : !tt.ptr, i32 - %2 = tt.load %1 : !tt.ptr - %3 = tt.addptr %1, %c1_i32 : !tt.ptr, i32 - %4 = tt.load %3 : !tt.ptr + %cst = arith.constant dense<1> : tensor<1xi32> + %cst_0 = arith.constant dense<0> : tensor<1xi32> + %c250_i32 = arith.constant 250 : i32 + %c16_i32 = arith.constant 16 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.addptr %arg4, %0 : !tt.ptr, i32 + %2 = tt.load %1 : !tt.ptr + %3 = tt.addptr %1, %c1_i32 : !tt.ptr, i32 + %4 = tt.load %3 : !tt.ptr scf.for %arg6 = %2 to %4 step %c16_i32 : i32 { - %22 = arith.divsi %arg6, %c16_i32 : i32 - %23 = tt.addptr %arg2, %22 : !tt.ptr, i32 - tt.store %23, %0 : !tt.ptr - } - %5 = arith.muli %0, %c250_i32 : i32 - %6 = tt.splat %0 : i32 -> tensor<1xi32> - %7 = arith.cmpi slt, %0, %arg5 : i32 - %8 = tt.splat %7 : i1 -> tensor<1xi1> - %9 = tt.addptr %arg0, %0 : !tt.ptr, i32 - %10 = tt.splat %9 : !tt.ptr -> tensor<1x!tt.ptr> - %11 = tt.load %10, %8, %cst_0 : tensor<1x!tt.ptr> - %12 = tt.addptr %arg3, %5 : !tt.ptr, i32 - %13 = tt.splat %12 : !tt.ptr -> tensor<1x!tt.ptr> - %14 = tt.addptr %13, %11 : tensor<1x!tt.ptr>, tensor<1xi32> - %15 = tt.atomic_rmw add, acq_rel, gpu, %14, %cst, %8 : (tensor<1x!tt.ptr>, tensor<1xi32>, tensor<1xi1>) -> tensor<1xi32> - %16 = tt.splat %arg4 : !tt.ptr -> tensor<1x!tt.ptr> - %17 = tt.addptr %16, %11 : tensor<1x!tt.ptr>, tensor<1xi32> - %18 = tt.load %17, %8, %cst_0 : tensor<1x!tt.ptr> - %19 = arith.addi %15, %18 : tensor<1xi32> - %20 = tt.splat %arg1 : !tt.ptr -> tensor<1x!tt.ptr> - %21 = tt.addptr %20, %19 : tensor<1x!tt.ptr>, tensor<1xi32> - tt.store %21, %6, %8 : tensor<1x!tt.ptr> - tt.return - } -} + %22 = arith.divsi %arg6, %c16_i32 : i32 + %23 = tt.addptr %arg2, %22 : !tt.ptr, i32 + tt.store %23, %0 : !tt.ptr + } + %5 = arith.muli %0, %c250_i32 : i32 + %6 = tt.splat %0 : i32 -> tensor<1xi32> + %7 = arith.cmpi slt, %0, %arg5 : i32 + %8 = tt.splat %7 : i1 -> tensor<1xi1> + %9 = tt.addptr %arg0, %0 : !tt.ptr, i32 + %10 = tt.splat %9 : !tt.ptr -> tensor<1x!tt.ptr> + %11 = tt.load %10, %8, %cst_0 : tensor<1x!tt.ptr> + %12 = tt.addptr %arg3, %5 : !tt.ptr, i32 + %13 = tt.splat %12 : !tt.ptr -> tensor<1x!tt.ptr> + %14 = tt.addptr %13, %11 : tensor<1x!tt.ptr>, tensor<1xi32> + %15 = tt.atomic_rmw add, acq_rel, gpu, %14, %cst, %8 : (tensor<1x!tt.ptr>, tensor<1xi32>, tensor<1xi1>) -> tensor<1xi32> + %16 = tt.splat %arg4 : !tt.ptr -> tensor<1x!tt.ptr> + %17 = tt.addptr %16, %11 : tensor<1x!tt.ptr>, tensor<1xi32> + %18 = tt.load %17, %8, %cst_0 : tensor<1x!tt.ptr> + %19 = arith.addi %15, %18 : tensor<1xi32> + %20 = tt.splat %arg1 : !tt.ptr -> tensor<1x!tt.ptr> + %21 = tt.addptr %20, %19 : tensor<1x!tt.ptr>, tensor<1xi32> + tt.store %21, %6, %8 : tensor<1x!tt.ptr> + tt.return + } +} // CHECK-LABEL: func.func @moe_align_block_size_stage4 // CHECK: %[[CAST1:.*]] = memref.reinterpret_cast %[[.*]] to offset: [%[[.*]]], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1], offset: ?>> // CHECK: %[[CAST2:.*]] = memref.alloc() : memref<1xi32> -// CHECK: memref.copy %[[CAST1]], %[[CAST2]] : memref<1xi32, strided<[1], offset: ?>> to memref<1xi32> \ No newline at end of file +// CHECK: memref.copy %[[CAST1]], %[[CAST2]] : memref<1xi32, strided<[1], offset: ?>> to memref<1xi32> diff --git a/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseCmp.mlir b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseCmp.mlir index 5c589cca91..abdae3fa16 100644 --- a/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseCmp.mlir +++ b/third_party/ascend/unittest/Conversion/General/TritonToStructured/parseCmp.mlir @@ -50,7 +50,7 @@ tt.func public @test_cmp_ult(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, % // CHECK: %[[VAL_23:.*]] = arith.cmpi slt, %[[VAL_22]], %[[VAL_4]] : tensor<2xi32> // CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_19]], %[[VAL_23]], %[[VAL_3]] : tensor<2x!tt.ptr> // CHECK: %[[VAL_25:.*]] = tensor.empty() : tensor<2x512xf32> -// CHECK: %[[VAL_26:.*]] = linalg.broadcast ins(%[[VAL_24]] : tensor<2xf32>) outs(%[[VAL_25]] : tensor<2x512xf32>) dimensions = [1] +// CHECK: %[[VAL_26:.*]] = linalg.broadcast ins(%[[VAL_24]] : tensor<2xf32>) outs(%[[VAL_25]] : tensor<2x512xf32>) dimensions = [1] // CHECK: %[[VAL_27:.*]] = tensor.reshape %[[VAL_26]](%[[VAL_2]]) : (tensor<2x512xf32>, tensor<1xi64>) -> tensor<1024xf32> // CHECK: %[[VAL_28:.*]] = arith.muli %[[VAL_7]], %[[VAL_6]] : i32 // CHECK: %[[VAL_29:.*]] = tt.splat %[[VAL_28]] : i32 -> tensor<1024xi32> @@ -73,12 +73,12 @@ tt.func public @test_cmp_uge(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, % %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c1024_i32 {tt.divisibility = dense<512> : tensor<1xi32>} : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> %4 = arith.addi %3, %2 : tensor<1024xi32> %5 = arith.divsi %4, %cst_1 : tensor<1024xi32> %6 = arith.cmpi uge, %cst_0, %5 : tensor<1024xi32> %7 = arith.muli %5, %cst_1 : tensor<1024xi32> - %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> %10 = tt.load %9, %6, %cst : tensor<1024x!tt.ptr> %11 = arith.muli %0, %c1024_i32 : i32 diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/bubbleupoperation.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/bubbleupoperation.mlir index 13d653f80f..626f670d1d 100644 --- a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/bubbleupoperation.mlir +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/bubbleupoperation.mlir @@ -124,4 +124,4 @@ tt.func @test_slice_all_bubbleup(%i: index, %c: f32) -> tensor<128xi32> { %0 = tt.make_range {start = 0 : i32, end = 128 : i32} : tensor<128xi32> %1 = tensor.extract_slice %0[0][128][1] : tensor<128xi32> to tensor<128xi32> tt.return %1 : tensor<128xi32> -} \ No newline at end of file +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/nested_loop.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/nested_loop.mlir index 984141e3a2..77aed1ce16 100644 --- a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/nested_loop.mlir +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/nested_loop.mlir @@ -204,4 +204,4 @@ tt.func public @test_kernel2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, % // CHECK: scf.yield %[[VAL_70]]#3, %[[VAL_69]], %[[VAL_71]], %[[VAL_72]] : tensor<128xi32>, tensor<128xi32>, tensor<128xi64>, tensor<128xi64> // CHECK: } // CHECK: tt.return -// CHECK: } \ No newline at end of file +// CHECK: } diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/splat.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/splat.mlir index dbfa654c23..cc3accd712 100644 --- a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/splat.mlir +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/splat.mlir @@ -11,4 +11,4 @@ tt.func@test_unstructure_splatandloadscenario(%base: !tt.ptr) -> tensor<128 %ptr = tt.addptr %base_tensor, %offset_tensor : tensor<128x!tt.ptr>, tensor<128xi64> %val = tt.load %ptr : tensor<128x!tt.ptr> tt.return %val : tensor<128xf32> -} \ No newline at end of file +} diff --git a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/unstructure_mix.mlir b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/unstructure_mix.mlir index db36ad34c1..4516711f08 100644 --- a/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/unstructure_mix.mlir +++ b/third_party/ascend/unittest/Conversion/General/TritonToUnstructure/unstructure_mix.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt --triton-to-unstructure %s | FileCheck %s +// RUN: triton-opt --triton-to-unstructure %s | FileCheck %s tt.func public @indirect_mix_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<16> : tensor<1x8xi32> @@ -79,4 +79,4 @@ tt.func public @indirect_mix_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : // CHECK: %[[VAL_43:.*]] = tt.addptr %[[VAL_42]], %[[VAL_41]] : tensor<16x8x!tt.ptr>, tensor<16x8xi32> // CHECK: tt.store %[[VAL_43]], %[[VAL_35]] : tensor<16x8x!tt.ptr> // CHECK: tt.return -// CHECK: } \ No newline at end of file +// CHECK: } diff --git a/third_party/ascend/unittest/autotune_ut/01-vector-add.py b/third_party/ascend/unittest/autotune_ut/01-vector-add.py index 555b961d94..1219c1a2e3 100644 --- a/third_party/ascend/unittest/autotune_ut/01-vector-add.py +++ b/third_party/ascend/unittest/autotune_ut/01-vector-add.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Vector Add ============= @@ -33,19 +32,15 @@ from triton.backends.ascend.testing import do_bench_npu -@triton.autotune( - configs=[], - key=["n_elements"] -) +@triton.autotune(configs=[], key=["n_elements"]) @triton.jit -def add_kernel( - x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. -): +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): # There are multiple 'programs' processing different data. We identify which program # we are here: pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. @@ -73,7 +68,7 @@ def add_torch(x, y): def add_autotune(x, y): output = torch.empty_like(x) n_elements = output.numel() - add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)](x, y, output, n_elements) + add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )](x, y, output, n_elements) return output diff --git a/third_party/ascend/unittest/autotune_ut/02-fused-softmax.py b/third_party/ascend/unittest/autotune_ut/02-fused-softmax.py index 09c7de0f80..66f6c7a371 100644 --- a/third_party/ascend/unittest/autotune_ut/02-fused-softmax.py +++ b/third_party/ascend/unittest/autotune_ut/02-fused-softmax.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Fused Softmax ============= @@ -62,16 +61,10 @@ def softmax_kernel( # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols row = tl.load(input_ptrs, mask=mask, other=-float("inf")) # Subtract maximum for numerical stability - row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to( - XBLOCK_SUB, BLOCK_SIZE - ) + row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE) # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) numerator = tl.exp(row_minus_max) - denominator = ( - tl.sum(numerator, axis=1) - .reshape(XBLOCK_SUB, 1) - .broadcast_to(XBLOCK_SUB, BLOCK_SIZE) - ) + denominator = (tl.sum(numerator, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE)) softmax_output = numerator / denominator # Write back output to DRAM output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) @@ -89,9 +82,8 @@ def softmax_autotune(x): # Allocate output y = torch.empty_like(x) # Create a number of persistent programs. - softmax_kernel[lambda meta: (triton.cdiv(n_rows, meta["XBLOCK"]), 1, 1)]( - y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE - ) + softmax_kernel[lambda meta: (triton.cdiv(n_rows, meta["XBLOCK"]), 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, + n_cols, BLOCK_SIZE=BLOCK_SIZE) return y diff --git a/third_party/ascend/unittest/autotune_ut/03-layer-norm.py b/third_party/ascend/unittest/autotune_ut/03-layer-norm.py index 42e6b5b1bd..ab547cfa36 100644 --- a/third_party/ascend/unittest/autotune_ut/03-layer-norm.py +++ b/third_party/ascend/unittest/autotune_ut/03-layer-norm.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Layer Normalization ============= @@ -63,9 +62,7 @@ def _layer_norm_fwd_fused( col_idx = off + tl.arange(0, RBLOCK_SIZE) col_mask = col_idx < N mask = row_mask[:, None] & col_mask[None, :] - a = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to( - tl.float32 - ) + a = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to(tl.float32) _mean += a mean = tl.sum(_mean, axis=1, keep_dims=True) / N # Compute variance @@ -74,9 +71,7 @@ def _layer_norm_fwd_fused( col_idx = off + tl.arange(0, RBLOCK_SIZE) col_mask = col_idx < N mask = row_mask[:, None] & col_mask[None, :] - x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to( - tl.float32 - ) + x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to(tl.float32) x = tl.where(mask, x - mean, 0.0) _var += x * x var = tl.sum(_var, axis=1, keep_dims=True) / N @@ -91,9 +86,7 @@ def _layer_norm_fwd_fused( mask = row_mask[:, None] & col_mask[None, :] w = tl.load(W + col_idx, mask=col_mask).reshape((1, RBLOCK_SIZE)) b = tl.load(B + col_idx, mask=col_mask).reshape((1, RBLOCK_SIZE)) - x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to( - tl.float32 - ) + x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to(tl.float32) x_hat = (x - mean) * rstd y = x_hat * w + b # Write output @@ -112,8 +105,8 @@ def layer_norm_autotune(args): # reshape input data into 2D tensor x_arg = x.reshape(-1, x.shape[-1]) M, N = x_arg.shape - mean = torch.empty((M,), dtype=torch.float32, device=x.device) - rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) # enqueue kernel _layer_norm_fwd_fused[lambda meta: (triton.cdiv(M, meta["XBLOCK_SIZE"]), 1, 1)]( # @@ -126,7 +119,7 @@ def test_layer_norm(shape, dtype, eps=1e-5): M, N = shape device = "npu" x_shape = shape - w_shape = (x_shape[-1],) + w_shape = (x_shape[-1], ) weight = torch.rand(w_shape, dtype=dtype, device=device) bias = torch.rand(w_shape, dtype=dtype, device=device) x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) diff --git a/third_party/ascend/unittest/autotune_ut/04-libentry.py b/third_party/ascend/unittest/autotune_ut/04-libentry.py index 72949d9cf9..e956ba034b 100644 --- a/third_party/ascend/unittest/autotune_ut/04-libentry.py +++ b/third_party/ascend/unittest/autotune_ut/04-libentry.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Vector Add with Libentry ============= @@ -43,19 +42,16 @@ triton.Config({'BLOCK_SIZE': 12 * 1024, 'multibuffer': True}), triton.Config({'BLOCK_SIZE': 12 * 1024, 'multibuffer': False}), triton.Config({'BLOCK_SIZE': 8 * 1024, 'multibuffer': True}), - ], - key=["n_elements"] -) + ], key=["n_elements"]) @libentry() @triton.jit -def add_kernel( - x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. -): +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): # There are multiple 'programs' processing different data. We identify which program # we are here: pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. @@ -83,7 +79,7 @@ def add_torch(x, y): def add_autotune(x, y): output = torch.empty_like(x) n_elements = output.numel() - add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)](x, y, output, n_elements) + add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )](x, y, output, n_elements) return output diff --git a/third_party/ascend/unittest/autotune_ut/test_autotune_param_valid.py b/third_party/ascend/unittest/autotune_ut/test_autotune_param_valid.py index 1ded361ae2..4ce48e6c19 100644 --- a/third_party/ascend/unittest/autotune_ut/test_autotune_param_valid.py +++ b/third_party/ascend/unittest/autotune_ut/test_autotune_param_valid.py @@ -29,23 +29,20 @@ @triton.autotune( - configs=[], - key={"x": "n_elements"}, - hints={ + configs=[], key={"x": "n_elements"}, hints={ "split_params": {"x": "BLOCK_SIZE"}, "tiling_params": {"x": "BLOCK_SIZE_SUB"}, "low_dim_axes": ["x"], "reduction_axes": [], - } -) + }) @triton.jit def add_kernel( - x_ptr, - y_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - BLOCK_SIZE_SUB: tl.constexpr, + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_SUB: tl.constexpr, ): offset = tl.program_id(0) * BLOCK_SIZE loops1 = (BLOCK_SIZE + BLOCK_SIZE_SUB - 1) // BLOCK_SIZE_SUB @@ -65,12 +62,14 @@ def add_torch(x, y): def add_autotune(x, y): output = torch.empty_like(x) n_elements = output.numel() - add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)](x, y, output, n_elements) + add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )](x, y, output, n_elements) return output @pytest.mark.autotune -@pytest.mark.parametrize('size', [2048, ]) +@pytest.mark.parametrize('size', [ + 2048, +]) def test_add(size: int): x = torch.rand(size, device="npu") y = torch.rand(size, device="npu") @@ -83,15 +82,13 @@ def test_add(size: int): @pytest.mark.autotune def test_add_no_reduction_axes(): try: + @triton.autotune( - configs=[], - key={"x": "n_elements"}, - hints={ + configs=[], key={"x": "n_elements"}, hints={ "split_params": {"x": "BLOCK_SIZE"}, "tiling_params": {"x": "BLOCK_SIZE_SUB"}, "low_dim_axes": ["x"], - } - ) + }) @triton.jit def add_kernel_exception(): pass @@ -102,15 +99,13 @@ def add_kernel_exception(): @pytest.mark.autotune def test_add_no_low_dim_axes(): try: + @triton.autotune( - configs=[], - key={"x": "n_elements"}, - hints={ + configs=[], key={"x": "n_elements"}, hints={ "split_params": {"x": "BLOCK_SIZE"}, "tiling_params": {"x": "BLOCK_SIZE_SUB"}, "reduction_axes": [], - } - ) + }) @triton.jit def add_kernel_exception(): pass @@ -121,15 +116,12 @@ def add_kernel_exception(): @pytest.mark.autotune def test_add_no_tiling_params(): try: - @triton.autotune( - configs=[], - key={"x": "n_elements"}, - hints={ - "split_params": {"x": "BLOCK_SIZE"}, - "low_dim_axes": ["x"], - "reduction_axes": [], - } - ) + + @triton.autotune(configs=[], key={"x": "n_elements"}, hints={ + "split_params": {"x": "BLOCK_SIZE"}, + "low_dim_axes": ["x"], + "reduction_axes": [], + }) @triton.jit def add_kernel_exception(): pass @@ -140,15 +132,13 @@ def add_kernel_exception(): @pytest.mark.autotune def test_add_no_split_params(): try: + @triton.autotune( - configs=[], - key={"x": "n_elements"}, - hints={ + configs=[], key={"x": "n_elements"}, hints={ "tiling_params": {"x": "BLOCK_SIZE_SUB"}, "low_dim_axes": ["x"], "reduction_axes": [], - } - ) + }) @triton.jit def add_kernel_exception(): pass @@ -159,15 +149,13 @@ def add_kernel_exception(): @pytest.mark.autotune def test_add_no_keyname(): try: + @triton.autotune( - configs=[], - key={"x0": "n_elements"}, - hints={ + configs=[], key={"x0": "n_elements"}, hints={ "tiling_params": {"x": "BLOCK_SIZE_SUB"}, "low_dim_axes": ["x"], "reduction_axes": [], - } - ) + }) @triton.jit def add_kernel_exception(): pass diff --git a/third_party/ascend/unittest/autotune_ut/test_common.py b/third_party/ascend/unittest/autotune_ut/test_common.py index 50502ec774..8c80b45a7d 100644 --- a/third_party/ascend/unittest/autotune_ut/test_common.py +++ b/third_party/ascend/unittest/autotune_ut/test_common.py @@ -83,10 +83,7 @@ def normalize_axis_list(axis_list: list, sym_to_sem: dict) -> list: @pytest.fixture def mock_autotuner(): - with mock.patch( - "triton.backends.ascend.runtime.autotuner.AutoTilingTuner.run", - new=MockAutoTilingTunerRun - ): + with mock.patch("triton.backends.ascend.runtime.autotuner.AutoTilingTuner.run", new=MockAutoTilingTunerRun): yield diff --git a/third_party/ascend/unittest/autotune_ut/test_customized_config.py b/third_party/ascend/unittest/autotune_ut/test_customized_config.py index ee73db9f57..3431109569 100644 --- a/third_party/ascend/unittest/autotune_ut/test_customized_config.py +++ b/third_party/ascend/unittest/autotune_ut/test_customized_config.py @@ -35,23 +35,20 @@ triton.Config({'XBLOCK': 128, 'XBLOCK_SUB': 32}), triton.Config({'XBLOCK': 128, 'XBLOCK_SUB': 64}), triton.Config({'XBLOCK': 396, 'XBLOCK_SUB': 6}), - ], - key=["n_rows", "n_cols"], - hints={ + ], key=["n_rows", "n_cols"], hints={ "auto_gen_config": False, - } -) + }) @triton.jit def softmax_kernel( - output_ptr, - input_ptr, - input_row_stride, - output_row_stride, - n_rows, - n_cols, - BLOCK_SIZE: tl.constexpr, - XBLOCK: tl.constexpr, - XBLOCK_SUB: tl.constexpr, + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, + XBLOCK: tl.constexpr, + XBLOCK_SUB: tl.constexpr, ): row_start = tl.program_id(0) * XBLOCK for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB): @@ -62,15 +59,9 @@ def softmax_kernel( mask = xmask & ymask input_ptrs = input_ptr + (row_offsets * input_row_stride + col_offsets) row = tl.load(input_ptrs, mask=mask, other=-float("inf")) - row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to( - XBLOCK_SUB, BLOCK_SIZE - ) + row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE) numerator = tl.exp(row_minus_max) - denominator = ( - tl.sum(numerator, axis=1) - .reshape(XBLOCK_SUB, 1) - .broadcast_to(XBLOCK_SUB, BLOCK_SIZE) - ) + denominator = (tl.sum(numerator, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE)) softmax_output = numerator / denominator output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) tl.store(output_ptrs, softmax_output, mask=mask) @@ -84,14 +75,15 @@ def softmax_autotune(x): n_rows, n_cols = x.shape BLOCK_SIZE = n_cols y = torch.empty_like(x) - softmax_kernel[lambda meta: (triton.cdiv(n_rows, meta["XBLOCK"]), 1, 1)]( - y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE - ) + softmax_kernel[lambda meta: (triton.cdiv(n_rows, meta["XBLOCK"]), 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, + n_cols, BLOCK_SIZE=BLOCK_SIZE) return y @pytest.mark.autotune -@pytest.mark.parametrize('shape,dtype', [((16896, 1024), torch.float32), ]) +@pytest.mark.parametrize('shape,dtype', [ + ((16896, 1024), torch.float32), +]) def test_softmax(shape, dtype): x = torch.randn(shape, dtype=dtype, device="npu") y_torch = softmax_torch(x) diff --git a/third_party/ascend/unittest/autotune_ut/test_low_dim_axes_parse.py b/third_party/ascend/unittest/autotune_ut/test_low_dim_axes_parse.py index 6c0be27f0d..b44355baab 100644 --- a/third_party/ascend/unittest/autotune_ut/test_low_dim_axes_parse.py +++ b/third_party/ascend/unittest/autotune_ut/test_low_dim_axes_parse.py @@ -26,14 +26,9 @@ def test_low_dim_axis_parse_base_case1(mock_autotuner): import triton.backends.ascend.runtime - @triton.autotune( - configs=[], - key=["n_elements"] - ) + @triton.autotune(configs=[], key=["n_elements"]) @triton.jit - def triton_low_dim_axis_parse_base_case1( - x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr - ): + def triton_low_dim_axis_parse_base_case1(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE # <- Separate assignment @@ -53,7 +48,7 @@ def triton_low_dim_axis_parse_base_case1( "low_dim_axes": ["x"], "reduction_axes": [], } - grid = lambda meta: (meta["BLOCK_SIZE"],) + grid = lambda meta: (meta["BLOCK_SIZE"], ) act_res = triton_low_dim_axis_parse_base_case1[grid]() check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/unittest/autotune_ut/test_mask_parse.py b/third_party/ascend/unittest/autotune_ut/test_mask_parse.py index 1b7d2383da..7a18f9f9f8 100644 --- a/third_party/ascend/unittest/autotune_ut/test_mask_parse.py +++ b/third_party/ascend/unittest/autotune_ut/test_mask_parse.py @@ -30,13 +30,12 @@ def test_triton_dot_case1(mock_autotuner): """ import triton.backends.ascend.runtime - @triton.autotune( - configs=[], - key=["M", "N", "K"] - ) + @triton.autotune(configs=[], key=["M", "N", "K"]) @triton.jit def triton_dot_case1( - A, B, C, + A, + B, + C, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, @@ -65,9 +64,9 @@ def triton_dot_case1( for loop_k in range(loops_k): kdx = loop_k * KBLOCK_SUB + tl.arange(0, KBLOCK_SUB) - kdx_m = kdx[None, :] # <- - A_ptr = A + mdx * K + kdx_m - a_mask = (mdx < M) & (kdx_m < K) # Use res of Subscript in mask compare + kdx_m = kdx[None, :] # <- + A_ptr = A + mdx * K + kdx_m + a_mask = (mdx < M) & (kdx_m < K) # Use res of Subscript in mask compare a = tl.load(A_ptr, mask=a_mask, other=0.0) kdx_n = kdx[:, None] @@ -76,7 +75,7 @@ def triton_dot_case1( b = tl.load(B_ptr, mask=b_mask, other=0.0) acc += tl.dot(a, b) - + C_ptr = C + mdx * N + ndx c_mask = (mdx < M) & (ndx < N) tl.store(C_ptr, acc, mask=c_mask) @@ -102,13 +101,12 @@ def test_triton_dot_case2(mock_autotuner): """ import triton.backends.ascend.runtime - @triton.autotune( - configs=[], - key=["M", "N", "K"] - ) + @triton.autotune(configs=[], key=["M", "N", "K"]) @triton.jit def triton_dot_case2( - A, B, C, + A, + B, + C, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, @@ -137,8 +135,8 @@ def triton_dot_case2( for loop_k in range(loops_k): kdx = loop_k * KBLOCK_SUB + tl.arange(0, KBLOCK_SUB) - A_ptr = A + mdx * K + kdx[None, :] # <- - a_mask = (mdx < M) & (kdx[None, :] < K) # Cal subsript directly in mask compare + A_ptr = A + mdx * K + kdx[None, :] # <- + a_mask = (mdx < M) & (kdx[None, :] < K) # Cal subsript directly in mask compare a = tl.load(A_ptr, mask=a_mask, other=0.0) B_ptr = B + kdx[:, None] * N + ndx @@ -146,7 +144,7 @@ def triton_dot_case2( b = tl.load(B_ptr, mask=b_mask, other=0.0) acc += tl.dot(a, b) - + C_ptr = C + mdx * N + ndx c_mask = (mdx < M) & (ndx < N) tl.store(C_ptr, acc, mask=c_mask) diff --git a/third_party/ascend/unittest/autotune_ut/test_no_tiling_axis_parse.py b/third_party/ascend/unittest/autotune_ut/test_no_tiling_axis_parse.py index af7945bf9d..f708cef694 100644 --- a/third_party/ascend/unittest/autotune_ut/test_no_tiling_axis_parse.py +++ b/third_party/ascend/unittest/autotune_ut/test_no_tiling_axis_parse.py @@ -37,20 +37,18 @@ def case_torch(x): return torch.permute(x, (1, 0)) -@triton.autotune( - configs=[], - key=['xnumel', 'ynumel'], - hints={ - "auto_gen_config": True, - } -) +@triton.autotune(configs=[], key=['xnumel', 'ynumel'], hints={ + "auto_gen_config": True, +}) @triton.jit -def triton_permute_2d(output_ptr, - x_ptr, - xnumel: tl.constexpr, - ynumel: tl.constexpr, - XBLOCK: tl.constexpr, - YBLOCK: tl.constexpr, ): +def triton_permute_2d( + output_ptr, + x_ptr, + xnumel: tl.constexpr, + ynumel: tl.constexpr, + XBLOCK: tl.constexpr, + YBLOCK: tl.constexpr, +): xpid = tl.program_id(0) ypid = tl.program_id(1) @@ -72,11 +70,13 @@ def case_triton(x_cal, is_simt_only=False): ynumel = x_cal.shape[1] output = torch.randint(1, (ynumel, xnumel), dtype=x_cal.dtype, device=x_cal.device) if is_simt_only: - (triton_permute_2d[lambda meta: (triton.cdiv(xnumel, meta['XBLOCK']), triton.cdiv(ynumel, meta['YBLOCK']), 1)] - (output, x_cal, xnumel, ynumel, force_simt_only=True)) + (triton_permute_2d[lambda meta: (triton.cdiv(xnumel, meta['XBLOCK']), triton.cdiv(ynumel, meta['YBLOCK']), 1)]( + output, x_cal, xnumel, ynumel, force_simt_only=True)) else: - (triton_permute_2d[lambda meta: (triton.cdiv(xnumel, meta['XBLOCK']), triton.cdiv(ynumel, meta['YBLOCK']), 1)] - (output, x_cal, xnumel, ynumel)) + (triton_permute_2d[lambda meta: + (triton.cdiv(xnumel, meta['XBLOCK']), triton.cdiv(ynumel, meta['YBLOCK']), 1)](output, x_cal, + xnumel, + ynumel)) return output diff --git a/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py b/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py index b1c2a8c97b..e6f3baa442 100644 --- a/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py +++ b/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py @@ -28,12 +28,12 @@ def test_triton_max_last_dim_case1(mock_autotuner): @triton.autotune(configs=[], key=["x0_numel", "r1_numel"]) @triton.jit def triton_max_last_dim1( - in_ptr0, - out_ptr0, - x0_numel, - r1_numel, - X0BLOCK: tl.constexpr, - X0BLOCK_SUB: tl.constexpr, + in_ptr0, + out_ptr0, + x0_numel, + r1_numel, + X0BLOCK: tl.constexpr, + X0BLOCK_SUB: tl.constexpr, R1BLOCK_SUB: tl.constexpr, ): x0_offset = tl.program_id(0) * X0BLOCK @@ -51,7 +51,7 @@ def triton_max_last_dim1( tmp = tl.load(in_ptr0 + (r1 + r1_numel * x0), r1_mask & x0_mask, other=float("-inf")) block_val = tl.maximum(block_val, tmp) # Reduce along axis = 1 (the last dimension in this 2D tensor) - block_res = tl.max(block_val, axis=1)[:, None] # <- explicit positive axis index + block_res = tl.max(block_val, axis=1)[:, None] # <- explicit positive axis index tl.store(out_ptr0 + x0, block_res, x0_mask) ref_res = { @@ -61,7 +61,7 @@ def triton_max_last_dim1( "low_dim_axes": ["ry"], "reduction_axes": ["ry"], } - grid = lambda meta: (meta["X0BLOCK"],) + grid = lambda meta: (meta["X0BLOCK"], ) act_res = triton_max_last_dim1[grid]() check_axes_parse_res(act_res, ref_res) @@ -70,18 +70,15 @@ def triton_max_last_dim1( def test_triton_max_last_dim_case2(mock_autotuner): import triton.backends.ascend.runtime - @triton.autotune( - configs=[], - key=["x0_numel", "r1_numel"] - ) + @triton.autotune(configs=[], key=["x0_numel", "r1_numel"]) @triton.jit def triton_max_last_dim2( - in_ptr0, - out_ptr0, - x0_numel, - r1_numel, - X0BLOCK: tl.constexpr, - X0BLOCK_SUB: tl.constexpr, + in_ptr0, + out_ptr0, + x0_numel, + r1_numel, + X0BLOCK: tl.constexpr, + X0BLOCK_SUB: tl.constexpr, R1BLOCK_SUB: tl.constexpr, ): x0_offset = tl.program_id(0) * X0BLOCK @@ -99,7 +96,7 @@ def triton_max_last_dim2( tmp = tl.load(in_ptr0 + (r1 + r1_numel * x0), r1_mask & x0_mask, other=float("-inf")) block_val = tl.maximum(block_val, tmp) # Reduce along axis=-1 (the last dimension, equivalent to axis=1 in 2D) - block_res = tl.max(block_val, axis=-1)[:, None] # <- negative axis index (last dim) + block_res = tl.max(block_val, axis=-1)[:, None] # <- negative axis index (last dim) tl.store(out_ptr0 + x0, block_res, x0_mask) ref_res = { @@ -109,7 +106,7 @@ def triton_max_last_dim2( "low_dim_axes": ["ry"], "reduction_axes": ["ry"], } - grid = lambda meta: (meta["X0BLOCK"],) + grid = lambda meta: (meta["X0BLOCK"], ) act_res = triton_max_last_dim2[grid]() check_axes_parse_res(act_res, ref_res) @@ -118,18 +115,15 @@ def triton_max_last_dim2( def test_triton_max_last_dim_case3(mock_autotuner): import triton.backends.ascend.runtime - @triton.autotune( - configs=[], - key=["x0_numel", "r1_numel"] - ) + @triton.autotune(configs=[], key=["x0_numel", "r1_numel"]) @triton.jit def triton_max_last_dim3( - in_ptr0, - out_ptr0, - x0_numel, - r1_numel, - X0BLOCK: tl.constexpr, - X0BLOCK_SUB: tl.constexpr, + in_ptr0, + out_ptr0, + x0_numel, + r1_numel, + X0BLOCK: tl.constexpr, + X0BLOCK_SUB: tl.constexpr, R1BLOCK_SUB: tl.constexpr, ): x0_offset = tl.program_id(0) * X0BLOCK @@ -147,7 +141,7 @@ def triton_max_last_dim3( tmp = tl.load(in_ptr0 + (r1 + r1_numel * x0), r1_mask & x0_mask, other=float("-inf")) block_val = tl.maximum(block_val, tmp) # Reduce along axis=1, passed as a positional argument (not keyword `axis=...`) - block_res = tl.max(block_val, 1)[:, None] # <- explicit positive axis index + block_res = tl.max(block_val, 1)[:, None] # <- explicit positive axis index tl.store(out_ptr0 + x0, block_res, x0_mask) ref_res = { @@ -157,7 +151,7 @@ def triton_max_last_dim3( "low_dim_axes": ["ry"], "reduction_axes": ["ry"], } - grid = lambda meta: (meta["X0BLOCK"],) + grid = lambda meta: (meta["X0BLOCK"], ) act_res = triton_max_last_dim3[grid]() check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/unittest/autotune_ut/test_split_axis_parse.py b/third_party/ascend/unittest/autotune_ut/test_split_axis_parse.py index a2d70314a8..078873341f 100644 --- a/third_party/ascend/unittest/autotune_ut/test_split_axis_parse.py +++ b/third_party/ascend/unittest/autotune_ut/test_split_axis_parse.py @@ -29,14 +29,9 @@ def test_split_axis_parse_base_case1(mock_autotuner): import triton.backends.ascend.runtime - @triton.autotune( - configs=[], - key=["n_elements"] - ) + @triton.autotune(configs=[], key=["n_elements"]) @triton.jit - def triton_split_axis_parse_base_case1( - x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr - ): + def triton_split_axis_parse_base_case1(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE # <- Separate assignment @@ -56,7 +51,7 @@ def triton_split_axis_parse_base_case1( "low_dim_axes": ["x"], "reduction_axes": [], } - grid = lambda meta: (meta["BLOCK_SIZE"],) + grid = lambda meta: (meta["BLOCK_SIZE"], ) act_res = triton_split_axis_parse_base_case1[grid]() check_axes_parse_res(act_res, ref_res) @@ -65,14 +60,9 @@ def triton_split_axis_parse_base_case1( def test_split_axis_parse_base_case2(mock_autotuner): import triton.backends.ascend.runtime - @triton.autotune( - configs=[], - key=["n_elements"] - ) + @triton.autotune(configs=[], key=["n_elements"]) @triton.jit - def triton_split_axis_parse_base_case2( - x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr - ): + def triton_split_axis_parse_base_case2(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): block_start = tl.program_id(axis=0) * BLOCK_SIZE # <- Computed inline but still named offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -91,7 +81,7 @@ def triton_split_axis_parse_base_case2( "low_dim_axes": ["x"], "reduction_axes": [], } - grid = lambda meta: (meta["BLOCK_SIZE"],) + grid = lambda meta: (meta["BLOCK_SIZE"], ) act_res = triton_split_axis_parse_base_case2[grid]() check_axes_parse_res(act_res, ref_res) @@ -99,15 +89,10 @@ def triton_split_axis_parse_base_case2( def test_split_axis_parse_base_case3(mock_autotuner): import triton.backends.ascend.runtime - - @triton.autotune( - configs=[], - key=["n_elements"] - ) + + @triton.autotune(configs=[], key=["n_elements"]) @triton.jit - def triton_split_axis_parse_base_case3( - x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr - ): + def triton_split_axis_parse_base_case3(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # <- Fully fused mask = offsets < n_elements @@ -124,7 +109,7 @@ def triton_split_axis_parse_base_case3( "low_dim_axes": ["x"], "reduction_axes": [], } - grid = lambda meta: (meta["BLOCK_SIZE"],) + grid = lambda meta: (meta["BLOCK_SIZE"], ) act_res = triton_split_axis_parse_base_case3[grid]() check_axes_parse_res(act_res, ref_res) @@ -133,10 +118,7 @@ def triton_split_axis_parse_base_case3( def test_grid_stride_loop_block_only_tiling_semantics(mock_autotuner): import triton.backends.ascend.runtime - @triton.autotune( - configs=[], - key=["N", "index_len"] - ) + @triton.autotune(configs=[], key=["N", "index_len"]) @triton.jit def triton_grid_stride_loop_block_only_tiling_semantics( input_ptr, diff --git a/third_party/ascend/unittest/autotune_ut/test_tiling_axis_parse.py b/third_party/ascend/unittest/autotune_ut/test_tiling_axis_parse.py index 104535c781..1c5e150932 100644 --- a/third_party/ascend/unittest/autotune_ut/test_tiling_axis_parse.py +++ b/third_party/ascend/unittest/autotune_ut/test_tiling_axis_parse.py @@ -26,15 +26,11 @@ def test_tiling_axis_parse_base_case1(mock_autotuner): import triton.backends.ascend.runtime - - @triton.autotune( - configs=[], - key=["n_elements"] - ) + + @triton.autotune(configs=[], key=["n_elements"]) @triton.jit - def triton_tiling_axis_parse_base_case1( - x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_SUB: tl.constexpr - ): + def triton_tiling_axis_parse_base_case1(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + BLOCK_SUB: tl.constexpr): offset = tl.program_id(axis=0) * BLOCK_SIZE base = tl.arange(0, BLOCK_SUB) loops = (BLOCK_SIZE + BLOCK_SUB - 1) // BLOCK_SUB # <- @@ -55,7 +51,7 @@ def triton_tiling_axis_parse_base_case1( "low_dim_axes": ["x"], "reduction_axes": [], } - grid = lambda meta: (meta["BLOCK_SIZE"],) + grid = lambda meta: (meta["BLOCK_SIZE"], ) act_res = triton_tiling_axis_parse_base_case1[grid]() check_axes_parse_res(act_res, ref_res) @@ -64,15 +60,11 @@ def triton_tiling_axis_parse_base_case1( @pytest.mark.skip def test_tiling_axis_parse_base_case2(mock_autotuner): import triton.backends.ascend.runtime - - @triton.autotune( - configs=[], - key=["n_elements"] - ) + + @triton.autotune(configs=[], key=["n_elements"]) @triton.jit - def triton_tiling_axis_parse_base_case2( - x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_SUB: tl.constexpr - ): + def triton_tiling_axis_parse_base_case2(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + BLOCK_SUB: tl.constexpr): offset = tl.program_id(axis=0) * BLOCK_SIZE base = tl.arange(0, BLOCK_SUB) for offset_sub in range(0, BLOCK_SIZE, BLOCK_SUB): @@ -92,7 +84,7 @@ def triton_tiling_axis_parse_base_case2( "low_dim_axes": ["x"], "reduction_axes": [], } - grid = lambda meta: (meta["BLOCK_SIZE"],) + grid = lambda meta: (meta["BLOCK_SIZE"], ) act_res = triton_tiling_axis_parse_base_case2[grid]() check_axes_parse_res(act_res, ref_res) @@ -101,15 +93,11 @@ def triton_tiling_axis_parse_base_case2( @pytest.mark.skip def test_tiling_axis_parse_base_case3(mock_autotuner): import triton.backends.ascend.runtime - - @triton.autotune( - configs=[], - key=["n_elements"] - ) + + @triton.autotune(configs=[], key=["n_elements"]) @triton.jit - def triton_tiling_axis_parse_base_case3( - x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_SUB: tl.constexpr - ): + def triton_tiling_axis_parse_base_case3(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + BLOCK_SUB: tl.constexpr): offset = tl.program_id(axis=0) * BLOCK_SIZE base = tl.arange(0, BLOCK_SUB)[:] # <- for offset_sub in range(0, BLOCK_SIZE, BLOCK_SUB): @@ -129,7 +117,7 @@ def triton_tiling_axis_parse_base_case3( "low_dim_axes": ["x"], "reduction_axes": [], } - grid = lambda meta: (meta["BLOCK_SIZE"],) + grid = lambda meta: (meta["BLOCK_SIZE"], ) act_res = triton_tiling_axis_parse_base_case3[grid]() check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/unittest/custom_op/builtin_ops_demo.py b/third_party/ascend/unittest/custom_op/builtin_ops_demo.py index 232b54fccc..e3e9fc6750 100644 --- a/third_party/ascend/unittest/custom_op/builtin_ops_demo.py +++ b/third_party/ascend/unittest/custom_op/builtin_ops_demo.py @@ -18,32 +18,13 @@ def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): index = tl.full([8], 0, tl.int32) value = tl.full([8, 64], 0, tl.float32) tmp = tl.full([8], 0, tl.float32) - x = al.custom("__builtin_index_select", - x_ptr, index, - dim=0, - bound=100, - end_offset=(2, 2), - start_offset=(0, 0), - src_stride=(4, 1), - out=x) - al.custom("__builtin_index_put", - x_ptr, index, value, - dim=0, - bound=12, - dst_shape=(1, 2, 3), - dst_offset=(4, 5, 6), + x = al.custom("__builtin_index_select", x_ptr, index, dim=0, bound=100, end_offset=(2, 2), start_offset=(0, 0), + src_stride=(4, 1), out=x) + al.custom("__builtin_index_put", x_ptr, index, value, dim=0, bound=12, dst_shape=(1, 2, 3), dst_offset=(4, 5, 6), dst_stride=(8, 4, 1)) - tmp = al.custom("__builtin_gather_load", - y_ptr, index, - bound=100, - dim=0, - src_stride=(1,), - index_shape=(3,), - offsets=(0,), - out=tmp) - al.custom("__builtin_scatter_store", - out_ptr, value, index, - 1, 0, (1, ), (2, ), (1, )) + tmp = al.custom("__builtin_gather_load", y_ptr, index, bound=100, dim=0, src_stride=(1, ), index_shape=(3, ), + offsets=(0, ), out=tmp) + al.custom("__builtin_scatter_store", out_ptr, value, index, 1, 0, (1, ), (2, ), (1, )) y = al.custom("__builtin_indirect_load", x_ptr, index, mask=i < n, other=y, out=y) al.custom("__builtin_indirect_store", out_ptr, index, value) tl.store(out_ptr + i, y, mask=i < n) diff --git a/third_party/ascend/unittest/custom_op/custom_op_demo.py b/third_party/ascend/unittest/custom_op/custom_op_demo.py index 4817658627..a28d6cdea6 100644 --- a/third_party/ascend/unittest/custom_op/custom_op_demo.py +++ b/third_party/ascend/unittest/custom_op/custom_op_demo.py @@ -80,8 +80,7 @@ def example_op(src, index, offset, axis, _builder=None): # output can be provided here to make it easy to use. x = tl.semantic.full(src.shape, 0, tl.float32, _builder) y = tl.semantic.full(index.shape, 0, tl.float32, _builder) - return al.custom_semantic(_example_custom_op.name, - src, index, offset, axis, out=(x, y), _builder=_builder) + return al.custom_semantic(_example_custom_op.name, src, index, offset, axis, out=(x, y), _builder=_builder) @triton.jit diff --git a/third_party/ascend/unittest/custom_op/custom_op_extra_buffer_demo.py b/third_party/ascend/unittest/custom_op/custom_op_extra_buffer_demo.py index 4469d7eec0..fe2dbe351b 100644 --- a/third_party/ascend/unittest/custom_op/custom_op_extra_buffer_demo.py +++ b/third_party/ascend/unittest/custom_op/custom_op_extra_buffer_demo.py @@ -102,10 +102,8 @@ def main() -> None: elif parsed: print("Note: parsed sizes differ from spec; inspect MLIR spelling below.") else: - print( - "Could not parse extra_buffers_sizes automatically; " - "search the dump for 'extra_buffers_sizes'." - ) + print("Could not parse extra_buffers_sizes automatically; " + "search the dump for 'extra_buffers_sizes'.") print("\n--- MLIR excerpt (lines containing hivm.hir.custom) ---") for line in mlir.splitlines(): diff --git a/third_party/ascend/unittest/custom_op/test_gather_load.py b/third_party/ascend/unittest/custom_op/test_gather_load.py index dc92988d28..03e83b9171 100644 --- a/third_party/ascend/unittest/custom_op/test_gather_load.py +++ b/third_party/ascend/unittest/custom_op/test_gather_load.py @@ -17,14 +17,8 @@ def test_gather_load_kernel(src_ptr, index_ptr, out_ptr): # gather load from GM to UB dst = tl.full(index.shape, 0, tl.float32) - gathered = al.custom("__builtin_gather_load", - src_ptr, index, - bound=4, - dim=0, - src_stride=(2, 1), - index_shape=(2, 2), - offsets=(0, 0), - out=dst) + gathered = al.custom("__builtin_gather_load", src_ptr, index, bound=4, dim=0, src_stride=(2, 1), index_shape=(2, 2), + offsets=(0, 0), out=dst) # store result to GM tl.store(out_ptr + rows * 2 + cols, gathered, mask) @@ -34,5 +28,5 @@ def test_gather_load_kernel(src_ptr, index_ptr, out_ptr): src = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]], device='npu') index = torch.tensor([[0, 1], [2, 3]], device='npu') out = torch.empty((2, 2), device='npu', dtype=torch.float32) - test_gather_load_kernel[(1,)](src, index, out) + test_gather_load_kernel[(1, )](src, index, out) print("result: ", out) # [[1., 4.], [5., 8.]] diff --git a/third_party/ascend/unittest/custom_op/test_index_select.py b/third_party/ascend/unittest/custom_op/test_index_select.py index 97c3e72502..d06174fde1 100644 --- a/third_party/ascend/unittest/custom_op/test_index_select.py +++ b/third_party/ascend/unittest/custom_op/test_index_select.py @@ -17,17 +17,15 @@ def builtin_index_select_kernel(src_ptr, index_ptr, out_ptr): dst = tl.full((2, 2), 0, dtype=tl.float32) # Invoke __builtin_index_select custom op to gather elements - out_tile = al.custom( - "__builtin_index_select", - src_ptr, # Pointer to source tensor in GM - idx, # Index tensor (in UB) for gathering - dim=0, # Dimension to gather along - bound=4, # Upper bound for valid index values (out-of-bound check) - end_offset=(2, 2),# End offsets of each dimension for the index tensor - start_offset=(0, 0), # Start offsets of each dimension for the source tensor - src_stride=(4, 1),# Stride of each dimension for the source tensor in GM - out=dst # Output tensor (in UB) to store gathered elements - ) + out_tile = al.custom("__builtin_index_select", src_ptr, # Pointer to source tensor in GM + idx, # Index tensor (in UB) for gathering + dim=0, # Dimension to gather along + bound=4, # Upper bound for valid index values (out-of-bound check) + end_offset=(2, 2), # End offsets of each dimension for the index tensor + start_offset=(0, 0), # Start offsets of each dimension for the source tensor + src_stride=(4, 1), # Stride of each dimension for the source tensor in GM + out=dst # Output tensor (in UB) to store gathered elements + ) # Store the gathered tile from UB to output tensor in GM tl.store(out_ptr + r * 2 + c, out_tile) @@ -35,15 +33,12 @@ def builtin_index_select_kernel(src_ptr, index_ptr, out_ptr): if __name__ == "__main__": src = torch.tensor( - [[10., 11., 12., 13.], - [20., 21., 22., 23.], - [30., 31., 32., 33.], - [40., 41., 42., 43.]], + [[10., 11., 12., 13.], [20., 21., 22., 23.], [30., 31., 32., 33.], [40., 41., 42., 43.]], device="npu", dtype=torch.float32, ) index = torch.tensor([2, 0], device="npu", dtype=torch.int32) out = torch.empty((2, 2), device="npu", dtype=torch.float32) ref = torch.index_select(src, 0, index.to(torch.int64))[:, :2] - builtin_index_select_kernel[(1,)](src, index, out) - torch.testing.assert_close(out, ref) # ref: [[30., 31.], [10., 11.]] + builtin_index_select_kernel[(1, )](src, index, out) + torch.testing.assert_close(out, ref) # ref: [[30., 31.], [10., 11.]] diff --git a/third_party/ascend/unittest/pytest_ut/test_01_vector_add.py b/third_party/ascend/unittest/pytest_ut/test_01_vector_add.py index 5b3554c7c2..8d59474dfb 100644 --- a/third_party/ascend/unittest/pytest_ut/test_01_vector_add.py +++ b/third_party/ascend/unittest/pytest_ut/test_01_vector_add.py @@ -19,7 +19,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - """ Vector Addition - Pytest Version """ diff --git a/third_party/ascend/unittest/pytest_ut/test_02_fused_softmax.py b/third_party/ascend/unittest/pytest_ut/test_02_fused_softmax.py index e4b8c0a3f6..91fde18a95 100644 --- a/third_party/ascend/unittest/pytest_ut/test_02_fused_softmax.py +++ b/third_party/ascend/unittest/pytest_ut/test_02_fused_softmax.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Fused Softmax ============= @@ -51,7 +50,8 @@ def naive_softmax(x): @triton.jit -def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr): +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, + BLOCK_SIZE: tl.constexpr): # starting row of the program row_start = tl.program_id(0) row_step = tl.num_programs(0) @@ -99,15 +99,7 @@ def softmax(x): num_programs = min(num_programs, n_rows) # Create a number of persistent programs. - kernel[(num_programs, 1, 1)]( - y, - x, - x.stride(0), - y.stride(0), - n_rows, - n_cols, - BLOCK_SIZE - ) + kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE) return y diff --git a/third_party/ascend/unittest/pytest_ut/test_03_matrix_multiplication.py b/third_party/ascend/unittest/pytest_ut/test_03_matrix_multiplication.py index 43aabfafc7..ec5f8e8655 100644 --- a/third_party/ascend/unittest/pytest_ut/test_03_matrix_multiplication.py +++ b/third_party/ascend/unittest/pytest_ut/test_03_matrix_multiplication.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Matrix Multiplication =============== @@ -46,28 +45,19 @@ def get_autotune_config(): ) @triton.jit def matmul_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - # Matrix dimensions - M, - N, - K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, - stride_ak, # - stride_bk, - stride_bn, # - stride_cm, - stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, # - ACTIVATION: tl.constexpr, # + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + ACTIVATION: tl.constexpr, # ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -132,9 +122,7 @@ def matmul_kernel( # Comment out the following lines to enable split the workload to two vector cores SUB_BLK_M: tl.constexpr = BLOCK_SIZE_M // 2 for s in extension.parallel(0, 2, bind_sub_block=True): - vec_sub_blk = extension.extract_slice( - accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1) - ) + vec_sub_blk = extension.extract_slice(accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1)) if ACTIVATION == "leaky_relu_custom": vec_sub_blk = leaky_relu_custom(vec_sub_blk) c_sub_blk = vec_sub_blk.to(tl.float16) @@ -173,24 +161,18 @@ def matmul(a, b, activation=""): K, N = b.shape # Allocates output. c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. def grid(META): return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) matmul_kernel[grid]( - a, - b, - c, # - M, - N, - K, # - a.stride(0), - a.stride(1), # - b.stride(0), - b.stride(1), # - c.stride(0), - c.stride(1), # + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # ACTIVATION=activation, # ) return c @@ -212,7 +194,8 @@ def grid(META): "activation", [ "", - pytest.param("leaky_relu_custom", marks=pytest.mark.skip(reason="temporarily skip leaky_relu_custom ub overflow case")), + pytest.param("leaky_relu_custom", + marks=pytest.mark.skip(reason="temporarily skip leaky_relu_custom ub overflow case")), ], ) def test_matrix_multiplication(shape, activation): diff --git a/third_party/ascend/unittest/pytest_ut/test_04_low_memory_dropout.py b/third_party/ascend/unittest/pytest_ut/test_04_low_memory_dropout.py index 1aab7a20d6..f615608e7a 100644 --- a/third_party/ascend/unittest/pytest_ut/test_04_low_memory_dropout.py +++ b/third_party/ascend/unittest/pytest_ut/test_04_low_memory_dropout.py @@ -19,7 +19,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Low-Memory Dropout ================== @@ -61,10 +60,10 @@ def dropout(x, x_keep, p): output = torch.empty_like(x) assert x.is_contiguous() n_elements = x.numel() - + def grid(meta): return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) - + _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) return output @@ -97,10 +96,10 @@ def seeded_dropout(x, p, seed): output = torch.empty_like(x) assert x.is_contiguous() n_elements = x.numel() - + def grid(meta): return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) - + _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) return output @@ -117,7 +116,8 @@ def test_dropout_matches_reference(shape, p): torch.testing.assert_close(output, expected, atol=1e-6, rtol=0) -@pytest.mark.parametrize("shape,p,seed", [((10, ), 0.5, 123), ((256, ), 0.5, 123), ((513, ), 0.2, 7), ((32, 64), 0.35, 999)]) +@pytest.mark.parametrize("shape,p,seed", [((10, ), 0.5, 123), ((256, ), 0.5, 123), ((513, ), 0.2, 7), + ((32, 64), 0.35, 999)]) def test_seeded_dropout_is_deterministic(shape, p, seed): torch.manual_seed(0) x = torch.randn(size=shape, device=DEV, dtype=torch.float32) diff --git a/third_party/ascend/unittest/pytest_ut/test_05_layer_norm.py b/third_party/ascend/unittest/pytest_ut/test_05_layer_norm.py index ef14ac70f7..84e3268b9b 100644 --- a/third_party/ascend/unittest/pytest_ut/test_05_layer_norm.py +++ b/third_party/ascend/unittest/pytest_ut/test_05_layer_norm.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Layer Normalization ============= @@ -109,7 +108,7 @@ def test_layer_norm(dtype): device = 'npu' x_shape = (M, N) - w_shape = (x_shape[-1],) + w_shape = (x_shape[-1], ) weight = torch.rand(w_shape, dtype=dtype, device=device) bias = torch.rand(w_shape, dtype=dtype, device=device) x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) diff --git a/third_party/ascend/unittest/pytest_ut/test_06_fused_attention.py b/third_party/ascend/unittest/pytest_ut/test_06_fused_attention.py index 5d8c695c19..8f83d6a730 100644 --- a/third_party/ascend/unittest/pytest_ut/test_06_fused_attention.py +++ b/third_party/ascend/unittest/pytest_ut/test_06_fused_attention.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Fused Attention =============== @@ -48,8 +47,10 @@ def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, quer K_block_ptr, V_block_ptr, # Key and value block pointers for current stage start_m, qk_scale, # Starting position of current query block, qk scale factor BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # Block size constants - STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # Current stage flag, m and n offset indices - N_CTX: tl.constexpr, fp8_v: tl.constexpr): # Total context length, whether to enable FP8 for value precision + STAGE: tl.constexpr, offs_m: tl.constexpr, + offs_n: tl.constexpr, # Current stage flag, m and n offset indices + N_CTX: tl.constexpr, + fp8_v: tl.constexpr): # Total context length, whether to enable FP8 for value precision # Set the processing range [lo, hi) for the current stage (in column block units) # Causal attention, as the name implies, restricts the flow of information during computation, # only allowing the model to see the current and previous positions. @@ -144,18 +145,12 @@ def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, quer @triton.jit -def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, - stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, stride_qk: tl.constexpr, - stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, stride_kk: tl.constexpr, - stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, stride_vk: tl.constexpr, - stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, stride_on: tl.constexpr, - Z: tl.constexpr, H: tl.constexpr, - N_CTX: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - STAGE: tl.constexpr - ): +def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, + stride_qk: tl.constexpr, stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, + stride_kk: tl.constexpr, stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, + stride_vk: tl.constexpr, stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, + stride_on: tl.constexpr, Z: tl.constexpr, H: tl.constexpr, N_CTX: tl.constexpr, HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, STAGE: tl.constexpr): # Total number of blocks in sequence dimension (M) NUM_BLOCKS_M = N_CTX // BLOCK_M # Total tasks = number of sequence blocks × batch size (Z) × number of attention heads (H) @@ -214,11 +209,8 @@ def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, if HEAD_DIM < 256: acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) else: - acc_offset = ( - off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM - + off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM - + task_m_idx * BLOCK_M * HEAD_DIM - ) + acc_offset = (off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM + + off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM + task_m_idx * BLOCK_M * HEAD_DIM) acc_ptr = acc + acc_offset # load q: it will stay in SRAM throughout @@ -293,18 +285,11 @@ def forward(ctx, q, k, v, causal, sm_scale, BM, BN): acc = torch.zeros((q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), dtype=torch.float32, device=q.device) M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - _attn_fwd[(num_cores,)]( - q, k, v, M, out, acc, sm_scale, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - out.stride(0), out.stride(1), out.stride(2), out.stride(3), - q.shape[0], q.shape[1], N_CTX=q.shape[2], - HEAD_DIM=HEAD_DIM_K, - BLOCK_M=BM, - BLOCK_N=BN, - STAGE=stage, - **extra_kern_args) + _attn_fwd[(num_cores, )](q, k, v, M, out, acc, sm_scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), + v.stride(2), v.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), + q.shape[0], q.shape[1], N_CTX=q.shape[2], HEAD_DIM=HEAD_DIM_K, BLOCK_M=BM, BLOCK_N=BN, + STAGE=stage, **extra_kern_args) ctx.save_for_backward(q, k, v, out, M) ctx.sm_scale = sm_scale @@ -338,7 +323,10 @@ def test_attention_fused(Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN): sm_scale = 0.5 tri_out = attention(q, k, v, causal, sm_scale, BM, BN) ref_out = torch_npu.npu_fusion_attention( - q, k, v, H, + q, + k, + v, + H, padding_mask=None, atten_mask=None, scale=sm_scale, diff --git a/third_party/ascend/unittest/pytest_ut/test_07_extern_functions.py b/third_party/ascend/unittest/pytest_ut/test_07_extern_functions.py index 3da074572f..01e41209c7 100644 --- a/third_party/ascend/unittest/pytest_ut/test_07_extern_functions.py +++ b/third_party/ascend/unittest/pytest_ut/test_07_extern_functions.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Libdevice (`tl.extra.libdevice`) function ============================== diff --git a/third_party/ascend/unittest/pytest_ut/test_08_grouped_gemm.py b/third_party/ascend/unittest/pytest_ut/test_08_grouped_gemm.py index c4954ebb16..830f9e57d6 100644 --- a/third_party/ascend/unittest/pytest_ut/test_08_grouped_gemm.py +++ b/third_party/ascend/unittest/pytest_ut/test_08_grouped_gemm.py @@ -18,7 +18,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Group GEMM ============================ @@ -164,7 +163,7 @@ def group_gemm_fn(group_A, group_B): d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) def grid(meta): - return (meta['NUM_SM'],) + return (meta['NUM_SM'], ) grouped_matmul_kernel[grid]( d_a_ptrs, @@ -218,7 +217,7 @@ def test_grouped_gemm_tutorial_example(group_m, group_n, group_k): def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): def grid(meta): - return (meta['NUM_SM'],) + return (meta['NUM_SM'], ) grouped_matmul_kernel[grid]( a_ptrs, diff --git a/third_party/ascend/unittest/pytest_ut/test_09_persistent_matmul.py b/third_party/ascend/unittest/pytest_ut/test_09_persistent_matmul.py index 3b2e3e3e48..f73a3fc46f 100644 --- a/third_party/ascend/unittest/pytest_ut/test_09_persistent_matmul.py +++ b/third_party/ascend/unittest/pytest_ut/test_09_persistent_matmul.py @@ -19,7 +19,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Persistent Matmul ===================== @@ -217,7 +216,7 @@ def matmul(a, b): c = torch.empty((M, N), device=a.device, dtype=a.dtype) def grid(meta): - return (triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),) + return (triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) matmul_kernel[grid]( a, @@ -252,7 +251,7 @@ def matmul_persistent(a, b): c = torch.empty((M, N), device=a.device, dtype=a.dtype) def grid(meta): - return (min(num_sms, triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"])),) + return (min(num_sms, triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"])), ) matmul_kernel_persistent[grid]( a, diff --git a/third_party/ascend/unittest/pytest_ut/test_10_gather_sorted.py b/third_party/ascend/unittest/pytest_ut/test_10_gather_sorted.py index 38614b9405..8fcc3f6b9e 100644 --- a/third_party/ascend/unittest/pytest_ut/test_10_gather_sorted.py +++ b/third_party/ascend/unittest/pytest_ut/test_10_gather_sorted.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Gather sorted =============== @@ -51,7 +50,9 @@ def torch_gather_sorted(embeddings, sorted_idxes, aux_idxes): # triton-version gather_sorted's kernel @triton.jit -def gather_sorted_kernel(embeddings_ptr, sorted_indices_ptr, aux_indices_ptr, res_ptr, rows, cols, DEFAULT_VALUE: tl.constexpr, BIG_CORE_NUM: tl.constexpr, BIG_ROW_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE_SUB: tl.constexpr): +def gather_sorted_kernel(embeddings_ptr, sorted_indices_ptr, aux_indices_ptr, res_ptr, rows, cols, + DEFAULT_VALUE: tl.constexpr, BIG_CORE_NUM: tl.constexpr, BIG_ROW_BLOCK_SIZE: tl.constexpr, + COL_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE_SUB: tl.constexpr): SMALL_ROW_BLOCK_SIZE = BIG_ROW_BLOCK_SIZE - 1 emb_dtype = embeddings_ptr.type.element_ty @@ -60,7 +61,8 @@ def gather_sorted_kernel(embeddings_ptr, sorted_indices_ptr, aux_indices_ptr, re core_idx = tl.program_id(0) # compute the the size and start index of block row_block_size = BIG_ROW_BLOCK_SIZE if (core_idx < BIG_CORE_NUM) else SMALL_ROW_BLOCK_SIZE - row_start_idx = (core_idx * BIG_ROW_BLOCK_SIZE) if (core_idx < BIG_CORE_NUM) else (BIG_CORE_NUM * BIG_ROW_BLOCK_SIZE + (core_idx - BIG_CORE_NUM) * SMALL_ROW_BLOCK_SIZE) + row_start_idx = (core_idx * BIG_ROW_BLOCK_SIZE) if (core_idx < BIG_CORE_NUM) else ( + BIG_CORE_NUM * BIG_ROW_BLOCK_SIZE + (core_idx - BIG_CORE_NUM) * SMALL_ROW_BLOCK_SIZE) # this version has 3-buffers, initilize for buffers row_block_size_0 = tl.cdiv(row_block_size, 3) @@ -126,7 +128,8 @@ def gather_sorted_kernel(embeddings_ptr, sorted_indices_ptr, aux_indices_ptr, re # triton-version gather_sorted's host -def triton_gather_sorted(embeddings: torch.Tensor, sorted_indices: torch.Tensor, aux_indices: torch.Tensor, default_value=1.0): +def triton_gather_sorted(embeddings: torch.Tensor, sorted_indices: torch.Tensor, aux_indices: torch.Tensor, + default_value=1.0): # constant settings for npu ALIGNED = 32 USE_SIZE = 96 * 1024 @@ -140,7 +143,8 @@ def triton_gather_sorted(embeddings: torch.Tensor, sorted_indices: torch.Tensor, # when writing an npu kernel using triton, # you should note that the difference between BLOCK_SIZE and BLOCK_SIZE_SUB # BLOCK_SIZE specifies the size of data that are processed in one program - col_size_aligned = triton.cdiv(embeddings.shape[-1] * embeddings.element_size(), ALIGNED) * ALIGNED // embeddings.element_size() + col_size_aligned = triton.cdiv(embeddings.shape[-1] * embeddings.element_size(), + ALIGNED) * ALIGNED // embeddings.element_size() # the data are scattered to multiple programs, which can not be even # some process more data, some process less big_row_block_size = triton.cdiv(n_rows, CORE_NUM) @@ -151,7 +155,9 @@ def triton_gather_sorted(embeddings: torch.Tensor, sorted_indices: torch.Tensor, grid = (min(n_rows, CORE_NUM), triton.cdiv(n_cols, col_block_size)) # launch the kernel - gather_sorted_kernel[grid](embeddings, sorted_indices, aux_indices, output, n_rows, n_cols, default_value, BIG_CORE_NUM=big_core_num, BIG_ROW_BLOCK_SIZE=big_row_block_size, COL_BLOCK_SIZE=col_block_size, COL_BLOCK_SIZE_SUB=col_block_size_sub) + gather_sorted_kernel[grid](embeddings, sorted_indices, aux_indices, output, n_rows, n_cols, default_value, + BIG_CORE_NUM=big_core_num, BIG_ROW_BLOCK_SIZE=big_row_block_size, + COL_BLOCK_SIZE=col_block_size, COL_BLOCK_SIZE_SUB=col_block_size_sub) return output @@ -185,7 +191,7 @@ def generate_inputs(index_shape, table_shape, dtype): @pytest.mark.parametrize("table_cols", [16, 17, 31, 32, 63, 64, 128, 256, 819, 512, 1024, 8192, 1001, 2003, 17000]) @pytest.mark.parametrize("index_num", [19, 123, 4321, 54321, 100, 200, 819, 500, 700, 1000]) def test_gather_sorted(table_rows, table_cols, index_num): - table, sorted_indices, aux_indices = generate_inputs((index_num,), (table_rows, table_cols), torch.float) + table, sorted_indices, aux_indices = generate_inputs((index_num, ), (table_rows, table_cols), torch.float) expect = torch_gather_sorted(table, sorted_indices, aux_indices).cpu() torch.npu.synchronize() diff --git a/third_party/ascend/unittest/pytest_ut/test_11_rab_time.py b/third_party/ascend/unittest/pytest_ut/test_11_rab_time.py index fe8293b51c..6603a2a457 100644 --- a/third_party/ascend/unittest/pytest_ut/test_11_rab_time.py +++ b/third_party/ascend/unittest/pytest_ut/test_11_rab_time.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Relative Attention Bias Timestamps =============== @@ -46,12 +45,10 @@ def create_pos_w(train_len: int, num_layers: int) -> torch.Tensor: def create_past_valid_lens(bs: int, past_len: int) -> torch.Tensor: - return torch.randint(0, past_len, (bs,)) + return torch.randint(0, past_len, (bs, )) -def create_timestamps( - train_len: int, candidate_len: int, past_valid_lens: torch.Tensor -) -> torch.Tensor: +def create_timestamps(train_len: int, candidate_len: int, past_valid_lens: torch.Tensor) -> torch.Tensor: bs = past_valid_lens.size(0) timestamps = torch.zeros(bs, train_len + candidate_len // 2) for i, valid_len in enumerate(past_valid_lens): @@ -66,11 +63,7 @@ def create_timestamps( def create_timestamps_weights(num_layers: int): - return ( - torch.arange(0, NUM_BUCKETS + 1) - .repeat(num_layers) - .reshape(NUM_BUCKETS + 1, num_layers) - ) + return (torch.arange(0, NUM_BUCKETS + 1).repeat(num_layers).reshape(NUM_BUCKETS + 1, num_layers)) def create_rab_time_grad(num_layers: int, batchsize: int, s: int): @@ -101,9 +94,7 @@ def rab_time_forward_kernel( col_iter_num = tl.cdiv(BLOCK_SIZE, COL_BLOCK_SIZE) for col_idx in tl.range(0, col_iter_num): - cols_offsets = ( - pid0 * BLOCK_SIZE + col_idx * COL_BLOCK_SIZE + tl.arange(0, COL_BLOCK_SIZE) - ) + cols_offsets = (pid0 * BLOCK_SIZE + col_idx * COL_BLOCK_SIZE + tl.arange(0, COL_BLOCK_SIZE)) cols_mask = cols_offsets < index_len out_mask = cols_offsets < index_len @@ -139,9 +130,7 @@ def rab_time_forward_triton(ts_w, timestamps, bucketization_divisor): num_buckets = ts_w.shape[0] - 1 timestamps_expanded = timestamps.unsqueeze(-1).repeat(1, 1, 2) - timestamps_expanded = timestamps_expanded.reshape( - bs, infer_len, 1 - ) - timestamps_expanded.reshape(bs, 1, infer_len) + timestamps_expanded = timestamps_expanded.reshape(bs, infer_len, 1) - timestamps_expanded.reshape(bs, 1, infer_len) timestamps_expanded = timestamps_expanded.view(-1) timestamps_expanded = timestamps_expanded.contiguous() @@ -150,9 +139,7 @@ def rab_time_forward_triton(ts_w, timestamps, bucketization_divisor): index_len = bs * infer_len * infer_len out = torch.empty((num_layers, index_len), dtype=ts_w.dtype, device=ts_w.device) - outer_loop_num, sub_num_layers, remain_layers = get_outer_loop_num( - num_layers, index_len - ) + outer_loop_num, sub_num_layers, remain_layers = get_outer_loop_num(num_layers, index_len) CORE_NUM = get_npu_properties()["num_vectorcore"] BLOCK_SIZE = math.ceil(index_len / CORE_NUM) @@ -184,15 +171,9 @@ def grid(meta): @triton.jit -def rab_time_backward_kernel( - inp, src, index, index_len, BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr -): +def rab_time_backward_kernel(inp, src, index, index_len, BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr): pid0 = tl.program_id(axis=0) - total_col_num = ( - BLOCK_SIZE - if pid0 * BLOCK_SIZE + BLOCK_SIZE < index_len - else index_len - pid0 * BLOCK_SIZE - ) + total_col_num = (BLOCK_SIZE if pid0 * BLOCK_SIZE + BLOCK_SIZE < index_len else index_len - pid0 * BLOCK_SIZE) COL_BLOCK_SIZE = min(COL_BLOCK_SIZE, total_col_num) col_iter_num = (total_col_num + COL_BLOCK_SIZE - 1) // COL_BLOCK_SIZE @@ -204,11 +185,8 @@ def rab_time_backward_kernel( acc_result = 0.0 acc_result = acc_result.to(inp.dtype.element_ty) - cur_col_num = ( - COL_BLOCK_SIZE - if col_start_offset + COL_BLOCK_SIZE < total_col_num - else total_col_num - col_start_offset - ) + cur_col_num = (COL_BLOCK_SIZE if col_start_offset + COL_BLOCK_SIZE < total_col_num else total_col_num - + col_start_offset) for cur_idx in range(0, cur_col_num): cur_offset = pid0 * BLOCK_SIZE + col_start_offset + cur_idx @@ -229,32 +207,23 @@ def rab_time_backward_kernel( tl.atomic_add(inp + base_idx, acc_result) -def rab_time_backward_triton( - rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor -): +def rab_time_backward_triton(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor): num_layers, b, s, _ = rab_time_grad.shape - tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to( - rab_time_grad.device - ) + tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to(rab_time_grad.device) - bucket_timestamps_expand = ( - bucket_timestamps.reshape(b, s // 2, 1, s // 2, 1) - .repeat(1, 1, 2, 1, 2) - .reshape(b, s, s) - .to(torch.int64) - ).view(-1) + bucket_timestamps_expand = (bucket_timestamps.reshape(b, s // 2, 1, s // 2, + 1).repeat(1, 1, 2, 1, 2).reshape(b, s, + s).to(torch.int64)).view(-1) index_len = bucket_timestamps_expand.numel() rab_time_grad_f32 = rab_time_grad.to(torch.float32) - sorted_bucket_timestamps_expand, sorted_idx = torch.sort( - bucket_timestamps_expand.view(-1) - ) + sorted_bucket_timestamps_expand, sorted_idx = torch.sort(bucket_timestamps_expand.view(-1)) torch.npu.synchronize() def grid(meta): - return (triton.cdiv(index_len, meta["BLOCK_SIZE"]),) + return (triton.cdiv(index_len, meta["BLOCK_SIZE"]), ) CORE_NUM = get_npu_properties()["num_vectorcore"] BLOCK_SIZE = math.ceil(index_len / CORE_NUM) @@ -275,9 +244,7 @@ def grid(meta): return tsw_grad -def rab_time_forward_golden( - ts_w: torch.Tensor, timestamps: torch.Tensor, bucketization_divisor: float -) -> torch.Tensor: +def rab_time_forward_golden(ts_w: torch.Tensor, timestamps: torch.Tensor, bucketization_divisor: float) -> torch.Tensor: """ torch realization of rab time forward for reference. """ @@ -286,15 +253,10 @@ def rab_time_forward_golden( num_layers = ts_w.shape[1] timestamps = timestamps.unsqueeze(-1).repeat(1, 1, 2) - diff_timestamps = timestamps.reshape(bs, infer_len, 1) - timestamps.reshape( - bs, 1, infer_len - ) + diff_timestamps = timestamps.reshape(bs, infer_len, 1) - timestamps.reshape(bs, 1, infer_len) clamp_max = torch.exp(torch.tensor(NUM_BUCKETS * BUCKET_DIVISOR)) - diff_timestamps = ( - torch.log(torch.abs(diff_timestamps).clamp(1, clamp_max)) - / bucketization_divisor - ) + diff_timestamps = (torch.log(torch.abs(diff_timestamps).clamp(1, clamp_max)) / bucketization_divisor) bucket_timestamps = diff_timestamps.long() bucket_timestamps = bucket_timestamps.view(-1) result = torch.index_select(ts_w, dim=0, index=bucket_timestamps) @@ -305,35 +267,23 @@ def rab_time_forward_golden( return result -def rab_time_backward_golden( - rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor -): +def rab_time_backward_golden(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor): """ torch realization of rab time backward for reference. """ num_layers, b, s, _ = rab_time_grad.shape - tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to( - rab_time_grad.device - ) + tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to(rab_time_grad.device) - bucket_timestamps_expand = ( - bucket_timestamps.reshape(b, s // 2, 1, s // 2, 1) - .repeat(1, 1, 2, 1, 2) - .reshape(b, s, s) - .to(torch.int64) - ) + bucket_timestamps_expand = (bucket_timestamps.reshape(b, s // 2, 1, s // 2, + 1).repeat(1, 1, 2, 1, 2).reshape(b, s, s).to(torch.int64)) for n, grad in enumerate(rab_time_grad.to(torch.float32)): - tsw_grad[n] = tsw_grad[n].scatter_add( - src=grad.view(-1), index=bucket_timestamps_expand.view(-1), dim=0 - ) + tsw_grad[n] = tsw_grad[n].scatter_add(src=grad.view(-1), index=bucket_timestamps_expand.view(-1), dim=0) return tsw_grad def run_rab_time_forward_case(num_layers, train_len, candidate_len, bs, dtype): past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) - timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to( - torch.int32 - ) + timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to(torch.int32) timestamps_weights = create_timestamps_weights(num_layers).to(dtype) timestamps = timestamps.npu() timestamps_weights = timestamps_weights.npu() @@ -361,18 +311,12 @@ def run_rab_time_forward_case(num_layers, train_len, candidate_len, bs, dtype): def run_rab_time_backward_case(num_layers: int, batchsize: int, s: int, dtype: torch.dtype): grad = create_rab_time_grad(num_layers, batchsize, s).to(dtype).npu() - bucket_timestamps = ( - create_bucket_timestamps(batchsize, s // 2).to(torch.int32).npu() - ) + bucket_timestamps = (create_bucket_timestamps(batchsize, s // 2).to(torch.int32).npu()) torch_npu.npu.synchronize() - golden_result = ( - rab_time_backward_golden(grad, bucket_timestamps).to(torch.float32).cpu() - ) - op_result = ( - rab_time_backward_triton(grad, bucket_timestamps).to(torch.float32).cpu() - ) + golden_result = (rab_time_backward_golden(grad, bucket_timestamps).to(torch.float32).cpu()) + op_result = (rab_time_backward_triton(grad, bucket_timestamps).to(torch.float32).cpu()) loss = 1e-4 if dtype == torch.float32 else 1e-3 torch.testing.assert_close(op_result, golden_result, rtol=loss, atol=loss) diff --git a/third_party/ascend/unittest/pytest_ut/test_12_hstu_attention.py b/third_party/ascend/unittest/pytest_ut/test_12_hstu_attention.py index 72d6ff6ed4..f1acf686b5 100644 --- a/third_party/ascend/unittest/pytest_ut/test_12_hstu_attention.py +++ b/third_party/ascend/unittest/pytest_ut/test_12_hstu_attention.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ HSTU Attention =============== @@ -246,18 +245,68 @@ def _hstu_attn_fwd( # noqa C901 off_head = (start_m - seq_start * head_num // 2) // (seq_len // 2) start_m_1 = (start_m - seq_start * head_num // 2) % (seq_len // 2) start_m_2 = seq_len - start_m_1 - BLOCK_M - _hstu_attn_fwd_compute(Q, K, V, seq_offsets, Out, stride_qm, stride_qh, stride_kn, stride_kh, - stride_vn, stride_vh, stride_om, stride_oh, alpha, head_num, MAX_SEQ_LEN, off_batch, off_head, - start_m_1, seq_start, seq_len, CAUSAL, HAS_BIAS, head_dim, head_dim, BLOCK_M, BLOCK_N, - mask_block=mask_block, - bias=bias, - ) - _hstu_attn_fwd_compute(Q, K, V, seq_offsets, Out, stride_qm, stride_qh, stride_kn, stride_kh, - stride_vn, stride_vh, stride_om, stride_oh, alpha, head_num, MAX_SEQ_LEN, off_batch, off_head, - start_m_2, seq_start, seq_len, CAUSAL, HAS_BIAS, head_dim, head_dim, BLOCK_M, BLOCK_N, - mask_block=mask_block, - bias=bias, - ) + _hstu_attn_fwd_compute( + Q, + K, + V, + seq_offsets, + Out, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_om, + stride_oh, + alpha, + head_num, + MAX_SEQ_LEN, + off_batch, + off_head, + start_m_1, + seq_start, + seq_len, + CAUSAL, + HAS_BIAS, + head_dim, + head_dim, + BLOCK_M, + BLOCK_N, + mask_block=mask_block, + bias=bias, + ) + _hstu_attn_fwd_compute( + Q, + K, + V, + seq_offsets, + Out, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_om, + stride_oh, + alpha, + head_num, + MAX_SEQ_LEN, + off_batch, + off_head, + start_m_2, + seq_start, + seq_len, + CAUSAL, + HAS_BIAS, + head_dim, + head_dim, + BLOCK_M, + BLOCK_N, + mask_block=mask_block, + bias=bias, + ) @triton.jit @@ -442,7 +491,13 @@ def _hstu_attn_bwd_one_col_block( # noqa C901 @triton.jit def _hstu_attn_bwd( # noqa C901 - Q, K, V, Grad, DQ, DK, DV, + Q, + K, + V, + Grad, + DQ, + DK, + DV, stride_qm: tl.constexpr, stride_qh: tl.constexpr, stride_kn: tl.constexpr, @@ -529,10 +584,34 @@ def triton_hstu_attention_fwd( core_num = get_npu_properties('num_aicore') tasks = total_seq * head_num // BLOCK_M // 2 grid = (core_num, 1, 1) - _hstu_attn_fwd[grid](q, k, v, seq_offsets, out, q.stride(0), q.stride(1), k.stride(0), k.stride(1), - v.stride(0), v.stride(1), out.stride(0), out.stride(1), alpha, batch, head_num, max_seq_len, head_dim, - causal, has_bias, core_num, tasks, BLOCK_M, BLOCK_N, mask, bias, - ) + _hstu_attn_fwd[grid]( + q, + k, + v, + seq_offsets, + out, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + out.stride(0), + out.stride(1), + alpha, + batch, + head_num, + max_seq_len, + head_dim, + causal, + has_bias, + core_num, + tasks, + BLOCK_M, + BLOCK_N, + mask, + bias, + ) return out @@ -555,12 +634,38 @@ def triton_hstu_attention_bwd( batch = seq_offsets.numel() - 1 _, head_num, head_dim = q.shape has_bias = bias is not None - grid = (batch * head_num, 1,) - _hstu_attn_bwd[grid](q, k, v, grad, dq, dk, dv, - q.stride(0), q.stride(1), k.stride(0), k.stride(1), v.stride(0), v.stride(1), - grad.stride(0), grad.stride(1), seq_offsets, alpha, batch, head_num, max_seq_len, head_dim, - causal, has_bias, BLOCK_BWD, BLOCK_BWD, bias, - ) + grid = ( + batch * head_num, + 1, + ) + _hstu_attn_bwd[grid]( + q, + k, + v, + grad, + dq, + dk, + dv, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + grad.stride(0), + grad.stride(1), + seq_offsets, + alpha, + batch, + head_num, + max_seq_len, + head_dim, + causal, + has_bias, + BLOCK_BWD, + BLOCK_BWD, + bias, + ) return dq, dk, dv @@ -569,8 +674,8 @@ def jagged_data_gen(batch_size, max_seq_len, num_heads, attention_dim, dataType) seq_lens = np.random.choice(seq_array, size=batch_size) if not np.isin(max_seq_len, seq_lens): seq_lens[np.random.randint(0, batch_size)] = max_seq_len - seq_offset = torch.concat((torch.zeros((1,), dtype=torch.int64), - torch.cumsum(torch.from_numpy(seq_lens), axis=0))).to(torch.int64).numpy() + seq_offset = torch.concat((torch.zeros((1, ), dtype=torch.int64), torch.cumsum(torch.from_numpy(seq_lens), + axis=0))).to(torch.int64).numpy() max_seq_len = np.max(seq_lens) total_seqs = np.sum(seq_lens) grad = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) @@ -596,7 +701,7 @@ def dense_to_jagged(q, dense_tensor, seq_lens): tensor = torch.zeros_like(q) offset = 0 for batch_id, seq_len in enumerate(seq_lens): - tensor[offset: offset + seq_len, :, :] = dense_tensor[batch_id, 0: seq_len, :, :] + tensor[offset:offset + seq_len, :, :] = dense_tensor[batch_id, 0:seq_len, :, :] offset = offset + seq_len return tensor @@ -605,7 +710,7 @@ def jagged_to_dense(jagged_tensor, seq_lens, head_nums, atten_dim): need_pad_seq = [] offset = 0 for _, seq_len in enumerate(seq_lens): - src_tensor = jagged_tensor[offset: offset + seq_len, :, :].reshape(seq_len, head_nums, atten_dim) + src_tensor = jagged_tensor[offset:offset + seq_len, :, :].reshape(seq_len, head_nums, atten_dim) need_pad_seq.append(src_tensor) offset = offset + seq_len @@ -682,13 +787,14 @@ def run_fwd_case(batch_size, max_seq_len, num_heads, attention_dim, data_type): def golden_bwd(grad, q, k, v, bias, mask, max_seq_len, seq_offset, enable_mask, silu_scale, enable_bias, data_type): + def jagged_to_dense_bwd(jagged_tensor, seq_lens, max_seq_len, head_num, head_dim): batch_size = len(seq_lens) dense_tensor = torch.zeros(batch_size, max_seq_len, head_num, head_dim, dtype=jagged_tensor.dtype) offset = 0 for batch_id, seq_len in enumerate(seq_lens): - dense_tensor[batch_id, :seq_len, :, :] = jagged_tensor[offset: offset + seq_len, :, :] + dense_tensor[batch_id, :seq_len, :, :] = jagged_tensor[offset:offset + seq_len, :, :] offset = offset + seq_len return dense_tensor @@ -698,7 +804,7 @@ def dense_to_jagged_bwd(jagged_tensor, dense_tensor, seq_lens): offset = 0 for batch_id, seq_len in enumerate(seq_lens): - tensor[offset: offset + seq_len, :, :] = dense_tensor[batch_id, 0: seq_len, :, :] + tensor[offset:offset + seq_len, :, :] = dense_tensor[batch_id, 0:seq_len, :, :] offset = offset + seq_len return tensor @@ -710,7 +816,7 @@ def dense_to_jagged_bwd(jagged_tensor, dense_tensor, seq_lens): head_nums = grad.shape[1] head_dim = grad.shape[2] batch_size = bias.shape[0] - seq_lens = np.zeros((batch_size,)).astype(np.int64) + seq_lens = np.zeros((batch_size, )).astype(np.int64) for batch_id in range(batch_size): seq_lens[batch_id] = seq_offset[batch_id + 1] - seq_offset[batch_id] grad_dens = jagged_to_dense_bwd(grad, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) diff --git a/third_party/ascend/unittest/pytest_ut/test_13_matrix_multiplication_optimized.py b/third_party/ascend/unittest/pytest_ut/test_13_matrix_multiplication_optimized.py index b71c10d682..27df8cf162 100644 --- a/third_party/ascend/unittest/pytest_ut/test_13_matrix_multiplication_optimized.py +++ b/third_party/ascend/unittest/pytest_ut/test_13_matrix_multiplication_optimized.py @@ -45,12 +45,12 @@ def get_npu_properties(): triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), - ], - key=["M", "N", "K"] -) + ], key=["M", "N", "K"]) @triton.jit def matmul_kernel( - mat_a, mat_b, mat_c, + mat_a, + mat_b, + mat_c, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, @@ -63,7 +63,6 @@ def matmul_kernel( pid = tl.program_id(axis=0) task_m_idx = 0 task_n_idx = 0 - ''' 水平分核方式每个任务块编号如下 [0, 1, 2, 3, 4, 5, 6, 7] @@ -107,56 +106,50 @@ def matmul_kernel( NUM_BLOCKS = NUM_BLOCKS_M * NUM_BLOCKS_N # 当任务量较多时,可以使能对角线分核策略进行优化 if NUM_BLOCKS_M >= BLOCK_TRESHHOLD and NUM_BLOCKS_N >= BLOCK_TRESHHOLD: - for block_idx in range( - pid, NUM_BLOCKS, num_cores - ): + for block_idx in range(pid, NUM_BLOCKS, num_cores): # 8 * 8 对角线分核代码实现 - curThresholdM = BLOCK_TRESHHOLD if block_idx < (NUM_BLOCKS_M // BLOCK_TRESHHOLD * BLOCK_TRESHHOLD) * NUM_BLOCKS_N else NUM_BLOCKS_M % BLOCK_TRESHHOLD + curThresholdM = BLOCK_TRESHHOLD if block_idx < ( + NUM_BLOCKS_M // BLOCK_TRESHHOLD * BLOCK_TRESHHOLD) * NUM_BLOCKS_N else NUM_BLOCKS_M % BLOCK_TRESHHOLD curThresholdM_thresholdN = curThresholdM * BLOCK_TRESHHOLD - curThresholdN = BLOCK_TRESHHOLD if block_idx % (NUM_BLOCKS_N * BLOCK_TRESHHOLD) < (curThresholdM * NUM_BLOCKS_N) // curThresholdM_thresholdN * curThresholdM_thresholdN else NUM_BLOCKS_N % BLOCK_TRESHHOLD + curThresholdN = BLOCK_TRESHHOLD if block_idx % (NUM_BLOCKS_N * BLOCK_TRESHHOLD) < ( + curThresholdM * + NUM_BLOCKS_N) // curThresholdM_thresholdN * curThresholdM_thresholdN else NUM_BLOCKS_N % BLOCK_TRESHHOLD localRelativeBlock = block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) % (BLOCK_TRESHHOLD * curThresholdM) - task_m_idx = localRelativeBlock % curThresholdM + block_idx // (BLOCK_TRESHHOLD * NUM_BLOCKS_N) * BLOCK_TRESHHOLD + task_m_idx = localRelativeBlock % curThresholdM + block_idx // (BLOCK_TRESHHOLD * + NUM_BLOCKS_N) * BLOCK_TRESHHOLD # 求最小公倍数,方便求基本块的坐标 x, y = curThresholdM, curThresholdN if curThresholdM > curThresholdN else curThresholdN, curThresholdM while y != 0: x, y = y, x % y lcm = curThresholdM * curThresholdN // x - task_n_idx = (localRelativeBlock + (localRelativeBlock // lcm)) % curThresholdN + block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) // curThresholdM_thresholdN * BLOCK_TRESHHOLD + task_n_idx = (localRelativeBlock + (localRelativeBlock // lcm)) % curThresholdN + block_idx % ( + BLOCK_TRESHHOLD * NUM_BLOCKS_N) // curThresholdM_thresholdN * BLOCK_TRESHHOLD m_start = task_m_idx * BLOCK_M n_start = task_n_idx * BLOCK_N mat_c_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k_start in range(0, K, BLOCK_K): - mat_a_offset = ((m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + ( - k_start + tl.arange(0, BLOCK_K) - )[None, :] + mat_a_offset = ( + (m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + (k_start + tl.arange(0, BLOCK_K))[None, :] mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( - (k_start + tl.arange(0, BLOCK_K)) < K - )[None, :] + (k_start + tl.arange(0, BLOCK_K)) < K)[None, :] mat_a_block = tl.load(mat_a + mat_a_offset, mask=mat_a_mask, other=0.0) extension.compile_hint(mat_a_block, "dot_pad_only_k") - mat_b_offset = ((k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + ( - n_start + tl.arange(0, BLOCK_N) - )[None, :] + mat_b_offset = ( + (k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( - (n_start + tl.arange(0, BLOCK_N)) < N - )[None, :] + (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] mat_b_block = tl.load(mat_b + mat_b_offset, mask=mat_b_mask, other=0.0) extension.compile_hint(mat_b_block, "dot_pad_only_k") mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) - mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + ( - n_start + tl.arange(0, BLOCK_N) - )[None, :] + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( - (n_start + tl.arange(0, BLOCK_N)) < N - )[None, :] + (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask=mat_c_mask) else: # 传统顺序分核 - for block_idx in range( - pid, NUM_BLOCKS, num_cores - ): + for block_idx in range(pid, NUM_BLOCKS, num_cores): task_m_idx = block_idx // NUM_BLOCKS_N task_n_idx = block_idx % NUM_BLOCKS_N m_start = task_m_idx * BLOCK_M @@ -164,41 +157,33 @@ def matmul_kernel( mat_c_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k_start in range(0, K, BLOCK_K): - mat_a_offset = ((m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + ( - k_start + tl.arange(0, BLOCK_K) - )[None, :] + mat_a_offset = ( + (m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + (k_start + tl.arange(0, BLOCK_K))[None, :] mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( - (k_start + tl.arange(0, BLOCK_K)) < K - )[None, :] + (k_start + tl.arange(0, BLOCK_K)) < K)[None, :] mat_a_block = tl.load(mat_a + mat_a_offset, mask=mat_a_mask, other=0.0) extension.compile_hint(mat_a_block, "dot_pad_only_k") - mat_b_offset = ((k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + ( - n_start + tl.arange(0, BLOCK_N) - )[None, :] + mat_b_offset = ( + (k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( - (n_start + tl.arange(0, BLOCK_N)) < N - )[None, :] + (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] mat_b_block = tl.load(mat_b + mat_b_offset, mask=mat_b_mask, other=0.0) extension.compile_hint(mat_b_block, "dot_pad_only_k") mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) - mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + ( - n_start + tl.arange(0, BLOCK_N) - )[None, :] + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( - (n_start + tl.arange(0, BLOCK_N)) < N - )[None, :] + (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask=mat_c_mask) def triton_matmul( - mat_a, - mat_b, + mat_a, + mat_b, ): m = mat_a.shape[0] k = mat_a.shape[1] n = mat_b.shape[1] mat_c = torch.empty(m, n, dtype=mat_a.dtype, device=mat_a.device) - ''' NPU芯片更加亲和512B对齐场景,如下分块通用性能较好,可以使用autotune选取最优 BLOCK_M = 128 @@ -208,15 +193,7 @@ def triton_matmul( num_cores = get_npu_properties()["num_aicore"] - matmul_kernel[(num_cores,)]( - mat_a, - mat_b, - mat_c, - m, - n, - k, - num_cores - ) + matmul_kernel[(num_cores, )](mat_a, mat_b, mat_c, m, n, k, num_cores) return mat_c @@ -233,7 +210,7 @@ def test_matmul_extension(): golden = torch.matmul(mat_a, mat_b) mask = golden.abs() < 1.0 - tmpatol = tmprtol = 2 ** -6 + tmpatol = tmprtol = 2**-6 torch.testing.assert_close(result[mask], golden[mask], atol=tmpatol, rtol=0) torch.testing.assert_close(result[~mask], golden[~mask], atol=0, rtol=tmprtol) diff --git a/third_party/ascend/unittest/pytest_ut/test_14_accuracy_comparison.py b/third_party/ascend/unittest/pytest_ut/test_14_accuracy_comparison.py index f8618dd7bc..bb7720ec4a 100644 --- a/third_party/ascend/unittest/pytest_ut/test_14_accuracy_comparison.py +++ b/third_party/ascend/unittest/pytest_ut/test_14_accuracy_comparison.py @@ -42,12 +42,11 @@ def torch_func(x0, x1): # 2. 定义 Triton kernel(在 NPU/GPU 上执行) @triton.jit - def triton_kernel_add( - out_ptr0, # 输出指针:结果存储位置 - in_ptr0, # 输入指针0:x0 的起始地址 - in_ptr1, # 输入指针1:x1 的起始地址 - XS: tl.constexpr # constexpr 参数:向量长度,在编译时确定 - ): + def triton_kernel_add(out_ptr0, # 输出指针:结果存储位置 + in_ptr0, # 输入指针0:x0 的起始地址 + in_ptr1, # 输入指针1:x1 的起始地址 + XS: tl.constexpr # constexpr 参数:向量长度,在编译时确定 + ): # 生成 [0, 1, 2, ..., XS-1] 的索引数组 idx = tl.arange(0, XS) # 从 in_ptr0 + idx 处加载 x0 的值 @@ -76,7 +75,8 @@ def triton_func(x0, x1): # 6. 打印成功信息 print( - f"== dtype:{triton_cal.dtype} == The accuracy comparison between triton_result and torch_result was successful.") + f"== dtype:{triton_cal.dtype} == The accuracy comparison between triton_result and torch_result was successful." + ) def accuracy_comparison(y_cal, y_ref): @@ -102,13 +102,8 @@ def accuracy_comparison(y_cal, y_ref): torch.testing.assert_close(y_ref, y_cal, rtol=1e-3, atol=1e-3, equal_nan=True) elif tensor_dtype == torch.bfloat16: # bfloat16 精度更低,建议转为 float32 再比较 - torch.testing.assert_close( - y_ref.to(torch.float32), - y_cal.to(torch.float32), - rtol=1e-3, - atol=1e-3, - equal_nan=True - ) + torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=1e-3, atol=1e-3, + equal_nan=True) elif tensor_dtype == torch.float32: # float32 精度较高,使用更严格的容差 torch.testing.assert_close(y_ref, y_cal, rtol=1e-4, atol=1e-4, equal_nan=True) @@ -136,14 +131,14 @@ def accuracy_comparison(y_cal, y_ref): def test_all_dtypes(dtype_name, dtype, low, high): N = 1024 if dtype == torch.bool: - x0 = torch.randint(low=low, high=high, size=(N,)).bool().npu() - x1 = torch.randint(low=low, high=high, size=(N,)).bool().npu() + x0 = torch.randint(low=low, high=high, size=(N, )).bool().npu() + x1 = torch.randint(low=low, high=high, size=(N, )).bool().npu() elif dtype.is_floating_point: - x0 = torch.rand((N,), dtype=dtype).npu() - x1 = torch.rand((N,), dtype=dtype).npu() + x0 = torch.rand((N, ), dtype=dtype).npu() + x1 = torch.rand((N, ), dtype=dtype).npu() else: - x0 = torch.randint(low=low, high=high, size=(N,), dtype=dtype).npu() - x1 = torch.randint(low=low, high=high, size=(N,), dtype=dtype).npu() + x0 = torch.randint(low=low, high=high, size=(N, ), dtype=dtype).npu() + x1 = torch.randint(low=low, high=high, size=(N, ), dtype=dtype).npu() print(f"Running test for {dtype_name}...") run_add(x0, x1) diff --git a/third_party/ascend/unittest/pytest_ut/test_15_demo_autotune.py b/third_party/ascend/unittest/pytest_ut/test_15_demo_autotune.py index 8845a370b0..a63a0b7ec3 100644 --- a/third_party/ascend/unittest/pytest_ut/test_15_demo_autotune.py +++ b/third_party/ascend/unittest/pytest_ut/test_15_demo_autotune.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Autotune ============= @@ -45,10 +44,9 @@ def get_autotune_config(): key=["numel"], ) @triton.jit -def triton_calc_kernel( - out_ptr0, in_ptr0, in_ptr1, numel, - XS: tl.constexpr # Block size controlling how many elements each thread block processes -): +def triton_calc_kernel(out_ptr0, in_ptr0, in_ptr1, numel, + XS: tl.constexpr # Block size controlling how many elements each thread block processes + ): pid = tl.program_id(0) idx = pid * XS + tl.arange(0, XS) msk = idx < numel @@ -81,8 +79,8 @@ def test_triton_autotune(): DEV = "npu" DTYPE = torch.float32 N = 192 * 1024 - x0 = torch.randn((N,), dtype=DTYPE, device=DEV) - x1 = torch.randn((N,), dtype=DTYPE, device=DEV) + x0 = torch.randn((N, ), dtype=DTYPE, device=DEV) + x1 = torch.randn((N, ), dtype=DTYPE, device=DEV) torch_ref = torch_calc_func(x0, x1) triton_cal = triton_calc_func(x0, x1) diff --git a/third_party/ascend/unittest/pytest_ut/test_16_profiler.py b/third_party/ascend/unittest/pytest_ut/test_16_profiler.py index d6f7d12ab4..393969cf63 100644 --- a/third_party/ascend/unittest/pytest_ut/test_16_profiler.py +++ b/third_party/ascend/unittest/pytest_ut/test_16_profiler.py @@ -32,23 +32,13 @@ def profiler_wrapper(fn, *args): stream = torch.npu.current_stream() experimental_config = torch_npu.profiler._ExperimentalConfig( aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False, - data_simplification=False - ) + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) with torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.CPU, - torch_npu.profiler.ProfilerActivity.NPU - ], + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), - record_shapes=True, - profile_memory=False, - with_stack=False, - with_flops=False, - with_modules=False, + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), record_shapes=True, + profile_memory=False, with_stack=False, with_flops=False, with_modules=False, experimental_config=experimental_config) as prof: stream.synchronize() for _ in range(skip_first + (wait + warmup + active) * repeat): @@ -103,17 +93,17 @@ def test_elementwise_ops(dtype, low, high): test_case_is_inductor = False if dtype == torch.bool: - x0 = torch.randint(low=low, high=high, size=(N,)).bool().npu() - x1 = torch.randint(low=low, high=high, size=(N,)).bool().npu() + x0 = torch.randint(low=low, high=high, size=(N, )).bool().npu() + x1 = torch.randint(low=low, high=high, size=(N, )).bool().npu() triton_cal = triton_or_func(x0, x1, N) ref = x0 | x1 else: if dtype.is_floating_point: - x0 = torch.rand((N,), dtype=dtype).npu() - x1 = torch.rand((N,), dtype=dtype).npu() + x0 = torch.rand((N, ), dtype=dtype).npu() + x1 = torch.rand((N, ), dtype=dtype).npu() else: - x0 = torch.randint(low=low, high=high, size=(N,), dtype=dtype).npu() - x1 = torch.randint(low=low, high=high, size=(N,), dtype=dtype).npu() + x0 = torch.randint(low=low, high=high, size=(N, ), dtype=dtype).npu() + x1 = torch.randint(low=low, high=high, size=(N, ), dtype=dtype).npu() triton_cal = triton_add_func(x0, x1, N) ref = x0 + x1 @@ -122,4 +112,5 @@ def test_elementwise_ops(dtype, low, high): def wrapper(): _ = triton_add_func(x0, x1, N) if dtype != torch.bool else triton_or_func(x0, x1, N) + profiler_wrapper(wrapper) diff --git a/third_party/ascend/unittest/pytest_ut/test_18_gather.py b/third_party/ascend/unittest/pytest_ut/test_18_gather.py index 5b2041ffcb..f8a4c70898 100644 --- a/third_party/ascend/unittest/pytest_ut/test_18_gather.py +++ b/third_party/ascend/unittest/pytest_ut/test_18_gather.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - """ Gather =============== @@ -53,7 +52,9 @@ def torch_gather(embeddings, idxes, default_value=0.0): # triton-version gather's kernel @triton.jit -def gather_kernel(embeddings_ptr, idxes_ptr, res_ptr, rows, cols, DEFAULT_VALUE: tl.constexpr, BIG_CORE_NUM: tl.constexpr, BIG_ROW_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE_SUB: tl.constexpr): +def gather_kernel(embeddings_ptr, idxes_ptr, res_ptr, rows, cols, DEFAULT_VALUE: tl.constexpr, + BIG_CORE_NUM: tl.constexpr, BIG_ROW_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr, + COL_BLOCK_SIZE_SUB: tl.constexpr): SMALL_ROW_BLOCK_SIZE = BIG_ROW_BLOCK_SIZE - 1 embedding_dtype = embeddings_ptr.type.element_ty @@ -63,7 +64,8 @@ def gather_kernel(embeddings_ptr, idxes_ptr, res_ptr, rows, cols, DEFAULT_VALUE: core_idx = tl.program_id(0) # compute the the size and start index of block row_block_size = BIG_ROW_BLOCK_SIZE if (core_idx < BIG_CORE_NUM) else SMALL_ROW_BLOCK_SIZE - row_start_idx = (core_idx * BIG_ROW_BLOCK_SIZE) if (core_idx < BIG_CORE_NUM) else (BIG_CORE_NUM * BIG_ROW_BLOCK_SIZE + (core_idx - BIG_CORE_NUM) * SMALL_ROW_BLOCK_SIZE) + row_start_idx = (core_idx * BIG_ROW_BLOCK_SIZE) if (core_idx < BIG_CORE_NUM) else ( + BIG_CORE_NUM * BIG_ROW_BLOCK_SIZE + (core_idx - BIG_CORE_NUM) * SMALL_ROW_BLOCK_SIZE) # process blocks witn shape (row_block_size, COL_BLOCK_SIZE_SUB) one by one for col_idx in tl.range(0, COL_BLOCK_SIZE, COL_BLOCK_SIZE_SUB): @@ -101,7 +103,8 @@ def triton_gather(embeddings: torch.Tensor, indices: torch.Tensor, default_value # when writing an npu kernel using triton, # you should note that the difference between BLOCK_SIZE and BLOCK_SIZE_SUB # BLOCK_SIZE specifies the size of data that are processed in one program - col_size_aligned = triton.cdiv(embeddings.shape[-1] * embeddings.element_size(), 32) * 32 // embeddings.element_size() + col_size_aligned = triton.cdiv(embeddings.shape[-1] * embeddings.element_size(), + 32) * 32 // embeddings.element_size() # the data are scattered to multiple programs, which can not be even # some process more data, some process less big_row_block_size = triton.cdiv(n_rows, CORE_NUM) @@ -114,7 +117,9 @@ def triton_gather(embeddings: torch.Tensor, indices: torch.Tensor, default_value grid = (min(n_rows, CORE_NUM), triton.cdiv(n_cols, col_block_size)) # launch the kernel - gather_kernel[grid](embeddings, indices, output, n_rows, n_cols, default_value, BIG_CORE_NUM=big_core_num, BIG_ROW_BLOCK_SIZE=big_row_block_size, COL_BLOCK_SIZE=col_block_size, COL_BLOCK_SIZE_SUB=col_block_size_sub) + gather_kernel[grid](embeddings, indices, output, n_rows, n_cols, default_value, BIG_CORE_NUM=big_core_num, + BIG_ROW_BLOCK_SIZE=big_row_block_size, COL_BLOCK_SIZE=col_block_size, + COL_BLOCK_SIZE_SUB=col_block_size_sub) return output diff --git a/third_party/ascend/unittest/pytest_ut/test_add.py b/third_party/ascend/unittest/pytest_ut/test_add.py index 3dd8402761..be88ce1a5d 100644 --- a/third_party/ascend/unittest/pytest_ut/test_add.py +++ b/third_party/ascend/unittest/pytest_ut/test_add.py @@ -76,11 +76,9 @@ def test_all_blocks_parallel(param_list, monkeypatch): monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") -@pytest.mark.parametrize('param_list', - [ - ['float32', (2, 4096, 8), 2, 32768, 1024], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], +]) def test_auto_blockify(param_list, monkeypatch): monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") dtype, shape, ncore, xblock, xblock_sub = param_list @@ -90,4 +88,4 @@ def test_auto_blockify(param_list, monkeypatch): y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() triton_add[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub, auto_blockify_size=ncore) test_common.validate_cmp(dtype, y_cal, y_ref) - monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") \ No newline at end of file + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") diff --git a/third_party/ascend/unittest/pytest_ut/test_address_check.py b/third_party/ascend/unittest/pytest_ut/test_address_check.py index ac93fe35bf..3a8429c08f 100644 --- a/third_party/ascend/unittest/pytest_ut/test_address_check.py +++ b/third_party/ascend/unittest/pytest_ut/test_address_check.py @@ -18,7 +18,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - import torch import torch_npu import triton @@ -45,7 +44,7 @@ def test_npu_tensor_should_success(): y_npu = torch.rand(size, device='npu') output = torch.empty(size, device='npu') - simple_kernel[(1,)](x_npu, y_npu, output, size) + simple_kernel[(1, )](x_npu, y_npu, output, size) expected = x_npu + y_npu actual = output @@ -62,7 +61,7 @@ def test_cpu_tensor_should_fail(): output = torch.empty(size, device='npu') with pytest.raises(ValueError) as exc_info: - simple_kernel[(1,)](x_cpu, y_cpu, output, size) + simple_kernel[(1, )](x_cpu, y_cpu, output, size) error_msg = str(exc_info.value) assert "cannot be accessed from Triton (cpu tensor?)" in error_msg, \ diff --git a/third_party/ascend/unittest/pytest_ut/test_advance_ptr.py b/third_party/ascend/unittest/pytest_ut/test_advance_ptr.py index 76d7e8e1d1..f9052cd28e 100644 --- a/third_party/ascend/unittest/pytest_ut/test_advance_ptr.py +++ b/third_party/ascend/unittest/pytest_ut/test_advance_ptr.py @@ -18,8 +18,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - - import triton import triton.language as tl @@ -30,22 +28,10 @@ @triton.jit def fn_npu_3d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): - block_ptr_in = tl.make_block_ptr( - base=x_ptr, - shape=(XB, YB, ZB), - strides=(YB * ZB, ZB, 1), - offsets=(0, 0, 0), - block_shape=(XB, YB, 2), - order=(2, 1, 0) - ) - block_ptr_out = tl.make_block_ptr( - base=output_ptr, - shape=(XB, YB, ZB), - strides=(YB * ZB, ZB, 1), - offsets=(0, 0, 0), - block_shape=(XB, YB, 2), - order=(2, 1, 0) - ) + block_ptr_in = tl.make_block_ptr(base=x_ptr, shape=(XB, YB, ZB), strides=(YB * ZB, ZB, 1), offsets=(0, 0, 0), + block_shape=(XB, YB, 2), order=(2, 1, 0)) + block_ptr_out = tl.make_block_ptr(base=output_ptr, shape=(XB, YB, ZB), strides=(YB * ZB, ZB, 1), offsets=(0, 0, 0), + block_shape=(XB, YB, 2), order=(2, 1, 0)) pid = tl.program_id(axis=0) # pid=0,1 BLOCK_SIZE_N=8 for _ in range(ZB // 2): X = tl.load(block_ptr_in, boundary_check=(0, 1, 2)) @@ -62,4 +48,4 @@ def test_advance_with_boundary_check(dtype, shape): expected = x output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() fn_npu_3d[1, 1, 1](output, x, XB=shape[0], YB=shape[1], ZB=shape[2]) - torch.testing.assert_close(output, expected) \ No newline at end of file + torch.testing.assert_close(output, expected) diff --git a/third_party/ascend/unittest/pytest_ut/test_alloc.py b/third_party/ascend/unittest/pytest_ut/test_alloc.py index 53d1994153..450e0c5b98 100644 --- a/third_party/ascend/unittest/pytest_ut/test_alloc.py +++ b/third_party/ascend/unittest/pytest_ut/test_alloc.py @@ -30,7 +30,6 @@ from triton._C.libtriton import ir, buffer_ir from triton._C.libtriton.ascend import ir as ascend_ir - os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" @@ -50,7 +49,8 @@ def compile_kernel(kernel, signature, constants): ir.load_dialects(context) buffer_ir.load_dialects(context) ascend_ir.load_dialects(context) - module = ast_to_ttir(kernel, src, context, Options(), {"create_address_space": al.semantic.create_address_space}, {}) + module = ast_to_ttir(kernel, src, context, Options(), {"create_address_space": al.semantic.create_address_space}, + {}) return str(module) @@ -66,9 +66,7 @@ def allocate_local_buffer(XBLOCK: tl.constexpr): bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0A) bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0B) bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0C) - bl.alloc( - tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.UB, is_mem_unique=True - ) + bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.UB, is_mem_unique=True) # ============== Main for manual testing ============== @@ -77,8 +75,6 @@ def allocate_local_buffer(XBLOCK: tl.constexpr): print("=" * 60) print("Test 1: Nested Scopes") print("=" * 60) - mlir = compile_kernel( - allocate_local_buffer, {}, {"XBLOCK": 256} - ) + mlir = compile_kernel(allocate_local_buffer, {}, {"XBLOCK": 256}) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) diff --git a/third_party/ascend/unittest/pytest_ut/test_arch.py b/third_party/ascend/unittest/pytest_ut/test_arch.py index 7ca66dc29a..4946ca3538 100644 --- a/third_party/ascend/unittest/pytest_ut/test_arch.py +++ b/third_party/ascend/unittest/pytest_ut/test_arch.py @@ -88,6 +88,7 @@ def test_arch(): print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) + # ============== Main for manual testing ============== if __name__ == "__main__": test_arch() diff --git a/third_party/ascend/unittest/pytest_ut/test_argmax.py b/third_party/ascend/unittest/pytest_ut/test_argmax.py index 16bfa28d78..bb26a40884 100644 --- a/third_party/ascend/unittest/pytest_ut/test_argmax.py +++ b/third_party/ascend/unittest/pytest_ut/test_argmax.py @@ -23,22 +23,20 @@ def triton_argmax_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): tl.store(out_ptr1, tmp4, None) -@pytest.mark.parametrize('shape', [(128,), (256,), (37,), (741,)]) +@pytest.mark.parametrize('shape', [(128, ), (256, ), (37, ), (741, )]) @pytest.mark.parametrize('dtype', ['int32', 'float32', 'uint8', 'int8']) def test_argmax_1d(dtype, shape): x0 = test_common.generate_tensor(shape, dtype).npu() triton_res = torch.empty(1, dtype=torch.int32).npu() numel = shape[0] - triton_argmax_1d[(1,)](x0, triton_res, numel, numel) + triton_argmax_1d[(1, )](x0, triton_res, numel, numel) torch_res = torch_argmax(x0, dim=0, keepdim=True) test_common.validate_cmp("int32", triton_res, torch_res) @triton.jit -def triton_argmax_2d(in_ptr0, out_ptr0, - dim: tl.constexpr, - M: tl.constexpr, N: tl.constexpr, - MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): +def triton_argmax_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, + NNUMEL: tl.constexpr): mblk_idx = tl.arange(0, MNUMEL) nblk_idx = tl.arange(0, NNUMEL) mmask = mblk_idx < M @@ -59,7 +57,9 @@ def triton_argmax_2d(in_ptr0, out_ptr0, def test_argmax_2d(dtype, shape, dim): shapex, shapey = shape x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([shape[1 - dim], ], dtype=torch.int32).npu() + triton_res = torch.empty([ + shape[1 - dim], + ], dtype=torch.int32).npu() triton_argmax_2d[(1, 1)](x0, triton_res, dim, shapex, shapey, shapex, shapey) torch_res = torch_argmax(x0, dim=dim, keepdim=False) test_common.validate_cmp("int32", triton_res, torch_res) diff --git a/third_party/ascend/unittest/pytest_ut/test_argmin.py b/third_party/ascend/unittest/pytest_ut/test_argmin.py index 98018baa55..25d50b71f0 100644 --- a/third_party/ascend/unittest/pytest_ut/test_argmin.py +++ b/third_party/ascend/unittest/pytest_ut/test_argmin.py @@ -23,22 +23,20 @@ def triton_argmin_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): tl.store(out_ptr1, tmp4, None) -@pytest.mark.parametrize('shape', [(128,), (256,), (37,), (741,)]) +@pytest.mark.parametrize('shape', [(128, ), (256, ), (37, ), (741, )]) @pytest.mark.parametrize('dtype', ['int32', 'float32', 'uint8', 'int8']) def test_argmin_1d(dtype, shape): x0 = test_common.generate_tensor(shape, dtype).npu() triton_res = torch.empty(1, dtype=torch.int32).npu() numel = shape[0] - triton_argmin_1d[(1,)](x0, triton_res, numel, numel) + triton_argmin_1d[(1, )](x0, triton_res, numel, numel) torch_res = torch_argmin(x0, dim=0, keepdim=True) test_common.validate_cmp("int32", triton_res, torch_res) @triton.jit -def triton_argmin_2d(in_ptr0, out_ptr0, - dim: tl.constexpr, - M: tl.constexpr, N: tl.constexpr, - MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): +def triton_argmin_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, + NNUMEL: tl.constexpr): mblk_idx = tl.arange(0, MNUMEL) nblk_idx = tl.arange(0, NNUMEL) mmask = mblk_idx < M @@ -59,7 +57,9 @@ def triton_argmin_2d(in_ptr0, out_ptr0, def test_argmin_2d(dtype, shape, dim): shapex, shapey = shape x0 = test_common.generate_tensor(shape, dtype).npu() - triton_res = torch.empty([shape[1 - dim], ], dtype=torch.int32).npu() + triton_res = torch.empty([ + shape[1 - dim], + ], dtype=torch.int32).npu() triton_argmin_2d[(1, 1)](x0, triton_res, dim, shapex, shapey, shapex, shapey) torch_res = torch_argmin(x0, dim=dim, keepdim=False) test_common.validate_cmp("int32", triton_res, torch_res) diff --git a/third_party/ascend/unittest/pytest_ut/test_asm.py b/third_party/ascend/unittest/pytest_ut/test_asm.py index db668e99f2..189e8c2e4c 100644 --- a/third_party/ascend/unittest/pytest_ut/test_asm.py +++ b/third_party/ascend/unittest/pytest_ut/test_asm.py @@ -5,18 +5,21 @@ import pytest import test_common + def torch_add(x, y): res = x + y return res + @triton.jit -def triton_asm_add(x_ptr, - y_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - ): - pid = tl.program_id(axis=0) +def triton_asm_add( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements @@ -26,43 +29,39 @@ def triton_asm_add(x_ptr, asm=""" ADD.s64 $0, $1, $2 """, - constraints=( - "=l,l,l" - ), + constraints=("=l,l,l"), args=[x, y], dtype=tl.int64, is_pure=True, - pack=1, + pack=1, ) tl.store(output_ptr + offsets, output, mask=mask) -@pytest.mark.parametrize('param_list', - [ - ['int64', 4096, 1024], - ] - ) - +@pytest.mark.parametrize('param_list', [ + ['int64', 4096, 1024], +]) def test_case(param_list): dtype, length, block_size = param_list ncore = length // block_size - x = test_common.generate_tensor((length,), dtype).npu() - y = test_common.generate_tensor((length,), dtype).npu() + x = test_common.generate_tensor((length, ), dtype).npu() + y = test_common.generate_tensor((length, ), dtype).npu() res_ref = torch_add(x, y) - res_cal = torch.zeros((length,), dtype = eval('torch.' + dtype)).npu() - triton_asm_add[(ncore,)](x, y, res_cal, length, BLOCK_SIZE=block_size) + res_cal = torch.zeros((length, ), dtype=eval('torch.' + dtype)).npu() + triton_asm_add[(ncore, )](x, y, res_cal, length, BLOCK_SIZE=block_size) test_common.validate_cmp(dtype, res_cal, res_ref) @triton.jit -def triton_asm_add_2d(x_ptr, - y_ptr, - output_ptr, - M, - N, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - ): +def triton_asm_add_2d( + x_ptr, + y_ptr, + output_ptr, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): pid = tl.program_id(axis=0) row_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M) col_offsets = tl.arange(0, BLOCK_N) @@ -74,9 +73,7 @@ def triton_asm_add_2d(x_ptr, asm=""" ADD.s64 $0, $1, $2 """, - constraints=( - "=l,l,l" - ), + constraints=("=l,l,l"), args=[x, y], dtype=tl.int64, is_pure=True, @@ -85,11 +82,9 @@ def triton_asm_add_2d(x_ptr, tl.store(output_ptr + offsets, output, mask=mask) -@pytest.mark.parametrize('param_list', - [ - ['int64', 64, 32, 16, 32], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['int64', 64, 32, 16, 32], +]) def test_case_2d(param_list): dtype, M, N, block_m, block_n = param_list ncore = M // block_m @@ -97,5 +92,5 @@ def test_case_2d(param_list): y = test_common.generate_tensor((M, N), dtype).npu() res_ref = torch_add(x, y) res_cal = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() - triton_asm_add_2d[(ncore,)](x, y, res_cal, M, N, BLOCK_M=block_m, BLOCK_N=block_n) + triton_asm_add_2d[(ncore, )](x, y, res_cal, M, N, BLOCK_M=block_m, BLOCK_N=block_n) test_common.validate_cmp(dtype, res_cal, res_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_asm_scalar.py b/third_party/ascend/unittest/pytest_ut/test_asm_scalar.py index 3d86c516aa..23b8f6553c 100644 --- a/third_party/ascend/unittest/pytest_ut/test_asm_scalar.py +++ b/third_party/ascend/unittest/pytest_ut/test_asm_scalar.py @@ -6,9 +6,7 @@ @triton.jit -def triton_asm_time( - output_ptr, -): +def triton_asm_time(output_ptr, ): y = tl.inline_asm_elementwise( asm=""" MOV $0, SYS_CNT @@ -24,15 +22,11 @@ def triton_asm_time( @pytest.mark.parametrize( "param_list", - [ - [ - "int64", - ] - ], + [[ + "int64", + ]], ) def test_case(param_list): - (dtype,) = param_list - res_cal = torch.zeros((1,), dtype=eval("torch." + dtype)).npu() - triton_asm_time[(1,)]( - res_cal, - ) + (dtype, ) = param_list + res_cal = torch.zeros((1, ), dtype=eval("torch." + dtype)).npu() + triton_asm_time[(1, )](res_cal, ) diff --git a/third_party/ascend/unittest/pytest_ut/test_assume1.py b/third_party/ascend/unittest/pytest_ut/test_assume1.py index b4c8ec7e3f..05d5630092 100644 --- a/third_party/ascend/unittest/pytest_ut/test_assume1.py +++ b/third_party/ascend/unittest/pytest_ut/test_assume1.py @@ -5,9 +5,7 @@ import pytest import test_common -from triton._internal_testing import ( - is_interpreter -) +from triton._internal_testing import (is_interpreter) @triton.jit @@ -31,4 +29,4 @@ def test_assume(dtype): if is_interpreter(): return - assert 'llvm.intr.assume' in pgm.asm['ttadapter'] \ No newline at end of file + assert 'llvm.intr.assume' in pgm.asm['ttadapter'] diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_add.py b/third_party/ascend/unittest/pytest_ut/test_atomic_add.py index 957bd03e08..967717ad33 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_add.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_add.py @@ -52,37 +52,31 @@ def atomic_add_supply(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr): @triton.jit -def atomic_add_for_load_offset( - index_ptr, in_ptr0, out_ptr0 -): +def atomic_add_for_load_offset(index_ptr, in_ptr0, out_ptr0): index = tl.atomic_add(index_ptr, 1) val = tl.load(in_ptr0 + index) tl.store(out_ptr0, val) @triton.jit -def atomic_add_for_store_offset( - index_ptr, out_ptr0 -): +def atomic_add_for_store_offset(index_ptr, out_ptr0): index = tl.atomic_add(index_ptr, 1) tl.store(out_ptr0 + index, 1) -@pytest.mark.parametrize('param_list', - [ - ['int64', (256, 32), 2], - ['int32', (32, 32), 2], - ['int16', (32, 32), 2], - ['int8', (32, 32), 2], - ['uint8', (32, 32), 2], - ['float32', (32, 32), 2], - ['float16', (64, 64), 4], - ['bfloat16', (64, 64), 4], - ['float32', (128, 128), 8], - ['float16', (128, 128), 16], - ['float32', (32768, 16), 32], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['int64', (256, 32), 2], + ['int32', (32, 32), 2], + ['int16', (32, 32), 2], + ['int8', (32, 32), 2], + ['uint8', (32, 32), 2], + ['float32', (32, 32), 2], + ['float16', (64, 64), 4], + ['bfloat16', (64, 64), 4], + ['float32', (128, 128), 8], + ['float16', (128, 128), 16], + ['float32', (32768, 16), 32], +]) def test_atomic_add(param_list): dtype, shape, ncore = param_list block_size = shape[0] * shape[1] / ncore @@ -189,7 +183,7 @@ def test_atomic_add_for_load_offset(): index_ref += 1 output_ref = output.clone() output_ref = input_tensor[index] - + atomic_add_for_load_offset[(1, )](index, input_tensor, output) torch.equal(index, index_ref) torch.equal(output, output_ref) @@ -202,7 +196,7 @@ def test_atomic_add_for_store_offset(): index_ref += 1 output_ref = output.clone() output_ref[index] = 1 - + atomic_add_for_store_offset[(1, )](index, output) torch.equal(index, index_ref) torch.equal(output, output_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_and.py b/third_party/ascend/unittest/pytest_ut/test_atomic_and.py index be1250d887..675f65f773 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_and.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_and.py @@ -39,15 +39,13 @@ def atomic_and(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr tl.store(out_ptr1 + (x1), tmp1, xmask) -@pytest.mark.parametrize('param_list', - [ - ['int64', (32, 32), 2], - ['int32', (32, 32), 2], - ['int16', (32, 32), 2], - ['int8', (16, 16), 4], - ['uint8', (16, 16), 4], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['int64', (32, 32), 2], + ['int32', (32, 32), 2], + ['int16', (32, 32), 2], + ['int8', (16, 16), 4], + ['uint8', (16, 16), 4], +]) def test_atomic_and(param_list): dtype, shape, ncore = param_list block_size = shape[0] * shape[1] // ncore diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py b/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py index 88bc7089b6..e834aa8d79 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py @@ -25,7 +25,6 @@ import torch import torch_npu - types_all = [ (torch.float32, 'float32'), ] @@ -49,7 +48,6 @@ def atomic_cas(in_ptr0, in_ptr1, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl. tl.store(out_ptr1 + (x1), tmp1, xmask) - @triton.jit def atomic_cas_with_full( ptr, @@ -61,11 +59,11 @@ def atomic_cas_with_full( x = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = x < n_elements - cmp = tl.full((BLOCK_SIZE,), 2.0, tl.float32) - val = tl.full((BLOCK_SIZE,), 1.0, tl.float32) + cmp = tl.full((BLOCK_SIZE, ), 2.0, tl.float32) + val = tl.full((BLOCK_SIZE, ), 1.0, tl.float32) - old = tl.atomic_cas(ptr + x, cmp, val) # in_ptr(origin 2) -> ref: 1 X - tl.store(out + x, old, mask=mask) # out(origin 1) -> ref: old in_ptr(2) √ + old = tl.atomic_cas(ptr + x, cmp, val) # in_ptr(origin 2) -> ref: 1 X + tl.store(out + x, old, mask=mask) # out(origin 1) -> ref: old in_ptr(2) √ @triton.jit @@ -81,25 +79,22 @@ def atomic_cas_without_full( x = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = x < n_elements - cmp = tl.load(cmp_ptr + x, mask) # 2 - val = tl.load(val_ptr + x, mask) # 1 + cmp = tl.load(cmp_ptr + x, mask) # 2 + val = tl.load(val_ptr + x, mask) # 1 - old = tl.atomic_cas(ptr + x, cmp, val) # old : 2 + old = tl.atomic_cas(ptr + x, cmp, val) # old : 2 tl.store(out_ptr + x, old, mask=mask) - -@pytest.mark.parametrize('param_list', - [ - ['int16', (8, 8), 2], - ['int32', (32, 32), 6], - ['int64', (32, 32), 2], - ['float32', (32, 32), 2], - ['float16', (64, 64), 4], - ['float32', (128, 128), 8], - ['float16', (128, 128), 16], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['int16', (8, 8), 2], + ['int32', (32, 32), 6], + ['int64', (32, 32), 2], + ['float32', (32, 32), 2], + ['float16', (64, 64), 4], + ['float32', (128, 128), 8], + ['float16', (128, 128), 16], +]) def test_atomic_cas(param_list): dtype, shape, ncore = param_list block_size = shape[0] * shape[1] // ncore @@ -163,19 +158,15 @@ def test_atomic_cas_return_value(param_list): test_common.validate_cmp(dtype, pointer_old, pointer_old_ref) - @pytest.mark.parametrize('dtype,sigtype', types_all) @pytest.mark.parametrize('n_elements, BLOCK_SIZE', [(4096, 256)]) @pytest.mark.skip(reason="full tensor has problem, skipped") def test_atomic_cas_with_full(n_elements, BLOCK_SIZE, dtype, sigtype): - in_ptr = torch.full((n_elements,), 2, dtype=dtype).npu() + in_ptr = torch.full((n_elements, ), 2, dtype=dtype).npu() out_ptr = torch.empty_like(in_ptr) grid = (ceil_div(n_elements, BLOCK_SIZE), 1, 1) - atomic_cas_with_full[grid]( - in_ptr, out_ptr, n_elements, - BLOCK_SIZE=BLOCK_SIZE - ) + atomic_cas_with_full[grid](in_ptr, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE) # old should be all 2 (for in-range) torch.testing.assert_close(out_ptr, torch.full_like(out_ptr, 2.0)) @@ -187,17 +178,13 @@ def test_atomic_cas_with_full(n_elements, BLOCK_SIZE, dtype, sigtype): @pytest.mark.parametrize('dtype,sigtype', types_all) @pytest.mark.parametrize('n_elements, BLOCK_SIZE', [(4096, 256)]) def test_atomic_cas_without_full(n_elements, BLOCK_SIZE, dtype, sigtype): - in_ptr = torch.full((n_elements,), 2, dtype=dtype).npu() - cmp_ptr = torch.full((n_elements,), 2, dtype=dtype).npu() - val_ptr = torch.full((n_elements,), 1, dtype=dtype).npu() - out_ptr = torch.full((n_elements,), 1, dtype=dtype).npu() # ref: in_ptr + in_ptr = torch.full((n_elements, ), 2, dtype=dtype).npu() + cmp_ptr = torch.full((n_elements, ), 2, dtype=dtype).npu() + val_ptr = torch.full((n_elements, ), 1, dtype=dtype).npu() + out_ptr = torch.full((n_elements, ), 1, dtype=dtype).npu() # ref: in_ptr grid = (ceil_div(n_elements, BLOCK_SIZE), 1, 1) - atomic_cas_without_full[grid]( - in_ptr, cmp_ptr, val_ptr, out_ptr, n_elements, - BLOCK_SIZE=BLOCK_SIZE - ) + atomic_cas_without_full[grid](in_ptr, cmp_ptr, val_ptr, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE) torch.testing.assert_close(in_ptr, torch.full_like(in_ptr, 1.0)) torch.testing.assert_close(out_ptr, torch.full_like(out_ptr, 2.0)) - diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_max.py b/third_party/ascend/unittest/pytest_ut/test_atomic_max.py index 77b64eee4f..7a7a200e24 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_max.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_max.py @@ -53,22 +53,20 @@ def triton_test_fn_atomic_max_dma_supply(in_ptr0, out_ptr0, n_elements: tl.const # torch.max do not support int -@pytest.mark.parametrize('param_list', - [ - ['uint8', (32, 32), 2], - ['int16', (32, 32), 2], - ['bfloat16', (32, 32), 2], - ['float16', (32, 32), 2], - ['float32', (128, 128), 8], - ['float32', (32768, 16), 32], - ['int32', (32, 32), 2], - ['int32', (128, 128), 8], - ['int32', (32768, 16), 32], - ['int64', (32, 32), 2], - ['int64', (128, 128), 8], - ['int64', (8192, 16), 32], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['uint8', (32, 32), 2], + ['int16', (32, 32), 2], + ['bfloat16', (32, 32), 2], + ['float16', (32, 32), 2], + ['float32', (128, 128), 8], + ['float32', (32768, 16), 32], + ['int32', (32, 32), 2], + ['int32', (128, 128), 8], + ['int32', (32768, 16), 32], + ['int64', (32, 32), 2], + ['int64', (128, 128), 8], + ['int64', (8192, 16), 32], +]) def test_atomic_max(param_list): dtype, shape, ncore = param_list block_size = shape[0] * shape[1] / ncore diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_min.py b/third_party/ascend/unittest/pytest_ut/test_atomic_min.py index 08dd777f76..461a3baa9c 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_min.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_min.py @@ -50,18 +50,17 @@ def triton_test_fn_atomic_min_dma_supply(in_ptr0, out_ptr0, n_elements: tl.const tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.atomic_min(out_ptr0 + (x1), tmp0, xmask) -@pytest.mark.parametrize('param_list', - [ - ['uint8', (32, 32), 2], - ['int8', (32, 32), 2], - ['int16', (32, 32), 2], - ['int32', (32, 32), 2], - ['int64', (32, 32), 2], - ['bfloat16', (64, 64), 4], - ['float16', (64, 64), 4], - ['float32', (32, 32), 2], - ] - ) + +@pytest.mark.parametrize('param_list', [ + ['uint8', (32, 32), 2], + ['int8', (32, 32), 2], + ['int16', (32, 32), 2], + ['int32', (32, 32), 2], + ['int64', (32, 32), 2], + ['bfloat16', (64, 64), 4], + ['float16', (64, 64), 4], + ['float32', (32, 32), 2], +]) def test_atomic_min(param_list): dtype, shape, ncore = param_list block_size = shape[0] * shape[1] / ncore diff --git a/third_party/ascend/unittest/pytest_ut/test_atomic_rmw_useanalysis.py b/third_party/ascend/unittest/pytest_ut/test_atomic_rmw_useanalysis.py index b3d91384e9..da1da5eafc 100644 --- a/third_party/ascend/unittest/pytest_ut/test_atomic_rmw_useanalysis.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_rmw_useanalysis.py @@ -14,10 +14,10 @@ def atomic_rmw_useanalysis_kernel( ): pid = tl.program_id(0) base_idx = pid * 8 - + term1 = 15.0 * 15.0 term2 = 8.0 * (7.0 - base_idx) - + delta = term1 + term2 sqrt_delta = tl.sqrt(delta) @@ -36,7 +36,7 @@ def atomic_rmw_useanalysis_kernel( p = tl.exp(scaled) result = p * (data * 2.0 - d_val) - + output_offsets = offsets tl.atomic_add(output_ptr + output_offsets, result, mask=mask) @@ -52,7 +52,7 @@ def test_atomic_rmw_useanalysis(): d_data = torch.randn(N, dtype=torch.float32, device=DEVICE) output_data = torch.zeros(N, dtype=torch.float32, device=DEVICE) - grid = (8,) + grid = (8, ) atomic_rmw_useanalysis_kernel[grid]( input_data, @@ -63,7 +63,7 @@ def test_atomic_rmw_useanalysis(): BLOCK_SIZE=BLOCK_SIZE, ) output_sum = output_data.abs().sum().item() - + if output_sum == 0: raise AssertionError("UseAnalysis bug detected: atomic_rmw dependencies were erased") else: diff --git a/third_party/ascend/unittest/pytest_ut/test_block_ptr.py b/third_party/ascend/unittest/pytest_ut/test_block_ptr.py index 80f85b1e3a..e8132a4508 100644 --- a/third_party/ascend/unittest/pytest_ut/test_block_ptr.py +++ b/third_party/ascend/unittest/pytest_ut/test_block_ptr.py @@ -83,7 +83,7 @@ def test_npu(para_type, data_type, XB, YB, ZB): print(a) fn_npu_[1, 1, 1](output, x, y, z, output1, XB=XB, YB=YB, ZB=ZB, debug=True) print(output) - torch.testing.assert_close(output,a) + torch.testing.assert_close(output, a) @triton.jit @@ -143,8 +143,8 @@ def ref_func(inputs, scale, cu_lens): outputs = torch.zeros_like(inputs) bsz = cu_lens.size(0) - 1 for bid in range(bsz): - tmp = inputs[cu_lens[bid]: cu_lens[bid + 1]].to(torch.float32) * scale[bid] - outputs[cu_lens[bid]: cu_lens[bid + 1]] = tmp.to(outputs.dtype) + tmp = inputs[cu_lens[bid]:cu_lens[bid + 1]].to(torch.float32) * scale[bid] + outputs[cu_lens[bid]:cu_lens[bid + 1]] = tmp.to(outputs.dtype) return outputs @@ -155,8 +155,14 @@ def tt_func(inputs, scale, cu_lens): assert head_dim <= 1024 BLOCK_SIZE_N = 1024 BLOCK_SIZE_M = 4 - dma_block_ptr[20, ]( - inputs, outputs, scale, bsz, cu_lens, + dma_block_ptr[ + 20, + ]( + inputs, + outputs, + scale, + bsz, + cu_lens, inputs.stride(0), inputs.stride(1), outputs.stride(0), @@ -169,16 +175,14 @@ def tt_func(inputs, scale, cu_lens): return outputs -@pytest.mark.parametrize('param_list', - [ - [8, 1024, 1024, True], - [8, 1024, 1024, False], - ] - ) +@pytest.mark.parametrize('param_list', [ + [8, 1024, 1024, True], + [8, 1024, 1024, False], +]) def test_func(param_list): bsz, max_len, max_n, test_align = param_list - lens = torch.randint(max_len // 2, max_len, (bsz,), dtype=torch.int32, device="npu") - n = torch.randint(max_n // 2, max_n, (1,), dtype=torch.int32, device="npu")[0].item() + lens = torch.randint(max_len // 2, max_len, (bsz, ), dtype=torch.int32, device="npu") + n = torch.randint(max_n // 2, max_n, (1, ), dtype=torch.int32, device="npu")[0].item() if test_align: lens = (lens + 1023) // 1024 * 1024 n = (n + 1023) // 1024 * 1024 diff --git a/third_party/ascend/unittest/pytest_ut/test_boundary_check.py b/third_party/ascend/unittest/pytest_ut/test_boundary_check.py index 8773a0b038..c7dabda498 100644 --- a/third_party/ascend/unittest/pytest_ut/test_boundary_check.py +++ b/third_party/ascend/unittest/pytest_ut/test_boundary_check.py @@ -27,19 +27,13 @@ # ========== Test 1: Static base address + boundary_check ========== @triton.jit def static_base_boundary_check_kernel( - out_ptr, - in_ptr, - BLOCK_SIZE: tl.constexpr, + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, ): - ptr = tl.make_block_ptr( - base=in_ptr, - shape=(BLOCK_SIZE * 2,), - strides=(1,), - offsets=(0,), - block_shape=(BLOCK_SIZE,), - order=(0,) - ) - data = tl.load(ptr, boundary_check=(0,), padding_option="zero") + ptr = tl.make_block_ptr(base=in_ptr, shape=(BLOCK_SIZE * 2, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + data = tl.load(ptr, boundary_check=(0, ), padding_option="zero") result = tl.sum(data) tl.store(out_ptr, result) @@ -52,7 +46,7 @@ def test_static_base(): BLOCK_SIZE = 64 in_tensor = torch.randn(BLOCK_SIZE * 2, dtype=torch.float32).npu() out_tensor = torch.zeros(1, dtype=torch.float32).npu() - static_base_boundary_check_kernel[(1,)]( + static_base_boundary_check_kernel[(1, )]( out_ptr=out_tensor, in_ptr=in_tensor, BLOCK_SIZE=BLOCK_SIZE, @@ -64,21 +58,15 @@ def test_static_base(): # ========== Test 2: Simple dynamic base address + boundary_check ========== @triton.jit def simple_dynamic_base_boundary_check_kernel( - out_ptr, - in_ptr, - offset: tl.int32, - BLOCK_SIZE: tl.constexpr, + out_ptr, + in_ptr, + offset: tl.int32, + BLOCK_SIZE: tl.constexpr, ): base = in_ptr + offset - ptr = tl.make_block_ptr( - base=base, - shape=(BLOCK_SIZE * 2,), - strides=(1,), - offsets=(0,), - block_shape=(BLOCK_SIZE,), - order=(0,) - ) - data = tl.load(ptr, boundary_check=(0,), padding_option="zero") + ptr = tl.make_block_ptr(base=base, shape=(BLOCK_SIZE * 2, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + data = tl.load(ptr, boundary_check=(0, ), padding_option="zero") result = tl.sum(data) tl.store(out_ptr, result) @@ -88,7 +76,7 @@ def test_simple_dynamic_base(): offset = 32 in_tensor = torch.randn(BLOCK_SIZE * 4, dtype=torch.float32).npu() out_tensor = torch.zeros(1, dtype=torch.float32).npu() - simple_dynamic_base_boundary_check_kernel[(1,)]( + simple_dynamic_base_boundary_check_kernel[(1, )]( out_ptr=out_tensor, in_ptr=in_tensor, offset=offset, @@ -101,12 +89,12 @@ def test_simple_dynamic_base(): # ========== Test 3: Nested loop + dynamic base address + advance + boundary_check ========== @triton.jit def nested_dynamic_advance_boundary_kernel( - out_ptr, - in_ptr, - stride_in: tl.int32, - OUTER_LOOP: tl.constexpr, - INNER_LOOP: tl.constexpr, - BLOCK_SIZE: tl.constexpr, + out_ptr, + in_ptr, + stride_in: tl.int32, + OUTER_LOOP: tl.constexpr, + INNER_LOOP: tl.constexpr, + BLOCK_SIZE: tl.constexpr, ): """ Smallest reproducible code: The dynamic base address is in the outer loop, @@ -114,17 +102,11 @@ def nested_dynamic_advance_boundary_kernel( """ for i in range(OUTER_LOOP): base = in_ptr + i * stride_in - ptr = tl.make_block_ptr( - base=base, - shape=(INNER_LOOP * BLOCK_SIZE,), - strides=(1,), - offsets=(0,), - block_shape=(BLOCK_SIZE,), - order=(0,) - ) + ptr = tl.make_block_ptr(base=base, shape=(INNER_LOOP * BLOCK_SIZE, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) for j in range(INNER_LOOP): - cur_ptr = tl.advance(ptr, (j * BLOCK_SIZE,)) - data = tl.load(cur_ptr, boundary_check=(0,), padding_option="zero") + cur_ptr = tl.advance(ptr, (j * BLOCK_SIZE, )) + data = tl.load(cur_ptr, boundary_check=(0, ), padding_option="zero") result = tl.sum(data) tl.store(out_ptr + i * INNER_LOOP + j, result) @@ -147,7 +129,7 @@ def test_nested_dynamic(): INNER_LOOP = 2 in_tensor = torch.randn(OUTER_LOOP * INNER_LOOP * BLOCK_SIZE * 2, dtype=torch.float32).npu() out_tensor = torch.zeros(OUTER_LOOP * INNER_LOOP, dtype=torch.float32).npu() - nested_dynamic_advance_boundary_kernel[(1,)]( + nested_dynamic_advance_boundary_kernel[(1, )]( out_ptr=out_tensor, in_ptr=in_tensor, stride_in=INNER_LOOP * BLOCK_SIZE, @@ -162,19 +144,13 @@ def test_nested_dynamic(): # ========== Test 4: Explicit out-of-bounds access + zero padding + boundary_check ========== @triton.jit def out_of_bound_zero_padding_kernel( - out_ptr, - in_ptr, - BLOCK_SIZE: tl.constexpr, + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, ): - ptr = tl.make_block_ptr( - base=in_ptr, - shape=(BLOCK_SIZE,), - strides=(1,), - offsets=(0,), - block_shape=(BLOCK_SIZE * 2,), - order=(0,) - ) - data = tl.load(ptr, boundary_check=(0,), padding_option="zero") + ptr = tl.make_block_ptr(base=in_ptr, shape=(BLOCK_SIZE, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE * 2, ), order=(0, )) + data = tl.load(ptr, boundary_check=(0, ), padding_option="zero") result = tl.sum(data) tl.store(out_ptr, result) @@ -183,7 +159,7 @@ def test_out_of_bound(): BLOCK_SIZE = 64 in_tensor = torch.randn(BLOCK_SIZE, dtype=torch.float32).npu() out_tensor = torch.zeros(1, dtype=torch.float32).npu() - out_of_bound_zero_padding_kernel[(1,)]( + out_of_bound_zero_padding_kernel[(1, )]( out_ptr=out_tensor, in_ptr=in_tensor, BLOCK_SIZE=BLOCK_SIZE, @@ -195,19 +171,13 @@ def test_out_of_bound(): # ========== Test 5:padding_option = NAN + boundary_check========== @triton.jit def nan_padding_kernel( - out_ptr, - in_ptr, - BLOCK_SIZE: tl.constexpr, + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, ): - ptr = tl.make_block_ptr( - base=in_ptr, - shape=(BLOCK_SIZE,), - strides=(1,), - offsets=(0,), - block_shape=(BLOCK_SIZE * 2,), - order=(0,) - ) - data = tl.load(ptr, boundary_check=(0,), padding_option="nan") + ptr = tl.make_block_ptr(base=in_ptr, shape=(BLOCK_SIZE, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE * 2, ), order=(0, )) + data = tl.load(ptr, boundary_check=(0, ), padding_option="nan") result = tl.sum(data) tl.store(out_ptr, result) @@ -217,7 +187,7 @@ def test_nan_padding(): in_tensor = torch.randn(BLOCK_SIZE, dtype=torch.float32).npu() out_tensor = torch.zeros(1, dtype=torch.float32).npu() try: - nan_padding_kernel[(1,)]( + nan_padding_kernel[(1, )]( out_ptr=out_tensor, in_ptr=in_tensor, BLOCK_SIZE=BLOCK_SIZE, @@ -230,22 +200,16 @@ def test_nan_padding(): # ========== Test 6:Multi-layer advance + boundary_check ========== @triton.jit def multi_advance_kernel( - out_ptr, - in_ptr, - BLOCK_SIZE: tl.constexpr, + out_ptr, + in_ptr, + BLOCK_SIZE: tl.constexpr, ): base = in_ptr - ptr0 = tl.make_block_ptr( - base=base, - shape=(BLOCK_SIZE * 4,), - strides=(1,), - offsets=(0,), - block_shape=(BLOCK_SIZE,), - order=(0,) - ) - ptr1 = tl.advance(ptr0, (BLOCK_SIZE,)) - ptr2 = tl.advance(ptr1, (BLOCK_SIZE,)) - data = tl.load(ptr2, boundary_check=(0,), padding_option="zero") + ptr0 = tl.make_block_ptr(base=base, shape=(BLOCK_SIZE * 4, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + ptr1 = tl.advance(ptr0, (BLOCK_SIZE, )) + ptr2 = tl.advance(ptr1, (BLOCK_SIZE, )) + data = tl.load(ptr2, boundary_check=(0, ), padding_option="zero") result = tl.sum(data) tl.store(out_ptr, result) @@ -254,7 +218,7 @@ def test_multi_advance(): BLOCK_SIZE = 64 in_tensor = torch.randn(BLOCK_SIZE * 4, dtype=torch.float32).npu() out_tensor = torch.zeros(1, dtype=torch.float32).npu() - multi_advance_kernel[(1,)]( + multi_advance_kernel[(1, )]( out_ptr=out_tensor, in_ptr=in_tensor, BLOCK_SIZE=BLOCK_SIZE, @@ -266,23 +230,17 @@ def test_multi_advance(): # ========== Test 7:Complex base address calculation + boundary_check ========== @triton.jit def complex_base_calculation_kernel( - out_ptr, - in_ptr, - offset1: tl.int32, - offset2: tl.int32, - scale: tl.int32, - BLOCK_SIZE: tl.constexpr, + out_ptr, + in_ptr, + offset1: tl.int32, + offset2: tl.int32, + scale: tl.int32, + BLOCK_SIZE: tl.constexpr, ): base = in_ptr + offset1 * scale + offset2 - ptr = tl.make_block_ptr( - base=base, - shape=(BLOCK_SIZE * 2,), - strides=(1,), - offsets=(0,), - block_shape=(BLOCK_SIZE,), - order=(0,) - ) - data = tl.load(ptr, boundary_check=(0,), padding_option="zero") + ptr = tl.make_block_ptr(base=base, shape=(BLOCK_SIZE * 2, ), strides=(1, ), offsets=(0, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + data = tl.load(ptr, boundary_check=(0, ), padding_option="zero") result = tl.sum(data) tl.store(out_ptr, result) @@ -293,7 +251,7 @@ def test_complex_base(): total_offset = offset1 * scale + offset2 in_tensor = torch.randn(total_offset + BLOCK_SIZE * 2, dtype=torch.float32).npu() out_tensor = torch.zeros(1, dtype=torch.float32).npu() - complex_base_calculation_kernel[(1,)]( + complex_base_calculation_kernel[(1, )]( out_ptr=out_tensor, in_ptr=in_tensor, offset1=offset1, diff --git a/third_party/ascend/unittest/pytest_ut/test_cat_help_func.py b/third_party/ascend/unittest/pytest_ut/test_cat_help_func.py index 3fc6771f3a..4b102b9335 100644 --- a/third_party/ascend/unittest/pytest_ut/test_cat_help_func.py +++ b/third_party/ascend/unittest/pytest_ut/test_cat_help_func.py @@ -33,7 +33,7 @@ def gen_1d_cat_shapes(min_val=1, max_val=4096): shape1 = random.randint(min_val, max_val) shape2 = random.randint(min_val, max_val) - return (shape1,), (shape2,), 0 + return (shape1, ), (shape2, ), 0 def gen_2d_cat_shapes(dim=0, min_val=1, max_val=4096): @@ -85,12 +85,7 @@ def gen_3d_cat_shapes(dim=0, min_val=1, max_val=4096): return shape1, shape2, dim -def gen_100_cat_shapes( - num_groups=100, - mix_ratio=(0.3, 0.3, 0.4), - min_val=1, - max_val=4096 -): +def gen_100_cat_shapes(num_groups=100, mix_ratio=(0.3, 0.3, 0.4), min_val=1, max_val=4096): shape_list = [] num_1d = int(num_groups * mix_ratio[0]) @@ -111,12 +106,8 @@ def gen_100_cat_shapes( random.shuffle(shape_list) return shape_list -full_shape = gen_100_cat_shapes( - num_groups=100, - mix_ratio=(0.3, 0.4, 0.3), - min_val=1, - max_val=4096 -) + +full_shape = gen_100_cat_shapes(num_groups=100, mix_ratio=(0.3, 0.4, 0.3), min_val=1, max_val=4096) @triton.jit @@ -130,7 +121,6 @@ def _cat_helper_func_2D_1( x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, - ): y0_offset = tl.program_id(0) * Y0BLOCK_SUB base_y0 = tl.arange(0, Y0BLOCK_SUB) @@ -152,7 +142,8 @@ def _cat_helper_func_2D_1( @triton.jit -def triton_unk_fused_cat_dim0_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): +def triton_unk_fused_cat_dim0_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_numel, Y0BLOCK: tl.constexpr, + Y0BLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): y0_offset = tl.program_id(0) * Y0BLOCK base_y0 = tl.arange(0, Y0BLOCK_SUB) loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB @@ -184,7 +175,7 @@ def triton_unk_fused_cat_dim0_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_n @triton.jit def triton_unk_fused_cat_dim0_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, y1_numel, x1_numel, YBLOCK: tl.constexpr, - YBLOCK_2: tl.constexpr, YBLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): + YBLOCK_2: tl.constexpr, YBLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): y0_offset = tl.program_id(0) * YBLOCK base_y0 = tl.arange(0, YBLOCK_SUB) loops_y0 = (YBLOCK + YBLOCK_SUB - 1) // YBLOCK_SUB @@ -223,7 +214,8 @@ def triton_unk_fused_cat_dim0_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, y1_n new_z0 = tl.arange(0, 2)[:, None, None] new_x2_mask = new_x2 < x1_numel new_y1_mask = new_y1 < min_numel - tl.store(output_ptr + (new_x2 + x1_numel * new_y1 + x1_numel * y0_numel * new_z0), tmp13, new_x2_mask & new_y1_mask) + tl.store(output_ptr + (new_x2 + x1_numel * new_y1 + x1_numel * y0_numel * new_z0), tmp13, + new_x2_mask & new_y1_mask) loops_y1 = (YBLOCK_2 + YBLOCK_SUB - 1) // YBLOCK_SUB y2_offset = tl.program_id(0) * YBLOCK_2 + min_numel @@ -263,7 +255,8 @@ def triton_unk_fused_cat_dim0_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, y1_n @triton.jit -def triton_unk_fused_cat_dim1_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): +def triton_unk_fused_cat_dim1_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_numel, Y0BLOCK: tl.constexpr, + Y0BLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): y0_offset = tl.program_id(0) * Y0BLOCK base_y0 = tl.arange(0, Y0BLOCK_SUB) loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB @@ -290,11 +283,13 @@ def triton_unk_fused_cat_dim1_sameshape(output_ptr, x_ptr, y_ptr, y0_numel, x1_n new_z0 = tl.arange(0, 2)[None, :, None] new_x2_mask = new_x2 < x1_numel new_y1_mask = new_y1 < y0_numel - tl.store(output_ptr + (new_x2 + 2 * x1_numel * new_y1 + x1_numel * new_z0), tmp13, new_x2_mask & new_y1_mask) + tl.store(output_ptr + (new_x2 + 2 * x1_numel * new_y1 + x1_numel * new_z0), tmp13, + new_x2_mask & new_y1_mask) @triton.jit -def triton_unk_fused_cat_dim1_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, x0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): +def triton_unk_fused_cat_dim1_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, x0_numel, x1_numel, Y0BLOCK: tl.constexpr, + Y0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): y0_offset = tl.program_id(0) * Y0BLOCK base_y0 = tl.arange(0, Y0BLOCK_SUB) loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB @@ -373,7 +368,8 @@ def triton_unk_fused_cat_dim1_diffshape(output_ptr, x_ptr, y_ptr, y0_numel, x0_n @triton.jit -def triton_unk_fused_cat_3d_dim0(output_ptr, x_ptr, y_ptr, z0_numel, z1_numel, y1_numel, x1_numel, ZBLOCK: tl.constexpr, ZBLOCK_2: tl.constexpr, ZBLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): +def triton_unk_fused_cat_3d_dim0(output_ptr, x_ptr, y_ptr, z0_numel, z1_numel, y1_numel, x1_numel, ZBLOCK: tl.constexpr, + ZBLOCK_2: tl.constexpr, ZBLOCK_SUB: tl.constexpr, X1BLOCK_SUB: tl.constexpr): z0_offset = tl.program_id(0) * ZBLOCK base_z0 = tl.arange(0, ZBLOCK_SUB) loops_z0 = (ZBLOCK + ZBLOCK_SUB - 1) // ZBLOCK_SUB @@ -413,7 +409,8 @@ def triton_unk_fused_cat_3d_dim0(output_ptr, x_ptr, y_ptr, z0_numel, z1_numel, y new_z0 = tl.arange(0, 2)[:, None, None] new_x2_mask = new_x2 < xy_numel new_z1_mask = new_z1 < min_numel - tl.store(output_ptr + (new_x2 + xy_numel * new_z1 + xy_numel * z0_numel * new_z0), tmp13, new_x2_mask & new_z1_mask) + tl.store(output_ptr + (new_x2 + xy_numel * new_z1 + xy_numel * z0_numel * new_z0), tmp13, + new_x2_mask & new_z1_mask) loops_z1 = (ZBLOCK_2 + ZBLOCK_SUB - 1) // ZBLOCK_SUB z2_offset = tl.program_id(0) * ZBLOCK_2 + min_numel @@ -453,7 +450,8 @@ def triton_unk_fused_cat_3d_dim0(output_ptr, x_ptr, y_ptr, z0_numel, z1_numel, y @triton.jit -def triton_unk_fused_cat_3d_dim1(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, y1_numel, x0_numel, Z0BLOCK: tl.constexpr, Z0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): +def triton_unk_fused_cat_3d_dim1(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, y1_numel, x0_numel, + Z0BLOCK: tl.constexpr, Z0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): z0_offset = tl.program_id(0) * Z0BLOCK base_z0 = tl.arange(0, Z0BLOCK_SUB) loops_z0 = (Z0BLOCK + Z0BLOCK_SUB - 1) // Z0BLOCK_SUB @@ -493,7 +491,8 @@ def triton_unk_fused_cat_3d_dim1(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, y new_x2_mask = new_x2 < min_numel new_z1_mask = new_z1 < z0_numel sum_numel = min_numel + max_numel - tl.store(output_ptr + (new_x2 + sum_numel * new_z1 + x0_numel * y0_numel * new_z0), tmp13, new_x2_mask & new_z1_mask) + tl.store(output_ptr + (new_x2 + sum_numel * new_z1 + x0_numel * y0_numel * new_z0), tmp13, + new_x2_mask & new_z1_mask) if y0_numel == y1_numel: return @@ -535,7 +534,8 @@ def triton_unk_fused_cat_3d_dim1(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, y @triton.jit -def triton_unk_fused_cat_3d_dim2(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, x0_numel, x1_numel, Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): +def triton_unk_fused_cat_3d_dim2(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, x0_numel, x1_numel, + Y0BLOCK: tl.constexpr, Y0BLOCK_SUB: tl.constexpr, XBLOCK_SUB: tl.constexpr): y0_offset = tl.program_id(0) * Y0BLOCK base_y0 = tl.arange(0, Y0BLOCK_SUB) loops_y0 = (Y0BLOCK + Y0BLOCK_SUB - 1) // Y0BLOCK_SUB @@ -619,12 +619,12 @@ def triton_unk_fused_cat_3d_dim2(output_ptr, x_ptr, y_ptr, z0_numel, y0_numel, x testlist = [ # ===================== 1D场景(15组,dim=0) ===================== - ((3,), (3,), 0), - ((7,), (9,), 0), - ((13,), (11,), 0), - ((2047,), (2047,), 0), - ((2701,), (3003,), 0), - ((4093,), (3095,), 0), + ((3, ), (3, ), 0), + ((7, ), (9, ), 0), + ((13, ), (11, ), 0), + ((2047, ), (2047, ), 0), + ((2701, ), (3003, ), 0), + ((4093, ), (3095, ), 0), # ===================== 2D场景(20组,dim0/dim1) ===================== # dim0(行拼接,列维度一致) @@ -672,13 +672,16 @@ def test_cat_bigshape(testlists, dtype): if cat_dim == 0: ZBLOCK = (min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core ZBLOCK_2 = (max(x0.shape[0], x1.shape[0]) - min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core - triton_unk_fused_cat_3d_dim0[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x1.shape[0], x0.shape[1], x0.shape[2], ZBLOCK, ZBLOCK_2, 1, 256) + triton_unk_fused_cat_3d_dim0[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x1.shape[0], x0.shape[1], + x0.shape[2], ZBLOCK, ZBLOCK_2, 1, 256) elif cat_dim == 1: Z0BLOCK = (x0.shape[0] + num_core - 1) // num_core - triton_unk_fused_cat_3d_dim1[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x1.shape[1], x1.shape[2], Z0BLOCK, 1, 256) + triton_unk_fused_cat_3d_dim1[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x1.shape[1], + x1.shape[2], Z0BLOCK, 1, 256) else: Y0BLOCK = (x0.shape[0] * x0.shape[1] + num_core - 1) // num_core - triton_unk_fused_cat_3d_dim2[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x0.shape[2], x1.shape[2], Y0BLOCK, 1, 256) + triton_unk_fused_cat_3d_dim2[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x0.shape[2], + x1.shape[2], Y0BLOCK, 1, 256) test_common.validate_cmp(dtype, torch_res, triton_res) return numel_large = torch_res.numel() > 512 and len(x0.shape) < 3 @@ -693,17 +696,21 @@ def test_cat_bigshape(testlists, dtype): if cat_dim == 1: Y0BLOCK = (x0.shape[0] + num_core - 1) // num_core if x0.shape[1] == x1.shape[1]: - triton_unk_fused_cat_dim1_sameshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], Y0BLOCK, 1, 256) + triton_unk_fused_cat_dim1_sameshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], + Y0BLOCK, 1, 256) else: - triton_unk_fused_cat_dim1_diffshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], x1.shape[1], Y0BLOCK, 1, 256) + triton_unk_fused_cat_dim1_diffshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], + x1.shape[1], Y0BLOCK, 1, 256) else: if x0.shape[0] == x1.shape[0]: Y0BLOCK = (x0.shape[0] + num_core - 1) // num_core - triton_unk_fused_cat_dim0_sameshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], Y0BLOCK, 1, 256) + triton_unk_fused_cat_dim0_sameshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x0.shape[1], + Y0BLOCK, 1, 256) else: YBLOCK = (min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core YBLOCK_2 = (max(x0.shape[0], x1.shape[0]) - min(x0.shape[0], x1.shape[0]) + num_core - 1) // num_core - triton_unk_fused_cat_dim0_diffshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x1.shape[0], x1.shape[1], YBLOCK, YBLOCK_2, 1, 256) + triton_unk_fused_cat_dim0_diffshape[num_core, 1, 1](triton_res, x0, x1, x0.shape[0], x1.shape[0], + x1.shape[1], YBLOCK, YBLOCK_2, 1, 256) if squeeze_flag: triton_res = triton_res.squeeze() else: @@ -713,7 +720,8 @@ def test_cat_bigshape(testlists, dtype): x0 = torch.unsqueeze(x0, dim=0) x1 = torch.unsqueeze(x1, dim=0) triton_res = torch.unsqueeze(triton_res, dim=0) - _cat_helper_func_2D_1[num_core, 1, 1](x0, x1, triton_res, x0.shape[1], x1.shape[1], x0.shape[0], x0.shape[1] + x1.shape[1], 256, 16) + _cat_helper_func_2D_1[num_core, 1, 1](x0, x1, triton_res, x0.shape[1], x1.shape[1], x0.shape[0], + x0.shape[1] + x1.shape[1], 256, 16) if squeeze_flag: triton_res = triton_res.squeeze() diff --git a/third_party/ascend/unittest/pytest_ut/test_celoss_indices.py b/third_party/ascend/unittest/pytest_ut/test_celoss_indices.py index c9cca12573..69c1f9b0a2 100644 --- a/third_party/ascend/unittest/pytest_ut/test_celoss_indices.py +++ b/third_party/ascend/unittest/pytest_ut/test_celoss_indices.py @@ -18,7 +18,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - import torch import triton import triton.language as tl @@ -92,17 +91,22 @@ def test_celoss_indices_kernel(shape=(1, 2)): D = 1 inp = torch.randn(shape, dtype=dtype, device=device) - tgt = torch.randint(0, C, (N,), dtype=torch.int64, device=device) + tgt = torch.randint(0, C, (N, ), dtype=torch.int64, device=device) wgt = torch.randn(C, dtype=dtype, device=device) - out_triton = torch.empty((N * D,), dtype=torch.float32, device=device) - w_tgt_triton = torch.empty((N * D,), dtype=torch.float32, device=device) + out_triton = torch.empty((N * D, ), dtype=torch.float32, device=device) + w_tgt_triton = torch.empty((N * D, ), dtype=torch.float32, device=device) grid = (triton.cdiv(D, BLOCK_D), N) celoss_indices_kernel[grid]( - inp, tgt, wgt, out_triton, w_tgt_triton, + inp, + tgt, + wgt, + out_triton, + w_tgt_triton, ignore_index, - C, D, + C, + D, BLOCK_C=BLOCK_C, BLOCK_D=BLOCK_D, ) diff --git a/third_party/ascend/unittest/pytest_ut/test_compile_hint.py b/third_party/ascend/unittest/pytest_ut/test_compile_hint.py index 7ab3173013..4582201936 100644 --- a/third_party/ascend/unittest/pytest_ut/test_compile_hint.py +++ b/third_party/ascend/unittest/pytest_ut/test_compile_hint.py @@ -46,11 +46,9 @@ def triton_compile_hint(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_ @pytest.mark.skip(reason="not supported after the NPUIR is updated in April, and will be fixed later") -@pytest.mark.parametrize('param_list', - [ - ['float32', (2, 4096, 8), 2, 32768, 1024], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], +]) def test_compile_hint(param_list): dtype, shape, ncore, xblock, xblock_sub = param_list x0 = test_common.generate_tensor(shape, dtype).npu() diff --git a/third_party/ascend/unittest/pytest_ut/test_complex_mask.py b/third_party/ascend/unittest/pytest_ut/test_complex_mask.py index 7b26b447d6..a8f3ada327 100644 --- a/third_party/ascend/unittest/pytest_ut/test_complex_mask.py +++ b/third_party/ascend/unittest/pytest_ut/test_complex_mask.py @@ -56,7 +56,7 @@ def test_complex_mask_copy(): N = 1024 x = torch.randn(N, dtype=torch.float32).npu() y = torch.empty_like(x).npu() - copy_kernel[(1,)](x, y, N=N, NUMEL=N) + copy_kernel[(1, )](x, y, N=N, NUMEL=N) torch.testing.assert_close(x, y) @@ -65,5 +65,5 @@ def test_complex_mask_permute_copy(): N = 32 x = torch.randn(M * N, dtype=torch.float32).npu() y = torch.empty_like(x).npu() - permute_copy_kernel[(1,)](x, y, M=M, N=N, NUMEL=M * N) + permute_copy_kernel[(1, )](x, y, M=M, N=N, NUMEL=M * N) torch.testing.assert_close(x, y) diff --git a/third_party/ascend/unittest/pytest_ut/test_copy.py b/third_party/ascend/unittest/pytest_ut/test_copy.py index 99af972025..4cf9daa322 100644 --- a/third_party/ascend/unittest/pytest_ut/test_copy.py +++ b/third_party/ascend/unittest/pytest_ut/test_copy.py @@ -91,6 +91,7 @@ def test_copy(): print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) + # ============== Main for manual testing ============== if __name__ == "__main__": test_copy() diff --git a/third_party/ascend/unittest/pytest_ut/test_custom.py b/third_party/ascend/unittest/pytest_ut/test_custom.py index 0c1b6488a5..8b3b6d0743 100755 --- a/third_party/ascend/unittest/pytest_ut/test_custom.py +++ b/third_party/ascend/unittest/pytest_ut/test_custom.py @@ -64,7 +64,7 @@ def __init__(self, x, ptr1, ptr2, offset: tl.int64, other, out=None): # Tag ptr2 as an argument that should be aligned at dimension 1. # Tag 2nd argument that should be aligned at dimension 0. - self.align_dim = {"ptr2": 1, 1 : 0} + self.align_dim = {"ptr2": 1, 1: 0} @triton.jit @@ -192,18 +192,18 @@ def test_custom_op(): assert 'i32, ' not in line assert "iterator_types" in line for iterator_name in ( - "parallel", - "broadcast", - "transpose", - "reduction", - "interleave", - "deinterleave", - "inverse", - "pad", - "concat", - "gather", - "cumulative", - "opaque", + "parallel", + "broadcast", + "transpose", + "reduction", + "interleave", + "deinterleave", + "inverse", + "pad", + "concat", + "gather", + "cumulative", + "opaque", ): assert iterator_name in line @@ -212,10 +212,7 @@ def _custom_lines(mlir: str, op_name: str): # Match the MLIR string attribute exactly (avoid `my_custom_op` matching # `my_custom_op_extra_buf`). quoted = f'"{op_name}"' - return [ - line for line in mlir.splitlines() - if "hivm.hir.custom" in line and quoted in line - ] + return [line for line in mlir.splitlines() if "hivm.hir.custom" in line and quoted in line] def test_custom_op_extra_buffers_mixed_scalar_types(): @@ -295,7 +292,7 @@ def test_custom_op_without_extra_buffers_has_no_extra_buffer_attrs(): test_custom_op_extra_buffers_integer_variants() test_custom_op_extra_buffers_mixed_scalar_types() test_custom_op_extra_buffers_single_buffer() - mlir = compile_kernel(my_kernel, - {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, {"BLOCK": 256}) + mlir = compile_kernel(my_kernel, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_atomic.py b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_atomic.py index b50ecdf7d5..c0527c7dc4 100644 --- a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_atomic.py +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_atomic.py @@ -40,9 +40,9 @@ def single_disc_mask_atomic_add_kernel( def test_single_discrete_mask_atomic_add(BLOCK_N): in_tensor = torch.arange(BLOCK_N, dtype=torch.float16, device='npu') expected = in_tensor.clone() - single_disc_mask_atomic_add_kernel[(1,)](in_tensor, BLOCK_N=BLOCK_N) + single_disc_mask_atomic_add_kernel[(1, )](in_tensor, BLOCK_N=BLOCK_N) half = BLOCK_N // 2 expected[:half] += 1 assert torch.allclose(in_tensor, expected), \ - f"Expected:\n{expected.cpu()}\nGot:\n{in_tensor.cpu()}" \ No newline at end of file + f"Expected:\n{expected.cpu()}\nGot:\n{in_tensor.cpu()}" diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py index 52b4eccbad..cb03ff1bec 100644 --- a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py @@ -62,7 +62,7 @@ def test_single_discrete_mask_load(BLOCK_N): in_tensor = torch.arange(BLOCK_N, dtype=torch.float16, device='npu') out_tensor = torch.empty(BLOCK_N, dtype=torch.float16, device='npu') - single_disc_mask_load_kernel[(1,)](in_tensor, out_tensor, BLOCK_N=BLOCK_N) + single_disc_mask_load_kernel[(1, )](in_tensor, out_tensor, BLOCK_N=BLOCK_N) half = BLOCK_N // 2 expected = torch.zeros(BLOCK_N, dtype=torch.float16, device='npu') @@ -91,12 +91,12 @@ def single_disc_mask_store_kernel( @pytest.mark.parametrize("BLOCK_N", [8]) def test_single_discrete_mask_store(BLOCK_N): in_tensor = torch.arange(BLOCK_N, dtype=torch.float16, device='npu') - out_tensor = torch.full((BLOCK_N,), -1.0, dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_N, ), -1.0, dtype=torch.float16, device='npu') - single_disc_mask_store_kernel[(1,)](in_tensor, out_tensor, BLOCK_N=BLOCK_N) + single_disc_mask_store_kernel[(1, )](in_tensor, out_tensor, BLOCK_N=BLOCK_N) half = BLOCK_N // 2 - expected = torch.full((BLOCK_N,), -1.0, dtype=torch.float16, device='npu') + expected = torch.full((BLOCK_N, ), -1.0, dtype=torch.float16, device='npu') expected[:half] = in_tensor[:half] assert torch.allclose(out_tensor, expected), \ f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" @@ -113,7 +113,7 @@ def single_cont_mask_load_kernel( BLOCK_M: tl.constexpr, ): row_offs = tl.arange(0, BLOCK_M) - cont_mask = row_offs < M # Continuous mask + cont_mask = row_offs < M # Continuous mask ptr_in = in_ptr + row_offs ptr_out = out_ptr + row_offs data = tl.load(ptr_in, mask=cont_mask, other=0.0) @@ -123,11 +123,11 @@ def single_cont_mask_load_kernel( @pytest.mark.parametrize("M,BLOCK_M", [(6, 8)]) def test_single_continuous_mask_load(M, BLOCK_M): in_tensor = torch.arange(BLOCK_M, dtype=torch.float16, device='npu') - out_tensor = torch.full((BLOCK_M,), -1.0, dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_M, ), -1.0, dtype=torch.float16, device='npu') - single_cont_mask_load_kernel[(1,)](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M) + single_cont_mask_load_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M) - expected = torch.full((BLOCK_M,), -1.0, dtype=torch.float16, device='npu') + expected = torch.full((BLOCK_M, ), -1.0, dtype=torch.float16, device='npu') expected[:M] = in_tensor[:M] assert torch.allclose(out_tensor, expected), \ f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" @@ -154,9 +154,9 @@ def single_cont_mask_store_kernel( @pytest.mark.parametrize("M,BLOCK_M", [(6, 8)]) def test_single_continuous_mask_store(M, BLOCK_M): in_tensor = torch.arange(BLOCK_M, dtype=torch.float16, device='npu') - out_tensor = torch.full((BLOCK_M,), -1.0, dtype=torch.float16, device='npu') - single_cont_mask_store_kernel[(1,)](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M) - expected = torch.full((BLOCK_M,), -1.0, dtype=torch.float16, device='npu') + out_tensor = torch.full((BLOCK_M, ), -1.0, dtype=torch.float16, device='npu') + single_cont_mask_store_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M) + expected = torch.full((BLOCK_M, ), -1.0, dtype=torch.float16, device='npu') expected[:M] = in_tensor[:M] assert torch.allclose(out_tensor, expected), \ f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" @@ -191,8 +191,7 @@ def test_cont_disc_combined_mask_load(M, BLOCK_M, BLOCK_N): in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') out_tensor = torch.empty((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') - cont_disc_combined_mask_load_kernel[(1,)]( - in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + cont_disc_combined_mask_load_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) half_n = BLOCK_N // 2 expected = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') @@ -214,8 +213,8 @@ def cont_disc_combined_mask_store_kernel( ): row_offs = tl.arange(0, BLOCK_M) col_offs = tl.arange(0, BLOCK_N) - row_boundary = row_offs < M # continuous -> contLeaf - col_stride = (col_offs * 2) < BLOCK_N # discrete -> discLeaf + row_boundary = row_offs < M # continuous -> contLeaf + col_stride = (col_offs * 2) < BLOCK_N # discrete -> discLeaf mask = row_boundary[:, None] & col_stride[None, :] ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] @@ -227,8 +226,7 @@ def cont_disc_combined_mask_store_kernel( def test_cont_disc_combined_mask_store(M, BLOCK_M, BLOCK_N): in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') out_tensor = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') - cont_disc_combined_mask_store_kernel[(1,)]( - in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + cont_disc_combined_mask_store_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) half_n = BLOCK_N // 2 expected = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') expected[:M, :half_n] = 1.0 @@ -249,9 +247,9 @@ def interleave_cont_disc_mask_kernel( pid = tl.program_id(0) col_offs = tl.arange(0, N) even_col_offs = tl.arange(0, N // 2) * 2 - even_col_mask = even_col_offs < N # discrete: cmpi(muli(range,2), N) + even_col_mask = even_col_offs < N # discrete: cmpi(muli(range,2), N) row_offs = tl.arange(0, M) - row_mask = row_offs < M # continuous: cmpi(range_M, M) + row_mask = row_offs < M # continuous: cmpi(range_M, M) in_even_ptr = in_ptr + row_offs[:, None] * N + even_col_offs[None, :] in_odd_ptr = in_ptr + row_offs[:, None] * N + even_col_offs[None, :] + 1 even_data = tl.load(in_even_ptr, mask=row_mask[:, None] & even_col_mask[None, :], other=0.0) @@ -268,7 +266,7 @@ def test_discrete_mask_load_store(M, N): """Regression test: mask=row_mask & even_col_mask (continuous & discrete 2-way)""" input_tensor = torch.arange(M * N, dtype=torch.float16, device='npu').reshape(M, N) output_tensor = torch.empty_like(input_tensor) - interleave_cont_disc_mask_kernel[(1,)](input_tensor, output_tensor, M=M, N=N) + interleave_cont_disc_mask_kernel[(1, )](input_tensor, output_tensor, M=M, N=N) even_cols = input_tensor[:, 0::2] odd_cols = input_tensor[:, 1::2] ref_output = torch.empty_like(input_tensor) @@ -292,15 +290,12 @@ def multi_cont_disc_mask_load_store_kernel( row_offs = tl.arange(0, BLOCK_M) col_offs = tl.arange(0, BLOCK_N) - row_boundary = row_offs < M # continuous mask - col_boundary = col_offs < N # continuous mask - row_stride = (row_offs * 2) < BLOCK_M # discrete mask - col_stride = (col_offs * 2) < BLOCK_N # discrete mask + row_boundary = row_offs < M # continuous mask + col_boundary = col_offs < N # continuous mask + row_stride = (row_offs * 2) < BLOCK_M # discrete mask + col_stride = (col_offs * 2) < BLOCK_N # discrete mask - mask = (row_boundary[:, None] - & col_boundary[None, :] - & row_stride[:, None] - & col_stride[None, :]) + mask = (row_boundary[:, None] & col_boundary[None, :] & row_stride[:, None] & col_stride[None, :]) ptr_in = in_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] ptr_out = out_ptr + row_offs[:, None] * BLOCK_N + col_offs[None, :] @@ -317,18 +312,15 @@ def test_multi_cont_disc_mask_load_store(M, N, BLOCK_M, BLOCK_N): in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') out_tensor = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') - multi_cont_disc_mask_load_store_kernel[(1,)]( - in_tensor, out_tensor, M=M, N=N, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + multi_cont_disc_mask_load_store_kernel[(1, )](in_tensor, out_tensor, M=M, N=N, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) - half_m = BLOCK_M // 2 # = 4 - half_n = BLOCK_N // 2 # = 4 + half_m = BLOCK_M // 2 # = 4 + half_n = BLOCK_N // 2 # = 4 expected = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') expected[:half_m, :half_n] = 2.0 - assert torch.allclose(out_tensor, expected), ( - f"BLOCK=({BLOCK_M},{BLOCK_N}), valid=({M},{N})\n" - f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" - ) + assert torch.allclose(out_tensor, expected), (f"BLOCK=({BLOCK_M},{BLOCK_N}), valid=({M},{N})\n" + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}") # ============================================================================= @@ -361,8 +353,7 @@ def test_broadcast_cont_disc_2d_load(M, BLOCK_M, BLOCK_N): in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') out_tensor = torch.empty((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') - broadcast_cont_disc_2d_load_kernel[(1,)]( - in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + broadcast_cont_disc_2d_load_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) disc_true_rows = BLOCK_M // 2 both_true_rows = min(M, disc_true_rows) @@ -370,10 +361,8 @@ def test_broadcast_cont_disc_2d_load(M, BLOCK_M, BLOCK_N): expected = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') expected[:both_true_rows, :] = 1.0 - assert torch.allclose(out_tensor, expected), ( - f"M={M}, BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}\n" - f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" - ) + assert torch.allclose(out_tensor, expected), (f"M={M}, BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}\n" + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}") # ============================================================================= @@ -407,8 +396,7 @@ def test_broadcast_cont_disc_2d_load_store(M, BLOCK_M, BLOCK_N): in_tensor = torch.ones((BLOCK_M, BLOCK_N), dtype=torch.float16, device='npu') out_tensor = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') - broadcast_cont_disc_2d_load_store_kernel[(1,)]( - in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + broadcast_cont_disc_2d_load_store_kernel[(1, )](in_tensor, out_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) disc_true_rows = BLOCK_M // 2 both_true_rows = min(M, disc_true_rows) @@ -416,7 +404,5 @@ def test_broadcast_cont_disc_2d_load_store(M, BLOCK_M, BLOCK_N): expected = torch.full((BLOCK_M, BLOCK_N), -1.0, dtype=torch.float16, device='npu') expected[:both_true_rows, :] = 1.0 - assert torch.allclose(out_tensor, expected), ( - f"M={M}, BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}\n" - f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}" - ) + assert torch.allclose(out_tensor, expected), (f"M={M}, BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}\n" + f"Expected:\n{expected.cpu()}\nGot:\n{out_tensor.cpu()}") diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_tail_block_mte_oob.py b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_tail_block_mte_oob.py index 1bb5925123..d8129d2677 100644 --- a/third_party/ascend/unittest/pytest_ut/test_discrete_mask_tail_block_mte_oob.py +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_tail_block_mte_oob.py @@ -108,8 +108,8 @@ def _fill_segment_to_boundary(dtype, device, in_bytes, target_free, chunk_max_by pool1 = torch.npu.memory_reserved(0) alloc1 = torch.npu.memory_allocated(0) - seg_size = pool1 - pool0 # should be 2 MB = 2097152 bytes - probe_actual = alloc1 - alloc0 # NPU 512-byte aligned → 512 bytes + seg_size = pool1 - pool0 # should be 2 MB = 2097152 bytes + probe_actual = alloc1 - alloc0 # NPU 512-byte aligned → 512 bytes print(f"\n[mte] Step 1: probe") print(f"[mte] segment_size = {seg_size} bytes ({seg_size // 1024} KB)") @@ -121,25 +121,23 @@ def _fill_segment_to_boundary(dtype, device, in_bytes, target_free, chunk_max_by # avoid opening a new segment via the large-alloc path. pre_fillers = [probe] - for chunk in [chunk_max_bytes, - chunk_max_bytes // 2, - chunk_max_bytes // 4, - chunk_max_bytes // 8, - 32 * 1024, 16 * 1024, 8 * 1024, - 4 * 1024, 2 * 1024, 1024, 512]: + for chunk in [ + chunk_max_bytes, chunk_max_bytes // 2, chunk_max_bytes // 4, chunk_max_bytes // 8, 32 * 1024, 16 * 1024, + 8 * 1024, 4 * 1024, 2 * 1024, 1024, 512 + ]: while True: free = torch.npu.memory_reserved(0) - torch.npu.memory_allocated(0) if free <= in_bytes: - break # not enough room even for in_tensor; stop + break # not enough room even for in_tensor; stop if free <= target_free: - break # already in target range; try smaller chunk + break # already in target range; try smaller chunk if free <= target_free + chunk: - break # this chunk would overshoot; try smaller chunk + break # this chunk would overshoot; try smaller chunk try: t = torch.empty(chunk // elem_size, dtype=dtype, device=device) pre_fillers.append(t) except RuntimeError: - break # segment exhausted; try smaller chunk + break # segment exhausted; try smaller chunk pool_free_after_fill = torch.npu.memory_reserved(0) - torch.npu.memory_allocated(0) pre_bytes = sum(t.numel() * elem_size for t in pre_fillers) @@ -159,7 +157,7 @@ def _fill_segment_to_boundary(dtype, device, in_bytes, target_free, chunk_max_by def test_mte_segment_boundary_oob(BLOCK_M, BLOCK_N, M): """Regression: combined discrete mask load causes OOB on tail blocks. - Verifies that DiscreteMaskAccessConversionPass correctly bounds + Verifies that DiscreteMaskAccessConversionPass correctly bounds the memory copy to M rows (the contiguous range), not BLOCK_M rows (the full tile). Test outcome: @@ -170,12 +168,12 @@ def test_mte_segment_boundary_oob(BLOCK_M, BLOCK_N, M): device = 'npu' elem_size = 2 # float16 - in_bytes = M * BLOCK_N * elem_size # 8192 bytes - oob_bytes = (BLOCK_M - M) * BLOCK_N * elem_size # 24576 bytes + in_bytes = M * BLOCK_N * elem_size # 8192 bytes + oob_bytes = (BLOCK_M - M) * BLOCK_N * elem_size # 24576 bytes # TARGET_FREE: midpoint between in_bytes and oob_bytes. # Ensures in_tensor fits AND gap < oob_bytes so OOB crosses segment boundary. - target_free = (in_bytes + oob_bytes) // 2 # 16384 bytes - chunk_max_bytes = 512 * 1024 # 512 KB + target_free = (in_bytes + oob_bytes) // 2 # 16384 bytes + chunk_max_bytes = 512 * 1024 # 512 KB print(f"\n[mte] BLOCK_M={BLOCK_M} BLOCK_N={BLOCK_N} M={M}") print(f"[mte] in_bytes = {in_bytes} bytes (in_tensor: {M}×{BLOCK_N}×{elem_size})") @@ -188,9 +186,8 @@ def test_mte_segment_boundary_oob(BLOCK_M, BLOCK_N, M): in_tensor = None try: - pre_fillers, pool_free_after_fill, _ = _fill_segment_to_boundary( - dtype, device, in_bytes, target_free, chunk_max_bytes - ) + pre_fillers, pool_free_after_fill, _ = _fill_segment_to_boundary(dtype, device, in_bytes, target_free, + chunk_max_bytes) except Exception as exc: torch.npu.empty_cache() pytest.skip(f"Memory layout setup failed (allocator behaviour may differ): {exc}") @@ -200,11 +197,9 @@ def test_mte_segment_boundary_oob(BLOCK_M, BLOCK_N, M): for t in reversed(pre_fillers): del t torch.npu.empty_cache() - pytest.skip( - f"pre_fill did not reach target range [{in_bytes}, {target_free}] bytes; " - f"got {pool_free_after_fill} bytes. " - f"Skipping MTE check (NPU allocator behaviour may differ)." - ) + pytest.skip(f"pre_fill did not reach target range [{in_bytes}, {target_free}] bytes; " + f"got {pool_free_after_fill} bytes. " + f"Skipping MTE check (NPU allocator behaviour may differ).") try: # Step 3: allocate in_tensor — lands at the very end of the segment. @@ -218,27 +213,21 @@ def test_mte_segment_boundary_oob(BLOCK_M, BLOCK_N, M): print(f"[mte] gap = {gap} bytes (in_tensor end → segment end)") if oob_bytes <= gap: - pytest.skip( - f"gap ({gap} bytes) >= oob_bytes ({oob_bytes} bytes): " - f"OOB would not cross the segment boundary. " - f"Skipping MTE check." - ) + pytest.skip(f"gap ({gap} bytes) >= oob_bytes ({oob_bytes} bytes): " + f"OOB would not cross the segment boundary. " + f"Skipping MTE check.") print(f"[mte] oob_bytes({oob_bytes} B) > gap({gap} B) → MTE expected if unfixed ✓") # Step 4: run kernel num_pids_m = math.ceil(M / BLOCK_M) print(f"\n[mte] Step 4: kernel (grid=({num_pids_m},))") - cont_disc_oob_inplace_2d_kernel[(num_pids_m,)]( - in_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) + cont_disc_oob_inplace_2d_kernel[(num_pids_m, )](in_tensor, M=M, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) torch.npu.synchronize() print("[mte] PASSED: fix is effective, no OOB.") except RuntimeError as exc: - pytest.fail( - f"MTE OOB triggered — DiscreteMaskAccessConversionPass fix " - f"may not be applied or is incomplete.\nError: {exc}" - ) + pytest.fail(f"MTE OOB triggered — DiscreteMaskAccessConversionPass fix " + f"may not be applied or is incomplete.\nError: {exc}") finally: if in_tensor is not None: diff --git a/third_party/ascend/unittest/pytest_ut/test_discrete_overlap_mask.py b/third_party/ascend/unittest/pytest_ut/test_discrete_overlap_mask.py index 6f44b1de5c..2d4f6437dc 100644 --- a/third_party/ascend/unittest/pytest_ut/test_discrete_overlap_mask.py +++ b/third_party/ascend/unittest/pytest_ut/test_discrete_overlap_mask.py @@ -27,10 +27,10 @@ # --------------------------------------------------------------------------- # Fixed Constants # --------------------------------------------------------------------------- -_M_ROWS = 16 # Rows per program -_OFFS = 8 # Write window offset for two programs -_HALF = 12 # Mask threshold -_NUM_C = 24 # Rows of matrix C (= OFFS + M_ROWS, ensure pid=1 write window not out of bounds) +_M_ROWS = 16 # Rows per program +_OFFS = 8 # Write window offset for two programs +_HALF = 12 # Mask threshold +_NUM_C = 24 # Rows of matrix C (= OFFS + M_ROWS, ensure pid=1 write window not out of bounds) assert _OFFS < _HALF < _M_ROWS, "OFFS < HALF < M_ROWS to ensure True/False on both sides" assert _NUM_C >= _OFFS + _M_ROWS, "NUM_C must accommodate upper bound of pid=1 write window" @@ -41,7 +41,8 @@ # --------------------------------------------------------------------------- @triton.jit def _copy_matrix_kernel( - A_ptr, idx_ptr, + A_ptr, + idx_ptr, C_ptr, idx_stride, A_row_stride, @@ -64,8 +65,8 @@ def _copy_matrix_kernel( OFFS: tl.constexpr = 8 M_ROWS: tl.constexpr = 16 - N_BLOCK = N_id * BLOCK_N + tl.arange(0, BLOCK_N) # shape: (BLOCK_N,) - M_BLOCK = tl.arange(0, M_ROWS) # shape: (M_ROWS,) + N_BLOCK = N_id * BLOCK_N + tl.arange(0, BLOCK_N) # shape: (BLOCK_N,) + M_BLOCK = tl.arange(0, M_ROWS) # shape: (M_ROWS,) # Discrete row indices (loaded at runtime -> mask cannot be statically analyzed) idx = tl.load(idx_ptr + program_id * idx_stride + M_BLOCK) @@ -83,8 +84,7 @@ def _copy_matrix_kernel( # Write to C (mask=False rows rely on load-select-store to preserve original values) tl.store( - C_ptr + (OFFS * program_id + M_BLOCK[:, None]) * C_row_stride - + N_BLOCK[None, :] * C_col_stride, + C_ptr + (OFFS * program_id + M_BLOCK[:, None]) * C_row_stride + N_BLOCK[None, :] * C_col_stride, val, mask=mask[:, None], ) @@ -109,6 +109,7 @@ def _make_idx(device: str) -> torch.Tensor: First 4 values ∈ [OFFS, OFFS+4) -> mask=False Last HALF=12 values ∈ [HALF, HALF*2) -> mask=True """ + def shuffle_quads(lst: list) -> list: """Reverse each group of 4 elements (ignore if less than 4).""" out = lst[:] @@ -117,15 +118,15 @@ def shuffle_quads(lst: list) -> list: out[i + 3], out[i + 2], out[i + 1], out[i] return out - num_false = _M_ROWS - _HALF # = 4 + num_false = _M_ROWS - _HALF # = 4 - seg0_true = shuffle_quads(list(range(0, _HALF))) # 12 values, < 12 - seg0_false = shuffle_quads(list(range(_HALF, _HALF + num_false))) # 4 values, >= 12 - idx0 = seg0_true + seg0_false # Total length 16 + seg0_true = shuffle_quads(list(range(0, _HALF))) # 12 values, < 12 + seg0_false = shuffle_quads(list(range(_HALF, _HALF + num_false))) # 4 values, >= 12 + idx0 = seg0_true + seg0_false # Total length 16 - seg1_false = shuffle_quads(list(range(_OFFS, _OFFS + num_false))) # 4 values, < 12 - seg1_true = shuffle_quads(list(range(_HALF, _HALF + _HALF))) # 12 values, >= 12 - idx1 = seg1_false + seg1_true # Total length 16 + seg1_false = shuffle_quads(list(range(_OFFS, _OFFS + num_false))) # 4 values, < 12 + seg1_true = shuffle_quads(list(range(_HALF, _HALF + _HALF))) # 12 values, >= 12 + idx1 = seg1_false + seg1_true # Total length 16 assert len(idx0) == _M_ROWS, f"idx0 length error: {len(idx0)}" assert len(idx1) == _M_ROWS, f"idx1 length error: {len(idx1)}" @@ -175,10 +176,14 @@ def _run_once(BLOCK_N: int, dtype_str: str) -> None: grid = (2, 1) _copy_matrix_kernel[grid]( - A_ptr=A, idx_ptr=idx, C_ptr=C, + A_ptr=A, + idx_ptr=idx, + C_ptr=C, idx_stride=idx.stride(0), - A_row_stride=A.stride(0), A_col_stride=A.stride(1), - C_row_stride=C.stride(0), C_col_stride=C.stride(1), + A_row_stride=A.stride(0), + A_col_stride=A.stride(1), + C_row_stride=C.stride(0), + C_col_stride=C.stride(1), BLOCK_N=BLOCK_N, HALF=_HALF, enable_sync_block_lock=True, @@ -187,12 +192,10 @@ def _run_once(BLOCK_N: int, dtype_str: str) -> None: # Verification assert torch.all(C[:_HALF] == zero_val), ( f"[dtype={dtype_str}, BLOCK_N={BLOCK_N}] " - f"C[:HALF] should all be {zero_val}, actual unique values: {C[:_HALF].unique().tolist()}" - ) + f"C[:HALF] should all be {zero_val}, actual unique values: {C[:_HALF].unique().tolist()}") assert torch.all(C[_HALF:] == one_val), ( f"[dtype={dtype_str}, BLOCK_N={BLOCK_N}] " - f"C[HALF:] should all be {one_val}, actual unique values: {C[_HALF:].unique().tolist()}" - ) + f"C[HALF:] should all be {one_val}, actual unique values: {C[_HALF:].unique().tolist()}") @pytest.mark.parametrize("param_list", [ diff --git a/third_party/ascend/unittest/pytest_ut/test_dot.py b/third_party/ascend/unittest/pytest_ut/test_dot.py index 1e6cf3c3a5..635df74c5e 100644 --- a/third_party/ascend/unittest/pytest_ut/test_dot.py +++ b/third_party/ascend/unittest/pytest_ut/test_dot.py @@ -112,6 +112,7 @@ def triton_dot_2_ignore_tf32(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.co oidx = bidx[:, None] * D + didx[None, :] tl.store(output_ptr + oidx, ret, mask=out_mask) + testlist1 = [ (10, 13, 35, 39), ] @@ -136,8 +137,7 @@ def test_dot_2(restore_npu_hf32_setting, sigtype, B, C, D): @pytest.mark.xfail( - reason="Temporarily disabled: TA backend does not support allow_tf32 yet. Will be fixed in follow-up." -) + reason="Temporarily disabled: TA backend does not support allow_tf32 yet. Will be fixed in follow-up.") @pytest.mark.parametrize("B, C, D", testlist2) @pytest.mark.parametrize("sigtype", typelist) def test_dot_2_allow_tf32(restore_npu_hf32_setting, sigtype, B, C, D): diff --git a/third_party/ascend/unittest/pytest_ut/test_erfinv.py b/third_party/ascend/unittest/pytest_ut/test_erfinv.py index a41521417a..51d155a115 100644 --- a/third_party/ascend/unittest/pytest_ut/test_erfinv.py +++ b/third_party/ascend/unittest/pytest_ut/test_erfinv.py @@ -84,11 +84,9 @@ def test_all_blocks_parallel(param_list, monkeypatch): monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") -@pytest.mark.parametrize('param_list', - [ - ['float32', (2, 4096, 8), 2, 32768, 1024], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], +]) def test_auto_blockify(param_list, monkeypatch): monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") dtype, shape, ncore, xblock, xblock_sub = param_list @@ -97,7 +95,7 @@ def test_auto_blockify(param_list, monkeypatch): x[0][0][1] = -1 # erfinv(-1) -> -∞ # Avoid numerical instability near ±1 - # Move values in (threshold, 1) to threshold and (-1, -threshold) to -threshold + # Move values in (threshold, 1) to threshold and (-1, -threshold) to -threshold threshold = 1 - 1.1e-4 too_close_pos = (x > threshold) & (x < 1) too_close_neg = (x < -threshold) & (x > -1) diff --git a/third_party/ascend/unittest/pytest_ut/test_expm1.py b/third_party/ascend/unittest/pytest_ut/test_expm1.py index d1eb5d2cda..7a06320d61 100644 --- a/third_party/ascend/unittest/pytest_ut/test_expm1.py +++ b/third_party/ascend/unittest/pytest_ut/test_expm1.py @@ -46,10 +46,9 @@ def triton_expm1(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constex @pytest.mark.skip(reason="expm1 failed sometimes, wait for fix") -@pytest.mark.parametrize('param_list', - [ - ['float32', (2, 4096, 8), 2, 32768, 1024], - ]) +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], +]) def test_expm1(param_list): dtype, shape, ncore, xblock, xblock_sub = param_list x0_ref = test_common.generate_tensor(shape, dtype) diff --git a/third_party/ascend/unittest/pytest_ut/test_fast_dividef.py b/third_party/ascend/unittest/pytest_ut/test_fast_dividef.py index 7d0d9ebdb6..be534a5fdb 100644 --- a/third_party/ascend/unittest/pytest_ut/test_fast_dividef.py +++ b/third_party/ascend/unittest/pytest_ut/test_fast_dividef.py @@ -45,13 +45,10 @@ def triton_fast_dividef(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK tl.store(out_ptr0 + (x0), tmp2, None) -@pytest.mark.parametrize('param_list', - [ - ['float32', (2, 4096, 8), 2, 32768, 1024], - ['float16', (2, 4096, 8), 2, 32768, 1024], - ] - ) - +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], +]) def test_case(param_list): dtype, shape, ncore, xblock, xblock_sub = param_list x0 = test_common.generate_tensor(shape, dtype).npu() diff --git a/third_party/ascend/unittest/pytest_ut/test_fast_expf.py b/third_party/ascend/unittest/pytest_ut/test_fast_expf.py index 7d0d9ebdb6..be534a5fdb 100644 --- a/third_party/ascend/unittest/pytest_ut/test_fast_expf.py +++ b/third_party/ascend/unittest/pytest_ut/test_fast_expf.py @@ -45,13 +45,10 @@ def triton_fast_dividef(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK tl.store(out_ptr0 + (x0), tmp2, None) -@pytest.mark.parametrize('param_list', - [ - ['float32', (2, 4096, 8), 2, 32768, 1024], - ['float16', (2, 4096, 8), 2, 32768, 1024], - ] - ) - +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], +]) def test_case(param_list): dtype, shape, ncore, xblock, xblock_sub = param_list x0 = test_common.generate_tensor(shape, dtype).npu() diff --git a/third_party/ascend/unittest/pytest_ut/test_gamma.py b/third_party/ascend/unittest/pytest_ut/test_gamma.py index 19a33dc688..b3cdf3b9af 100644 --- a/third_party/ascend/unittest/pytest_ut/test_gamma.py +++ b/third_party/ascend/unittest/pytest_ut/test_gamma.py @@ -70,11 +70,9 @@ def test_all_blocks_parallel(param_list, monkeypatch): monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") -@pytest.mark.parametrize('param_list', - [ - ['float32', (2, 2048, 8), 2, 32768, 512], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 2048, 8), 2, 32768, 512], +]) def test_auto_blockify(param_list, monkeypatch): monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") dtype, shape, ncore, xblock, xblock_sub = param_list diff --git a/third_party/ascend/unittest/pytest_ut/test_if_advance.py b/third_party/ascend/unittest/pytest_ut/test_if_advance.py index beb8c670c7..2c63e7ebba 100644 --- a/third_party/ascend/unittest/pytest_ut/test_if_advance.py +++ b/third_party/ascend/unittest/pytest_ut/test_if_advance.py @@ -4,35 +4,17 @@ import triton.language as tl import triton.language.extra.cann.extension as al + @triton.jit -def triton_if_advance_kernel(in_ptr0, in_ptr1, out_ptr, - xnumel, ynumel, k_loops, - XBLOCK: tl.constexpr, YBLOCK: tl.constexpr): +def triton_if_advance_kernel(in_ptr0, in_ptr1, out_ptr, xnumel, ynumel, k_loops, XBLOCK: tl.constexpr, + YBLOCK: tl.constexpr): - K_block_ptr = tl.make_block_ptr( - base = in_ptr0, - shape = (xnumel, ynumel), - strides = (ynumel, 1), - offsets = (0, 0), - block_shape = (XBLOCK, YBLOCK), - order = (1, 0) - ) - V_block_ptr = tl.make_block_ptr( - base = in_ptr1, - shape = (ynumel, xnumel), - strides = (xnumel, 1), - offsets = (0, 0), - block_shape = (YBLOCK, XBLOCK), - order = (1, 0) - ) - O_block_ptr = tl.make_block_ptr( - base = out_ptr, - shape = (xnumel, xnumel), - strides = (xnumel, 1), - offsets = (0, 0), - block_shape = (XBLOCK, XBLOCK), - order = (1, 0) - ) + K_block_ptr = tl.make_block_ptr(base=in_ptr0, shape=(xnumel, ynumel), strides=(ynumel, 1), offsets=(0, 0), + block_shape=(XBLOCK, YBLOCK), order=(1, 0)) + V_block_ptr = tl.make_block_ptr(base=in_ptr1, shape=(ynumel, xnumel), strides=(xnumel, 1), offsets=(0, 0), + block_shape=(YBLOCK, XBLOCK), order=(1, 0)) + O_block_ptr = tl.make_block_ptr(base=out_ptr, shape=(xnumel, xnumel), strides=(xnumel, 1), offsets=(0, 0), + block_shape=(XBLOCK, XBLOCK), order=(1, 0)) res = tl.zeros([XBLOCK, XBLOCK], tl.float32) for i in range(0, k_loops): if i > 0: @@ -40,16 +22,18 @@ def triton_if_advance_kernel(in_ptr0, in_ptr1, out_ptr, V_block_ptr = tl.advance(V_block_ptr, (YBLOCK, 0)) a = tl.load(K_block_ptr) b = tl.load(V_block_ptr) - res = tl.dot(a, b, acc = res) + res = tl.dot(a, b, acc=res) tl.store(O_block_ptr, res) + def test_if_advance(): x = torch.randn((64, 256), dtype=torch.float32, device="npu") y = torch.randn((256, 64), dtype=torch.float32, device="npu") out_tri = torch.empty((64, 64), dtype=torch.float32, device="npu") out_std = torch.empty((64, 64), dtype=torch.float32, device="npu") - torch.matmul(x, y, out = out_std) - triton_if_advance_kernel[1,1,1](x, y, out_tri, 64, 256, 4, 64, 64) - torch.testing.assert_close(out_std, out_tri, atol = 1e-2, rtol = 1e-2) + torch.matmul(x, y, out=out_std) + triton_if_advance_kernel[1, 1, 1](x, y, out_tri, 64, 256, 4, 64, 64) + torch.testing.assert_close(out_std, out_tri, atol=1e-2, rtol=1e-2) + -test_if_advance() \ No newline at end of file +test_if_advance() diff --git a/third_party/ascend/unittest/pytest_ut/test_if_load.py b/third_party/ascend/unittest/pytest_ut/test_if_load.py index 0c2e0c8e5a..7932366cb4 100644 --- a/third_party/ascend/unittest/pytest_ut/test_if_load.py +++ b/third_party/ascend/unittest/pytest_ut/test_if_load.py @@ -56,29 +56,27 @@ def triton_for_if_load(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.c tl.store(out_ptr0 + index, tmp0, None) -@pytest.mark.parametrize('param_list', - [ - ['float32', (32,), 32], - ]) +@pytest.mark.parametrize('param_list', [ + ['float32', (32, ), 32], +]) def test_if_load(param_list): dtype, shape, xblock = param_list x0 = test_common.generate_tensor(shape, dtype).npu() y_ref = x0.clone() y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_if_load[(1,)](x0, y_cal, xblock) + triton_if_load[(1, )](x0, y_cal, xblock) test_common.validate_cmp(dtype, y_cal, y_ref) -@pytest.mark.parametrize('param_list', - [ - ['float32', (32,), 32, 16], - ]) +@pytest.mark.parametrize('param_list', [ + ['float32', (32, ), 32, 16], +]) def test_if_load(param_list): dtype, shape, xblock, xblock_sub = param_list x0 = test_common.generate_tensor(shape, dtype).npu() y_ref = x0.clone() y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_for_if_load[(1,)](x0, y_cal, xblock, xblock_sub) + triton_for_if_load[(1, )](x0, y_cal, xblock, xblock_sub) test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_implicit_atomic.py b/third_party/ascend/unittest/pytest_ut/test_implicit_atomic.py index 67e5c729c8..b7a539f392 100644 --- a/third_party/ascend/unittest/pytest_ut/test_implicit_atomic.py +++ b/third_party/ascend/unittest/pytest_ut/test_implicit_atomic.py @@ -24,7 +24,6 @@ import test_common import triton.language as tl - types_all = [ (torch.float32, 'float32'), ] @@ -46,7 +45,7 @@ def addptr_implicit_perm_atomic_add_2d( y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] # [1, YB] x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] # [XB, 1] - val = 1.0 + (x.to(tl.float32) * 0.01) + (y.to(tl.float32) * 0.001) # [XB, YB] + val = 1.0 + (x.to(tl.float32) * 0.01) + (y.to(tl.float32) * 0.001) # [XB, YB] xmask = x < xnumel ymask = y < ynumel old = tl.atomic_add(ptr + (x + 4 * y), val, xmask & ymask) @@ -85,18 +84,18 @@ def addptr_implicit_perm_atomic_cas_2d( @pytest.mark.parametrize('dtype,sigtype', types_all) @pytest.mark.parametrize('xnumel, ynumel, XBLOCK, YBLOCK', [(4, 512, 4, 64)]) def test_addptr_implicit_perm_atomic_add_2d( - dtype, sigtype, - xnumel, ynumel, - XBLOCK, YBLOCK, + dtype, + sigtype, + xnumel, + ynumel, + XBLOCK, + YBLOCK, ): - in_ptr = torch.zeros((ynumel * 4,), dtype=dtype).npu() + in_ptr = torch.zeros((ynumel * 4, ), dtype=dtype).npu() out_ptr = torch.ones_like(in_ptr) grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) - addptr_implicit_perm_atomic_add_2d[grid]( - in_ptr, out_ptr, ynumel, xnumel, - YBLOCK=YBLOCK, XBLOCK=XBLOCK - ) + addptr_implicit_perm_atomic_add_2d[grid](in_ptr, out_ptr, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK) y_idx = torch.arange(ynumel).unsqueeze(1).npu() x_idx = torch.arange(xnumel).unsqueeze(0).npu() @@ -110,20 +109,21 @@ def test_addptr_implicit_perm_atomic_add_2d( @pytest.mark.parametrize('dtype,sigtype', types_all) @pytest.mark.parametrize('xnumel, ynumel, XBLOCK, YBLOCK', [(4, 512, 4, 64)]) def test_addptr_implicit_perm_atomic_cas_2d( - dtype, sigtype, - xnumel, ynumel, - XBLOCK, YBLOCK, + dtype, + sigtype, + xnumel, + ynumel, + XBLOCK, + YBLOCK, ): - in_ptr = torch.full((ynumel * 4,), 2, dtype=dtype).npu() - out_ptr = torch.full((ynumel * 4,), 1, dtype=dtype).npu() - cmp_ptr = torch.full((ynumel * 4,), 2, dtype=dtype).npu() - val_ptr = torch.full((ynumel * 4,), 1, dtype=dtype).npu() + in_ptr = torch.full((ynumel * 4, ), 2, dtype=dtype).npu() + out_ptr = torch.full((ynumel * 4, ), 1, dtype=dtype).npu() + cmp_ptr = torch.full((ynumel * 4, ), 2, dtype=dtype).npu() + val_ptr = torch.full((ynumel * 4, ), 1, dtype=dtype).npu() grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) - addptr_implicit_perm_atomic_cas_2d[grid]( - in_ptr, out_ptr, cmp_ptr, val_ptr, ynumel, xnumel, - YBLOCK=YBLOCK, XBLOCK=XBLOCK - ) + addptr_implicit_perm_atomic_cas_2d[grid](in_ptr, out_ptr, cmp_ptr, val_ptr, ynumel, xnumel, YBLOCK=YBLOCK, + XBLOCK=XBLOCK) y_idx = torch.arange(ynumel).unsqueeze(1).npu() x_idx = torch.arange(xnumel).unsqueeze(0).npu() diff --git a/third_party/ascend/unittest/pytest_ut/test_implicit_permute.py b/third_party/ascend/unittest/pytest_ut/test_implicit_permute.py index f5c82bfafd..9aaafa2371 100644 --- a/third_party/ascend/unittest/pytest_ut/test_implicit_permute.py +++ b/third_party/ascend/unittest/pytest_ut/test_implicit_permute.py @@ -24,7 +24,6 @@ import triton.language as tl import test_common - types_all = [ (torch.float32, 'float32'), ] @@ -49,19 +48,12 @@ # Triton kernel # ---------------------------------------------------------- @triton.jit -def addptr_implicit_perm_load_store_2d_static_stride( - ptr, - out, - ynumel, - xnumel, - stride_y: tl.constexpr, - stride_x: tl.constexpr, - YBLOCK: tl.constexpr, - XBLOCK: tl.constexpr -): +def addptr_implicit_perm_load_store_2d_static_stride(ptr, out, ynumel, xnumel, stride_y: tl.constexpr, + stride_x: tl.constexpr, YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr): # logical indices (A^T view) - x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] - y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] + x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] + y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] mask = (x < xnumel) & (y < ynumel) # IMPORTANT: @@ -84,8 +76,8 @@ def addptr_implicit_perm_load_store_2d( XBLOCK: tl.constexpr, ): # logical indices (A^T view) - x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] - y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] + x = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] + y = tl.program_id(1) * YBLOCK + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] mask = (x < xnumel) & (y < ynumel) @@ -172,9 +164,9 @@ def addptr_implicit_perm_load_store_4d_static_stride( YBLOCK: tl.constexpr, XBLOCK: tl.constexpr, ): - pid0 = tl.program_id(0) # covers (w, x) - pid1 = tl.program_id(1) # y - pid2 = tl.program_id(2) # z + pid0 = tl.program_id(0) # covers (w, x) + pid1 = tl.program_id(1) # y + pid2 = tl.program_id(2) # z xblocks_per_w = (xnumel + XBLOCK - 1) // XBLOCK @@ -215,9 +207,9 @@ def addptr_implicit_perm_load_store_4d( YBLOCK: tl.constexpr, XBLOCK: tl.constexpr, ): - pid0 = tl.program_id(0) # covers (w, x) - pid1 = tl.program_id(1) # y - pid2 = tl.program_id(2) # z + pid0 = tl.program_id(0) # covers (w, x) + pid1 = tl.program_id(1) # y + pid2 = tl.program_id(2) # z xblocks_per_w = (xnumel + XBLOCK - 1) // XBLOCK @@ -242,20 +234,13 @@ def addptr_implicit_perm_load_store_4d( @triton.jit -def make_tensor_ptr_implicit_perm_load_store_2d_static_stride( - ptr, - out, - ynumel, - xnumel, - STRIDE_Y: tl.constexpr, - STRIDE_X: tl.constexpr, - YBLOCK: tl.constexpr, - XBLOCK: tl.constexpr -): +def make_tensor_ptr_implicit_perm_load_store_2d_static_stride(ptr, out, ynumel, xnumel, STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr): y0 = tl.program_id(1) * YBLOCK x0 = tl.program_id(0) * XBLOCK - y = y0 + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] - x = x0 + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] + y = y0 + tl.arange(0, YBLOCK)[None, :] # [1, YBLOCK] + x = x0 + tl.arange(0, XBLOCK)[:, None] # [XBLOCK, 1] xmask = x < xnumel ymask = y < ynumel mask = xmask & ymask @@ -359,9 +344,9 @@ def make_tensor_ptr_implicit_perm_load_store_3d( def make_tensor_ptr_implicit_perm_load_3d_static_stride( ptr, out, - znumel, # logical z (== X) - ynumel, # logical y (== Y) - xnumel, # logical x (== Z) + znumel, # logical z (== X) + ynumel, # logical y (== Y) + xnumel, # logical x (== Z) STRIDE_Z: tl.constexpr, STRIDE_Y: tl.constexpr, STRIDE_X: tl.constexpr, @@ -403,18 +388,10 @@ def make_tensor_ptr_implicit_perm_load_3d_static_stride( tl.store(out + out_offset, val, mask=mask) - @triton.jit -def advance_implicit_perm_load_store_2d_static_stride( - ptr, - out, - ynumel, - xnumel, - STRIDE_Y: tl.constexpr, - STRIDE_X: tl.constexpr, - YBLOCK: tl.constexpr, - XBLOCK: tl.constexpr -): +def advance_implicit_perm_load_store_2d_static_stride(ptr, out, ynumel, xnumel, STRIDE_Y: tl.constexpr, + STRIDE_X: tl.constexpr, YBLOCK: tl.constexpr, + XBLOCK: tl.constexpr): y0 = tl.program_id(1) * YBLOCK x0 = tl.program_id(0) * XBLOCK y = y0 + tl.arange(0, YBLOCK)[None, :] @@ -532,8 +509,8 @@ def test_addptr_implicit_perm_load_store_2d_static_stride( out = torch.zeros_like(A) # A^T logical shape - xnumel = Y # cols of A - ynumel = X # rows of A + xnumel = Y # cols of A + ynumel = X # rows of A # A^T logical stride stride_x = 1 @@ -582,8 +559,8 @@ def test_addptr_implicit_perm_load_store_2d( out = torch.zeros_like(A) # A^T logical shape - xnumel = Y # cols of A - ynumel = X # rows of A + xnumel = Y # cols of A + ynumel = X # rows of A # A^T logical stride stride_x = 1 @@ -611,9 +588,7 @@ def test_addptr_implicit_perm_load_store_2d( @pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) @pytest.mark.parametrize("dtype, sigtype", types_all) -def test_addptr_implicit_perm_load_store_3d_static_stride( - X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype -): +def test_addptr_implicit_perm_load_store_3d_static_stride(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): """ Test goal: - Real memory layout: A[X, Y, Z], row-major (stride = (Y*Z, Z, 1)) @@ -662,9 +637,7 @@ def test_addptr_implicit_perm_load_store_3d_static_stride( @pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) @pytest.mark.parametrize("dtype, sigtype", types_all) -def test_addptr_implicit_perm_load_store_3d( - X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype -): +def test_addptr_implicit_perm_load_store_3d(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): """ Same as static-stride version, but stride passed as runtime values. """ @@ -705,9 +678,7 @@ def test_addptr_implicit_perm_load_store_3d( @pytest.mark.parametrize("X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK", case_4d) @pytest.mark.parametrize("dtype, sigtype", types_all) -def test_addptr_implicit_perm_load_store_4d_static_stride( - X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK, dtype, sigtype -): +def test_addptr_implicit_perm_load_store_4d_static_stride(X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK, dtype, sigtype): """ Test goal: - Real memory layout: A[X, Y, Z, W], row-major (stride = (Y*Z*W, Z*W, W, 1)) @@ -764,9 +735,7 @@ def test_addptr_implicit_perm_load_store_4d_static_stride( @pytest.mark.parametrize("X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK", case_4d) @pytest.mark.parametrize("dtype, sigtype", types_all) -def test_addptr_implicit_perm_load_store_4d( - X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK, dtype, sigtype -): +def test_addptr_implicit_perm_load_store_4d(X, Y, Z, W, XBLOCK, YBLOCK, ZBLOCK, WBLOCK, dtype, sigtype): """ Same as static-stride version, but stride passed as runtime values. """ @@ -817,9 +786,7 @@ def test_addptr_implicit_perm_load_store_4d( # ---------------------------------------------------------- @pytest.mark.parametrize("X, Y, XBLOCK, YBLOCK", case_2d) @pytest.mark.parametrize("dtype, sigtype", types_all) -def test_make_tensor_ptr_implicit_perm_load_store_2d_static_stride( - X, Y, XBLOCK, YBLOCK, dtype, sigtype -): +def test_make_tensor_ptr_implicit_perm_load_store_2d_static_stride(X, Y, XBLOCK, YBLOCK, dtype, sigtype): """ Test goal matches addptr_2d_static_stride, but uses tl.make_block_ptr + tl.load(tptr). Real layout: A[X,Y] row-major stride=(Y,1) @@ -852,9 +819,7 @@ def test_make_tensor_ptr_implicit_perm_load_store_2d_static_stride( @pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) @pytest.mark.parametrize("dtype, sigtype", types_all) -def test_make_tensor_ptr_implicit_perm_load_store_3d_static_stride( - X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype -): +def test_make_tensor_ptr_implicit_perm_load_store_3d_static_stride(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): """ Real layout: A[X,Y,Z] row-major stride=(Y*Z, Z, 1) Kernel view (logical): shape=(Z,Y,X), strides=(1, Z, Y*Z) @@ -895,9 +860,7 @@ def test_make_tensor_ptr_implicit_perm_load_store_3d_static_stride( @pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) @pytest.mark.parametrize("dtype, sigtype", types_all) -def test_make_tensor_ptr_implicit_perm_load_store_3d( - X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype -): +def test_make_tensor_ptr_implicit_perm_load_store_3d(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): """ Same as static stride but STRIDE_* passed at runtime. """ @@ -937,9 +900,7 @@ def test_make_tensor_ptr_implicit_perm_load_store_3d( @pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) @pytest.mark.parametrize("dtype, sigtype", types_all) -def test_make_tensor_ptr_implicit_perm_load_3d_static_stride( - X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype -): +def test_make_tensor_ptr_implicit_perm_load_3d_static_stride(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): A = test_common.generate_tensor(shape=(X, Y, Z), dtype=sigtype).npu() _assert_row_major_3d(A, X, Y, Z) @@ -957,7 +918,7 @@ def test_make_tensor_ptr_implicit_perm_load_3d_static_stride( out = torch.empty((xnumel, ynumel, znumel), device="npu", dtype=A.dtype) assert out.is_contiguous() OUT_STRIDE0 = ynumel * znumel # Y*X - OUT_STRIDE1 = znumel # X + OUT_STRIDE1 = znumel # X OUT_STRIDE2 = 1 grid = ( @@ -993,9 +954,7 @@ def test_make_tensor_ptr_implicit_perm_load_3d_static_stride( # ---------------------------------------------------------- @pytest.mark.parametrize("X, Y, XBLOCK, YBLOCK", case_2d) @pytest.mark.parametrize("dtype, sigtype", types_all) -def test_advance_implicit_perm_load_store_2d_static_stride( - X, Y, XBLOCK, YBLOCK, dtype, sigtype -): +def test_advance_implicit_perm_load_store_2d_static_stride(X, Y, XBLOCK, YBLOCK, dtype, sigtype): """ Same goal as addptr_2d_static_stride, but uses tl.make_block_ptr + tl.advance. """ @@ -1025,9 +984,7 @@ def test_advance_implicit_perm_load_store_2d_static_stride( @pytest.mark.parametrize("X, Y, Z, XBLOCK, YBLOCK, ZBLOCK", case_3d) @pytest.mark.parametrize("dtype, sigtype", types_all) -def test_advance_implicit_perm_load_store_3d_static_stride( - X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype -): +def test_advance_implicit_perm_load_store_3d_static_stride(X, Y, Z, XBLOCK, YBLOCK, ZBLOCK, dtype, sigtype): """ Real layout: A[X,Y,Z] row-major stride=(Y*Z, Z, 1) Kernel view (logical): shape=(Z,Y,X), strides=(1, Z, Y*Z) diff --git a/third_party/ascend/unittest/pytest_ut/test_indirect_scalar_load_offset.py b/third_party/ascend/unittest/pytest_ut/test_indirect_scalar_load_offset.py index c7879a5392..b9a34efccd 100644 --- a/third_party/ascend/unittest/pytest_ut/test_indirect_scalar_load_offset.py +++ b/third_party/ascend/unittest/pytest_ut/test_indirect_scalar_load_offset.py @@ -45,7 +45,8 @@ def gather_after_reduce_kernel( mask = offsets < vocab_size vals = tl.load( logits_ptr + req_idx * logits_stride + offsets, - mask=mask, other=-float('inf'), + mask=mask, + other=-float('inf'), ) block_max = tl.max(vals) max_val = tl.maximum(max_val, block_max) @@ -79,13 +80,18 @@ def test_gather_after_reduce(num_rows, vocab_size): logits = logits_ref.npu() logits_flat = logits.reshape(-1) - topk_ids_ref = torch.randint(0, vocab_size, (num_rows,), dtype=torch.int64) + topk_ids_ref = torch.randint(0, vocab_size, (num_rows, ), dtype=torch.int64) topk_ids = topk_ids_ref.npu() output = torch.empty(num_rows, dtype=torch.float32).npu() - gather_after_reduce_kernel[(num_rows,)]( - logits_flat, topk_ids, output, vocab_size, vocab_size, BLOCK=BLOCK, + gather_after_reduce_kernel[(num_rows, )]( + logits_flat, + topk_ids, + output, + vocab_size, + vocab_size, + BLOCK=BLOCK, ) output_ref = torch_reference(logits_ref, topk_ids_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_interleave_optimizaiton.py b/third_party/ascend/unittest/pytest_ut/test_interleave_optimizaiton.py index e50bcdbf6a..7f491e0494 100644 --- a/third_party/ascend/unittest/pytest_ut/test_interleave_optimizaiton.py +++ b/third_party/ascend/unittest/pytest_ut/test_interleave_optimizaiton.py @@ -59,7 +59,8 @@ def triton_interleave_load(q_ptr, k_ptr, head_dim_half: tl.constexpr, bias: tl.c @triton.jit -def triton_interleave_load_with_mask(q_ptr, k_ptr, head_dim_half: tl.constexpr, bias: tl.constexpr, numel: tl.constexpr): +def triton_interleave_load_with_mask(q_ptr, k_ptr, head_dim_half: tl.constexpr, bias: tl.constexpr, + numel: tl.constexpr): d_indices = tl.program_id(0) + tl.arange(0, head_dim_half) mask = d_indices < numel q_real = tl.load(q_ptr + d_indices * 2 + bias, mask) @@ -83,48 +84,42 @@ def triton_interleave_loadstore_with_mask(q_ptr, head_dim_half: tl.constexpr, bi tl.store(q_ptr + d_indices * 2 + 1 + bias, new_q_imag, mask) -@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias', - [ - ['float32', torch.float32, 16, 4], - ] - ) +@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias', [ + ['float32', torch.float32, 16, 4], +]) def test_interleave(para_type, data_type, head_dim_half, bias): length = bias + head_dim_half * 2 - q = torch.randn((length,), dtype=data_type).npu() + q = torch.randn((length, ), dtype=data_type).npu() k = torch.zeros_like(q, dtype=data_type).npu() k_ref = torch.zeros_like(q, dtype=data_type).npu() - triton_interleave_load[(1,)](q, k, head_dim_half, bias) + triton_interleave_load[(1, )](q, k, head_dim_half, bias) k_ref = torch_interleave_load(q, k_ref, head_dim_half, bias) assert torch.allclose(k, k_ref) -@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias,numel', - [ - ['float32', torch.float32, 16, 0, 8], - ] - ) +@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias,numel', [ + ['float32', torch.float32, 16, 0, 8], +]) def test_interleave_with_mask(para_type, data_type, head_dim_half, bias, numel): length = bias + head_dim_half * 2 - q = torch.randn((length,), dtype=data_type).npu() + q = torch.randn((length, ), dtype=data_type).npu() k = torch.zeros_like(q, dtype=data_type).npu() k_ref = torch.zeros_like(q, dtype=data_type).npu() - triton_interleave_load_with_mask[(1,)](q, k, head_dim_half, bias, numel) + triton_interleave_load_with_mask[(1, )](q, k, head_dim_half, bias, numel) k_ref = torch_interleave_load_with_mask(q, k_ref, head_dim_half, bias, numel) assert torch.allclose(k, k_ref) -@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias,numel', - [ - ['float32', torch.float32, 16, 0, 8], - ] - ) +@pytest.mark.parametrize('para_type,data_type,head_dim_half,bias,numel', [ + ['float32', torch.float32, 16, 0, 8], +]) def test_interleave_loadstore_with_mask(para_type, data_type, head_dim_half, bias, numel): length = bias + head_dim_half * 2 - q = torch.randn((length,), dtype=data_type).npu() + q = torch.randn((length, ), dtype=data_type).npu() q_ref = q.clone() - triton_interleave_loadstore_with_mask[(1,)](q, head_dim_half, bias, numel) + triton_interleave_loadstore_with_mask[(1, )](q, head_dim_half, bias, numel) q_ref = torch_interleave_loadstore_with_mask(q_ref, head_dim_half, bias, numel) assert torch.allclose(q, q_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_lgamma.py b/third_party/ascend/unittest/pytest_ut/test_lgamma.py index 0633922df0..fc56fffa24 100644 --- a/third_party/ascend/unittest/pytest_ut/test_lgamma.py +++ b/third_party/ascend/unittest/pytest_ut/test_lgamma.py @@ -88,11 +88,9 @@ def test_all_blocks_parallel(param_list, monkeypatch): monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") -@pytest.mark.parametrize('param_list', - [ - ['float32', (2, 2048, 8), 2, 32768, 512], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 2048, 8), 2, 32768, 512], +]) def test_auto_blockify(param_list, monkeypatch): monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") dtype, shape, ncore, xblock, xblock_sub = param_list @@ -104,11 +102,11 @@ def test_auto_blockify(param_list, monkeypatch): threshold = torch.zeros_like(x) if neg_mask.any(): neg_ints = nearest_int[neg_mask] - threshold[neg_mask] = 5.75e-5 * (2.42 ** (-1 - neg_ints)) + threshold[neg_mask] = 5.75e-5 * (2.42**(-1 - neg_ints)) mask = (torch.abs(x - nearest_int) < threshold) & (nearest_int <= -1) if mask.any(): - x = torch.where(mask, nearest_int, x) - + x = torch.where(mask, nearest_int, x) + y_ref = torch.lgamma(x).npu() y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() triton_lgamma[ncore, 1, 1](x, y_cal, x.numel(), xblock, xblock_sub, auto_blockify_size=ncore) diff --git a/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py b/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py index 4382b701c5..b3993de102 100644 --- a/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py +++ b/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py @@ -97,32 +97,26 @@ def triton_linearize_mask_broadcast(in_tensor, BLOCK_SIZE): N = in_tensor.shape[1] triton_output = torch.zeros_like(in_tensor) - grid = (ceil_div(2 * M * N, BLOCK_SIZE),) - - linearize_mask_broadcast_kernel[grid]( - in_tensor, - triton_output, - N=N, - M=M, - BLOCK_SIZE_N=BLOCK_SIZE, - optimize_dynamic_offset=True - ) + grid = (ceil_div(2 * M * N, BLOCK_SIZE), ) + + linearize_mask_broadcast_kernel[grid](in_tensor, triton_output, N=N, M=M, BLOCK_SIZE_N=BLOCK_SIZE, + optimize_dynamic_offset=True) @triton.jit def rem_kernel(in_ptr0, in_ptr1, out_ptr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) x = tl.arange(0, BLOCK_SIZE) - + base_offset = pid * BLOCK_SIZE + x - + rem_result = base_offset % 128 mask = rem_result < 64 tmp0 = tl.load(in_ptr0 + base_offset, mask=mask, other=0.0) tmp1 = tl.load(in_ptr1 + base_offset, mask=mask, other=0.0) tmp2 = tmp0 + tmp1 - + tl.store(out_ptr + base_offset, tmp2, mask=mask) @@ -130,23 +124,23 @@ def test_linearize_mask_rem(): N = 1024 BLOCK_SIZE = 256 dtype = 'float32' - shape = (N,) + shape = (N, ) x0 = test_common.generate_tensor(shape, dtype).npu() x1 = test_common.generate_tensor(shape, dtype).npu() triton_res = torch.zeros(shape).npu() - - grid = (ceil_div(N, BLOCK_SIZE),) + + grid = (ceil_div(N, BLOCK_SIZE), ) rem_kernel[grid](x0, x1, triton_res, N, BLOCK_SIZE=BLOCK_SIZE) - + base_offsets = torch.arange(N).npu() rem_results = base_offsets % 128 mask_bool = rem_results < 64 - - torch_res = torch.zeros((N,)).npu() + + torch_res = torch.zeros((N, )).npu() torch_res[mask_bool] = x0[mask_bool] + x1[mask_bool] - - test_common.validate_cmp(dtype, triton_res, torch_res) + + test_common.validate_cmp(dtype, triton_res, torch_res) def profile_performance_test(M, N, dtype, BLOCK_SIZE): diff --git a/third_party/ascend/unittest/pytest_ut/test_makeblockptr_negative_padding.py b/third_party/ascend/unittest/pytest_ut/test_makeblockptr_negative_padding.py index abd13ba327..4d339997ab 100644 --- a/third_party/ascend/unittest/pytest_ut/test_makeblockptr_negative_padding.py +++ b/third_party/ascend/unittest/pytest_ut/test_makeblockptr_negative_padding.py @@ -84,9 +84,7 @@ def negative_padding_with_store_kernel( tl.store(out_ptr, in_val, boundary_check=(0, 1)) -@pytest.mark.parametrize('param_list', [ - (8, 8), (16, 16), (32, 32), (64, 64) -]) +@pytest.mark.parametrize('param_list', [(8, 8), (16, 16), (32, 32), (64, 64)]) def test_makeblockptr_load_with_negative_padding(param_list): shape = param_list torch.manual_seed(1) @@ -112,9 +110,7 @@ def test_makeblockptr_load_with_negative_padding(param_list): test_common.validate_cmp("int32", output, output_ref) -@pytest.mark.parametrize('param_list', [ - (8, 8), (16, 16), (32, 32), (64, 64) -]) +@pytest.mark.parametrize('param_list', [(8, 8), (16, 16), (32, 32), (64, 64)]) def test_makeblockptr_store_with_negative_padding(param_list): shape = param_list torch.manual_seed(1) diff --git a/third_party/ascend/unittest/pytest_ut/test_mod.py b/third_party/ascend/unittest/pytest_ut/test_mod.py index 73690260f9..fc91c60afb 100644 --- a/third_party/ascend/unittest/pytest_ut/test_mod.py +++ b/third_party/ascend/unittest/pytest_ut/test_mod.py @@ -68,7 +68,7 @@ def test_case(param_list): y_ref = torch_pointwise(x0, x1, dtype) if dtype == "float16": y_ref = y_ref.to(torch.float16) - y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() triton_mod[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) #test_common.validate_cmp(dtype, y_cal, y_ref.npu()) if dtype == 'int8': diff --git a/third_party/ascend/unittest/pytest_ut/test_mul_reduce.py b/third_party/ascend/unittest/pytest_ut/test_mul_reduce.py index c69b5aef64..5ca4246331 100644 --- a/third_party/ascend/unittest/pytest_ut/test_mul_reduce.py +++ b/third_party/ascend/unittest/pytest_ut/test_mul_reduce.py @@ -22,20 +22,16 @@ def triton_pw_rdc5d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr nblk_idx = tl.arange(0, N) kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) x0 = tl.load(in_ptr0 + idx) x1 = tl.load(in_ptr1 + idx) ret0 = x0 * x1 ret = tl.reduce(ret0, 4, minimum, keep_dims=True) zblk_idx = tl.arange(0, 1) - odx = (lblk_idx[:, None, None, None, None] * K * N * M + - mblk_idx[None, :, None, None, None] * K * N + - nblk_idx[None, None, :, None, None] * K + - kblk_idx[None, None, None, :, None] + + odx = (lblk_idx[:, None, None, None, None] * K * N * M + mblk_idx[None, :, None, None, None] * K * N + + nblk_idx[None, None, :, None, None] * K + kblk_idx[None, None, None, :, None] + zblk_idx[None, None, None, None, :]) tl.store(out_ptr0 + odx, ret) @@ -50,10 +46,7 @@ def test_pw_rdc5d(dtype, shape): expected = (a * b).to(dtype) - triton_pw_rdc5d[(1,)]( - a, b, out, - L=L, M=M, N=N, K=K, Z=Z - ) + triton_pw_rdc5d[(1, )](a, b, out, L=L, M=M, N=N, K=K, Z=Z) torch.testing.assert_close(out.cpu(), expected.cpu(), rtol=1e-3, atol=1e-3) diff --git a/third_party/ascend/unittest/pytest_ut/test_multibuffer.py b/third_party/ascend/unittest/pytest_ut/test_multibuffer.py index 4f3e5c0370..04d2471e22 100644 --- a/third_party/ascend/unittest/pytest_ut/test_multibuffer.py +++ b/third_party/ascend/unittest/pytest_ut/test_multibuffer.py @@ -63,13 +63,11 @@ def test_multibuffer(): print("=" * 60) print("Test 1: test_alloc_ub_multibuffer") print("=" * 60) - mlir = compile_kernel( - multibuffer, {}, {"XBLOCK": 256} - ) + mlir = compile_kernel(multibuffer, {}, {"XBLOCK": 256}) print(f"Generated MLIR ({len(mlir)} chars):\n") print(mlir) # ============== Main for manual testing ============== if __name__ == "__main__": - test_multibuffer() \ No newline at end of file + test_multibuffer() diff --git a/third_party/ascend/unittest/pytest_ut/test_negative_mask_dim.py b/third_party/ascend/unittest/pytest_ut/test_negative_mask_dim.py index 02f9e79d68..55ea0cb752 100644 --- a/third_party/ascend/unittest/pytest_ut/test_negative_mask_dim.py +++ b/third_party/ascend/unittest/pytest_ut/test_negative_mask_dim.py @@ -35,16 +35,14 @@ def triton_negative_mask_dim(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): tl.store(out_ptr0 + index, tmp0, None) -@pytest.mark.parametrize('param_list', - [ - ['float32', (32,), 32], - ]) +@pytest.mark.parametrize('param_list', [ + ['float32', (32, ), 32], +]) def test_negative_mask_dim(param_list): dtype, shape, xblock = param_list x0 = torch.ones(shape, dtype=eval('torch.' + dtype)).npu() y_ref = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() y_cal = torch.ones(shape, dtype=eval('torch.' + dtype)).npu() - triton_negative_mask_dim[(1,)](x0, y_cal, xblock) + triton_negative_mask_dim[(1, )](x0, y_cal, xblock) assert torch.allclose(y_cal, y_ref) - diff --git a/third_party/ascend/unittest/pytest_ut/test_nextafter.py b/third_party/ascend/unittest/pytest_ut/test_nextafter.py index 45abbef2bf..007f74d9dc 100644 --- a/third_party/ascend/unittest/pytest_ut/test_nextafter.py +++ b/third_party/ascend/unittest/pytest_ut/test_nextafter.py @@ -47,11 +47,10 @@ def triton_nextafter(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SU tl.store(out_ptr0 + (x0), tmp2, None) -@pytest.mark.parametrize('param_list', - [ - ['float32', (2, 4096, 8), 2, 32768, 1024], - ['float16', (2, 4096, 8), 2, 32768, 1024], - ]) +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], +]) def test_nextafter(param_list): dtype, shape, ncore, xblock, xblock_sub = param_list x0_ref = test_common.generate_tensor(shape, dtype) diff --git a/third_party/ascend/unittest/pytest_ut/test_paged_kvcache_krope.py b/third_party/ascend/unittest/pytest_ut/test_paged_kvcache_krope.py index 0b5df57c72..819f131f12 100644 --- a/third_party/ascend/unittest/pytest_ut/test_paged_kvcache_krope.py +++ b/third_party/ascend/unittest/pytest_ut/test_paged_kvcache_krope.py @@ -45,15 +45,11 @@ def test_bubbleup_extract_nonzero_offset(): total_tokens = num_pages * PAGE_SIZE kv_cache = torch.zeros(total_tokens, head_dim, dtype=torch.float32, device=device) for token_id in range(total_tokens): - kv_cache[token_id, :head_dim_v] = ( - torch.arange(head_dim_v, dtype=torch.float32) + token_id * 100 - ) - kv_cache[token_id, head_dim_v:] = ( - torch.arange(head_dim_v, head_dim, dtype=torch.float32) + token_id * 1000 - ) + kv_cache[token_id, :head_dim_v] = (torch.arange(head_dim_v, dtype=torch.float32) + token_id * 100) + kv_cache[token_id, head_dim_v:] = (torch.arange(head_dim_v, head_dim, dtype=torch.float32) + token_id * 1000) output = torch.zeros(BLOCK_N, rope_dim, dtype=torch.float32, device=device) - rope_like_load_kernel[(1,)]( + rope_like_load_kernel[(1, )]( kv_cache.flatten(), req_to_tokens, output.flatten(), @@ -67,14 +63,10 @@ def test_bubbleup_extract_nonzero_offset(): expected = torch.zeros(BLOCK_N, rope_dim, dtype=torch.float32, device=device) for token_id in range(BLOCK_N): - expected[token_id] = ( - torch.arange(head_dim_v, head_dim, dtype=torch.float32) + token_id * 1000 - ) + expected[token_id] = (torch.arange(head_dim_v, head_dim, dtype=torch.float32) + token_id * 1000) buggy = torch.zeros(BLOCK_N, rope_dim, dtype=torch.float32, device=device) for token_id in range(BLOCK_N): - buggy[token_id] = ( - torch.arange(rope_dim, dtype=torch.float32) + token_id * 100 - ) + buggy[token_id] = (torch.arange(rope_dim, dtype=torch.float32) + token_id * 100) - assert torch.allclose(output, expected, atol=1e-5) \ No newline at end of file + assert torch.allclose(output, expected, atol=1e-5) diff --git a/third_party/ascend/unittest/pytest_ut/test_permuted_boundary_check.py b/third_party/ascend/unittest/pytest_ut/test_permuted_boundary_check.py index 7535c9ce40..9ff97ac885 100644 --- a/third_party/ascend/unittest/pytest_ut/test_permuted_boundary_check.py +++ b/third_party/ascend/unittest/pytest_ut/test_permuted_boundary_check.py @@ -4,45 +4,23 @@ import triton.language as tl import pytest + @triton.jit -def zj_fa_fwd_pattern( - in_ptr0, in_ptr1, out_ptr, - M, K, N, - MBLOCK: tl.constexpr, - NBLOCK: tl.constexpr, - KBLOCK: tl.constexpr -): - a_ptr = tl.make_block_ptr( - base = in_ptr0, - shape = (M, K), # 8, 3 - strides = (K, 1), - offsets = (0, 0), - block_shape = (MBLOCK, KBLOCK), - order = (1, 0) - ) +def zj_fa_fwd_pattern(in_ptr0, in_ptr1, out_ptr, M, K, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr, + KBLOCK: tl.constexpr): + a_ptr = tl.make_block_ptr(base=in_ptr0, shape=(M, K), # 8, 3 + strides=(K, 1), offsets=(0, 0), block_shape=(MBLOCK, KBLOCK), order=(1, 0)) - b_ptr = tl.make_block_ptr( - base = in_ptr1, - shape = (K, N), # 3, 8 - strides = (1, K), - offsets = (0, 0), - block_shape = (KBLOCK, NBLOCK), - order = (0, 1) - ) + b_ptr = tl.make_block_ptr(base=in_ptr1, shape=(K, N), # 3, 8 + strides=(1, K), offsets=(0, 0), block_shape=(KBLOCK, NBLOCK), order=(0, 1)) - c_ptr = tl.make_block_ptr( - base = out_ptr, - shape = (M, N), - strides = (1, M), - offsets = (0, 0), - block_shape = (MBLOCK, NBLOCK), - order = (0, 1) - ) + c_ptr = tl.make_block_ptr(base=out_ptr, shape=(M, N), strides=(1, M), offsets=(0, 0), block_shape=(MBLOCK, NBLOCK), + order=(0, 1)) - a = tl.load(a_ptr, boundary_check = (0,), padding_option="zero") - b = tl.load(b_ptr, boundary_check = (0,), padding_option="zero") + a = tl.load(a_ptr, boundary_check=(0, ), padding_option="zero") + b = tl.load(b_ptr, boundary_check=(0, ), padding_option="zero") c = tl.dot(a, b) - tl.store(c_ptr, c, boundary_check = (0, 1)) + tl.store(c_ptr, c, boundary_check=(0, 1)) def test_permute_boundary_check(): @@ -52,9 +30,9 @@ def test_permute_boundary_check(): MBLOCK = 8 NBLOCK = 8 KBLOCK = 4 - a = torch.randn((M, K), device="npu") # 8, 3 - b = torch.randn((N, K), device="npu") # 8, 3 + a = torch.randn((M, K), device="npu") # 8, 3 + b = torch.randn((N, K), device="npu") # 8, 3 c = torch.empty((N, M), device="npu") - zj_fa_fwd_pattern[(1,1,1)](a, b, c, M, K, N, MBLOCK, NBLOCK, KBLOCK) + zj_fa_fwd_pattern[(1, 1, 1)](a, b, c, M, K, N, MBLOCK, NBLOCK, KBLOCK) std = a @ b.T - torch.testing.assert_close(std, c.T, atol = 1e-2, rtol = 1e-2) \ No newline at end of file + torch.testing.assert_close(std, c.T, atol=1e-2, rtol=1e-2) diff --git a/third_party/ascend/unittest/pytest_ut/test_reduce_maximum.py b/third_party/ascend/unittest/pytest_ut/test_reduce_maximum.py index e1fb74e5b9..ec29ccbbfe 100644 --- a/third_party/ascend/unittest/pytest_ut/test_reduce_maximum.py +++ b/third_party/ascend/unittest/pytest_ut/test_reduce_maximum.py @@ -64,15 +64,11 @@ def triton_max_5d_dim13(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (lblk_idx[:, None, None] * N * Z + - nblk_idx[None, :, None] * Z + - zblk_idx[None, None, :]) + odx = (lblk_idx[:, None, None] * N * Z + nblk_idx[None, :, None] * Z + zblk_idx[None, None, :]) x = tl.load(in_ptr0 + idx) ret_k = tl.reduce(x, 3, maximum) # [L, M, N, Z] @@ -89,16 +85,12 @@ def triton_max_5d_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: t kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (mblk_idx[:, None, None, None] * N * K * Z + - nblk_idx[None, :, None, None] * K * Z + - kblk_idx[None, None, :, None] * Z + - zblk_idx[None, None, None, :]) + odx = (mblk_idx[:, None, None, None] * N * K * Z + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) x = tl.load(in_ptr0 + idx) ret = tl.reduce(x, 0, maximum) @@ -114,16 +106,12 @@ def triton_max_5d_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: t kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (lblk_idx[:, None, None, None] * N * K * Z + - nblk_idx[None, :, None, None] * K * Z + - kblk_idx[None, None, :, None] * Z + - zblk_idx[None, None, None, :]) + odx = (lblk_idx[:, None, None, None] * N * K * Z + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) x = tl.load(in_ptr0 + idx) ret = tl.reduce(x, 1, maximum) @@ -139,16 +127,12 @@ def triton_max_5d_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: t kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (lblk_idx[:, None, None, None] * M * K * Z + - mblk_idx[None, :, None, None] * K * Z + - kblk_idx[None, None, :, None] * Z + - zblk_idx[None, None, None, :]) + odx = (lblk_idx[:, None, None, None] * M * K * Z + mblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) x = tl.load(in_ptr0 + idx) ret = tl.reduce(x, 2, maximum) @@ -164,16 +148,12 @@ def triton_max_5d_dim3(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: t kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (lblk_idx[:, None, None, None] * M * N * Z + - mblk_idx[None, :, None, None] * N * Z + - nblk_idx[None, None, :, None] * Z + - zblk_idx[None, None, None, :]) + odx = (lblk_idx[:, None, None, None] * M * N * Z + mblk_idx[None, :, None, None] * N * Z + + nblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) x = tl.load(in_ptr0 + idx) ret = tl.reduce(x, 3, maximum) @@ -189,16 +169,12 @@ def triton_max_5d_dim4(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: t kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (lblk_idx[:, None, None, None] * M * N * K + - mblk_idx[None, :, None, None] * N * K + - nblk_idx[None, None, :, None] * K + - kblk_idx[None, None, None, :]) + odx = (lblk_idx[:, None, None, None] * M * N * K + mblk_idx[None, :, None, None] * N * K + + nblk_idx[None, None, :, None] * K + kblk_idx[None, None, None, :]) x = tl.load(in_ptr0 + idx) ret = tl.reduce(x, 4, maximum) @@ -214,10 +190,8 @@ def triton_max_5d_all(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) x = tl.load(in_ptr0 + idx) @@ -233,23 +207,20 @@ def triton_max_5d_all(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl (triton_max_5d_dim024, (1, 1, 1, 1, 1), "dim024"), (triton_max_5d_dim024, (2, 2, 2, 2, 2), "dim024"), (triton_max_5d_dim024, (3, 11, 1, 3, 42), "dim024"), - (triton_max_5d_dim13, (1, 1, 1, 1, 1024), "dim13"), - (triton_max_5d_dim0, (2, 2, 2, 2, 2), "dim0"), (triton_max_5d_dim1, (2, 2, 2, 2, 2), "dim1"), (triton_max_5d_dim2, (2, 2, 2, 2, 2), "dim2"), (triton_max_5d_dim3, (2, 2, 2, 2, 2), "dim3"), (triton_max_5d_dim4, (2, 2, 2, 2, 2), "dim4"), - (triton_max_5d_all, (3, 11, 1, 3, 42), "all"), ] typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -ids = ["{}-{}-{}".format(testfunc.__name__, "-".join(map(str, shape)), dim_name) - for testfunc, shape, dim_name in testlist - ] +ids = [ + "{}-{}-{}".format(testfunc.__name__, "-".join(map(str, shape)), dim_name) for testfunc, shape, dim_name in testlist +] @pytest.mark.parametrize('testfunc, shape, dim_name', testlist, ids=ids) @@ -267,7 +238,7 @@ def test_max(testfunc, dtype, shape, dim_name): ans, _ = torch.max(x0, 4) ans, _ = torch.max(ans, 2) ans, _ = torch.max(ans, 0) - output = torch.zeros((shape[1],) + (shape[3],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[1], ) + (shape[3], ), dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim13": if 'int' in dtype: @@ -277,7 +248,7 @@ def test_max(testfunc, dtype, shape, dim_name): else: ans, _ = torch.max(x0, 3) ans, _ = torch.max(ans, 1) - output = torch.zeros((shape[0],) + (shape[2],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[0], ) + (shape[2], ) + (shape[4], ), dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim0": if 'int' in dtype: @@ -285,7 +256,8 @@ def test_max(testfunc, dtype, shape, dim_name): ans = ans.to(dtype=eval('torch.' + dtype)) else: ans, _ = torch.max(x0, 0) - output = torch.zeros((shape[1],) + (shape[2],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[1], ) + (shape[2], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim1": if 'int' in dtype: @@ -293,7 +265,8 @@ def test_max(testfunc, dtype, shape, dim_name): ans = ans.to(dtype=eval('torch.' + dtype)) else: ans, _ = torch.max(x0, 1) - output = torch.zeros((shape[0],) + (shape[2],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[0], ) + (shape[2], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim2": if 'int' in dtype: @@ -301,7 +274,8 @@ def test_max(testfunc, dtype, shape, dim_name): ans = ans.to(dtype=eval('torch.' + dtype)) else: ans, _ = torch.max(x0, 2) - output = torch.zeros((shape[0],) + (shape[1],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim3": if 'int' in dtype: @@ -309,7 +283,8 @@ def test_max(testfunc, dtype, shape, dim_name): ans = ans.to(dtype=eval('torch.' + dtype)) else: ans, _ = torch.max(x0, 3) - output = torch.zeros((shape[0],) + (shape[1],) + (shape[2],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[2], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim4": if 'int' in dtype: @@ -317,7 +292,8 @@ def test_max(testfunc, dtype, shape, dim_name): ans = ans.to(dtype=eval('torch.' + dtype)) else: ans, _ = torch.max(x0, 4) - output = torch.zeros((shape[0],) + (shape[1],) + (shape[2],) + (shape[3],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[2], ) + (shape[3], ), + dtype=eval('torch.' + dtype)).npu() elif dim_name == "all": if 'int' in dtype: @@ -325,8 +301,8 @@ def test_max(testfunc, dtype, shape, dim_name): ans = torch.tensor([ans], dtype=eval('torch.' + dtype)) else: ans = torch.tensor([torch.max(x0)], dtype=eval('torch.' + dtype)) - output = torch.zeros((1,), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((1, ), dtype=eval('torch.' + dtype)).npu() - testfunc[(1,)](x0, output, *shape) + testfunc[(1, )](x0, output, *shape) test_common.validate_cmp(dtype, output, ans) diff --git a/third_party/ascend/unittest/pytest_ut/test_reduce_min_4_keepdim_True_with_index_op.py b/third_party/ascend/unittest/pytest_ut/test_reduce_min_4_keepdim_True_with_index_op.py index adf3104f52..3428508f69 100644 --- a/third_party/ascend/unittest/pytest_ut/test_reduce_min_4_keepdim_True_with_index_op.py +++ b/third_party/ascend/unittest/pytest_ut/test_reduce_min_4_keepdim_True_with_index_op.py @@ -14,7 +14,7 @@ @triton.jit def promote_to_tensor(x): # Addition promotes to tensor for us - return x + tl.zeros((1,), tl.int1) + return x + tl.zeros((1, ), tl.int1) @triton.jit @@ -33,8 +33,8 @@ def minimum_with_index(a_value, a_index, b_value, b_index): @triton.jit -def triton_min_5d_dim4_keepdim(in_ptr0, in_ptr1, out_ptr0, out_ptr1, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - Z: tl.constexpr): +def triton_min_5d_dim4_keepdim(in_ptr0, in_ptr1, out_ptr0, out_ptr1, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr, + K: tl.constexpr, Z: tl.constexpr): lblk_idx = tl.arange(0, L) mblk_idx = tl.arange(0, M) nblk_idx = tl.arange(0, N) @@ -74,9 +74,7 @@ def triton_min_5d_dim4_keepdim(in_ptr0, in_ptr1, out_ptr0, out_ptr1, L: tl.const typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -ids = ["{}-{}".format(testfunc.__name__, "-".join(map(str, shape))) - for testfunc, shape in testlist -] +ids = ["{}-{}".format(testfunc.__name__, "-".join(map(str, shape))) for testfunc, shape in testlist] @pytest.mark.parametrize('testfunc, shape', testlist, ids=ids) @@ -93,6 +91,6 @@ def test_min_dim4_keepdim(testfunc, sigtype, shape): ans, ans1 = torch.min(x0, 4) output = torch.zeros(shape[0:4], dtype=dtype).npu() output1 = torch.zeros(shape[0:4], dtype=torch.int32).npu() - testfunc[(1,)](x0, x1, output, output1, *shape, debug=True) + testfunc[(1, )](x0, x1, output, output1, *shape, debug=True) test_common.validate_cmp(sigtype, output, ans) - test_common.validate_cmp('int32', output1, ans1) \ No newline at end of file + test_common.validate_cmp('int32', output1, ans1) diff --git a/third_party/ascend/unittest/pytest_ut/test_reduce_minimum.py b/third_party/ascend/unittest/pytest_ut/test_reduce_minimum.py index c3d3831b90..f472eef67e 100644 --- a/third_party/ascend/unittest/pytest_ut/test_reduce_minimum.py +++ b/third_party/ascend/unittest/pytest_ut/test_reduce_minimum.py @@ -64,15 +64,11 @@ def triton_min_5d_dim13(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (lblk_idx[:, None, None] * N * Z + - nblk_idx[None, :, None] * Z + - zblk_idx[None, None, :]) + odx = (lblk_idx[:, None, None] * N * Z + nblk_idx[None, :, None] * Z + zblk_idx[None, None, :]) x = tl.load(in_ptr0 + idx) ret_k = tl.reduce(x, 3, minimum) # [L, M, N, Z] @@ -89,16 +85,12 @@ def triton_min_5d_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: t kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (mblk_idx[:, None, None, None] * N * K * Z + - nblk_idx[None, :, None, None] * K * Z + - kblk_idx[None, None, :, None] * Z + - zblk_idx[None, None, None, :]) + odx = (mblk_idx[:, None, None, None] * N * K * Z + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) x = tl.load(in_ptr0 + idx) ret = tl.reduce(x, 0, minimum) @@ -114,16 +106,12 @@ def triton_min_5d_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: t kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (lblk_idx[:, None, None, None] * N * K * Z + - nblk_idx[None, :, None, None] * K * Z + - kblk_idx[None, None, :, None] * Z + - zblk_idx[None, None, None, :]) + odx = (lblk_idx[:, None, None, None] * N * K * Z + nblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) x = tl.load(in_ptr0 + idx) ret = tl.reduce(x, 1, minimum) @@ -139,16 +127,12 @@ def triton_min_5d_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: t kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (lblk_idx[:, None, None, None] * M * K * Z + - mblk_idx[None, :, None, None] * K * Z + - kblk_idx[None, None, :, None] * Z + - zblk_idx[None, None, None, :]) + odx = (lblk_idx[:, None, None, None] * M * K * Z + mblk_idx[None, :, None, None] * K * Z + + kblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) x = tl.load(in_ptr0 + idx) ret = tl.reduce(x, 2, minimum) @@ -164,16 +148,12 @@ def triton_min_5d_dim3(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: t kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (lblk_idx[:, None, None, None] * M * N * Z + - mblk_idx[None, :, None, None] * N * Z + - nblk_idx[None, None, :, None] * Z + - zblk_idx[None, None, None, :]) + odx = (lblk_idx[:, None, None, None] * M * N * Z + mblk_idx[None, :, None, None] * N * Z + + nblk_idx[None, None, :, None] * Z + zblk_idx[None, None, None, :]) x = tl.load(in_ptr0 + idx) ret = tl.reduce(x, 3, minimum) @@ -189,16 +169,12 @@ def triton_min_5d_dim4(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: t kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) - odx = (lblk_idx[:, None, None, None] * M * N * K + - mblk_idx[None, :, None, None] * N * K + - nblk_idx[None, None, :, None] * K + - kblk_idx[None, None, None, :]) + odx = (lblk_idx[:, None, None, None] * M * N * K + mblk_idx[None, :, None, None] * N * K + + nblk_idx[None, None, :, None] * K + kblk_idx[None, None, None, :]) x = tl.load(in_ptr0 + idx) ret = tl.reduce(x, 4, minimum) @@ -214,10 +190,8 @@ def triton_min_5d_all(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl kblk_idx = tl.arange(0, K) zblk_idx = tl.arange(0, Z) - idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + - mblk_idx[None, :, None, None, None] * Z * K * N + - nblk_idx[None, None, :, None, None] * Z * K + - kblk_idx[None, None, None, :, None] * Z + + idx = (lblk_idx[:, None, None, None, None] * Z * K * N * M + mblk_idx[None, :, None, None, None] * Z * K * N + + nblk_idx[None, None, :, None, None] * Z * K + kblk_idx[None, None, None, :, None] * Z + zblk_idx[None, None, None, None, :]) x = tl.load(in_ptr0 + idx) @@ -233,23 +207,20 @@ def triton_min_5d_all(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl (triton_min_5d_dim024, (1, 1, 1, 1, 1), "dim024"), (triton_min_5d_dim024, (2, 2, 2, 2, 2), "dim024"), (triton_min_5d_dim024, (3, 11, 1, 3, 42), "dim024"), - (triton_min_5d_dim13, (1, 1, 1, 1, 1024), "dim13"), - (triton_min_5d_dim0, (2, 2, 2, 2, 2), "dim0"), (triton_min_5d_dim1, (2, 2, 2, 2, 2), "dim1"), (triton_min_5d_dim2, (2, 2, 2, 2, 2), "dim2"), (triton_min_5d_dim3, (2, 2, 2, 2, 2), "dim3"), (triton_min_5d_dim4, (2, 2, 2, 2, 2), "dim4"), - (triton_min_5d_all, (3, 11, 1, 3, 42), "all"), ] typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] -ids = ["{}-{}-{}".format(testfunc.__name__, "-".join(map(str, shape)), dim_name) - for testfunc, shape, dim_name in testlist - ] +ids = [ + "{}-{}-{}".format(testfunc.__name__, "-".join(map(str, shape)), dim_name) for testfunc, shape, dim_name in testlist +] @pytest.mark.parametrize('testfunc, shape, dim_name', testlist, ids=ids) @@ -267,7 +238,7 @@ def test_min(testfunc, dtype, shape, dim_name): ans, _ = torch.min(x0, 4) ans, _ = torch.min(ans, 2) ans, _ = torch.min(ans, 0) - output = torch.zeros((shape[1],) + (shape[3],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[1], ) + (shape[3], ), dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim13": if 'int' in dtype: @@ -277,7 +248,7 @@ def test_min(testfunc, dtype, shape, dim_name): else: ans, _ = torch.min(x0, 3) ans, _ = torch.min(ans, 1) - output = torch.zeros((shape[0],) + (shape[2],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[0], ) + (shape[2], ) + (shape[4], ), dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim0": if 'int' in dtype: @@ -285,7 +256,8 @@ def test_min(testfunc, dtype, shape, dim_name): ans = ans.to(dtype=eval('torch.' + dtype)) else: ans, _ = torch.min(x0, 0) - output = torch.zeros((shape[1],) + (shape[2],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[1], ) + (shape[2], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim1": if 'int' in dtype: @@ -293,7 +265,8 @@ def test_min(testfunc, dtype, shape, dim_name): ans = ans.to(dtype=eval('torch.' + dtype)) else: ans, _ = torch.min(x0, 1) - output = torch.zeros((shape[0],) + (shape[2],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[0], ) + (shape[2], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim2": if 'int' in dtype: @@ -301,7 +274,8 @@ def test_min(testfunc, dtype, shape, dim_name): ans = ans.to(dtype=eval('torch.' + dtype)) else: ans, _ = torch.min(x0, 2) - output = torch.zeros((shape[0],) + (shape[1],) + (shape[3],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[3], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim3": if 'int' in dtype: @@ -309,7 +283,8 @@ def test_min(testfunc, dtype, shape, dim_name): ans = ans.to(dtype=eval('torch.' + dtype)) else: ans, _ = torch.min(x0, 3) - output = torch.zeros((shape[0],) + (shape[1],) + (shape[2],) + (shape[4],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[2], ) + (shape[4], ), + dtype=eval('torch.' + dtype)).npu() elif dim_name == "dim4": if 'int' in dtype: @@ -317,7 +292,8 @@ def test_min(testfunc, dtype, shape, dim_name): ans = ans.to(dtype=eval('torch.' + dtype)) else: ans, _ = torch.min(x0, 4) - output = torch.zeros((shape[0],) + (shape[1],) + (shape[2],) + (shape[3],), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((shape[0], ) + (shape[1], ) + (shape[2], ) + (shape[3], ), + dtype=eval('torch.' + dtype)).npu() elif dim_name == "all": if 'int' in dtype: @@ -325,8 +301,8 @@ def test_min(testfunc, dtype, shape, dim_name): ans = torch.tensor([ans], dtype=eval('torch.' + dtype)) else: ans = torch.tensor([torch.min(x0)], dtype=eval('torch.' + dtype)) - output = torch.zeros((1,), dtype=eval('torch.' + dtype)).npu() + output = torch.zeros((1, ), dtype=eval('torch.' + dtype)).npu() - testfunc[(1,)](x0, output, *shape) + testfunc[(1, )](x0, output, *shape) test_common.validate_cmp(dtype, output, ans) diff --git a/third_party/ascend/unittest/pytest_ut/test_select_analysis.py b/third_party/ascend/unittest/pytest_ut/test_select_analysis.py index 012ddc5fcd..6f7f99b110 100644 --- a/third_party/ascend/unittest/pytest_ut/test_select_analysis.py +++ b/third_party/ascend/unittest/pytest_ut/test_select_analysis.py @@ -40,10 +40,7 @@ def kernel_cal_select_mask_bool( false_tensor = tl.arange(0, BLOCK) >= numel mask = offs < indice res = tl.where(mask, true_tensor, false_tensor) - tl.store( - Output_ptr + offs, - res - ) + tl.store(Output_ptr + offs, res) @triton.jit @@ -62,20 +59,13 @@ def kernel_cal_select_mask( row_indices = tl.load(Indices_ptr) col_indices = tl.load(Indices_ptr + 1) - qk_ub = tl.load( - QK_ptr + offs - ) - other = tl.load( - Other_ptr + offs - ) + qk_ub = tl.load(QK_ptr + offs) + other = tl.load(Other_ptr + offs) mask_rows = rows < row_indices * stride_qk mask_cols = cols < col_indices res = tl.where(mask_rows[:, None] & mask_cols[None, :], qk_ub, other) - tl.store( - Output_ptr + offs, - res - ) + tl.store(Output_ptr + offs, res) def torch_cal_select_mask_bool( @@ -90,7 +80,7 @@ def torch_cal_select_mask_bool( res = torch.where(mask, true_tensor, false_tensor) return res - + def torch_cal_select_mask( QK: torch.Tensor, @@ -104,29 +94,21 @@ def torch_cal_select_mask( return Output -@pytest.mark.parametrize('param_list', - [ - ['bool', 64, 63] - ] - ) +@pytest.mark.parametrize('param_list', [['bool', 64, 63]]) def test_select_analysis_bool(param_list): dtype, SEQ_LEN, indice = param_list assert dtype == 'bool' qk_cal = torch.empty(SEQ_LEN).npu() indices = torch.tensor([indice]).npu() qk_ref = torch_cal_select_mask_bool(indice, SEQ_LEN, SEQ_LEN) - kernel_cal_select_mask_bool[(1,)]( - qk_cal, indices, SEQ_LEN, SEQ_LEN - ) + kernel_cal_select_mask_bool[(1, )](qk_cal, indices, SEQ_LEN, SEQ_LEN) test_common.validate_cmp(dtype, qk_cal, qk_ref) -@pytest.mark.parametrize('param_list', - [ - ['float16', 64, 63, 62], - ['float32', 64, 63, 62], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['float16', 64, 63, 62], + ['float32', 64, 63, 62], +]) def test_select_analysis(param_list): dtype, SEQ_LEN, indice_x, indice_y = param_list assert dtype != 'bool' @@ -135,8 +117,5 @@ def test_select_analysis(param_list): other = torch.zeros_like(qk).npu() indices_tensor = torch.tensor([indice_x, indice_y]).npu() qk_ref = torch_cal_select_mask(qk, other, indices_tensor) - kernel_cal_select_mask[(1,)]( - qk, other, qk_cal, indices_tensor, - qk.stride(0), SEQ_LEN * SEQ_LEN, SEQ_LEN - ) - test_common.validate_cmp(dtype, qk_cal, qk_ref) \ No newline at end of file + kernel_cal_select_mask[(1, )](qk, other, qk_cal, indices_tensor, qk.stride(0), SEQ_LEN * SEQ_LEN, SEQ_LEN) + test_common.validate_cmp(dtype, qk_cal, qk_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_invert.py b/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_invert.py index 0610f0b21a..64abd9b3cb 100644 --- a/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_invert.py +++ b/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_invert.py @@ -47,7 +47,7 @@ def cal_atten_mask_kernel( if cur_s1 >= SEQ_LEN: return - + beg_sbs = idx_sub_sbs * BLOCK_SBS // sparse_block_size end_sbs = ((idx_sub_sbs + 1) * BLOCK_SBS) // sparse_block_size @@ -61,11 +61,7 @@ def cal_atten_mask_kernel( mask_n = offs_n < SEQ_LEN mask_load = mask_m[:, None] & mask_n[None, :] - qk_ub = tl.load( - QK_ptr + offs_m[:, None] * stride_qk_m + offs_n[None, :] * stride_qk_n, - mask=mask_load, - other=0.0 - ) + qk_ub = tl.load(QK_ptr + offs_m[:, None] * stride_qk_m + offs_n[None, :] * stride_qk_n, mask=mask_load, other=0.0) for idx_k in range(beg_sbs, end_sbs): idx_s2 = tl.load(Indices_ptr + TOPK_BASE + idx_k * stride_ik) @@ -78,19 +74,10 @@ def cal_atten_mask_kernel( mask_higher_sbs = tl.arange(0, BLOCK_SBS) < idx_higher_sbs qk_ub = tl.where((mask_lower_sbs & mask_higher_sbs)[None, :], float("-inf"), qk_ub) - tl.store( - QK_ptr + offs_m[:, None] * stride_qk_m + offs_n[None, :] * stride_qk_n, - qk_ub, - mask=mask_load - ) + tl.store(QK_ptr + offs_m[:, None] * stride_qk_m + offs_n[None, :] * stride_qk_n, qk_ub, mask=mask_load) -def launch_cal_atten_mask( - qk_tensor, - indices_tensor, - sparse_block_size=64, - block_sbs=128 -): +def launch_cal_atten_mask(qk_tensor, indices_tensor, sparse_block_size=64, block_sbs=128): """ qk_tensor: (SEQ_LEN, SEQ_LEN) indices_tensor: (K,) / (BATCH, K, ...) @@ -154,17 +141,13 @@ def torch_cal_atten_mask( return qk_out -@pytest.mark.parametrize('param_list', - [ - ['float32', 1024, 128, 64] - ] - ) +@pytest.mark.parametrize('param_list', [['float32', 1024, 128, 64]]) def test_divsiop_select_analysis1(param_list): dtype, SEQ_LEN, BLOCK_SBS, SPARSE_BLOCK = param_list qk = torch.zeros((SEQ_LEN, SEQ_LEN), dtype=eval('torch.' + dtype), device='npu') K_SIZE = 20 - indices = torch.full((K_SIZE,), -1, dtype=torch.int32, device='npu') + indices = torch.full((K_SIZE, ), -1, dtype=torch.int32, device='npu') indices[10] = 20 qk_ref = torch_cal_atten_mask(qk.clone(), indices, sparse_block_size=SPARSE_BLOCK, block_sbs=BLOCK_SBS) qk_cal = launch_cal_atten_mask(qk, indices, sparse_block_size=SPARSE_BLOCK, block_sbs=BLOCK_SBS) - test_common.validate_cmp(dtype, qk_cal, qk_ref) \ No newline at end of file + test_common.validate_cmp(dtype, qk_cal, qk_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_signbit.py b/third_party/ascend/unittest/pytest_ut/test_signbit.py index b576e98a9e..8350094340 100644 --- a/third_party/ascend/unittest/pytest_ut/test_signbit.py +++ b/third_party/ascend/unittest/pytest_ut/test_signbit.py @@ -66,12 +66,10 @@ def test_all_blocks_parallel(param_list, monkeypatch): monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") -@pytest.mark.parametrize('param_list', - [ - ['float16', (2, 4096, 8), 2, 32768, 1024], - ['float32', (2, 4096, 8), 2, 32768, 1024], - ] - ) +@pytest.mark.parametrize('param_list', [ + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['float32', (2, 4096, 8), 2, 32768, 1024], +]) def test_auto_blockify(param_list, monkeypatch): monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") dtype, shape, ncore, xblock, xblock_sub = param_list diff --git a/third_party/ascend/unittest/pytest_ut/test_simplify_iterargs.py b/third_party/ascend/unittest/pytest_ut/test_simplify_iterargs.py index 44f95114ab..d1dabaf669 100644 --- a/third_party/ascend/unittest/pytest_ut/test_simplify_iterargs.py +++ b/third_party/ascend/unittest/pytest_ut/test_simplify_iterargs.py @@ -54,27 +54,25 @@ def triton_fn_broadcast(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.cons tl.store(out_ptr0 + (x0), tmp0, None) -@pytest.mark.parametrize('param_list', - [ - ['float32', (128, 128), 128, 128, 32], - ]) +@pytest.mark.parametrize('param_list', [ + ['float32', (128, 128), 128, 128, 32], +]) def test_expanddims(param_list): dtype, shape, xblock, yblock, yblock_sub = param_list x0 = test_common.generate_tensor(shape, dtype).npu() y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_fn_expanddims[(1,)](x0, y_cal, xblock, yblock, yblock_sub) + triton_fn_expanddims[(1, )](x0, y_cal, xblock, yblock, yblock_sub) test_common.validate_cmp(dtype, y_cal, x0) -@pytest.mark.parametrize('param_list', - [ - ['float32', (128, 128), 128, 128, 32], - ]) +@pytest.mark.parametrize('param_list', [ + ['float32', (128, 128), 128, 128, 32], +]) def test_broadcast(param_list): dtype, shape, xblock, yblock, yblock_sub = param_list x0 = test_common.generate_tensor(shape, dtype).npu() y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_fn_broadcast[(1,)](x0, y_cal, xblock, yblock, yblock_sub) + triton_fn_broadcast[(1, )](x0, y_cal, xblock, yblock, yblock_sub) test_common.validate_cmp(dtype, y_cal, x0) diff --git a/third_party/ascend/unittest/pytest_ut/test_sink_broadcast.py b/third_party/ascend/unittest/pytest_ut/test_sink_broadcast.py index cf902cc00a..99b347246d 100644 --- a/third_party/ascend/unittest/pytest_ut/test_sink_broadcast.py +++ b/third_party/ascend/unittest/pytest_ut/test_sink_broadcast.py @@ -59,10 +59,9 @@ def triton_sink_broadcast3(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, YBLOCK: tl.c tl.store(out_ptr0 + index, tmp0, index < XBLOCK * YBLOCK) -@pytest.mark.parametrize('param_list', - [ - ['float32', (32, 32), 32, 32], - ]) +@pytest.mark.parametrize('param_list', [ + ['float32', (32, 32), 32, 32], +]) def test_sink_broadcast(param_list): dtype, shape, xblock, yblock = param_list x0 = test_common.generate_tensor(shape, dtype).npu() @@ -71,21 +70,20 @@ def test_sink_broadcast(param_list): y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() y_cal2 = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_sink_broadcast1[(1,)](x0, y_cal, xblock, yblock) - triton_sink_broadcast2[(1,)](x0, y_cal2, xblock, yblock) + triton_sink_broadcast1[(1, )](x0, y_cal, xblock, yblock) + triton_sink_broadcast2[(1, )](x0, y_cal2, xblock, yblock) test_common.validate_cmp(dtype, y_cal, y_ref) test_common.validate_cmp(dtype, y_cal2, y_ref) -@pytest.mark.parametrize('param_list', - [ - ['float32', (32, 32), 32, 32], - ]) +@pytest.mark.parametrize('param_list', [ + ['float32', (32, 32), 32, 32], +]) def test_sink_broadcast3(param_list): dtype, shape, xblock, yblock = param_list x0 = test_common.generate_tensor(shape, dtype).npu() y_ref = x0.clone() y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - triton_sink_broadcast3[(1,)](x0, y_cal, xblock, yblock) + triton_sink_broadcast3[(1, )](x0, y_cal, xblock, yblock) test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/unittest/pytest_ut/test_use_analysis.py b/third_party/ascend/unittest/pytest_ut/test_use_analysis.py index 7a9ce7ba4d..06ce0a2470 100644 --- a/third_party/ascend/unittest/pytest_ut/test_use_analysis.py +++ b/third_party/ascend/unittest/pytest_ut/test_use_analysis.py @@ -30,7 +30,8 @@ @triton.jit -def triton_reduce_deadcode(v_ptr, in_ptr0, in_ptr1, out_ptr0, VBLOCK: tl.constexpr, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr): +def triton_reduce_deadcode(v_ptr, in_ptr0, in_ptr1, out_ptr0, VBLOCK: tl.constexpr, XBLOCK: tl.constexpr, + YBLOCK: tl.constexpr): v_idx = tl.arange(0, VBLOCK) v = tl.load(v_ptr + v_idx) v_ret = tl.argmax(v, 0) @@ -65,9 +66,9 @@ def test_reduce_deadcode(): dtype = torch.float32 in0 = torch.randn((XBLOCK, YBLOCK), dtype=dtype, device='npu') in1 = torch.randn((XBLOCK, YBLOCK), dtype=dtype, device='npu') - v = torch.randn((VBLOCK,), dtype=dtype, device='npu') + v = torch.randn((VBLOCK, ), dtype=dtype, device='npu') out = torch.zeros((XBLOCK, YBLOCK), dtype=dtype, device='npu') - triton_reduce_deadcode[(1,)](v, in0, in1, out, VBLOCK=VBLOCK, XBLOCK=XBLOCK, YBLOCK=YBLOCK) + triton_reduce_deadcode[(1, )](v, in0, in1, out, VBLOCK=VBLOCK, XBLOCK=XBLOCK, YBLOCK=YBLOCK) expected = torch_reduce_deadcode(in0, in1, v) test_common.validate_cmp(sigtype, out, expected)