Skip to content
61 changes: 61 additions & 0 deletions backends/aoti/common_shims_slim.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/aoti/common_shims_slim.h>

namespace executorch {
namespace backends {
namespace aoti {

// ============================================================
// Basic Property Getters - Implementations
// ============================================================

AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr) {
if (tensor == nullptr || ret_data_ptr == nullptr) {
return Error::InvalidArgument;
}
*ret_data_ptr = tensor->data_ptr();
return Error::Ok;
}

AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) {
if (tensor == nullptr || ret_sizes == nullptr) {
return Error::InvalidArgument;
}
*ret_sizes = const_cast<int64_t*>(tensor->sizes().data());
return Error::Ok;
}

AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) {
if (tensor == nullptr || ret_strides == nullptr) {
return Error::InvalidArgument;
}
*ret_strides = const_cast<int64_t*>(tensor->strides().data());
return Error::Ok;
}

AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
if (tensor == nullptr || ret_dtype == nullptr) {
return Error::InvalidArgument;
}
*ret_dtype = static_cast<int32_t>(tensor->dtype());
return Error::Ok;
}

AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
if (tensor == nullptr || ret_dim == nullptr) {
return Error::InvalidArgument;
}
*ret_dim = static_cast<int64_t>(tensor->dim());
return Error::Ok;
}

} // namespace aoti
} // namespace backends
} // namespace executorch
51 changes: 51 additions & 0 deletions backends/aoti/common_shims_slim.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/aoti/export.h>
#include <executorch/backends/aoti/slim/core/SlimTensor.h>
#include <executorch/runtime/core/error.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace aoti {

// Common using declarations for ExecuTorch types
using executorch::runtime::Error;

// Tensor type definition using SlimTensor
using Tensor = executorch::backends::aoti::slim::SlimTensor;

// Common AOTI type aliases
using AOTIRuntimeError = Error;
using AOTITorchError = Error;

// ============================================================
// Basic Property Getters - Declarations
// ============================================================

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);

} // namespace aoti
} // namespace backends
} // namespace executorch
18 changes: 18 additions & 0 deletions backends/aoti/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,21 @@ def define_common_targets():
":delegate_handle",
],
)

# SlimTensor-based common shims library
# Uses SlimTensor for all tensor operations
runtime.cxx_library(
name = "common_shims_slim",
srcs = [
"common_shims_slim.cpp",
],
headers = [
"common_shims_slim.h",
"export.h",
],
visibility = ["@EXECUTORCH_CLIENTS"],
exported_deps = [
"//executorch/runtime/core:core",
"//executorch/backends/aoti/slim/core:slimtensor",
],
)
25 changes: 25 additions & 0 deletions backends/aoti/tests/TARGETS
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils")

oncall("executorch")

Expand All @@ -20,3 +21,27 @@ cpp_unittest(
"//executorch/extension/tensor:tensor",
],
)

cpp_unittest(
name = "test_common_shims_slim",
srcs = [
"test_common_shims_slim.cpp",
],
deps = [
"//executorch/backends/aoti:common_shims_slim",
"//executorch/backends/aoti/slim/core:slimtensor",
"//executorch/backends/aoti/slim/factory:empty",
"//executorch/runtime/core:core",
"//executorch/runtime/platform:platform",
],
external_deps = [
("cuda", None, "cuda-lazy"),
],
preprocessor_flags = [
"-DCUDA_AVAILABLE=1",
],
keep_gpu_sections = True,
remote_execution = re_test_utils.remote_execution(
platform = "gpu-remote-execution",
),
)
Loading
Loading