Skip to content

Commit b4130b0

Browse files
committed
[Fix] Honor stride in universal copy atom lowering
CopyOpUniversalCopyType::emitAtomCallSSA / emitAtomCall used to emit a single contiguous llvm.load / llvm.store / llvm.memcpy against the memory-side pointer of fly.copy_atom_call, ignoring the memref's layout stride. When the memref's coalesced leaf has stride != 1, this silently produced wrong IR (adjacent lanes read/written instead of stride-spaced elements). Fix by consulting the coalesced leaf of the memref layout: * count <= 1 or stride == 1: keep the existing fast path (single vector load/store or memcpy). * otherwise: emit element-wise gather/scatter -- per-element GEP by i * stride (in elements), swizzle the resulting pointer, then load/insertelement (load side) or extractelement/store (store side).
1 parent bc574f4 commit b4130b0

2 files changed

Lines changed: 188 additions & 18 deletions

File tree

lib/Dialect/Fly/IR/FlyUniversalOps.cpp

Lines changed: 162 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,87 @@
99
#include "mlir/IR/DialectImplementation.h"
1010

1111
#include "flydsl/Dialect/Fly/IR/FlyDialect.h"
12+
#include "flydsl/Dialect/Fly/Utils/LayoutUtils.h"
1213
#include "flydsl/Dialect/Fly/Utils/PointerUtils.h"
1314
#include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc"
1415

