Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/flydsl/Dialect/Fly/IR/FlyOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,15 @@ def Fly_AtomSetValueOp : Fly_Op<"atom.set_value", [Pure, DeclareOpInterfaceMetho
}

def Fly_CopyAtomCall : Fly_Op<"copy_atom_call"> {
let hasVerifier = 1;
let arguments = (ins Fly_CopyAtom:$copyAtom, Fly_MemRef:$src, Fly_MemRef:$dst, Optional<Fly_MemRef>:$pred);
}
def Fly_MmaAtomCall : Fly_Op<"mma_atom_call"> {
let arguments = (ins Fly_MmaAtom:$mmaAtom, Fly_MemRef:$d, Fly_MemRef:$a, Fly_MemRef:$b, Fly_MemRef:$c);
}

def Fly_CopyAtomCallSSA : Fly_Op<"copy_atom_call_ssa", [AttrSizedOperandSegments]> {
let hasVerifier = 1;
let arguments = (ins Fly_CopyAtom:$copyAtom, AnyType:$src,
Optional<AnyType>:$dst, Optional<AnyType>:$pred);
let results = (outs Variadic<AnyType>:$results);
Expand Down
81 changes: 81 additions & 0 deletions lib/Dialect/Fly/IR/FlyOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,53 @@ Type applyOffsetOnTensorLike(LayoutBuilder<LayoutAttr> &builder, Type tensorLike
llvm_unreachable("Unsupported tensor like type");
}

FailureOr<std::pair<int64_t, int64_t>> getCoalescedLeafCountAndStride(fly::MemRefType memRefTy) {
auto layoutAttr = dyn_cast<LayoutAttr>(memRefTy.getLayout());
if (!layoutAttr)
return failure();
LayoutBuilder<LayoutAttr> builder(memRefTy.getContext());
auto coalesced = layoutCoalesce(builder, layoutAttr);
if (!coalesced.isLeaf())
return failure();
auto shape = coalesced.getShape().getLeafAsInt();
auto stride = coalesced.getStride().getLeafAsInt();
if (!shape.isStatic() || !stride.isStatic())
return failure();
return std::make_pair<int64_t, int64_t>(shape.getValue(), stride.getValue());
}

LogicalResult verifyUniversalCopyOperand(Operation *op, StringRef operandName, CopyAtomType copyAtomTy,
fly::MemRefType memRefTy) {
auto universalCopy = dyn_cast<CopyOpUniversalCopyType>(copyAtomTy.getCopyOp());
if (!universalCopy)
return success();

auto countAndStride = getCoalescedLeafCountAndStride(memRefTy);
if (failed(countAndStride)) {
return op->emitOpError() << operandName
<< " memref layout must coalesce to a single static leaf for "
<< copyAtomTy;
}

auto [count, stride] = *countAndStride;
int64_t elemBits = memRefTy.getElemTy().getIntOrFloatBitWidth();
int64_t copyBits = universalCopy.getBitSize();
int64_t totalBits = count * elemBits;
if (totalBits != copyBits) {
return op->emitOpError() << operandName << " memref covers " << totalBits
<< " bits after coalescing, but " << copyAtomTy << " expects "
<< copyBits << " bits";
}

int64_t contiguousBits = (count <= 1 || stride == 1) ? totalBits : elemBits;
if (contiguousBits < copyBits) {
return op->emitOpError() << operandName << " memref contiguous bit count " << contiguousBits
<< " is smaller than copy granularity " << copyBits;
}

return success();
}

} // namespace

#define FLY_INFER_RETURN_TYPES(OP) \
Expand All @@ -133,6 +180,40 @@ Type applyOffsetOnTensorLike(LayoutBuilder<LayoutAttr> &builder, Type tensorLike
mlir::OpaqueProperties properties, mlir::RegionRange regions, \
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes)

LogicalResult CopyAtomCall::verify() {
auto copyAtomTy = dyn_cast<CopyAtomType>(getCopyAtom().getType());
if (!copyAtomTy)
return emitOpError("copyAtom is not CopyAtomType");

auto srcTy = cast<fly::MemRefType>(getSrc().getType());
auto dstTy = cast<fly::MemRefType>(getDst().getType());
if (srcTy.getElemTy() != dstTy.getElemTy())
return emitOpError("src/dst element types mismatch");

if (failed(verifyUniversalCopyOperand(getOperation(), "src", copyAtomTy, srcTy)))
return failure();
if (failed(verifyUniversalCopyOperand(getOperation(), "dst", copyAtomTy, dstTy)))
return failure();
return success();
}

