Skip to content

Commit cd1ac59

Browse files
committed
[Fix] Reject invalid universal copy operands before lowering
`fly.copy_atom_call` / `fly.copy_atom_call_ssa` with `!fly.universal_copy<...>` used to accept non-contiguous memrefs and lower them as if the atom operated on a contiguous slice. Fix this by verifying universal copy operands before lowering: * the memref layout must coalesce to a single static leaf * the coalesced bit count must match the copy atom bit width * the contiguous bit count must not be smaller than the copy granularity This turns strided/otherwise incompatible memrefs into a clear verification error instead of silently generating code that deviates from the original atom semantics.
1 parent 56dd78b commit cd1ac59

4 files changed

Lines changed: 85 additions & 64 deletions

File tree

include/flydsl/Dialect/Fly/IR/FlyOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,15 @@ def Fly_AtomSetValueOp : Fly_Op<"atom.set_value", [Pure, DeclareOpInterfaceMetho
358358
}
359359

360360
def Fly_CopyAtomCall : Fly_Op<"copy_atom_call"> {
361+
let hasVerifier = 1;
361362
let arguments = (ins Fly_CopyAtom:$copyAtom, Fly_MemRef:$src, Fly_MemRef:$dst, Optional<Fly_MemRef>:$pred);
362363
}
363364
def Fly_MmaAtomCall : Fly_Op<"mma_atom_call"> {
364365
let arguments = (ins Fly_MmaAtom:$mmaAtom, Fly_MemRef:$d, Fly_MemRef:$a, Fly_MemRef:$b, Fly_MemRef:$c);
365366
}
366367

367368
def Fly_CopyAtomCallSSA : Fly_Op<"copy_atom_call_ssa", [AttrSizedOperandSegments]> {
369+
let hasVerifier = 1;
368370
let arguments = (ins Fly_CopyAtom:$copyAtom, AnyType:$src,
369371
Optional<AnyType>:$dst, Optional<AnyType>:$pred);
370372
let results = (outs Variadic<AnyType>:$results);

lib/Dialect/Fly/IR/FlyOps.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,53 @@ Type applyOffsetOnTensorLike(LayoutBuilder<LayoutAttr> &builder, Type tensorLike
124124
llvm_unreachable("Unsupported tensor like type");
125125
}
126126

127+
FailureOr<std::pair<int64_t, int64_t>> getCoalescedLeafCountAndStride(fly::MemRefType memRefTy) {
128+
auto layoutAttr = dyn_cast<LayoutAttr>(memRefTy.getLayout());
129+
if (!layoutAttr)
130+
return failure();
131+
LayoutBuilder<LayoutAttr> builder(memRefTy.getContext());
132+
auto coalesced = layoutCoalesce(builder, layoutAttr);
133+
if (!coalesced.isLeaf())
134+
return failure();
135+
auto shape = coalesced.getShape().getLeafAsInt();
136+
auto stride = coalesced.getStride().getLeafAsInt();
137+
if (!shape.isStatic() || !stride.isStatic())
138+
return failure();
139+
return std::make_pair<int64_t, int64_t>(shape.getValue(), stride.getValue());
140+
}
141+
142+
LogicalResult verifyUniversalCopyOperand(Operation *op, StringRef operandName, CopyAtomType copyAtomTy,
143+
fly::MemRefType memRefTy) {
144+
auto universalCopy = dyn_cast<CopyOpUniversalCopyType>(copyAtomTy.getCopyOp());
145+
if (!universalCopy)
146+
return success();
147+
148+
auto countAndStride = getCoalescedLeafCountAndStride(memRefTy);
149+
if (failed(countAndStride)) {
150+
return op->emitOpError() << operandName
151+
<< " memref layout must coalesce to a single static leaf for "
152+
<< copyAtomTy;
153+
}
154+
155+
auto [count, stride] = *countAndStride;
156+
int64_t elemBits = memRefTy.getElemTy().getIntOrFloatBitWidth();
157+
int64_t copyBits = universalCopy.getBitSize();
158+
int64_t totalBits = count * elemBits;
159+
if (totalBits != copyBits) {
160+
return op->emitOpError() << operandName << " memref covers " << totalBits
161+
<< " bits after coalescing, but " << copyAtomTy << " expects "
162+
<< copyBits << " bits";
163+
}
164+
165+
int64_t contiguousBits = (count <= 1 || stride == 1) ? totalBits : elemBits;
166+
if (contiguousBits < copyBits) {
167+
return op->emitOpError() << operandName << " memref contiguous bit count " << contiguousBits
168+
<< " is smaller than copy granularity " << copyBits;
169+
}
170+
171+
return success();
172+
}
173+
127174
} // namespace
128175

