Skip to content
41 changes: 41 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,47 @@ AOTITorchError aoti_torch_new_tensor_handle(
return Error::Ok;
}

AOTITorchError aoti_torch__reinterpret_tensor(
Tensor* self,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
Tensor** ret_new_tensor) {
ET_CHECK_OR_RETURN_ERROR(
self != nullptr,
InvalidArgument,
"aoti_torch__reinterpret_tensor: self is null");

ET_CHECK_OR_RETURN_ERROR(
ret_new_tensor != nullptr,
InvalidArgument,
"aoti_torch__reinterpret_tensor: ret_new_tensor is null");

ET_CHECK_OR_RETURN_ERROR(
ndim >= 0,
InvalidArgument,
"aoti_torch__reinterpret_tensor: ndim must be non-negative, got %lld",
static_cast<long long>(ndim));

ET_CHECK_OR_RETURN_ERROR(
!(sizes_ptr == nullptr && ndim > 0),
InvalidArgument,
"aoti_torch__reinterpret_tensor: sizes_ptr is null but ndim > 0");

IntArrayRef sizes(sizes_ptr, static_cast<size_t>(ndim));
IntArrayRef strides(strides_ptr, static_cast<size_t>(ndim));

// Create a new tensor view using as_strided. This creates a tensor that
// shares the same underlying storage but with different sizes, strides,
// and storage offset. SlimTensor::as_strided() handles this via copy
// constructor which shares the SharedPtr<Storage>.
*ret_new_tensor =
new Tensor(self->as_strided(sizes, strides, storage_offset));

return Error::Ok;
}

} // extern "C"

} // namespace executorch::backends::cuda
27 changes: 24 additions & 3 deletions backends/cuda/runtime/shims/memory_slim.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,30 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor);
* @param new_handle Output parameter for the new tensor handle
* @return AOTITorchError error code (Error::Ok on success)
*/
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_new_tensor_handle(
Tensor* orig_handle,
Tensor** new_handle);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);

/**
* Creates a reinterpreted view of a tensor with new sizes, strides, and offset.
*
* This is equivalent to torch.as_strided() - it creates a new tensor that
* shares the same underlying storage but with different view parameters.
*
* @param self Original tensor to reinterpret (must not be null)
* @param ndim Number of dimensions for the new view
* @param sizes_ptr Pointer to array of dimension sizes
* @param strides_ptr Pointer to array of strides for each dimension
* @param storage_offset Storage offset in number of elements
* @param ret_new_tensor Output parameter for the reinterpreted tensor view
* @return AOTITorchError error code (Error::Ok on success)
*/
AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
Tensor* self,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
Tensor** ret_new_tensor);

} // extern "C"

Expand Down
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@ def define_common_targets():
cuda_shim_slim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
cuda_shim_slim_cpp_unittest("aoti_torch_delete_tensor_object")
cuda_shim_slim_cpp_unittest("aoti_torch_new_tensor_handle")
cuda_shim_slim_cpp_unittest("aoti_torch__reinterpret_tensor")
Loading
Loading