Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/detail/HIPHooksInterface.h>
#include <ATen/detail/IPUHooksInterface.h>
#include <ATen/detail/MAIAHooksInterface.h>
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/detail/MTIAHooksInterface.h>
#include <ATen/detail/ORTHooksInterface.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <ATen/detail/XPUHooksInterface.h>
#include <c10/core/QEngine.h>
Expand Down Expand Up @@ -142,8 +142,8 @@ class TORCH_API Context {
static bool hasLazy() {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
}
static bool hasORT() {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::ORT);
static bool hasMAIA() {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
}
// defined in header so that getNonVariableType has ability to inline
// call_once check. getNonVariableType is called fairly frequently
Expand Down Expand Up @@ -455,8 +455,8 @@ static inline bool hasMPS() {
return globalContext().hasMPS();
}

static inline bool hasORT() {
return globalContext().hasORT();
static inline bool hasMAIA() {
return globalContext().hasMAIA();
}

static inline bool hasXPU() {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/TensorIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1530,13 +1530,13 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) {

// XLA and lazy tensors don't have storage, so they don't have an underlying data pointer.
// Nothing beyond this point is important for meta functions, so it's fine to exit early here.
// Extend the condition to ORT tesnors as ORT tensors also don't have storage.
// Extend the condition to MAIA tesnors as MAIA tensors also don't have storage.
if (privateuse1_without_storage ||
common_device_.type() == DeviceType::MTIA ||
common_device_.type() == DeviceType::XLA ||
common_device_.type() == DeviceType::IPU ||
common_device_.type() == DeviceType::Lazy ||
common_device_.type() == DeviceType::ORT ||
common_device_.type() == DeviceType::MAIA ||
common_device_.type() == DeviceType::HPU) return;

for (auto& op : operands_) {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/Version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ std::string show_config() {
ss << detail::getCUDAHooks().showConfig();
}

if (hasORT()) {
ss << detail::getORTHooks().showConfig();
if (hasMAIA()) {
ss << detail::getMAIAHooks().showConfig();
}

if (hasXPU()) {
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/core/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,10 @@ class TORCH_API TensorBase {
return impl_->is_mps();
}

/// Returns if a `Tensor` is ort tensor.
bool is_ort() const {
/// Returns if a `Tensor` is maia tensor.
bool is_maia() const {
// NB: this is not a native function to avoid dispatching overhead.
return impl_->is_ort();
return impl_->is_maia();
}

/// Returns if a `Tensor` is vulkan tensor.
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
// In theory, we should only have to check if the given runtime key has "dense" functionality,
// e.g. DispatchKey::CPU (which is composed of DispatchKey::Dense and BackendComponent::CPUBit).
// However, there are some backends that should be included in this set that don't have the dense key set.
// E.g. DispatchKey::Meta, DispatchKey::ORT.
// E.g. DispatchKey::Meta, DispatchKey::MAIA.
if (c10::isBackendDispatchKey(dispatch_key)) {
DispatchKey autograd_key = getAutogradKeyFromBackend(toBackendComponent(dispatch_key));
updateDispatchTableEntry_(dispatcher, autograd_key);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/core/op_registration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ There’s four main use cases
* You’re writing a new operator that isn’t supposed to be part of the public PyTorch API.
* You’re writing a new operator but don’t want to change the core pytorch code base, say you’re developing a shared library with operators.
* You’re writing a C++ extension for PyTorch or you’re using inline c++ in your .py model files.
* You’re writing a backend library like XLA or ORT that adds new kernels to all operators defined in `native_functions.yaml`.
* You’re writing a backend library like XLA or MAIA that adds new kernels to all operators defined in `native_functions.yaml`.

For these use cases, the custom operator API is the better solution.

### What is the price for using the custom operator API instead of `native_functions.yaml`?

If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/ORT example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats.
If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/MAIA example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats.

* It will not get a C++ API generated. There will not be `Tensor::your_op()` methods or `at::your_op()` functions to call your operator.
* The API for calling the operator from Python looks a little bit different. It needs to be called through `torch.ops.your_op()` instead of `torch._C`.
Expand Down
29 changes: 29 additions & 0 deletions aten/src/ATen/detail/MAIAHooksInterface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <ATen/detail/MAIAHooksInterface.h>

#include <c10/util/CallOnce.h>
#include <c10/util/Registry.h>

#include <cstddef>
#include <memory>

namespace at {
namespace detail {

// See getCUDAHooks for some more commentary
const MAIAHooksInterface& getMAIAHooks() {
static std::unique_ptr<MAIAHooksInterface> maia_hooks;
static c10::once_flag once;
c10::call_once(once, [] {
maia_hooks = MAIAHooksRegistry()->Create("MAIAHooks", {});
if (!maia_hooks) {
maia_hooks = std::make_unique<MAIAHooksInterface>();
}
});
return *maia_hooks;
}
} // namespace detail

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
C10_DEFINE_REGISTRY(MAIAHooksRegistry, MAIAHooksInterface, MAIAHooksArgs)

} // namespace at
31 changes: 31 additions & 0 deletions aten/src/ATen/detail/MAIAHooksInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include <c10/util/Exception.h>
#include <c10/util/Registry.h>

// NB: Class must live in `at` due to limitations of Registry.h.
namespace at {

struct TORCH_API MAIAHooksInterface {
// This should never actually be implemented, but it is used to
// squelch -Werror=non-virtual-dtor
virtual ~MAIAHooksInterface() = default;

virtual std::string showConfig() const {
TORCH_CHECK(false, "Cannot query detailed MAIA version information.");
}
};

// NB: dummy argument to suppress "ISO C++11 requires at least one argument
// for the "..." in a variadic macro"
struct TORCH_API MAIAHooksArgs {};

TORCH_DECLARE_REGISTRY(MAIAHooksRegistry, MAIAHooksInterface, MAIAHooksArgs);
#define REGISTER_MAIA_HOOKS(clsname) \
C10_REGISTER_CLASS(MAIAHooksRegistry, clsname, clsname)

namespace detail {
TORCH_API const MAIAHooksInterface& getMAIAHooks();
} // namespace detail

} // namespace at
29 changes: 0 additions & 29 deletions aten/src/ATen/detail/ORTHooksInterface.cpp

This file was deleted.

36 changes: 0 additions & 36 deletions aten/src/ATen/detail/ORTHooksInterface.h

This file was deleted.

22 changes: 11 additions & 11 deletions aten/src/ATen/test/extension_backend_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

#include <torch/csrc/jit/runtime/operator.h>

// NB. These tests use the ORT dispatch key to test backend dispatching
// machinery, but these tests are not specific to ORT at all. The ORT
// NB. These tests use the MAIA dispatch key to test backend dispatching
// machinery, but these tests are not specific to MAIA at all. The MAIA
// backend is fully out-of-tree, so it's safe to use this key for
// in-tree tests.

Expand All @@ -22,16 +22,16 @@ Tensor empty_override(SymIntArrayRef size, c10::optional<ScalarType> dtype, c10:
Storage(
Storage::use_byte_size_t(),
0,
at::DataPtr(nullptr, Device(DeviceType::ORT, 1)),
at::DataPtr(nullptr, Device(DeviceType::MAIA, 1)),
nullptr,
false),
DispatchKey::ORT,
DispatchKey::MAIA,
caffe2::TypeMeta::Make<float>());
return Tensor(std::move(tensor_impl));
}

Tensor add_override(const Tensor & a, const Tensor & b , const Scalar& c) {
auto out = empty({5, 5}, at::kORT); // Don't return self as-is
auto out = empty({5, 5}, at::kMAIA); // Don't return self as-is
test_int = 2;
return out;
}
Expand All @@ -47,28 +47,28 @@ Tensor empty_strided_override(
return empty_override(fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, c10::nullopt);
}

TORCH_LIBRARY_IMPL(aten, ORT, m) {
TORCH_LIBRARY_IMPL(aten, MAIA, m) {
m.impl("aten::empty.memory_format", empty_override);
m.impl("aten::empty_strided", empty_strided_override);
m.impl("aten::add.Tensor", add_override);
}

TEST(BackendExtensionTest, TestRegisterOp) {
Tensor a = empty({5, 5}, at::kORT);
ASSERT_EQ(a.device().type(), at::kORT);
Tensor a = empty({5, 5}, at::kMAIA);
ASSERT_EQ(a.device().type(), at::kMAIA);
ASSERT_EQ(a.device().index(), 1);
ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
ASSERT_EQ(test_int, 1);

Tensor b = empty_like(a, at::kORT);
ASSERT_EQ(b.device().type(), at::kORT);
Tensor b = empty_like(a, at::kMAIA);
ASSERT_EQ(b.device().type(), at::kMAIA);
ASSERT_EQ(b.device().index(), 1);
ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make<float>());

add(a, b);
ASSERT_EQ(test_int, 2);

// Ensure that non-ORT operator still works
// Ensure that non-MAIA operator still works
Tensor d = empty({5, 5}, at::kCPU);
ASSERT_EQ(d.device().type(), at::kCPU);
}
4 changes: 2 additions & 2 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@ aten_cpu_non_globed_sources = [
"aten/src/ATen/detail/CUDAHooksInterface.cpp",
"aten/src/ATen/detail/HIPHooksInterface.cpp",
"aten/src/ATen/detail/MPSHooksInterface.cpp",
"aten/src/ATen/detail/ORTHooksInterface.cpp",
"aten/src/ATen/detail/MAIAHooksInterface.cpp",
"aten/src/ATen/detail/PrivateUse1HooksInterface.cpp",
"aten/src/ATen/detail/XPUHooksInterface.cpp",
"aten/src/ATen/detail/MTIAHooksInterface.cpp",
Expand All @@ -964,7 +964,7 @@ aten_cpu_non_globed_headers = [
"aten/src/ATen/detail/CUDAHooksInterface.h",
"aten/src/ATen/detail/MPSHooksInterface.h",
"aten/src/ATen/detail/HIPHooksInterface.h",
"aten/src/ATen/detail/ORTHooksInterface.h",
"aten/src/ATen/detail/MAIAHooksInterface.h",
"aten/src/ATen/detail/PrivateUse1HooksInterface.h",
"aten/src/ATen/detail/XPUHooksInterface.h",
"aten/src/ATen/detail/MTIAHooksInterface.h",
Expand Down
18 changes: 9 additions & 9 deletions c10/core/Backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ enum class Backend {
SparseCsrVE,
SparseCsrXPU,
SparseCsrPrivateUse1,
ORT,
MAIA,
XLA,
Vulkan,
Metal,
Expand Down Expand Up @@ -76,8 +76,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::VE;
} else if (t == DispatchKey::FPGA) {
return Backend::FPGA;
} else if (t == DispatchKey::ORT) {
return Backend::ORT;
} else if (t == DispatchKey::MAIA) {
return Backend::MAIA;
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
return Backend::XLA;
} else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
Expand Down Expand Up @@ -154,8 +154,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
return DispatchKey::VE;
case Backend::FPGA:
return DispatchKey::FPGA;
case Backend::ORT:
return DispatchKey::ORT;
case Backend::MAIA:
return DispatchKey::MAIA;
case Backend::XLA:
return DispatchKey::XLA;
case Backend::Lazy:
Expand Down Expand Up @@ -236,8 +236,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::VE;
case Backend::FPGA:
return DeviceType::FPGA;
case Backend::ORT:
return DeviceType::ORT;
case Backend::MAIA:
return DeviceType::MAIA;
case Backend::XLA:
return DeviceType::XLA;
case Backend::Lazy:
Expand Down Expand Up @@ -298,8 +298,8 @@ static inline const char* toString(Backend b) {
return "XPU";
case Backend::IPU:
return "IPU";
case Backend::ORT:
return "ORT";
case Backend::MAIA:
return "MAIA";
case Backend::XLA:
return "XLA";
case Backend::Lazy:
Expand Down
2 changes: 1 addition & 1 deletion c10/core/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ DeviceType parse_type(const std::string& device_string) {
{"hip", DeviceType::HIP},
{"ve", DeviceType::VE},
{"fpga", DeviceType::FPGA},
{"ort", DeviceType::ORT},
{"maia", DeviceType::MAIA},
{"xla", DeviceType::XLA},
{"lazy", DeviceType::Lazy},
{"vulkan", DeviceType::Vulkan},
Expand Down
6 changes: 3 additions & 3 deletions c10/core/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ struct C10_API Device final {
return type_ == DeviceType::Metal;
}

/// Return true if the device is of ORT type.
bool is_ort() const noexcept {
return type_ == DeviceType::ORT;
/// Return true if the device is of MAIA type.
bool is_maia() const noexcept {
return type_ == DeviceType::MAIA;
}

/// Return true if the device is of META type.
Expand Down
Loading