Skip to content
Open
46 changes: 46 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using c10::DeviceIndex;
using c10::DeviceType;
using c10::ScalarType;
using executorch::backends::aoti::slim::empty_strided;
using executorch::backends::aoti::slim::makeArrayRef;
using executorch::backends::aoti::slim::from_blob;
using executorch::backends::aoti::slim::IntArrayRef;

Expand Down Expand Up @@ -76,6 +77,51 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2(
return Error::Ok;
}

AOTITorchError aoti_torch_empty_strided(
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor) {
ET_CHECK_OR_RETURN_ERROR(
ret_new_tensor != nullptr,
InvalidArgument,
"aoti_torch_empty_strided: ret_new_tensor is null");

ET_CHECK_OR_RETURN_ERROR(
!(sizes_ptr == nullptr && ndim > 0),
InvalidArgument,
"aoti_torch_empty_strided: sizes_ptr is null but ndim > 0");

IntArrayRef sizes(sizes_ptr, static_cast<size_t>(ndim));

// Handle nullptr strides by computing contiguous strides
if (strides_ptr == nullptr) {
std::vector<int64_t> contig_strides =
executorch::backends::aoti::slim::compute_contiguous_strides(sizes);
*ret_new_tensor = new Tensor(empty_strided(
sizes,
makeArrayRef(contig_strides),
static_cast<ScalarType>(dtype),
Device(
static_cast<DeviceType>(device_type),
static_cast<DeviceIndex>(device_index))));
} else {
IntArrayRef strides(strides_ptr, static_cast<size_t>(ndim));
*ret_new_tensor = new Tensor(empty_strided(
sizes,
strides,
static_cast<ScalarType>(dtype),
Device(
static_cast<DeviceType>(device_type),
static_cast<DeviceIndex>(device_index))));
}

return Error::Ok;
}

} // extern "C"

} // namespace executorch::backends::cuda
22 changes: 22 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,28 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2(
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size);

/**
* Creates an uninitialized tensor with specified dimensions, strides, and
* dtype on either CPU or CUDA device.
*
* @param ndim Number of dimensions in the tensor
* @param sizes_ptr Pointer to array of dimension sizes
* @param strides_ptr Pointer to array of strides for each dimension
* @param dtype Data type identifier (matches PyTorch scalar types)
* @param device_type Device type (0=CPU, 1=CUDA)
* @param device_index Device index
* @param ret_new_tensor Output parameter for the created tensor
* @return AOTITorchError error code (Error::Ok on success)
*/
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_empty_strided(
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor);

} // extern "C"

} // namespace executorch::backends::cuda
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")

# SlimTensor-based shim tests
cuda_shim_slim_cpp_unittest("aoti_torch_empty_strided")
cuda_shim_slim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
Loading
Loading