Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 51 additions & 39 deletions flang/lib/Lower/OpenMP/Atomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,16 +560,45 @@ void Fortran::lower::omp::lowerAtomic(
int action1 = analysis.op1.what & analysis.Action;
memOrder = makeValidForAction(memOrder, action0, action1, version);

// --- Shared capture scaffolding ---
mlir::Operation *captureOp = nullptr;
fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
fir::FirOpBuilder::InsertPoint atomicAt, postAt;

if (construct.IsCapture()) {
assert(action0 != analysis.None && action1 != analysis.None &&
"Expexcing two actions");
Comment on lines +569 to +570
(void)action0;
(void)action1;
captureOp = mlir::omp::AtomicCaptureOp::create(
builder, loc, hint, makeMemOrderAttr(converter, memOrder));
// Set the non-atomic insertion point to before the atomic.capture.
preAt = getInsertionPointBefore(captureOp);

mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
builder.setInsertionPointToEnd(block);
// Set the atomic insertion point to before the terminator inside
// atomic.capture.
mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
atomicAt = getInsertionPointBefore(term);
postAt = getInsertionPointAfter(captureOp);
hint = nullptr;
memOrder = std::nullopt;
}

if (auto *cond = get(analysis.cond)) {
// atomic compare: if (x == e) x = d
// e : expecteVal
// d : desiredVal

// Check for compound clauses (fail, capture) that are not yet
// Restore insertion point so pre-processing code (e.g. computing
// expectedVal) is emitted before the capture op, not after the terminator.
builder.restoreInsertionPoint(preAt);

// Check for compound clause (fail) that is not yet
// supported with atomic compare.
if (llvm::any_of(clauses, [](const omp::Clause &clause) {
return clause.id == llvm::omp::Clause::OMPC_fail ||
clause.id == llvm::omp::Clause::OMPC_capture;
return clause.id == llvm::omp::Clause::OMPC_fail;
})) {
TODO(loc, "Compound clauses of OpenMP ATOMIC COMPARE");
}
Expand Down Expand Up @@ -617,6 +646,17 @@ void Fortran::lower::omp::lowerAtomic(
expectedVal = builder.createConvert(loc, elemTypeOfX, expectedVal);
}

// If this is a compare+capture, generate the read op first.
if (construct.IsCapture()) {
assert(get(analysis.op0.assign) && (analysis.op0.what & analysis.Read) &&
"Expected a read operation for compare capture");
mlir::Operation *readOp = genAtomicRead(
converter, semaCtx, loc, stmtCtx, atomAddr, atom,
*get(analysis.op0.assign), hint, memOrder, preAt, atomicAt, postAt);
assert(readOp && "Should have created an atomic read operation");
builder.setInsertionPointAfter(readOp);
}

mlir::UnitAttr weakAttr = nullptr;
if (llvm::any_of(clauses, [](const omp::Clause &clause) {
return clause.id == llvm::omp::Clause::OMPC_weak;
Expand Down Expand Up @@ -685,34 +725,9 @@ void Fortran::lower::omp::lowerAtomic(
// Generate omp.yield
mlir::omp::YieldOp::create(builder, loc, newVal);
builder.setInsertionPointAfter(atomicOp);

// END omp atomic compare
} else {
mlir::Operation *captureOp = nullptr;
fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
fir::FirOpBuilder::InsertPoint atomicAt, postAt;

if (construct.IsCapture()) {
// Capturing operation.
assert(action0 != analysis.None && action1 != analysis.None &&
"Expexcing two actions");
(void)action0;
(void)action1;
captureOp = mlir::omp::AtomicCaptureOp::create(
builder, loc, hint, makeMemOrderAttr(converter, memOrder));
// Set the non-atomic insertion point to before the atomic.capture.
preAt = getInsertionPointBefore(captureOp);

mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
builder.setInsertionPointToEnd(block);
// Set the atomic insertion point to before the terminator inside
// atomic.capture.
mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
atomicAt = getInsertionPointBefore(term);
postAt = getInsertionPointAfter(captureOp);
hint = nullptr;
memOrder = std::nullopt;
} else {
if (!construct.IsCapture()) {
// Non-capturing operation.
assert(action0 != analysis.None && action1 == analysis.None &&
"Expexcing single action");
Comment on lines 732 to 733
Expand All @@ -735,16 +750,13 @@ void Fortran::lower::omp::lowerAtomic(
*get(analysis.op1.assign), hint, memOrder, preAt, atomicAt, postAt);
}

if (construct.IsCapture()) {
// If this is a capture operation, the first/second ops will be inside
// of it. Set the insertion point to past the capture op itself.
builder.restoreInsertionPoint(postAt);
} else {
if (secondOp) {
builder.setInsertionPointAfter(secondOp);
} else {
builder.setInsertionPointAfter(firstOp);
}
if (!construct.IsCapture()) {
builder.setInsertionPointAfter(secondOp ? secondOp : firstOp);
}
}

// Shared capture cleanup.
if (construct.IsCapture()) {
builder.restoreInsertionPoint(postAt);
}
}
32 changes: 32 additions & 0 deletions flang/test/Integration/OpenMP/atomic-compare.f90
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,35 @@ subroutine atomic_compare_weak(x, e, d)
if (x == e) x = d
end

! Integer equality compare+capture: cmpxchg + store old value
!CHECK-LABEL: define void @atomic_compare_capture_int_eq_(
!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
!CHECK: store i32 %[[OLD]], ptr %[[V]]
subroutine atomic_compare_capture_int_eq(x, e, d, v)
integer :: x, e, d, v
!$omp atomic compare capture
v = x
if (x == e) x = d
!$omp end atomic
end

! Compare+capture with clause order reversed: capture compare
!CHECK-LABEL: define void @atomic_capture_compare_int_eq_(
!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
!CHECK: store i32 %[[OLD]], ptr %[[V]]
subroutine atomic_capture_compare_int_eq(x, e, d, v)
integer :: x, e, d, v
!$omp atomic capture compare
v = x
if (x == e) x = d
!$omp end atomic
end

56 changes: 56 additions & 0 deletions flang/test/Lower/OpenMP/atomic-compare.f90
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,59 @@ subroutine atomic_compare_int_eq_weak(x, e, d)
!$omp atomic compare weak
if (x .eq. e) x = d
end

! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_eq(
! CHECK-SAME: %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
! CHECK-SAME: %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
! CHECK-SAME: %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
! CHECK-SAME: %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
! CHECK: %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
! CHECK: %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
! CHECK: omp.atomic.capture memory_order(relaxed) {
! CHECK: omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
! CHECK: omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
! CHECK: ^bb0(%[[XVAL:.*]]: i32):
! CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[XVAL]], %[[EVAL]] : i32
! CHECK: %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
! CHECK: omp.yield(%[[SEL]] : i32)
! CHECK: }
! CHECK: }
subroutine atomic_compare_capture_int_eq(x, e, d, v)
integer :: x, e, d, v
!$omp atomic compare capture
v = x
if (x .eq. e) x = d
!$omp end atomic
end

! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_gt(
! CHECK-SAME: %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
! CHECK-SAME: %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
! CHECK-SAME: %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
! CHECK-SAME: %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
! CHECK: %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
! CHECK: %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
! CHECK: omp.atomic.capture memory_order(relaxed) {
! CHECK: omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
! CHECK: omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
! CHECK: ^bb0(%[[XVAL:.*]]: i32):
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[XVAL]], %[[EVAL]] : i32
! CHECK: %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
! CHECK: omp.yield(%[[SEL]] : i32)
! CHECK: } {{.*}}weak{{.*}}
! CHECK: }
subroutine atomic_compare_capture_int_gt(x, e, d, v)
integer :: x, e, d, v
!$omp atomic compare capture weak
v = x
if (x > e) x = d
!$omp end atomic
end
29 changes: 29 additions & 0 deletions flang/test/Parser/OpenMP/atomic-unparse.f90
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,26 @@ program main
i = j
end if

!COMPARE CAPTURE
!$omp atomic compare capture
k = i
if (i .eq. j) then
i = k
end if
!$omp end atomic
!$omp atomic capture compare
k = i
if (i .eq. j) then
i = k
end if
!$omp end atomic
!$omp atomic capture compare weak
k = i
if (i < j) then
i = k
end if
!$omp end atomic

!ATOMIC
!$omp atomic
i = j
Expand Down Expand Up @@ -296,6 +316,15 @@ end program main
!CHECK: !$OMP ATOMIC WEAK COMPARE
!CHECK: !$OMP ATOMIC COMPARE SEQ_CST WEAK

!COMPARE CAPTURE

!CHECK: !$OMP ATOMIC COMPARE CAPTURE
!CHECK: !$OMP END ATOMIC
!CHECK: !$OMP ATOMIC CAPTURE COMPARE
!CHECK: !$OMP END ATOMIC
!CHECK: !$OMP ATOMIC CAPTURE COMPARE WEAK
!CHECK: !$OMP END ATOMIC

!ATOMIC
!CHECK: !$OMP ATOMIC
!CHECK: !$OMP ATOMIC SEQ_CST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,12 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
auto secondReadStmt = dyn_cast<AtomicReadOpInterface>(secondOp);
auto secondUpdateStmt = dyn_cast<AtomicUpdateOpInterface>(secondOp);
auto secondWriteStmt = dyn_cast<AtomicWriteOpInterface>(secondOp);
auto secondCompareStmt = dyn_cast<AtomicCompareOpInterface>(secondOp);

if (!((firstUpdateStmt && secondReadStmt) ||
(firstReadStmt && secondUpdateStmt) ||
(firstReadStmt && secondWriteStmt)))
(firstReadStmt && secondWriteStmt) ||
(firstReadStmt && secondCompareStmt)))
return ops.front().emitError()
<< "invalid sequence of operations in the capture region";
if (firstUpdateStmt && secondReadStmt &&
Expand All @@ -310,6 +312,11 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
return firstReadStmt.emitError()
<< "captured variable in atomic.read must be updated in "
"second operation";
if (firstReadStmt && secondCompareStmt &&
firstReadStmt.getX() != secondCompareStmt.getX())
return firstReadStmt.emitError()
<< "captured variable in atomic.read must be updated in "
"second operation";
Comment on lines +317 to +319

return mlir::success();
}]
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1929,6 +1929,12 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
omp.atomic.write ...
omp.terminator
}

omp.atomic.capture {
omp.atomic.read ...
omp.atomic.compare ...
omp.terminator
}
```
}] # clausesDescription;

Expand All @@ -1947,6 +1953,10 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
/// Returns the `atomic.update` operation inside the region, if any.
/// Otherwise, it returns nullptr.
AtomicUpdateOp getAtomicUpdateOp();

/// Returns the `atomic.compare` operation inside the region, if any.
/// Otherwise, it returns nullptr.
AtomicCompareOp getAtomicCompareOp();
}] # clausesExtraClassDeclaration;

let hasRegionVerifier = 1;
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4655,6 +4655,12 @@ AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
return dyn_cast<AtomicUpdateOp>(getSecondOp());
}

AtomicCompareOp AtomicCaptureOp::getAtomicCompareOp() {
if (auto op = dyn_cast<AtomicCompareOp>(getFirstOp()))
return op;
return dyn_cast<AtomicCompareOp>(getSecondOp());
}

LogicalResult AtomicCaptureOp::verify() {
return verifySynchronizationHint(*this, getHint());
}
Expand Down
Loading