diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp index 4ce0c8f878c48..c0f97aada637d 100644 --- a/flang/lib/Lower/OpenMP/Atomic.cpp +++ b/flang/lib/Lower/OpenMP/Atomic.cpp @@ -560,16 +560,45 @@ void Fortran::lower::omp::lowerAtomic( int action1 = analysis.op1.what & analysis.Action; memOrder = makeValidForAction(memOrder, action0, action1, version); + // --- Shared capture scaffolding --- + mlir::Operation *captureOp = nullptr; + fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint(); + fir::FirOpBuilder::InsertPoint atomicAt, postAt; + + if (construct.IsCapture()) { + assert(action0 != analysis.None && action1 != analysis.None && + "Expexcing two actions"); + (void)action0; + (void)action1; + captureOp = mlir::omp::AtomicCaptureOp::create( + builder, loc, hint, makeMemOrderAttr(converter, memOrder)); + // Set the non-atomic insertion point to before the atomic.capture. + preAt = getInsertionPointBefore(captureOp); + + mlir::Block *block = builder.createBlock(&captureOp->getRegion(0)); + builder.setInsertionPointToEnd(block); + // Set the atomic insertion point to before the terminator inside + // atomic.capture. + mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc); + atomicAt = getInsertionPointBefore(term); + postAt = getInsertionPointAfter(captureOp); + hint = nullptr; + memOrder = std::nullopt; + } + if (auto *cond = get(analysis.cond)) { // atomic compare: if (x == e) x = d // e : expecteVal // d : desiredVal - // Check for compound clauses (fail, capture) that are not yet + // Restore insertion point so pre-processing code (e.g. computing + // expectedVal) is emitted before the capture op, not after the terminator. + builder.restoreInsertionPoint(preAt); + + // Check for compound clause (fail) that is not yet // supported with atomic compare. if (llvm::any_of(clauses, [](const omp::Clause &clause) { - return clause.id == llvm::omp::Clause::OMPC_fail || - clause.id == llvm::omp::Clause::OMPC_capture; + return clause.id == llvm::omp::Clause::OMPC_fail; })) { TODO(loc, "Compound clauses of OpenMP ATOMIC COMPARE"); } @@ -617,6 +646,17 @@ void Fortran::lower::omp::lowerAtomic( expectedVal = builder.createConvert(loc, elemTypeOfX, expectedVal); } + // If this is a compare+capture, generate the read op first. + if (construct.IsCapture()) { + assert(get(analysis.op0.assign) && (analysis.op0.what & analysis.Read) && + "Expected a read operation for compare capture"); + mlir::Operation *readOp = genAtomicRead( + converter, semaCtx, loc, stmtCtx, atomAddr, atom, + *get(analysis.op0.assign), hint, memOrder, preAt, atomicAt, postAt); + assert(readOp && "Should have created an atomic read operation"); + builder.setInsertionPointAfter(readOp); + } + mlir::UnitAttr weakAttr = nullptr; if (llvm::any_of(clauses, [](const omp::Clause &clause) { return clause.id == llvm::omp::Clause::OMPC_weak; @@ -685,34 +725,9 @@ void Fortran::lower::omp::lowerAtomic( // Generate omp.yield mlir::omp::YieldOp::create(builder, loc, newVal); builder.setInsertionPointAfter(atomicOp); - // END omp atomic compare } else { - mlir::Operation *captureOp = nullptr; - fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint(); - fir::FirOpBuilder::InsertPoint atomicAt, postAt; - - if (construct.IsCapture()) { - // Capturing operation. - assert(action0 != analysis.None && action1 != analysis.None && - "Expexcing two actions"); - (void)action0; - (void)action1; - captureOp = mlir::omp::AtomicCaptureOp::create( - builder, loc, hint, makeMemOrderAttr(converter, memOrder)); - // Set the non-atomic insertion point to before the atomic.capture. - preAt = getInsertionPointBefore(captureOp); - - mlir::Block *block = builder.createBlock(&captureOp->getRegion(0)); - builder.setInsertionPointToEnd(block); - // Set the atomic insertion point to before the terminator inside - // atomic.capture. - mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc); - atomicAt = getInsertionPointBefore(term); - postAt = getInsertionPointAfter(captureOp); - hint = nullptr; - memOrder = std::nullopt; - } else { + if (!construct.IsCapture()) { // Non-capturing operation. assert(action0 != analysis.None && action1 == analysis.None && "Expexcing single action"); @@ -735,16 +750,13 @@ void Fortran::lower::omp::lowerAtomic( *get(analysis.op1.assign), hint, memOrder, preAt, atomicAt, postAt); } - if (construct.IsCapture()) { - // If this is a capture operation, the first/second ops will be inside - // of it. Set the insertion point to past the capture op itself. - builder.restoreInsertionPoint(postAt); - } else { - if (secondOp) { - builder.setInsertionPointAfter(secondOp); - } else { - builder.setInsertionPointAfter(firstOp); - } + if (!construct.IsCapture()) { + builder.setInsertionPointAfter(secondOp ? secondOp : firstOp); } } + + // Shared capture cleanup. + if (construct.IsCapture()) { + builder.restoreInsertionPoint(postAt); + } } diff --git a/flang/test/Integration/OpenMP/atomic-compare.f90 b/flang/test/Integration/OpenMP/atomic-compare.f90 index 249fb0dd8fa64..650f64b80af12 100644 --- a/flang/test/Integration/OpenMP/atomic-compare.f90 +++ b/flang/test/Integration/OpenMP/atomic-compare.f90 @@ -260,3 +260,35 @@ subroutine atomic_compare_weak(x, e, d) if (x == e) x = d end +! Integer equality compare+capture: cmpxchg + store old value +!CHECK-LABEL: define void @atomic_compare_capture_int_eq_( +!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]]) +!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]] +!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]] +!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic +!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0 +!CHECK: store i32 %[[OLD]], ptr %[[V]] +subroutine atomic_compare_capture_int_eq(x, e, d, v) + integer :: x, e, d, v + !$omp atomic compare capture + v = x + if (x == e) x = d + !$omp end atomic +end + +! Compare+capture with clause order reversed: capture compare +!CHECK-LABEL: define void @atomic_capture_compare_int_eq_( +!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]]) +!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]] +!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]] +!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic +!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0 +!CHECK: store i32 %[[OLD]], ptr %[[V]] +subroutine atomic_capture_compare_int_eq(x, e, d, v) + integer :: x, e, d, v + !$omp atomic capture compare + v = x + if (x == e) x = d + !$omp end atomic +end + diff --git a/flang/test/Lower/OpenMP/atomic-compare.f90 b/flang/test/Lower/OpenMP/atomic-compare.f90 index ac70edbed4e60..752a221aa538d 100644 --- a/flang/test/Lower/OpenMP/atomic-compare.f90 +++ b/flang/test/Lower/OpenMP/atomic-compare.f90 @@ -161,3 +161,59 @@ subroutine atomic_compare_int_eq_weak(x, e, d) !$omp atomic compare weak if (x .eq. e) x = d end + +! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_eq( +! CHECK-SAME: %[[X:.*]]: !fir.ref {fir.bindc_name = "x"}, +! CHECK-SAME: %[[E:.*]]: !fir.ref {fir.bindc_name = "e"}, +! CHECK-SAME: %[[D:.*]]: !fir.ref {fir.bindc_name = "d"}, +! CHECK-SAME: %[[V:.*]]: !fir.ref {fir.bindc_name = "v"}) +! CHECK: %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}} +! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}} +! CHECK: %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}} +! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}} +! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref +! CHECK: omp.atomic.capture memory_order(relaxed) { +! CHECK: omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref, !fir.ref, i32 +! CHECK: omp.atomic.compare %[[X_DECL]]#0 : !fir.ref { +! CHECK: ^bb0(%[[XVAL:.*]]: i32): +! CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[XVAL]], %[[EVAL]] : i32 +! CHECK: %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref +! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32 +! CHECK: omp.yield(%[[SEL]] : i32) +! CHECK: } +! CHECK: } +subroutine atomic_compare_capture_int_eq(x, e, d, v) + integer :: x, e, d, v + !$omp atomic compare capture + v = x + if (x .eq. e) x = d + !$omp end atomic +end + +! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_gt( +! CHECK-SAME: %[[X:.*]]: !fir.ref {fir.bindc_name = "x"}, +! CHECK-SAME: %[[E:.*]]: !fir.ref {fir.bindc_name = "e"}, +! CHECK-SAME: %[[D:.*]]: !fir.ref {fir.bindc_name = "d"}, +! CHECK-SAME: %[[V:.*]]: !fir.ref {fir.bindc_name = "v"}) +! CHECK: %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}} +! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}} +! CHECK: %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}} +! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}} +! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref +! CHECK: omp.atomic.capture memory_order(relaxed) { +! CHECK: omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref, !fir.ref, i32 +! CHECK: omp.atomic.compare %[[X_DECL]]#0 : !fir.ref { +! CHECK: ^bb0(%[[XVAL:.*]]: i32): +! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[XVAL]], %[[EVAL]] : i32 +! CHECK: %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref +! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32 +! CHECK: omp.yield(%[[SEL]] : i32) +! CHECK: } {{.*}}weak{{.*}} +! CHECK: } +subroutine atomic_compare_capture_int_gt(x, e, d, v) + integer :: x, e, d, v + !$omp atomic compare capture weak + v = x + if (x > e) x = d + !$omp end atomic +end diff --git a/flang/test/Parser/OpenMP/atomic-unparse.f90 b/flang/test/Parser/OpenMP/atomic-unparse.f90 index 4f3cf0eac0338..dc0cc1a62f6c2 100644 --- a/flang/test/Parser/OpenMP/atomic-unparse.f90 +++ b/flang/test/Parser/OpenMP/atomic-unparse.f90 @@ -192,6 +192,26 @@ program main i = j end if +!COMPARE CAPTURE +!$omp atomic compare capture + k = i + if (i .eq. j) then + i = k + end if +!$omp end atomic +!$omp atomic capture compare + k = i + if (i .eq. j) then + i = k + end if +!$omp end atomic +!$omp atomic capture compare weak + k = i + if (i < j) then + i = k + end if +!$omp end atomic + !ATOMIC !$omp atomic i = j @@ -296,6 +316,15 @@ end program main !CHECK: !$OMP ATOMIC WEAK COMPARE !CHECK: !$OMP ATOMIC COMPARE SEQ_CST WEAK +!COMPARE CAPTURE + +!CHECK: !$OMP ATOMIC COMPARE CAPTURE +!CHECK: !$OMP END ATOMIC +!CHECK: !$OMP ATOMIC CAPTURE COMPARE +!CHECK: !$OMP END ATOMIC +!CHECK: !$OMP ATOMIC CAPTURE COMPARE WEAK +!CHECK: !$OMP END ATOMIC + !ATOMIC !CHECK: !$OMP ATOMIC !CHECK: !$OMP ATOMIC SEQ_CST diff --git a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td index abb21705b3c1c..8c9015f05bb72 100644 --- a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td @@ -289,10 +289,12 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> { auto secondReadStmt = dyn_cast(secondOp); auto secondUpdateStmt = dyn_cast(secondOp); auto secondWriteStmt = dyn_cast(secondOp); + auto secondCompareStmt = dyn_cast(secondOp); if (!((firstUpdateStmt && secondReadStmt) || (firstReadStmt && secondUpdateStmt) || - (firstReadStmt && secondWriteStmt))) + (firstReadStmt && secondWriteStmt) || + (firstReadStmt && secondCompareStmt))) return ops.front().emitError() << "invalid sequence of operations in the capture region"; if (firstUpdateStmt && secondReadStmt && @@ -310,6 +312,11 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> { return firstReadStmt.emitError() << "captured variable in atomic.read must be updated in " "second operation"; + if (firstReadStmt && secondCompareStmt && + firstReadStmt.getX() != secondCompareStmt.getX()) + return firstReadStmt.emitError() + << "captured variable in atomic.read must be updated in " + "second operation"; return mlir::success(); }] diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 51e7080db5b29..ebee887f2afd1 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1929,6 +1929,12 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [ omp.atomic.write ... omp.terminator } + + omp.atomic.capture { + omp.atomic.read ... + omp.atomic.compare ... + omp.terminator + } ``` }] # clausesDescription; @@ -1947,6 +1953,10 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [ /// Returns the `atomic.update` operation inside the region, if any. /// Otherwise, it returns nullptr. AtomicUpdateOp getAtomicUpdateOp(); + + /// Returns the `atomic.compare` operation inside the region, if any. + /// Otherwise, it returns nullptr. + AtomicCompareOp getAtomicCompareOp(); }] # clausesExtraClassDeclaration; let hasRegionVerifier = 1; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 7c75791611218..f7a96fd2c2d17 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -4655,6 +4655,12 @@ AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() { return dyn_cast(getSecondOp()); } +AtomicCompareOp AtomicCaptureOp::getAtomicCompareOp() { + if (auto op = dyn_cast(getFirstOp())) + return op; + return dyn_cast(getSecondOp()); +} + LogicalResult AtomicCaptureOp::verify() { return verifySynchronizationHint(*this, getHint()); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index d35e8612e158b..dff1ae30e20aa 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -4807,6 +4807,43 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, return success(); } +/// Helper to extract the OMPAtomicCompareOp from an integer comparison +/// predicate. Returns std::nullopt for unsupported predicates. +static std::optional +convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) { + switch (predicate) { + case LLVM::ICmpPredicate::eq: + return llvm::omp::OMPAtomicCompareOp::EQ; + case LLVM::ICmpPredicate::slt: + case LLVM::ICmpPredicate::ult: + return llvm::omp::OMPAtomicCompareOp::MIN; + case LLVM::ICmpPredicate::sgt: + case LLVM::ICmpPredicate::ugt: + return llvm::omp::OMPAtomicCompareOp::MAX; + default: + return std::nullopt; + } +} + +/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison +/// predicate. Returns std::nullopt for unsupported predicates. +static std::optional +convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) { + switch (predicate) { + case LLVM::FCmpPredicate::oeq: + case LLVM::FCmpPredicate::ueq: + return llvm::omp::OMPAtomicCompareOp::EQ; + case LLVM::FCmpPredicate::olt: + case LLVM::FCmpPredicate::ult: + return llvm::omp::OMPAtomicCompareOp::MIN; + case LLVM::FCmpPredicate::ogt: + case LLVM::FCmpPredicate::ugt: + return llvm::omp::OMPAtomicCompareOp::MAX; + default: + return std::nullopt; + } +} + static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, @@ -4815,13 +4852,150 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, if (failed(checkImplementationStatus(*atomicCaptureOp))) return failure(); + omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp(); + omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp(); + omp::AtomicCompareOp atomicCompareOp = atomicCaptureOp.getAtomicCompareOp(); + + // If the capture contains an atomic.compare, delegate to + // createAtomicCompare with the capture variable (V) set. + if (atomicCompareOp) { + omp::AtomicReadOp atomicReadOp = atomicCaptureOp.getAtomicReadOp(); + assert(atomicReadOp && "expected atomic.read in capture+compare"); + + Region ®ion = atomicCompareOp.getRegion(); + Block &block = region.front(); + + llvm::Type *llvmXElementType = + moduleTranslation.convertType(block.getArgument(0).getType()); + llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX()); + llvm::Value *llvmV = moduleTranslation.lookupValue(atomicReadOp.getV()); + + bool isSigned = false; + llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = { + llvmX, llvmXElementType, isSigned, /*IsVolatile=*/false}; + llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = { + llvmV, llvmXElementType, /*isSigned=*/false, /*IsVolatile=*/false}; + llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicR = {nullptr, nullptr, false, + false}; + + llvm::AtomicOrdering atomicOrdering = + convertAtomicOrdering(atomicCaptureOp.getMemoryOrder()); + + // Pre-translate non-pattern operations inside the compare region. + auto isAtomicComparePatternOp = [](Operation &op) { + return llvm::isa(op); + }; + for (Operation &op : block.without_terminator()) { + if (isAtomicComparePatternOp(op)) + continue; + bool allOperandsMapped = + llvm::all_of(op.getOperands(), [&](mlir::Value v) { + return moduleTranslation.lookupValue(v) != nullptr; + }); + if (!allOperandsMapped) + continue; + if (failed(moduleTranslation.convertOperation(op, builder))) + return atomicCompareOp.emitError( + "failed to translate operation inside atomic compare region"); + } + + auto materializeValue = [&](mlir::Value val) -> llvm::Value * { + if (llvm::Value *existing = moduleTranslation.lookupValue(val)) + return existing; + if (auto loadOp = val.getDefiningOp()) { + if (loadOp->getParentRegion() == ®ion) { + llvm::Value *loadAddr = + moduleTranslation.lookupValue(loadOp.getAddr()); + if (!loadAddr) + return nullptr; + llvm::Type *loadType = + moduleTranslation.convertType(loadOp.getResult().getType()); + return builder.CreateLoad(loadType, loadAddr); + } + } + return nullptr; + }; + + // Extract comparison predicate, eVal, and dVal from the region. + llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ; + llvm::Value *eVal = nullptr; + llvm::Value *dVal = nullptr; + bool isXBinopExpr = false; + + for (Operation &op : block.getOperations()) { + if (auto icmpOp = dyn_cast(op)) { + auto maybeOp = + convertICmpPredicateToAtomicCompareOp(icmpOp.getPredicate()); + if (!maybeOp) + return atomicCompareOp.emitError( + "unsupported comparison predicate in atomic compare"); + compareOp = *maybeOp; + + LLVM::ICmpPredicate pred = icmpOp.getPredicate(); + isSigned = (pred == LLVM::ICmpPredicate::slt || + pred == LLVM::ICmpPredicate::sgt || + pred == LLVM::ICmpPredicate::sle || + pred == LLVM::ICmpPredicate::sge); + + isXBinopExpr = (icmpOp.getOperand(0) == block.getArgument(0)); + mlir::Value eOperand = + isXBinopExpr ? icmpOp.getOperand(1) : icmpOp.getOperand(0); + eVal = materializeValue(eOperand); + } else if (auto fcmpOp = dyn_cast(op)) { + auto maybeOp = + convertFCmpPredicateToAtomicCompareOp(fcmpOp.getPredicate()); + if (!maybeOp) + return atomicCompareOp.emitError( + "unsupported comparison predicate in atomic compare"); + compareOp = *maybeOp; + + isXBinopExpr = (fcmpOp.getOperand(0) == block.getArgument(0)); + mlir::Value eOperand = + isXBinopExpr ? fcmpOp.getOperand(1) : fcmpOp.getOperand(0); + eVal = materializeValue(eOperand); + } else if (auto selectOp = dyn_cast(op)) { + if (!dVal) + dVal = materializeValue(selectOp.getTrueValue()); + } + } + + if (!eVal) + return atomicCompareOp.emitError( + "failed to extract expected value (e) from atomic compare region"); + if (!dVal) { + auto yieldOp = cast(block.getTerminator()); + if (yieldOp.getResults().empty()) + return atomicCompareOp.emitError( + "failed to extract desired value (d) from atomic compare region"); + dVal = materializeValue(yieldOp.getResults()[0]); + } + + llvmAtomicX.IsSigned = isSigned; + + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + bool isPostfixUpdate = true; + bool isWeak = atomicCompareOp.getWeak(); + + bool savedHandleFPNegZero = ompBuilder->setHandleFPNegZero(true); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = + ompBuilder->createAtomicCompare( + ompLoc, llvmAtomicX, llvmAtomicV, llvmAtomicR, eVal, dVal, + atomicOrdering, compareOp, isXBinopExpr, isPostfixUpdate, + /*IsFailOnly=*/false, isWeak); + ompBuilder->setHandleFPNegZero(savedHandleFPNegZero); + + if (failed(handleError(afterIP, *atomicCaptureOp))) + return failure(); + + builder.restoreIP(*afterIP); + return success(); + } + mlir::Value mlirExpr; bool isXBinopExpr = false, isPostfixUpdate = false; llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP; - omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp(); - omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp(); - assert((atomicUpdateOp || atomicWriteOp) && "internal op must be an atomic.update or atomic.write op"); @@ -4908,43 +5082,6 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, return success(); } -/// Helper to extract the OMPAtomicCompareOp from an integer comparison -/// predicate. Returns std::nullopt for unsupported predicates. -static std::optional -convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) { - switch (predicate) { - case LLVM::ICmpPredicate::eq: - return llvm::omp::OMPAtomicCompareOp::EQ; - case LLVM::ICmpPredicate::slt: - case LLVM::ICmpPredicate::ult: - return llvm::omp::OMPAtomicCompareOp::MIN; - case LLVM::ICmpPredicate::sgt: - case LLVM::ICmpPredicate::ugt: - return llvm::omp::OMPAtomicCompareOp::MAX; - default: - return std::nullopt; - } -} - -/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison -/// predicate. Returns std::nullopt for unsupported predicates. -static std::optional -convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) { - switch (predicate) { - case LLVM::FCmpPredicate::oeq: - case LLVM::FCmpPredicate::ueq: - return llvm::omp::OMPAtomicCompareOp::EQ; - case LLVM::FCmpPredicate::olt: - case LLVM::FCmpPredicate::ult: - return llvm::omp::OMPAtomicCompareOp::MIN; - case LLVM::FCmpPredicate::ogt: - case LLVM::FCmpPredicate::ugt: - return llvm::omp::OMPAtomicCompareOp::MAX; - default: - return std::nullopt; - } -} - /// Converts an omp.atomic.compare operation to LLVM IR. /// /// if (x == e) x = d diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 4af59f9aff297..87f9a6f8e4119 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -2065,6 +2065,54 @@ func.func @omp_atomic_compare(%x: memref, %e: i32, %d: i32) { return } +// CHECK-LABEL: omp_atomic_compare_capture +// CHECK-SAME: (%[[V:.*]]: memref, %[[X:.*]]: memref, %[[E:.*]]: i32, %[[D:.*]]: i32) +func.func @omp_atomic_compare_capture(%v: memref, %x: memref, %e: i32, %d: i32) { + // CHECK: omp.atomic.capture { + // CHECK-NEXT: omp.atomic.read %[[V]] = %[[X]] : memref, memref, i32 + // CHECK-NEXT: omp.atomic.compare %[[X]] : memref { + // CHECK-NEXT: ^bb0(%[[XVAL:.*]]: i32): + // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi eq, %[[XVAL]], %[[E]] : i32 + // CHECK-NEXT: %[[SEL:.*]] = arith.select %[[CMP]], %[[D]], %[[XVAL]] : i32 + // CHECK-NEXT: omp.yield(%[[SEL]] : i32) + // CHECK-NEXT: } + // CHECK-NEXT: } + omp.atomic.capture { + omp.atomic.read %v = %x : memref, memref, i32 + omp.atomic.compare %x : memref { + ^bb0(%xval: i32): + %cmp = arith.cmpi eq, %xval, %e : i32 + %sel = arith.select %cmp, %d, %xval : i32 + omp.yield(%sel : i32) + } + } + return +} + +// CHECK-LABEL: omp_atomic_compare_capture_weak +// CHECK-SAME: (%[[V:.*]]: memref, %[[X:.*]]: memref, %[[E:.*]]: i32, %[[D:.*]]: i32) +func.func @omp_atomic_compare_capture_weak(%v: memref, %x: memref, %e: i32, %d: i32) { + // CHECK: omp.atomic.capture { + // CHECK-NEXT: omp.atomic.read %[[V]] = %[[X]] : memref, memref, i32 + // CHECK-NEXT: omp.atomic.compare %[[X]] : memref { + // CHECK-NEXT: ^bb0(%[[XVAL:.*]]: i32): + // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[XVAL]], %[[E]] : i32 + // CHECK-NEXT: %[[SEL:.*]] = arith.select %[[CMP]], %[[D]], %[[XVAL]] : i32 + // CHECK-NEXT: omp.yield(%[[SEL]] : i32) + // CHECK-NEXT: } {weak} + // CHECK-NEXT: } + omp.atomic.capture { + omp.atomic.read %v = %x : memref, memref, i32 + omp.atomic.compare %x : memref { + ^bb0(%xval: i32): + %cmp = arith.cmpi sgt, %xval, %e : i32 + %sel = arith.select %cmp, %d, %xval : i32 + omp.yield(%sel : i32) + } {weak} + } + return +} + // CHECK-LABEL: omp_sectionsop func.func @omp_sectionsop(%data_var1 : memref, %data_var2 : memref, %data_var3 : memref, %redn_var : !llvm.ptr) { diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 1890eaa6a3f0b..6ce57b304bfd9 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2795,6 +2795,47 @@ llvm.func @omp_atomic_compare_float_neg_zero(%xf : !llvm.ptr, %ef : f32, %df : f // ----- +// CHECK-LABEL: @omp_atomic_compare_capture_int_eq +// CHECK-SAME: (ptr %[[X:.*]], ptr %[[V:.*]], i32 %[[E:.*]], i32 %[[D:.*]]) +llvm.func @omp_atomic_compare_capture_int_eq(%x : !llvm.ptr, %v : !llvm.ptr, %e : i32, %d : i32) { + // Integer equality compare+capture → cmpxchg + extractvalue + store + // CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[E]], i32 %[[D]] monotonic monotonic + // CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0 + // CHECK: store i32 %[[OLD]], ptr %[[V]] + omp.atomic.capture { + omp.atomic.read %v = %x : !llvm.ptr, !llvm.ptr, i32 + omp.atomic.compare %x : !llvm.ptr { + ^bb0(%xval : i32): + %cmp = llvm.icmp "eq" %xval, %e : i32 + %sel = llvm.select %cmp, %d, %xval : i1, i32 + omp.yield(%sel : i32) + } + } + llvm.return +} + +// ----- + +// CHECK-LABEL: @omp_atomic_compare_capture_weak_int_eq +// CHECK-SAME: (ptr %[[X:.*]], ptr %[[V:.*]], i32 %[[E:.*]], i32 %[[D:.*]]) +llvm.func @omp_atomic_compare_capture_weak_int_eq(%x : !llvm.ptr, %v : !llvm.ptr, %e : i32, %d : i32) { + // Integer equality compare+capture → cmpxchg + extractvalue + store + // CHECK: %[[RES:.*]] = cmpxchg weak ptr %[[X]], i32 %[[E]], i32 %[[D]] monotonic monotonic + // CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0 + // CHECK: store i32 %[[OLD]], ptr %[[V]] + omp.atomic.capture { + omp.atomic.read %v = %x : !llvm.ptr, !llvm.ptr, i32 + omp.atomic.compare %x : !llvm.ptr { + ^bb0(%xval : i32): + %cmp = llvm.icmp "eq" %xval, %e : i32 + %sel = llvm.select %cmp, %d, %xval : i1, i32 + omp.yield(%sel : i32) + } {weak} + } + llvm.return +} +// ----- + // CHECK-LABEL: @omp_sections_empty llvm.func @omp_sections_empty() -> () { omp.sections {