Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 68 additions & 16 deletions docs/PTO_IR_manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -6086,7 +6086,7 @@ pto.tscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>)

##### `pto.mgather` - Gather-Load from Global Memory

**Summary:** Loads elements from global memory into a tile using per-element indices.
**Summary:** Loads elements from a global table into a VEC tile using per-element indices. Supports an optional A5-only out-of-bounds mode that lowers to the corresponding `MGATHER<...>` template overload.

**Semantics:**

Expand All @@ -6096,18 +6096,34 @@ dst[i, j] = mem[idx[i, j]]

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `mem` | `AnyMemRef/pto.tile_buf` | Source memory |
| `idx` | `pto.tile_buf` | Index tile |
| `dst` | `pto.tile_buf` | Destination tile |
| Name | Type | Default | Description |
|------|------|---------|-------------|
| `mem` | `!pto.partition_tensor_view<...>` / GM memref | `NA` | Global source table |
| `idx` | `pto.tile_buf` | `NA` | Index tile |
| `dst` | `pto.tile_buf` | `NA` | Destination VEC tile |
| `gatherOob` | `#pto<gather_oob ...>` | `undefined` | A5-only out-of-bounds mode (`undefined/clamp/wrap/zero`) |

**Results:** None. Writes into `dst` via DPS pattern.

**Constraints & Verification:**

- Index interpretation is target-defined. The CPU simulator treats indices as linear element indices into `src.data()`.
- No bounds checks are enforced on `indexes` by the CPU simulator.
- **Types (data and indices)**
- `mem` and `dst` must have the **same element type**. Supported element types: `i8`/`i16`/`i32`/`f16`/`bf16`/`f32`. On **A5** targets, `float8_e4m3` / `float8_e5m2` family element types are also supported.
- `idx` element type must be signless `i32`.

- **Tile / memory roles**
- `dst` and `idx` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`.
- `mem` must denote a GlobalTensor in GM memory.
- `mem` must use `ND` layout when layout can be inferred.

- **Shape**
- `dst row == idx row`.
- `idx column == 1` or `idx column == dst column`.
- If `mem` is a rank-5 static GM memref, it must satisfy `<1, 1, 1, Rows, RowWidth>`.

- **Out-of-bounds mode**
- Default `gatherOob = undefined` lowers to the default `MGATHER(dst, mem, idx)` overload.
- Non-default `gatherOob` values are only supported on **A5** and lower to `MGATHER<GatherOOB::...>(dst, mem, idx)`.

**Hardware Mapping:**

Expand All @@ -6118,13 +6134,17 @@ dst[i, j] = mem[idx[i, j]]
```mlir
pto.mgather ins(%mem, %idx : memref<...>, !pto.tile_buf<...>)
outs(%dst : !pto.tile_buf<...>)

pto.mgather ins(%mem, %idx : memref<...>, !pto.tile_buf<...>)
outs(%dst : !pto.tile_buf<...>)
{gatherOob = #pto<gather_oob zero>}
```

---

##### `pto.mscatter` - Scatter-Store to Global Memory

**Summary:** Stores elements from a tile into global memory using per-element indices.
**Summary:** Stores elements from a VEC tile into a global table using per-element indices. Supports optional A5-only atomic and out-of-bounds modes that lower to the corresponding `MSCATTER<...>` template overload family.

**Semantics:**

Expand All @@ -6134,18 +6154,41 @@ mem[idx[i, j]] = src[i, j]

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `src` | `pto.tile_buf` | Source tile |
| `idx` | `pto.tile_buf` | Index tile |
| `mem` | `AnyMemRef/pto.tile_buf` | Destination memory |
| Name | Type | Default | Description |
|------|------|---------|-------------|
| `src` | `pto.tile_buf` | `NA` | Source VEC tile |
| `idx` | `pto.tile_buf` | `NA` | Index tile |
| `mem` | `!pto.partition_tensor_view<...>` / GM memref | `NA` | Global destination table |
| `scatterAtomicOp` | `#pto<scatter_atomic_op ...>` | `none` | A5-only atomic mode (`none/add/max/min`) |
| `scatterOob` | `#pto<scatter_oob ...>` | `undefined` | A5-only out-of-bounds mode (`undefined/skip/clamp/wrap`) |

**Results:** None. Writes into `mem` via DPS pattern.

**Constraints & Verification:**

- Index interpretation is target-defined. The CPU simulator treats indices as linear element indices into `dst.data()`.
- No bounds checks are enforced on `indexes` by the CPU simulator.
- **Types (data and indices)**
- `src` and `mem` must have the **same element type**. Supported element types: `i8`/`i16`/`i32`/`f16`/`bf16`/`f32`. On **A5** targets, `float8_e4m3` / `float8_e5m2` family element types are also supported.
- `idx` element type must be signless `i32`.

- **Tile / memory roles**
- `src` and `idx` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`.
- `mem` must denote a GlobalTensor in GM memory.
- `mem` must use `ND` layout when layout can be inferred.

- **Shape**
- `src row == idx row`.
- `idx column == 1` or `idx column == src column`.
- If `mem` is a rank-5 static GM memref, it must satisfy `<1, 1, 1, Rows, RowWidth>`.

- **Atomic modes**
- Default `scatterAtomicOp = none` lowers to the default `MSCATTER(mem, src, idx)` overload.
- Non-default `scatterAtomicOp` values are only supported on **A5**.
- `add` requires `i32`/`f16`/`f32`.
- `max`/`min` require signless `i32` or `f32`.

- **Out-of-bounds modes**
- Default `scatterOob = undefined` lowers to the 1-template-parameter `MSCATTER<Atomic>(mem, src, idx)` form when only atomic is specified, or to the default overload when both attrs are default.
- Non-default `scatterOob` values are only supported on **A5** and lower to `MSCATTER<ScatterAtomicOp::..., ScatterOOB::...>(mem, src, idx)`.

**Hardware Mapping:**

Expand All @@ -6156,6 +6199,15 @@ mem[idx[i, j]] = src[i, j]
```mlir
pto.mscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>)
outs(%mem : memref<...>)

