Skip to content

Commit 8a95a98

Browse files
ChaoWaoclaude
andcommitted
Refactor: L3 SceneTestCase — HostCallable as Python DAG + ChipCallable + SubCallable
Extend SceneTestCase to support L3 distributed tests (multi-chip + SubWorker). L3 CALLABLE mirrors L2 structurally: - orchestration: Python DAG function (vs C++ binary for L2) - callables: list of ChipCallable entries (compiled) + SubCallable entries (registered), distinguished by field presence Key additions: - CallableNamespace: dot-access container for orch functions to access compiled ChipCallables and registered SubCallable IDs by name, with keep() for lifetime management of transient objects past drain() - _compile_chip_callable_from_spec: extracted from compile_chip_callable, reused by both L2 and L3 compilation paths - _run_and_validate_l3: builds CallableNamespace, wraps orch in Task, compares all tensors against golden - conftest.py L3 branch: registers SubCallable entries from CALLABLE, reads device_count/num_sub_workers from case config dict - Rewrites test_l3_dependency.py and test_l3_group.py as SceneTestCase subclasses with module-level orch + sub functions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8d5f25b commit 8a95a98

7 files changed

Lines changed: 467 additions & 316 deletions

File tree

conftest.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,24 @@ def st_worker(request, st_platform, device_pool):
137137
device_pool.release(ids)
138138

139139
elif level == 3:
140-
max_devices = max((c.get("device_count", 1) for c in cls.CASES), default=1)
141-
max_subs = max((c.get("num_sub_workers", 0) for c in cls.CASES), default=0)
140+
max_devices = max((c.get("config", {}).get("device_count", 1) for c in cls.CASES), default=1)
141+
max_subs = max((c.get("config", {}).get("num_sub_workers", 0) for c in cls.CASES), default=0)
142142
ids = device_pool.allocate(max_devices)
143143
if not ids:
144144
pytest.fail(f"need {max_devices} devices")
145145

146146
from simpler.worker import Worker # noqa: PLC0415
147147

148148
w = Worker(level=3, device_ids=ids, num_sub_workers=max_subs, platform=st_platform, runtime=runtime)
149+
150+
# Register SubCallable entries from cls.CALLABLE
151+
sub_ids = {}
152+
for entry in cls.CALLABLE.get("callables", []):
153+
if "callable" in entry:
154+
cid = w.register(entry["callable"])
155+
sub_ids[entry["name"]] = cid
156+
cls._st_sub_ids = sub_ids
157+
149158
w.init()
150159
yield w
151160
w.close()

python/bindings/task_interface.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,14 @@ NB_MODULE(_task_interface, m) {
222222
return reinterpret_cast<uint64_t>(&self);
223223
},
224224
"Return the memory address of the underlying C++ object."
225+
)
226+
227+
.def_static(
228+
"sizeof",
229+
[]() -> size_t {
230+
return sizeof(ChipStorageTaskArgs);
231+
},
232+
"Return sizeof(ChipStorageTaskArgs) in bytes."
225233
);
226234

227235
// --- TensorArgType enum ---

python/simpler/worker.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -272,26 +272,14 @@ def _init_level3(self) -> None:
272272
device_ids = self._config.get("device_ids", [])
273273
n_sub = self._config.get("num_sub_workers", 0)
274274

275-
# 1. Allocate mailboxes
275+
# 1. Allocate sub-worker mailboxes
276276
for _ in range(n_sub):
277277
shm = SharedMemory(create=True, size=DIST_SUB_MAILBOX_SIZE)
278278
assert shm.buf is not None
279279
struct.pack_into("i", shm.buf, _OFF_STATE, _IDLE)
280280
self._shms.append(shm)
281281

282-
# 2. Fork SubWorker processes (MUST be before any C++ threads)
283-
registry = self._callable_registry
284-
for i in range(n_sub):
285-
pid = os.fork()
286-
if pid == 0:
287-
buf = self._shms[i].buf
288-
assert buf is not None
289-
_sub_worker_loop(buf, registry)
290-
os._exit(0)
291-
else:
292-
self._pids.append(pid)
293-
294-
# 3. Fork ChipWorker processes (only if device_ids provided)
282+
# 2. Prepare chip-worker config (but do NOT fork yet — deferred to _start_level3)
295283
if device_ids:
296284
from runtime_builder import RuntimeBuilder # noqa: PLC0415
297285

@@ -302,39 +290,71 @@ def _init_level3(self) -> None:
302290
builder = RuntimeBuilder(platform)
303291
binaries = builder.get_binaries(runtime, build=False)
304292

