diff --git a/include/flydsl/Dialect/Fly/IR/FlyOps.td b/include/flydsl/Dialect/Fly/IR/FlyOps.td index f0adac843..92e015ef0 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyOps.td +++ b/include/flydsl/Dialect/Fly/IR/FlyOps.td @@ -358,6 +358,7 @@ def Fly_AtomSetValueOp : Fly_Op<"atom.set_value", [Pure, DeclareOpInterfaceMetho } def Fly_CopyAtomCall : Fly_Op<"copy_atom_call"> { + let hasVerifier = 1; let arguments = (ins Fly_CopyAtom:$copyAtom, Fly_MemRef:$src, Fly_MemRef:$dst, Optional:$pred); } def Fly_MmaAtomCall : Fly_Op<"mma_atom_call"> { @@ -365,6 +366,7 @@ def Fly_MmaAtomCall : Fly_Op<"mma_atom_call"> { } def Fly_CopyAtomCallSSA : Fly_Op<"copy_atom_call_ssa", [AttrSizedOperandSegments]> { + let hasVerifier = 1; let arguments = (ins Fly_CopyAtom:$copyAtom, AnyType:$src, Optional:$dst, Optional:$pred); let results = (outs Variadic:$results); diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp index a278fea36..a1c083220 100644 --- a/lib/Dialect/Fly/IR/FlyOps.cpp +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -124,6 +124,53 @@ Type applyOffsetOnTensorLike(LayoutBuilder &builder, Type tensorLike llvm_unreachable("Unsupported tensor like type"); } +FailureOr> getCoalescedLeafCountAndStride(fly::MemRefType memRefTy) { + auto layoutAttr = dyn_cast(memRefTy.getLayout()); + if (!layoutAttr) + return failure(); + LayoutBuilder builder(memRefTy.getContext()); + auto coalesced = layoutCoalesce(builder, layoutAttr); + if (!coalesced.isLeaf()) + return failure(); + auto shape = coalesced.getShape().getLeafAsInt(); + auto stride = coalesced.getStride().getLeafAsInt(); + if (!shape.isStatic() || !stride.isStatic()) + return failure(); + return std::make_pair(shape.getValue(), stride.getValue()); +} + +LogicalResult verifyUniversalCopyOperand(Operation *op, StringRef operandName, CopyAtomType copyAtomTy, + fly::MemRefType memRefTy) { + auto universalCopy = dyn_cast(copyAtomTy.getCopyOp()); + if (!universalCopy) + return success(); + + auto countAndStride = getCoalescedLeafCountAndStride(memRefTy); + if (failed(countAndStride)) { + return op->emitOpError() << operandName + << " memref layout must coalesce to a single static leaf for " + << copyAtomTy; + } + + auto [count, stride] = *countAndStride; + int64_t elemBits = memRefTy.getElemTy().getIntOrFloatBitWidth(); + int64_t copyBits = universalCopy.getBitSize(); + int64_t totalBits = count * elemBits; + if (totalBits != copyBits) { + return op->emitOpError() << operandName << " memref covers " << totalBits + << " bits after coalescing, but " << copyAtomTy << " expects " + << copyBits << " bits"; + } + + int64_t contiguousBits = (count <= 1 || stride == 1) ? totalBits : elemBits; + if (contiguousBits < copyBits) { + return op->emitOpError() << operandName << " memref contiguous bit count " << contiguousBits + << " is smaller than copy granularity " << copyBits; + } + + return success(); +} + } // namespace #define FLY_INFER_RETURN_TYPES(OP) \ @@ -133,6 +180,40 @@ Type applyOffsetOnTensorLike(LayoutBuilder &builder, Type tensorLike mlir::OpaqueProperties properties, mlir::RegionRange regions, \ llvm::SmallVectorImpl &inferredReturnTypes) +LogicalResult CopyAtomCall::verify() { + auto copyAtomTy = dyn_cast(getCopyAtom().getType()); + if (!copyAtomTy) + return emitOpError("copyAtom is not CopyAtomType"); + + auto srcTy = cast(getSrc().getType()); + auto dstTy = cast(getDst().getType()); + if (srcTy.getElemTy() != dstTy.getElemTy()) + return emitOpError("src/dst element types mismatch"); + + if (failed(verifyUniversalCopyOperand(getOperation(), "src", copyAtomTy, srcTy))) + return failure(); + if (failed(verifyUniversalCopyOperand(getOperation(), "dst", copyAtomTy, dstTy))) + return failure(); + return success(); +} + +LogicalResult CopyAtomCallSSA::verify() { + auto copyAtomTy = dyn_cast(getCopyAtom().getType()); + if (!copyAtomTy) + return emitOpError("copyAtom is not CopyAtomType"); + + auto srcTy = dyn_cast(getSrc().getType()); + auto dstTy = getDst() ? dyn_cast(getDst().getType()) : fly::MemRefType(); + if (srcTy && dstTy && srcTy.getElemTy() != dstTy.getElemTy()) + return emitOpError("src/dst element types mismatch"); + + if (srcTy && failed(verifyUniversalCopyOperand(getOperation(), "src", copyAtomTy, srcTy))) + return failure(); + if (dstTy && failed(verifyUniversalCopyOperand(getOperation(), "dst", copyAtomTy, dstTy))) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // Constructors //===----------------------------------------------------------------------===// diff --git a/tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir b/tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir index 7db68b081..662e00a5f 100644 --- a/tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir +++ b/tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir @@ -120,31 +120,6 @@ gpu.module @convert_atom_call_to_ssa_form { gpu.return } - // Test 3b: copy_atom_call with register dst, non-coalescable layout should NOT be promoted - // (4,2):(1,8) cannot coalesce to rank=1 stride=1 - // CHECK-LABEL: gpu.func @copy_dst_register_non_coalescable - // CHECK: fly.copy_atom_call( - // CHECK-NOT: fly.copy_atom_call_ssa - gpu.func @copy_dst_register_non_coalescable(%src: !fly.ptr) kernel { - %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> - %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> - - %src_view = fly.make_view(%src, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref - - %nc_shape = fly.make_int_tuple() : () -> !fly.int_tuple<(4,2)> - %nc_stride = fly.make_int_tuple() : () -> !fly.int_tuple<(1,8)> - %nc_layout = fly.make_layout(%nc_shape, %nc_stride) : (!fly.int_tuple<(4,2)>, !fly.int_tuple<(1,8)>) -> !fly.layout<(4,2):(1,8)> - - %copy = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> - - %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr - %reg_view = fly.make_view(%reg_ptr, %nc_layout) : (!fly.ptr, !fly.layout<(4,2):(1,8)>) -> !fly.memref - - fly.copy_atom_call(%copy, %src_view, %reg_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () - gpu.return - } - // Test 4: mma_atom_call with register d (rank=1, stride=1) should be promoted // a, b, c are also register eligible, so they get pre-loaded as vectors // CHECK-LABEL: gpu.func @mma_d_register diff --git a/tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir b/tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir new file mode 100644 index 000000000..aae144a48 --- /dev/null +++ b/tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// RUN: not %fly-opt %s --fly-convert-atom-call-to-ssa-form --convert-fly-to-rocdl 2>&1 | FileCheck %s + +gpu.module @bug_strided_universal_copy { + +// CHECK: error: 'fly.copy_atom_call' op src memref contiguous bit count 16 is smaller than copy granularity 64 + gpu.func @load_strided_global_into_register(%src: !fly.ptr) kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %stride8 = fly.make_int_tuple() : () -> !fly.int_tuple<8> + + %src_layout = fly.make_layout(%shape4, %stride8) + : (!fly.int_tuple<4>, !fly.int_tuple<8>) -> !fly.layout<4:8> + %reg_layout = fly.make_layout(%shape4, %stride1) + : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %src_view = fly.make_view(%src, %src_layout) + : (!fly.ptr, !fly.layout<4:8>) -> !fly.memref + + %copy = fly.make_copy_atom {valBits = 16 : i32} + : !fly.copy_atom, 16> + + %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} + : () -> !fly.ptr + %reg_view = fly.make_view(%reg_ptr, %reg_layout) + : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + fly.copy_atom_call(%copy, %src_view, %reg_view) + : (!fly.copy_atom, 16>, + !fly.memref, + !fly.memref) -> () + gpu.return + } +}