1516
namespace mlir::fly {
1617

18+
namespace {
19+
20+
/// Returns the element-count and element-stride of a memref, assuming its
21+
/// layout coalesces into a single static leaf. This represents how a single
22+
/// atom-call moves `count` elements starting from the base pointer with the
23+
/// given `elementStride` (in elements).
24+
FailureOr<std::pair<int64_t, int64_t>>
25+
getCoalescedLeafCountAndStride(fly::MemRefType memTy) {
26+
auto layoutAttr = dyn_cast<LayoutAttr>(memTy.getLayout());
27+
if (!layoutAttr)
28+
return failure();
29+
LayoutBuilder<LayoutAttr> lb(memTy.getContext());
30+
auto coal = layoutCoalesce(lb, layoutAttr);
31+
if (!coal.isLeaf())
32+
return failure();
33+
auto shapeInt = coal.getShape().getLeafAsInt();
34+
auto strideInt = coal.getStride().getLeafAsInt();
35+
if (!shapeInt.isStatic() || !strideInt.isStatic())
36+
return failure();
37+
return std::make_pair<int64_t, int64_t>(shapeInt.getValue(), strideInt.getValue());
38+
}
39+
40+
/// Emits a sequence of element-wise loads from a strided memory pointer and
41+
/// packs the values into a `vector<count x elemTy>`. The i-th element is read
42+
/// from `base + i * elementStride` (in elements).
43+
Value emitStridedLoadAsVector(OpBuilder &b, Location loc, VectorType vecTy,
44+
TypedValue<LLVM::LLVMPointerType> base,
45+
int64_t count, int64_t elementStride,
46+
fly::MemRefType memTy) {
47+
Type llvmElemTy = projectToLLVMCompatibleElemTy(memTy.getElemTy());
48+
Type vecElemTy = vecTy.getElementType();
49+
auto ptrTy = base.getType();
50+
Value vec = LLVM::UndefOp::create(b, loc, vecTy);
51+
for (int64_t i = 0; i < count; ++i) {
52+
Value gep = LLVM::GEPOp::create(b, loc, ptrTy, llvmElemTy, base,
53+
ArrayRef<LLVM::GEPArg>{int32_t(i * elementStride)});
54+
Value swz = applySwizzleOnPtr(b, loc,
55+
cast<TypedValue<LLVM::LLVMPointerType>>(gep),
56+
memTy.getSwizzle());
57+
Value elem = LLVM::LoadOp::create(b, loc, llvmElemTy, swz);
58+
if (llvmElemTy != vecElemTy)
59+
elem = LLVM::BitcastOp::create(b, loc, vecElemTy, elem);
60+
Value idx = arith::ConstantIntOp::create(b, loc, i, /*width=*/32);
61+
vec = LLVM::InsertElementOp::create(b, loc, vec, elem, idx);
62+
}
63+
return vec;
64+
}
65+
66+
/// Emits a sequence of element-wise stores that scatter a `vector<count x E>`
67+
/// into a strided memory pointer. The i-th element is written to
68+
/// `base + i * elementStride` (in elements).
69+
void emitStridedStoreFromVector(OpBuilder &b, Location loc, Value vec,
70+
TypedValue<LLVM::LLVMPointerType> base,
71+
int64_t count, int64_t elementStride,
72+
fly::MemRefType memTy) {
73+
auto vecTy = cast<VectorType>(vec.getType());
74+
Type llvmElemTy = projectToLLVMCompatibleElemTy(memTy.getElemTy());
75+
Type vecElemTy = vecTy.getElementType();
76+
auto ptrTy = base.getType();
77+
for (int64_t i = 0; i < count; ++i) {
78+
Value idx = arith::ConstantIntOp::create(b, loc, i, /*width=*/32);
79+
Value elem = LLVM::ExtractElementOp::create(b, loc, vec, idx);
80+
if (llvmElemTy != vecElemTy)
81+
elem = LLVM::BitcastOp::create(b, loc, llvmElemTy, elem);
82+
Value gep = LLVM::GEPOp::create(b, loc, ptrTy, llvmElemTy, base,
83+
ArrayRef<LLVM::GEPArg>{int32_t(i * elementStride)});
84+
Value swz = applySwizzleOnPtr(b, loc,
85+
cast<TypedValue<LLVM::LLVMPointerType>>(gep),
86+
memTy.getSwizzle());
87+
LLVM::StoreOp::create(b, loc, elem, swz);
88+
}
89+
}
90+
91+
} // namespace
92+
1793
bool CopyOpUniversalCopyType::isStatic() const { return true; }
1894

1995
Value CopyOpUniversalCopyType::rebuildStaticValue(OpBuilder &builder, Location loc,
@@ -129,23 +205,55 @@ FailureOr<Value> CopyOpUniversalCopyType::emitAtomCallSSA(OpBuilder &builder, Lo
129205
Value dst) const {
130206
Value result;
131207
if (isa<fly::MemRefType>(srcTyArg)) {
132-
// src is memory
208+
// src is memory: honor the memref's coalesced stride so a non-unit stride
209+
// layout lowers to a per-element scatter/gather rather than a single
210+
// contiguous vector load that silently ignores the stride.
133211
auto srcMemTy = cast<fly::MemRefType>(srcTyArg);
212+
auto srcBase = cast<TypedValue<LLVM::LLVMPointerType>>(src);
213+
auto countAndStride = getCoalescedLeafCountAndStride(srcMemTy);
214+
if (failed(countAndStride))
215+
return failure();
216+
auto [count, stride] = *countAndStride;
217+
134218
Type loadTy = resultTy ? resultTy : builder.getIntegerType(getBitSize());
135-
Value srcPtr = applySwizzleOnPtr(builder, loc, cast<TypedValue<LLVM::LLVMPointerType>>(src),
136-
srcMemTy.getSwizzle());
137-
result = LLVM::LoadOp::create(builder, loc, loadTy, srcPtr);
219+
220+
if (count <= 1 || stride == 1) {
221+
Value srcPtr = applySwizzleOnPtr(builder, loc, srcBase, srcMemTy.getSwizzle());
222+
result = LLVM::LoadOp::create(builder, loc, loadTy, srcPtr);
223+
} else {
224+
Type llvmElemTy = projectToLLVMCompatibleElemTy(srcMemTy.getElemTy());
225+
auto vecTy = VectorType::get({count}, llvmElemTy);
226+
Value vec =
227+
emitStridedLoadAsVector(builder, loc, vecTy, srcBase, count, stride, srcMemTy);
228+
if (vec.getType() != loadTy)
229+
vec = LLVM::BitcastOp::create(builder, loc, loadTy, vec);
230+
result = vec;
231+
}
138232
} else {
139233
// src is register
140234
result = src;
141235
}
142236

143237
if (!resultTy) {
144-
// dst is memory
238+
// dst is memory: symmetric treatment for strided stores.
145239
auto dstMemTy = cast<fly::MemRefType>(dstTyArg);
146-
Value dstPtr = applySwizzleOnPtr(builder, loc, cast<TypedValue<LLVM::LLVMPointerType>>(dst),
147-
dstMemTy.getSwizzle());
148-
LLVM::StoreOp::create(builder, loc, result, dstPtr);
240+
auto dstBase = cast<TypedValue<LLVM::LLVMPointerType>>(dst);
241+
auto countAndStride = getCoalescedLeafCountAndStride(dstMemTy);
242+
if (failed(countAndStride))
243+
return failure();
244+
auto [count, stride] = *countAndStride;
245+
246+
if (count <= 1 || stride == 1) {
247+
Value dstPtr = applySwizzleOnPtr(builder, loc, dstBase, dstMemTy.getSwizzle());
248+
LLVM::StoreOp::create(builder, loc, result, dstPtr);
249+
} else {
250+
Type llvmElemTy = projectToLLVMCompatibleElemTy(dstMemTy.getElemTy());
251+
auto vecTy = VectorType::get({count}, llvmElemTy);
252+
Value vec = result;
253+
if (vec.getType() != vecTy)
254+
vec = LLVM::BitcastOp::create(builder, loc, vecTy, vec);
255+
emitStridedStoreFromVector(builder, loc, vec, dstBase, count, stride, dstMemTy);
256+
}
149257
}
150258
return result;
151259
}
@@ -188,14 +296,53 @@ LogicalResult CopyOpUniversalCopyType::emitAtomCall(OpBuilder &builder, Location
188296
if (!isa<LLVM::LLVMPointerType>(src.getType()) || !isa<LLVM::LLVMPointerType>(dst.getType()))
189297
return failure();
190298

191-
int32_t copyBytes = getBitSize() / 8;
192-
Value srcPtr = applySwizzleOnPtr(builder, loc, cast<TypedValue<LLVM::LLVMPointerType>>(src),
193-
srcMemTy.getSwizzle());
194-
Value dstPtr = applySwizzleOnPtr(builder, loc, cast<TypedValue<LLVM::LLVMPointerType>>(dst),
195-
dstMemTy.getSwizzle());
196-
Value len = arith::ConstantIntOp::create(builder, loc, copyBytes, /*width=*/32);
197-
LLVM::MemcpyOp::create(builder, loc, dstPtr, srcPtr, len, /*isVolatile=*/false);
299+
auto srcBase = cast<TypedValue<LLVM::LLVMPointerType>>(src);
300+
auto dstBase = cast<TypedValue<LLVM::LLVMPointerType>>(dst);
301+
302+
auto srcCs = getCoalescedLeafCountAndStride(srcMemTy);
303+
auto dstCs = getCoalescedLeafCountAndStride(dstMemTy);
304+
if (failed(srcCs) || failed(dstCs))
305+
return failure();
306+
auto [srcCount, srcStride] = *srcCs;
307+
auto [dstCount, dstStride] = *dstCs;
308+
if (srcCount != dstCount)
309+
return failure();
310+
311+
bool srcContig = srcCount <= 1 || srcStride == 1;
312+
bool dstContig = dstCount <= 1 || dstStride == 1;
313+
314+
if (srcContig && dstContig) {
315+
// Fast path: both sides are contiguous, lower to a single memcpy.
316+
int32_t copyBytes = getBitSize() / 8;
317+
Value srcPtr = applySwizzleOnPtr(builder, loc, srcBase, srcMemTy.getSwizzle());
318+
Value dstPtr = applySwizzleOnPtr(builder, loc, dstBase, dstMemTy.getSwizzle());
319+
Value len = arith::ConstantIntOp::create(builder, loc, copyBytes, /*width=*/32);
320+
LLVM::MemcpyOp::create(builder, loc, dstPtr, srcPtr, len, /*isVolatile=*/false);
321+
return success();
322+
}
198323

324+
// At least one side is strided: emit element-wise gather/scatter so each
325+
// side honors its own stride.
326+
Type llvmSrcElemTy = projectToLLVMCompatibleElemTy(srcMemTy.getElemTy());
327+
Type llvmDstElemTy = projectToLLVMCompatibleElemTy(dstMemTy.getElemTy());
328+
auto srcPtrTy = srcBase.getType();
329+
auto dstPtrTy = dstBase.getType();
330+
for (int64_t i = 0; i < srcCount; ++i) {
331+
Value srcGep = LLVM::GEPOp::create(builder, loc, srcPtrTy, llvmSrcElemTy, srcBase,
332+
ArrayRef<LLVM::GEPArg>{int32_t(i * srcStride)});
333+
Value srcSwz = applySwizzleOnPtr(builder, loc,
334+
cast<TypedValue<LLVM::LLVMPointerType>>(srcGep),
335+
srcMemTy.getSwizzle());
336+
Value v = LLVM::LoadOp::create(builder, loc, llvmSrcElemTy, srcSwz);
337+
if (llvmSrcElemTy != llvmDstElemTy)
338+
v = LLVM::BitcastOp::create(builder, loc, llvmDstElemTy, v);
339+
Value dstGep = LLVM::GEPOp::create(builder, loc, dstPtrTy, llvmDstElemTy, dstBase,
340+
ArrayRef<LLVM::GEPArg>{int32_t(i * dstStride)});
341+
Value dstSwz = applySwizzleOnPtr(builder, loc,
342+
cast<TypedValue<LLVM::LLVMPointerType>>(dstGep),
343+
dstMemTy.getSwizzle());
344+
LLVM::StoreOp::create(builder, loc, v, dstSwz);
345+
}
199346
return success();
200347
}
201348

tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlir

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,20 @@ gpu.module @bug_strided_universal_copy {
77
// CHECK-LABEL: gpu.func @load_strided_global_into_register(
88
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>
99
// CHECK: %[[REG:.*]] = llvm.alloca %{{.*}} x f16 : (i64) -> !llvm.ptr<5>
10-
// CHECK: %[[V:.*]] = llvm.load %[[ARG0]] : !llvm.ptr<1> -> vector<4xf16>
11-
// CHECK-NEXT: llvm.store %[[V]], %[[REG]] : vector<4xf16>, !llvm.ptr<5>
10+
// CHECK: %[[U:.*]] = llvm.mlir.undef : vector<4xf16>
11+
// CHECK: %[[P0:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, f16
12+
// CHECK: %[[E0:.*]] = llvm.load %[[P0]] : !llvm.ptr<1> -> f16
13+
// CHECK: %[[V0:.*]] = llvm.insertelement %[[E0]], %[[U]]{{.*}} : vector<4xf16>
14+
// CHECK: %[[P1:.*]] = llvm.getelementptr %[[ARG0]][8] : (!llvm.ptr<1>) -> !llvm.ptr<1>, f16
15+
// CHECK: %[[E1:.*]] = llvm.load %[[P1]] : !llvm.ptr<1> -> f16
16+
// CHECK: %[[V1:.*]] = llvm.insertelement %[[E1]], %[[V0]]{{.*}} : vector<4xf16>
17+
// CHECK: %[[P2:.*]] = llvm.getelementptr %[[ARG0]][16] : (!llvm.ptr<1>) -> !llvm.ptr<1>, f16
18+
// CHECK: %[[E2:.*]] = llvm.load %[[P2]] : !llvm.ptr<1> -> f16
19+
// CHECK: %[[V2:.*]] = llvm.insertelement %[[E2]], %[[V1]]{{.*}} : vector<4xf16>
20+
// CHECK: %[[P3:.*]] = llvm.getelementptr %[[ARG0]][24] : (!llvm.ptr<1>) -> !llvm.ptr<1>, f16
21+
// CHECK: %[[E3:.*]] = llvm.load %[[P3]] : !llvm.ptr<1> -> f16
22+
// CHECK: %[[V3:.*]] = llvm.insertelement %[[E3]], %[[V2]]{{.*}} : vector<4xf16>
23+
// CHECK: llvm.store %[[V3]], %[[REG]] : vector<4xf16>, !llvm.ptr<5>
1224
gpu.func @load_strided_global_into_register(%src: !fly.ptr<f16, global>) kernel {
1325
%shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4>
1426
%stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1>
@@ -41,7 +53,18 @@ gpu.module @bug_strided_universal_copy {
4153
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>
4254
// CHECK: %[[REG:.*]] = llvm.alloca %{{.*}} x f16 : (i64) -> !llvm.ptr<5>
4355
// CHECK: %[[V:.*]] = llvm.load %[[REG]] : !llvm.ptr<5> -> vector<4xf16>
44-
// CHECK-NEXT: llvm.store %[[V]], %[[ARG0]] : vector<4xf16>, !llvm.ptr<1>
56+
// CHECK: %[[E0:.*]] = llvm.extractelement %[[V]]{{.*}} : vector<4xf16>
57+
// CHECK: %[[P0:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, f16
58+
// CHECK: llvm.store %[[E0]], %[[P0]] : f16, !llvm.ptr<1>
59+
// CHECK: %[[E1:.*]] = llvm.extractelement %[[V]]{{.*}} : vector<4xf16>
60+
// CHECK: %[[P1:.*]] = llvm.getelementptr %[[ARG0]][8] : (!llvm.ptr<1>) -> !llvm.ptr<1>, f16
61+
// CHECK: llvm.store %[[E1]], %[[P1]] : f16, !llvm.ptr<1>
62+
// CHECK: %[[E2:.*]] = llvm.extractelement %[[V]]{{.*}} : vector<4xf16>
63+
// CHECK: %[[P2:.*]] = llvm.getelementptr %[[ARG0]][16] : (!llvm.ptr<1>) -> !llvm.ptr<1>, f16
64+
// CHECK: llvm.store %[[E2]], %[[P2]] : f16, !llvm.ptr<1>
65+
// CHECK: %[[E3:.*]] = llvm.extractelement %[[V]]{{.*}} : vector<4xf16>
66+
// CHECK: %[[P3:.*]] = llvm.getelementptr %[[ARG0]][24] : (!llvm.ptr<1>) -> !llvm.ptr<1>, f16
67+
// CHECK: llvm.store %[[E3]], %[[P3]] : f16, !llvm.ptr<1>
4568
gpu.func @store_register_into_strided_global(%dst: !fly.ptr<f16, global>) kernel {
4669
%shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4>
4770
%stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1>

0 commit comments

Comments
 (0)