-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathconftest.py
More file actions
366 lines (301 loc) · 14 KB
/
conftest.py
File metadata and controls
366 lines (301 loc) · 14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""Root conftest — CLI options, markers, ST platform filtering, runtime isolation, and ST fixtures.
Runtime isolation: CANN's AICPU framework caches the user .so per device context.
Switching runtimes on the same device within one process causes hangs. When multiple
runtimes are collected and --runtime is not specified, pytest_runtestloop spawns a
subprocess per runtime so each gets a clean CANN context. See docs/testing.md.
"""
from __future__ import annotations
import os
import signal
import subprocess
import sys
import pytest
# Exit code used when the session watchdog fires. Matches the GNU `timeout`
# convention so shell wrappers (e.g. CI) can distinguish timeout from other
# failures.
TIMEOUT_EXIT_CODE = 124
def _parse_device_range(s: str) -> list[int]:
"""Parse '4-7' -> [4,5,6,7] or '0' -> [0]."""
if "-" in s:
start, end = s.split("-", 1)
return list(range(int(start), int(end) + 1))
return [int(s)]
class DevicePool:
"""Device allocator for pytest fixtures.
Manages a fixed set of device IDs. Tests allocate IDs before use
and release them after. Works identically for sim and onboard.
"""
def __init__(self, device_ids: list[int]):
self._available = list(device_ids)
def allocate(self, n: int = 1) -> list[int]:
if n > len(self._available):
return []
allocated = self._available[:n]
self._available = self._available[n:]
return allocated
def release(self, ids: list[int]) -> None:
self._available.extend(ids)
_device_pool: DevicePool | None = None
def pytest_addoption(parser):
"""Register CLI options."""
parser.addoption("--platform", action="store", default=None, help="Target platform (e.g., a2a3sim, a2a3)")
parser.addoption("--device", action="store", default="0", help="Device ID or range (e.g., 0, 4-7)")
parser.addoption(
"--case",
action="append",
default=None,
help="Case selector; repeatable. Forms: 'Foo' (any class), 'ClassA::Foo', 'ClassA::' (whole class).",
)
parser.addoption(
"--manual",
action="store",
choices=["exclude", "include", "only"],
default="exclude",
help="Manual case handling: exclude (default), include, only",
)
parser.addoption("--runtime", action="store", default=None, help="Only run tests for this runtime")
parser.addoption("--rounds", type=int, default=1, help="Run each case N times (default: 1)")
parser.addoption(
"--skip-golden", action="store_true", default=False, help="Skip golden comparison (benchmark mode)"
)
parser.addoption(
"--enable-profiling", action="store_true", default=False, help="Enable profiling (first round only)"
)
parser.addoption("--dump-tensor", action="store_true", default=False, help="Dump per-task tensor I/O at runtime")
parser.addoption("--build", action="store_true", default=False, help="Compile runtime from source")
parser.addoption(
"--pto-isa-commit",
action="store",
default=None,
help="Pin pto-isa clone to this commit before running tests",
)
parser.addoption(
"--clone-protocol",
action="store",
default="ssh",
choices=["ssh", "https"],
help="Protocol for cloning pto-isa when --pto-isa-commit is set",
)
# Distinct from pytest-timeout's per-test --timeout (which `.[test]` pulls
# in on the a2a3 hardware runner); this is session-level.
parser.addoption(
"--pto-session-timeout",
action="store",
type=int,
default=0,
help=(f"Abort whole pytest session after N seconds (0 = disabled; exit code {TIMEOUT_EXIT_CODE} on timeout)"),
)
def _install_session_timeout(timeout_s: int) -> None:
def _handler(signum, frame):
print(
f"\n{'=' * 40}\n"
f"[pytest] TIMEOUT: session exceeded {timeout_s}s "
f"({timeout_s // 60}min) limit, aborting\n"
f"{'=' * 40}",
flush=True,
)
os._exit(TIMEOUT_EXIT_CODE)
# signal.alarm / SIGALRM are Unix-only; skip silently on platforms without
# them so --pto-session-timeout is a no-op rather than a crash (e.g. Windows).
if hasattr(signal, "alarm") and hasattr(signal, "SIGALRM"):
signal.signal(signal.SIGALRM, _handler)
signal.alarm(timeout_s)
def pytest_configure(config):
"""Register custom markers and apply global config."""
config.addinivalue_line("markers", "platforms(list): supported platforms for standalone ST functions")
config.addinivalue_line("markers", "requires_hardware: test needs Ascend toolchain and real device")
config.addinivalue_line("markers", "device_count(n): number of NPU devices needed")
config.addinivalue_line(
"markers",
"runtime(name): runtime this standalone test targets; used by runtime-isolation subprocess "
"filtering so non-@scene_test tests only run under their matching runtime",
)
log_level = config.getoption("--log-level", default=None)
if log_level:
os.environ["PTO_LOG_LEVEL"] = log_level
commit = config.getoption("--pto-isa-commit")
if commit:
from simpler_setup.pto_isa import ensure_pto_isa_root # noqa: PLC0415
root = ensure_pto_isa_root(
verbose=True,
commit=commit,
clone_protocol=config.getoption("--clone-protocol"),
)
if root:
os.environ["PTO_ISA_ROOT"] = root
timeout = config.getoption("--pto-session-timeout")
if timeout and timeout > 0:
_install_session_timeout(timeout)
def pytest_collection_modifyitems(session, config, items):
"""Skip ST tests based on --platform and --runtime filters, and order L3 before L2."""
platform = config.getoption("--platform")
runtime_filter = config.getoption("--runtime")
# Sort: L3 tests first (they fork child processes that inherit main process CANN state,
# so they must run before L2 tests pollute the CANN context).
def sort_key(item):
cls = getattr(item, "cls", None)
level = getattr(cls, "_st_level", 0) if cls else 0
return (0 if level >= 3 else 1, item.nodeid)
items.sort(key=sort_key)
for item in items:
cls = getattr(item, "cls", None)
if cls and hasattr(cls, "CASES") and isinstance(cls.CASES, list):
if not platform:
item.add_marker(pytest.mark.skip(reason="--platform required"))
elif not any(platform in c.get("platforms", []) for c in cls.CASES):
item.add_marker(pytest.mark.skip(reason=f"No cases for {platform}"))
elif runtime_filter and getattr(cls, "_st_runtime", None) != runtime_filter:
item.add_marker(
pytest.mark.skip(reason=f"Runtime {getattr(cls, '_st_runtime', '?')} != {runtime_filter}")
)
continue
platforms_marker = item.get_closest_marker("platforms")
if platforms_marker:
if not platform:
item.add_marker(pytest.mark.skip(reason="--platform required"))
elif platform not in platforms_marker.args[0]:
item.add_marker(pytest.mark.skip(reason=f"Not supported on {platform}"))
# runtime-isolation filter for non-@scene_test tests: if the item declares
# `@pytest.mark.runtime("X")` and a --runtime filter is active, skip when
# they don't match. Prevents test_explicit_fatal_reports and friends from
# running under every runtime's subprocess.
runtime_marker = item.get_closest_marker("runtime")
if runtime_marker and runtime_marker.args and runtime_filter and runtime_marker.args[0] != runtime_filter:
item.add_marker(pytest.mark.skip(reason=f"Runtime {runtime_marker.args[0]} != {runtime_filter}"))
# ---------------------------------------------------------------------------
# Runtime isolation: spawn subprocess per runtime
# ---------------------------------------------------------------------------
def _collect_st_runtimes(items):
"""Return sorted list of unique runtimes from collected SceneTestCase items."""
runtimes = set()
for item in items:
cls = getattr(item, "cls", None)
rt = getattr(cls, "_st_runtime", None) if cls else None
if rt:
runtimes.add(rt)
return sorted(runtimes)
def pytest_runtestloop(session):
"""Override test execution to isolate runtimes in subprocesses.
If --runtime is specified (or only one runtime collected), run normally.
Otherwise, spawn one subprocess per runtime and aggregate results.
"""
runtime_filter = session.config.getoption("--runtime")
if runtime_filter:
return # single runtime — let pytest run normally
runtimes = _collect_st_runtimes(session.items)
if len(runtimes) <= 1:
return # zero or one runtime — no isolation needed
# Multiple runtimes: spawn subprocess per runtime
# Re-invoke pytest with the same args + --runtime <rt> for each runtime
base_args = [sys.executable, "-m", "pytest"]
for arg in session.config.invocation_params.args:
base_args.append(str(arg))
failed = False
for rt in runtimes:
# Build subprocess command: inject --runtime <rt>
cmd = base_args + ["--runtime", rt]
header = f" Runtime: {rt}"
print(f"\n{'=' * 60}\n{header}\n{'=' * 60}\n", flush=True)
result = subprocess.run(cmd, check=False, cwd=session.config.invocation_params.dir)
if result.returncode == TIMEOUT_EXIT_CODE:
print(f"\n*** Runtime {rt}: TIMED OUT ***\n", flush=True)
os._exit(TIMEOUT_EXIT_CODE)
if result.returncode != 0:
failed = True
print(f"\n*** Runtime {rt}: FAILED ***\n", flush=True)
else:
print(f"\n--- Runtime {rt}: PASSED ---\n", flush=True)
if failed:
session.testsfailed = 1
else:
session.testscollected = sum(1 for _ in session.items)
session.testsfailed = 0
return True # returning True prevents default runtestloop
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
def device_pool(request):
"""Session-scoped device pool parsed from --device."""
global _device_pool # noqa: PLW0603
if _device_pool is None:
raw = request.config.getoption("--device")
_device_pool = DevicePool(_parse_device_range(raw))
return _device_pool
@pytest.fixture(scope="session")
def st_platform(request):
"""Platform from --platform CLI flag."""
p = request.config.getoption("--platform")
if not p:
pytest.skip("--platform required for ST tests")
return p
@pytest.fixture()
def st_worker(request, st_platform, device_pool):
"""Per-test Worker with devices allocated from pool.
Reads _st_level and CASES from the test class to determine
how many devices and sub-workers to allocate.
"""
cls = request.node.cls
if cls is None or not hasattr(cls, "_st_level"):
pytest.skip("st_worker requires SceneTestCase")
level = cls._st_level
runtime = cls._st_runtime
build = request.config.getoption("--build", default=False)
if level == 2:
ids = device_pool.allocate(1)
if not ids:
pytest.fail(f"no devices available in --device pool (requested 1, pool has {len(device_pool._available)})")
from simpler.worker import Worker # noqa: PLC0415
w = Worker(level=2, device_id=ids[0], platform=st_platform, runtime=runtime, build=build)
w._st_device_id = ids[0] # expose primary device to test_run for profiling snapshots
w.init()
yield w
w.close()
device_pool.release(ids)
elif level == 3:
max_devices = max((c.get("config", {}).get("device_count", 1) for c in cls.CASES), default=1)
max_subs = max((c.get("config", {}).get("num_sub_workers", 0) for c in cls.CASES), default=0)
ids = device_pool.allocate(max_devices)
if not ids:
pytest.fail(
f"need {max_devices} devices but --device pool has {len(device_pool._available)}; widen --device range"
)
from simpler.worker import Worker # noqa: PLC0415
w = Worker(
level=3,
device_ids=ids,
num_sub_workers=max_subs,
platform=st_platform,
runtime=runtime,
build=build,
)
w._st_device_id = ids[0] # expose primary device to test_run for profiling snapshots
# Register SubCallable entries from cls.CALLABLE
sub_ids = {}
for entry in cls.CALLABLE.get("callables", []):
if "callable" in entry:
cid = w.register(entry["callable"])
sub_ids[entry["name"]] = cid
cls._st_sub_ids = sub_ids
w.init()
yield w
w.close()
device_pool.release(ids)
@pytest.fixture()
def st_device_ids(request, device_pool):
"""Allocate device IDs. Use @pytest.mark.device_count(n) to request multiple."""
marker = request.node.get_closest_marker("device_count")
n = marker.args[0] if marker else 1
ids = device_pool.allocate(n)
if not ids:
pytest.fail(f"need {n} devices")
yield ids
device_pool.release(ids)