From ec79f86fa3810ed89ca1f6bc62e4cc45524cd1b1 Mon Sep 17 00:00:00 2001 From: liyuzhuo Date: Sun, 15 Feb 2026 00:46:35 +0800 Subject: [PATCH] [fix] fix import error cased by _load_cuda_libs when running pytest --- transformer_engine/plugin/core/backends/vendor/cuda/cuda.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/plugin/core/backends/vendor/cuda/cuda.py b/transformer_engine/plugin/core/backends/vendor/cuda/cuda.py index 8be7dd5052..bb3f6daa1d 100644 --- a/transformer_engine/plugin/core/backends/vendor/cuda/cuda.py +++ b/transformer_engine/plugin/core/backends/vendor/cuda/cuda.py @@ -65,7 +65,11 @@ def try_load_lib(name, search_patterns): try_load_lib("nvrtc", [f"libnvrtc{ext}*"]) try_load_lib("curand", [f"libcurand{ext}*"]) - te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent + te_path_override = os.environ.get("TE_LIB_PATH") + if te_path_override: + te_path = Path(te_path_override) + else: + te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent for search_dir in [te_path, te_path / "transformer_engine"]: if search_dir.exists(): matches = list(search_dir.glob(f"libtransformer_engine{ext}*"))