Skip to content

Commit 4b28dcd

Browse files
Refactor: L3 SceneTestCase — HostCallable as Python DAG + ChipCallable + SubCallable (#514)
Extend SceneTestCase to support L3 distributed tests (multi-chip + SubWorker). L3 CALLABLE mirrors L2 structurally: - orchestration: Python DAG function (vs C++ binary for L2) - callables: list of ChipCallable entries (compiled) + SubCallable entries (registered), distinguished by field presence Key additions: - CallableNamespace: dot-access container for orch functions to access compiled ChipCallables and registered SubCallable IDs by name, with keep() for lifetime management of transient objects past drain() - _compile_chip_callable_from_spec: extracted from compile_chip_callable, reused by both L2 and L3 compilation paths - _run_and_validate_l3: builds CallableNamespace, wraps orch in Task, compares all tensors against golden - conftest.py L3 branch: registers SubCallable entries from CALLABLE, reads device_count/num_sub_workers from case config dict - Rewrites test_l3_dependency.py and test_l3_group.py as SceneTestCase subclasses with module-level orch + sub functions Co-authored-by: wcwxy <26245345+ChaoWao@users.noreply.github.com>
1 parent a90b0a2 commit 4b28dcd

7 files changed

Lines changed: 476 additions & 316 deletions

File tree

conftest.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,24 @@ def st_worker(request, st_platform, device_pool):
137137
device_pool.release(ids)
138138

139139
elif level == 3:
140-
max_devices = max((c.get("device_count", 1) for c in cls.CASES), default=1)
141-
max_subs = max((c.get("num_sub_workers", 0) for c in cls.CASES), default=0)
140+
max_devices = max((c.get("config", {}).get("device_count", 1) for c in cls.CASES), default=1)
141+
max_subs = max((c.get("config", {}).get("num_sub_workers", 0) for c in cls.CASES), default=0)
142142
ids = device_pool.allocate(max_devices)
143143
if not ids:
144144
pytest.fail(f"need {max_devices} devices")
145145

146146
from simpler.worker import Worker # noqa: PLC0415
147147

148148
w = Worker(level=3, device_ids=ids, num_sub_workers=max_subs, platform=st_platform, runtime=runtime)
149+
150+
# Register SubCallable entries from cls.CALLABLE
151+
sub_ids = {}
152+
for entry in cls.CALLABLE.get("callables", []):
153+
if "callable" in entry:
154+
cid = w.register(entry["callable"])
155+
sub_ids[entry["name"]] = cid
156+
cls._st_sub_ids = sub_ids
157+
149158
w.init()
150159
yield w
151160
w.close()

python/bindings/task_interface.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,14 @@ NB_MODULE(_task_interface, m) {
222222
return reinterpret_cast<uint64_t>(&self);
223223
},
224224
"Return the memory address of the underlying C++ object."
225+
)
226+
227+
.def_static(
228+
"sizeof",
229+
[]() -> size_t {
230+
return sizeof(ChipStorageTaskArgs);
231+
},
232+
"Return sizeof(ChipStorageTaskArgs) in bytes."
225233
);
226234

227235
// --- TensorArgType enum ---

python/simpler/worker.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ def _mailbox_addr(shm: SharedMemory) -> int:
9393
return ctypes.addressof(ctypes.c_char.from_buffer(buf))
9494

9595

96+
def _args_size(csa_cls) -> int:
97+
"""Return sizeof(ChipStorageTaskArgs). Uses C++ binding if available, else heap probe."""
98+
if hasattr(csa_cls, "sizeof"):
99+
return csa_cls.sizeof()
100+
objs = [csa_cls() for _ in range(5)]
101+
ptrs = [o.__ptr__() for o in objs]
102+
return min(abs(ptrs[i + 1] - ptrs[i]) for i in range(len(ptrs) - 1))
103+
104+
96105
def _sub_worker_loop(buf, registry: dict) -> None:
97106
"""Runs in forked child process."""
98107
while True:
@@ -272,26 +281,14 @@ def _init_level3(self) -> None:
272281
device_ids = self._config.get("device_ids", [])
273282
n_sub = self._config.get("num_sub_workers", 0)
274283

275-
# 1. Allocate mailboxes
284+
# 1. Allocate sub-worker mailboxes
276285
for _ in range(n_sub):
277286
shm = SharedMemory(create=True, size=DIST_SUB_MAILBOX_SIZE)
278287
assert shm.buf is not None
279288
struct.pack_into("i", shm.buf, _OFF_STATE, _IDLE)
280289
self._shms.append(shm)
281290

282-
# 2. Fork SubWorker processes (MUST be before any C++ threads)
283-
registry = self._callable_registry
284-
for i in range(n_sub):
285-
pid = os.fork()
286-
if pid == 0:
287-
buf = self._shms[i].buf
288-
assert buf is not None
289-
_sub_worker_loop(buf, registry)
290-
os._exit(0)
291-
else:
292-
self._pids.append(pid)
293-
294-
# 3. Fork ChipWorker processes (only if device_ids provided)
291+
# 2. Prepare chip-worker config (but do NOT fork yet — deferred to _start_level3)
295292
if device_ids:
296293
from runtime_builder import RuntimeBuilder # noqa: PLC0415
297294