129176
#define FLY_INFER_RETURN_TYPES(OP) \
@@ -133,6 +180,40 @@ Type applyOffsetOnTensorLike(LayoutBuilder<LayoutAttr> &builder, Type tensorLike
133180
mlir::OpaqueProperties properties, mlir::RegionRange regions, \
134181
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes)
135182

183+
LogicalResult CopyAtomCall::verify() {
184+
auto copyAtomTy = dyn_cast<CopyAtomType>(getCopyAtom().getType());
185+
if (!copyAtomTy)
186+
return emitOpError("copyAtom is not CopyAtomType");
187+
188+
auto srcTy = cast<fly::MemRefType>(getSrc().getType());
189+
auto dstTy = cast<fly::MemRefType>(getDst().getType());
190+
if (srcTy.getElemTy() != dstTy.getElemTy())
191+
return emitOpError("src/dst element types mismatch");
192+
193+
if (failed(verifyUniversalCopyOperand(getOperation(), "src", copyAtomTy, srcTy)))
194+
return failure();
195+
if (failed(verifyUniversalCopyOperand(getOperation(), "dst", copyAtomTy, dstTy)))
196+
return failure();
197+
return success();
198+
}
199+
200+
LogicalResult CopyAtomCallSSA::verify() {
201+
auto copyAtomTy = dyn_cast<CopyAtomType>(getCopyAtom().getType());
202+
if (!copyAtomTy)
203+
return emitOpError("copyAtom is not CopyAtomType");
204+
205+
auto srcTy = dyn_cast<fly::MemRefType>(getSrc().getType());
206+
auto dstTy = getDst() ? dyn_cast<fly::MemRefType>(getDst().getType()) : fly::MemRefType();
207+
if (srcTy && dstTy && srcTy.getElemTy() != dstTy.getElemTy())
208+
return emitOpError("src/dst element types mismatch");
209+
210+
if (srcTy && failed(verifyUniversalCopyOperand(getOperation(), "src", copyAtomTy, srcTy)))
211+
return failure();
212+
if (dstTy && failed(verifyUniversalCopyOperand(getOperation(), "dst", copyAtomTy, dstTy)))
213+
return failure();
214+
return success();
215+
}
216+
136217
//===----------------------------------------------------------------------===//
137218
// Constructors
138219
//===----------------------------------------------------------------------===//

tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -120,31 +120,6 @@ gpu.module @convert_atom_call_to_ssa_form {
120120
gpu.return
121121
}
122122

123-
// Test 3b: copy_atom_call with register dst, non-coalescable layout should NOT be promoted
124-
// (4,2):(1,8) cannot coalesce to rank=1 stride=1
125-
// CHECK-LABEL: gpu.func @copy_dst_register_non_coalescable
126-
// CHECK: fly.copy_atom_call(
127-
// CHECK-NOT: fly.copy_atom_call_ssa
128-
gpu.func @copy_dst_register_non_coalescable(%src: !fly.ptr<f16, global>) kernel {
129-
%shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4>
130-
%stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1>
131-
%vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1>
132-
133-
%src_view = fly.make_view(%src, %vec4) : (!fly.ptr<f16, global>, !fly.layout<4:1>) -> !fly.memref<f16, global, 4:1>
134-
135-
%nc_shape = fly.make_int_tuple() : () -> !fly.int_tuple<(4,2)>
136-
%nc_stride = fly.make_int_tuple() : () -> !fly.int_tuple<(1,8)>
137-
%nc_layout = fly.make_layout(%nc_shape, %nc_stride) : (!fly.int_tuple<(4,2)>, !fly.int_tuple<(1,8)>) -> !fly.layout<(4,2):(1,8)>
138-
139-
%copy = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom<!fly.universal_copy<64>, 16>
140-
141-
%reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr<f16, register>
142-
%reg_view = fly.make_view(%reg_ptr, %nc_layout) : (!fly.ptr<f16, register>, !fly.layout<(4,2):(1,8)>) -> !fly.memref<f16, register, (4,2):(1,8)>
143-
144-
fly.copy_atom_call(%copy, %src_view, %reg_view) : (!fly.copy_atom<!fly.universal_copy<64>, 16>, !fly.memref<f16, global, 4:1>, !fly.memref<f16, register, (4,2):(1,8)>) -> ()
145-
gpu.return
146-
}
147-
148123
// Test 4: mma_atom_call with register d (rank=1, stride=1) should be promoted
149124
// a, b, c are also register eligible, so they get pre-loaded as vectors
150125
// CHECK-LABEL: gpu.func @mma_d_register
Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// Copyright (c) 2025 FlyDSL Project Contributors
3-
// RUN: %fly-opt %s --fly-convert-atom-call-to-ssa-form --convert-fly-to-rocdl | FileCheck %s
3+
// RUN: not %fly-opt %s --fly-convert-atom-call-to-ssa-form --convert-fly-to-rocdl 2>&1 | FileCheck %s
44

