Skip to content
22 changes: 22 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,28 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
return Error::Ok;
}

AOTITorchError aoti_torch_new_tensor_handle(
Tensor* orig_handle,
Tensor** new_handle) {
ET_CHECK_OR_RETURN_ERROR(
orig_handle != nullptr,
InvalidArgument,
"aoti_torch_new_tensor_handle: orig_handle is null");

ET_CHECK_OR_RETURN_ERROR(
new_handle != nullptr,
InvalidArgument,
"aoti_torch_new_tensor_handle: new_handle is null");

// Create a new SlimTensor that shares the same underlying storage.
// SlimTensor's copy constructor shares the SharedPtr<Storage>, so both
// tensors will reference the same memory. When the last tensor is deleted,
// the storage will be freed.
*new_handle = new Tensor(*orig_handle);

return Error::Ok;
}

} // extern "C"

} // namespace executorch::backends::cuda
16 changes: 16 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_empty_strided(
*/
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor);

/**
* Creates a new tensor handle that shares storage with the original tensor.
*
* The new handle is a copy of the original tensor's metadata (sizes, strides,
* dtype, device) and shares the same underlying storage via SharedPtr.
* Both tensors will reference the same memory, and the memory will only be
* freed when all references are deleted.
*
* @param orig_handle Pointer to the original tensor (must not be null)
* @param new_handle Output parameter for the new tensor handle
* @return AOTITorchError error code (Error::Ok on success)
*/
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_new_tensor_handle(
Tensor* orig_handle,
Tensor** new_handle);

} // extern "C"

} // namespace executorch::backends::cuda
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,4 @@ def define_common_targets():
cuda_shim_slim_cpp_unittest("aoti_torch_empty_strided")
cuda_shim_slim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
cuda_shim_slim_cpp_unittest("aoti_torch_delete_tensor_object")
cuda_shim_slim_cpp_unittest("aoti_torch_new_tensor_handle")
Loading
Loading