Skip to content

[flang][OpenMP] Support for "atomic compare capture"#202315

Open
SunilKuravinakop wants to merge 3 commits into
llvm:mainfrom
SunilKuravinakop:compare_capture
Open

[flang][OpenMP] Support for "atomic compare capture"#202315
SunilKuravinakop wants to merge 3 commits into
llvm:mainfrom
SunilKuravinakop:compare_capture

Conversation

@SunilKuravinakop
Copy link
Copy Markdown
Contributor

Adding support for "!$omp atomic compare capture".

subroutine compare_capture_01(var1, num1, num2, num3)
integer :: var1, num1, num2, num3
!$omp atomic compare capture
num3 = var1
if (var1 == num1) var1 = num2
!$omp end atomic
end subroutine

This also Fixes #202311

@llvmorg-github-actions
Copy link
Copy Markdown

llvmorg-github-actions Bot commented Jun 8, 2026

@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-flang-fir-hlfir

Author: SunilKuravinakop

Changes

Adding support for "!$omp atomic compare capture".

subroutine compare_capture_01(var1, num1, num2, num3)
integer :: var1, num1, num2, num3
!$omp atomic compare capture
num3 = var1
if (var1 == num1) var1 = num2
!$omp end atomic
end subroutine

This also Fixes #202311