55
gpu.module @bug_strided_universal_copy {
66

7-
// CHECK-LABEL: gpu.func @load_strided_global_into_register(
8-
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>
9-
// CHECK: %[[REG:.*]] = llvm.alloca %{{.*}} x f16 : (i64) -> !llvm.ptr<5>
10-
// CHECK: %[[V:.*]] = llvm.load %[[ARG0]] : !llvm.ptr<1> -> vector<4xf16>
11-
// CHECK-NEXT: llvm.store %[[V]], %[[REG]] : vector<4xf16>, !llvm.ptr<5>
7+
// CHECK: error: 'fly.copy_atom_call' op src memref contiguous bit count 16 is smaller than copy granularity 64
128
gpu.func @load_strided_global_into_register(%src: !fly.ptr<f16, global>) kernel {
139
%shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4>
1410
%stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1>
@@ -36,37 +32,4 @@ gpu.module @bug_strided_universal_copy {
3632
!fly.memref<f16, register, 4:1>) -> ()
3733
gpu.return
3834
}
39-
40-
// CHECK-LABEL: gpu.func @store_register_into_strided_global(
41-
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>
42-
// CHECK: %[[REG:.*]] = llvm.alloca %{{.*}} x f16 : (i64) -> !llvm.ptr<5>
43-
// CHECK: %[[V:.*]] = llvm.load %[[REG]] : !llvm.ptr<5> -> vector<4xf16>
44-
// CHECK-NEXT: llvm.store %[[V]], %[[ARG0]] : vector<4xf16>, !llvm.ptr<1>
45-
gpu.func @store_register_into_strided_global(%dst: !fly.ptr<f16, global>) kernel {
46-
%shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4>
47-
%stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1>
48-
%stride8 = fly.make_int_tuple() : () -> !fly.int_tuple<8>
49-
50-
%dst_layout = fly.make_layout(%shape4, %stride8)
51-
: (!fly.int_tuple<4>, !fly.int_tuple<8>) -> !fly.layout<4:8>
52-
%reg_layout = fly.make_layout(%shape4, %stride1)
53-
: (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1>
54-
55-
%dst_view = fly.make_view(%dst, %dst_layout)
56-
: (!fly.ptr<f16, global>, !fly.layout<4:8>) -> !fly.memref<f16, global, 4:8>
57-
58-
%copy = fly.make_copy_atom {valBits = 16 : i32}
59-
: !fly.copy_atom<!fly.universal_copy<64>, 16>
60-
61-
%reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}}
62-
: () -> !fly.ptr<f16, register>
63-
%reg_view = fly.make_view(%reg_ptr, %reg_layout)
64-
: (!fly.ptr<f16, register>, !fly.layout<4:1>) -> !fly.memref<f16, register, 4:1>
65-
66-
fly.copy_atom_call(%copy, %reg_view, %dst_view)
67-
: (!fly.copy_atom<!fly.universal_copy<64>, 16>,
68-
!fly.memref<f16, register, 4:1>,
69-
!fly.memref<f16, global, 4:8>) -> ()
70-
gpu.return
71-
}
7235
}

0 commit comments

Comments
 (0)