From 0a08a287c46e259676260d2518afb554a84aea89 Mon Sep 17 00:00:00 2001 From: Ashwin Hari Date: Mon, 22 Apr 2024 08:47:52 -0700 Subject: [PATCH 1/5] rename ort dispatch key to maia --- aten/src/ATen/Context.h | 10 ++--- aten/src/ATen/TensorIterator.cpp | 4 +- aten/src/ATen/Version.cpp | 4 +- aten/src/ATen/core/TensorBase.h | 6 +-- aten/src/ATen/core/dispatch/OperatorEntry.cpp | 2 +- aten/src/ATen/core/op_registration/README.md | 4 +- aten/src/ATen/detail/MAIAHooksInterface.cpp | 29 ++++++++++++++ aten/src/ATen/detail/MAIAHooksInterface.h | 31 ++++++++++++++ aten/src/ATen/detail/ORTHooksInterface.cpp | 29 -------------- aten/src/ATen/detail/ORTHooksInterface.h | 36 ----------------- aten/src/ATen/test/extension_backend_test.cpp | 22 +++++----- build_variables.bzl | 4 +- c10/core/Backend.h | 18 ++++----- c10/core/Device.cpp | 2 +- c10/core/Device.h | 6 +-- c10/core/DeviceType.cpp | 6 +-- c10/core/DeviceType.h | 4 +- c10/core/DispatchKey.cpp | 6 +-- c10/core/DispatchKey.h | 8 ++-- c10/core/DispatchKeySet.h | 2 +- c10/core/TensorImpl.h | 6 +-- c10/core/TensorOptions.h | 8 ++-- caffe2/proto/caffe2.proto | 2 +- caffe2/proto/caffe2_pb2.pyi | 4 +- .../{ort_extension.cpp => maia_extension.cpp} | 38 +++++++++--------- test/cpp_extensions/setup.py | 4 +- test/test_cpp_extensions_aot.py | 40 +++++++++---------- third_party/ideep | 2 +- third_party/kineto | 2 +- third_party/onnx | 2 +- third_party/pybind11 | 2 +- third_party/sleef | 2 +- tools/pyi/gen_pyi.py | 2 +- torch/_C/_autograd.pyi | 2 +- torch/_tensor.py | 6 +-- torch/csrc/autograd/init.cpp | 2 +- torch/csrc/autograd/python_variable.cpp | 8 ++-- torch/csrc/jit/frontend/sugared_value.cpp | 2 +- torch/csrc/jit/runtime/register_prim_ops.cpp | 4 +- torch/library.h | 4 +- torch/overrides.py | 2 +- torchgen/model.py | 2 +- 42 files changed, 186 insertions(+), 193 deletions(-) create mode 100644 aten/src/ATen/detail/MAIAHooksInterface.cpp create mode 100644 aten/src/ATen/detail/MAIAHooksInterface.h delete mode 100644 aten/src/ATen/detail/ORTHooksInterface.cpp delete mode 100644 aten/src/ATen/detail/ORTHooksInterface.h rename test/cpp_extensions/{ort_extension.cpp => maia_extension.cpp} (78%) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 5c5036caa90fb..e85af59b0dbe8 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -14,7 +14,7 @@ #include #include #include -#include +#include #include #include #include @@ -142,8 +142,8 @@ class TORCH_API Context { static bool hasLazy() { return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy); } - static bool hasORT() { - return c10::impl::hasDeviceGuardImpl(c10::DeviceType::ORT); + static bool hasMAIA() { + return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA); } // defined in header so that getNonVariableType has ability to inline // call_once check. getNonVariableType is called fairly frequently @@ -455,8 +455,8 @@ static inline bool hasMPS() { return globalContext().hasMPS(); } -static inline bool hasORT() { - return globalContext().hasORT(); +static inline bool hasMAIA() { + return globalContext().hasMAIA(); } static inline bool hasXPU() { diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 868fdd83cc7c8..0afac10d44fbf 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -1530,13 +1530,13 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) { // XLA and lazy tensors don't have storage, so they don't have an underlying data pointer. // Nothing beyond this point is important for meta functions, so it's fine to exit early here. - // Extend the condition to ORT tesnors as ORT tensors also don't have storage. + // Extend the condition to MAIA tesnors as MAIA tensors also don't have storage. if (privateuse1_without_storage || common_device_.type() == DeviceType::MTIA || common_device_.type() == DeviceType::XLA || common_device_.type() == DeviceType::IPU || common_device_.type() == DeviceType::Lazy || - common_device_.type() == DeviceType::ORT || + common_device_.type() == DeviceType::MAIA || common_device_.type() == DeviceType::HPU) return; for (auto& op : operands_) { diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index eb71fe315d430..cf33d89e0814e 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -190,8 +190,8 @@ std::string show_config() { ss << detail::getCUDAHooks().showConfig(); } - if (hasORT()) { - ss << detail::getORTHooks().showConfig(); + if (hasMAIA()) { + ss << detail::getMAIAHooks().showConfig(); } if (hasXPU()) { diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index a94b28b86f5a1..e03c6bdf2bd10 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -507,10 +507,10 @@ class TORCH_API TensorBase { return impl_->is_mps(); } - /// Returns if a `Tensor` is ort tensor. - bool is_ort() const { + /// Returns if a `Tensor` is maia tensor. + bool is_maia() const { // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_ort(); + return impl_->is_maia(); } /// Returns if a `Tensor` is vulkan tensor. diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 627109c516daf..5f4538f2c9790 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -421,7 +421,7 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp // In theory, we should only have to check if the given runtime key has "dense" functionality, // e.g. DispatchKey::CPU (which is composed of DispatchKey::Dense and BackendComponent::CPUBit). // However, there are some backends that should be included in this set that don't have the dense key set. - // E.g. DispatchKey::Meta, DispatchKey::ORT. + // E.g. DispatchKey::Meta, DispatchKey::MAIA. if (c10::isBackendDispatchKey(dispatch_key)) { DispatchKey autograd_key = getAutogradKeyFromBackend(toBackendComponent(dispatch_key)); updateDispatchTableEntry_(dispatcher, autograd_key); diff --git a/aten/src/ATen/core/op_registration/README.md b/aten/src/ATen/core/op_registration/README.md index 5605e962a6e5e..61b41b48c4a67 100644 --- a/aten/src/ATen/core/op_registration/README.md +++ b/aten/src/ATen/core/op_registration/README.md @@ -13,13 +13,13 @@ There’s four main use cases * You’re writing a new operator that isn’t supposed to be part of the public PyTorch API. * You’re writing a new operator but don’t want to change the core pytorch code base, say you’re developing a shared library with operators. * You’re writing a C++ extension for PyTorch or you’re using inline c++ in your .py model files. -* You’re writing a backend library like XLA or ORT that adds new kernels to all operators defined in `native_functions.yaml`. +* You’re writing a backend library like XLA or MAIA that adds new kernels to all operators defined in `native_functions.yaml`. For these use cases, the custom operator API is the better solution. ### What is the price for using the custom operator API instead of `native_functions.yaml`? -If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/ORT example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats. +If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/MAIA example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats. * It will not get a C++ API generated. There will not be `Tensor::your_op()` methods or `at::your_op()` functions to call your operator. * The API for calling the operator from Python looks a little bit different. It needs to be called through `torch.ops.your_op()` instead of `torch._C`. diff --git a/aten/src/ATen/detail/MAIAHooksInterface.cpp b/aten/src/ATen/detail/MAIAHooksInterface.cpp new file mode 100644 index 0000000000000..e82ad8f677018 --- /dev/null +++ b/aten/src/ATen/detail/MAIAHooksInterface.cpp @@ -0,0 +1,29 @@ +#include + +#include +#include + +#include +#include + +namespace at { +namespace detail { + +// See getCUDAHooks for some more commentary +const MAIAHooksInterface& getMAIAHooks() { + static std::unique_ptr maia_hooks; + static c10::once_flag once; + c10::call_once(once, [] { + maia_hooks = MAIAHooksRegistry()->Create("MAIAHooks", {}); + if (!maia_hooks) { + maia_hooks = std::make_unique(); + } + }); + return *maia_hooks; +} +} // namespace detail + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_DEFINE_REGISTRY(MAIAHooksRegistry, MAIAHooksInterface, MAIAHooksArgs) + +} // namespace at diff --git a/aten/src/ATen/detail/MAIAHooksInterface.h b/aten/src/ATen/detail/MAIAHooksInterface.h new file mode 100644 index 0000000000000..ad4ef146eccd9 --- /dev/null +++ b/aten/src/ATen/detail/MAIAHooksInterface.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +// NB: Class must live in `at` due to limitations of Registry.h. +namespace at { + +struct TORCH_API MAIAHooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + virtual ~MAIAHooksInterface() = default; + + virtual std::string showConfig() const { + TORCH_CHECK(false, "Cannot query detailed MAIA version information."); + } +}; + +// NB: dummy argument to suppress "ISO C++11 requires at least one argument +// for the "..." in a variadic macro" +struct TORCH_API MAIAHooksArgs {}; + +TORCH_DECLARE_REGISTRY(MAIAHooksRegistry, MAIAHooksInterface, MAIAHooksArgs); +#define REGISTER_MAIA_HOOKS(clsname) \ + C10_REGISTER_CLASS(MAIAHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const MAIAHooksInterface& getMAIAHooks(); +} // namespace detail + +} // namespace at diff --git a/aten/src/ATen/detail/ORTHooksInterface.cpp b/aten/src/ATen/detail/ORTHooksInterface.cpp deleted file mode 100644 index bbb69809e8770..0000000000000 --- a/aten/src/ATen/detail/ORTHooksInterface.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include - -#include -#include - -#include -#include - -namespace at { -namespace detail { - -// See getCUDAHooks for some more commentary -const ORTHooksInterface& getORTHooks() { - static std::unique_ptr ort_hooks; - static c10::once_flag once; - c10::call_once(once, [] { - ort_hooks = ORTHooksRegistry()->Create("ORTHooks", {}); - if (!ort_hooks) { - ort_hooks = std::make_unique(); - } - }); - return *ort_hooks; -} -} // namespace detail - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -C10_DEFINE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs) - -} // namespace at diff --git a/aten/src/ATen/detail/ORTHooksInterface.h b/aten/src/ATen/detail/ORTHooksInterface.h deleted file mode 100644 index f49969ec66a5b..0000000000000 --- a/aten/src/ATen/detail/ORTHooksInterface.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include -#include - -constexpr const char* ORT_HELP = - " You need to 'import torch_ort' to use the 'ort' device in PyTorch. " - "The 'torch_ort' module is provided by the ONNX Runtime itself " - "(https://onnxruntime.ai)."; - -// NB: Class must live in `at` due to limitations of Registry.h. -namespace at { - -struct TORCH_API ORTHooksInterface { - // This should never actually be implemented, but it is used to - // squelch -Werror=non-virtual-dtor - virtual ~ORTHooksInterface() = default; - - virtual std::string showConfig() const { - TORCH_CHECK(false, "Cannot query detailed ORT version information.", ORT_HELP); - } -}; - -// NB: dummy argument to suppress "ISO C++11 requires at least one argument -// for the "..." in a variadic macro" -struct TORCH_API ORTHooksArgs {}; - -TORCH_DECLARE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs); -#define REGISTER_ORT_HOOKS(clsname) \ - C10_REGISTER_CLASS(ORTHooksRegistry, clsname, clsname) - -namespace detail { -TORCH_API const ORTHooksInterface& getORTHooks(); -} // namespace detail - -} // namespace at diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp index f2ce15e99ecda..4be68b1d0a710 100644 --- a/aten/src/ATen/test/extension_backend_test.cpp +++ b/aten/src/ATen/test/extension_backend_test.cpp @@ -6,8 +6,8 @@ #include -// NB. These tests use the ORT dispatch key to test backend dispatching -// machinery, but these tests are not specific to ORT at all. The ORT +// NB. These tests use the MAIA dispatch key to test backend dispatching +// machinery, but these tests are not specific to MAIA at all. The MAIA // backend is fully out-of-tree, so it's safe to use this key for // in-tree tests. @@ -22,16 +22,16 @@ Tensor empty_override(SymIntArrayRef size, c10::optional dtype, c10: Storage( Storage::use_byte_size_t(), 0, - at::DataPtr(nullptr, Device(DeviceType::ORT, 1)), + at::DataPtr(nullptr, Device(DeviceType::MAIA, 1)), nullptr, false), - DispatchKey::ORT, + DispatchKey::MAIA, caffe2::TypeMeta::Make()); return Tensor(std::move(tensor_impl)); } Tensor add_override(const Tensor & a, const Tensor & b , const Scalar& c) { - auto out = empty({5, 5}, at::kORT); // Don't return self as-is + auto out = empty({5, 5}, at::kMAIA); // Don't return self as-is test_int = 2; return out; } @@ -47,28 +47,28 @@ Tensor empty_strided_override( return empty_override(fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, c10::nullopt); } -TORCH_LIBRARY_IMPL(aten, ORT, m) { +TORCH_LIBRARY_IMPL(aten, MAIA, m) { m.impl("aten::empty.memory_format", empty_override); m.impl("aten::empty_strided", empty_strided_override); m.impl("aten::add.Tensor", add_override); } TEST(BackendExtensionTest, TestRegisterOp) { - Tensor a = empty({5, 5}, at::kORT); - ASSERT_EQ(a.device().type(), at::kORT); + Tensor a = empty({5, 5}, at::kMAIA); + ASSERT_EQ(a.device().type(), at::kMAIA); ASSERT_EQ(a.device().index(), 1); ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make()); ASSERT_EQ(test_int, 1); - Tensor b = empty_like(a, at::kORT); - ASSERT_EQ(b.device().type(), at::kORT); + Tensor b = empty_like(a, at::kMAIA); + ASSERT_EQ(b.device().type(), at::kMAIA); ASSERT_EQ(b.device().index(), 1); ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make()); add(a, b); ASSERT_EQ(test_int, 2); - // Ensure that non-ORT operator still works + // Ensure that non-MAIA operator still works Tensor d = empty({5, 5}, at::kCPU); ASSERT_EQ(d.device().type(), at::kCPU); } diff --git a/build_variables.bzl b/build_variables.bzl index ec0c31369ee53..a8b173ac3fce7 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -947,7 +947,7 @@ aten_cpu_non_globed_sources = [ "aten/src/ATen/detail/CUDAHooksInterface.cpp", "aten/src/ATen/detail/HIPHooksInterface.cpp", "aten/src/ATen/detail/MPSHooksInterface.cpp", - "aten/src/ATen/detail/ORTHooksInterface.cpp", + "aten/src/ATen/detail/MAIAHooksInterface.cpp", "aten/src/ATen/detail/PrivateUse1HooksInterface.cpp", "aten/src/ATen/detail/XPUHooksInterface.cpp", "aten/src/ATen/detail/MTIAHooksInterface.cpp", @@ -964,7 +964,7 @@ aten_cpu_non_globed_headers = [ "aten/src/ATen/detail/CUDAHooksInterface.h", "aten/src/ATen/detail/MPSHooksInterface.h", "aten/src/ATen/detail/HIPHooksInterface.h", - "aten/src/ATen/detail/ORTHooksInterface.h", + "aten/src/ATen/detail/MAIAHooksInterface.h", "aten/src/ATen/detail/PrivateUse1HooksInterface.h", "aten/src/ATen/detail/XPUHooksInterface.h", "aten/src/ATen/detail/MTIAHooksInterface.h", diff --git a/c10/core/Backend.h b/c10/core/Backend.h index d298f0d697b2c..1cf1782fa5707 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -46,7 +46,7 @@ enum class Backend { SparseCsrVE, SparseCsrXPU, SparseCsrPrivateUse1, - ORT, + MAIA, XLA, Vulkan, Metal, @@ -76,8 +76,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) { return Backend::VE; } else if (t == DispatchKey::FPGA) { return Backend::FPGA; - } else if (t == DispatchKey::ORT) { - return Backend::ORT; + } else if (t == DispatchKey::MAIA) { + return Backend::MAIA; } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) { return Backend::XLA; } else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) { @@ -154,8 +154,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) { return DispatchKey::VE; case Backend::FPGA: return DispatchKey::FPGA; - case Backend::ORT: - return DispatchKey::ORT; + case Backend::MAIA: + return DispatchKey::MAIA; case Backend::XLA: return DispatchKey::XLA; case Backend::Lazy: @@ -236,8 +236,8 @@ static inline DeviceType backendToDeviceType(Backend b) { return DeviceType::VE; case Backend::FPGA: return DeviceType::FPGA; - case Backend::ORT: - return DeviceType::ORT; + case Backend::MAIA: + return DeviceType::MAIA; case Backend::XLA: return DeviceType::XLA; case Backend::Lazy: @@ -298,8 +298,8 @@ static inline const char* toString(Backend b) { return "XPU"; case Backend::IPU: return "IPU"; - case Backend::ORT: - return "ORT"; + case Backend::MAIA: + return "MAIA"; case Backend::XLA: return "XLA"; case Backend::Lazy: diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp index 7cc97d1a33aca..1b19114663c1f 100644 --- a/c10/core/Device.cpp +++ b/c10/core/Device.cpp @@ -26,7 +26,7 @@ DeviceType parse_type(const std::string& device_string) { {"hip", DeviceType::HIP}, {"ve", DeviceType::VE}, {"fpga", DeviceType::FPGA}, - {"ort", DeviceType::ORT}, + {"maia", DeviceType::MAIA}, {"xla", DeviceType::XLA}, {"lazy", DeviceType::Lazy}, {"vulkan", DeviceType::Vulkan}, diff --git a/c10/core/Device.h b/c10/core/Device.h index c58c03c9b9adf..cbe9129852ade 100644 --- a/c10/core/Device.h +++ b/c10/core/Device.h @@ -142,9 +142,9 @@ struct C10_API Device final { return type_ == DeviceType::Metal; } - /// Return true if the device is of ORT type. - bool is_ort() const noexcept { - return type_ == DeviceType::ORT; + /// Return true if the device is of MAIA type. + bool is_maia() const noexcept { + return type_ == DeviceType::MAIA; } /// Return true if the device is of META type. diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index 0b44e1d862e13..3cd70f42e2746 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -27,8 +27,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) { return lower_case ? "ve" : "VE"; case DeviceType::FPGA: return lower_case ? "fpga" : "FPGA"; - case DeviceType::ORT: - return lower_case ? "ort" : "ORT"; + case DeviceType::MAIA: + return lower_case ? "maia" : "MAIA"; case DeviceType::XLA: return lower_case ? "xla" : "XLA"; case DeviceType::Lazy: @@ -83,7 +83,7 @@ bool isValidDeviceType(DeviceType d) { case DeviceType::HIP: case DeviceType::VE: case DeviceType::FPGA: - case DeviceType::ORT: + case DeviceType::MAIA: case DeviceType::XLA: case DeviceType::Lazy: case DeviceType::MPS: diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h index 701ea3f3bd211..911c863363f96 100644 --- a/c10/core/DeviceType.h +++ b/c10/core/DeviceType.h @@ -42,7 +42,7 @@ enum class DeviceType : int8_t { IDEEP = 5, // IDEEP. HIP = 6, // AMD HIP FPGA = 7, // FPGA - ORT = 8, // ONNX Runtime / Microsoft + MAIA = 8, // ONNX Runtime / Microsoft XLA = 9, // XLA / TPU Vulkan = 10, // Vulkan Metal = 11, // Metal @@ -66,7 +66,7 @@ constexpr DeviceType kCPU = DeviceType::CPU; constexpr DeviceType kCUDA = DeviceType::CUDA; constexpr DeviceType kHIP = DeviceType::HIP; constexpr DeviceType kFPGA = DeviceType::FPGA; -constexpr DeviceType kORT = DeviceType::ORT; +constexpr DeviceType kMAIA = DeviceType::MAIA; constexpr DeviceType kXLA = DeviceType::XLA; constexpr DeviceType kMPS = DeviceType::MPS; constexpr DeviceType kMeta = DeviceType::Meta; diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 62f1ac03e5ba4..0388234efd5b3 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -66,8 +66,8 @@ const char* toString(DispatchKey t) { return "Dense"; case DispatchKey::FPGA: return "FPGA"; - case DispatchKey::ORT: - return "ORT"; + case DispatchKey::MAIA: + return "MAIA"; case DispatchKey::Vulkan: return "Vulkan"; case DispatchKey::Metal: @@ -263,7 +263,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { {"Undefined", c10::DispatchKey::Undefined}, {"Dense", c10::DispatchKey::Dense}, {"FPGA", c10::DispatchKey::FPGA}, - {"ORT", c10::DispatchKey::ORT}, + {"MAIA", c10::DispatchKey::MAIA}, {"MPS", c10::DispatchKey::MPS}, {"Vulkan", c10::DispatchKey::Vulkan}, {"Metal", c10::DispatchKey::Metal}, diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 0219db40edccb..71277ebfd891e 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -181,13 +181,11 @@ enum class DispatchKey : uint16_t { // https://gitlab.com/pytorch-complex/vitis_kernels // TODO: put this in BackendComponents - // ONNX Runtime, lives out of tree at https://github.com/pytorch/ort and - // https://github.com/microsoft/onnxruntime, and is also used to test general - // backend/extension machinery in the core. cf: - // - test/cpp_extensions/ort_extension.cpp + // MAIA backend lives out of tree + // - test/cpp_extensions/maia_extension.cpp // - test/test_torch.py // - aten/src/ATen/test/extension_backend_test.cpp - ORT, + MAIA, Vulkan, // TODO: put this in BackendComponents Metal, // TODO: put this in BackendComponents diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index db2e94fd8cdc0..f7461ea73a6dd 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -702,7 +702,7 @@ constexpr DispatchKeySet autogradother_backends = // Technically, HIP will now redispatch to its own custom AutogradHIP // slot in the runtime table. {DispatchKey::FPGA, - DispatchKey::ORT, + DispatchKey::MAIA, DispatchKey::Vulkan, DispatchKey::Metal, DispatchKey::CustomRNGKeyId, diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 40a65cb10788d..95e7a0e3b6117 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1204,11 +1204,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return device_opt_.has_value() && device_opt_->type() == kMPS; } - bool is_ort() const { + bool is_maia() const { if (C10_UNLIKELY(device_policy_)) { - return device_custom().is_ort(); + return device_custom().is_maia(); } - return device_opt_.has_value() && device_opt_->type() == kORT; + return device_opt_.has_value() && device_opt_->type() == kMAIA; } bool is_nested() const { diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 2d9e4a24331ef..765f474702ef7 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -653,8 +653,8 @@ inline DispatchKey computeDispatchKey( #undef DO_CASE case c10::DeviceType::FPGA: return DispatchKey::FPGA; - case c10::DeviceType::ORT: - return DispatchKey::ORT; + case c10::DeviceType::MAIA: + return DispatchKey::MAIA; case c10::DeviceType::Vulkan: return DispatchKey::Vulkan; case c10::DeviceType::Metal: @@ -757,8 +757,8 @@ inline c10::DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) { case DispatchKey::Vulkan: return c10::DeviceType::Vulkan; - case DispatchKey::ORT: - return c10::DeviceType::ORT; + case DispatchKey::MAIA: + return c10::DeviceType::MAIA; default: TORCH_CHECK( false, diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto index 861a6c5d43740..077e7b0ed5446 100644 --- a/caffe2/proto/caffe2.proto +++ b/caffe2/proto/caffe2.proto @@ -218,7 +218,7 @@ enum DeviceTypeProto { PROTO_IDEEP = 5; // IDEEP. PROTO_HIP = 6; // AMD HIP PROTO_FPGA = 7; // FPGA - PROTO_ORT = 8; // ONNX Runtime + PROTO_MAIA = 8; // MAIA PROTO_XLA = 9; // XLA / TPU PROTO_MPS = 10; // MPS // Change the following number if you add more devices in the code. diff --git a/caffe2/proto/caffe2_pb2.pyi b/caffe2/proto/caffe2_pb2.pyi index ed1f4249a43ee..43249ebf75dbd 100644 --- a/caffe2/proto/caffe2_pb2.pyi +++ b/caffe2/proto/caffe2_pb2.pyi @@ -23,7 +23,7 @@ class _DeviceTypeProto(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapp PROTO_IDEEP = DeviceTypeProto.V(5) PROTO_HIP = DeviceTypeProto.V(6) PROTO_FPGA = DeviceTypeProto.V(7) - PROTO_ORT = DeviceTypeProto.V(8) + PROTO_MAIA = DeviceTypeProto.V(8) PROTO_XLA = DeviceTypeProto.V(9) PROTO_MPS = DeviceTypeProto.V(10) PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11) @@ -37,7 +37,7 @@ PROTO_OPENCL = DeviceTypeProto.V(4) PROTO_IDEEP = DeviceTypeProto.V(5) PROTO_HIP = DeviceTypeProto.V(6) PROTO_FPGA = DeviceTypeProto.V(7) -PROTO_ORT = DeviceTypeProto.V(8) +PROTO_MAIA = DeviceTypeProto.V(8) PROTO_XLA = DeviceTypeProto.V(9) PROTO_MPS = DeviceTypeProto.V(10) PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11) diff --git a/test/cpp_extensions/ort_extension.cpp b/test/cpp_extensions/maia_extension.cpp similarity index 78% rename from test/cpp_extensions/ort_extension.cpp rename to test/cpp_extensions/maia_extension.cpp index b646f3b14939d..13315810f54c4 100644 --- a/test/cpp_extensions/ort_extension.cpp +++ b/test/cpp_extensions/maia_extension.cpp @@ -10,10 +10,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) { Storage( Storage::use_byte_size_t(), 0, - at::DataPtr(nullptr, Device(DeviceType::ORT, 0)), + at::DataPtr(nullptr, Device(DeviceType::MAIA, 0)), nullptr, false), - DispatchKey::ORT, + DispatchKey::MAIA, dtype); // This is a hack to workaround the shape checks in _convolution. tensor_impl->set_sizes_contiguous(size); @@ -52,7 +52,7 @@ std::tuple fake_convolution_backward( get_tensor(input.dtype(), {})); } -TORCH_LIBRARY_IMPL(aten, ORT, m) { +TORCH_LIBRARY_IMPL(aten, MAIA, m) { m.impl("empty.memory_format", empty_override); m.impl("add.out", add_out_override); m.impl("convolution_overrideable", fake_convolution); @@ -61,34 +61,34 @@ TORCH_LIBRARY_IMPL(aten, ORT, m) { // TODO: Extend this to exercise multi-device setting. In that case, // we need to add a thread local variable to track the current device. -struct ORTGuardImpl final : public c10::impl::DeviceGuardImplInterface { - static constexpr DeviceType static_type = DeviceType::ORT; - ORTGuardImpl() {} - ORTGuardImpl(DeviceType t) { - AT_ASSERT(t == DeviceType::ORT); +struct MAIAGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr DeviceType static_type = DeviceType::MAIA; + MAIAGuardImpl() {} + MAIAGuardImpl(DeviceType t) { + AT_ASSERT(t == DeviceType::MAIA); } DeviceType type() const override { - return DeviceType::ORT; + return DeviceType::MAIA; } Device exchangeDevice(Device d) const override { - AT_ASSERT(d.type() == DeviceType::ORT); + AT_ASSERT(d.type() == DeviceType::MAIA); AT_ASSERT(d.index() == 0); return d; } Device getDevice() const override { - return Device(DeviceType::ORT, 0); + return Device(DeviceType::MAIA, 0); } void setDevice(Device d) const override { - AT_ASSERT(d.type() == DeviceType::ORT); + AT_ASSERT(d.type() == DeviceType::MAIA); AT_ASSERT(d.index() == 0); } void uncheckedSetDevice(Device d) const noexcept override { } Stream getStream(Device d) const noexcept override { - return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0)); + return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0)); } Stream exchangeStream(Stream s) const noexcept override { - return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0)); + return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0)); } DeviceIndex deviceCount() const noexcept override { return 1; @@ -99,23 +99,23 @@ struct ORTGuardImpl final : public c10::impl::DeviceGuardImplInterface { const Stream& stream, const DeviceIndex device_index, const EventFlag flag) const override { - TORCH_CHECK(false, "ORT backend doesn't support events."); + TORCH_CHECK(false, "MAIA backend doesn't support events."); } void block( void* event, const Stream& stream) const override { - TORCH_CHECK(false, "ORT backend doesn't support events."); + TORCH_CHECK(false, "MAIA backend doesn't support events."); } bool queryEvent(void* event) const override { - TORCH_CHECK(false, "ORT backend doesn't support events."); + TORCH_CHECK(false, "MAIA backend doesn't support events."); } void destroyEvent( void* event, const DeviceIndex device_index) const noexcept override { } }; -constexpr DeviceType ORTGuardImpl::static_type; -C10_REGISTER_GUARD_IMPL(ORT, ORTGuardImpl); +constexpr DeviceType MAIAGuardImpl::static_type; +C10_REGISTER_GUARD_IMPL(MAIA, MAIAGuardImpl); int get_test_int() { return test_int; diff --git a/test/cpp_extensions/setup.py b/test/cpp_extensions/setup.py index 3731dc8c91d0a..4d4288a3076fc 100644 --- a/test/cpp_extensions/setup.py +++ b/test/cpp_extensions/setup.py @@ -28,8 +28,8 @@ "torch_test_cpp_extension.cpp", ["extension.cpp"], extra_compile_args=CXX_FLAGS ), CppExtension( - "torch_test_cpp_extension.ort", - ["ort_extension.cpp"], + "torch_test_cpp_extension.maia", + ["maia_extension.cpp"], extra_compile_args=CXX_FLAGS, ), CppExtension( diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 1d5df82a1259e..3e5ce5cfcef45 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -26,11 +26,11 @@ try: if HAS_PYTEST: cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp") - ort_extension = pytest.importorskip("torch_test_cpp_extension.ort") + maia_extension = pytest.importorskip("torch_test_cpp_extension.maia") rng_extension = pytest.importorskip("torch_test_cpp_extension.rng") else: import torch_test_cpp_extension.cpp as cpp_extension - import torch_test_cpp_extension.ort as ort_extension + import torch_test_cpp_extension.maia as maia_extension import torch_test_cpp_extension.rng as rng_extension except ImportError as e: raise RuntimeError( @@ -255,46 +255,46 @@ def test_pybind_return_types(self): @torch.testing._internal.common_utils.markDynamoStrictTest -class TestORTTensor(common.TestCase): +class TestMAIATensor(common.TestCase): def test_unregistered(self): a = torch.arange(0, 10, device="cpu") with self.assertRaisesRegex(RuntimeError, "Could not run"): - b = torch.arange(0, 10, device="ort") + b = torch.arange(0, 10, device="maia") - @skipIfTorchDynamo("dynamo cannot model ort device") + @skipIfTorchDynamo("dynamo cannot model maia device") def test_zeros(self): a = torch.empty(5, 5, device="cpu") self.assertEqual(a.device, torch.device("cpu")) - b = torch.empty(5, 5, device="ort") - self.assertEqual(b.device, torch.device("ort", 0)) - self.assertEqual(ort_extension.get_test_int(), 0) + b = torch.empty(5, 5, device="maia") + self.assertEqual(b.device, torch.device("maia", 0)) + self.assertEqual(maia_extension.get_test_int(), 0) self.assertEqual(torch.get_default_dtype(), b.dtype) - c = torch.empty((5, 5), dtype=torch.int64, device="ort") - self.assertEqual(ort_extension.get_test_int(), 0) + c = torch.empty((5, 5), dtype=torch.int64, device="maia") + self.assertEqual(maia_extension.get_test_int(), 0) self.assertEqual(torch.int64, c.dtype) def test_add(self): - a = torch.empty(5, 5, device="ort", requires_grad=True) - self.assertEqual(ort_extension.get_test_int(), 0) + a = torch.empty(5, 5, device="maia", requires_grad=True) + self.assertEqual(maia_extension.get_test_int(), 0) - b = torch.empty(5, 5, device="ort") - self.assertEqual(ort_extension.get_test_int(), 0) + b = torch.empty(5, 5, device="maia") + self.assertEqual(maia_extension.get_test_int(), 0) c = a + b - self.assertEqual(ort_extension.get_test_int(), 1) + self.assertEqual(maia_extension.get_test_int(), 1) def test_conv_backend_override(self): # To simplify tests, we use 4d input here to avoid doing view4d( which # needs more overrides) in _convolution. - input = torch.empty(2, 4, 10, 2, device="ort", requires_grad=True) - weight = torch.empty(6, 4, 2, 2, device="ort", requires_grad=True) - bias = torch.empty(6, device="ort") + input = torch.empty(2, 4, 10, 2, device="maia", requires_grad=True) + weight = torch.empty(6, 4, 2, 2, device="maia", requires_grad=True) + bias = torch.empty(6, device="maia") # Make sure forward is overriden out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1) - self.assertEqual(ort_extension.get_test_int(), 2) + self.assertEqual(maia_extension.get_test_int(), 2) self.assertEqual(out.shape[0], input.shape[0]) self.assertEqual(out.shape[1], weight.shape[0]) @@ -302,7 +302,7 @@ def test_conv_backend_override(self): # Double backward is dispatched to _convolution_double_backward. # It is not tested here as it involves more computation/overrides. grad = torch.autograd.grad(out, input, out, create_graph=True) - self.assertEqual(ort_extension.get_test_int(), 3) + self.assertEqual(maia_extension.get_test_int(), 3) self.assertEqual(grad[0].shape, input.shape) diff --git a/third_party/ideep b/third_party/ideep index 8a6cc4e09dc50..6c581ef0fdd48 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit 8a6cc4e09dc509f04f83c085e38786b1fb44e14d +Subproject commit 6c581ef0fdd487e51bb1518f9473dfcc6ec15415 diff --git a/third_party/kineto b/third_party/kineto index 47911e2326097..8466a8b111b36 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 47911e232609720f551b7a412076899d8a16a744 +Subproject commit 8466a8b111b36dc725e6855d52a0b133d925a8e0 diff --git a/third_party/onnx b/third_party/onnx index 990217f043af7..ccde5da81388f 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit 990217f043af7222348ca8f0301e17fa7b841781 +Subproject commit ccde5da81388ffa770ca98b64e07f803ad089414 diff --git a/third_party/pybind11 b/third_party/pybind11 index 3e9dfa2866941..8a099e44b3d5f 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit 3e9dfa2866941655c56877882565e7577de6fc7b +Subproject commit 8a099e44b3d5f85b20f05828d919d2332a8de841 diff --git a/third_party/sleef b/third_party/sleef index 60e76d2bce17d..e0a003ee838b7 160000 --- a/third_party/sleef +++ b/third_party/sleef @@ -1 +1 @@ -Subproject commit 60e76d2bce17d278b439d9da17177c8f957a9e9b +Subproject commit e0a003ee838b75d11763aa9c3ef17bf71a725bff diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 369f1504bf48f..f0b9044c6fe9c 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1160,7 +1160,7 @@ def replace_special_case(hint: str) -> str: "is_meta": ["is_meta: _bool"], "is_mps": ["is_mps: _bool"], "is_mtia": ["is_mtia: _bool"], - "is_ort": ["is_ort: _bool"], + "is_maia": ["is_maia: _bool"], "is_mkldnn": ["is_mkldnn: _bool"], "is_vulkan": ["is_vulkan: _bool"], "is_ipu": ["is_ipu: _bool"], diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 2c50a28bfbf6f..34eb451be08c0 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -22,7 +22,7 @@ class DeviceType(Enum): IDEEP = ... HIP = ... FPGA = ... - ORT = ... + MAIA = ... XLA = ... MPS = ... HPU = ... diff --git a/torch/_tensor.py b/torch/_tensor.py index 0ce59ca924bd5..4ae1ff943c885 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -101,7 +101,7 @@ def __deepcopy__(self, memo): if ( self.is_sparse or self.device.type - in ["lazy", "xla", "mtia", "mps", "ort", "meta", "ipu"] + in ["lazy", "xla", "mtia", "mps", "maia", "meta", "ipu"] or ( not torch._C._has_storage(self) and self.device.type == torch._C._get_privateuse1_backend_name() @@ -249,7 +249,7 @@ def _reduce_ex_internal(self, proto): # See Note [Don't serialize hooks] torch.utils.hooks.warn_if_has_hooks(self) backward_hooks: Dict[Any, Any] = OrderedDict() - # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, ORT Tensors. + # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. # We considered a few options: # 1. CPU tensor can't be used here. # Otherwise in torch.load CPU storage is reconstructed with randomly @@ -259,7 +259,7 @@ def _reduce_ex_internal(self, proto): # 2. Python list is not a good fit due to performance reason. # `tolist()` converts every single element in the tensor into python objects # and serialize them one by one. - if self.device.type in ["xla", "mtia", "ort"] or ( + if self.device.type in ["xla", "mtia", "maia"] or ( not torch._C._has_storage(self) and self.device.type == torch._C._get_privateuse1_backend_name() ): diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 2bea7c4cda5c0..8edf23cd2ec0b 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -162,7 +162,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { .value("IDEEP", c10::DeviceType::IDEEP) .value("HIP", c10::DeviceType::HIP) .value("FPGA", c10::DeviceType::FPGA) - .value("ORT", c10::DeviceType::ORT) + .value("MAIA", c10::DeviceType::MAIA) .value("XLA", c10::DeviceType::XLA) .value("Vulkan", c10::DeviceType::Vulkan) .value("Metal", c10::DeviceType::Metal) diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index ea55bb55dd243..3705ac5e423e0 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1451,13 +1451,13 @@ PyObject* THPVariable_is_mps(THPVariable* self, void* unused) { END_HANDLE_TH_ERRORS } -PyObject* THPVariable_is_ort(THPVariable* self, void* unused) { +PyObject* THPVariable_is_maia(THPVariable* self, void* unused) { HANDLE_TH_ERRORS if (check_has_torch_function((PyObject*)self)) { - return handle_torch_function_getter(self, "is_ort"); + return handle_torch_function_getter(self, "is_maia"); } auto& self_ = THPVariable_Unpack(self); - return torch::autograd::utils::wrap(self_.is_ort()); + return torch::autograd::utils::wrap(self_.is_maia()); END_HANDLE_TH_ERRORS } @@ -1674,7 +1674,7 @@ static struct PyGetSetDef THPVariable_properties[] = { nullptr}, {"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr}, {"is_mps", (getter)THPVariable_is_mps, nullptr, nullptr, nullptr}, - {"is_ort", (getter)THPVariable_is_ort, nullptr, nullptr, nullptr}, + {"is_maia", (getter)THPVariable_is_maia, nullptr, nullptr, nullptr}, {"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr}, {"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr}, {"is_quantized", diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index e9f090cfbbe09..80b5d27fba079 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -145,7 +145,7 @@ std::shared_ptr SimpleValue::attr( {"H", "prim"}, {"mT", "aten"}, {"mH", "aten"}, - {"is_ort", "prim"}, + {"is_maia", "prim"}, {"itemsize", "prim"}, {"nbytes", "prim"}, {"ndim", "prim"}, diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index 4d8a0cd89d8ff..cec9c70bc7b67 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -2431,11 +2431,11 @@ static const std::vector opGenArgs1{ }, aliasAnalysisFromSchema()), OperatorGeneratorArgs( - TORCH_SELECTIVE_SCHEMA("prim::is_ort(Tensor a) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::is_maia(Tensor a) -> bool"), [](Stack& stack) { at::Tensor a; pop(stack, a); - push(stack, a.is_ort()); + push(stack, a.is_maia()); }, aliasAnalysisFromSchema()), OperatorGeneratorArgs( diff --git a/torch/library.h b/torch/library.h index fcac0e80942da..c38179a6eea1d 100644 --- a/torch/library.h +++ b/torch/library.h @@ -370,8 +370,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) { return c10::DispatchKey::Meta; case c10::DeviceType::HIP: return c10::DispatchKey::HIP; - case c10::DeviceType::ORT: - return c10::DispatchKey::ORT; + case c10::DeviceType::MAIA: + return c10::DispatchKey::MAIA; case c10::DeviceType::HPU: return c10::DispatchKey::HPU; case c10::DeviceType::MTIA: diff --git a/torch/overrides.py b/torch/overrides.py index 6a5d3e891dc8a..9f99ee0c54dde 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1283,7 +1283,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.is_mps.__get__: lambda self: -1, Tensor.is_mtia.__get__: lambda self: -1, Tensor.is_nested.__get__: lambda self: -1, - Tensor.is_ort.__get__: lambda self: -1, + Tensor.is_maia.__get__: lambda self: -1, Tensor.is_mkldnn.__get__: lambda self: -1, Tensor.is_quantized.__get__: lambda self: -1, Tensor.is_sparse.__get__: lambda self: -1, diff --git a/torchgen/model.py b/torchgen/model.py index 7b0dd8cc1feee..2706f234c56b0 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -79,7 +79,7 @@ class DispatchKey(Enum): CatchAll = Undefined FPGA = auto() - ORT = auto() + MAIA = auto() Vulkan = auto() Metal = auto() MKLDNN = auto() From 97daf6efc5e55b063175d4474fc0b0b535817035 Mon Sep 17 00:00:00 2001 From: Ashwin Hari Date: Thu, 18 Apr 2024 10:20:30 -0700 Subject: [PATCH 2/5] revert changes to thirdparty --- third_party/ideep | 2 +- third_party/kineto | 2 +- third_party/onnx | 2 +- third_party/pybind11 | 2 +- third_party/sleef | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/third_party/ideep b/third_party/ideep index 6c581ef0fdd48..8a6cc4e09dc50 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit 6c581ef0fdd487e51bb1518f9473dfcc6ec15415 +Subproject commit 8a6cc4e09dc509f04f83c085e38786b1fb44e14d diff --git a/third_party/kineto b/third_party/kineto index 8466a8b111b36..47911e2326097 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 8466a8b111b36dc725e6855d52a0b133d925a8e0 +Subproject commit 47911e232609720f551b7a412076899d8a16a744 diff --git a/third_party/onnx b/third_party/onnx index ccde5da81388f..990217f043af7 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit ccde5da81388ffa770ca98b64e07f803ad089414 +Subproject commit 990217f043af7222348ca8f0301e17fa7b841781 diff --git a/third_party/pybind11 b/third_party/pybind11 index 8a099e44b3d5f..3e9dfa2866941 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit 8a099e44b3d5f85b20f05828d919d2332a8de841 +Subproject commit 3e9dfa2866941655c56877882565e7577de6fc7b diff --git a/third_party/sleef b/third_party/sleef index e0a003ee838b7..60e76d2bce17d 160000 --- a/third_party/sleef +++ b/third_party/sleef @@ -1 +1 @@ -Subproject commit e0a003ee838b75d11763aa9c3ef17bf71a725bff +Subproject commit 60e76d2bce17d278b439d9da17177c8f957a9e9b From 6a8e19a612ec126f1ac6c471d4e09dc33dbed678 Mon Sep 17 00:00:00 2001 From: Ashwin Hari Date: Thu, 18 Apr 2024 11:14:21 -0700 Subject: [PATCH 3/5] fix case --- torch/csrc/Storage.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index c22e6f5d1b95d..a3f8263303782 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -355,7 +355,7 @@ static PyObject* THPStorage_pynew( } else if (device.type() == at::DeviceType::PrivateUse1) { at::globalContext().lazyInitPrivateUse1(); allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::ORT) { + } else if (device.type() == at::DeviceType::MAIA) { allocator = c10::GetAllocator(device.type()); } else { // NOLINTEND(bugprone-branch-clone) From eb43a6a1f9505ff878250a2876be6ce5dd88eea2 Mon Sep 17 00:00:00 2001 From: Ashwin Hari Date: Thu, 18 Apr 2024 11:40:22 -0700 Subject: [PATCH 4/5] applied lint warning --- aten/src/ATen/Context.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index e85af59b0dbe8..32b22855f939b 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -12,9 +12,9 @@ #include #include #include +#include #include #include -#include #include #include #include From cbc83328e06eecd6e16875b714853226a1ebe958 Mon Sep 17 00:00:00 2001 From: Ashwin Hari Date: Mon, 22 Apr 2024 12:04:08 -0700 Subject: [PATCH 5/5] add to allow list --- .../check_forward_backward_compatibility.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 5a4aac572c17c..093e27154ea31 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -46,6 +46,7 @@ ("prim::ModuleDictIndex", datetime.date(9999, 1, 1)), ("prim::MKLDNNRelu6", datetime.date(9999, 1, 1)), ("prim::MKLDNNRelu6_", datetime.date(9999, 1, 1)), + ("prim::is_ort", datetime.date(9999, 1, 1)), ("prim::Concat", datetime.date(9999, 1, 1)), ("aten::_NestedTensor_GeneralizedBMM", datetime.date(9999, 1, 1)), # Internal, profiler-specific ops