diff --git a/backends/aoti/common_shims_slim.cpp b/backends/aoti/common_shims_slim.cpp index 15d5343ad31..b004c1d16a6 100644 --- a/backends/aoti/common_shims_slim.cpp +++ b/backends/aoti/common_shims_slim.cpp @@ -56,6 +56,48 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) { return Error::Ok; } +// ============================================================ +// Storage & Device Property Getters - Implementations +// ============================================================ + +AOTITorchError aoti_torch_get_storage_offset( + Tensor* tensor, + int64_t* ret_storage_offset) { + if (tensor == nullptr || ret_storage_offset == nullptr) { + return Error::InvalidArgument; + } + *ret_storage_offset = tensor->storage_offset(); + return Error::Ok; +} + +AOTITorchError aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) { + if (tensor == nullptr || ret_size == nullptr) { + return Error::InvalidArgument; + } + *ret_size = static_cast(tensor->storage()->nbytes()); + return Error::Ok; +} + +AOTITorchError aoti_torch_get_device_type( + Tensor* tensor, + int32_t* ret_device_type) { + if (tensor == nullptr || ret_device_type == nullptr) { + return Error::InvalidArgument; + } + *ret_device_type = static_cast(tensor->device_type()); + return Error::Ok; +} + +AOTITorchError aoti_torch_get_device_index( + Tensor* tensor, + int32_t* ret_device_index) { + if (tensor == nullptr || ret_device_index == nullptr) { + return Error::InvalidArgument; + } + *ret_device_index = static_cast(tensor->device_index()); + return Error::Ok; +} + } // namespace aoti } // namespace backends } // namespace executorch diff --git a/backends/aoti/common_shims_slim.h b/backends/aoti/common_shims_slim.h index 4669c7d771c..26022a76f14 100644 --- a/backends/aoti/common_shims_slim.h +++ b/backends/aoti/common_shims_slim.h @@ -46,6 +46,22 @@ aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype); AOTI_SHIM_EXPORT AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim); +// ============================================================ +// Storage & Device Property Getters - Declarations +// ============================================================ + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_storage_offset(Tensor* tensor, int64_t* ret_storage_offset); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_device_type(Tensor* tensor, int32_t* ret_device_type); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_device_index(Tensor* tensor, int32_t* ret_device_index); + } // namespace aoti } // namespace backends } // namespace executorch diff --git a/backends/aoti/tests/test_common_shims_slim.cpp b/backends/aoti/tests/test_common_shims_slim.cpp index 2e4bfa63286..728bcc6a34f 100644 --- a/backends/aoti/tests/test_common_shims_slim.cpp +++ b/backends/aoti/tests/test_common_shims_slim.cpp @@ -289,6 +289,93 @@ void runGetDimTest(slim_c10::DeviceType device_type) { } } +// ============================================================================ +// Storage & Device Property Tests +// ============================================================================ + +void runGetStorageOffsetTest(slim_c10::DeviceType device_type) { + std::vector sizes = {2, 3}; + std::vector strides = calculateContiguousStrides(sizes); + slim_c10::Device device(device_type, 0); + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + + int64_t ret_storage_offset = -1; + AOTITorchError error = + aoti_torch_get_storage_offset(tensor, &ret_storage_offset); + + EXPECT_EQ(error, Error::Ok); + // Default storage offset for newly created tensor is 0 + EXPECT_EQ(ret_storage_offset, 0); + + delete tensor; +} + +void runGetStorageSizeTest(slim_c10::DeviceType device_type) { + std::vector sizes = {2, 3}; + std::vector strides = calculateContiguousStrides(sizes); + slim_c10::Device device(device_type, 0); + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + + int64_t ret_size = -1; + AOTITorchError error = aoti_torch_get_storage_size(tensor, &ret_size); + + EXPECT_EQ(error, Error::Ok); + // 2 * 3 * sizeof(float) = 6 * 4 = 24 bytes + EXPECT_EQ(ret_size, 24); + + delete tensor; +} + +void runGetDeviceTypeTest(slim_c10::DeviceType device_type) { + std::vector sizes = {2, 3}; + std::vector strides = calculateContiguousStrides(sizes); + slim_c10::Device device(device_type, 0); + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + + int32_t ret_device_type = -1; + AOTITorchError error = aoti_torch_get_device_type(tensor, &ret_device_type); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(ret_device_type, static_cast(device_type)); + + delete tensor; +} + +void runGetDeviceIndexTest(slim_c10::DeviceType device_type) { + std::vector sizes = {2, 3}; + std::vector strides = calculateContiguousStrides(sizes); + slim_c10::Device device(device_type, 0); + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + + int32_t ret_device_index = -1; + AOTITorchError error = aoti_torch_get_device_index(tensor, &ret_device_index); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(ret_device_index, 0); + + delete tensor; +} + // ============================================================================ // CPU Tests // ============================================================================ @@ -313,6 +400,22 @@ TEST_F(CommonShimsSlimTest, GetDim_CPU) { runGetDimTest(slim_c10::DeviceType::CPU); } +TEST_F(CommonShimsSlimTest, GetStorageOffset_CPU) { + runGetStorageOffsetTest(slim_c10::DeviceType::CPU); +} + +TEST_F(CommonShimsSlimTest, GetStorageSize_CPU) { + runGetStorageSizeTest(slim_c10::DeviceType::CPU); +} + +TEST_F(CommonShimsSlimTest, GetDeviceType_CPU) { + runGetDeviceTypeTest(slim_c10::DeviceType::CPU); +} + +TEST_F(CommonShimsSlimTest, GetDeviceIndex_CPU) { + runGetDeviceIndexTest(slim_c10::DeviceType::CPU); +} + // ============================================================================ // CUDA Tests // ============================================================================ @@ -352,6 +455,34 @@ TEST_F(CommonShimsSlimTest, GetDim_CUDA) { } runGetDimTest(slim_c10::DeviceType::CUDA); } + +TEST_F(CommonShimsSlimTest, GetStorageOffset_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runGetStorageOffsetTest(slim_c10::DeviceType::CUDA); +} + +TEST_F(CommonShimsSlimTest, GetStorageSize_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runGetStorageSizeTest(slim_c10::DeviceType::CUDA); +} + +TEST_F(CommonShimsSlimTest, GetDeviceType_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runGetDeviceTypeTest(slim_c10::DeviceType::CUDA); +} + +TEST_F(CommonShimsSlimTest, GetDeviceIndex_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runGetDeviceIndexTest(slim_c10::DeviceType::CUDA); +} #endif // ============================================================================