The standalone specification of the AMK megakernel intermediate representation, the "DWG-format" asset AMK owns. It defines the SM-level task-DAG, the counter-synchronization model and its two safety invariants, the buffer/page memory model, the
ScheduleConfigsearch surface, the on-disk JSON format, and the mapping to the on-device ABI invm/abi.h. The normative Python reference isschedule/ir.py; this document is written so an external team could build a compatible importer, validator, or VM. Where prose and code disagree,schedule/ir.pywins.Versions:
IR_VERSION = "0.2.0"(on-disk format),ABI_VERSION = "0.2"(on-device ABI, mirrorsAMK_ABI_VERSION_{MAJOR,MINOR}invm/abi.h). Minor bumps are additive/compatible;MegakernelProgram.from_dictrejects a major mismatch. The format is dependency-free pure Python + stdlib, you can read, validate, diff, and version a schedule on a laptop with no GPU.
A program is a schedule.ir.MegakernelProgram: a complete, runnable-after-validation megakernel
schedule. It is the artifact compile.py emits, the flywheel stores, and the VM loads. Its parts:
| Field | Type | Role |
|---|---|---|
meta |
dict |
{model, gpu, regime, dtype, notes…}, provenance for results.tsv. |
target |
GpuTarget | None |
the GPU described as data (the retargeting surface). |
buffers |
list[Buffer] |
every named tensor (weights, activations, KV, IO, consts). |
counters |
list[Counter] |
synchronization counters (the only cross-task signalling). |
tasks |
list[Task] |
the DAG nodes; each == one ABI instruction invocation on one SM. |
pages |
PageAllocation | None |
physical scratch-slot binding for activations (Layer 2 output). |
config |
ScheduleConfig | None |
the Loop-2 search point that produced this lowering. |
ir_version, abi_version |
str |
format/ABI versions. |
The task-DAG is the heart: nodes are Tasks, edges are producer→consumer relations induced
by counters (MegakernelProgram.dependency_edges()). A forward pass is a DAG; execution is a
topological walk of it with monotonic counters. There are no locks and no arbitrary
signalling, only counter increments and static-threshold waits.
DType(schedule.ir.DType):F32=0, F16=1, BF16=2, F8E4M3=3, F8E5M2=4, I32=5, I8=6, I4=7(packed 4-bit, two per byte),U8=8, BOOL=9..bits/.nbytes(count)give sizes (ceil for sub-byte packed types).MemSpace(schedule.ir.MemSpace):HBM=0(weights/KV/IO),GLOBAL_SCRATCH=1(large activations),SMEM=2(hot activations, the megakernel win),REGISTER=3(lowering hint).BufferKind(schedule.ir.BufferKind):WEIGHT=0,ACTIVATION=1,KV_CACHE=2,IO_INPUT=3,IO_OUTPUT=4,CONST=5. The read-only kinds (WEIGHT,CONST,IO_INPUT, i.e._READONLY_KINDS) never require a producer edge to read.InstructionKind(the opcodes, §7):NOP=0 … ATTENTION_COMBINE=18. Extend by appending; never renumber.
Buffer(id, name, kind: BufferKind, dtype: DType, shape: tuple[int,...],
space: MemSpace = HBM, source: str | None = None)
source is the dotted state-dict key for WEIGHT/CONST buffers (else None). Derived:
numel, nbytes, contiguous_strides() (row-major element strides → amk_buffer_t.stride).
Rank (len(shape)) MUST be <= ABI_MAX_RANK (4).
Task(id, op: InstructionKind, inputs: list[int], outputs: list[int], out_counter: int,
waits: list[Wait] = [], params: dict = {}, sm: int | None = None,
est_bytes=0, est_flops=0, label="")
inputs/outputsare buffer ids (read / written). Caps:<= ABI_MAX_INPUTS (8)inputs,<= ABI_MAX_OUTPUTS (4)outputs.out_counter: the single counter this task increments by 1 on completion, after a release fence ordering all its output-buffer writes. Meaning: "all my outputs are written and visible." Exactly one increment per task.waits: preconditions (<= ABI_MAX_WAITS (8)).params: op-specific scalars (§7); keys/types validated againstPARAM_FIELDS.sm: SM/worker assignment, a lowering OUTPUT (assigned fromScheduleConfig.sm_assignment;None= unassigned). Agents/search edit the config, never this field.est_bytes/est_flops: cost-model hints (weights dominateest_bytes);labelis human text.
Counter(id, init: int = 0, note: str = "") # init is always 0 for a fresh forward pass
Wait(counter: int, threshold: int) # do not execute until counters[counter] >= threshold
A counter is a uint32 (see amk_counter_t), monotonically incremented by its producer tasks,
never decremented. threshold MUST be a statically-known positive int; for a counter with >1
producer it MUST equal the producer count (§3.2).
Each task, on completion, does exactly one out_counter += 1 (producers only increment)
after a release fence ordering ALL of its output-buffer writes before the increment. Each task,
before executing, waits on a set of (counter, threshold) pairs with statically-known
thresholds (consumers only wait, never signal).
Producer→consumer edges (MegakernelProgram.dependency_edges()): for every counter, each
producer task precedes every task that waits on it. The acyclicity of this graph guarantees a
topological order exists (MegakernelProgram.topological_order() via Kahn's algorithm; returns
None on a cycle). The reference VM (vm/reference_vm.py) and CUDA VM both execute by repeatedly
firing every task whose every wait is satisfied, incrementing counters, a counter-driven walk of
the DAG.
Decode-loop model (frozen). One kernel launch == one forward pass == one decoded token.
Counters are host-memset to zero before each launch (Counter.init = 0); KV_CACHE persists in
HBM across launches; positions advance via params (pos, kv_start, kv_len) between
launches. The host drives the autoregressive loop. This keeps each launch under the Windows WDDM
~2s TDR watchdog (GpuTarget.wddm_tdr) while preserving the megakernel-within-a-step win.
schedule.ir.validate(prog) -> ValidationResult statically proves a program is safe to load. A
REJECTED result MUST prevent launch; the VM refuses anything validate() rejects.
validate() never raises on a malformed program, it always returns a result (the
clean-signal contract). ValidationResult carries ok, errors, warnings, stats and
report().
- Referential integrity + opcode arity (
OP_REGISTRY) + required params + ABI caps + param type/key sanity (PARAM_FIELDS:'i'→int32,'f'→real). - Every
Wait.thresholdsatisfies1 <= threshold <= #producers(counter). A wait on a counter with no producer, orthreshold > #producers, is unsatisfiable → REJECTED. - The producer→consumer graph is acyclic (
topological_order() != None); a cycle → REJECTED with a witness from_describe_cycle(iterative DFS, neverRecursionError, even at 5000+ nodes). - Per-SM queue ordering (once
smis assigned): each SM's serial queue (task-list order) must be a linear extension of the DAG, for every edgea→bwithsm[a]==sm[b],amust precedebin that SM's queue, else the SM blocks on a counter only its own later entry could signal. Assignedsmmust be in[0, target.num_sms).
A counter carries a count, not which producer finished. Therefore:
- Shared-counter all-join rule. A counter with
>1producer is a true join: every wait on it MUST usethreshold == #producers. A partial wait (1 < t < #producers) is a "first-k-of-N" race (the wrong producers can satisfy it) → REJECTED. - Transitive happens-before provenance. For every ACTIVATION / IO read, there must be a
transitive predecessor (through dependency edges) that wrote that buffer.
validate()walks the topo order maintaining, per task, the bitmask of buffers written by transitive predecessors (avail[t] = ext_mask ∪ over preds (avail[p] ∪ out_mask[p])); a read whose bit is unset is a data RACE → REJECTED. Read-only kinds (_READONLY_KINDS) are pre-set inext_mask. - KV_CACHE ordering. A
KV_CACHEwritten this pass (KV_APPEND) may be read only by tasks ordered after the append. The writer reading its own cache (prior-step state) is fine; any other reader without a happens-before edge from theKV_APPENDis a RACE → REJECTED.
- Page-aliasing (WAR/WAW). When
pagesis present (graphs<= 4000tasks), if aPageis reused by two activation buffers but a reader of the first and a writer of the second are unordered, a warning is emitted (possible clobber). - Output reachability. Every
IO_OUTPUTbuffer must be produced by some task, else REJECTED. - GPU labeling.
meta['gpu'] != target.nameis a warning (flywheel derives GPU fromtarget.name).
MegakernelProgram.simulate_counters(), static reachability cross-check; returns(order, stuck);stucknon-empty iff a wait is permanently unsatisfiable.MegakernelProgram.simulate_adversarial(seeds=16), fires ready tasks in varied worst-case interleavings and asserts every transient input was written by a prior-fired task; returns a list of race violations (the dynamic backstop to the static checks).
WEIGHT/CONSTlive in HBM, read-only, bound byBuffer.sourceinto the model state-dict. The bandwidth bound is on these (MegakernelProgram.total_weight_bytes()).KV_CACHEis persistent read/write HBM state (written byKV_APPEND, read by attention), surviving across launches.IO_INPUT/IO_OUTPUTare the model's input (token ids / embeds) and output (logits / sampled token) in HBM.ACTIVATIONis transient and logical: it is bound to a physicalPagefor its live range by the allocator; non-overlapping live ranges may share a page (graph-coloring reuse).
Page(id, space: MemSpace, nbytes, live_start=-1, live_end=-1) is a physical scratch slot.
PageAllocation(buffer_to_page: dict[int,int], pages: list[Page]) is the allocation output
(total_scratch_bytes sums the pages). Activations in SMEM/GLOBAL_SCRATCH do not
round-trip to HBM between ops, that is the megakernel win. On device, the host resolves every
buffer id to a fixed void* ptr (already offset for paged activations) before launch; the VM
never allocates HBM mid-flight.
ScheduleConfig is the structured object a coding agent proposes; the frozen VM
deterministically lowers it into a runnable MegakernelProgram. The agent never writes kernel
code, it only chooses a point in this search space the VM knows how to realize safely. It is also
exactly what is logged to the flywheel as the "schedule" column.
| Field | Type | Default | Meaning |
|---|---|---|---|
tiling |
dict[str, dict[str,int]] |
{} |
per-op-archetype tile sizes, e.g. {"gemv": {"N_tile": 256}, "attention": {"kv_block": 128}}. |
fusion_grouping |
list[list[str]] |
[] |
adjacent op-name groups fused into one resident task group. |
sm_assignment |
str | dict[int,int] |
"load_balance" |
"round_robin" | "load_balance" | explicit {task_id: sm}. INPUT only; lowerer resolves it into each Task.sm. |
pipelining_depth |
int |
2 |
instructions-ahead to prefetch weights, the biggest megakernel win (hides the inter-op HBM bubble). 0 = no prefetch. |
page_allocation |
str |
"graph_color" |
"linear" | "graph_color" | "none". |
threads_per_block |
int |
256 |
block size of the persistent VM kernel (occupancy-proven by the loader). |
smem_bytes_per_block |
int |
0 |
dynamic SMEM opt-in per block; MUST be <= GpuTarget.smem_bytes_per_block_optin. |
ScheduleConfig.to_dict() serializes int-keyed sm_assignment with string keys (re-parsed on
load). Search/agents tune these fields; the deterministic lowering + validate() guarantee the
result is safe regardless of the point chosen.
GpuTarget(name, sm_arch, num_sms, smem_bytes_per_sm, smem_bytes_per_block_optin, regs_per_sm, max_threads_per_sm, max_regs_per_thread, l2_bytes, hbm_bytes, hbm_bandwidth_gbs, fp16_tflops, clock_ghz=0.0, supports_cooperative=True, wddm_tdr=False, note="").
bandwidth_bound_us(weight_bytes) = weight_bytes / (hbm_bandwidth_gbs*1e9) * 1e6, the honest
single-stream decode floor. Built-in registry schedule.ir.TARGETS: rtx5090 (sm_120, 82 SMs,
wddm_tdr=True, measured on this machine), b200 (sm_100, spec only), h100 (sm_90, spec
only). Add a chip by adding a record, never by editing the scheduler.
Serialization is explicit, human-readable, additive-compatible JSON (git-friendly).
MegakernelProgram.to_json(indent=2) / .save(path) write it;
MegakernelProgram.from_json(s) / .load(path) / .from_dict(d) read it. Enums serialize by
name (kind, dtype, space, op are strings like "WEIGHT", "GEMV_TILE"). Unknown
(newer) fields in target/config are dropped on load via _filter_known (forward
compatibility); a major ir_version mismatch raises. Round-trip is stable
(from_json(to_json(p)).to_json() == to_json(p)).
{
"ir_version": "0.2.0",
"abi_version": "0.2",
"meta": {"model": "toy", "gpu": "rtx5090"},
"target": {"name": "rtx5090", "sm_arch": 120, "num_sms": 82, "...": "..."},
"buffers": [
{"id": 0, "name": "x", "kind": "IO_INPUT", "dtype": "F16",
"shape": [1, 16], "space": "HBM", "source": null},
{"id": 1, "name": "proj.w", "kind": "WEIGHT", "dtype": "F16",
"shape": [16, 16], "space": "HBM", "source": "proj.weight"}
],
"counters": [{"id": 0, "init": 0, "note": "rmsnorm done"}],
"tasks": [
{"id": 0, "op": "RMSNORM", "inputs": [0, 2], "outputs": [3], "out_counter": 0,
"waits": [], "params": {"eps": 1e-06, "hidden": 16}, "sm": null,
"est_bytes": 0, "est_flops": 0, "label": "rmsnorm"},
{"id": 1, "op": "GEMV_TILE", "inputs": [3, 1], "outputs": [4], "out_counter": 1,
"waits": [{"counter": 0, "threshold": 1}],
"params": {"K": 16, "N_tile": 16, "n_off": 0}, "sm": null, "label": "gemv"}
],
"pages": {"buffer_to_page": {"3": 0, "4": 0}, "pages": [
{"id": 0, "space": "SMEM", "nbytes": 64, "live_start": 0, "live_end": 1}]},
"config": {"tiling": {}, "fusion_grouping": [], "sm_assignment": "load_balance",
"pipelining_depth": 2, "page_allocation": "graph_color",
"threads_per_block": 256, "smem_bytes_per_block": 0}
}pages and config are null until those passes run. buffer_to_page keys are JSON strings
(int-parsed on load).
Opcodes are the Layer-1 ABI archetypes. Numeric codes are canonical (mirrored in vm/abi.h as
AMK_OP_*). OP_REGISTRY[kind] = OpSpec(kind, min_inputs, max_inputs, n_outputs, required_params, note) pins arity and the params that make each op mathematically correct;
validate() checks against it. -1 means variadic/variable (still capped by ABI limits).
| Opcode (code) | inputs | outputs | required params | semantics (instructions/reference.py) |
|---|---|---|---|---|
NOP (0) |
0 | 0 | , | no-op |
COPY (1) |
1 | 1 | , | page→page move |
EMBED (2) |
2 | 1 | hidden |
[ids, table[V,H]] → gathered rows |
RMSNORM (3) |
2 | 1 | eps, hidden |
x*rsqrt(mean(x²)+eps)*w |
LAYERNORM (4) |
2–3 | 1 | eps, hidden |
LayerNorm(x, w[, b]) |
GEMV_TILE (5) |
2–3 | 1 | K, N_tile, n_off |
out[..,n_off:n_off+N_tile] = x @ W[n_off:…].T |
GEMM_TILE (6) |
2–3 | 1 | M_tile, K, N_tile, n_off |
tiled GEMM (prefill/batch) |
ATTENTION_TILE (7) |
3–4 | 1 | head_dim, kv_start, kv_len, scale, n_heads, n_kv_heads |
GQA attention over a KV window |
ROPE (8) |
2 | 1 | head_dim, theta |
Llama rotate-half rotary embedding |
SILU_MUL (9) |
2 | 1 | , | SwiGLU silu(gate)*up |
GELU (10) |
1 | 1 | , | GELU |
ADD (11) |
2 | 1 | , | residual add |
MUL (12) |
1–2 | 1 | , | elementwise / scale |
DEQUANT (13) |
2–3 | 1 | qdtype, group |
int4/int8 + scales → fp tile |
SOFTMAX (14) |
1 | 1 | , | softmax over dim |
ALLREDUCE_SHARD (15) |
1–8 | 1 | , | tensor-parallel shard reduce (multi-GPU) |
KV_APPEND (16) |
2 | 1 | pos |
append new k/v into the KV cache at pos |
SAMPLE_ARGMAX (17) |
1 | 1 | , | greedy logits → next token |
ATTENTION_COMBINE (18) |
2–8 | 1 | , | merge per-KV-block (out,m,l) flash partials |
Frozen numeric conventions (the backends MUST match instructions/reference.py): weight
layout [N_out, K_in], GEMV/GEMM compute x @ W.T, a tile writes out[..., n_off:n_off+N_tile];
reductions accumulate in fp32 then cast; RoPE rotate-half; GQA with repeat_interleave and scale
1/sqrt(head_dim). Known scalar param fields and marshalled types are in PARAM_FIELDS
(int32 'i' or float 'f'); unknown keys warn (they will not marshal to amk_params_t).
Each Task maps 1:1 onto an amk_instruction_t (a fixed-size POD). The host loader resolves
the program into the on-device amk_program_t tables before launch. The IR enum values and
capacity/version constants are CANONICAL; vm/abi.h must match, and tests/test_abi_sync.py
parses both and fails the build on any drift.
IR (schedule/ir.py) |
ABI (vm/abi.h) |
|---|---|
Task.op |
amk_instruction_t.op (amk_opcode_t) |
Task.inputs / outputs |
int32 inputs[AMK_MAX_INPUTS] / outputs[AMK_MAX_OUTPUTS] + n_inputs/n_outputs |
Task.waits[i].counter / .threshold |
wait_counter[AMK_MAX_WAITS] / wait_threshold[AMK_MAX_WAITS] + n_waits |
Task.out_counter |
amk_instruction_t.out_counter |
Task.sm |
amk_instruction_t.sm (>= 0; loader rejects < 0) |
Task.params |
amk_params_t (keys/types per PARAM_FIELDS) |
Buffer |
amk_buffer_t {ptr, numel, rank, dtype, space, shape[4], stride[4]} (element strides) |
Counter |
amk_counter_t (uint32, host-memset 0 before each launch) |
MegakernelProgram |
amk_program_t {buffers, counters, instructions, sm_queue[][], scratch, abort_flag} |
ABI_MAX_INPUTS/OUTPUTS/WAITS/RANK |
AMK_MAX_INPUTS/OUTPUTS/WAITS/RANK (= 8/4/8/4) |
DType / MemSpace / InstructionKind codes |
amk_dtype_t / amk_memspace_t / amk_opcode_t |
ABI_VERSION = "0.2" |
AMK_ABI_VERSION_MAJOR.MINOR (= 0.2) |
signal(c): thread 0 issues a device-scope release fence (__threadfence()) ordering all output stores, thenatomicAdd(&prog.counters[c], 1u), then__syncthreads(). Cross-GPU counters (ALLREDUCE_SHARD) use__threadfence_system().wait(c,t): thread 0 spins on an acquire load the compiler may not hoist (while (atomicAdd(&counters[c],0u) < t) { backoff(); if (*abort_flag) return; }), then__syncthreads(). A plain non-volatile load is FORBIDDEN (it would hoist and spin forever).backoff()= exponential__nanosleep. Theabort_flagpoll is the watchdog escape.- Launch contract:
cudaLaunchCooperativeKernel(co-resident blocks make forward progress);gridDimcapped by verified cooperative occupancy; dynamic SMEM opt-in<= GpuTarget.smem_bytes_per_block_optin; one launch per token to stay under the WDDM TDR. The host treatscudaErrorLaunchTimeoutas a distinct TIMEOUT, not a clean REJECTED. - Instruction contract (Layer 1): each micro-kernel is exactly
__device__ void amk_inst_<name>(const amk_program_t&, const amk_instruction_t&), pure compute, MUST NOT touch counters or any undeclared buffer, MUST NOT launch work.
A compatible implementation MUST:
- Represent the data model of §1 with the canonical enum codes (§1.1) and ABI caps (8/4/8 i/o/waits, rank 4).
- Implement
validate()enforcing both invariants of §3 (deadlock-freedom and race-freedom, incl. the shared-counter all-join rule, transitive happens-before, and the KV_CACHE ordering rule) and refuse to load any rejected program. - Execute by counter-driven scheduling (§2): fire a task only when every wait
counter >= threshold; on completion incrementout_counterby exactly 1 after a release fence. - Read/write the JSON of §6 with name-encoded enums, additive forward-compatibility, and stable
round-trip; reject a major
ir_versionmismatch. - Honor the opcode semantics and frozen numeric conventions of §7 (verified against
instructions/reference.py) and the ABI mapping of §8 (verified againstvm/abi.h).
The canonical conformance oracle is vm/reference_vm.py (ReferenceVM) executing real numerics;
if a CUDA VM disagrees with it, the CUDA side is wrong by definition.