Skip to content
20 changes: 20 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,26 @@ AOTITorchError aoti_torch__reinterpret_tensor(
return Error::Ok;
}

AOTITorchError
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
(void)non_blocking; // SlimTensor::copy_() is always synchronous for now

ET_CHECK_OR_RETURN_ERROR(
self != nullptr, InvalidArgument, "aoti_torch_copy_: self is null");

ET_CHECK_OR_RETURN_ERROR(
src != nullptr, InvalidArgument, "aoti_torch_copy_: src is null");

// SlimTensor::copy_() handles:
// - Same numel validation
// - Same dtype validation
// - CPU-CPU, CPU-CUDA, CUDA-CPU, CUDA-CUDA copies
// - Contiguous fast path and non-contiguous element-wise copy
self->copy_(*src);

return Error::Ok;
}

} // extern "C"

} // namespace executorch::backends::cuda
15 changes: 15 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,21 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
int64_t storage_offset,
Tensor** ret_new_tensor);

/**
* Copies data from source tensor to destination tensor.
*
* Handles all device combinations (CPU-CPU, CPU-CUDA, CUDA-CPU, CUDA-CUDA)
* and supports tensors with different strides. The destination tensor must
* already be allocated with sufficient storage.
*
* @param self Destination tensor (must not be null)
* @param src Source tensor to copy from (must not be null)
* @param non_blocking If true, the copy may be asynchronous (currently ignored)
* @return AOTITorchError error code (Error::Ok on success)
*/
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);

} // extern "C"

} // namespace executorch::backends::cuda
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ def define_common_targets():
cuda_shim_slim_cpp_unittest("aoti_torch_delete_tensor_object")
cuda_shim_slim_cpp_unittest("aoti_torch_new_tensor_handle")
cuda_shim_slim_cpp_unittest("aoti_torch__reinterpret_tensor")
cuda_shim_slim_cpp_unittest("aoti_torch_copy_")
Loading
Loading