Patch is 28.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/202315.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/Atomic.cpp (+51-39)
  • (modified) flang/test/Integration/OpenMP/atomic-compare.f90 (+32)
  • (modified) flang/test/Lower/OpenMP/atomic-compare.f90 (+56)
  • (modified) flang/test/Parser/OpenMP/atomic-unparse.f90 (+29)
  • (modified) mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td (+8-1)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+10)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+6)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+177-40)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+48)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+41)
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index 4ce0c8f878c48..c0f97aada637d 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -560,16 +560,45 @@ void Fortran::lower::omp::lowerAtomic(
   int action1 = analysis.op1.what & analysis.Action;
   memOrder = makeValidForAction(memOrder, action0, action1, version);
 
+  // --- Shared capture scaffolding ---
+  mlir::Operation *captureOp = nullptr;
+  fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
+  fir::FirOpBuilder::InsertPoint atomicAt, postAt;
+
+  if (construct.IsCapture()) {
+    assert(action0 != analysis.None && action1 != analysis.None &&
+           "Expexcing two actions");
+    (void)action0;
+    (void)action1;
+    captureOp = mlir::omp::AtomicCaptureOp::create(
+        builder, loc, hint, makeMemOrderAttr(converter, memOrder));
+    // Set the non-atomic insertion point to before the atomic.capture.
+    preAt = getInsertionPointBefore(captureOp);
+
+    mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
+    builder.setInsertionPointToEnd(block);
+    // Set the atomic insertion point to before the terminator inside
+    // atomic.capture.
+    mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
+    atomicAt = getInsertionPointBefore(term);
+    postAt = getInsertionPointAfter(captureOp);
+    hint = nullptr;
+    memOrder = std::nullopt;
+  }
+
   if (auto *cond = get(analysis.cond)) {
     // atomic compare: if (x == e) x = d
     // e : expecteVal
     // d : desiredVal
 
-    // Check for compound clauses (fail, capture) that are not yet
+    // Restore insertion point so pre-processing code (e.g. computing
+    // expectedVal) is emitted before the capture op, not after the terminator.
+    builder.restoreInsertionPoint(preAt);
+
+    // Check for compound clause (fail) that is not yet
     // supported with atomic compare.
     if (llvm::any_of(clauses, [](const omp::Clause &clause) {
-          return clause.id == llvm::omp::Clause::OMPC_fail ||
-                 clause.id == llvm::omp::Clause::OMPC_capture;
+          return clause.id == llvm::omp::Clause::OMPC_fail;
         })) {
       TODO(loc, "Compound clauses of OpenMP ATOMIC COMPARE");
     }
@@ -617,6 +646,17 @@ void Fortran::lower::omp::lowerAtomic(
       expectedVal = builder.createConvert(loc, elemTypeOfX, expectedVal);
     }
 
+    // If this is a compare+capture, generate the read op first.
+    if (construct.IsCapture()) {
+      assert(get(analysis.op0.assign) && (analysis.op0.what & analysis.Read) &&
+             "Expected a read operation for compare capture");
+      mlir::Operation *readOp = genAtomicRead(
+          converter, semaCtx, loc, stmtCtx, atomAddr, atom,
+          *get(analysis.op0.assign), hint, memOrder, preAt, atomicAt, postAt);
+      assert(readOp && "Should have created an atomic read operation");
+      builder.setInsertionPointAfter(readOp);
+    }
+
     mlir::UnitAttr weakAttr = nullptr;
     if (llvm::any_of(clauses, [](const omp::Clause &clause) {
           return clause.id == llvm::omp::Clause::OMPC_weak;
@@ -685,34 +725,9 @@ void Fortran::lower::omp::lowerAtomic(
     // Generate omp.yield
     mlir::omp::YieldOp::create(builder, loc, newVal);
     builder.setInsertionPointAfter(atomicOp);
-
     // END omp atomic compare
   } else {
-    mlir::Operation *captureOp = nullptr;
-    fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
-    fir::FirOpBuilder::InsertPoint atomicAt, postAt;
-
-    if (construct.IsCapture()) {
-      // Capturing operation.
-      assert(action0 != analysis.None && action1 != analysis.None &&
-             "Expexcing two actions");
-      (void)action0;
-      (void)action1;
-      captureOp = mlir::omp::AtomicCaptureOp::create(
-          builder, loc, hint, makeMemOrderAttr(converter, memOrder));
-      // Set the non-atomic insertion point to before the atomic.capture.
-      preAt = getInsertionPointBefore(captureOp);
-
-      mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
-      builder.setInsertionPointToEnd(block);
-      // Set the atomic insertion point to before the terminator inside
-      // atomic.capture.
-      mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
-      atomicAt = getInsertionPointBefore(term);
-      postAt = getInsertionPointAfter(captureOp);
-      hint = nullptr;
-      memOrder = std::nullopt;
-    } else {
+    if (!construct.IsCapture()) {
       // Non-capturing operation.
       assert(action0 != analysis.None && action1 == analysis.None &&
              "Expexcing single action");
@@ -735,16 +750,13 @@ void Fortran::lower::omp::lowerAtomic(
           *get(analysis.op1.assign), hint, memOrder, preAt, atomicAt, postAt);
     }
 
-    if (construct.IsCapture()) {
-      // If this is a capture operation, the first/second ops will be inside
-      // of it. Set the insertion point to past the capture op itself.
-      builder.restoreInsertionPoint(postAt);
-    } else {
-      if (secondOp) {
-        builder.setInsertionPointAfter(secondOp);
-      } else {
-        builder.setInsertionPointAfter(firstOp);
-      }
+    if (!construct.IsCapture()) {
+      builder.setInsertionPointAfter(secondOp ? secondOp : firstOp);
     }
   }
+
+  // Shared capture cleanup.
+  if (construct.IsCapture()) {
+    builder.restoreInsertionPoint(postAt);
+  }
 }
diff --git a/flang/test/Integration/OpenMP/atomic-compare.f90 b/flang/test/Integration/OpenMP/atomic-compare.f90
index 249fb0dd8fa64..650f64b80af12 100644
--- a/flang/test/Integration/OpenMP/atomic-compare.f90
+++ b/flang/test/Integration/OpenMP/atomic-compare.f90
@@ -260,3 +260,35 @@ subroutine atomic_compare_weak(x, e, d)
   if (x == e) x = d
 end 
 
+! Integer equality compare+capture: cmpxchg + store old value
+!CHECK-LABEL: define void @atomic_compare_capture_int_eq_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
+!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
+!CHECK: store i32 %[[OLD]], ptr %[[V]]
+subroutine atomic_compare_capture_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture
+    v = x
+    if (x == e) x = d
+  !$omp end atomic
+end
+
+! Compare+capture with clause order reversed: capture compare
+!CHECK-LABEL: define void @atomic_capture_compare_int_eq_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
+!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
+!CHECK: store i32 %[[OLD]], ptr %[[V]]
+subroutine atomic_capture_compare_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic capture compare
+    v = x
+    if (x == e) x = d
+  !$omp end atomic
+end
+
diff --git a/flang/test/Lower/OpenMP/atomic-compare.f90 b/flang/test/Lower/OpenMP/atomic-compare.f90
index ac70edbed4e60..752a221aa538d 100644
--- a/flang/test/Lower/OpenMP/atomic-compare.f90
+++ b/flang/test/Lower/OpenMP/atomic-compare.f90
@@ -161,3 +161,59 @@ subroutine atomic_compare_int_eq_weak(x, e, d)
   !$omp atomic compare weak
   if (x .eq. e) x = d
 end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_eq(
+! CHECK-SAME:    %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME:    %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
+! CHECK-SAME:    %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
+! CHECK-SAME:    %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
+! CHECK:         %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK:         %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK:         %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
+! CHECK:         %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK:         %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK:         omp.atomic.capture memory_order(relaxed) {
+! CHECK:           omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[XVAL:.*]]: i32):
+! CHECK:             %[[CMP:.*]] = arith.cmpi eq, %[[XVAL]], %[[EVAL]] : i32
+! CHECK:             %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
+! CHECK:             %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
+! CHECK:             omp.yield(%[[SEL]] : i32)
+! CHECK:           }
+! CHECK:         }
+subroutine atomic_compare_capture_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture
+    v = x
+    if (x .eq. e) x = d
+  !$omp end atomic
+end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_gt(
+! CHECK-SAME:    %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME:    %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
+! CHECK-SAME:    %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
+! CHECK-SAME:    %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
+! CHECK:         %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK:         %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK:         %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
+! CHECK:         %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK:         %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK:         omp.atomic.capture memory_order(relaxed) {
+! CHECK:           omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[XVAL:.*]]: i32):
+! CHECK:             %[[CMP:.*]] = arith.cmpi sgt, %[[XVAL]], %[[EVAL]] : i32
+! CHECK:             %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
+! CHECK:             %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
+! CHECK:             omp.yield(%[[SEL]] : i32)
+! CHECK:           } {{.*}}weak{{.*}}
+! CHECK:         }
+subroutine atomic_compare_capture_int_gt(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture weak
+    v = x
+    if (x > e) x = d
+  !$omp end atomic
+end
diff --git a/flang/test/Parser/OpenMP/atomic-unparse.f90 b/flang/test/Parser/OpenMP/atomic-unparse.f90
index 4f3cf0eac0338..dc0cc1a62f6c2 100644
--- a/flang/test/Parser/OpenMP/atomic-unparse.f90
+++ b/flang/test/Parser/OpenMP/atomic-unparse.f90
@@ -192,6 +192,26 @@ program main
       i = j
    end if
 
+!COMPARE CAPTURE
+!$omp atomic compare capture
+   k = i
+   if (i .eq. j) then
+      i = k
+   end if
+!$omp end atomic
+!$omp atomic capture compare
+   k = i
+   if (i .eq. j) then
+      i = k
+   end if
+!$omp end atomic
+!$omp atomic capture compare weak
+   k = i
+   if (i < j) then
+      i = k
+   end if
+!$omp end atomic
+
 !ATOMIC
 !$omp atomic
    i = j
@@ -296,6 +316,15 @@ end program main
 !CHECK: !$OMP ATOMIC WEAK COMPARE
 !CHECK: !$OMP ATOMIC COMPARE SEQ_CST WEAK
 
+!COMPARE CAPTURE
+
+!CHECK: !$OMP ATOMIC COMPARE CAPTURE
+!CHECK: !$OMP END ATOMIC
+!CHECK: !$OMP ATOMIC CAPTURE COMPARE
+!CHECK: !$OMP END ATOMIC
+!CHECK: !$OMP ATOMIC CAPTURE COMPARE WEAK
+!CHECK: !$OMP END ATOMIC
+
 !ATOMIC
 !CHECK: !$OMP ATOMIC
 !CHECK: !$OMP ATOMIC SEQ_CST
diff --git a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
index abb21705b3c1c..8c9015f05bb72 100644
--- a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
@@ -289,10 +289,12 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
         auto secondReadStmt = dyn_cast<AtomicReadOpInterface>(secondOp);
         auto secondUpdateStmt = dyn_cast<AtomicUpdateOpInterface>(secondOp);
         auto secondWriteStmt = dyn_cast<AtomicWriteOpInterface>(secondOp);
+        auto secondCompareStmt = dyn_cast<AtomicCompareOpInterface>(secondOp);
 
         if (!((firstUpdateStmt && secondReadStmt) ||
               (firstReadStmt && secondUpdateStmt) ||
-              (firstReadStmt && secondWriteStmt)))
+              (firstReadStmt && secondWriteStmt) ||
+              (firstReadStmt && secondCompareStmt)))
           return ops.front().emitError()
                 << "invalid sequence of operations in the capture region";
         if (firstUpdateStmt && secondReadStmt &&
@@ -310,6 +312,11 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
           return firstReadStmt.emitError()
                 << "captured variable in atomic.read must be updated in "
                     "second operation";
+        if (firstReadStmt && secondCompareStmt &&
+            firstReadStmt.getX() != secondCompareStmt.getX())
+          return firstReadStmt.emitError()
+                << "captured variable in atomic.read must be updated in "
+                    "second operation";
 
         return mlir::success();
       }]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 0962b330e2f23..1241abc10298f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1928,6 +1928,12 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
         omp.atomic.write ...
         omp.terminator
       }
