diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index 835923828..45b5ee2f0 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -1,4 +1,6 @@ /************************************************************************* +* This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -13,6 +15,9 @@ #include "../common.h" #include "../utils.cuh" +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_device_utils.cuh" // for rocm_upper_bound() +#endif namespace transformer_engine { @@ -65,15 +70,22 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; // Find tensor corresponding to block +#ifdef __HIP_PLATFORM_AMD__ + const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid); +#else int tensor_id = 0; while (args.block_range[tensor_id + 1] <= bid) { ++tensor_id; } +#endif const Type* input = reinterpret_cast(args.input_list[tensor_id]); Type* output = reinterpret_cast(args.output_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int padded_num_rows = args.padded_num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; +#ifdef __HIP_PLATFORM_AMD__ + const bool inplace = (input == output); +#endif // Find position of tile within tensor const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; @@ -83,6 +95,35 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP const int tile_row = tile_id_m * tile_dim_m; const int tile_col = tile_id_n * tile_dim_n; +#ifdef __HIP_PLATFORM_AMD__ + // Process subtiles with vectorized loads/stores +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + const int col = tile_col + j1 * nvec; + const int remaining = row_length - col; + if (row < num_rows) { + // Valid data row: skip copy when in-place + if (!inplace) { + const size_t offset = static_cast(row) * row_length + col; + Vec v; + v.load_from_elts(input, offset, remaining > 0 ? min(remaining, nvec) : 0); + v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0); + } + } else if (row < padded_num_rows) { + // Padding row: fill with zeros + const size_t offset = static_cast(row) * row_length + col; + Vec v; + v.clear(); + v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0); + } + } + } +#else // !__HIP_PLATFORM_AMD__ // Load input and store to registers // Note: Each thread loads n_iterations subtiles, casts to output // type, and transposes in registers. @@ -125,6 +166,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP } } } +#endif // __HIP_PLATFORM_AMD__ } template @@ -150,14 +192,21 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; // Find tensor corresponding to block +#ifdef __HIP_PLATFORM_AMD__ + const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid); +#else int tensor_id = 0; while (args.block_range[tensor_id + 1] <= bid) { ++tensor_id; } +#endif const Type* input = reinterpret_cast(args.input_list[tensor_id]); Type* output = reinterpret_cast(args.output_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; +#ifdef __HIP_PLATFORM_AMD__ + const bool inplace = (input == output); +#endif // Find position of tile within tensor const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; @@ -167,6 +216,26 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult const int tile_row = tile_id_m * tile_dim_m; const int tile_col = tile_id_n * tile_dim_n; +#ifdef __HIP_PLATFORM_AMD__ + // Process subtiles with vectorized loads/stores +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + const int col = tile_col + j1 * nvec; + if (row < num_rows && !inplace) { + const int remaining = row_length - col; + const size_t offset = static_cast(row) * row_length + col; + Vec v; + v.load_from_elts(input, offset, remaining > 0 ? min(remaining, nvec) : 0); + v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0); + } + } + } +#else // !__HIP_PLATFORM_AMD__ // Load input and store to registers // Note: Each thread loads n_iterations subtiles, casts to output // type, and transposes in registers. @@ -202,6 +271,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult } } } +#endif // __HIP_PLATFORM_AMD__ } } // namespace diff --git a/transformer_engine/common/util/rocm_device_utils.cuh b/transformer_engine/common/util/rocm_device_utils.cuh index 0d2b4c658..89c49b533 100644 --- a/transformer_engine/common/util/rocm_device_utils.cuh +++ b/transformer_engine/common/util/rocm_device_utils.cuh @@ -118,6 +118,23 @@ __device__ __forceinline__ void rocm_atomicMaxFloat(float *addr, float val) { atomicMax(reinterpret_cast(addr), __float_as_int(val)); } +// Binary search on a sorted array. +// Returns the largest index i in [0, n) such that arr[i] <= val. +// Precondition: arr is sorted in non-decreasing order and arr[0] <= val. +template +__device__ __forceinline__ int rocm_upper_bound(const T* arr, int n, T val) { + int lo = 0, hi = n - 1; + while (lo < hi) { + int mid = (lo + hi + 1) / 2; + if (arr[mid] <= val) { + lo = mid; + } else { + hi = mid - 1; + } + } + return lo; +} + template __device__ __forceinline__ float rocm_block_reduce_max(float val, int warp_id) { __shared__ float staging[WARPS];