Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/distributed_level_runtime.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def my_orch(w, args):
payload.callable = chip_callable.buffer_ptr()
payload.args = task_args.__ptr__()
payload.block_dim = 24
r = w.submit(WorkerType.CHIP, payload, outputs=[64])
r = w.submit(WorkerType.NEXT_LEVEL, payload, outputs=[64])

# SubWorker task: runs Python callable, depends on chip output
sub_p = WorkerPayload()
Expand All @@ -254,7 +254,7 @@ def my_orch(w, args):
args_list.append(a.__ptr__())

# 1 DAG node, 4 chips execute in parallel
w.submit(WorkerType.CHIP, payload, args_list=args_list, outputs=[out_size])
w.submit(WorkerType.NEXT_LEVEL, payload, args_list=args_list, outputs=[out_size])
```

### Why It's Uniform
Expand Down
23 changes: 10 additions & 13 deletions python/bindings/dist_worker_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ namespace nb = nanobind;

inline void bind_dist_worker(nb::module_ &m) {
// --- WorkerType ---
nb::enum_<WorkerType>(m, "WorkerType")
.value("CHIP", WorkerType::CHIP)
.value("SUB", WorkerType::SUB)
.value("DIST", WorkerType::DIST);
nb::enum_<WorkerType>(m, "WorkerType").value("NEXT_LEVEL", WorkerType::NEXT_LEVEL).value("SUB", WorkerType::SUB);

// --- TaskState ---
nb::enum_<TaskState>(m, "TaskState")
Expand Down Expand Up @@ -167,27 +164,27 @@ inline void bind_dist_worker(nb::module_ &m) {
)

.def(
"add_chip_worker",
"add_next_level_worker",
[](DistWorker &self, DistWorker &w) {
self.add_worker(WorkerType::CHIP, &w);
self.add_worker(WorkerType::NEXT_LEVEL, &w);
},
nb::arg("worker"), "Add a lower-level DistWorker as a CHIP sub-worker (for L4+)."
nb::arg("worker"), "Add a lower-level DistWorker as a NEXT_LEVEL sub-worker."
)

.def(
"add_chip_worker_native",
"add_next_level_worker",
[](DistWorker &self, ChipWorker &w) {
self.add_worker(WorkerType::CHIP, &w);
self.add_worker(WorkerType::NEXT_LEVEL, &w);
},
nb::arg("worker"), "Add a ChipWorker (_ChipWorker) as a CHIP sub-worker (for L3)."
nb::arg("worker"), "Add a ChipWorker as a NEXT_LEVEL sub-worker."
)

.def(
"add_chip_process",
"add_next_level_worker",
[](DistWorker &self, DistChipProcess &w) {
self.add_worker(WorkerType::CHIP, &w);
self.add_worker(WorkerType::NEXT_LEVEL, &w);
},
nb::arg("worker"), "Add a forked ChipProcess as a CHIP sub-worker (process-isolated)."
nb::arg("worker"), "Add a forked process as a NEXT_LEVEL sub-worker."
)

.def(
Expand Down
34 changes: 20 additions & 14 deletions python/simpler/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
w.init()

def my_orch(w, args):
r = w.submit(WorkerType.CHIP, chip_payload, inputs=[...], outputs=[64])
r = w.submit(WorkerType.NEXT_LEVEL, chip_payload, inputs=[...], outputs=[64])
w.submit(WorkerType.SUB, sub_payload(cid), inputs=[r.outputs[0].ptr])

w.run(Task(orch=my_orch, args=my_args))
Expand Down Expand Up @@ -235,6 +235,8 @@ def __init__(self, level: int, **config) -> None:

def register(self, fn: Callable) -> int:
"""Register a callable for SubWorker use. Must be called before init()."""
if self.level < 3:
raise RuntimeError("Worker.register() is only available at level 3+")
if self._initialized:
raise RuntimeError("Worker.register() must be called before init()")
cid = len(self._callable_registry)
Expand Down Expand Up @@ -365,7 +367,7 @@ def _start_level3(self) -> None:
for shm in self._chip_shms:
cp = DistChipProcess(_mailbox_addr(shm), self._l3_args_size)
self._dist_chip_procs.append(cp)
dw.add_chip_process(cp)
dw.add_next_level_worker(cp)

for shm in self._shms:
sw = DistSubWorker(_mailbox_addr(shm))
Expand All @@ -391,19 +393,8 @@ def run(self, task_or_payload, args=None, **kwargs) -> None:
if self.level == 2:
assert self._chip_worker is not None
if isinstance(task_or_payload, WorkerPayload):
from .task_interface import ChipCallConfig # noqa: PLC0415

config = ChipCallConfig()
config.block_dim = task_or_payload.block_dim
config.aicpu_thread_num = task_or_payload.aicpu_thread_num
config.enable_profiling = task_or_payload.enable_profiling
self._chip_worker.run(
task_or_payload.callable, # type: ignore[arg-type]
task_or_payload.args,
config,
)
self._run_l2_from_payload(task_or_payload)
else:
# run(callable, args, **kwargs)
self._chip_worker.run(task_or_payload, args, **kwargs)
else:
self._start_level3()
Expand All @@ -412,6 +403,21 @@ def run(self, task_or_payload, args=None, **kwargs) -> None:
task.orch(self, task.args)
self._dist_worker.drain()

def _run_l2_from_payload(self, payload: WorkerPayload) -> None:
"""Unpack a WorkerPayload and forward to ChipWorker (L2 only)."""
from .task_interface import ChipCallConfig # noqa: PLC0415

assert self._chip_worker is not None
config = ChipCallConfig()
config.block_dim = payload.block_dim
config.aicpu_thread_num = payload.aicpu_thread_num
config.enable_profiling = payload.enable_profiling
self._chip_worker.run(
payload.callable, # type: ignore[arg-type]
payload.args,
config,
)

# ------------------------------------------------------------------
# Orchestration API (called from inside orch functions at L3+)
# ------------------------------------------------------------------
Expand Down
12 changes: 6 additions & 6 deletions src/common/distributed/dist_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void DistScheduler::start(const Config &cfg) {
threads.push_back(std::move(wt));
}
};
make_threads(cfg_.chip_workers, chip_threads_);
make_threads(cfg_.next_level_workers, next_level_threads_);
make_threads(cfg_.sub_workers, sub_threads_);

stop_requested_.store(false, std::memory_order_relaxed);
Expand All @@ -95,11 +95,11 @@ void DistScheduler::stop() {

if (sched_thread_.joinable()) sched_thread_.join();

for (auto &wt : chip_threads_)
for (auto &wt : next_level_threads_)
wt->stop();
for (auto &wt : sub_threads_)
wt->stop();
chip_threads_.clear();
next_level_threads_.clear();
sub_threads_.clear();

running_.store(false, std::memory_order_release);
Expand Down Expand Up @@ -157,7 +157,7 @@ void DistScheduler::run() {
// Exit when stop requested and all workers idle
if (stop_requested_.load(std::memory_order_acquire)) {
bool any_busy = false;
for (auto &wt : chip_threads_)
for (auto &wt : next_level_threads_)
if (!wt->idle()) {
any_busy = true;
break;
Expand Down Expand Up @@ -268,15 +268,15 @@ void DistScheduler::dispatch_ready() {
}

WorkerThread *DistScheduler::pick_idle(WorkerType type) {
auto &threads = (type == WorkerType::CHIP) ? chip_threads_ : sub_threads_;
auto &threads = (type == WorkerType::NEXT_LEVEL) ? next_level_threads_ : sub_threads_;
for (auto &wt : threads) {
if (wt->idle()) return wt.get();
}
return nullptr;
}

std::vector<WorkerThread *> DistScheduler::pick_n_idle(WorkerType type, int n) {
auto &threads = (type == WorkerType::CHIP) ? chip_threads_ : sub_threads_;
auto &threads = (type == WorkerType::NEXT_LEVEL) ? next_level_threads_ : sub_threads_;
std::vector<WorkerThread *> result;
result.reserve(n);
for (auto &wt : threads) {
Expand Down
6 changes: 3 additions & 3 deletions src/common/distributed/dist_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ class DistScheduler {
DistTaskSlotState *slots;
int32_t num_slots;
DistReadyQueue *ready_queue;
std::vector<IWorker *> chip_workers; // WorkerType::CHIP
std::vector<IWorker *> sub_workers; // WorkerType::SUB
std::vector<IWorker *> next_level_workers; // WorkerType::NEXT_LEVEL
std::vector<IWorker *> sub_workers; // WorkerType::SUB
// Called when a task reaches CONSUMED (TensorMap cleanup + ring release).
std::function<void(DistTaskSlot)> on_consumed_cb;
};
Expand All @@ -104,7 +104,7 @@ class DistScheduler {
Config cfg_;

// Per-worker threads
std::vector<std::unique_ptr<WorkerThread>> chip_threads_;
std::vector<std::unique_ptr<WorkerThread>> next_level_threads_;
std::vector<std::unique_ptr<WorkerThread>> sub_threads_;

// Shared completion queue (WorkerThread → Scheduler)
Expand Down
7 changes: 3 additions & 4 deletions src/common/distributed/dist_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ using DistTaskSlot = int32_t;
// =============================================================================

enum class WorkerType : int32_t {
CHIP = 0, // ChipWorker: L2 hardware device
SUB = 1, // SubWorker: fork/shm Python function
DIST = 2, // DistWorker: lower-level node (L4+)
NEXT_LEVEL = 0, // Next-level Worker (L3→ChipWorker, L4→DistWorker(L3), …)
SUB = 1, // SubWorker: fork/shm Python function
};

// =============================================================================
Expand All @@ -75,7 +74,7 @@ enum class TaskState : int32_t {

struct WorkerPayload {
DistTaskSlot task_slot = DIST_INVALID_SLOT;
WorkerType worker_type = WorkerType::CHIP;
WorkerType worker_type = WorkerType::NEXT_LEVEL;

// --- ChipWorker fields (set in PR 2-2) ---
const void *callable = nullptr; // ChipCallable buffer ptr
Expand Down
4 changes: 2 additions & 2 deletions src/common/distributed/dist_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ DistWorker::~DistWorker() {

void DistWorker::add_worker(WorkerType type, IWorker *worker) {
if (initialized_) throw std::runtime_error("DistWorker: add_worker after init");
if (type == WorkerType::CHIP || type == WorkerType::DIST) chip_workers_.push_back(worker);
if (type == WorkerType::NEXT_LEVEL) next_level_workers_.push_back(worker);
else sub_workers_.push_back(worker);
}

Expand All @@ -38,7 +38,7 @@ void DistWorker::init() {
cfg.slots = slots_.get();
cfg.num_slots = DIST_TASK_WINDOW_SIZE;
cfg.ready_queue = &ready_queue_;
cfg.chip_workers = chip_workers_;
cfg.next_level_workers = next_level_workers_;
cfg.sub_workers = sub_workers_;
cfg.on_consumed_cb = [this](DistTaskSlot slot) {
on_consumed(slot);
Expand Down
6 changes: 3 additions & 3 deletions src/common/distributed/dist_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
* Usage (L3 host worker, instantiated from Python via nanobind):
*
* DistWorker dw(level=3);
* dw.add_worker(WorkerType::CHIP, chip_worker_ptr);
* dw.add_worker(WorkerType::NEXT_LEVEL, chip_worker_ptr);
* dw.add_worker(WorkerType::SUB, sub_worker_ptr);
* dw.init();
*
Expand All @@ -32,7 +32,7 @@
* dw.execute(); // blocks until all submitted tasks complete
*
* // When used as an IWorker by a higher-level DistWorker (L4+):
* parent.add_worker(WorkerType::DIST, &dw);
* parent.add_worker(WorkerType::NEXT_LEVEL, &dw);
* // parent scheduler calls dw.dispatch() / dw.poll()
*/

Expand Down Expand Up @@ -107,7 +107,7 @@ class DistWorker : public IWorker {
DistOrchestrator orchestrator_;
DistScheduler scheduler_;

std::vector<IWorker *> chip_workers_;
std::vector<IWorker *> next_level_workers_;
std::vector<IWorker *> sub_workers_;

// --- Drain support ---
Expand Down
4 changes: 2 additions & 2 deletions tests/st/a2a3/tensormap_and_ringbuffer/test_l3_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def run_dag(w, callables, task_args, config):
callables.keep(chip_args) # prevent GC before drain

chip_p = WorkerPayload()
chip_p.worker_type = WorkerType.CHIP
chip_p.worker_type = WorkerType.NEXT_LEVEL
chip_p.callable = callables.vector_kernel.buffer_ptr()
chip_p.args = chip_args.__ptr__()
chip_p.block_dim = config.block_dim
chip_p.aicpu_thread_num = config.aicpu_thread_num
chip_result = w.submit(WorkerType.CHIP, chip_p, inputs=[], outputs=[task_args.f.numel() * 4])
chip_result = w.submit(WorkerType.NEXT_LEVEL, chip_p, inputs=[], outputs=[task_args.f.numel() * 4])

sub_p = WorkerPayload()
sub_p.worker_type = WorkerType.SUB
Expand Down
4 changes: 2 additions & 2 deletions tests/st/a2a3/tensormap_and_ringbuffer/test_l3_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def run_dag(w, callables, task_args, config):
callables.keep(args0, args1) # prevent GC before drain

chip_p = WorkerPayload()
chip_p.worker_type = WorkerType.CHIP
chip_p.worker_type = WorkerType.NEXT_LEVEL
chip_p.callable = callables.vector_kernel.buffer_ptr()
chip_p.block_dim = config.block_dim
chip_p.aicpu_thread_num = config.aicpu_thread_num
group_result = w.submit(WorkerType.CHIP, chip_p, args_list=[args0.__ptr__(), args1.__ptr__()], outputs=[4])
group_result = w.submit(WorkerType.NEXT_LEVEL, chip_p, args_list=[args0.__ptr__(), args1.__ptr__()], outputs=[4])

sub_p = WorkerPayload()
sub_p.worker_type = WorkerType.SUB
Expand Down
23 changes: 12 additions & 11 deletions tests/ut/cpp/test_dist_orchestrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ struct OrchestratorFixture : public ::testing::Test {

void TearDown() override { ring.shutdown(); }

// Submit a CHIP task with the given input/output specs.
DistSubmitResult submit_chip(const std::vector<DistInputSpec> &inputs, const std::vector<DistOutputSpec> &outputs) {
// Submit a NEXT_LEVEL task with the given input/output specs.
DistSubmitResult
submit_next_level(const std::vector<DistInputSpec> &inputs, const std::vector<DistOutputSpec> &outputs) {
WorkerPayload p;
p.worker_type = WorkerType::CHIP;
return orch.submit(WorkerType::CHIP, p, inputs, outputs);
p.worker_type = WorkerType::NEXT_LEVEL;
return orch.submit(WorkerType::NEXT_LEVEL, p, inputs, outputs);
}
};

Expand All @@ -52,7 +53,7 @@ struct OrchestratorFixture : public ::testing::Test {
// ---------------------------------------------------------------------------

TEST_F(OrchestratorFixture, IndependentTaskIsImmediatelyReady) {
auto res = submit_chip({}, {{64}});
auto res = submit_next_level({}, {{64}});
EXPECT_NE(res.task_slot, DIST_INVALID_SLOT);
ASSERT_EQ(res.outputs.size(), 1u);
EXPECT_NE(res.outputs[0].ptr, nullptr);
Expand All @@ -65,14 +66,14 @@ TEST_F(OrchestratorFixture, IndependentTaskIsImmediatelyReady) {

TEST_F(OrchestratorFixture, DependentTaskIsPending) {
// Task A produces a buffer
auto a = submit_chip({}, {{128}});
auto a = submit_next_level({}, {{128}});
DistTaskSlot a_slot;
rq.try_pop(a_slot); // drain ready queue

uint64_t a_out = reinterpret_cast<uint64_t>(a.outputs[0].ptr);

// Task B depends on A's output
auto b = submit_chip({{a_out}}, {{64}});
auto b = submit_next_level({{a_out}}, {{64}});
EXPECT_EQ(slots[b.task_slot].state.load(), TaskState::PENDING);
EXPECT_EQ(slots[b.task_slot].fanin_count, 1);

Expand All @@ -81,7 +82,7 @@ TEST_F(OrchestratorFixture, DependentTaskIsPending) {
}

TEST_F(OrchestratorFixture, TensorMapTracksProducer) {
auto a = submit_chip({}, {{256}});
auto a = submit_next_level({}, {{256}});
DistTaskSlot drain_slot;
rq.try_pop(drain_slot);

Expand All @@ -90,7 +91,7 @@ TEST_F(OrchestratorFixture, TensorMapTracksProducer) {
}

TEST_F(OrchestratorFixture, OnConsumedCleansUpTensorMap) {
auto a = submit_chip({}, {{64}});
auto a = submit_next_level({}, {{64}});
DistTaskSlot slot;
rq.try_pop(slot);

Expand All @@ -107,7 +108,7 @@ TEST_F(OrchestratorFixture, OnConsumedCleansUpTensorMap) {

TEST_F(OrchestratorFixture, ScopeRegistersAndReleasesRef) {
orch.scope_begin();
auto a = submit_chip({}, {{64}});
auto a = submit_next_level({}, {{64}});
DistTaskSlot slot;
rq.try_pop(slot);

Expand All @@ -126,7 +127,7 @@ TEST_F(OrchestratorFixture, ScopeRegistersAndReleasesRef) {
}

TEST_F(OrchestratorFixture, MultipleOutputsAllocated) {
auto res = submit_chip({}, {{32}, {64}, {128}});
auto res = submit_next_level({}, {{32}, {64}, {128}});
ASSERT_EQ(res.outputs.size(), 3u);
EXPECT_EQ(res.outputs[0].size, 32u);
EXPECT_EQ(res.outputs[1].size, 64u);
Expand Down
Loading
Loading