+
+      omp.atomic.capture {
+        omp.atomic.read ...
+        omp.atomic.compare ...
+        omp.terminator
+      }
     ```
   }] # clausesDescription;
 
@@ -1946,6 +1952,10 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
     /// Returns the `atomic.update` operation inside the region, if any.
     /// Otherwise, it returns nullptr.
     AtomicUpdateOp getAtomicUpdateOp();
+
+    /// Returns the `atomic.compare` operation inside the region, if any.
+    /// Otherwise, it returns nullptr.
+    AtomicCompareOp getAtomicCompareOp();
   }] # clausesExtraClassDeclaration;
 
   let hasRegionVerifier = 1;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index db5fd8f2e3230..0eafd0a267b97 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4615,6 +4615,12 @@ AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
   return dyn_cast<AtomicUpdateOp>(getSecondOp());
 }
 
+AtomicCompareOp AtomicCaptureOp::getAtomicCompareOp() {
+  if (auto op = dyn_cast<AtomicCompareOp>(getFirstOp()))
+    return op;
+  return dyn_cast<AtomicCompareOp>(getSecondOp());
+}
+
 LogicalResult AtomicCaptureOp::verify() {
   return verifySynchronizationHint(*this, getHint());
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 6f93ad231cfac..c5a07a7dc6cb2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4795,6 +4795,43 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
   return success();
 }
 
+/// Helper to extract the OMPAtomicCompareOp from an integer comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::ICmpPredicate::eq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::ICmpPredicate::slt:
+  case LLVM::ICmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::ICmpPredicate::sgt:
+  case LLVM::ICmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
+/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::FCmpPredicate::oeq:
+  case LLVM::FCmpPredicate::ueq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::FCmpPredicate::olt:
+  case LLVM::FCmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::FCmpPredicate::ogt:
+  case LLVM::FCmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
 static LogicalResult
 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
                         llvm::IRBuilderBase &builder,
@@ -4803,13 +4840,150 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
   if (failed(checkImplementationStatus(*atomicCaptureOp)))
     return failure();
 
+  omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
+  omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
+  omp::AtomicCompareOp atomicCompareOp = atomicCaptureOp.getAtomicCompareOp();
+
+  // If the capture contains an atomic.compare, delegate to
+  // createAtomicCompare with the capture variable (V) set.
+  if (atomicCompareOp) {
+    omp::AtomicReadOp atomicReadOp = atomicCaptureOp.getAtomicReadOp();
+    assert(atomicReadOp && "expected atomic.read in capture+compare");
+
+    Region &region = atomicCompareOp.getRegion();
+    Block &block = region.front();
+
+    llvm::Type *llvmXElementType =
+        moduleTranslation.convertType(block.getArgument(0).getType());
+    llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
+    llvm::Value *llvmV = moduleTranslation.lookupValue(atomicReadOp.getV());
+
+    bool isSigned = false;
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {
+        llvmX, llvmXElementType, isSigned, /*IsVolatile=*/false};
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {
+        llvmV, llvmXElementType, /*isSigned=*/false, /*IsVolatile=*/false};
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicR = {nullptr, nullptr, false,
+                                                        false};
+
+    llvm::AtomicOrdering atomicOrdering =
+        convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
+
+    // Pre-translate non-pattern operations inside the compare region.
+    auto isAtomicComparePatternOp = [](Operation &op) {
+      return llvm::isa<LLVM::ICmpOp, LLVM::FCmpOp, LLVM::SelectOp, LLVM::AndOp,
+                       LLVM::OrOp>(op);
+    };
+    for (Operation &op : block.without_terminator()) {
+      if (isAtomicComparePatternOp(op))
+        continue;
+      bool allOperandsMapped =
+          llvm::all_of(op.getOperands(), [&](mlir::Value v) {
+            return moduleTranslation.lookupValue(v) != nullptr;
+          });
+      if (!allOperandsMapped)
+        continue;
+      if (failed(moduleTranslation.convertOperation(op, builder)))
+        return atomicCompareOp.emitError(
+            "failed to translate operation inside atomic compare region");
+    }
+
+    auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
+      if (llvm::Value *existing = moduleTranslation.lookupValue(val))
+        return existing;
+      if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
+        if (loadOp->getParentRegion() == &region) {
+          llvm::Value *loadAddr =
+              moduleTranslation.lookupValue(loadOp.getAddr());
+          if (!loadAddr)
+            return nullptr;
+          llvm::Type *loadType =
+              moduleTranslation.convertType(loadOp.getResult().getType());
+          return builder.CreateLoad(loadType, loadAddr);
+        }
+      }
+      return nullptr;
+    };
+
+    // Extract comparison predicate, eVal, and dVal from the region.
+    llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
+    llvm::Value *eVal = nullptr;
+    llvm::Value *dVal = nullptr;
+    bool isXBinopExpr = false;
+
+    for (Operation &op : block.getOperations()) {
+      if (auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
+        auto maybeOp =
+            convertICmpPredicateToAtomicCompareOp(icmpOp.getPredicate());
+        if (!maybeOp)
+     ...
[truncated]

@llvmorg-github-actions
Copy link
Copy Markdown

@llvm/pr-subscribers-mlir-openacc

Author: SunilKuravinakop

Changes

Adding support for "!$omp atomic compare capture".

subroutine compare_capture_01(var1, num1, num2, num3)
integer :: var1, num1, num2, num3
!$omp atomic compare capture
num3 = var1
if (var1 == num1) var1 = num2
!$omp end atomic
end subroutine

This also Fixes #202311


Patch is 28.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/202315.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/Atomic.cpp (+51-39)
  • (modified) flang/test/Integration/OpenMP/atomic-compare.f90 (+32)
  • (modified) flang/test/Lower/OpenMP/atomic-compare.f90 (+56)
  • (modified) flang/test/Parser/OpenMP/atomic-unparse.f90 (+29)
  • (modified) mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td (+8-1)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+10)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+6)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+177-40)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+48)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+41)
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index 4ce0c8f878c48..c0f97aada637d 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -560,16 +560,45 @@ void Fortran::lower::omp::lowerAtomic(
   int action1 = analysis.op1.what & analysis.Action;
   memOrder = makeValidForAction(memOrder, action0, action1, version);
 
+  // --- Shared capture scaffolding ---
+  mlir::Operation *captureOp = nullptr;
+  fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
+  fir::FirOpBuilder::InsertPoint atomicAt, postAt;
+
+  if (construct.IsCapture()) {
+    assert(action0 != analysis.None && action1 != analysis.None &&
+           "Expexcing two actions");
+    (void)action0;
+    (void)action1;
+    captureOp = mlir::omp::AtomicCaptureOp::create(
+        builder, loc, hint, makeMemOrderAttr(converter, memOrder));
+    // Set the non-atomic insertion point to before the atomic.capture.
+    preAt = getInsertionPointBefore(captureOp);
+
+    mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
+    builder.setInsertionPointToEnd(block);
+    // Set the atomic insertion point to before the terminator inside
+    // atomic.capture.
+    mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
+    atomicAt = getInsertionPointBefore(term);
+    postAt = getInsertionPointAfter(captureOp);
+    hint = nullptr;
+    memOrder = std::nullopt;
+  }
+
   if (auto *cond = get(analysis.cond)) {
     // atomic compare: if (x == e) x = d
     // e : expecteVal
     // d : desiredVal
 
-    // Check for compound clauses (fail, capture) that are not yet
+    // Restore insertion point so pre-processing code (e.g. computing
+    // expectedVal) is emitted before the capture op, not after the terminator.
+    builder.restoreInsertionPoint(preAt);
+
+    // Check for compound clause (fail) that is not yet
     // supported with atomic compare.
     if (llvm::any_of(clauses, [](const omp::Clause &clause) {
-          return clause.id == llvm::omp::Clause::OMPC_fail ||
-                 clause.id == llvm::omp::Clause::OMPC_capture;
+          return clause.id == llvm::omp::Clause::OMPC_fail;
         })) {
       TODO(loc, "Compound clauses of OpenMP ATOMIC COMPARE");
     }
@@ -617,6 +646,17 @@ void Fortran::lower::omp::lowerAtomic(
       expectedVal = builder.createConvert(loc, elemTypeOfX, expectedVal);
     }
 
+    // If this is a compare+capture, generate the read op first.
+    if (construct.IsCapture()) {
+      assert(get(analysis.op0.assign) && (analysis.op0.what & analysis.Read) &&
+             "Expected a read operation for compare capture");
+      mlir::Operation *readOp = genAtomicRead(
+          converter, semaCtx, loc, stmtCtx, atomAddr, atom,
+          *get(analysis.op0.assign), hint, memOrder, preAt, atomicAt, postAt);
+      assert(readOp && "Should have created an atomic read operation");
+      builder.setInsertionPointAfter(readOp);
+    }
+
     mlir::UnitAttr weakAttr = nullptr;
     if (llvm::any_of(clauses, [](const omp::Clause &clause) {
           return clause.id == llvm::omp::Clause::OMPC_weak;
@@ -685,34 +725,9 @@ void Fortran::lower::omp::lowerAtomic(
     // Generate omp.yield
     mlir::omp::YieldOp::create(builder, loc, newVal);
     builder.setInsertionPointAfter(atomicOp);
-
     // END omp atomic compare
   } else {
-    mlir::Operation *captureOp = nullptr;
-    fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
-    fir::FirOpBuilder::InsertPoint atomicAt, postAt;
-
-    if (construct.IsCapture()) {
-      // Capturing operation.
-      assert(action0 != analysis.None && action1 != analysis.None &&
-             "Expexcing two actions");
-      (void)action0;
-      (void)action1;
-      captureOp = mlir::omp::AtomicCaptureOp::create(
-          builder, loc, hint, makeMemOrderAttr(converter, memOrder));
-      // Set the non-atomic insertion point to before the atomic.capture.
-      preAt = getInsertionPointBefore(captureOp);
-
-      mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
-      builder.setInsertionPointToEnd(block);
-      // Set the atomic insertion point to before the terminator inside
-      // atomic.capture.
-      mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
-      atomicAt = getInsertionPointBefore(term);
-      postAt = getInsertionPointAfter(captureOp);
-      hint = nullptr;
-      memOrder = std::nullopt;
-    } else {
+    if (!construct.IsCapture()) {
       // Non-capturing operation.
       assert(action0 != analysis.None && action1 == analysis.None &&
              "Expexcing single action");
@@ -735,16 +750,13 @@ void Fortran::lower::omp::lowerAtomic(
           *get(analysis.op1.assign), hint, memOrder, preAt, atomicAt, postAt);
     }
 
-    if (construct.IsCapture()) {
-      // If this is a capture operation, the first/second ops will be inside
-      // of it. Set the insertion point to past the capture op itself.
-      builder.restoreInsertionPoint(postAt);
-    } else {
-      if (secondOp) {
-        builder.setInsertionPointAfter(secondOp);
-      } else {
-        builder.setInsertionPointAfter(firstOp);
-      }
+    if (!construct.IsCapture()) {
+      builder.setInsertionPointAfter(secondOp ? secondOp : firstOp);
     }
   }
+
+  // Shared capture cleanup.
+  if (construct.IsCapture()) {
+    builder.restoreInsertionPoint(postAt);
+  }
 }
diff --git a/flang/test/Integration/OpenMP/atomic-compare.f90 b/flang/test/Integration/OpenMP/atomic-compare.f90
index 249fb0dd8fa64..650f64b80af12 100644
--- a/flang/test/Integration/OpenMP/atomic-compare.f90
+++ b/flang/test/Integration/OpenMP/atomic-compare.f90
@@ -260,3 +260,35 @@ subroutine atomic_compare_weak(x, e, d)
   if (x == e) x = d
 end 
 
+! Integer equality compare+capture: cmpxchg + store old value
+!CHECK-LABEL: define void @atomic_compare_capture_int_eq_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
+!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
+!CHECK: store i32 %[[OLD]], ptr %[[V]]
+subroutine atomic_compare_capture_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture
+    v = x
+    if (x == e) x = d
+  !$omp end atomic
+end
+
+! Compare+capture with clause order reversed: capture compare
+!CHECK-LABEL: define void @atomic_capture_compare_int_eq_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
+!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
+!CHECK: store i32 %[[OLD]], ptr %[[V]]
+subroutine atomic_capture_compare_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic capture compare
+    v = x
+    if (x == e) x = d
+  !$omp end atomic
+end
+
diff --git a/flang/test/Lower/OpenMP/atomic-compare.f90 b/flang/test/Lower/OpenMP/atomic-compare.f90
index ac70edbed4e60..752a221aa538d 100644
--- a/flang/test/Lower/OpenMP/atomic-compare.f90
+++ b/flang/test/Lower/OpenMP/atomic-compare.f90
@@ -161,3 +161,59 @@ subroutine atomic_compare_int_eq_weak(x, e, d)
   !$omp atomic compare weak
   if (x .eq. e) x = d
 end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_eq(
+! CHECK-SAME:    %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME:    %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
+! CHECK-SAME:    %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
+! CHECK-SAME:    %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
+! CHECK:         %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK:         %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK:         %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
+! CHECK:         %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK:         %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK:         omp.atomic.capture memory_order(relaxed) {
+! CHECK:           omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[XVAL:.*]]: i32):
+! CHECK:             %[[CMP:.*]] = arith.cmpi eq, %[[XVAL]], %[[EVAL]] : i32
+! CHECK:             %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
+! CHECK:             %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
+! CHECK:             omp.yield(%[[SEL]] : i32)
+! CHECK:           }
+! CHECK:         }
+subroutine atomic_compare_capture_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture
+    v = x
+    if (x .eq. e) x = d
+  !$omp end atomic
+end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_gt(
+! CHECK-SAME:    %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME:    %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
+! CHECK-SAME:    %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
+! CHECK-SAME:    %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
+! CHECK:         %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK:         %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK:         %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
+! CHECK:         %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK:         %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK:         omp.atomic.capture memory_order(relaxed) {
+! CHECK:           omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[XVAL:.*]]: i32):
+! CHECK:             %[[CMP:.*]] = arith.cmpi sgt, %[[XVAL]], %[[EVAL]] : i32
+! CHECK:             %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
+! CHECK:             %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
+! CHECK:             omp.yield(%[[SEL]] : i32)
+! CHECK:           } {{.*}}weak{{.*}}
+! CHECK:         }
+subroutine atomic_compare_capture_int_gt(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture weak
+    v = x
+    if (x > e) x = d
+  !$omp end atomic
+end
diff --git a/flang/test/Parser/OpenMP/atomic-unparse.f90 b/flang/test/Parser/OpenMP/atomic-unparse.f90
index 4f3cf0eac0338..dc0cc1a62f6c2 100644
--- a/flang/test/Parser/OpenMP/atomic-unparse.f90
+++ b/flang/test/Parser/OpenMP/atomic-unparse.f90
@@ -192,6 +192,26 @@ program main
       i = j
    end if
 
+!COMPARE CAPTURE
+!$omp atomic compare capture
+   k = i
+   if (i .eq. j) then
+      i = k
+   end if
+!$omp end atomic
+!$omp atomic capture compare
+   k = i
+   if (i .eq. j) then
+      i = k
+   end if
+!$omp end atomic
+!$omp atomic capture compare weak
+   k = i
+   if (i < j) then
+      i = k
+   end if
+!$omp end atomic
+
 !ATOMIC
 !$omp atomic
    i = j
@@ -296,6 +316,15 @@ end program main
 !CHECK: !$OMP ATOMIC WEAK COMPARE
 !CHECK: !$OMP ATOMIC COMPARE SEQ_CST WEAK
 
+!COMPARE CAPTURE
+
+!CHECK: !$OMP ATOMIC COMPARE CAPTURE
+!CHECK: !$OMP END ATOMIC
+!CHECK: !$OMP ATOMIC CAPTURE COMPARE
+!CHECK: !$OMP END ATOMIC
+!CHECK: !$OMP ATOMIC CAPTURE COMPARE WEAK
+!CHECK: !$OMP END ATOMIC
+
 !ATOMIC
 !CHECK: !$OMP ATOMIC
 !CHECK: !$OMP ATOMIC SEQ_CST
diff --git a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
index abb21705b3c1c..8c9015f05bb72 100644
--- a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
@@ -289,10 +289,12 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
         auto secondReadStmt = dyn_cast<AtomicReadOpInterface>(secondOp);
         auto secondUpdateStmt = dyn_cast<AtomicUpdateOpInterface>(secondOp);
         auto secondWriteStmt = dyn_cast<AtomicWriteOpInterface>(secondOp);
+        auto secondCompareStmt = dyn_cast<AtomicCompareOpInterface>(secondOp);
 
         if (!((firstUpdateStmt && secondReadStmt) ||
               (firstReadStmt && secondUpdateStmt) ||
-              (firstReadStmt && secondWriteStmt)))
+              (firstReadStmt && secondWriteStmt) ||
+              (firstReadStmt && secondCompareStmt)))
           return ops.front().emitError()
                 << "invalid sequence of operations in the capture region";
         if (firstUpdateStmt && secondReadStmt &&
@@ -310,6 +312,11 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
           return firstReadStmt.emitError()
                 << "captured variable in atomic.read must be updated in "
                     "second operation";
+        if (firstReadStmt && secondCompareStmt &&
+            firstReadStmt.getX() != secondCompareStmt.getX())
+          return firstReadStmt.emitError()
+                << "captured variable in atomic.read must be updated in "
+                    "second operation";
 
         return mlir::success();
       }]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 0962b330e2f23..1241abc10298f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1928,6 +1928,12 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
         omp.atomic.write ...
         omp.terminator
       }
+
+      omp.atomic.capture {
+        omp.atomic.read ...
+        omp.atomic.compare ...
+        omp.terminator
+      }
     ```
   }] # clausesDescription;
 