pto.mscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>)
outs(%mem : memref<...>)
{scatterAtomicOp = #pto<scatter_atomic_op add>}

pto.mscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>)
outs(%mem : memref<...>)
{scatterAtomicOp = #pto<scatter_atomic_op add>,
scatterOob = #pto<scatter_oob skip>}
```

---
Expand Down
36 changes: 36 additions & 0 deletions include/PTO/IR/PTOAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,42 @@ def PTO_ReluPreModeAttr : EnumAttr<PTO_Dialect, PTO_ReluPreModeEnum, "relu_pre_m
let summary = "TSTORE relu pre mode attribute";
}

def PTO_GatherOOBEnum : PTO_I32Enum<
"GatherOOB", "PTO MGATHER out-of-bounds mode", [
I32EnumAttrCase<"Undefined", 0, "undefined">,
I32EnumAttrCase<"Clamp", 1, "clamp">,
I32EnumAttrCase<"Wrap", 2, "wrap">,
I32EnumAttrCase<"Zero", 3, "zero">
]>;

def PTO_GatherOOBAttr : EnumAttr<PTO_Dialect, PTO_GatherOOBEnum, "gather_oob"> {
let summary = "MGATHER out-of-bounds handling mode";
}

def PTO_ScatterAtomicOpEnum : PTO_I32Enum<
"ScatterAtomicOp", "PTO MSCATTER atomic mode", [
I32EnumAttrCase<"None", 0, "none">,
I32EnumAttrCase<"Add", 1, "add">,
I32EnumAttrCase<"Max", 2, "max">,
I32EnumAttrCase<"Min", 3, "min">
]>;

def PTO_ScatterAtomicOpAttr : EnumAttr<PTO_Dialect, PTO_ScatterAtomicOpEnum, "scatter_atomic_op"> {
let summary = "MSCATTER atomic mode";
}

def PTO_ScatterOOBEnum : PTO_I32Enum<
"ScatterOOB", "PTO MSCATTER out-of-bounds mode", [
I32EnumAttrCase<"Undefined", 0, "undefined">,
I32EnumAttrCase<"Skip", 1, "skip">,
I32EnumAttrCase<"Clamp", 2, "clamp">,
I32EnumAttrCase<"Wrap", 3, "wrap">
]>;

def PTO_ScatterOOBAttr : EnumAttr<PTO_Dialect, PTO_ScatterOOBEnum, "scatter_oob"> {
let summary = "MSCATTER out-of-bounds handling mode";
}

def PTO_AccToVecMode_Enum : PTO_I32Enum<"AccToVecMode", "TMOV acc-to-vec mode", [
I32EnumAttrCase<"SingleModeVec0", 0, "single_mode_vec0">,
I32EnumAttrCase<"SingleModeVec1", 1, "single_mode_vec1">,
Expand Down
7 changes: 5 additions & 2 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2166,7 +2166,8 @@ def MGatherOp : PTO_TOp<"mgather", [
let arguments = (ins
PTODpsType:$mem,
PTODpsType:$idx,
PTODpsType:$dst);
PTODpsType:$dst,
DefaultValuedAttr<PTO_GatherOOBAttr, "::mlir::pto::GatherOOB::Undefined">:$gatherOob);

let results = (outs);

Expand Down Expand Up @@ -2261,7 +2262,9 @@ def MScatterOp : PTO_TOp<"mscatter", [
let arguments = (ins
PTODpsType:$src,
PTODpsType:$idx,
PTODpsType:$mem // outs target
PTODpsType:$mem, // outs target
DefaultValuedAttr<PTO_ScatterAtomicOpAttr, "::mlir::pto::ScatterAtomicOp::None">:$scatterAtomicOp,
DefaultValuedAttr<PTO_ScatterOOBAttr, "::mlir::pto::ScatterOOB::Undefined">:$scatterOob
);

let results = (outs);
Expand Down
Loading
Loading