From bf47ea85a53ca03a70acc715e497948ead326acd Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 26 Feb 2026 16:53:21 +0000 Subject: [PATCH 1/2] [Relax][Analysis] Validate global_symbol on non-Relax functions Move CheckGlobalVarAndGsymbolConsistency call outside the Relax-function-only branch so it applies to all BaseFunc types (including TIR PrimFunc). Change the method parameter type from Function to BaseFunc and use structured bindings for the iteration loop. Fix tests that had inconsistent global_symbol attributes vs GlobalVar names: - test_tir_host_func.py: remove mismatched "global_symbol": "test" for main - test_tir_intrin.py: remove mismatched "global_symbol": "test_fma" for test_tir_fma - test_tir_transform_device_kernel_launch.py: fix kernel_by_another_name mismatch - test_s_tir_transform_manifest_shared_memory_local_stage.py: remove mismatched "default_function" for main --- src/relax/analysis/well_formed.cc | 12 ++++++------ ...r_transform_manifest_shared_memory_local_stage.py | 4 ++-- tests/python/tir-base/test_tir_host_func.py | 4 +--- tests/python/tir-base/test_tir_intrin.py | 2 +- .../test_tir_transform_device_kernel_launch.py | 6 +++--- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 2c1a42fbe843..5bc2df0d0e3c 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -93,12 +93,12 @@ class WellFormedChecker : public relax::ExprVisitor, WellFormedChecker(obj.as(), check_struct_info); if (const auto* mod = obj.as()) { - for (const auto& it : mod->functions) { + for (const auto& [gvar, base_func] : mod->functions) { + well_formed_checker.CheckGlobalVarAndGsymbolConsistency(gvar, base_func); // visit relax.Function - if (auto* n = it.second.as()) { - Function func = ffi::GetRef(n); - well_formed_checker.func_name_map_[n] = it.first->name_hint; - well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func); + if (auto opt = base_func.as()) { + Function func = opt.value(); + well_formed_checker.func_name_map_[func.get()] = gvar->name_hint; well_formed_checker.VisitExpr(func); } } @@ -146,7 +146,7 @@ class WellFormedChecker : public relax::ExprVisitor, return "(anonymous function)"; } - void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, Function func) { + void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, BaseFunc func) { // the uniqueness of all global vars are ensured by IRModule->global_var_map_, so do not need // to check again diff --git a/tests/python/s_tir/transform/test_s_tir_transform_manifest_shared_memory_local_stage.py b/tests/python/s_tir/transform/test_s_tir_transform_manifest_shared_memory_local_stage.py index 7a2f0ea136f5..d73e6e68c0f2 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_manifest_shared_memory_local_stage.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_manifest_shared_memory_local_stage.py @@ -29,7 +29,7 @@ class MatmulBefore: @T.prim_func def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) # body # with T.sblock("root") for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): @@ -70,7 +70,7 @@ class MatmulAfter: @T.prim_func def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) # body # with T.sblock("root") for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): diff --git a/tests/python/tir-base/test_tir_host_func.py b/tests/python/tir-base/test_tir_host_func.py index 8f9a51942d0b..88578f2c3c15 100644 --- a/tests/python/tir-base/test_tir_host_func.py +++ b/tests/python/tir-base/test_tir_host_func.py @@ -33,7 +33,6 @@ def main( ): T.func_attr( { - "global_symbol": "test", "target": tvm.target.Target("llvm", host="llvm"), "tir.noalias": True, } @@ -59,12 +58,11 @@ def test_host_func(): func = tvm.te.create_prim_func( te_workload.matmul(729, 729, 729, in_dtype="float32", out_dtype="float32") ) - mod = tvm.ir.IRModule({"main": func}) + mod = tvm.ir.IRModule({"main": func.with_attr("global_symbol", "main")}) target = tvm.target.Target("cuda") mod = tvm.tir.transform.Apply( lambda f: f.with_attr( { - "global_symbol": "test", "tir.is_host_func": True, } ) diff --git a/tests/python/tir-base/test_tir_intrin.py b/tests/python/tir-base/test_tir_intrin.py index 24b0bc90cc2a..e50219d2aea9 100644 --- a/tests/python/tir-base/test_tir_intrin.py +++ b/tests/python/tir-base/test_tir_intrin.py @@ -307,7 +307,7 @@ class Module: @T.prim_func def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) n = T.int32() stride = T.int32() stride_1 = T.int32() diff --git a/tests/python/tir-transform/test_tir_transform_device_kernel_launch.py b/tests/python/tir-transform/test_tir_transform_device_kernel_launch.py index e9c5d9776409..cb3b5dea2ff8 100644 --- a/tests/python/tir-transform/test_tir_transform_device_kernel_launch.py +++ b/tests/python/tir-transform/test_tir_transform_device_kernel_launch.py @@ -92,7 +92,7 @@ def main(A: T.Buffer(1, "float32")): @T.prim_func def kernel(A_data: T.handle("float32")): - T.func_attr({"target": T.target("cuda"), "global_symbol": "kernel_by_another_name"}) + T.func_attr({"target": T.target("cuda")}) A = T.decl_buffer(1, dtype="float32", data=A_data) A[0] = 0.0 @@ -101,7 +101,7 @@ class Expected: @T.prim_func def main(A: T.Buffer(1, "float32")): T.func_attr({"target": T.target("llvm")}) - T.call_packed("kernel_by_another_name", A.data) + T.call_packed("kernel", A.data) @T.prim_func def kernel(A_data: T.handle("float32")): @@ -110,7 +110,7 @@ def kernel(A_data: T.handle("float32")): "target": T.target("cuda"), "calling_conv": 2, "tir.kernel_launch_params": [], - "global_symbol": "kernel_by_another_name", + "global_symbol": "kernel", "tir.is_global_func": True, } ) From ab6d39a411fadb13e28d1cd44946b3aa249ac41c Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 26 Feb 2026 21:24:56 +0000 Subject: [PATCH 2/2] remove unused noqa E722 directives --- python/tvm/runtime/_tensor.py | 2 +- python/tvm/s_tir/dlight/analysis/common_analysis.py | 1 - tests/python/relax/test_group_gemm_flashinfer.py | 2 +- tests/scripts/release/make_notes.py | 1 - 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/runtime/_tensor.py b/python/tvm/runtime/_tensor.py index 9eea8a36d2e9..13557f9bd97a 100644 --- a/python/tvm/runtime/_tensor.py +++ b/python/tvm/runtime/_tensor.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-import, redefined-outer-name -# ruff: noqa: E722, F401, RUF005 +# ruff: noqa: F401, RUF005 """Runtime Tensor API""" import ctypes diff --git a/python/tvm/s_tir/dlight/analysis/common_analysis.py b/python/tvm/s_tir/dlight/analysis/common_analysis.py index fb44d3fceae8..3a148361092e 100644 --- a/python/tvm/s_tir/dlight/analysis/common_analysis.py +++ b/python/tvm/s_tir/dlight/analysis/common_analysis.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E722 # pylint: disable=missing-function-docstring, missing-class-docstring # pylint: disable=unused-argument, unused-variable diff --git a/tests/python/relax/test_group_gemm_flashinfer.py b/tests/python/relax/test_group_gemm_flashinfer.py index c128fdb35bb5..2d157584904a 100644 --- a/tests/python/relax/test_group_gemm_flashinfer.py +++ b/tests/python/relax/test_group_gemm_flashinfer.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E501, E722, F401, F841, RUF005 +# ruff: noqa: E501, F401, F841, RUF005 """Test for FlashInfer GroupedGemm TVM integration""" diff --git a/tests/scripts/release/make_notes.py b/tests/scripts/release/make_notes.py index b599f4b4f305..82e5a4372b0a 100644 --- a/tests/scripts/release/make_notes.py +++ b/tests/scripts/release/make_notes.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E722 import argparse import csv