Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 287 additions & 0 deletions docs/PTO_IR_manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -7712,6 +7712,293 @@ pto.trap

---

### 4.21 Communication Operations

This section documents PTO communication primitives. PTOAS currently exposes:

- Synchronous point-to-point ops: `pto.tput`, `pto.tget`
- Synchronous signal ops: `pto.tnotify`, `pto.twait`, `pto.ttest`
- Synchronous collective ops: `pto.tbroadcast`, `pto.comm_tgather`, `pto.comm_tscatter`, `pto.treduce`
- Asynchronous communication/session ops: `pto.build_async_session`, `pto.tput_async`, `pto.tget_async`, `pto.wait_async_event`, `pto.test_async_event`

##### `pto.build_async_session` - Create Async DMA Session

**Summary:** Creates an async DMA session handle used by `pto.tput_async` and `pto.tget_async`.

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `scratch` | `pto.tile_buf` / local memref | Local scratch/staging buffer used by the async runtime |
| `workspace` | `!pto.ptr<...>` / GM memref | Global workspace backing the async session |
| `sync_id` | optional `i32` attr | Session synchronization ID |
| `block_bytes` | optional `i64` attr | Communication block size in bytes |
| `comm_block_offset` | optional `i64` attr | Per-block GM offset in bytes |
| `queue_num` | optional `i32` attr | Queue count hint |
| `channel_group_idx` | optional `i64` attr | Communication channel-group selector |

**Results:** `!pto.async_session`

**Constraints & Verification:**

- `scratch` must be tile-like local storage.
- `workspace` must be a GM pointer/memref.
- Optional attrs are forwarded as session configuration and must use the declared integer types.

**Basic Example:**

```mlir
%session = pto.build_async_session(%scratch, %workspace : !pto.tile_buf<loc=vec, dtype=i8, rows=1, cols=256, v_row=1, v_col=256, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.ptr<i8>) {sync_id = 0 : i32} -> !pto.async_session
```

---

##### `pto.tput_async` - Asynchronous Remote Write

**Summary:** Starts an asynchronous remote write from local GM to remote GM and returns an async event handle.

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote destination buffer |
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local source buffer |
| `session` | `!pto.async_session` | Async DMA session |

**Results:** `!pto.async_event`

**Constraints & Verification:**

- `dst` / `src` must be GM-shaped values with identical element type and static shape.
- Current lowering only supports flat contiguous logical-1D transfers for async GM operands.
- `session` must come from `pto.build_async_session`.

**Basic Example:**

```mlir
%event = pto.tput_async(%dst, %src, %session : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.async_session) -> !pto.async_event
```

---

##### `pto.tget_async` - Asynchronous Remote Read

**Summary:** Starts an asynchronous remote read from remote GM to local GM and returns an async event handle.

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local destination buffer |
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote source buffer |
| `session` | `!pto.async_session` | Async DMA session |

**Results:** `!pto.async_event`

**Constraints & Verification:**

- Same operand constraints as `pto.tput_async`.
- `session` must be compatible with the transfer workspace and staging configuration.

**Basic Example:**

```mlir
%event = pto.tget_async(%dst, %src, %session : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.async_session) -> !pto.async_event
```

---

##### `pto.wait_async_event` / `pto.test_async_event` - Async Event Completion

**Summary:** Consume an async event produced by `pto.tput_async` / `pto.tget_async`.

**Arguments:**

| Op | Operands | Result | Description |
|----|----------|--------|-------------|
| `pto.wait_async_event` | `event`, `session` | `i1` | Blocking wait for completion |
| `pto.test_async_event` | `event`, `session` | `i1` | Non-blocking completion test |

**Constraints & Verification:**

- `event` must have type `!pto.async_event`.
- `session` must have type `!pto.async_session`.
- The event/session pair is expected to come from the same async communication flow.

**Basic Example:**

```mlir
%done0 = pto.wait_async_event(%event0, %session : !pto.async_event, !pto.async_session) -> i1
%done1 = pto.test_async_event(%event1, %session : !pto.async_event, !pto.async_session) -> i1
```

---

##### `pto.tput` - Synchronous Remote Write

**Summary:** Lowers to `pto::comm::TPUT(...)` and copies data from local GM to remote GM through a VEC staging tile.

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote destination buffer |
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local source buffer |
| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile |
| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile for ping-pong transfer |
| `atomicType` | `#pto.atomic_type<...>` | Atomic mode, default `atomic_none` |

**Constraints & Verification:**

- `dst` / `src` must be GM-shaped values with positive static shapes.
- `dst` and `src` must have the same element type and static shape.
- `ping` / `pong` must be local VEC tile-like values whose element type matches `src`.

**Basic Example:**

```mlir
pto.tput %dst, %src, %ping {atomicType = #pto.atomic_type<atomic_none>} :
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>

pto.tput %dst, %src, %ping, %pong {atomicType = #pto.atomic_type<atomic_add>} :
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
```

---

##### `pto.tget` - Synchronous Remote Read

**Summary:** Lowers to `pto::comm::TGET(...)` and copies data from remote GM to local GM through a VEC staging tile.

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local destination buffer |
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote source buffer |
| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile |
| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile for ping-pong transfer |

**Constraints & Verification:**

- Same GM/global-like and staging constraints as `pto.tput`.
- `dst` and `src` must have the same element type and static shape.

**Basic Example:**

```mlir
pto.tget %dst, %src, %ping :
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
```

---

##### `pto.tnotify` / `pto.twait` / `pto.ttest` - Communication Signal Ops

**Summary:** Lower to `pto::comm::TNOTIFY/TWAIT/TTEST` for GM `i32` signal buffers.

**Arguments:**