305-
# Determine args_size (sizeof ChipStorageTaskArgs)
306-
_objs = [_CSA() for _ in range(5)]
307-
_ptrs = [o.__ptr__() for o in _objs]
308-
args_size = min(abs(_ptrs[i + 1] - _ptrs[i]) for i in range(len(_ptrs) - 1))
309-
del _objs, _ptrs
310-
311-
host_lib_path = str(binaries.host_path)
312-
aicpu_path = str(binaries.aicpu_path)
313-
aicore_path = str(binaries.aicore_path)
314-
sim_ctx_path = str(binaries.sim_context_path) if getattr(binaries, "sim_context_path", None) else ""
293+
self._l3_args_size = _CSA.sizeof()
294+
self._l3_host_lib_path = str(binaries.host_path)
295+
self._l3_aicpu_path = str(binaries.aicpu_path)
296+
self._l3_aicore_path = str(binaries.aicore_path)
297+
self._l3_sim_ctx_path = (
298+
str(binaries.sim_context_path) if getattr(binaries, "sim_context_path", None) else ""
299+
)
315300

316-
for dev_id in device_ids:
301+
# Allocate chip mailboxes (shared memory, no fork yet)
302+
for _ in device_ids:
317303
shm = SharedMemory(create=True, size=DIST_CHIP_MAILBOX_SIZE)
318304
assert shm.buf is not None
319305
struct.pack_into("i", shm.buf, _CHIP_OFF_STATE, _IDLE)
320306
self._chip_shms.append(shm)
321307

308+
self._l3_started = False
309+
310+
def _start_level3(self) -> None:
311+
"""Fork child processes and start C++ scheduler. Called on first run()."""
312+
if self._l3_started:
313+
return
314+
self._l3_started = True
315+
316+
device_ids = self._config.get("device_ids", [])
317+
n_sub = self._config.get("num_sub_workers", 0)
318+
319+
# Fork SubWorker processes (MUST be before any C++ threads)
320+
registry = self._callable_registry
321+
for i in range(n_sub):
322+
pid = os.fork()
323+
if pid == 0:
324+
buf = self._shms[i].buf
325+
assert buf is not None
326+
_sub_worker_loop(buf, registry)
327+
os._exit(0)
328+
else:
329+
self._pids.append(pid)
330+
331+
# Fork ChipWorker processes
332+
if device_ids:
333+
for idx, dev_id in enumerate(device_ids):
322334
pid = os.fork()
323335
if pid == 0:
324-
buf = shm.buf
336+
buf = self._chip_shms[idx].buf
325337
assert buf is not None
326-
_chip_process_loop(buf, host_lib_path, dev_id, aicpu_path, aicore_path, sim_ctx_path, args_size)
338+
_chip_process_loop(
339+
buf,
340+
self._l3_host_lib_path,
341+
dev_id,
342+
self._l3_aicpu_path,
343+
self._l3_aicore_path,
344+
self._l3_sim_ctx_path,
345+
self._l3_args_size,
346+
)
327347
os._exit(0)
328348
else:
329349
self._chip_pids.append(pid)
330350

331-
# 4. Create DistWorker and wire chip processes + sub workers
351+
# Create DistWorker and wire chip processes + sub workers
332352
dw = DistWorker(3)
333353
self._dist_worker = dw
334354

335355
if device_ids:
336356
for shm in self._chip_shms:
337-
cp = DistChipProcess(_mailbox_addr(shm), args_size)
357+
cp = DistChipProcess(_mailbox_addr(shm), self._l3_args_size)
338358
self._dist_chip_procs.append(cp)
339359
dw.add_chip_process(cp)
340360

@@ -343,7 +363,7 @@ def _init_level3(self) -> None:
343363
self._dist_sub_workers.append(sw)
344364
dw.add_sub_worker(sw)
345365

346-
# 6. Start Scheduler + WorkerThreads (C++ threads start here, after fork)
366+
# Start Scheduler + WorkerThreads (C++ threads start here, after fork)
347367
dw.init()
348368

349369
# ------------------------------------------------------------------
@@ -377,6 +397,7 @@ def run(self, task_or_payload, args=None, **kwargs) -> None:
377397
# run(callable, args, **kwargs)
378398
self._chip_worker.run(task_or_payload, args, **kwargs)
379399
else:
400+
self._start_level3()
380401
assert self._dist_worker is not None
381402
task = task_or_payload
382403
task.orch(self, task.args)

simpler_setup/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from .platform_info import parse_platform
1414
from .pto_isa import ensure_pto_isa_root
1515
from .runtime_builder import RuntimeBuilder
16-
from .scene_test import Scalar, SceneTestCase, TaskArgsBuilder, Tensor, scene_test
16+
from .scene_test import CallableNamespace, Scalar, SceneTestCase, TaskArgsBuilder, Tensor, scene_test
1717

1818
__all__ = [
19+
"CallableNamespace",
1920
"KernelCompiler",
2021
"RuntimeBuilder",
2122
"Scalar",

0 commit comments

Comments
 (0)