Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,6 @@ LogicalResult
AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (isa<FloatType>(atomicOp.getType()))
return rewriter.notifyMatchFailure(atomicOp,
"unimplemented floating-point case");

auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
if (!scope)
Expand Down Expand Up @@ -488,13 +484,13 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
"failed to convert memref type");

Type pointeeType = pointerType.getPointeeType();
auto dstType = dyn_cast<IntegerType>(
getElementTypeForStoragePointer(pointeeType, typeConverter));
if (!dstType)
Type storageElemType =
getElementTypeForStoragePointer(pointeeType, typeConverter);
if (!storageElemType || !storageElemType.isIntOrFloat())
return rewriter.notifyMatchFailure(
atomicOp, "failed to determine destination element type");

int dstBits = static_cast<int>(dstType.getWidth());
int dstBits = static_cast<int>(storageElemType.getIntOrFloatBitWidth());
assert(dstBits % srcBits == 0);

spirv::MemorySemantics memSem = getAtomicAcqRelMemorySemantics(memrefType);
Expand All @@ -509,6 +505,7 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
break

switch (atomicOp.getKind()) {
ATOMIC_CASE(addf, EXTAtomicFAddOp);
ATOMIC_CASE(addi, AtomicIAddOp);
ATOMIC_CASE(maxs, AtomicSMaxOp);
ATOMIC_CASE(maxu, AtomicUMaxOp);
Expand Down Expand Up @@ -546,6 +543,8 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
atomicOp,
"sub-element-width atomic ops unsupported with Kernel capability");

auto dstType = cast<IntegerType>(storageElemType);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it safe to cast float to int as the dstType?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Integers are handled above


auto accessChainOp = ptr.getDefiningOp<spirv::AccessChainOp>();
if (!accessChainOp)
return failure();
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,21 @@ func.func @atomic_andi_i8_storage_buffer(%value: i8, %memref: memref<16xi8, #spi
}

}

// -----

// Floating-point atomic add requires the shader_atomic_float_add extension.

module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, AtomicFloat32AddEXT], [SPV_EXT_shader_atomic_float_add]>, #spirv.resource_limits<>>} {

// CHECK: func.func @atomic_addf_storage_buffer
// CHECK-SAME: (%[[VAL:.+]]: f32,
func.func @atomic_addf_storage_buffer(%value: f32, %memref: memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>, %i0: index, %i1: index, %i2: index) -> f32 {
// CHECK: %[[AC:.+]] = spirv.AccessChain
// CHECK: %[[ATOMIC:.+]] = spirv.EXT.AtomicFAdd <Device> <AcquireRelease|UniformMemory> %[[AC]], %[[VAL]] : !spirv.ptr<f32, StorageBuffer>
// CHECK: return %[[ATOMIC]]
%0 = memref.atomic_rmw "addf" %value, %memref[%i0, %i1, %i2] : (f32, memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>) -> f32
return %0: f32
}

}
Loading