@@ -302,39 +299,71 @@ def _init_level3(self) -> None:
302299
builder = RuntimeBuilder(platform)
303300
binaries = builder.get_binaries(runtime, build=False)
304301

305-
# Determine args_size (sizeof ChipStorageTaskArgs)
306-
_objs = [_CSA() for _ in range(5)]
307-
_ptrs = [o.__ptr__() for o in _objs]
308-
args_size = min(abs(_ptrs[i + 1] - _ptrs[i]) for i in range(len(_ptrs) - 1))
309-
del _objs, _ptrs
310-
311-
host_lib_path = str(binaries.host_path)
312-
aicpu_path = str(binaries.aicpu_path)
313-
aicore_path = str(binaries.aicore_path)
314-
sim_ctx_path = str(binaries.sim_context_path) if getattr(binaries, "sim_context_path", None) else ""
302+
self._l3_args_size = _args_size(_CSA)
303+
self._l3_host_lib_path = str(binaries.host_path)
304+
self._l3_aicpu_path = str(binaries.aicpu_path)
305+
self._l3_aicore_path = str(binaries.aicore_path)
306+
self._l3_sim_ctx_path = (
307+
str(binaries.sim_context_path) if getattr(binaries, "sim_context_path", None) else ""
308+
)
315309

316-
for dev_id in device_ids:
310+
# Allocate chip mailboxes (shared memory, no fork yet)
311+
for _ in device_ids:
317312
shm = SharedMemory(create=True, size=DIST_CHIP_MAILBOX_SIZE)
318313
assert shm.buf is not None
319314
struct.pack_into("i", shm.buf, _CHIP_OFF_STATE, _IDLE)
320315
self._chip_shms.append(shm)
321316

317+
self._l3_started = False
318+
319+
def _start_level3(self) -> None:
320+
"""Fork child processes and start C++ scheduler. Called on first run()."""
321+
if self._l3_started:
322+
return
323+
self._l3_started = True
324+
325+
device_ids = self._config.get("device_ids", [])
326+
n_sub = self._config.get("num_sub_workers", 0)
327+
328+
# Fork SubWorker processes (MUST be before any C++ threads)
329+
registry = self._callable_registry
330+
for i in range(n_sub):
331+
pid = os.fork()
332+
if pid == 0:
333+
buf = self._shms[i].buf
334+
assert buf is not None
335+
_sub_worker_loop(buf, registry)
336+
os._exit(0)
337+
else:
338+
self._pids.append(pid)
339+
340+
# Fork ChipWorker processes
341+
if device_ids:
342+
for idx, dev_id in enumerate(device_ids):
322343
pid = os.fork()
323344
if pid == 0:
324-
buf = shm.buf
345+
buf = self._chip_shms[idx].buf
325346
assert buf is not None
326-
_chip_process_loop(buf, host_lib_path, dev_id, aicpu_path, aicore_path, sim_ctx_path, args_size)
347+
_chip_process_loop(
348+
buf,
349+
self._l3_host_lib_path,
350+
dev_id,
351+
self._l3_aicpu_path,
352+
self._l3_aicore_path,
353+
self._l3_sim_ctx_path,
354+
self._l3_args_size,
355+
)
327356
os._exit(0)
328357
else:
329358
self._chip_pids.append(pid)
330359

331-
# 4. Create DistWorker and wire chip processes + sub workers
360+
# Create DistWorker and wire chip processes + sub workers
332361
dw = DistWorker(3)
333362
self._dist_worker = dw
334363

335364
if device_ids:
336365
for shm in self._chip_shms:
337-
cp = DistChipProcess(_mailbox_addr(shm), args_size)
366+
cp = DistChipProcess(_mailbox_addr(shm), self._l3_args_size)
338367
self._dist_chip_procs.append(cp)
339368
dw.add_chip_process(cp)
340369

@@ -343,7 +372,7 @@ def _init_level3(self) -> None:
343372
self._dist_sub_workers.append(sw)
344373
dw.add_sub_worker(sw)
345374

346-
# 6. Start Scheduler + WorkerThreads (C++ threads start here, after fork)
375+
# Start Scheduler + WorkerThreads (C++ threads start here, after fork)
347376
dw.init()
348377

349378
# ------------------------------------------------------------------
@@ -377,6 +406,7 @@ def run(self, task_or_payload, args=None, **kwargs) -> None:
377406
# run(callable, args, **kwargs)
378407
self._chip_worker.run(task_or_payload, args, **kwargs)
379408
else:
409+
self._start_level3()
380410
assert self._dist_worker is not None
381411
task = task_or_payload
382412
task.orch(self, task.args)

simpler_setup/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from .platform_info import parse_platform
1414
from .pto_isa import ensure_pto_isa_root
1515
from .runtime_builder import RuntimeBuilder
16-
from .scene_test import Scalar, SceneTestCase, TaskArgsBuilder, Tensor, scene_test
16+
from .scene_test import CallableNamespace, Scalar, SceneTestCase, TaskArgsBuilder, Tensor, scene_test
1717

1818
__all__ = [
19+
"CallableNamespace",
1920
"KernelCompiler",
2021
"RuntimeBuilder",
2122
"Scalar",

0 commit comments

Comments
 (0)