diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index fe9f9c59e6ede..4674aa351315f 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -454,10 +454,6 @@ LogicalResult AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (isa(atomicOp.getType())) - return rewriter.notifyMatchFailure(atomicOp, - "unimplemented floating-point case"); - auto memrefType = cast(atomicOp.getMemref().getType()); std::optional scope = getAtomicOpScope(memrefType); if (!scope) @@ -488,13 +484,13 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, "failed to convert memref type"); Type pointeeType = pointerType.getPointeeType(); - auto dstType = dyn_cast( - getElementTypeForStoragePointer(pointeeType, typeConverter)); - if (!dstType) + Type storageElemType = + getElementTypeForStoragePointer(pointeeType, typeConverter); + if (!storageElemType || !storageElemType.isIntOrFloat()) return rewriter.notifyMatchFailure( atomicOp, "failed to determine destination element type"); - int dstBits = static_cast(dstType.getWidth()); + int dstBits = static_cast(storageElemType.getIntOrFloatBitWidth()); assert(dstBits % srcBits == 0); spirv::MemorySemantics memSem = getAtomicAcqRelMemorySemantics(memrefType); @@ -509,6 +505,7 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, break switch (atomicOp.getKind()) { + ATOMIC_CASE(addf, EXTAtomicFAddOp); ATOMIC_CASE(addi, AtomicIAddOp); ATOMIC_CASE(maxs, AtomicSMaxOp); ATOMIC_CASE(maxu, AtomicUMaxOp); @@ -546,6 +543,8 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, atomicOp, "sub-element-width atomic ops unsupported with Kernel capability"); + auto dstType = cast(storageElemType); + auto accessChainOp = ptr.getDefiningOp(); if (!accessChainOp) return failure(); diff --git a/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir index b5815a73ee8b2..fa416512aa144 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir @@ -146,3 +146,21 @@ func.func @atomic_andi_i8_storage_buffer(%value: i8, %memref: memref<16xi8, #spi } } + +// ----- + +// Floating-point atomic add requires the shader_atomic_float_add extension. + +module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + +// CHECK: func.func @atomic_addf_storage_buffer +// CHECK-SAME: (%[[VAL:.+]]: f32, +func.func @atomic_addf_storage_buffer(%value: f32, %memref: memref<2x3x4xf32, #spirv.storage_class>, %i0: index, %i1: index, %i2: index) -> f32 { + // CHECK: %[[AC:.+]] = spirv.AccessChain + // CHECK: %[[ATOMIC:.+]] = spirv.EXT.AtomicFAdd %[[AC]], %[[VAL]] : !spirv.ptr + // CHECK: return %[[ATOMIC]] + %0 = memref.atomic_rmw "addf" %value, %memref[%i0, %i1, %i2] : (f32, memref<2x3x4xf32, #spirv.storage_class>) -> f32 + return %0: f32 +} + +}