diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index e19bb6229..8c519f985 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -7712,6 +7712,293 @@ pto.trap --- +### 4.21 Communication Operations + +This section documents PTO communication primitives. PTOAS currently exposes: + +- Synchronous point-to-point ops: `pto.tput`, `pto.tget` +- Synchronous signal ops: `pto.tnotify`, `pto.twait`, `pto.ttest` +- Synchronous collective ops: `pto.tbroadcast`, `pto.comm_tgather`, `pto.comm_tscatter`, `pto.treduce` +- Asynchronous communication/session ops: `pto.build_async_session`, `pto.tput_async`, `pto.tget_async`, `pto.wait_async_event`, `pto.test_async_event` + +##### `pto.build_async_session` - Create Async DMA Session + +**Summary:** Creates an async DMA session handle used by `pto.tput_async` and `pto.tget_async`. + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `scratch` | `pto.tile_buf` / local memref | Local scratch/staging buffer used by the async runtime | +| `workspace` | `!pto.ptr<...>` / GM memref | Global workspace backing the async session | +| `sync_id` | optional `i32` attr | Session synchronization ID | +| `block_bytes` | optional `i64` attr | Communication block size in bytes | +| `comm_block_offset` | optional `i64` attr | Per-block GM offset in bytes | +| `queue_num` | optional `i32` attr | Queue count hint | +| `channel_group_idx` | optional `i64` attr | Communication channel-group selector | + +**Results:** `!pto.async_session` + +**Constraints & Verification:** + +- `scratch` must be tile-like local storage. +- `workspace` must be a GM pointer/memref. +- Optional attrs are forwarded as session configuration and must use the declared integer types. + +**Basic Example:** + +```mlir +%session = pto.build_async_session(%scratch, %workspace : !pto.tile_buf, !pto.ptr) {sync_id = 0 : i32} -> !pto.async_session +``` + +--- + +##### `pto.tput_async` - Asynchronous Remote Write + +**Summary:** Starts an asynchronous remote write from local GM to remote GM and returns an async event handle. + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote destination buffer | +| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local source buffer | +| `session` | `!pto.async_session` | Async DMA session | + +**Results:** `!pto.async_event` + +**Constraints & Verification:** + +- `dst` / `src` must be GM-shaped values with identical element type and static shape. +- Current lowering only supports flat contiguous logical-1D transfers for async GM operands. +- `session` must come from `pto.build_async_session`. + +**Basic Example:** + +```mlir +%event = pto.tput_async(%dst, %src, %session : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.async_session) -> !pto.async_event +``` + +--- + +##### `pto.tget_async` - Asynchronous Remote Read + +**Summary:** Starts an asynchronous remote read from remote GM to local GM and returns an async event handle. + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local destination buffer | +| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote source buffer | +| `session` | `!pto.async_session` | Async DMA session | + +**Results:** `!pto.async_event` + +**Constraints & Verification:** + +- Same operand constraints as `pto.tput_async`. +- `session` must be compatible with the transfer workspace and staging configuration. + +**Basic Example:** + +```mlir +%event = pto.tget_async(%dst, %src, %session : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.async_session) -> !pto.async_event +``` + +--- + +##### `pto.wait_async_event` / `pto.test_async_event` - Async Event Completion + +**Summary:** Consume an async event produced by `pto.tput_async` / `pto.tget_async`. + +**Arguments:** + +| Op | Operands | Result | Description | +|----|----------|--------|-------------| +| `pto.wait_async_event` | `event`, `session` | `i1` | Blocking wait for completion | +| `pto.test_async_event` | `event`, `session` | `i1` | Non-blocking completion test | + +**Constraints & Verification:** + +- `event` must have type `!pto.async_event`. +- `session` must have type `!pto.async_session`. +- The event/session pair is expected to come from the same async communication flow. + +**Basic Example:** + +```mlir +%done0 = pto.wait_async_event(%event0, %session : !pto.async_event, !pto.async_session) -> i1 +%done1 = pto.test_async_event(%event1, %session : !pto.async_event, !pto.async_session) -> i1 +``` + +--- + +##### `pto.tput` - Synchronous Remote Write + +**Summary:** Lowers to `pto::comm::TPUT(...)` and copies data from local GM to remote GM through a VEC staging tile. + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote destination buffer | +| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local source buffer | +| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile | +| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile for ping-pong transfer | +| `atomicType` | `#pto.atomic_type<...>` | Atomic mode, default `atomic_none` | + +**Constraints & Verification:** + +- `dst` / `src` must be GM-shaped values with positive static shapes. +- `dst` and `src` must have the same element type and static shape. +- `ping` / `pong` must be local VEC tile-like values whose element type matches `src`. + +**Basic Example:** + +```mlir +pto.tput %dst, %src, %ping {atomicType = #pto.atomic_type} : + !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf + +pto.tput %dst, %src, %ping, %pong {atomicType = #pto.atomic_type} : + !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf +``` + +--- + +##### `pto.tget` - Synchronous Remote Read + +**Summary:** Lowers to `pto::comm::TGET(...)` and copies data from remote GM to local GM through a VEC staging tile. + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local destination buffer | +| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote source buffer | +| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile | +| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile for ping-pong transfer | + +**Constraints & Verification:** + +- Same GM/global-like and staging constraints as `pto.tput`. +- `dst` and `src` must have the same element type and static shape. + +**Basic Example:** + +```mlir +pto.tget %dst, %src, %ping : + !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf +``` + +--- + +##### `pto.tnotify` / `pto.twait` / `pto.ttest` - Communication Signal Ops + +**Summary:** Lower to `pto::comm::TNOTIFY/TWAIT/TTEST` for GM `i32` signal buffers. + +**Arguments:** + +| Op | Operands | Attributes | Result | +|----|----------|------------|--------| +| `pto.tnotify` | `signal`, `value` | `notifyOp = #pto.notify_op` | none | +| `pto.twait` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp` | none | +| `pto.ttest` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp` | `i1` | + +**Constraints & Verification:** + +- `signal` must be a GM-shaped value with element type `i32`. +- `value` / `cmpValue` must be signless integer scalars. + +**Basic Example:** + +```mlir +pto.tnotify %sig, %v {notifyOp = #pto.notify_op} : !pto.partition_tensor_view<1xi32>, i32 +pto.twait %sig, %v {cmp = #pto.wait_cmp} : !pto.partition_tensor_view<1xi32>, i32 +%ok = pto.ttest %sig, %v {cmp = #pto.wait_cmp} : !pto.partition_tensor_view<1xi32>, i32 -> i1 +``` + +--- + +##### `pto.tbroadcast` - Collective Broadcast + +**Summary:** Lowers to `pto::comm::TBROADCAST(...)`. + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `src` | GM-shaped value | Root source buffer | +| `ping` / `pong` | local VEC tile-like values | Staging tiles | +| `group` | variadic GM-shaped values | Parallel group members | +| `root` | `i32` attr | Root rank index inside `group` | + +**Constraints & Verification:** + +- `group` must be non-empty and all members must have identical types. +- `src` must have the same type as each `group` member. +- `root` must be in range `[0, group.size)`. + +**Basic Example:** + +```mlir +pto.tbroadcast %src, %ping, %g0, %g1, %g2 {root = 1, operandSegmentSizes = array} : + !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32> +``` + +--- + +##### `pto.comm_tgather` - Collective Gather + +**Summary:** Communication collective that lowers to `pto::comm::TGATHER(...)`. This op is distinct from tile-level `pto.tgather`. + +**Arguments:** `dst`, `ping`, optional `pong`, variadic `group`, `root` + +**Constraints & Verification:** + +- `group` must be non-empty and all members must have identical types. +- `dst` element type must match the group element type. +- `ping` / `pong` must be local VEC tile-like values with matching element type. + +--- + +##### `pto.comm_tscatter` - Collective Scatter + +**Summary:** Communication collective that lowers to `pto::comm::TSCATTER(...)`. This op is distinct from tile-level `pto.tscatter`. + +**Arguments:** `src`, `ping`, optional `pong`, variadic `group`, `root` + +**Constraints & Verification:** + +- `group` must be non-empty and all members must have identical types. +- `src` element type must match the group element type. +- `ping` / `pong` must be local VEC tile-like values with matching element type. + +--- + +##### `pto.treduce` - Collective Reduce + +**Summary:** Lowers to `pto::comm::TREDUCE(...)`. + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `dst` | GM-shaped value | Root destination buffer | +| `acc` | local VEC tile-like value | Accumulation tile | +| `recvPing` / `recvPong` | local VEC tile-like values | Receive staging tiles | +| `group` | variadic GM-shaped values | Parallel group members | +| `reduceOp` | `#pto.reduce_op` | Reduction mode | +| `root` | `i32` attr | Root rank index inside `group` | + +**Constraints & Verification:** + +- `group` must be non-empty and all members must have identical types. +- `dst` element type must match the group element type. +- `acc` and `recvPing` / `recvPong` must be local VEC tile-like values whose element type matches `dst`. + +--- + ## 5. Operation Summary Table | Category | Count | Pipeline | diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index 1a975fea1..b78a10bfe 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -394,6 +394,41 @@ def PTO_AtomicTypeAttr : EnumAttr, + I32EnumAttrCase<"Set", 1, "set"> + ]>; + +def PTO_NotifyOpAttr : EnumAttr { + let summary = "communication notify operation attribute"; +} + +def PTO_WaitCmpEnum : PTO_I32Enum< + "WaitCmp", "PTO communication wait/test compare", [ + I32EnumAttrCase<"EQ", 0, "eq">, + I32EnumAttrCase<"NE", 1, "ne">, + I32EnumAttrCase<"GT", 2, "gt">, + I32EnumAttrCase<"GE", 3, "ge">, + I32EnumAttrCase<"LT", 4, "lt">, + I32EnumAttrCase<"LE", 5, "le"> + ]>; + +def PTO_WaitCmpAttr : EnumAttr { + let summary = "communication wait/test comparison attribute"; +} + +def PTO_ReduceOpEnum : PTO_I32Enum< + "ReduceOp", "PTO communication reduce operation", [ + I32EnumAttrCase<"Sum", 0, "sum">, + I32EnumAttrCase<"Max", 1, "max">, + I32EnumAttrCase<"Min", 2, "min"> + ]>; + +def PTO_ReduceOpAttr : EnumAttr { + let summary = "communication reduce operation attribute"; +} + def PTO_ReluPreModeEnum : PTO_I32Enum< "ReluPreMode", "PTO TSTORE relu pre mode", [ I32EnumAttrCase<"NoRelu", 0, "no_relu">, diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 193fe44c5..293a042b3 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -1655,6 +1655,140 @@ def TestAsyncEventOp : PTO_Op<"test_async_event", [ }]; } +def TPutOp : PTO_Op<"tput", [ + DeclareOpInterfaceMethods +]> { + let summary = "Synchronous remote write from local GM to remote GM"; + let arguments = (ins + PTODpsType:$dst, + PTODpsType:$src, + PTODpsType:$ping, + Optional:$pong, + DefaultValuedAttr:$atomicType + ); + let results = (outs); + let hasVerifier = 1; +} + +def TGetOp : PTO_Op<"tget", [ + DeclareOpInterfaceMethods +]> { + let summary = "Synchronous remote read from remote GM to local GM"; + let arguments = (ins + PTODpsType:$dst, + PTODpsType:$src, + PTODpsType:$ping, + Optional:$pong + ); + let results = (outs); + let hasVerifier = 1; +} + +def TNotifyOp : PTO_Op<"tnotify", [ + DeclareOpInterfaceMethods +]> { + let summary = "Send a signal notification to remote GM"; + let arguments = (ins + PTODpsType:$signal, + AnySignlessInteger:$value, + PTO_NotifyOpAttr:$notifyOp + ); + let results = (outs); + let hasVerifier = 1; +} + +def TWaitOp : PTO_Op<"twait", [ + DeclareOpInterfaceMethods +]> { + let summary = "Block until signal(s) satisfy a comparison"; + let arguments = (ins + PTODpsType:$signal, + AnySignlessInteger:$cmpValue, + PTO_WaitCmpAttr:$cmp + ); + let results = (outs); + let hasVerifier = 1; +} + +def TTestOp : PTO_Op<"ttest", [ + DeclareOpInterfaceMethods +]> { + let summary = "Non-blocking signal comparison test"; + let arguments = (ins + PTODpsType:$signal, + AnySignlessInteger:$cmpValue, + PTO_WaitCmpAttr:$cmp + ); + let results = (outs I1:$result); + let hasVerifier = 1; +} + +def TBroadcastOp : PTO_Op<"tbroadcast", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods +]> { + let summary = "Broadcast local GM data to all group members"; + let arguments = (ins + PTODpsType:$src, + PTODpsType:$ping, + Optional:$pong, + Variadic:$group, + I32Attr:$root + ); + let results = (outs); + let hasVerifier = 1; +} + +def CommTGatherOp : PTO_Op<"comm_tgather", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods +]> { + let summary = "Gather remote GM data from a communication group"; + let arguments = (ins + PTODpsType:$dst, + PTODpsType:$ping, + Optional:$pong, + Variadic:$group, + I32Attr:$root + ); + let results = (outs); + let hasVerifier = 1; +} + +def CommTScatterOp : PTO_Op<"comm_tscatter", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods +]> { + let summary = "Scatter local GM data to a communication group"; + let arguments = (ins + PTODpsType:$src, + PTODpsType:$ping, + Optional:$pong, + Variadic:$group, + I32Attr:$root + ); + let results = (outs); + let hasVerifier = 1; +} + +def TReduceOp : PTO_Op<"treduce", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods +]> { + let summary = "Reduce remote GM data from a communication group"; + let arguments = (ins + PTODpsType:$dst, + PTODpsType:$acc, + PTODpsType:$recvPing, + Optional:$recvPong, + Variadic:$group, + PTO_ReduceOpAttr:$reduceOp, + I32Attr:$root + ); + let results = (outs); + let hasVerifier = 1; +} + def InitializeL2G2LPipeOp : PTO_Op<"initialize_l2g2l_pipe", [ DeclareOpInterfaceMethods ]> { diff --git a/include/pto-c/Dialect/PTO.h b/include/pto-c/Dialect/PTO.h index 787749157..c820c2684 100644 --- a/include/pto-c/Dialect/PTO.h +++ b/include/pto-c/Dialect/PTO.h @@ -98,6 +98,18 @@ MLIR_CAPI_EXPORTED int32_t mlirPTOAccToVecModeAttrGetValue(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAReluPreModeAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirPTOReluPreModeAttrGet(MlirContext ctx, int32_t value); MLIR_CAPI_EXPORTED int32_t mlirPTOReluPreModeAttrGetValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAAtomicTypeAttr(MlirAttribute attr); +MLIR_CAPI_EXPORTED MlirAttribute mlirPTOAtomicTypeAttrGet(MlirContext ctx, int32_t value); +MLIR_CAPI_EXPORTED int32_t mlirPTOAtomicTypeAttrGetValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirPTOAttrIsANotifyOpAttr(MlirAttribute attr); +MLIR_CAPI_EXPORTED MlirAttribute mlirPTONotifyOpAttrGet(MlirContext ctx, int32_t value); +MLIR_CAPI_EXPORTED int32_t mlirPTONotifyOpAttrGetValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAWaitCmpAttr(MlirAttribute attr); +MLIR_CAPI_EXPORTED MlirAttribute mlirPTOWaitCmpAttrGet(MlirContext ctx, int32_t value); +MLIR_CAPI_EXPORTED int32_t mlirPTOWaitCmpAttrGetValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAReduceOpAttr(MlirAttribute attr); +MLIR_CAPI_EXPORTED MlirAttribute mlirPTOReduceOpAttrGet(MlirContext ctx, int32_t value); +MLIR_CAPI_EXPORTED int32_t mlirPTOReduceOpAttrGetValue(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirPTORoundModeAttrGet(MlirContext ctx, int32_t value); MLIR_CAPI_EXPORTED bool mlirPTOAttrIsARoundModeAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED int32_t mlirPTORoundModeAttrGetValue(MlirAttribute attr); diff --git a/lib/Bindings/Python/PTOModule.cpp b/lib/Bindings/Python/PTOModule.cpp index c8dff3109..c2a03b2b3 100644 --- a/lib/Bindings/Python/PTOModule.cpp +++ b/lib/Bindings/Python/PTOModule.cpp @@ -153,6 +153,31 @@ PYBIND11_MODULE(_pto, m) { .value("NormalRelu", mlir::pto::ReluPreMode::NormalRelu) .export_values(); + py::enum_(m, "AtomicType") + .value("AtomicNone", mlir::pto::AtomicType::AtomicNone) + .value("AtomicAdd", mlir::pto::AtomicType::AtomicAdd) + .export_values(); + + py::enum_(m, "NotifyOp") + .value("AtomicAdd", mlir::pto::NotifyOp::AtomicAdd) + .value("Set", mlir::pto::NotifyOp::Set) + .export_values(); + + py::enum_(m, "WaitCmp") + .value("EQ", mlir::pto::WaitCmp::EQ) + .value("NE", mlir::pto::WaitCmp::NE) + .value("GT", mlir::pto::WaitCmp::GT) + .value("GE", mlir::pto::WaitCmp::GE) + .value("LT", mlir::pto::WaitCmp::LT) + .value("LE", mlir::pto::WaitCmp::LE) + .export_values(); + + py::enum_(m, "ReduceOp") + .value("Sum", mlir::pto::ReduceOp::Sum) + .value("Max", mlir::pto::ReduceOp::Max) + .value("Min", mlir::pto::ReduceOp::Min) + .export_values(); + py::enum_(m, "SyncOpType") .value("TLOAD", mlir::pto::SyncOpType::TLOAD) .value("TSTORE_ACC", mlir::pto::SyncOpType::TSTORE_ACC) @@ -266,6 +291,58 @@ PYBIND11_MODULE(_pto, m) { return cls(a); }, py::arg("cls"), py::arg("value"), py::arg("context") = py::none()); + + mlir_attribute_subclass(m, "AtomicTypeAttr", + [](MlirAttribute a) -> bool { + return mlirPTOAttrIsAAtomicTypeAttr(a); + }) + .def_classmethod( + "get", + [](py::object cls, mlir::pto::AtomicType value, MlirContext ctx) -> py::object { + MlirAttribute a = mlirPTOAtomicTypeAttrGet(ctx, static_cast(value)); + if (mlirAttributeIsNull(a)) return py::none(); + return cls(a); + }, + py::arg("cls"), py::arg("value"), py::arg("context") = py::none()); + + mlir_attribute_subclass(m, "NotifyOpAttr", + [](MlirAttribute a) -> bool { + return mlirPTOAttrIsANotifyOpAttr(a); + }) + .def_classmethod( + "get", + [](py::object cls, mlir::pto::NotifyOp value, MlirContext ctx) -> py::object { + MlirAttribute a = mlirPTONotifyOpAttrGet(ctx, static_cast(value)); + if (mlirAttributeIsNull(a)) return py::none(); + return cls(a); + }, + py::arg("cls"), py::arg("value"), py::arg("context") = py::none()); + + mlir_attribute_subclass(m, "WaitCmpAttr", + [](MlirAttribute a) -> bool { + return mlirPTOAttrIsAWaitCmpAttr(a); + }) + .def_classmethod( + "get", + [](py::object cls, mlir::pto::WaitCmp value, MlirContext ctx) -> py::object { + MlirAttribute a = mlirPTOWaitCmpAttrGet(ctx, static_cast(value)); + if (mlirAttributeIsNull(a)) return py::none(); + return cls(a); + }, + py::arg("cls"), py::arg("value"), py::arg("context") = py::none()); + + mlir_attribute_subclass(m, "ReduceOpAttr", + [](MlirAttribute a) -> bool { + return mlirPTOAttrIsAReduceOpAttr(a); + }) + .def_classmethod( + "get", + [](py::object cls, mlir::pto::ReduceOp value, MlirContext ctx) -> py::object { + MlirAttribute a = mlirPTOReduceOpAttrGet(ctx, static_cast(value)); + if (mlirAttributeIsNull(a)) return py::none(); + return cls(a); + }, + py::arg("cls"), py::arg("value"), py::arg("context") = py::none()); // [保留 HEAD]: AddressSpaceAttr 定义 mlir_attribute_subclass( m, "AddressSpaceAttr", diff --git a/lib/CAPI/Dialect/PTO.cpp b/lib/CAPI/Dialect/PTO.cpp index 162519fa6..4f1fe309a 100644 --- a/lib/CAPI/Dialect/PTO.cpp +++ b/lib/CAPI/Dialect/PTO.cpp @@ -585,6 +585,66 @@ int32_t mlirPTOReluPreModeAttrGetValue(MlirAttribute attr) { return static_cast(a.getValue()); } +bool mlirPTOAttrIsAAtomicTypeAttr(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirPTOAtomicTypeAttrGet(MlirContext ctx, int32_t value) { + auto *c = unwrap(ctx); + return wrap(mlir::pto::AtomicTypeAttr::get( + c, static_cast(value))); +} + +int32_t mlirPTOAtomicTypeAttrGetValue(MlirAttribute attr) { + auto a = mlir::cast(unwrap(attr)); + return static_cast(a.getValue()); +} + +bool mlirPTOAttrIsANotifyOpAttr(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirPTONotifyOpAttrGet(MlirContext ctx, int32_t value) { + auto *c = unwrap(ctx); + return wrap(mlir::pto::NotifyOpAttr::get( + c, static_cast(value))); +} + +int32_t mlirPTONotifyOpAttrGetValue(MlirAttribute attr) { + auto a = mlir::cast(unwrap(attr)); + return static_cast(a.getValue()); +} + +bool mlirPTOAttrIsAWaitCmpAttr(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirPTOWaitCmpAttrGet(MlirContext ctx, int32_t value) { + auto *c = unwrap(ctx); + return wrap(mlir::pto::WaitCmpAttr::get( + c, static_cast(value))); +} + +int32_t mlirPTOWaitCmpAttrGetValue(MlirAttribute attr) { + auto a = mlir::cast(unwrap(attr)); + return static_cast(a.getValue()); +} + +bool mlirPTOAttrIsAReduceOpAttr(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirPTOReduceOpAttrGet(MlirContext ctx, int32_t value) { + auto *c = unwrap(ctx); + return wrap(mlir::pto::ReduceOpAttr::get( + c, static_cast(value))); +} + +int32_t mlirPTOReduceOpAttrGetValue(MlirAttribute attr) { + auto a = mlir::cast(unwrap(attr)); + return static_cast(a.getValue()); +} + MlirAttribute mlirPTOTileBufConfigAttrGet(MlirContext ctx, MlirAttribute bLayout, MlirAttribute sLayout, diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index a419713ce..d0e93787f 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -2051,6 +2051,92 @@ static LogicalResult verifyAsyncFlatContiguous1DGMViewLike(Operation *op, return success(); } +static bool isCommGlobalLikeType(Type ty) { + if (auto memTy = dyn_cast(ty)) + return isGmAddressSpaceAttr(memTy.getMemorySpace()); + return isa(ty); +} + +static LogicalResult verifyCommGlobalLike(Operation *op, Value value, + StringRef name) { + Type ty = value.getType(); + if (!isCommGlobalLikeType(ty)) + return op->emitOpError() << "expects " << name + << " to be a GM memref/tensor_view/partition_view"; + + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return op->emitOpError() << "expects " << name << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic || dim <= 0) + return op->emitOpError() << "expects " << name + << " to have a positive static shape"; + } + return success(); +} + +static LogicalResult verifyCommSignalLike(Operation *op, Value value, + StringRef name) { + if (failed(verifyCommGlobalLike(op, value, name))) + return failure(); + Type elemTy = getElemTy(value.getType()); + if (!elemTy || !elemTy.isSignlessInteger(32)) + return op->emitOpError() << "expects " << name + << " element type to be i32"; + return success(); +} + +static LogicalResult verifyCommStagingTileLike(Operation *op, Value value, + StringRef name) { + Type ty = value.getType(); + if (!isa(ty)) + return op->emitOpError() << "expects " << name + << " to be a tile_buf or memref tile"; + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name + << " to be in vec address space"; + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return op->emitOpError() << "expects " << name << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic || dim <= 0) + return op->emitOpError() << "expects " << name + << " to have a positive static shape"; + } + return success(); +} + +static LogicalResult verifyCommGlobalGroup(Operation *op, ValueRange group, + StringRef name) { + if (group.empty()) + return op->emitOpError() << "expects at least one " << name << " operand"; + Type groupTy = group.front().getType(); + for (auto it : llvm::enumerate(group)) { + if (failed(verifyCommGlobalLike(op, it.value(), + (name + "[" + Twine(it.index()) + "]").str()))) + return failure(); + if (it.value().getType() != groupTy) + return op->emitOpError() << "expects all " << name + << " operands to have identical types"; + } + return success(); +} + +static LogicalResult verifyCommPingPongSameType(Operation *op, Value ping, + Value pong, StringRef pingName, + StringRef pongName) { + if (!pong) + return success(); + if (failed(verifyCommStagingTileLike(op, ping, pingName)) || + failed(verifyCommStagingTileLike(op, pong, pongName))) + return failure(); + if (ping.getType() != pong.getType()) + return op->emitOpError() << "expects " << pingName << " and " << pongName + << " to have identical types"; + return success(); +} + static std::optional getStaticByteSize(Type ty) { SmallVector shape = getShapeVec(ty); if (shape.empty()) @@ -9934,6 +10020,150 @@ LogicalResult TGetAsyncOp::verify() { return verifyAsyncTransferOp(getOperation(), getDst(), getSrc()); } +LogicalResult TPutOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong"))) + return failure(); + if (getElemTy(getDst().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects src and dst to have the same element type"); + if (getShapeVec(getDst().getType()) != getShapeVec(getSrc().getType())) + return emitOpError("expects src and dst to have the same static shape"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src/dst"); + return success(); +} + +LogicalResult TGetOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong"))) + return failure(); + if (getElemTy(getDst().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects src and dst to have the same element type"); + if (getShapeVec(getDst().getType()) != getShapeVec(getSrc().getType())) + return emitOpError("expects src and dst to have the same static shape"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src/dst"); + return success(); +} + +LogicalResult TNotifyOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) + return failure(); + auto valueTy = dyn_cast(getValue().getType()); + if (!valueTy || valueTy.getWidth() != 32) + return emitOpError("expects value to be i32"); + return success(); +} + +LogicalResult TWaitOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) + return failure(); + auto cmpTy = dyn_cast(getCmpValue().getType()); + if (!cmpTy || cmpTy.getWidth() != 32) + return emitOpError("expects cmp_value to be i32"); + return success(); +} + +LogicalResult TTestOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) + return failure(); + auto cmpTy = dyn_cast(getCmpValue().getType()); + if (!cmpTy || cmpTy.getWidth() != 32) + return emitOpError("expects cmp_value to be i32"); + return success(); +} + +LogicalResult TBroadcastOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getSrc().getType() != getGroup().front().getType()) + return emitOpError("expects src type to match group member type"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src"); + return success(); +} + +LogicalResult CommTGatherOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getElemTy(getDst().getType()) != getElemTy(getGroup().front().getType())) + return emitOpError("expects dst element type to match group member type"); + if (getElemTy(getPing().getType()) != getElemTy(getDst().getType())) + return emitOpError("expects staging tile element type to match dst"); + return success(); +} + +LogicalResult CommTScatterOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getElemTy(getSrc().getType()) != getElemTy(getGroup().front().getType())) + return emitOpError("expects src element type to match group member type"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src"); + return success(); +} + +LogicalResult TReduceOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommStagingTileLike(*this, getAcc(), "acc")) || + failed(verifyCommStagingTileLike(*this, getRecvPing(), "recv_ping")) || + failed(verifyCommPingPongSameType(*this, getRecvPing(), getRecvPong(), + "recv_ping", "recv_pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getElemTy(getDst().getType()) != getElemTy(getGroup().front().getType())) + return emitOpError("expects dst element type to match group member type"); + if (getAcc().getType() != getRecvPing().getType()) + return emitOpError("expects acc and recv_ping to have identical types"); + if (getElemTy(getAcc().getType()) != getElemTy(getDst().getType())) + return emitOpError("expects accumulator/receive tiles to match dst element type"); + return success(); +} + LogicalResult AicInitializePipeOp::verify() { return verifyFrontendInitCommon(*this, FunctionKernelKind::Cube, "cube"); } @@ -10070,6 +10300,83 @@ void TGetAsyncOp::getEffects( addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); } +void TPutOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); +} + +void TGetOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); +} + +void TNotifyOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSignalMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getValueMutable(), MemoryEffects::Read::get()); +} + +void TWaitOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSignalMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getCmpValueMutable(), MemoryEffects::Read::get()); +} + +void TTestOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSignalMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getCmpValueMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TBroadcastOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); + if (getPong()) { + auto pongRange = getPongMutable(); + if (auto it = pongRange.begin(); it != pongRange.end()) + addEffect(effects, &*it, MemoryEffects::Write::get()); + } +} + +void CommTGatherOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); +} + +void CommTScatterOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); + if (getPong()) { + auto pongRange = getPongMutable(); + if (auto it = pongRange.begin(); it != pongRange.end()) + addEffect(effects, &*it, MemoryEffects::Write::get()); + } +} + +void TReduceOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getAccMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getRecvPingMutable(), MemoryEffects::Read::get()); +} + void WaitAsyncEventOp::getEffects( SmallVectorImpl> &effects) { diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index b362b84b3..ce53dd7ff 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -354,7 +354,10 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { OpKillHandle(curOpInfo, live, op->getBlock()); } else if (isa(op)) { + pto::TPutAsyncOp, pto::TGetAsyncOp, pto::TPutOp, + pto::TGetOp, pto::TNotifyOp, pto::TWaitOp, pto::TTestOp, + pto::TBroadcastOp, pto::CommTGatherOp, + pto::CommTScatterOp, pto::TReduceOp>(op)) { UpdateOpGenInfo(curOpInfo, llvm::to_vector(op->getOperands())); OpKillHandle(curOpInfo, live, op->getBlock()); } else if (auto gpuLaunchOp = dyn_cast(op)) { diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index f091e3a9b..561b55b06 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -5252,6 +5252,367 @@ struct PTOAsyncEventToEmitC : public OpConversionPattern { std::string callee; }; +static FailureOr buildCommGlobalTensorValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalValue, + Value emittedValue, Operation *anchor) { + Value value = peelUnrealized(emittedValue); + if (isEmitCGlobalTensorLikeType(value.getType())) + return value; + + auto memTy = dyn_cast(originalValue.getType()); + if (!memTy) + return failure(); + + Value gt = buildGlobalTensorFromMemref(rewriter, loc, value, memTy, anchor); + if (!gt) + return failure(); + return gt; +} + +static FailureOr buildCommTileValue(ConversionPatternRewriter &rewriter, + Location loc, Value originalValue, + Value emittedValue) { + Value value = peelUnrealized(emittedValue); + if (auto opaqueTy = dyn_cast(value.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return value; + } + return buildAsyncScratchTileValue(rewriter, loc, originalValue, emittedValue); +} + +static std::string notifyOpTok(pto::NotifyOp op) { + switch (op) { + case pto::NotifyOp::AtomicAdd: + return "pto::comm::NotifyOp::AtomicAdd"; + case pto::NotifyOp::Set: + return "pto::comm::NotifyOp::Set"; + } + return "pto::comm::NotifyOp::Set"; +} + +static std::string waitCmpTok(pto::WaitCmp cmp) { + switch (cmp) { + case pto::WaitCmp::EQ: + return "pto::comm::WaitCmp::EQ"; + case pto::WaitCmp::NE: + return "pto::comm::WaitCmp::NE"; + case pto::WaitCmp::GT: + return "pto::comm::WaitCmp::GT"; + case pto::WaitCmp::GE: + return "pto::comm::WaitCmp::GE"; + case pto::WaitCmp::LT: + return "pto::comm::WaitCmp::LT"; + case pto::WaitCmp::LE: + return "pto::comm::WaitCmp::LE"; + } + return "pto::comm::WaitCmp::EQ"; +} + +static std::string reduceOpTok(pto::ReduceOp op) { + switch (op) { + case pto::ReduceOp::Sum: + return "pto::comm::ReduceOp::Sum"; + case pto::ReduceOp::Max: + return "pto::comm::ReduceOp::Max"; + case pto::ReduceOp::Min: + return "pto::comm::ReduceOp::Min"; + } + return "pto::comm::ReduceOp::Sum"; +} + +template +static FailureOr> buildCommGroupGlobalTensors( + ConversionPatternRewriter &rewriter, Location loc, OpTy op, + ValueRange originalGroup, ValueRange emittedGroup) { + SmallVector groupGTs; + groupGTs.reserve(originalGroup.size()); + for (auto [orig, emitted] : llvm::zip(originalGroup, emittedGroup)) { + FailureOr gt = + buildCommGlobalTensorValue(rewriter, loc, orig, emitted, op.getOperation()); + if (failed(gt)) + return failure(); + groupGTs.push_back(*gt); + } + return groupGTs; +} + +template +struct PTOCommCollectiveToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOCommCollectiveToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef apiName) + : OpConversionPattern(typeConverter, ctx), + apiName(apiName.str()) {} + + LogicalResult matchAndRewrite(CollectiveOp op, typename CollectiveOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + SmallVector operands; + std::string lambda = "([&]("; + + auto appendParam = [&](StringRef name) { + if (lambda.back() != '(') + lambda += ", "; + lambda += "auto &"; + lambda += name.str(); + }; + + auto appendOperand = [&](Value value, StringRef name) { + appendParam(name); + operands.push_back(value); + }; + + auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr { + if (!original) + return failure(); + return buildCommTileValue(rewriter, loc, original, emitted); + }; + + if constexpr (std::is_same_v) { + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); + appendOperand(*srcGT, "__src"); + appendOperand(*pingTile, "__ping"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + appendOperand(*pongTile, "__pong"); + } + for (size_t i = 0; i < groupGTs->size(); ++i) + appendOperand((*groupGTs)[i], ("__g" + Twine(i)).str()); + lambda += ") { "; + lambda += "using __GT = std::decay_t; __GT __group[] = {"; + for (size_t i = 0; i < groupGTs->size(); ++i) { + if (i) + lambda += ", "; + lambda += "__g" + std::to_string(i); + } + lambda += "}; auto __pg = pto::comm::ParallelGroup<__GT>::Create(__group, "; + lambda += std::to_string(groupGTs->size()) + ", " + std::to_string(op.getRoot()); + lambda += "); pto::comm::TBROADCAST(__pg, __src, __ping"; + if (op.getPong()) + lambda += ", __pong"; + lambda += "); })"; + } else if constexpr (std::is_same_v) { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); + appendOperand(*dstGT, "__dst"); + appendOperand(*pingTile, "__ping"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + appendOperand(*pongTile, "__pong"); + } + for (size_t i = 0; i < groupGTs->size(); ++i) + appendOperand((*groupGTs)[i], ("__g" + Twine(i)).str()); + lambda += ") { using __GT = std::decay_t; __GT __group[] = {"; + for (size_t i = 0; i < groupGTs->size(); ++i) { + if (i) + lambda += ", "; + lambda += "__g" + std::to_string(i); + } + lambda += "}; auto __pg = pto::comm::ParallelGroup<__GT>::Create(__group, "; + lambda += std::to_string(groupGTs->size()) + ", " + std::to_string(op.getRoot()); + lambda += "); pto::comm::TGATHER(__pg, __dst, __ping"; + if (op.getPong()) + lambda += ", __pong"; + lambda += "); })"; + } else if constexpr (std::is_same_v) { + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); + appendOperand(*srcGT, "__src"); + appendOperand(*pingTile, "__ping"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + appendOperand(*pongTile, "__pong"); + } + for (size_t i = 0; i < groupGTs->size(); ++i) + appendOperand((*groupGTs)[i], ("__g" + Twine(i)).str()); + lambda += ") { using __GT = std::decay_t; __GT __group[] = {"; + for (size_t i = 0; i < groupGTs->size(); ++i) { + if (i) + lambda += ", "; + lambda += "__g" + std::to_string(i); + } + lambda += "}; auto __pg = pto::comm::ParallelGroup<__GT>::Create(__group, "; + lambda += std::to_string(groupGTs->size()) + ", " + std::to_string(op.getRoot()); + lambda += "); pto::comm::TSCATTER(__pg, __src, __ping"; + if (op.getPong()) + lambda += ", __pong"; + lambda += "); })"; + } else { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr accTile = + buildCommTileValue(rewriter, loc, op.getAcc(), adaptor.getAcc()); + FailureOr recvPing = + buildCommTileValue(rewriter, loc, op.getRecvPing(), adaptor.getRecvPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); + appendOperand(*dstGT, "__dst"); + appendOperand(*accTile, "__acc"); + appendOperand(*recvPing, "__recv_ping"); + if (op.getRecvPong()) { + FailureOr recvPong = + buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); + if (failed(recvPong)) + return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); + appendOperand(*recvPong, "__recv_pong"); + } + for (size_t i = 0; i < groupGTs->size(); ++i) + appendOperand((*groupGTs)[i], ("__g" + Twine(i)).str()); + lambda += ") { using __GT = std::decay_t; __GT __group[] = {"; + for (size_t i = 0; i < groupGTs->size(); ++i) { + if (i) + lambda += ", "; + lambda += "__g" + std::to_string(i); + } + lambda += "}; auto __pg = pto::comm::ParallelGroup<__GT>::Create(__group, "; + lambda += std::to_string(groupGTs->size()) + ", " + std::to_string(op.getRoot()); + lambda += "); pto::comm::TREDUCE(__pg, __dst, __acc, __recv_ping"; + if (op.getRecvPong()) + lambda += ", __recv_pong"; + lambda += ", " + reduceOpTok(op.getReduceOp()) + "); })"; + } + + rewriter.create(loc, TypeRange{}, lambda, ArrayAttr{}, + ArrayAttr{}, operands); + rewriter.eraseOp(op); + return success(); + } + + std::string apiName; +}; + +template +struct PTOP2PCommToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOP2PCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} + + LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, op.getLoc(), op.getPing(), adaptor.getPing()); + if (failed(dstGT) || failed(srcGT) || failed(pingTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize p2p operands"); + + SmallVector operands{*dstGT, *srcGT, *pingTile}; + std::string actualCallee = callee; + if constexpr (std::is_same_v) { + if (op.getAtomicType() == pto::AtomicType::AtomicAdd) + actualCallee = "pto::comm::TPUT"; + } + if (op.getPong()) { + FailureOr pongTile = + buildCommTileValue(rewriter, op.getLoc(), op.getPong(), adaptor.getPong()); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + operands.push_back(*pongTile); + } + + rewriter.create(op.getLoc(), TypeRange{}, actualCallee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + return success(); + } + + std::string callee; +}; + +template +struct PTOSignalCommToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOSignalCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(SignalOp op, typename SignalOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr signalGT = buildCommGlobalTensorValue( + rewriter, op.getLoc(), op.getSignal(), adaptor.getSignal(), op.getOperation()); + if (failed(signalGT)) + return rewriter.notifyMatchFailure(op, "failed to materialize signal operand"); + + if constexpr (std::is_same_v) { + std::string actualCallee = + "([&](auto &__signal, auto __value){ pto::comm::TNOTIFY(__signal, __value, " + + notifyOpTok(op.getNotifyOp()) + "); })"; + SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue())}; + rewriter.create(op.getLoc(), TypeRange{}, actualCallee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + } else { + SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue())}; + if constexpr (std::is_same_v) { + Type resultTy = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert ttest result type"); + std::string actualCallee = + "([&](auto &__signal, auto __cmp){ return pto::comm::TTEST(__signal, __cmp, " + + waitCmpTok(op.getCmp()) + "); })"; + rewriter.replaceOpWithNewOp( + op, TypeRange{resultTy}, actualCallee, ArrayAttr{}, ArrayAttr{}, operands); + } else { + std::string actualCallee = + "([&](auto &__signal, auto __cmp){ pto::comm::TWAIT(__signal, __cmp, " + + waitCmpTok(op.getCmp()) + "); })"; + rewriter.create(op.getLoc(), TypeRange{}, actualCallee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + } + } + return success(); + } + + std::string callee; +}; + struct PTODeclareTileMemRefToEmitC : public OpConversionPattern { using OpConversionPattern< @@ -9712,6 +10073,24 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add>( typeConverter, ctx, "pto::comm::TGET_ASYNC"); + patterns.add>(typeConverter, ctx, + "pto::comm::TPUT"); + patterns.add>(typeConverter, ctx, + "pto::comm::TGET"); + patterns.add>(typeConverter, ctx, + "([&](auto &__signal, auto __value){ pto::comm::TNOTIFY(__signal, __value, "); + patterns.add>(typeConverter, ctx, + "([&](auto &__signal, auto __cmp){ pto::comm::TWAIT(__signal, __cmp, "); + patterns.add>(typeConverter, ctx, + "([&](auto &__signal, auto __cmp){ return pto::comm::TTEST(__signal, __cmp, "); + patterns.add>(typeConverter, ctx, + "TBROADCAST"); + patterns.add>(typeConverter, ctx, + "TGATHER"); + patterns.add>(typeConverter, ctx, + "TSCATTER"); + patterns.add>(typeConverter, ctx, + "TREDUCE"); patterns.add>( typeConverter, ctx, "PTOAS__ASYNC_EVENT_WAIT"); patterns.add>( diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index 06ee9843c..91fd99de7 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -52,6 +52,14 @@ def _load_local_pto_ext(): AccToVecModeAttr = _pto_mod.AccToVecModeAttr ReluPreMode = _pto_mod.ReluPreMode ReluPreModeAttr = _pto_mod.ReluPreModeAttr +AtomicType = _pto_mod.AtomicType +AtomicTypeAttr = _pto_mod.AtomicTypeAttr +NotifyOp = _pto_mod.NotifyOp +NotifyOpAttr = _pto_mod.NotifyOpAttr +WaitCmp = _pto_mod.WaitCmp +WaitCmpAttr = _pto_mod.WaitCmpAttr +ReduceOp = _pto_mod.ReduceOp +ReduceOpAttr = _pto_mod.ReduceOpAttr RoundMode = _pto_mod.RoundMode RoundModeAttr = _pto_mod.RoundModeAttr CmpMode = _pto_mod.CmpMode @@ -88,6 +96,10 @@ def _load_local_pto_ext(): "CompactMode", "CompactModeAttr", "AccToVecMode", "AccToVecModeAttr", "ReluPreMode", "ReluPreModeAttr", + "AtomicType", "AtomicTypeAttr", + "NotifyOp", "NotifyOpAttr", + "WaitCmp", "WaitCmpAttr", + "ReduceOp", "ReduceOpAttr", "RoundMode", "RoundModeAttr", "CmpMode", "CmpModeAttr", "PIPE", "PipeAttr", diff --git a/test/basic/comm_collective_emitc.pto b/test/basic/comm_collective_emitc.pto new file mode 100644 index 000000000..0a418dd5b --- /dev/null +++ b/test/basic/comm_collective_emitc.pto @@ -0,0 +1,39 @@ +// RUN: ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s --check-prefix=A3 + +module { + func.func @comm_collective_basic(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr, %peer0_ptr: !pto.ptr, %peer1_ptr: !pto.ptr, %peer2_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c128], strides = [%c1] : !pto.tensor_view<128xf32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c128], strides = [%c1] : !pto.tensor_view<128xf32> + %peer0_view = pto.make_tensor_view %peer0_ptr, shape = [%c128], strides = [%c1] : !pto.tensor_view<128xf32> + %peer1_view = pto.make_tensor_view %peer1_ptr, shape = [%c128], strides = [%c1] : !pto.tensor_view<128xf32> + %peer2_view = pto.make_tensor_view %peer2_ptr, shape = [%c128], strides = [%c1] : !pto.tensor_view<128xf32> + %dst = pto.partition_view %dst_view, offsets = [%c0], sizes = [%c128] : !pto.tensor_view<128xf32> -> !pto.partition_tensor_view<128xf32> + %src = pto.partition_view %src_view, offsets = [%c0], sizes = [%c128] : !pto.tensor_view<128xf32> -> !pto.partition_tensor_view<128xf32> + %peer0 = pto.partition_view %peer0_view, offsets = [%c0], sizes = [%c128] : !pto.tensor_view<128xf32> -> !pto.partition_tensor_view<128xf32> + %peer1 = pto.partition_view %peer1_view, offsets = [%c0], sizes = [%c128] : !pto.tensor_view<128xf32> -> !pto.partition_tensor_view<128xf32> + %peer2 = pto.partition_view %peer2_view, offsets = [%c0], sizes = [%c128] : !pto.tensor_view<128xf32> -> !pto.partition_tensor_view<128xf32> + %ping = pto.alloc_tile : !pto.tile_buf + %pong = pto.alloc_tile : !pto.tile_buf + %acc = pto.alloc_tile : !pto.tile_buf + "pto.tbroadcast"(%src, %ping, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () + "pto.tbroadcast"(%src, %ping, %pong, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () + "pto.comm_tgather"(%dst, %ping, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () + "pto.comm_tgather"(%dst, %ping, %pong, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () + "pto.comm_tscatter"(%src, %ping, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () + "pto.comm_tscatter"(%src, %ping, %pong, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () + "pto.treduce"(%dst, %acc, %ping, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, reduceOp = #pto, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () + "pto.treduce"(%dst, %acc, %ping, %pong, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, reduceOp = #pto, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () + return + } +} + +// A3: pto::comm::ParallelGroup +// A3: pto::comm::TBROADCAST( +// A3: pto::comm::TGATHER( +// A3: pto::comm::TSCATTER( +// A3: pto::comm::TREDUCE( +// A3: pto::comm::ReduceOp::Sum +// A3: pto::comm::ReduceOp::Max diff --git a/test/basic/comm_p2p_emitc.pto b/test/basic/comm_p2p_emitc.pto new file mode 100644 index 000000000..e445a89d4 --- /dev/null +++ b/test/basic/comm_p2p_emitc.pto @@ -0,0 +1,33 @@ +// RUN: ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s --check-prefix=A3 + +module { + func.func @comm_p2p_basic(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr, %signal_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c7_i32 = arith.constant 7 : i32 + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c128], strides = [%c1] : !pto.tensor_view<128xf32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c128], strides = [%c1] : !pto.tensor_view<128xf32> + %signal_view = pto.make_tensor_view %signal_ptr, shape = [%c1], strides = [%c1] : !pto.tensor_view<1xi32> + %dst = pto.partition_view %dst_view, offsets = [%c0], sizes = [%c128] : !pto.tensor_view<128xf32> -> !pto.partition_tensor_view<128xf32> + %src = pto.partition_view %src_view, offsets = [%c0], sizes = [%c128] : !pto.tensor_view<128xf32> -> !pto.partition_tensor_view<128xf32> + %signal = pto.partition_view %signal_view, offsets = [%c0], sizes = [%c1] : !pto.tensor_view<1xi32> -> !pto.partition_tensor_view<1xi32> + %ping = pto.alloc_tile : !pto.tile_buf + %pong = pto.alloc_tile : !pto.tile_buf + "pto.tput"(%dst, %src, %ping) <{atomicType = #pto}> : (!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf) -> () + "pto.tput"(%dst, %src, %ping, %pong) <{atomicType = #pto}> : (!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf) -> () + "pto.tget"(%dst, %src, %ping) : (!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf) -> () + "pto.tget"(%dst, %src, %ping, %pong) : (!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf) -> () + "pto.tnotify"(%signal, %c7_i32) <{notifyOp = #pto}> : (!pto.partition_tensor_view<1xi32>, i32) -> () + "pto.twait"(%signal, %c7_i32) <{cmp = #pto}> : (!pto.partition_tensor_view<1xi32>, i32) -> () + %tested = "pto.ttest"(%signal, %c7_i32) <{cmp = #pto}> : (!pto.partition_tensor_view<1xi32>, i32) -> i1 + return + } +} + +// A3: pto::comm::TPUT( +// A3: pto::comm::TPUT( +// A3: pto::comm::TGET( +// A3: pto::comm::TNOTIFY( +// A3: pto::comm::TWAIT( +// A3: pto::comm::TTEST( diff --git a/test/samples/CommSync/comm_collective.py b/test/samples/CommSync/comm_collective.py new file mode 100644 index 000000000..5fbb88325 --- /dev/null +++ b/test/samples/CommSync/comm_collective.py @@ -0,0 +1,95 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, F32Type, IndexType, InsertionPoint, Location, Module +from mlir.dialects import arith, func, pto + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + module = Module.create() + + f32 = F32Type.get(ctx) + idx = IndexType.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + tv1_f32 = pto.TensorViewType.get([128], f32, ctx) + pv1_f32 = pto.PartitionTensorViewType.get([128], f32, ctx) + + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + cfg = pto.TileBufConfigAttr.get(bl, sl, pto.TileConfig.fractalABSize, pd, ctx) + tb_f32 = pto.TileBufType.get([1, 128], f32, vec, [1, 128], cfg, ctx) + + fn_ty = func.FunctionType.get( + [ptr_f32, ptr_f32, ptr_f32, ptr_f32, ptr_f32], [] + ) + with InsertionPoint(module.body): + fn = func.FuncOp("comm_collective_kernel", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + dst_ptr, src_ptr, peer0_ptr, peer1_ptr, peer2_ptr = entry.arguments + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c128 = arith.ConstantOp(idx, 128).result + + def make_part(arg): + view = pto.MakeTensorViewOp(tv1_f32, arg, [c128], [c1]).result + return pto.PartitionViewOp( + pv1_f32, view, offsets=[c0], sizes=[c128] + ).result + + dst = make_part(dst_ptr) + src = make_part(src_ptr) + peer0 = make_part(peer0_ptr) + peer1 = make_part(peer1_ptr) + peer2 = make_part(peer2_ptr) + + ping = pto.AllocTileOp(tb_f32).result + pong = pto.AllocTileOp(tb_f32).result + acc = pto.AllocTileOp(tb_f32).result + + group = [peer0, peer1, peer2] + root = 1 + + pto.TBroadcastOp(src, ping, group, root) + pto.TBroadcastOp(src, ping, group, root, pong=pong) + pto.CommTGatherOp(dst, ping, group, root) + pto.CommTGatherOp(dst, ping, group, root, pong=pong) + pto.CommTScatterOp(src, ping, group, root) + pto.CommTScatterOp(src, ping, group, root, pong=pong) + pto.TReduceOp( + dst, + acc, + ping, + group, + pto.ReduceOpAttr.get(pto.ReduceOp.Sum, ctx), + root, + ) + pto.TReduceOp( + dst, + acc, + ping, + group, + pto.ReduceOpAttr.get(pto.ReduceOp.Max, ctx), + root, + recvPong=pong, + ) + + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/CommSync/comm_p2p.py b/test/samples/CommSync/comm_p2p.py new file mode 100644 index 000000000..ce92a47bf --- /dev/null +++ b/test/samples/CommSync/comm_p2p.py @@ -0,0 +1,84 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, F32Type, IndexType, InsertionPoint, IntegerType, Location, Module +from mlir.dialects import arith, func, pto + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + module = Module.create() + + f32 = F32Type.get(ctx) + i32 = IntegerType.get_signless(32, ctx) + i1 = IntegerType.get_signless(1, ctx) + idx = IndexType.get(ctx) + + ptr_f32 = pto.PtrType.get(f32, ctx) + ptr_i32 = pto.PtrType.get(i32, ctx) + + tv1_f32 = pto.TensorViewType.get([128], f32, ctx) + pv1_f32 = pto.PartitionTensorViewType.get([128], f32, ctx) + tv1_i32 = pto.TensorViewType.get([1], i32, ctx) + pv1_i32 = pto.PartitionTensorViewType.get([1], i32, ctx) + + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + cfg = pto.TileBufConfigAttr.get(bl, sl, pto.TileConfig.fractalABSize, pd, ctx) + tb_f32 = pto.TileBufType.get([1, 128], f32, vec, [1, 128], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_i32], []) + with InsertionPoint(module.body): + fn = func.FuncOp("comm_p2p_kernel", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + dst_ptr, src_ptr, signal_ptr = entry.arguments + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c128 = arith.ConstantOp(idx, 128).result + c7 = arith.ConstantOp(i32, 7).result + + dst_view = pto.MakeTensorViewOp(tv1_f32, dst_ptr, [c128], [c1]).result + src_view = pto.MakeTensorViewOp(tv1_f32, src_ptr, [c128], [c1]).result + signal_view = pto.MakeTensorViewOp(tv1_i32, signal_ptr, [c1], [c1]).result + + dst = pto.PartitionViewOp(pv1_f32, dst_view, offsets=[c0], sizes=[c128]).result + src = pto.PartitionViewOp(pv1_f32, src_view, offsets=[c0], sizes=[c128]).result + signal = pto.PartitionViewOp(pv1_i32, signal_view, offsets=[c0], sizes=[c1]).result + + ping = pto.AllocTileOp(tb_f32).result + pong = pto.AllocTileOp(tb_f32).result + + pto.TPutOp(dst, src, ping) + pto.TPutOp( + dst, + src, + ping, + pong=pong, + atomicType=pto.AtomicTypeAttr.get(pto.AtomicType.AtomicAdd, ctx), + ) + pto.TGetOp(dst, src, ping) + pto.TGetOp(dst, src, ping, pong=pong) + pto.TNotifyOp(signal, c7, pto.NotifyOpAttr.get(pto.NotifyOp.Set, ctx)) + pto.TWaitOp(signal, c7, pto.WaitCmpAttr.get(pto.WaitCmp.GE, ctx)) + pto.TTestOp(signal, c7, pto.WaitCmpAttr.get(pto.WaitCmp.EQ, ctx)) + + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/runop.sh b/test/samples/runop.sh index e0af12d8e..c7df6fc4e 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -862,6 +862,46 @@ PY fi fi + if [[ "$base" == "comm_p2p" ]]; then + for pat in \ + "pto::comm::TPUT(" \ + "pto::comm::TGET(" \ + "pto::comm::TNOTIFY(" \ + "pto::comm::TWAIT(" \ + "pto::comm::TTEST("; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + if ! grep -Fq "pto::AtomicType::AtomicAdd" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing atomic-add TPUT lowering" + overall=1 + continue + fi + fi + + if [[ "$base" == "comm_collective" ]]; then + for pat in \ + "pto::comm::ParallelGroup" \ + "pto::comm::TBROADCAST(" \ + "pto::comm::TGATHER(" \ + "pto::comm::TSCATTER(" \ + "pto::comm::TREDUCE("; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + if ! grep -Fq "pto::comm::ReduceOp::Sum" "$cpp" || ! grep -Fq "pto::comm::ReduceOp::Max" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing reduce-op enum lowering" + overall=1 + continue + fi + fi + # Regression guard for Issue #190: # Infer layout for a 2D column-vector view (16 x 1) should prefer DN. if [[ "$base" == "tensor_view_infer_layout_dn" ]]; then diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.h b/tools/ptobc/generated/ptobc_opcodes_v0.h index 49a57474d..9af1cc743 100644 --- a/tools/ptobc/generated/ptobc_opcodes_v0.h +++ b/tools/ptobc/generated/ptobc_opcodes_v0.h @@ -155,6 +155,15 @@ inline constexpr OpInfo kOpTable[] = { {0x1073, "pto.trowexpanddiv", 0, 0x00, 0x02, 0, 0, 0, 0x00}, {0x1074, "pto.trowexpandmul", 0, 0x00, 0x02, 0, 0, 0, 0x00}, {0x1075, "pto.tpack", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1076, "pto.tput", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1077, "pto.tget", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1078, "pto.tnotify", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1079, "pto.twait", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x107A, "pto.ttest", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x107B, "pto.tbroadcast", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x107C, "pto.comm_tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x107D, "pto.comm_tscatter", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x107E, "pto.treduce", 0, 0x00, 0x02, 0, 0, 0, 0x00}, {0x2000, "arith.addi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, {0x2001, "arith.ceildivsi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, {0x2002, "arith.cmpi", 0, 0x01, 0x00, 2, 1, 0, 0x01}, @@ -320,6 +329,15 @@ inline std::optional lookupOpcodeByName(llvm::StringRef name) { .Case("pto.subset", 0x1072) .Case("pto.trowexpanddiv", 0x1073) .Case("pto.trowexpandmul", 0x1074) + .Case("pto.tput", 0x1076) + .Case("pto.tget", 0x1077) + .Case("pto.tnotify", 0x1078) + .Case("pto.twait", 0x1079) + .Case("pto.ttest", 0x107A) + .Case("pto.tbroadcast", 0x107B) + .Case("pto.comm_tgather", 0x107C) + .Case("pto.comm_tscatter", 0x107D) + .Case("pto.treduce", 0x107E) .Case("scf.for", 0x4000) .Case("scf.if", 0x4001) .Case("scf.yield", 0x4002) @@ -471,6 +489,15 @@ inline std::optional lookupOpcodeAndVariantByFullName(llvm::St .Case("pto.subset", OpcodeAndVariant{0x1072, 0, 0}) .Case("pto.trowexpanddiv", OpcodeAndVariant{0x1073, 0, 0}) .Case("pto.trowexpandmul", OpcodeAndVariant{0x1074, 0, 0}) + .Case("pto.tput", OpcodeAndVariant{0x1076, 0, 0}) + .Case("pto.tget", OpcodeAndVariant{0x1077, 0, 0}) + .Case("pto.tnotify", OpcodeAndVariant{0x1078, 0, 0}) + .Case("pto.twait", OpcodeAndVariant{0x1079, 0, 0}) + .Case("pto.ttest", OpcodeAndVariant{0x107A, 0, 0}) + .Case("pto.tbroadcast", OpcodeAndVariant{0x107B, 0, 0}) + .Case("pto.comm_tgather", OpcodeAndVariant{0x107C, 0, 0}) + .Case("pto.comm_tscatter", OpcodeAndVariant{0x107D, 0, 0}) + .Case("pto.treduce", OpcodeAndVariant{0x107E, 0, 0}) .Case("scf.for", OpcodeAndVariant{0x4000, 0, 0}) .Case("scf.if", OpcodeAndVariant{0x4001, 0, 0}) .Case("scf.yield", OpcodeAndVariant{0x4002, 0, 0})