@@ -1946,6 +1952,10 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
     /// Returns the `atomic.update` operation inside the region, if any.
     /// Otherwise, it returns nullptr.
     AtomicUpdateOp getAtomicUpdateOp();
+
+    /// Returns the `atomic.compare` operation inside the region, if any.
+    /// Otherwise, it returns nullptr.
+    AtomicCompareOp getAtomicCompareOp();
   }] # clausesExtraClassDeclaration;
 
   let hasRegionVerifier = 1;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index db5fd8f2e3230..0eafd0a267b97 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4615,6 +4615,12 @@ AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
   return dyn_cast<AtomicUpdateOp>(getSecondOp());
 }
 
+AtomicCompareOp AtomicCaptureOp::getAtomicCompareOp() {
+  if (auto op = dyn_cast<AtomicCompareOp>(getFirstOp()))
+    return op;
+  return dyn_cast<AtomicCompareOp>(getSecondOp());
+}
+
 LogicalResult AtomicCaptureOp::verify() {
   return verifySynchronizationHint(*this, getHint());
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 6f93ad231cfac..c5a07a7dc6cb2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4795,6 +4795,43 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
   return success();
 }
 
+/// Helper to extract the OMPAtomicCompareOp from an integer comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::ICmpPredicate::eq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::ICmpPredicate::slt:
+  case LLVM::ICmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::ICmpPredicate::sgt:
+  case LLVM::ICmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
+/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::FCmpPredicate::oeq:
+  case LLVM::FCmpPredicate::ueq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::FCmpPredicate::olt:
+  case LLVM::FCmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::FCmpPredicate::ogt:
+  case LLVM::FCmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
 static LogicalResult
 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
                         llvm::IRBuilderBase &builder,
