Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions backends/aoti/common_shims_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,48 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
return Error::Ok;
}

// ============================================================
// Storage & Device Property Getters - Implementations
// ============================================================

AOTITorchError aoti_torch_get_storage_offset(
Tensor* tensor,
int64_t* ret_storage_offset) {
if (tensor == nullptr || ret_storage_offset == nullptr) {
return Error::InvalidArgument;
}
*ret_storage_offset = tensor->storage_offset();
return Error::Ok;
}

AOTITorchError aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) {
if (tensor == nullptr || ret_size == nullptr) {
return Error::InvalidArgument;
}
*ret_size = static_cast<int64_t>(tensor->storage()->nbytes());
return Error::Ok;
}

AOTITorchError aoti_torch_get_device_type(
Tensor* tensor,
int32_t* ret_device_type) {
if (tensor == nullptr || ret_device_type == nullptr) {
return Error::InvalidArgument;
}
*ret_device_type = static_cast<int32_t>(tensor->device_type());
return Error::Ok;
}

AOTITorchError aoti_torch_get_device_index(
Tensor* tensor,
int32_t* ret_device_index) {
if (tensor == nullptr || ret_device_index == nullptr) {
return Error::InvalidArgument;
}
*ret_device_index = static_cast<int32_t>(tensor->device_index());
return Error::Ok;
}

} // namespace aoti
} // namespace backends
} // namespace executorch
16 changes: 16 additions & 0 deletions backends/aoti/common_shims_slim.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);

// ============================================================
// Storage & Device Property Getters - Declarations
// ============================================================

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_storage_offset(Tensor* tensor, int64_t* ret_storage_offset);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_device_type(Tensor* tensor, int32_t* ret_device_type);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_device_index(Tensor* tensor, int32_t* ret_device_index);

} // namespace aoti
} // namespace backends
} // namespace executorch
131 changes: 131 additions & 0 deletions backends/aoti/tests/test_common_shims_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,93 @@ void runGetDimTest(slim_c10::DeviceType device_type) {
}
}

// ============================================================================
// Storage & Device Property Tests
// ============================================================================

void runGetStorageOffsetTest(slim_c10::DeviceType device_type) {
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
slim_c10::Device device(device_type, 0);

Tensor* tensor = new Tensor(slim::empty_strided(
slim::makeArrayRef(sizes),
slim::makeArrayRef(strides),
slim_c10::ScalarType::Float,
device));

int64_t ret_storage_offset = -1;
AOTITorchError error =
aoti_torch_get_storage_offset(tensor, &ret_storage_offset);

EXPECT_EQ(error, Error::Ok);
// Default storage offset for newly created tensor is 0
EXPECT_EQ(ret_storage_offset, 0);

delete tensor;
}

void runGetStorageSizeTest(slim_c10::DeviceType device_type) {
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
slim_c10::Device device(device_type, 0);

Tensor* tensor = new Tensor(slim::empty_strided(
slim::makeArrayRef(sizes),
slim::makeArrayRef(strides),
slim_c10::ScalarType::Float,
device));

int64_t ret_size = -1;
AOTITorchError error = aoti_torch_get_storage_size(tensor, &ret_size);

EXPECT_EQ(error, Error::Ok);
// 2 * 3 * sizeof(float) = 6 * 4 = 24 bytes
EXPECT_EQ(ret_size, 24);

delete tensor;
}

void runGetDeviceTypeTest(slim_c10::DeviceType device_type) {
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
slim_c10::Device device(device_type, 0);

Tensor* tensor = new Tensor(slim::empty_strided(
slim::makeArrayRef(sizes),
slim::makeArrayRef(strides),
slim_c10::ScalarType::Float,
device));

int32_t ret_device_type = -1;
AOTITorchError error = aoti_torch_get_device_type(tensor, &ret_device_type);

EXPECT_EQ(error, Error::Ok);
EXPECT_EQ(ret_device_type, static_cast<int32_t>(device_type));

delete tensor;
}

void runGetDeviceIndexTest(slim_c10::DeviceType device_type) {
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
slim_c10::Device device(device_type, 0);

Tensor* tensor = new Tensor(slim::empty_strided(
slim::makeArrayRef(sizes),
slim::makeArrayRef(strides),
slim_c10::ScalarType::Float,
device));

int32_t ret_device_index = -1;
AOTITorchError error = aoti_torch_get_device_index(tensor, &ret_device_index);

EXPECT_EQ(error, Error::Ok);
EXPECT_EQ(ret_device_index, 0);

delete tensor;
}

// ============================================================================
// CPU Tests
// ============================================================================
Expand All @@ -313,6 +400,22 @@ TEST_F(CommonShimsSlimTest, GetDim_CPU) {
runGetDimTest(slim_c10::DeviceType::CPU);
}

TEST_F(CommonShimsSlimTest, GetStorageOffset_CPU) {
runGetStorageOffsetTest(slim_c10::DeviceType::CPU);
}

TEST_F(CommonShimsSlimTest, GetStorageSize_CPU) {
runGetStorageSizeTest(slim_c10::DeviceType::CPU);
}

TEST_F(CommonShimsSlimTest, GetDeviceType_CPU) {
runGetDeviceTypeTest(slim_c10::DeviceType::CPU);
}

TEST_F(CommonShimsSlimTest, GetDeviceIndex_CPU) {
runGetDeviceIndexTest(slim_c10::DeviceType::CPU);
}

// ============================================================================
// CUDA Tests
// ============================================================================
Expand Down Expand Up @@ -352,6 +455,34 @@ TEST_F(CommonShimsSlimTest, GetDim_CUDA) {
}
runGetDimTest(slim_c10::DeviceType::CUDA);
}

TEST_F(CommonShimsSlimTest, GetStorageOffset_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}
runGetStorageOffsetTest(slim_c10::DeviceType::CUDA);
}

TEST_F(CommonShimsSlimTest, GetStorageSize_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}
runGetStorageSizeTest(slim_c10::DeviceType::CUDA);
}

TEST_F(CommonShimsSlimTest, GetDeviceType_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}
runGetDeviceTypeTest(slim_c10::DeviceType::CUDA);
}

TEST_F(CommonShimsSlimTest, GetDeviceIndex_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}
runGetDeviceIndexTest(slim_c10::DeviceType::CUDA);
}
#endif

// ============================================================================
Expand Down
Loading