diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 3da7f8d1c18e..cf52ec32ea1f 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -29,12 +29,12 @@ #define TVM_IR_ATTRS_H_ #include +#include +#include #include #include #include #include -#include -#include #include #include diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 264198333e6f..65d9a6f387fe 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -26,7 +26,7 @@ #include #include -#include +#include #include #include diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index faf2c18c1cac..7c4b8e7cb2bd 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -28,7 +28,9 @@ #include #include #include -#include +#include +#include +#include #include #include diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h index c14549f41283..cfd859406df8 100644 --- a/include/tvm/ir/instrument.h +++ b/include/tvm/ir/instrument.h @@ -28,7 +28,7 @@ #include #include -#include +#include #include #include diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index becd19ed70be..543c895ce519 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include diff --git a/include/tvm/ir/serialization.h b/include/tvm/ir/serialization.h new file mode 100644 index 000000000000..59bdb87067f3 --- /dev/null +++ b/include/tvm/ir/serialization.h @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/serialization.h + * \brief Utility functions for serialization. + * + * This is a thin forwarding header to ffi/extra/serialization.h. + * Prefer using ffi::ToJSONGraph / ffi::FromJSONGraph directly. + */ +#ifndef TVM_IR_SERIALIZATION_H_ +#define TVM_IR_SERIALIZATION_H_ + +#include +#include +#include + +#include + +namespace tvm { + +/*! + * \brief Save the node as well as all the node it depends on as json. + * This can be used to serialize any TVM object. + * + * \return the string representation of the node. + */ +TVM_DLL std::string SaveJSON(ffi::Any node); + +/*! + * \brief Load tvm Node object from json and return a shared_ptr of Node. + * \param json_str The json string to load from. + * + * \return The shared_ptr of the Node. + */ +TVM_DLL ffi::Any LoadJSON(std::string json_str); + +} // namespace tvm +#endif // TVM_IR_SERIALIZATION_H_ diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index 60a30ffe1709..19aba9461cca 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -23,9 +23,9 @@ #ifndef TVM_IR_SOURCE_MAP_H_ #define TVM_IR_SOURCE_MAP_H_ +#include #include #include -#include #include #include diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 5e38f3876937..902778c3db02 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -52,7 +52,6 @@ #include #include #include -#include #include #include diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h deleted file mode 100644 index 734a28c13301..000000000000 --- a/include/tvm/node/node.h +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/node/node.h - * \brief Definitions and helper macros for IR/AST nodes. - * - * The node folder contains base utilities for IR/AST nodes, - * invariant of which specific language dialect. - * - * We implement AST/IR nodes as sub-classes of runtime::Object. - * The base class Node is just an alias of runtime::Object. - * - * Besides the runtime type checking provided by Object, - * node folder contains additional functionalities such as - * reflection and serialization, which are important features - * for building a compiler infra. - */ -#ifndef TVM_NODE_NODE_H_ -#define TVM_NODE_NODE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { - -using ffi::Any; -using ffi::AnyView; -using ffi::Object; -using ffi::ObjectPtr; -using ffi::ObjectPtrEqual; -using ffi::ObjectPtrHash; -using ffi::ObjectRef; -using ffi::PackedArgs; -using ffi::TypeIndex; - -} // namespace tvm -#endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h deleted file mode 100644 index d5716f96f6d5..000000000000 --- a/include/tvm/node/reflection.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/node/reflection.h - * \brief Reflection utilities for IR/AST nodes. - */ -#ifndef TVM_NODE_REFLECTION_H_ -#define TVM_NODE_REFLECTION_H_ - -#include -#include - -namespace tvm { - -/*! - * \brief Create an object from a type key and a map of fields. - * \param type_key The type key of the object. - * \param fields The fields of the object. - * \return The created object. - */ -TVM_DLL ffi::Any CreateObject(const ffi::String& type_key, - const ffi::Map& fields); - -} // namespace tvm -#endif // TVM_NODE_REFLECTION_H_ diff --git a/include/tvm/node/serialization.h b/include/tvm/node/serialization.h index 5a8e098cfd6e..892ff5fdf19c 100644 --- a/include/tvm/node/serialization.h +++ b/include/tvm/node/serialization.h @@ -18,34 +18,13 @@ */ /*! - * Utility functions for serialization. * \file tvm/node/serialization.h + * \brief Forwarding header. Use tvm/ir/serialization.h instead. */ #ifndef TVM_NODE_SERIALIZATION_H_ #define TVM_NODE_SERIALIZATION_H_ -#include -#include +// This header has moved to tvm/ir/serialization.h +#include -#include - -namespace tvm { -/*! - * \brief save the node as well as all the node it depends on as json. - * This can be used to serialize any TVM object - * - * \return the string representation of the node. - */ -TVM_DLL std::string SaveJSON(ffi::Any node); - -/*! - * \brief Internal implementation of LoadJSON - * Load tvm Node object from json and return a shared_ptr of Node. - * \param json_str The json string to load from. - * - * \return The shared_ptr of the Node. - */ -TVM_DLL ffi::Any LoadJSON(std::string json_str); - -} // namespace tvm #endif // TVM_NODE_SERIALIZATION_H_ diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 4f00e1770b41..cbf7652b8093 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -18,96 +18,12 @@ */ /*! * \file tvm/node/structural_equal.h - * \brief Structural equality comparison. + * \brief Forwarding header. Use tvm/ffi/extra/structural_equal.h instead. */ #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_ #define TVM_NODE_STRUCTURAL_EQUAL_H_ -#include -#include -#include -#include +// This header has moved to tvm/ffi/extra/structural_equal.h +#include -#include -#include - -namespace tvm { - -/*! - * \brief Equality definition of base value class. - */ -class BaseValueEqual { - public: - bool operator()(const double& lhs, const double& rhs) const { - if (std::isnan(lhs) && std::isnan(rhs)) { - // IEEE floats do not compare as equivalent to each other. - // However, for the purpose of comparing IR representation, two - // NaN values are equivalent. - return true; - } else if (std::isnan(lhs) || std::isnan(rhs)) { - return false; - } else if (lhs == rhs) { - return true; - } else { - // fuzzy float pt comparison - constexpr double atol = 1e-9; - double diff = lhs - rhs; - return diff > -atol && diff < atol; - } - } - - bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } - bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; } - bool operator()(const ffi::Optional& lhs, const ffi::Optional& rhs) const { - return lhs == rhs; - } - bool operator()(const ffi::Optional& lhs, const ffi::Optional& rhs) const { - return lhs == rhs; - } - bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; } - bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; } - bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; } - bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; } - template ::value>::type> - bool operator()(const ENum& lhs, const ENum& rhs) const { - return lhs == rhs; - } -}; - -/*! - * \brief Content-aware structural equality comparator for objects. - * - * The structural equality is recursively defined in the DAG of IR nodes via SEqual. - * There are two kinds of nodes: - * - * - Graph node: a graph node in lhs can only be mapped as equal to - * one and only one graph node in rhs. - * - Normal node: equality is recursively defined without the restriction - * of graph nodes. - * - * Vars(tir::Var, relax::Var) nodes are graph nodes. - * - * A var-type node(e.g. tir::Var) can be mapped as equal to another var - * with the same type if one of the following condition holds: - * - * - They appear in a same definition point(e.g. function argument). - * - They points to the same VarNode via the same_as relation. - * - They appear in a same usage point, and map_free_vars is set to be True. - */ -class StructuralEqual : public BaseValueEqual { - public: - // inheritate operator() - using BaseValueEqual::operator(); - /*! - * \brief Compare objects via strutural equal. - * \param lhs The left operand. - * \param rhs The right operand. - * \param map_free_params Whether or not to map free variables. - * \return The comparison result. - */ - TVM_DLL bool operator()(const ffi::Any& lhs, const ffi::Any& rhs, - const bool map_free_params = false) const; -}; - -} // namespace tvm #endif // TVM_NODE_STRUCTURAL_EQUAL_H_ diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index ba7cbaf88aa6..8f90820b150d 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -17,113 +17,13 @@ * under the License. */ /*! - * \file tvm/node/structural_equal.h - * \brief Structural hash class. + * \file tvm/node/structural_hash.h + * \brief Forwarding header. Use tvm/ffi/extra/structural_hash.h instead. */ #ifndef TVM_NODE_STRUCTURAL_HASH_H_ #define TVM_NODE_STRUCTURAL_HASH_H_ -#include -#include -#include +// This header has moved to tvm/ffi/extra/structural_hash.h +#include -#include -#include -#include -#include - -namespace tvm { - -/*! - * \brief Hash definition of base value classes. - */ -class BaseValueHash { - protected: - template - uint64_t Reinterpret(T value) const { - union Union { - T a; - U b; - } u; - static_assert(sizeof(Union) == sizeof(T), "sizeof(Union) != sizeof(T)"); - static_assert(sizeof(Union) == sizeof(U), "sizeof(Union) != sizeof(U)"); - u.b = 0; - u.a = value; - return u.b; - } - - public: - uint64_t operator()(const float& key) const { return Reinterpret(key); } - uint64_t operator()(const double& key) const { - if (std::isnan(key)) { - // The IEEE format defines more than one bit-pattern that - // represents NaN. For the purpose of comparing IR - // representations, all NaN values are considered equivalent. - return Reinterpret(std::numeric_limits::quiet_NaN()); - } else { - return Reinterpret(key); - } - } - uint64_t operator()(const int64_t& key) const { return Reinterpret(key); } - uint64_t operator()(const uint64_t& key) const { return key; } - uint64_t operator()(const int& key) const { return Reinterpret(key); } - uint64_t operator()(const bool& key) const { return key; } - uint64_t operator()(const runtime::DataType& key) const { - return Reinterpret(key); - } - template ::value>::type> - uint64_t operator()(const ENum& key) const { - return Reinterpret(static_cast(key)); - } - uint64_t operator()(const std::string& key) const { - return tvm::ffi::details::StableHashBytes(key.data(), key.length()); - } - uint64_t operator()(const ffi::Optional& key) const { - if (key.has_value()) { - return Reinterpret(*key); - } else { - return 0; - } - } - uint64_t operator()(const ffi::Optional& key) const { - if (key.has_value()) { - return Reinterpret(*key); - } else { - return 0; - } - } - /*! - * \brief Compute structural hash value for a POD value in Any. - * \param key The Any object. - * \return The hash value. - */ - TVM_FFI_INLINE uint64_t HashPODValueInAny(const ffi::Any& key) const { - return ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(key)->v_uint64; - } -}; - -/*! - * \brief Content-aware structural hashing. - * - * The structural hash value is recursively defined in the DAG of IRNodes. - * There are two kinds of nodes: - * - * - Normal node: the hash value is defined by its content and type only. - * - Graph node: each graph node will be assigned a unique index ordered by the - * first occurrence during the visit. The hash value of a graph node is - * combined from the hash values of its contents and the index. - */ -class StructuralHash : public BaseValueHash { - public: - // inherit operator() - using BaseValueHash::operator(); - /*! - * \brief Compute structural hashing value for an object. - * \param key The left operand. - * \return The hash value. - */ - TVM_DLL uint64_t operator()(const ffi::Any& key) const; -}; - -} // namespace tvm #endif // TVM_NODE_STRUCTURAL_HASH_H_ diff --git a/include/tvm/relax/distributed/axis_group_graph.h b/include/tvm/relax/distributed/axis_group_graph.h index 6ea322938f06..6dc2022d5f19 100644 --- a/include/tvm/relax/distributed/axis_group_graph.h +++ b/include/tvm/relax/distributed/axis_group_graph.h @@ -251,7 +251,7 @@ using AxisShardingSpec = std::pair; class AxisShardingSpecEqual { public: bool operator()(const AxisShardingSpec& lhs, const AxisShardingSpec& rhs) const { - return StructuralEqual()(lhs.first, rhs.first) && lhs.second == rhs.second; + return ffi::StructuralEqual()(lhs.first, rhs.first) && lhs.second == rhs.second; } }; @@ -259,7 +259,7 @@ class AxisShardingSpecHash { public: size_t operator()(const AxisShardingSpec& sharding_spec) const { size_t seed = 0; - seed ^= StructuralHash()(sharding_spec.first); + seed ^= ffi::StructuralHash()(sharding_spec.first); seed ^= std::hash()(sharding_spec.second) << 1; return seed; } diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h index 222dea3fb1f2..66bae5411a16 100644 --- a/include/tvm/relax/exec_builder.h +++ b/include/tvm/relax/exec_builder.h @@ -23,6 +23,8 @@ #ifndef TVM_RELAX_EXEC_BUILDER_H_ #define TVM_RELAX_EXEC_BUILDER_H_ +#include +#include #include #include #include @@ -178,7 +180,8 @@ class ExecBuilderNode : public Object { /*! \brief The mutable internal executable. */ ObjectPtr exec_; // mutable /*! \brief internal dedup map when creating index for a new constant */ - std::unordered_map const_dedup_map_; + std::unordered_map + const_dedup_map_; }; class ExecBuilder : public ObjectRef { diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 9b4fa913795f..f8cebafa551c 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -25,9 +25,9 @@ #include #include #include -#include #include #include +#include #include #include diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index 12b97e20c21d..c51a6db5a2a0 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -22,10 +22,11 @@ #include #include #include -#include +#include #include #include #include +#include #include diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 80279a4862e0..e186e85b9d92 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -139,7 +139,10 @@ static_assert(static_cast(TypeIndex::kCustomStaticIndex) >= } // namespace runtime +using tvm::ffi::Object; using tvm::ffi::ObjectPtr; +using tvm::ffi::ObjectPtrEqual; +using tvm::ffi::ObjectPtrHash; using tvm::ffi::ObjectRef; } // namespace tvm #endif // TVM_RUNTIME_OBJECT_H_ diff --git a/include/tvm/s_tir/meta_schedule/arg_info.h b/include/tvm/s_tir/meta_schedule/arg_info.h index cf705508740f..ae2c3c9057df 100644 --- a/include/tvm/s_tir/meta_schedule/arg_info.h +++ b/include/tvm/s_tir/meta_schedule/arg_info.h @@ -22,7 +22,6 @@ #include #include #include -#include #include #include #include diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 47ed628da0ad..e5679c6064ac 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 53efc9df7f2b..cff1242c8890 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -23,7 +23,6 @@ #include #include #include -#include #include #include diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index 9fe3d7e1ac65..761e2a995ff1 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -21,7 +21,6 @@ #include #include -#include #include #include diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 9ce980d268df..9d63ae08e3d0 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -22,7 +22,6 @@ #include #include #include -#include #include #include diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index cf8c72daf89a..bd8c37780c1c 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index 68caa5ff4d97..2cc1782d9240 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -20,8 +20,8 @@ #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_ #include -#include #include +#include #include #include diff --git a/include/tvm/target/tag.h b/include/tvm/target/tag.h index b8de3fffba57..c6c828e2df5c 100644 --- a/include/tvm/target/tag.h +++ b/include/tvm/target/tag.h @@ -26,7 +26,6 @@ #include #include -#include #include #include diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 9a0bedd1cc4b..b71a4952b530 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 02ac88a9af6e..86f55a24106a 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index 889d0eff8904..52512edda88f 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -365,7 +365,7 @@ class VirtualDeviceCache { private: /*! \brief Already constructed VirtualDevices. */ - std::unordered_set cache_; + std::unordered_set cache_; }; /*! brief The attribute key for the virtual device. This key will be promoted to first class on diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 529765469165..34c11bdd3ed8 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -30,7 +30,6 @@ #include #include #include -#include #include #include #include diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 97dfbb133026..31c6e3bc5ccd 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 6866431ee487..c4e716cbe166 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 4fcb91403fe8..b41c92d66af2 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -25,6 +25,7 @@ #define TVM_TIR_STMT_H_ #include +#include #include #include diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 521b03a4728b..e83064b86489 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -25,7 +25,6 @@ #define TVM_TIR_VAR_H_ #include -#include #include #include diff --git a/python/tvm/runtime/_tensor.py b/python/tvm/runtime/_tensor.py index d5810094b511..13557f9bd97a 100644 --- a/python/tvm/runtime/_tensor.py +++ b/python/tvm/runtime/_tensor.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-import, redefined-outer-name -# ruff: noqa: E722, F401, RUF005 +# ruff: noqa: F401, RUF005 """Runtime Tensor API""" import ctypes @@ -111,7 +111,7 @@ def copyfrom(self, source_array): if not isinstance(source_array, np.ndarray): try: source_array = np.array(source_array, dtype=self.dtype) - except: + except Exception: raise TypeError( f"array must be an array_like data, type {type(source_array)} is not supported" ) diff --git a/python/tvm/s_tir/dlight/analysis/common_analysis.py b/python/tvm/s_tir/dlight/analysis/common_analysis.py index 2ec12b26c0dc..3a148361092e 100644 --- a/python/tvm/s_tir/dlight/analysis/common_analysis.py +++ b/python/tvm/s_tir/dlight/analysis/common_analysis.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E722 # pylint: disable=missing-function-docstring, missing-class-docstring # pylint: disable=unused-argument, unused-variable @@ -377,7 +376,7 @@ def get_max_shared_memory_per_block(target: Target) -> int: def get_root_block(sch: Schedule, func_name: str = "main") -> SBlockRV: try: block = sch.mod[func_name].body.block - except: + except Exception: raise ValueError( f"The function body is expected to be the root block, but got:\n" f"{sch.mod[func_name].body}" diff --git a/src/arith/conjunctive_normal_form.cc b/src/arith/conjunctive_normal_form.cc index 7a87a1fbabf5..47b6156cfa06 100644 --- a/src/arith/conjunctive_normal_form.cc +++ b/src/arith/conjunctive_normal_form.cc @@ -133,10 +133,10 @@ class AndOfOrs { std::vector> chunks_; /*! \brief Mapping from internal Key to PrimExpr */ - std::unordered_map key_to_expr_; + std::unordered_map key_to_expr_; /*! \brief Mapping from PrimExpr to internal Key */ - std::unordered_map expr_to_key_; + std::unordered_map expr_to_key_; /*! \brief Cached key representing tir::Bool(true) */ Key key_true_; diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 4d08c790724e..1779c42583a2 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -458,10 +458,12 @@ class IterMapRewriter : public ExprMutator { // usage of an input iterator. (e.g. (i-1) occurring in the // expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be // left-padded by 31 for each occurrence.) - std::unordered_map padded_iter_map_; + std::unordered_map + padded_iter_map_; // Map from padded iter mark to it's origin mark - std::unordered_map padded_origin_map_; + std::unordered_map + padded_origin_map_; /* If update_iterator_padding_ is true, allow the extents of the IterMap to be * padded beyond the original iterators. diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 3b8e96773ba5..6c932ea5221b 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -103,7 +103,7 @@ void AddInequality(std::vector* inequality_set, const PrimExpr& new_in Analyzer* analyzer) { if (analyzer->CanProve(new_ineq) || std::find_if(inequality_set->begin(), inequality_set->end(), [&](const PrimExpr& e) { - return StructuralEqual()(e, new_ineq); + return ffi::StructuralEqual()(e, new_ineq); }) != inequality_set->end()) { // redundant: follows from the vranges // or has already been added @@ -175,7 +175,7 @@ void MoveEquality(std::vector* upper_bounds, std::vector* lo // those exist in both upper & lower bounds will be moved to equalities for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) { auto lb = std::find_if(lower_bounds->begin(), lower_bounds->end(), - [&](const PrimExpr& e) { return StructuralEqual()(e, *ub); }); + [&](const PrimExpr& e) { return ffi::StructuralEqual()(e, *ub); }); if (lb != lower_bounds->end()) { equalities->push_back(*lb); lower_bounds->erase(lb); diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index 3794ff150bb9..23aaf2140c33 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -139,7 +139,7 @@ class TransitiveComparisonAnalyzer::Impl { * \see ExprToKey * \see ExprToPreviousKey */ - std::unordered_map expr_to_key; + std::unordered_map expr_to_key; /*! \brief Internal representation of a comparison operator */ struct Comparison { diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index fa6fc378f24f..aeecd79750ff 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -23,6 +23,7 @@ #include "msc_base_printer.h" +#include #include #include "../utils.h" diff --git a/src/ir/module.cc b/src/ir/module.cc index 04e3026b0f11..935d9e0ccdb4 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -21,13 +21,13 @@ * \brief The global module in TVM. */ #include +#include #include #include #include #include #include #include -#include #include #include diff --git a/src/node/serialization.cc b/src/ir/serialization.cc similarity index 97% rename from src/node/serialization.cc rename to src/ir/serialization.cc index 2faf8d170bd8..4d9074e98cf3 100644 --- a/src/node/serialization.cc +++ b/src/ir/serialization.cc @@ -18,7 +18,7 @@ */ /*! - * \file node/serialization.cc + * \file src/ir/serialization.cc * \brief Utilities to serialize TVM AST/IR objects. */ #include diff --git a/src/node/structural_equal.cc b/src/ir/structural_equal.cc similarity index 90% rename from src/node/structural_equal.cc rename to src/ir/structural_equal.cc index e33d7c774687..1d7cbd23d0ca 100644 --- a/src/node/structural_equal.cc +++ b/src/ir/structural_equal.cc @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file src/node/structural_equal.cc + * \file src/ir/structural_equal.cc */ #include #include @@ -25,8 +25,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -80,8 +80,4 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("node.GetFirstStructuralMismatch", ffi::StructuralEqual::GetFirstMismatch); } -bool StructuralEqual::operator()(const ffi::Any& lhs, const ffi::Any& rhs, - bool map_free_params) const { - return ffi::StructuralEqual::Equal(lhs, rhs, map_free_params); -} } // namespace tvm diff --git a/src/node/structural_hash.cc b/src/ir/structural_hash.cc similarity index 96% rename from src/node/structural_hash.cc rename to src/ir/structural_hash.cc index f32f0756c04d..ad74742e5144 100644 --- a/src/node/structural_hash.cc +++ b/src/ir/structural_hash.cc @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file src/node/structural_hash.cc + * \file src/ir/structural_hash.cc */ #include #include @@ -26,8 +26,6 @@ #include #include #include -#include -#include #include #include #include @@ -81,10 +79,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -uint64_t StructuralHash::operator()(const ffi::Any& object) const { - return ffi::StructuralHash::Hash(object, false); -} - struct RefToObjectPtr : public ObjectRef { static ObjectPtr Get(const ObjectRef& ref) { return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(ref); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 148918be8eee..9d2dec6b1a7c 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -21,12 +21,12 @@ * \file src/ir/transform.cc * \brief Infrastructure for transformation passes. */ +#include #include #include #include #include #include -#include #include #include @@ -312,10 +312,10 @@ IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { IRModule Pass::AssertImmutableModule(const IRModule& mod, const PassNode* node, const PassContext& pass_ctx) { - size_t before_pass_hash = tvm::StructuralHash()(mod); + size_t before_pass_hash = ffi::StructuralHash()(mod); IRModule copy_mod = mod; IRModule ret = node->operator()(mod, pass_ctx); - size_t after_pass_hash = tvm::StructuralHash()(copy_mod); + size_t after_pass_hash = ffi::StructuralHash()(copy_mod); if (before_pass_hash != after_pass_hash) { // The chance of getting a hash conflict between a module and the same module but mutated // must be very low. diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 2565a02b64a5..a5f4ffb84d7c 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -25,7 +25,6 @@ #include #include #include -#include namespace tvm { diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 09413ba007e3..1774ba4b4b03 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index cd951896d821..101e8e8b7410 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -485,7 +485,7 @@ class StructInfoBaseChecker // analyzer arith::Analyzer* analyzer_; // struct equal checker - StructuralEqual struct_equal_; + ffi::StructuralEqual struct_equal_; // customizable functions. /*! @@ -742,7 +742,7 @@ class StructInfoBasePreconditionCollector return Bool(false); } - StructuralEqual struct_equal; + ffi::StructuralEqual struct_equal; if (!struct_equal(lhs->device_mesh, rhs->device_mesh) || !struct_equal(lhs->placement, rhs->placement)) { return Bool(false); @@ -1154,7 +1154,7 @@ class StructInfoLCAFinder // analyzer arith::Analyzer* analyzer_; // struct equal checker - StructuralEqual struct_equal_; + ffi::StructuralEqual struct_equal_; // check arrays ffi::Optional> UnifyArray(const ffi::Array& lhs, @@ -1303,7 +1303,7 @@ class NonNegativeExpressionCollector : relax::StructInfoVisitor { } ffi::Array expressions_; - std::unordered_set dedup_lookup_; + std::unordered_set dedup_lookup_; }; ffi::Array CollectNonNegativeExpressions(const StructInfo& sinfo) { diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc index 861e57aeb7d5..396e5f9cbb1f 100644 --- a/src/relax/backend/adreno/annotate_custom_storage.cc +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -236,7 +236,7 @@ * */ -#include +#include #include #include #include diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc index a7103cde9577..af34a3ac10ef 100644 --- a/src/relax/backend/adreno/fold_vdevice_scope_change.cc +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -23,7 +23,7 @@ * store into global scope avoiding unnecessary device copy. */ -#include +#include #include #include #include diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 68c266c5dd11..4cca5d7b6c8b 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -69,7 +69,7 @@ struct MatchShapeTodoItem { /*! \brief Slot map used for shape lowering. */ using PrimExprSlotMap = - std::unordered_map; + std::unordered_map; // Collector to collect PrimExprSlotMap class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { diff --git a/src/relax/distributed/transform/legalize_redistribute.cc b/src/relax/distributed/transform/legalize_redistribute.cc index 7faa4697874d..88fdb14ffd0c 100644 --- a/src/relax/distributed/transform/legalize_redistribute.cc +++ b/src/relax/distributed/transform/legalize_redistribute.cc @@ -74,7 +74,7 @@ class RedistributeLegalizer : public ExprMutator { // and the device mesh must be 1d // todo: extend the ccl ops so that it can support 2d device mesh, and different sharding // dimension - TVM_FFI_ICHECK(StructuralEqual()(input_sinfo->device_mesh, attrs->device_mesh)); + TVM_FFI_ICHECK(ffi::StructuralEqual()(input_sinfo->device_mesh, attrs->device_mesh)); TVM_FFI_ICHECK(input_sinfo->device_mesh->shape.size() == 1); // only support "S[x]"-> "R" and "R" -> "S[x]" PlacementSpec input_spec = input_sinfo->placement->dim_specs[0]; diff --git a/src/relax/distributed/transform/lower_distir.cc b/src/relax/distributed/transform/lower_distir.cc index 83300f80acb2..49fc366a5360 100644 --- a/src/relax/distributed/transform/lower_distir.cc +++ b/src/relax/distributed/transform/lower_distir.cc @@ -256,7 +256,8 @@ class DistIRSharder : public ExprMutator { Function func_; ffi::Array new_params_; - std::unordered_map tuple_getitem_remap_; + std::unordered_map + tuple_getitem_remap_; }; namespace transform { diff --git a/src/relax/distributed/transform/propagate_sharding.cc b/src/relax/distributed/transform/propagate_sharding.cc index 703857da9141..1123b1db25b5 100644 --- a/src/relax/distributed/transform/propagate_sharding.cc +++ b/src/relax/distributed/transform/propagate_sharding.cc @@ -283,7 +283,7 @@ class ShardingConflictHandler : public ExprVisitor { } if (device_mesh.defined()) { - TVM_FFI_ICHECK(StructuralEqual()(device_mesh.value(), sharding_spec.first)) + TVM_FFI_ICHECK(ffi::StructuralEqual()(device_mesh.value(), sharding_spec.first)) << "Sharding conflict detected for tensor " << var->name_hint() << ": Device Mesh mismatch" << ". Conflict Handling logic will be added in the future."; @@ -561,7 +561,7 @@ class DistributedIRBuilder : public ExprMutator { if (const auto* inferred_dtensor_sinfo = new_call->struct_info_.as()) { Expr new_value = RemoveAnnotateSharding(new_call); - if (!StructuralEqual()( + if (!ffi::StructuralEqual()( DTensorStructInfo(inferred_dtensor_sinfo->tensor_sinfo, device_mesh, placements[0]), new_call->struct_info_)) { new_value = InsertRedistribute(new_value, device_mesh, placements[0]); @@ -577,7 +577,7 @@ class DistributedIRBuilder : public ExprMutator { Var new_var = builder_->Emit(new_call); var_remap_[binding->var->vid] = new_var; for (int i = 0; i < static_cast(inferred_tuple_sinfo->fields.size()); i++) { - if (!StructuralEqual()( + if (!ffi::StructuralEqual()( DTensorStructInfo( Downcast(inferred_tuple_sinfo->fields[i])->tensor_sinfo, device_mesh, placements[i]), @@ -607,7 +607,8 @@ class DistributedIRBuilder : public ExprMutator { } ffi::Map input_tensor_remap_; - std::unordered_map tuple_getitem_remap_; + std::unordered_map + tuple_getitem_remap_; AxisGroupGraph axis_group_graph_; }; namespace transform { diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 057351e3d069..c2af644fba26 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -445,7 +445,7 @@ class BlockBuilderImpl : public BlockBuilderNode { */ std::unique_ptr< std::unordered_map, - StructuralHashIgnoreNDarray, StructuralEqual>> + StructuralHashIgnoreNDarray, ffi::StructuralEqual>> ctx_func_dedup_map_ = nullptr; /*! @@ -455,7 +455,7 @@ class BlockBuilderImpl : public BlockBuilderNode { if (ctx_func_dedup_map_ != nullptr) return; ctx_func_dedup_map_ = std::make_unique< std::unordered_map, - StructuralHashIgnoreNDarray, StructuralEqual>>(); + StructuralHashIgnoreNDarray, ffi::StructuralEqual>>(); for (const auto& kv : context_mod_->functions) { const GlobalVar gv = kv.first; const BaseFunc func = kv.second; diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index 2f2d1dac9ae9..b13fb84105b7 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -23,8 +23,8 @@ */ #include +#include #include -#include #include #include #include diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 72f62041dbf0..a95b51745de0 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -22,9 +22,9 @@ * \brief A transform to match a Relax Expr and rewrite */ +#include #include #include -#include #include #include #include @@ -543,7 +543,7 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( for (size_t i = 1; i < indices.size(); i++) { for (const auto& [pat, expr] : info_vec[indices[i]].matches[i].value()) { if (auto it = merged_matches.find(pat); it != merged_matches.end()) { - if (!StructuralEqual()(expr, (*it).second)) { + if (!ffi::StructuralEqual()(expr, (*it).second)) { return std::nullopt; } } else { @@ -698,7 +698,7 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { auto sinfo_pattern = GetStructInfo(func_pattern); auto sinfo_replacement = GetStructInfo(func_replacement); - TVM_FFI_CHECK(StructuralEqual()(sinfo_pattern, sinfo_replacement), ValueError) + TVM_FFI_CHECK(ffi::StructuralEqual()(sinfo_pattern, sinfo_replacement), ValueError) << "The pattern and replacement must have the same signature, " << "but the pattern has struct info " << sinfo_pattern << ", while the replacement has struct info " << sinfo_replacement; @@ -832,7 +832,7 @@ class PatternMatchingMutator : public ExprMutator { Expr VisitExpr_(const SeqExprNode* seq) override { SeqExpr prev = Downcast(ExprMutator::VisitExpr_(seq)); - StructuralEqual struct_equal; + ffi::StructuralEqual struct_equal; while (auto opt = TryRewriteSeqExpr(prev)) { SeqExpr next = Downcast(builder_->Normalize(opt.value())); diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 2f7099937fec..3c0e57dc073d 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -25,7 +25,7 @@ #include "dataflow_matcher.h" #include -#include +#include #include #include #include @@ -67,7 +67,7 @@ bool MatchAttrs(const Any& attrs, const ffi::Map& attribu auto attr_name = kv.first; auto attr_value = kv.second; if (dict_attrs->dict.count(attr_name)) { - if (!StructuralEqual()(attr_value, dict_attrs->dict[attr_name])) { + if (!ffi::StructuralEqual()(attr_value, dict_attrs->dict[attr_name])) { return false; } } else { @@ -89,7 +89,7 @@ bool MatchAttrs(const Any& attrs, const ffi::Map& attribu if (attributes.count(field_name)) { ffi::reflection::FieldGetter field_getter(field_info); ffi::Any field_value = field_getter(obj); - if (!StructuralEqual()(attributes[field_name], field_value)) { + if (!ffi::StructuralEqual()(attributes[field_name], field_value)) { success = false; return true; } @@ -194,7 +194,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons if (Op::HasAttrMap(attr_name)) { auto op_map = Op::GetAttrMap(attr_name); if (op_map.count(op)) { - matches &= StructuralEqual()(attr_value, op_map[op]); + matches &= ffi::StructuralEqual()(attr_value, op_map[op]); } else { matches = false; } @@ -208,7 +208,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons matches = true; for (auto kv : attributes) { if (matches && op->attrs.defined() && op->attrs->dict.count(kv.first)) { - matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]); + matches &= ffi::StructuralEqual()(kv.second, op->attrs->dict[kv.first]); } else { matches = false; break; @@ -332,7 +332,7 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr0) { auto expr = UnwrapBindings(expr0, var2val_); - return StructuralEqual()(op->expr, expr); + return ffi::StructuralEqual()(op->expr, expr); } bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr0) { @@ -570,7 +570,8 @@ bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr // no need to jump, as var.dtype == value.dtype auto expr_sinfo = expr.as()->struct_info_; if (const TensorStructInfoNode* tensor_sinfo = expr_sinfo.as()) { - return (StructuralEqual()(op->dtype, tensor_sinfo->dtype)) && VisitDFPattern(op->pattern, expr); + return (ffi::StructuralEqual()(op->dtype, tensor_sinfo->dtype)) && + VisitDFPattern(op->pattern, expr); } return false; } diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 13ef41eede43..fdedf8091130 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -848,7 +848,7 @@ Var ExprMutator::WithStructInfo(Var var, StructInfo struct_info) { if (var->struct_info_.defined()) { // use same-as as a quick path if (var->struct_info_.same_as(struct_info) || - StructuralEqual()(var->struct_info_, struct_info)) { + ffi::StructuralEqual()(var->struct_info_, struct_info)) { return var; } else { Var new_var = var.as() ? DataflowVar(var->vid, struct_info, var->span) diff --git a/src/relax/op/distributed/utils.cc b/src/relax/op/distributed/utils.cc index d8a23da3825b..57c80abbf6d7 100644 --- a/src/relax/op/distributed/utils.cc +++ b/src/relax/op/distributed/utils.cc @@ -45,8 +45,8 @@ StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); for (int i = 1; i < static_cast(input_dtensor_sinfos.size()); i++) { - TVM_FFI_ICHECK(StructuralEqual()(input_dtensor_sinfos[0]->device_mesh, - input_dtensor_sinfos[i]->device_mesh)); + TVM_FFI_ICHECK(ffi::StructuralEqual()(input_dtensor_sinfos[0]->device_mesh, + input_dtensor_sinfos[i]->device_mesh)); } distributed::DeviceMesh device_mesh = input_dtensor_sinfos[0]->device_mesh; Var output_var("output", orig_output_sinfo); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index e5f3d19e8dd9..fc6ec6b8aa04 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -163,7 +163,7 @@ ffi::Optional> CheckConcatOutputShape( // For the specified axis, we compute the sum of shape value over each tensor. // Special case, if all concatenated values have the same shape - StructuralEqual structural_equal; + ffi::StructuralEqual structural_equal; PrimExpr first_concat_dim = shape_values[0][axis]; bool all_same = std::all_of(shape_values.begin(), shape_values.end(), [&](const auto& a) { return structural_equal(a[axis], first_concat_dim); diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 614f20dba7f8..000a0d4b7d79 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -150,7 +150,7 @@ class AppendLossMutator : private ExprMutator { * sets up var_remap_ from loss parameter Vars to backbone returned Vars. */ void CheckAndRemapLossParams(const ffi::Array& loss_func_params) { - static StructuralEqual checker; + static ffi::StructuralEqual checker; TVM_FFI_ICHECK(static_cast(loss_func_params.size()) >= num_backbone_outputs_) << "The number of parameters of the loss function is " << loss_func_params.size() << ", which is less than the given num_backbone_outputs " << num_backbone_outputs_; diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index f066bd02daec..93db755059b9 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 05c86d92630d..98fd075f55d5 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -92,7 +92,7 @@ class SymbolicVarCanonicalizer : public ExprMutator { // within each branch. auto new_sinfo = VisitExprDepStructInfoField(Downcast(op->struct_info_)); - StructuralEqual struct_equal; + ffi::StructuralEqual struct_equal; if (!struct_equal(new_sinfo, GetStructInfo(true_b))) { auto output_var = Var("then_branch_with_dyn", new_sinfo); @@ -351,7 +351,8 @@ class CanonicalizePlanner : public ExprVisitor { if (binding.as()) { return true; } else if (auto match_cast = binding.as()) { - return StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(match_cast->value)); + return ffi::StructuralEqual()(GetStructInfo(binding->var), + GetStructInfo(match_cast->value)); } else { TVM_FFI_THROW(InternalError) << "Invalid binding type: " << binding->GetTypeKey(); } diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index c71675bb26dd..888554df6751 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -22,7 +22,7 @@ */ #include -#include +#include #include #include #include diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 7e7f069cdd14..20e4ce4f59b7 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -24,6 +24,8 @@ * * Currently it removes common subexpressions within a Function. */ +#include +#include #include #include #include @@ -58,7 +60,7 @@ struct ReplacementKey { } friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) { - tvm::StructuralEqual eq; + ffi::StructuralEqual eq; return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast); } }; @@ -76,7 +78,7 @@ struct ReplacementKey { template <> struct std::hash { std::size_t operator()(const tvm::relax::ReplacementKey& key) const { - tvm::StructuralHash hasher; + tvm::ffi::StructuralHash hasher; return tvm::support::HashCombine(hasher(key.bound_value), hasher(key.match_cast)); } }; diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 3a289ebfff49..5194941c26c9 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -412,7 +412,8 @@ class ConstantFolder : public ExprMutator { } // cache for function build, via structural equality - std::unordered_map, StructuralHash, StructuralEqual> + std::unordered_map, ffi::StructuralHash, + ffi::StructuralEqual> func_build_cache_; }; diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 4a36047906c2..3f739cd243e4 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -492,7 +492,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { // structurally equal to the `new_buf` passed auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) { if (auto it = relax_to_tir_var_map_.find(expr); it != relax_to_tir_var_map_.end()) { - TVM_FFI_ICHECK(StructuralEqual()((*it).second, new_buf)) + TVM_FFI_ICHECK(ffi::StructuralEqual()((*it).second, new_buf)) << "Inconsistent buffers " << (*it).second << " and " << new_buf << " mapped to the same relax var: " << expr; } diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index bae9794ecc22..969319063c7b 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -52,7 +52,7 @@ class UnusedTrivialBindingRemover : public ExprMutator { } void VisitBinding_(const MatchCastNode* binding) override { if (binding->value.as() && - StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(binding->value))) { + ffi::StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(binding->value))) { has_trivial_binding.insert(binding->var.get()); } ExprVisitor::VisitBinding_(binding); diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 15ba2b82e8b4..0e9cc204ca06 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -525,7 +525,7 @@ class ParamRemapper : private ExprFunctor { int index_i = j + num_inputs_i; int index_0 = j + num_inputs_0; mapper.VisitExpr(functions[i]->params[index_i], functions[0]->params[index_0]); - StructuralEqual eq; + ffi::StructuralEqual eq; eq(functions[i]->params[index_i]->struct_info_, functions[0]->params[index_0]->struct_info_); } @@ -642,7 +642,7 @@ class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { // The mapping between the unified bindings and the original bindings in different functions. // The unified binding is the binding with all variables replaced by the unified variables as // defined in var_remap_. - std::unordered_map, StructuralHash, StructuralEqual> + std::unordered_map, ffi::StructuralHash, ffi::StructuralEqual> original_bindings_; }; // namespace diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index 9de26d8b1a4e..192dc7acef8e 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include diff --git a/src/relax/transform/specialize_primfunc_based_on_callsite.cc b/src/relax/transform/specialize_primfunc_based_on_callsite.cc index 10fc575e729d..8a38baedd733 100644 --- a/src/relax/transform/specialize_primfunc_based_on_callsite.cc +++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc @@ -21,7 +21,7 @@ * \brief Update PrimFunc buffers based on updated scope (or structure) info. */ -#include +#include #include #include #include diff --git a/src/s_tir/meta_schedule/module_equality.cc b/src/s_tir/meta_schedule/module_equality.cc index 6973ba809627..fff1a88c3386 100644 --- a/src/s_tir/meta_schedule/module_equality.cc +++ b/src/s_tir/meta_schedule/module_equality.cc @@ -21,8 +21,6 @@ #include #include #include -#include -#include #include #include @@ -33,8 +31,8 @@ namespace meta_schedule { class ModuleEqualityStructural : public ModuleEquality { public: - size_t Hash(IRModule mod) const { return tvm::StructuralHash()(mod); } - bool Equal(IRModule lhs, IRModule rhs) const { return tvm::StructuralEqual()(lhs, rhs); } + size_t Hash(IRModule mod) const { return ffi::StructuralHash()(mod); } + bool Equal(IRModule lhs, IRModule rhs) const { return ffi::StructuralEqual()(lhs, rhs); } ffi::String GetName() const { return "structural"; } }; diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h index 6b2dd3c96f47..d5569d07ec06 100644 --- a/src/s_tir/meta_schedule/utils.h +++ b/src/s_tir/meta_schedule/utils.h @@ -21,8 +21,9 @@ #include #include -#include -#include +#include +#include +#include #include #include #include @@ -228,7 +229,7 @@ inline ffi::String SHash2Hex(const ObjectRef& obj) { std::ostringstream os; size_t hash_code = 0; if (obj.defined()) { - hash_code = StructuralHash()(obj); + hash_code = ffi::StructuralHash()(obj); } os << "0x" << std::setw(16) << std::setfill('0') << std::hex << hash_code; return os.str(); diff --git a/src/s_tir/schedule/primitive/compute_inline.cc b/src/s_tir/schedule/primitive/compute_inline.cc index 4ceb444ecd4d..17f804514db4 100644 --- a/src/s_tir/schedule/primitive/compute_inline.cc +++ b/src/s_tir/schedule/primitive/compute_inline.cc @@ -744,7 +744,7 @@ class ReverseComputeInliner : public BaseInliner { if (const auto* if_ = producer_block->body.as()) { if (!if_->else_case.defined()) { PrimExpr if_predicate = analyzer_.Simplify(if_->condition); - if (!StructuralEqual()(predicate, if_predicate)) { + if (!ffi::StructuralEqual()(predicate, if_predicate)) { predicate = analyzer_.Simplify(predicate && if_->condition); producer_block.CopyOnWrite()->body = if_->then_case; } diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc b/src/s_tir/schedule/primitive/layout_transformation.cc index f608a4b0a3ff..2d8629c06fac 100644 --- a/src/s_tir/schedule/primitive/layout_transformation.cc +++ b/src/s_tir/schedule/primitive/layout_transformation.cc @@ -18,7 +18,8 @@ */ #include -#include +#include +#include #include #include diff --git a/src/s_tir/schedule/utils.h b/src/s_tir/schedule/utils.h index 715e34b09f61..d8aebb2f6dd3 100644 --- a/src/s_tir/schedule/utils.h +++ b/src/s_tir/schedule/utils.h @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index 6e749dbe6416..1e1bb446e42e 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -21,6 +21,7 @@ * \file inject_software_pipeline.cc * \brief Transform annotated loops into pipelined one that parallelize producers and consumers */ +#include #include #include #include @@ -784,7 +785,7 @@ class PipelineRewriter : public StmtExprMutator { auto stage_id = commit_group_indices[i]; auto predicate = new_blocks[i].predicate; for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) { - TVM_FFI_ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate)) + TVM_FFI_ICHECK(ffi::StructuralEqual()(predicate, new_blocks[i].predicate)) << "Predicates in the same stage are expected to be identical"; group_bodies.push_back(new_blocks[i].block->body); } diff --git a/src/s_tir/transform/using_assume_to_reduce_branches.cc b/src/s_tir/transform/using_assume_to_reduce_branches.cc index 2c356c8f8efa..e506d1985431 100644 --- a/src/s_tir/transform/using_assume_to_reduce_branches.cc +++ b/src/s_tir/transform/using_assume_to_reduce_branches.cc @@ -204,7 +204,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { PrimExpr current_predicate_and_context = CurrentScopePredicate(); PrimExpr buffer_predicate_and_context = buffer_assumption.buffer_context && buffer_assumption.buffer_predicate; - bool current_context_and_buffer_constraint_is_same = StructuralEqual()( + bool current_context_and_buffer_constraint_is_same = ffi::StructuralEqual::Equal( current_predicate_and_context, buffer_predicate_and_context, /*map_free_vars=*/true); if (current_context_and_buffer_constraint_is_same) { @@ -251,10 +251,11 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { } auto n = this->CopyOnWrite(op); - if (StructuralEqual()(then_clause_in_then_context, else_clause_in_then_context)) { + if (ffi::StructuralEqual()(then_clause_in_then_context, else_clause_in_then_context)) { n->value = analyzer_->Simplify(else_clause); return Stmt(n); - } else if (StructuralEqual()(then_clause_in_else_context, else_clause_in_else_context)) { + } else if (ffi::StructuralEqual()(then_clause_in_else_context, + else_clause_in_else_context)) { n->value = analyzer_->Simplify(then_clause); return Stmt(n); } else { diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc index 0c8cd3c12371..99d9618639f7 100644 --- a/src/script/printer/relax/expr.cc +++ b/src/script/printer/relax/expr.cc @@ -19,6 +19,7 @@ #include +#include #include #include "./utils.h" diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 7dddfaecbbe7..558abaef3350 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -125,7 +125,7 @@ inline ffi::Optional StructInfoAsAnn(const relax::Var& v, const AccessP inferred_sinfo = trivial_binding->struct_info_.as(); } - if (inferred_sinfo && StructuralEqual()(inferred_sinfo, v->struct_info_)) { + if (inferred_sinfo && ffi::StructuralEqual()(inferred_sinfo, v->struct_info_)) { return std::nullopt; } } diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 2ea588c5eeae..c0dbd2e46c27 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -19,7 +19,7 @@ #ifndef TVM_SCRIPT_PRINTER_UTILS_H_ #define TVM_SCRIPT_PRINTER_UTILS_H_ -#include +#include #include #include diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 0f8806a117c6..d82c96f2636e 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include diff --git a/src/support/scalars.h b/src/support/scalars.h index fa5a3482f5f6..069ed62445e7 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -25,6 +25,7 @@ #ifndef TVM_SUPPORT_SCALARS_H_ #define TVM_SUPPORT_SCALARS_H_ +#include #include #include "tvm/ir/expr.h" diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 8dfdd977accb..517dbe07b54e 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include #include diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 61abb610183a..5ee7feb11608 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -63,7 +63,7 @@ static inline void AssertReduceEqual(const tir::ReduceNode* a, const tir::Reduce "each reduction must be structurally identical, " "except for the ReduceNode::value_index. "; - StructuralEqual eq; + ffi::StructuralEqual eq; TVM_FFI_ICHECK(a->combiner.same_as(b->combiner)) << shared_text << "However, the reduction operation " << a->combiner << " does not match " diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 831abb929927..3d2536e423eb 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -250,7 +250,7 @@ ffi::Array GenerateOutputBuffers(const te::ComputeOp& compute_op, Create ffi::Array tensors; if (compute_op->body[0]->IsInstance()) { auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { - StructuralEqual eq; + ffi::StructuralEqual eq; return eq(a->combiner, b->combiner) && // eq(a->source, b->source) && // eq(a->axis, b->axis) && // diff --git a/src/tir/transform/common_subexpr_elim_tools.cc b/src/tir/transform/common_subexpr_elim_tools.cc index 1c52c6f97f5d..4aa4cbbe764d 100644 --- a/src/tir/transform/common_subexpr_elim_tools.cc +++ b/src/tir/transform/common_subexpr_elim_tools.cc @@ -797,7 +797,7 @@ std::vector> SyntacticToSemanticComputations( // normalized. This normalized table will keep the count for each set of equivalent terms // (i.e. each equivalence class), together with a term that did appear in this equivalence class // (in practice, the first term of the equivalence class that was encoutered). - support::OrderedMap, StructuralHash, ExprDeepEqual> + support::OrderedMap, ffi::StructuralHash, ExprDeepEqual> norm_table; // In order to avoid frequent rehashing if the norm_table becomes big, we immediately ask for diff --git a/src/tir/transform/common_subexpr_elim_tools.h b/src/tir/transform/common_subexpr_elim_tools.h index b9c056dcf230..cd548ec0ed0a 100644 --- a/src/tir/transform/common_subexpr_elim_tools.h +++ b/src/tir/transform/common_subexpr_elim_tools.h @@ -26,6 +26,7 @@ #ifndef TVM_TIR_TRANSFORM_COMMON_SUBEXPR_ELIM_TOOLS_H_ #define TVM_TIR_TRANSFORM_COMMON_SUBEXPR_ELIM_TOOLS_H_ +#include #include #include // For the ExprDeepEqual analysis #include @@ -46,13 +47,12 @@ namespace tir { /*! * \brief A computation table is a hashtable which associates to each expression being computed a number (which is the number of time that it is computed) - It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash) - as we need to hash similarly deeply equal terms. - The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does - not do variables remapping), so it is compatible with StructuralHash (intended to be used - with StructuralEqual). + It is important to note that the hash used is a ffi::StructuralHash (and not an + ObjectPtrHash) as we need to hash similarly deeply equal terms. The comparison used is + ExprDeepEqual, which is stricter than ffi::StructuralEqual (as it does not do variables remapping), + so it is compatible with ffi::StructuralHash (intended to be used with ffi::StructuralEqual). */ -using ComputationTable = support::OrderedMap; +using ComputationTable = support::OrderedMap; /*! * \brief A cache of computations is made of a pair of two hashtables, which respectively associate diff --git a/src/tir/transform/vectorize_loop.cc b/src/tir/transform/vectorize_loop.cc index 2e8f1811996a..719d27e7431c 100644 --- a/src/tir/transform/vectorize_loop.cc +++ b/src/tir/transform/vectorize_loop.cc @@ -22,6 +22,7 @@ */ // Loop vectorizer as in Halide pipeline. #include +#include #include #include #include @@ -168,7 +169,7 @@ class TryPredicateBufferAccesses : public StmtExprMutator { Ramp ramp = Downcast(node->indices[0]); // The vectorized access pattern must match the base of the predicate - if (!tvm::StructuralEqual()(ramp->base, base_)) { + if (!ffi::StructuralEqual()(ramp->base, base_)) { return node; } diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc index 495739d76608..9f6108617696 100644 --- a/tests/cpp/arith_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -55,7 +56,7 @@ TEST(Simplify, Mod) { } TEST(ConstantFold, Broadcast) { - tvm::StructuralEqual checker; + tvm::ffi::StructuralEqual checker; auto i32x4 = tvm::tir::Broadcast(tvm::IntImm(tvm::DataType::Int(32), 10), 4); auto i64x4 = tvm::cast(i32x4->dtype.with_bits(64), i32x4); auto i64x4_expected = tvm::tir::Broadcast(tvm::IntImm(tvm::DataType::Int(64), 10), 4); @@ -63,7 +64,7 @@ TEST(ConstantFold, Broadcast) { } TEST(ConstantFold, Ramp) { - tvm::StructuralEqual checker; + tvm::ffi::StructuralEqual checker; auto i32x4 = tvm::tir::Ramp(tvm::IntImm(tvm::DataType::Int(32), 10), tvm::IntImm(tvm::DataType::Int(32), 1), 4); auto i64x4 = tvm::cast(i32x4->dtype.with_bits(64), i32x4); diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index 67c9fe99cf30..1d3aa62f6629 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -40,7 +40,7 @@ TEST(Expr, VarTypeAnnotation) { using namespace tvm::tir; Var x("x", DataType::Float(32)); Var y("y", PrimType(DataType::Float(32))); - StructuralEqual checker; + tvm::ffi::StructuralEqual checker; TVM_FFI_ICHECK(checker(x->dtype, y->dtype)); TVM_FFI_ICHECK(checker(x->type_annotation, y->type_annotation)); } diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc index b1f7b80c996a..02b662875c62 100644 --- a/tests/cpp/nested_msg_test.cc +++ b/tests/cpp/nested_msg_test.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include #include @@ -215,10 +216,11 @@ TEST(NestedMsg, MapToNestedMsgBySInfo) { auto arr1 = arr[1].NestedArray(); EXPECT_TRUE(arr1[0].IsLeaf()); - EXPECT_TRUE(StructuralEqual()(arr1[0].LeafValue(), TupleGetItem(TupleGetItem(x, 1), 0))); + EXPECT_TRUE( + tvm::ffi::StructuralEqual()(arr1[0].LeafValue(), TupleGetItem(TupleGetItem(x, 1), 0))); EXPECT_TRUE(arr[2].IsLeaf()); - EXPECT_TRUE(StructuralEqual()(arr[2].LeafValue(), TupleGetItem(x, 2))); + EXPECT_TRUE(tvm::ffi::StructuralEqual()(arr[2].LeafValue(), TupleGetItem(x, 2))); } TEST(NestedMsg, NestedMsgToExpr) { @@ -246,13 +248,13 @@ TEST(NestedMsg, NestedMsgToExpr) { }); Expr expected = Tuple({x, Tuple({x, y}), Tuple({x, Tuple({y, z})})}); - EXPECT_TRUE(StructuralEqual()(expr, expected)); + EXPECT_TRUE(tvm::ffi::StructuralEqual()(expr, expected)); // test simplified relax::Var t("t", sf1); NestedMsg msg1 = {TupleGetItem(t, 0), TupleGetItem(t, 1)}; auto expr1 = NestedMsgToExpr(msg1, [](ffi::Optional leaf) { return leaf.value(); }); - EXPECT_TRUE(StructuralEqual()(expr1, t)); + EXPECT_TRUE(tvm::ffi::StructuralEqual()(expr1, t)); } TEST(NestedMsg, CombineNestedMsg) { @@ -323,7 +325,7 @@ TEST(NestedMsg, TransformTupleLeaf) { Expr expected = Tuple({y, Tuple({y, z}), x, Tuple({y, Tuple({y, z})})}); - EXPECT_TRUE(StructuralEqual()( + EXPECT_TRUE(tvm::ffi::StructuralEqual()( TransformTupleLeaf(expr, std::array({msg1, msg2}), ftransleaf), expected)); EXPECT_TRUE( diff --git a/tests/cpp/target/virtual_device_test.cc b/tests/cpp/target/virtual_device_test.cc index 60b643396ca7..8e2000852dbe 100644 --- a/tests/cpp/target/virtual_device_test.cc +++ b/tests/cpp/target/virtual_device_test.cc @@ -32,7 +32,7 @@ TEST(VirtualDevice, Join_Defined) { ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global"); - EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + EXPECT_TRUE(tvm::ffi::StructuralEqual()(actual.value(), expected)); } { Target target_a = Target("cuda"); @@ -41,7 +41,7 @@ TEST(VirtualDevice, Join_Defined) { ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global"); - EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + EXPECT_TRUE(tvm::ffi::StructuralEqual()(actual.value(), expected)); } { Target target_a = Target("cuda"); @@ -50,7 +50,7 @@ TEST(VirtualDevice, Join_Defined) { ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = VirtualDevice(kDLCUDA, 2, target_a); - EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + EXPECT_TRUE(tvm::ffi::StructuralEqual()(actual.value(), expected)); } { Target target_a = Target("cuda"); @@ -59,7 +59,7 @@ TEST(VirtualDevice, Join_Defined) { ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = rhs; - EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + EXPECT_TRUE(tvm::ffi::StructuralEqual()(actual.value(), expected)); } } @@ -96,7 +96,7 @@ TEST(VirtualDevice, Default) { VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, target_a, "local"); VirtualDevice actual = VirtualDevice::Default(lhs, rhs); VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global"); - EXPECT_TRUE(StructuralEqual()(actual, expected)); + EXPECT_TRUE(tvm::ffi::StructuralEqual()(actual, expected)); } TEST(VirtualDevice, Constructor_Invalid) { diff --git a/tests/python/relax/test_group_gemm_flashinfer.py b/tests/python/relax/test_group_gemm_flashinfer.py index 39ae8d5a30c7..2d157584904a 100644 --- a/tests/python/relax/test_group_gemm_flashinfer.py +++ b/tests/python/relax/test_group_gemm_flashinfer.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E501, E722, F401, F841, RUF005 +# ruff: noqa: E501, F401, F841, RUF005 """Test for FlashInfer GroupedGemm TVM integration""" @@ -58,7 +58,7 @@ def has_cutlass(): handle = pynvml.nvmlDeviceGetHandleByIndex(0) major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) return major >= 9 # SM90+ - except: + except Exception: return False diff --git a/tests/scripts/release/make_notes.py b/tests/scripts/release/make_notes.py index 82def8adc23c..82e5a4372b0a 100644 --- a/tests/scripts/release/make_notes.py +++ b/tests/scripts/release/make_notes.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E722 import argparse import csv @@ -218,7 +217,7 @@ def pr_title(number, heading): try: title = pr_dict[int(number)]["title"] title = strip_header(title, heading) - except: + except Exception: sprint("The out.pkl file is not match with csv file.") exit(1) return title