@@ -4803,13 +4840,150 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
   if (failed(checkImplementationStatus(*atomicCaptureOp)))
     return failure();
 
+  omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
+  omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
+  omp::AtomicCompareOp atomicCompareOp = atomicCaptureOp.getAtomicCompareOp();
+
+  // If the capture contains an atomic.compare, delegate to
+  // createAtomicCompare with the capture variable (V) set.
+  if (atomicCompareOp) {
+    omp::AtomicReadOp atomicReadOp = atomicCaptureOp.getAtomicReadOp();
+    assert(atomicReadOp && "expected atomic.read in capture+compare");
+
+    Region &region = atomicCompareOp.getRegion();
+    Block &block = region.front();
+
+    llvm::Type *llvmXElementType =
+        moduleTranslation.convertType(block.getArgument(0).getType());
+    llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
+    llvm::Value *llvmV = moduleTranslation.lookupValue(atomicReadOp.getV());
+
+    bool isSigned = false;
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {
+        llvmX, llvmXElementType, isSigned, /*IsVolatile=*/false};
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {
+        llvmV, llvmXElementType, /*isSigned=*/false, /*IsVolatile=*/false};
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicR = {nullptr, nullptr, false,
+                                                        false};
+
+    llvm::AtomicOrdering atomicOrdering =
+        convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
+
+    // Pre-translate non-pattern operations inside the compare region.
+    auto isAtomicComparePatternOp = [](Operation &op) {
+      return llvm::isa<LLVM::ICmpOp, LLVM::FCmpOp, LLVM::SelectOp, LLVM::AndOp,
+                       LLVM::OrOp>(op);
+    };
+    for (Operation &op : block.without_terminator()) {
+      if (isAtomicComparePatternOp(op))
+        continue;
+      bool allOperandsMapped =
+          llvm::all_of(op.getOperands(), [&](mlir::Value v) {
+            return moduleTranslation.lookupValue(v) != nullptr;
+          });
+      if (!allOperandsMapped)
+        continue;
+      if (failed(moduleTranslation.convertOperation(op, builder)))
+        return atomicCompareOp.emitError(
+            "failed to translate operation inside atomic compare region");
+    }
+
+    auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
+      if (llvm::Value *existing = moduleTranslation.lookupValue(val))
+        return existing;
+      if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
+        if (loadOp->getParentRegion() == &region) {
+          llvm::Value *loadAddr =
+              moduleTranslation.lookupValue(loadOp.getAddr());
+          if (!loadAddr)
+            return nullptr;
+          llvm::Type *loadType =
+              moduleTranslation.convertType(loadOp.getResult().getType());
+          return builder.CreateLoad(loadType, loadAddr);
+        }
+      }
+      return nullptr;
+    };
+
+    // Extract comparison predicate, eVal, and dVal from the region.
+    llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
+    llvm::Value *eVal = nullptr;
+    llvm::Value *dVal = nullptr;
+    bool isXBinopExpr = false;
+
+    for (Operation &op : block.getOperations()) {
+      if (auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
+        auto maybeOp =
+            convertICmpPredicateToAtomicCompareOp(icmpOp.getPredicate());
+        if (!maybeOp)
+     ...
[truncated]

@SunilKuravinakop
Copy link
Copy Markdown
Contributor Author

I am getting the following error for windows build:

++ sccache --zero-stats
sccache: error: Timed out waiting for server startup. Maybe the remote service is unreachable?
Run with SCCACHE_LOG=debug SCCACHE_NO_DAEMON=1 to get more information
Error: Process completed with exit code 2.

Can somebody help me in understanding it? If this is a server error is there something that needs to be done?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[flang][OpenMP] Support for "atomic compare capture"

1 participant