From 56dd78b2fc3a5dacede5ee16d7232e023b2a0cfb Mon Sep 17 00:00:00 2001 From: "kefan.cao" Date: Thu, 23 Apr 2026 11:26:59 +0800 Subject: [PATCH 1/2] [Test] Add reproducer for strided universal_copy in convert-fly-to-rocdl Freezes current (buggy) lowering: a non-unit-stride !fly.memref on one side of fly.copy_atom_call is lowered to a single contiguous llvm.load / llvm.store against the memory-side pointer, silently ignoring the stride. Next commit will fix emitAtomCallSSA and update these CHECKs. --- ...t_fly_to_rocdl_universal_copy_strided.mlir | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir diff --git a/tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir b/tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir new file mode 100644 index 000000000..5523c53b4 --- /dev/null +++ b/tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// RUN: %fly-opt %s --fly-convert-atom-call-to-ssa-form --convert-fly-to-rocdl | FileCheck %s + +gpu.module @bug_strided_universal_copy { + +// CHECK-LABEL: gpu.func @load_strided_global_into_register( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1> +// CHECK: %[[REG:.*]] = llvm.alloca %{{.*}} x f16 : (i64) -> !llvm.ptr<5> +// CHECK: %[[V:.*]] = llvm.load %[[ARG0]] : !llvm.ptr<1> -> vector<4xf16> +// CHECK-NEXT: llvm.store %[[V]], %[[REG]] : vector<4xf16>, !llvm.ptr<5> + gpu.func @load_strided_global_into_register(%src: !fly.ptr) kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %stride8 = fly.make_int_tuple() : () -> !fly.int_tuple<8> + + %src_layout = fly.make_layout(%shape4, %stride8) + : (!fly.int_tuple<4>, !fly.int_tuple<8>) -> !fly.layout<4:8> + %reg_layout = fly.make_layout(%shape4, %stride1) + : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %src_view = fly.make_view(%src, %src_layout) + : (!fly.ptr, !fly.layout<4:8>) -> !fly.memref + + %copy = fly.make_copy_atom {valBits = 16 : i32} + : !fly.copy_atom, 16> + + %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} + : () -> !fly.ptr + %reg_view = fly.make_view(%reg_ptr, %reg_layout) + : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + fly.copy_atom_call(%copy, %src_view, %reg_view) + : (!fly.copy_atom, 16>, + !fly.memref, + !fly.memref) -> () + gpu.return + } + +// CHECK-LABEL: gpu.func @store_register_into_strided_global( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1> +// CHECK: %[[REG:.*]] = llvm.alloca %{{.*}} x f16 : (i64) -> !llvm.ptr<5> +// CHECK: %[[V:.*]] = llvm.load %[[REG]] : !llvm.ptr<5> -> vector<4xf16> +// CHECK-NEXT: llvm.store %[[V]], %[[ARG0]] : vector<4xf16>, !llvm.ptr<1> + gpu.func @store_register_into_strided_global(%dst: !fly.ptr) kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %stride8 = fly.make_int_tuple() : () -> !fly.int_tuple<8> + + %dst_layout = fly.make_layout(%shape4, %stride8) + : (!fly.int_tuple<4>, !fly.int_tuple<8>) -> !fly.layout<4:8> + %reg_layout = fly.make_layout(%shape4, %stride1) + : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %dst_view = fly.make_view(%dst, %dst_layout) + : (!fly.ptr, !fly.layout<4:8>) -> !fly.memref + + %copy = fly.make_copy_atom {valBits = 16 : i32} + : !fly.copy_atom, 16> + + %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} + : () -> !fly.ptr + %reg_view = fly.make_view(%reg_ptr, %reg_layout) + : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + fly.copy_atom_call(%copy, %reg_view, %dst_view) + : (!fly.copy_atom, 16>, + !fly.memref, + !fly.memref) -> () + gpu.return + } +} From cd1ac5950419875b83b9fcbefce8388d8a0b0c0b Mon Sep 17 00:00:00 2001 From: "kefan.cao" Date: Fri, 24 Apr 2026 09:56:11 +0800 Subject: [PATCH 2/2] [Fix] Reject invalid universal copy operands before lowering `fly.copy_atom_call` / `fly.copy_atom_call_ssa` with `!fly.universal_copy<...>` used to accept non-contiguous memrefs and lower them as if the atom operated on a contiguous slice. Fix this by verifying universal copy operands before lowering: * the memref layout must coalesce to a single static leaf * the coalesced bit count must match the copy atom bit width * the contiguous bit count must not be smaller than the copy granularity This turns strided/otherwise incompatible memrefs into a clear verification error instead of silently generating code that deviates from the original atom semantics. --- include/flydsl/Dialect/Fly/IR/FlyOps.td | 2 + lib/Dialect/Fly/IR/FlyOps.cpp | 81 +++++++++++++++++++ .../convert-atom-call-to-ssa-form.mlir | 25 ------ ...t_fly_to_rocdl_universal_copy_strided.mlir | 41 +--------- 4 files changed, 85 insertions(+), 64 deletions(-) diff --git a/include/flydsl/Dialect/Fly/IR/FlyOps.td b/include/flydsl/Dialect/Fly/IR/FlyOps.td index f0adac843..92e015ef0 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyOps.td +++ b/include/flydsl/Dialect/Fly/IR/FlyOps.td @@ -358,6 +358,7 @@ def Fly_AtomSetValueOp : Fly_Op<"atom.set_value", [Pure, DeclareOpInterfaceMetho } def Fly_CopyAtomCall : Fly_Op<"copy_atom_call"> { + let hasVerifier = 1; let arguments = (ins Fly_CopyAtom:$copyAtom, Fly_MemRef:$src, Fly_MemRef:$dst, Optional:$pred); } def Fly_MmaAtomCall : Fly_Op<"mma_atom_call"> { @@ -365,6 +366,7 @@ def Fly_MmaAtomCall : Fly_Op<"mma_atom_call"> { } def Fly_CopyAtomCallSSA : Fly_Op<"copy_atom_call_ssa", [AttrSizedOperandSegments]> { + let hasVerifier = 1; let arguments = (ins Fly_CopyAtom:$copyAtom, AnyType:$src, Optional:$dst, Optional:$pred); let results = (outs Variadic:$results); diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp index a278fea36..a1c083220 100644 --- a/lib/Dialect/Fly/IR/FlyOps.cpp +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -124,6 +124,53 @@ Type applyOffsetOnTensorLike(LayoutBuilder &builder, Type tensorLike llvm_unreachable("Unsupported tensor like type"); } +FailureOr> getCoalescedLeafCountAndStride(fly::MemRefType memRefTy) { + auto layoutAttr = dyn_cast(memRefTy.getLayout()); + if (!layoutAttr) + return failure(); + LayoutBuilder builder(memRefTy.getContext()); + auto coalesced = layoutCoalesce(builder, layoutAttr); + if (!coalesced.isLeaf()) + return failure(); + auto shape = coalesced.getShape().getLeafAsInt(); + auto stride = coalesced.getStride().getLeafAsInt(); + if (!shape.isStatic() || !stride.isStatic()) + return failure(); + return std::make_pair(shape.getValue(), stride.getValue()); +} + +LogicalResult verifyUniversalCopyOperand(Operation *op, StringRef operandName, CopyAtomType copyAtomTy, + fly::MemRefType memRefTy) { + auto universalCopy = dyn_cast(copyAtomTy.getCopyOp()); + if (!universalCopy) + return success(); + + auto countAndStride = getCoalescedLeafCountAndStride(memRefTy); + if (failed(countAndStride)) { + return op->emitOpError() << operandName + << " memref layout must coalesce to a single static leaf for " + << copyAtomTy; + } + + auto [count, stride] = *countAndStride; + int64_t elemBits = memRefTy.getElemTy().getIntOrFloatBitWidth(); + int64_t copyBits = universalCopy.getBitSize(); + int64_t totalBits = count * elemBits; + if (totalBits != copyBits) { + return op->emitOpError() << operandName << " memref covers " << totalBits + << " bits after coalescing, but " << copyAtomTy << " expects " + << copyBits << " bits"; + } + + int64_t contiguousBits = (count <= 1 || stride == 1) ? totalBits : elemBits; + if (contiguousBits < copyBits) { + return op->emitOpError() << operandName << " memref contiguous bit count " << contiguousBits + << " is smaller than copy granularity " << copyBits; + } + + return success(); +} + } // namespace #define FLY_INFER_RETURN_TYPES(OP) \ @@ -133,6 +180,40 @@ Type applyOffsetOnTensorLike(LayoutBuilder &builder, Type tensorLike mlir::OpaqueProperties properties, mlir::RegionRange regions, \ llvm::SmallVectorImpl &inferredReturnTypes) +LogicalResult CopyAtomCall::verify() { + auto copyAtomTy = dyn_cast(getCopyAtom().getType()); + if (!copyAtomTy) + return emitOpError("copyAtom is not CopyAtomType"); + + auto srcTy = cast(getSrc().getType()); + auto dstTy = cast(getDst().getType()); + if (srcTy.getElemTy() != dstTy.getElemTy()) + return emitOpError("src/dst element types mismatch"); + + if (failed(verifyUniversalCopyOperand(getOperation(), "src", copyAtomTy, srcTy))) + return failure(); + if (failed(verifyUniversalCopyOperand(getOperation(), "dst", copyAtomTy, dstTy))) + return failure(); + return success(); +} + +LogicalResult CopyAtomCallSSA::verify() { + auto copyAtomTy = dyn_cast(getCopyAtom().getType()); + if (!copyAtomTy) + return emitOpError("copyAtom is not CopyAtomType"); + + auto srcTy = dyn_cast(getSrc().getType()); + auto dstTy = getDst() ? dyn_cast(getDst().getType()) : fly::MemRefType(); + if (srcTy && dstTy && srcTy.getElemTy() != dstTy.getElemTy()) + return emitOpError("src/dst element types mismatch"); + + if (srcTy && failed(verifyUniversalCopyOperand(getOperation(), "src", copyAtomTy, srcTy))) + return failure(); + if (dstTy && failed(verifyUniversalCopyOperand(getOperation(), "dst", copyAtomTy, dstTy))) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // Constructors //===----------------------------------------------------------------------===// diff --git a/tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir b/tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir index 7db68b081..662e00a5f 100644 --- a/tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir +++ b/tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir @@ -120,31 +120,6 @@ gpu.module @convert_atom_call_to_ssa_form { gpu.return } - // Test 3b: copy_atom_call with register dst, non-coalescable layout should NOT be promoted - // (4,2):(1,8) cannot coalesce to rank=1 stride=1 - // CHECK-LABEL: gpu.func @copy_dst_register_non_coalescable - // CHECK: fly.copy_atom_call( - // CHECK-NOT: fly.copy_atom_call_ssa - gpu.func @copy_dst_register_non_coalescable(%src: !fly.ptr) kernel { - %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> - %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> - - %src_view = fly.make_view(%src, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref - - %nc_shape = fly.make_int_tuple() : () -> !fly.int_tuple<(4,2)> - %nc_stride = fly.make_int_tuple() : () -> !fly.int_tuple<(1,8)> - %nc_layout = fly.make_layout(%nc_shape, %nc_stride) : (!fly.int_tuple<(4,2)>, !fly.int_tuple<(1,8)>) -> !fly.layout<(4,2):(1,8)> - - %copy = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> - - %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr - %reg_view = fly.make_view(%reg_ptr, %nc_layout) : (!fly.ptr, !fly.layout<(4,2):(1,8)>) -> !fly.memref - - fly.copy_atom_call(%copy, %src_view, %reg_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () - gpu.return - } - // Test 4: mma_atom_call with register d (rank=1, stride=1) should be promoted // a, b, c are also register eligible, so they get pre-loaded as vectors // CHECK-LABEL: gpu.func @mma_d_register diff --git a/tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir b/tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir index 5523c53b4..aae144a48 100644 --- a/tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir +++ b/tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir @@ -1,14 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (c) 2025 FlyDSL Project Contributors -// RUN: %fly-opt %s --fly-convert-atom-call-to-ssa-form --convert-fly-to-rocdl | FileCheck %s +// RUN: not %fly-opt %s --fly-convert-atom-call-to-ssa-form --convert-fly-to-rocdl 2>&1 | FileCheck %s gpu.module @bug_strided_universal_copy { -// CHECK-LABEL: gpu.func @load_strided_global_into_register( -// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1> -// CHECK: %[[REG:.*]] = llvm.alloca %{{.*}} x f16 : (i64) -> !llvm.ptr<5> -// CHECK: %[[V:.*]] = llvm.load %[[ARG0]] : !llvm.ptr<1> -> vector<4xf16> -// CHECK-NEXT: llvm.store %[[V]], %[[REG]] : vector<4xf16>, !llvm.ptr<5> +// CHECK: error: 'fly.copy_atom_call' op src memref contiguous bit count 16 is smaller than copy granularity 64 gpu.func @load_strided_global_into_register(%src: !fly.ptr) kernel { %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> @@ -36,37 +32,4 @@ gpu.module @bug_strided_universal_copy { !fly.memref) -> () gpu.return } - -// CHECK-LABEL: gpu.func @store_register_into_strided_global( -// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1> -// CHECK: %[[REG:.*]] = llvm.alloca %{{.*}} x f16 : (i64) -> !llvm.ptr<5> -// CHECK: %[[V:.*]] = llvm.load %[[REG]] : !llvm.ptr<5> -> vector<4xf16> -// CHECK-NEXT: llvm.store %[[V]], %[[ARG0]] : vector<4xf16>, !llvm.ptr<1> - gpu.func @store_register_into_strided_global(%dst: !fly.ptr) kernel { - %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> - %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %stride8 = fly.make_int_tuple() : () -> !fly.int_tuple<8> - - %dst_layout = fly.make_layout(%shape4, %stride8) - : (!fly.int_tuple<4>, !fly.int_tuple<8>) -> !fly.layout<4:8> - %reg_layout = fly.make_layout(%shape4, %stride1) - : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> - - %dst_view = fly.make_view(%dst, %dst_layout) - : (!fly.ptr, !fly.layout<4:8>) -> !fly.memref - - %copy = fly.make_copy_atom {valBits = 16 : i32} - : !fly.copy_atom, 16> - - %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} - : () -> !fly.ptr - %reg_view = fly.make_view(%reg_ptr, %reg_layout) - : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref - - fly.copy_atom_call(%copy, %reg_view, %dst_view) - : (!fly.copy_atom, 16>, - !fly.memref, - !fly.memref) -> () - gpu.return - } }