LogicalResult CopyAtomCallSSA::verify() {
auto copyAtomTy = dyn_cast<CopyAtomType>(getCopyAtom().getType());
if (!copyAtomTy)
return emitOpError("copyAtom is not CopyAtomType");

auto srcTy = dyn_cast<fly::MemRefType>(getSrc().getType());
auto dstTy = getDst() ? dyn_cast<fly::MemRefType>(getDst().getType()) : fly::MemRefType();
if (srcTy && dstTy && srcTy.getElemTy() != dstTy.getElemTy())
return emitOpError("src/dst element types mismatch");

if (srcTy && failed(verifyUniversalCopyOperand(getOperation(), "src", copyAtomTy, srcTy)))
return failure();
if (dstTy && failed(verifyUniversalCopyOperand(getOperation(), "dst", copyAtomTy, dstTy)))
return failure();
return success();
}

//===----------------------------------------------------------------------===//
// Constructors
//===----------------------------------------------------------------------===//
Expand Down
25 changes: 0 additions & 25 deletions tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -120,31 +120,6 @@ gpu.module @convert_atom_call_to_ssa_form {
gpu.return
}

// Test 3b: copy_atom_call with register dst, non-coalescable layout should NOT be promoted
// (4,2):(1,8) cannot coalesce to rank=1 stride=1
// CHECK-LABEL: gpu.func @copy_dst_register_non_coalescable
// CHECK: fly.copy_atom_call(
// CHECK-NOT: fly.copy_atom_call_ssa
gpu.func @copy_dst_register_non_coalescable(%src: !fly.ptr<f16, global>) kernel {
%shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4>
%stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1>
%vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1>

%src_view = fly.make_view(%src, %vec4) : (!fly.ptr<f16, global>, !fly.layout<4:1>) -> !fly.memref<f16, global, 4:1>

%nc_shape = fly.make_int_tuple() : () -> !fly.int_tuple<(4,2)>
%nc_stride = fly.make_int_tuple() : () -> !fly.int_tuple<(1,8)>
%nc_layout = fly.make_layout(%nc_shape, %nc_stride) : (!fly.int_tuple<(4,2)>, !fly.int_tuple<(1,8)>) -> !fly.layout<(4,2):(1,8)>

%copy = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom<!fly.universal_copy<64>, 16>

%reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr<f16, register>
%reg_view = fly.make_view(%reg_ptr, %nc_layout) : (!fly.ptr<f16, register>, !fly.layout<(4,2):(1,8)>) -> !fly.memref<f16, register, (4,2):(1,8)>

fly.copy_atom_call(%copy, %src_view, %reg_view) : (!fly.copy_atom<!fly.universal_copy<64>, 16>, !fly.memref<f16, global, 4:1>, !fly.memref<f16, register, (4,2):(1,8)>) -> ()
gpu.return
}

// Test 4: mma_atom_call with register d (rank=1, stride=1) should be promoted
// a, b, c are also register eligible, so they get pre-loaded as vectors
// CHECK-LABEL: gpu.func @mma_d_register
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2025 FlyDSL Project Contributors
// RUN: not %fly-opt %s --fly-convert-atom-call-to-ssa-form --convert-fly-to-rocdl 2>&1 | FileCheck %s

gpu.module @bug_strided_universal_copy {

// CHECK: error: 'fly.copy_atom_call' op src memref contiguous bit count 16 is smaller than copy granularity 64
gpu.func @load_strided_global_into_register(%src: !fly.ptr<f16, global>) kernel {
%shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4>
%stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1>
%stride8 = fly.make_int_tuple() : () -> !fly.int_tuple<8>

%src_layout = fly.make_layout(%shape4, %stride8)
: (!fly.int_tuple<4>, !fly.int_tuple<8>) -> !fly.layout<4:8>
%reg_layout = fly.make_layout(%shape4, %stride1)
: (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1>

%src_view = fly.make_view(%src, %src_layout)
: (!fly.ptr<f16, global>, !fly.layout<4:8>) -> !fly.memref<f16, global, 4:8>

%copy = fly.make_copy_atom {valBits = 16 : i32}
: !fly.copy_atom<!fly.universal_copy<64>, 16>

%reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}}
: () -> !fly.ptr<f16, register>
%reg_view = fly.make_view(%reg_ptr, %reg_layout)
: (!fly.ptr<f16, register>, !fly.layout<4:1>) -> !fly.memref<f16, register, 4:1>

fly.copy_atom_call(%copy, %src_view, %reg_view)
: (!fly.copy_atom<!fly.universal_copy<64>, 16>,
!fly.memref<f16, global, 4:8>,
!fly.memref<f16, register, 4:1>) -> ()
gpu.return
}
}
Loading