|
9 | 9 | #include "mlir/IR/DialectImplementation.h" |
10 | 10 |
|
11 | 11 | #include "flydsl/Dialect/Fly/IR/FlyDialect.h" |
| 12 | +#include "flydsl/Dialect/Fly/Utils/LayoutUtils.h" |
12 | 13 | #include "flydsl/Dialect/Fly/Utils/PointerUtils.h" |
13 | 14 | #include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" |
14 | 15 |
|
15 | 16 | namespace mlir::fly { |
16 | 17 |
|
| 18 | +namespace { |
| 19 | + |
| 20 | +/// Returns the element-count and element-stride of a memref, assuming its |
| 21 | +/// layout coalesces into a single static leaf. This represents how a single |
| 22 | +/// atom-call moves `count` elements starting from the base pointer with the |
| 23 | +/// given `elementStride` (in elements). |
| 24 | +FailureOr<std::pair<int64_t, int64_t>> |
| 25 | +getCoalescedLeafCountAndStride(fly::MemRefType memTy) { |
| 26 | + auto layoutAttr = dyn_cast<LayoutAttr>(memTy.getLayout()); |
| 27 | + if (!layoutAttr) |
| 28 | + return failure(); |
| 29 | + LayoutBuilder<LayoutAttr> lb(memTy.getContext()); |
| 30 | + auto coal = layoutCoalesce(lb, layoutAttr); |
| 31 | + if (!coal.isLeaf()) |
| 32 | + return failure(); |
| 33 | + auto shapeInt = coal.getShape().getLeafAsInt(); |
| 34 | + auto strideInt = coal.getStride().getLeafAsInt(); |
| 35 | + if (!shapeInt.isStatic() || !strideInt.isStatic()) |
| 36 | + return failure(); |
| 37 | + return std::make_pair<int64_t, int64_t>(shapeInt.getValue(), strideInt.getValue()); |
| 38 | +} |
| 39 | + |
| 40 | +/// Emits a sequence of element-wise loads from a strided memory pointer and |
| 41 | +/// packs the values into a `vector<count x elemTy>`. The i-th element is read |
| 42 | +/// from `base + i * elementStride` (in elements). |
| 43 | +Value emitStridedLoadAsVector(OpBuilder &b, Location loc, VectorType vecTy, |
| 44 | + TypedValue<LLVM::LLVMPointerType> base, |
| 45 | + int64_t count, int64_t elementStride, |
| 46 | + fly::MemRefType memTy) { |
| 47 | + Type llvmElemTy = projectToLLVMCompatibleElemTy(memTy.getElemTy()); |
| 48 | + Type vecElemTy = vecTy.getElementType(); |
| 49 | + auto ptrTy = base.getType(); |
| 50 | + Value vec = LLVM::UndefOp::create(b, loc, vecTy); |
| 51 | + for (int64_t i = 0; i < count; ++i) { |
| 52 | + Value gep = LLVM::GEPOp::create(b, loc, ptrTy, llvmElemTy, base, |
| 53 | + ArrayRef<LLVM::GEPArg>{int32_t(i * elementStride)}); |
| 54 | + Value swz = applySwizzleOnPtr(b, loc, |
| 55 | + cast<TypedValue<LLVM::LLVMPointerType>>(gep), |
| 56 | + memTy.getSwizzle()); |
| 57 | + Value elem = LLVM::LoadOp::create(b, loc, llvmElemTy, swz); |
| 58 | + if (llvmElemTy != vecElemTy) |
| 59 | + elem = LLVM::BitcastOp::create(b, loc, vecElemTy, elem); |
| 60 | + Value idx = arith::ConstantIntOp::create(b, loc, i, /*width=*/32); |
| 61 | + vec = LLVM::InsertElementOp::create(b, loc, vec, elem, idx); |
| 62 | + } |
| 63 | + return vec; |
| 64 | +} |
| 65 | + |
| 66 | +/// Emits a sequence of element-wise stores that scatter a `vector<count x E>` |
| 67 | +/// into a strided memory pointer. The i-th element is written to |
| 68 | +/// `base + i * elementStride` (in elements). |
| 69 | +void emitStridedStoreFromVector(OpBuilder &b, Location loc, Value vec, |
| 70 | + TypedValue<LLVM::LLVMPointerType> base, |
| 71 | + int64_t count, int64_t elementStride, |
| 72 | + fly::MemRefType memTy) { |
| 73 | + auto vecTy = cast<VectorType>(vec.getType()); |
| 74 | + Type llvmElemTy = projectToLLVMCompatibleElemTy(memTy.getElemTy()); |
| 75 | + Type vecElemTy = vecTy.getElementType(); |
| 76 | + auto ptrTy = base.getType(); |
| 77 | + for (int64_t i = 0; i < count; ++i) { |
| 78 | + Value idx = arith::ConstantIntOp::create(b, loc, i, /*width=*/32); |
| 79 | + Value elem = LLVM::ExtractElementOp::create(b, loc, vec, idx); |
| 80 | + if (llvmElemTy != vecElemTy) |
| 81 | + elem = LLVM::BitcastOp::create(b, loc, llvmElemTy, elem); |
| 82 | + Value gep = LLVM::GEPOp::create(b, loc, ptrTy, llvmElemTy, base, |
| 83 | + ArrayRef<LLVM::GEPArg>{int32_t(i * elementStride)}); |
| 84 | + Value swz = applySwizzleOnPtr(b, loc, |
| 85 | + cast<TypedValue<LLVM::LLVMPointerType>>(gep), |
| 86 | + memTy.getSwizzle()); |
| 87 | + LLVM::StoreOp::create(b, loc, elem, swz); |
| 88 | + } |
| 89 | +} |
| 90 | + |
| 91 | +} // namespace |
| 92 | + |
17 | 93 | bool CopyOpUniversalCopyType::isStatic() const { return true; } |
18 | 94 |
|
19 | 95 | Value CopyOpUniversalCopyType::rebuildStaticValue(OpBuilder &builder, Location loc, |
@@ -129,23 +205,55 @@ FailureOr<Value> CopyOpUniversalCopyType::emitAtomCallSSA(OpBuilder &builder, Lo |
129 | 205 | Value dst) const { |
130 | 206 | Value result; |
131 | 207 | if (isa<fly::MemRefType>(srcTyArg)) { |
132 | | - // src is memory |
| 208 | + // src is memory: honor the memref's coalesced stride so a non-unit stride |
| 209 | + // layout lowers to a per-element scatter/gather rather than a single |
| 210 | + // contiguous vector load that silently ignores the stride. |
133 | 211 | auto srcMemTy = cast<fly::MemRefType>(srcTyArg); |
| 212 | + auto srcBase = cast<TypedValue<LLVM::LLVMPointerType>>(src); |
| 213 | + auto countAndStride = getCoalescedLeafCountAndStride(srcMemTy); |
| 214 | + if (failed(countAndStride)) |
| 215 | + return failure(); |
| 216 | + auto [count, stride] = *countAndStride; |
| 217 | + |
134 | 218 | Type loadTy = resultTy ? resultTy : builder.getIntegerType(getBitSize()); |
135 | | - Value srcPtr = applySwizzleOnPtr(builder, loc, cast<TypedValue<LLVM::LLVMPointerType>>(src), |
136 | | - srcMemTy.getSwizzle()); |
137 | | - result = LLVM::LoadOp::create(builder, loc, loadTy, srcPtr); |
| 219 | + |
| 220 | + if (count <= 1 || stride == 1) { |
| 221 | + Value srcPtr = applySwizzleOnPtr(builder, loc, srcBase, srcMemTy.getSwizzle()); |
| 222 | + result = LLVM::LoadOp::create(builder, loc, loadTy, srcPtr); |
| 223 | + } else { |
| 224 | + Type llvmElemTy = projectToLLVMCompatibleElemTy(srcMemTy.getElemTy()); |
| 225 | + auto vecTy = VectorType::get({count}, llvmElemTy); |
| 226 | + Value vec = |
| 227 | + emitStridedLoadAsVector(builder, loc, vecTy, srcBase, count, stride, srcMemTy); |
| 228 | + if (vec.getType() != loadTy) |
| 229 | + vec = LLVM::BitcastOp::create(builder, loc, loadTy, vec); |
| 230 | + result = vec; |
| 231 | + } |
138 | 232 | } else { |
139 | 233 | // src is register |
140 | 234 | result = src; |
141 | 235 | } |
142 | 236 |
|
143 | 237 | if (!resultTy) { |
144 | | - // dst is memory |
| 238 | + // dst is memory: symmetric treatment for strided stores. |
145 | 239 | auto dstMemTy = cast<fly::MemRefType>(dstTyArg); |
146 | | - Value dstPtr = applySwizzleOnPtr(builder, loc, cast<TypedValue<LLVM::LLVMPointerType>>(dst), |
147 | | - dstMemTy.getSwizzle()); |
148 | | - LLVM::StoreOp::create(builder, loc, result, dstPtr); |
| 240 | + auto dstBase = cast<TypedValue<LLVM::LLVMPointerType>>(dst); |
| 241 | + auto countAndStride = getCoalescedLeafCountAndStride(dstMemTy); |
| 242 | + if (failed(countAndStride)) |
| 243 | + return failure(); |
| 244 | + auto [count, stride] = *countAndStride; |
| 245 | + |
| 246 | + if (count <= 1 || stride == 1) { |
| 247 | + Value dstPtr = applySwizzleOnPtr(builder, loc, dstBase, dstMemTy.getSwizzle()); |
| 248 | + LLVM::StoreOp::create(builder, loc, result, dstPtr); |
| 249 | + } else { |
| 250 | + Type llvmElemTy = projectToLLVMCompatibleElemTy(dstMemTy.getElemTy()); |
| 251 | + auto vecTy = VectorType::get({count}, llvmElemTy); |
| 252 | + Value vec = result; |
| 253 | + if (vec.getType() != vecTy) |
| 254 | + vec = LLVM::BitcastOp::create(builder, loc, vecTy, vec); |
| 255 | + emitStridedStoreFromVector(builder, loc, vec, dstBase, count, stride, dstMemTy); |
| 256 | + } |
149 | 257 | } |
150 | 258 | return result; |
151 | 259 | } |
@@ -188,14 +296,53 @@ LogicalResult CopyOpUniversalCopyType::emitAtomCall(OpBuilder &builder, Location |
188 | 296 | if (!isa<LLVM::LLVMPointerType>(src.getType()) || !isa<LLVM::LLVMPointerType>(dst.getType())) |
189 | 297 | return failure(); |
190 | 298 |
|
191 | | - int32_t copyBytes = getBitSize() / 8; |
192 | | - Value srcPtr = applySwizzleOnPtr(builder, loc, cast<TypedValue<LLVM::LLVMPointerType>>(src), |
193 | | - srcMemTy.getSwizzle()); |
194 | | - Value dstPtr = applySwizzleOnPtr(builder, loc, cast<TypedValue<LLVM::LLVMPointerType>>(dst), |
195 | | - dstMemTy.getSwizzle()); |
196 | | - Value len = arith::ConstantIntOp::create(builder, loc, copyBytes, /*width=*/32); |
197 | | - LLVM::MemcpyOp::create(builder, loc, dstPtr, srcPtr, len, /*isVolatile=*/false); |
| 299 | + auto srcBase = cast<TypedValue<LLVM::LLVMPointerType>>(src); |
| 300 | + auto dstBase = cast<TypedValue<LLVM::LLVMPointerType>>(dst); |
| 301 | + |
| 302 | + auto srcCs = getCoalescedLeafCountAndStride(srcMemTy); |
| 303 | + auto dstCs = getCoalescedLeafCountAndStride(dstMemTy); |
| 304 | + if (failed(srcCs) || failed(dstCs)) |
| 305 | + return failure(); |
| 306 | + auto [srcCount, srcStride] = *srcCs; |
| 307 | + auto [dstCount, dstStride] = *dstCs; |
| 308 | + if (srcCount != dstCount) |
| 309 | + return failure(); |
| 310 | + |
| 311 | + bool srcContig = srcCount <= 1 || srcStride == 1; |
| 312 | + bool dstContig = dstCount <= 1 || dstStride == 1; |
| 313 | + |
| 314 | + if (srcContig && dstContig) { |
| 315 | + // Fast path: both sides are contiguous, lower to a single memcpy. |
| 316 | + int32_t copyBytes = getBitSize() / 8; |
| 317 | + Value srcPtr = applySwizzleOnPtr(builder, loc, srcBase, srcMemTy.getSwizzle()); |
| 318 | + Value dstPtr = applySwizzleOnPtr(builder, loc, dstBase, dstMemTy.getSwizzle()); |
| 319 | + Value len = arith::ConstantIntOp::create(builder, loc, copyBytes, /*width=*/32); |
| 320 | + LLVM::MemcpyOp::create(builder, loc, dstPtr, srcPtr, len, /*isVolatile=*/false); |
| 321 | + return success(); |
| 322 | + } |
198 | 323 |
|
| 324 | + // At least one side is strided: emit element-wise gather/scatter so each |
| 325 | + // side honors its own stride. |
| 326 | + Type llvmSrcElemTy = projectToLLVMCompatibleElemTy(srcMemTy.getElemTy()); |
| 327 | + Type llvmDstElemTy = projectToLLVMCompatibleElemTy(dstMemTy.getElemTy()); |
| 328 | + auto srcPtrTy = srcBase.getType(); |
| 329 | + auto dstPtrTy = dstBase.getType(); |
| 330 | + for (int64_t i = 0; i < srcCount; ++i) { |
| 331 | + Value srcGep = LLVM::GEPOp::create(builder, loc, srcPtrTy, llvmSrcElemTy, srcBase, |
| 332 | + ArrayRef<LLVM::GEPArg>{int32_t(i * srcStride)}); |
| 333 | + Value srcSwz = applySwizzleOnPtr(builder, loc, |
| 334 | + cast<TypedValue<LLVM::LLVMPointerType>>(srcGep), |
| 335 | + srcMemTy.getSwizzle()); |
| 336 | + Value v = LLVM::LoadOp::create(builder, loc, llvmSrcElemTy, srcSwz); |
| 337 | + if (llvmSrcElemTy != llvmDstElemTy) |
| 338 | + v = LLVM::BitcastOp::create(builder, loc, llvmDstElemTy, v); |
| 339 | + Value dstGep = LLVM::GEPOp::create(builder, loc, dstPtrTy, llvmDstElemTy, dstBase, |
| 340 | + ArrayRef<LLVM::GEPArg>{int32_t(i * dstStride)}); |
| 341 | + Value dstSwz = applySwizzleOnPtr(builder, loc, |
| 342 | + cast<TypedValue<LLVM::LLVMPointerType>>(dstGep), |
| 343 | + dstMemTy.getSwizzle()); |
| 344 | + LLVM::StoreOp::create(builder, loc, v, dstSwz); |
| 345 | + } |
199 | 346 | return success(); |
200 | 347 | } |
201 | 348 |
|
|
0 commit comments