| Op | Operands | Attributes | Result |
|----|----------|------------|--------|
| `pto.tnotify` | `signal`, `value` | `notifyOp = #pto.notify_op<atomic_add/set>` | none |
| `pto.twait` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp<eq/ne/gt/ge/lt/le>` | none |
| `pto.ttest` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp<eq/ne/gt/ge/lt/le>` | `i1` |

**Constraints & Verification:**

- `signal` must be a GM-shaped value with element type `i32`.
- `value` / `cmpValue` must be signless integer scalars.

**Basic Example:**

```mlir
pto.tnotify %sig, %v {notifyOp = #pto.notify_op<set>} : !pto.partition_tensor_view<1xi32>, i32
pto.twait %sig, %v {cmp = #pto.wait_cmp<ge>} : !pto.partition_tensor_view<1xi32>, i32
%ok = pto.ttest %sig, %v {cmp = #pto.wait_cmp<eq>} : !pto.partition_tensor_view<1xi32>, i32 -> i1
```

---

##### `pto.tbroadcast` - Collective Broadcast

**Summary:** Lowers to `pto::comm::TBROADCAST(...)`.

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `src` | GM-shaped value | Root source buffer |
| `ping` / `pong` | local VEC tile-like values | Staging tiles |
| `group` | variadic GM-shaped values | Parallel group members |
| `root` | `i32` attr | Root rank index inside `group` |

**Constraints & Verification:**

- `group` must be non-empty and all members must have identical types.
- `src` must have the same type as each `group` member.
- `root` must be in range `[0, group.size)`.

**Basic Example:**

```mlir
pto.tbroadcast %src, %ping, %g0, %g1, %g2 {root = 1, operandSegmentSizes = array<i32: 1, 1, 0, 3>} :
!pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>
```

---

##### `pto.comm_tgather` - Collective Gather

**Summary:** Communication collective that lowers to `pto::comm::TGATHER(...)`. This op is distinct from tile-level `pto.tgather`.

**Arguments:** `dst`, `ping`, optional `pong`, variadic `group`, `root`

**Constraints & Verification:**

- `group` must be non-empty and all members must have identical types.
- `dst` element type must match the group element type.
- `ping` / `pong` must be local VEC tile-like values with matching element type.

---

##### `pto.comm_tscatter` - Collective Scatter

**Summary:** Communication collective that lowers to `pto::comm::TSCATTER(...)`. This op is distinct from tile-level `pto.tscatter`.

**Arguments:** `src`, `ping`, optional `pong`, variadic `group`, `root`

**Constraints & Verification:**

- `group` must be non-empty and all members must have identical types.
- `src` element type must match the group element type.
- `ping` / `pong` must be local VEC tile-like values with matching element type.

---

##### `pto.treduce` - Collective Reduce

**Summary:** Lowers to `pto::comm::TREDUCE(...)`.

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `dst` | GM-shaped value | Root destination buffer |
| `acc` | local VEC tile-like value | Accumulation tile |
| `recvPing` / `recvPong` | local VEC tile-like values | Receive staging tiles |
| `group` | variadic GM-shaped values | Parallel group members |
| `reduceOp` | `#pto.reduce_op<sum/max/min>` | Reduction mode |
| `root` | `i32` attr | Root rank index inside `group` |

**Constraints & Verification:**

- `group` must be non-empty and all members must have identical types.
- `dst` element type must match the group element type.
- `acc` and `recvPing` / `recvPong` must be local VEC tile-like values whose element type matches `dst`.

---

## 5. Operation Summary Table

| Category | Count | Pipeline |
Expand Down
35 changes: 35 additions & 0 deletions include/PTO/IR/PTOAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,41 @@ def PTO_AtomicTypeAttr : EnumAttr<PTO_Dialect, PTO_AtomicTypeEnum, "atomic_type"
let summary = "TSTORE atomic type attribute";
}

def PTO_NotifyOpEnum : PTO_I32Enum<
"NotifyOp", "PTO communication notify op", [
I32EnumAttrCase<"AtomicAdd", 0, "atomic_add">,
I32EnumAttrCase<"Set", 1, "set">
]>;

def PTO_NotifyOpAttr : EnumAttr<PTO_Dialect, PTO_NotifyOpEnum, "notify_op"> {
let summary = "communication notify operation attribute";
}

def PTO_WaitCmpEnum : PTO_I32Enum<
"WaitCmp", "PTO communication wait/test compare", [
I32EnumAttrCase<"EQ", 0, "eq">,
I32EnumAttrCase<"NE", 1, "ne">,
I32EnumAttrCase<"GT", 2, "gt">,
I32EnumAttrCase<"GE", 3, "ge">,
I32EnumAttrCase<"LT", 4, "lt">,
I32EnumAttrCase<"LE", 5, "le">
]>;

def PTO_WaitCmpAttr : EnumAttr<PTO_Dialect, PTO_WaitCmpEnum, "wait_cmp"> {
let summary = "communication wait/test comparison attribute";
}

def PTO_ReduceOpEnum : PTO_I32Enum<
"ReduceOp", "PTO communication reduce operation", [
I32EnumAttrCase<"Sum", 0, "sum">,
I32EnumAttrCase<"Max", 1, "max">,
I32EnumAttrCase<"Min", 2, "min">
]>;

def PTO_ReduceOpAttr : EnumAttr<PTO_Dialect, PTO_ReduceOpEnum, "reduce_op"> {
let summary = "communication reduce operation attribute";
}

def PTO_ReluPreModeEnum : PTO_I32Enum<
"ReluPreMode", "PTO TSTORE relu pre mode", [
I32EnumAttrCase<"NoRelu", 0, "no_relu">,
Expand Down
Loading
Loading