Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ with the exception that minor releases may include breaking changes.
[#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623],
[#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700],
[#1710], [#1717], [#1728], [#1730], [#1749], [#1751], [#1762], [#1765],
[#1774], [#1780], [#1781], [#1782], [#1787], [#1802])
[#1774], [#1780], [#1781], [#1782], [#1787], [#1802], [#1806])
([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**],
[**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**],
[**@simon1hofmann**])
Expand Down Expand Up @@ -598,6 +598,7 @@ changelogs._

<!-- PR links -->

[#1806]: https://github.com/munich-quantum-toolkit/core/pull/1806
[#1802]: https://github.com/munich-quantum-toolkit/core/pull/1802
[#1787]: https://github.com/munich-quantum-toolkit/core/pull/1787
[#1782]: https://github.com/munich-quantum-toolkit/core/pull/1782
Expand Down
18 changes: 14 additions & 4 deletions mlir/include/mlir/Dialect/QCO/Utils/WireIterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/QTensor/IR/QTensorOps.h"

#include <mlir/IR/Operation.h>
#include <mlir/IR/Value.h>

#include <cstdint>
#include <iterator>
Expand All @@ -32,9 +33,11 @@ class [[nodiscard]] WireIterator {
using difference_type = std::ptrdiff_t;
using value_type = Operation*;

WireIterator() : op_(nullptr), qubit_(nullptr), isSentinel_(false) {}
WireIterator()
: op_(nullptr), qubit_(nullptr), isFinal_(false), isSentinel_(false) {}
explicit WireIterator(Value qubit)
: op_(qubit.getDefiningOp()), qubit_(qubit), isSentinel_(false) {}
: op_(qubit.getDefiningOp()), qubit_(qubit), isFinal_(false),
isSentinel_(false) {}

/// @returns the operation the iterator points to.
[[nodiscard]] Operation* operation() const { return op_; }
Expand Down Expand Up @@ -77,14 +80,21 @@ class [[nodiscard]] WireIterator {
}

private:
/// @brief Move to the next operation on the qubit wire.
/// Return true, if an op doesn't return, but only consumes, a qubit value.
static bool isSinkLikeOperation(Operation* op);

/// Return true, if an op doesn't consume, but only returns, a qubit value.
static bool isSourceLikeOperation(Operation* op);

/// Move to the next operation on the qubit wire.
void forward();

/// @brief Move to the previous operation on the qubit wire.
/// Move to the previous operation on the qubit wire.
void backward();

Operation* op_;
Value qubit_;
bool isFinal_;
bool isSentinel_;
};

Expand Down
80 changes: 65 additions & 15 deletions mlir/lib/Dialect/QCO/Utils/WireIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/QCO/IR/QCOOps.h"
#include "mlir/Dialect/QTensor/IR/QTensorOps.h"

#include <llvm/ADT/STLExtras.h>
#include <llvm/ADT/TypeSwitch.h>
#include <llvm/Support/ErrorHandling.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
Expand All @@ -25,13 +26,20 @@
#include <iterator>

namespace mlir::qco {

bool WireIterator::isSinkLikeOperation(Operation* op) {
return isa<SinkOp, YieldOp, qtensor::InsertOp, scf::YieldOp>(op);
}

bool WireIterator::isSourceLikeOperation(Operation* op) {
return isa<AllocOp, StaticOp, qtensor::ExtractOp>(op);
}

Value WireIterator::qubit() const {
// Boundary ops (sink/deallocation/insert/yield) consume the wire via an
// operand and have no OpResult, matching the boundaries in forward/backward.
if (op_ != nullptr &&
(isa<SinkOp, YieldOp, qtensor::InsertOp, scf::YieldOp>(op_))) {
if (op_ != nullptr && isSinkLikeOperation(op_)) {
return nullptr;
}

return qubit_;
}

Expand All @@ -41,25 +49,38 @@ void WireIterator::forward() {
return;
}

// After the final operation comes the sentinel.
if (isFinal_) {
isSentinel_ = true;
return;
}

// Find the user-operation of the qubit SSA value.
assert(qubit_.hasOneUse() && "expected linear typing");
op_ = *(qubit_.user_begin());

// A sink/insert/yield or region entry defines the end of the qubit wire.
if (isa<SinkOp, YieldOp, qtensor::InsertOp, scf::YieldOp, scf::ForOp,
scf::IfOp, scf::WhileOp>(op_)) {
isSentinel_ = true;
if (isSinkLikeOperation(op_)) {
isFinal_ = true;
return;
}

if (!(isa<AllocOp, StaticOp, qtensor::ExtractOp>(op_))) {
if (!isSourceLikeOperation(op_)) {
// Find the output from the input qubit SSA value.
TypeSwitch<Operation*>(op_)
.Case<UnitaryOpInterface>([&](UnitaryOpInterface op) {
qubit_ = op.getOutputForInput(qubit_);
})
.Case<MeasureOp>([&](MeasureOp op) { qubit_ = op.getQubitOut(); })
.Case<ResetOp>([&](ResetOp op) { qubit_ = op.getQubitOut(); })
.Case<scf::ForOp, scf::WhileOp>([&](auto op) {
qubit_ = op.getTiedLoopResult(&*(qubit_.use_begin()));
})
.Case<qco::IfOp>([&](qco::IfOp op) {
auto it = llvm::find(op.getQubits(), qubit_);
assert(it != op.getQubits().end());
const auto idx = std::distance(op.getQubits().begin(), it);
qubit_ = op.getResults()[idx];
})
.Default([&](Operation* op) {
llvm::reportFatalInternalError("unknown op in def-use chain: " +
op->getName().getStringRef());
Expand All @@ -71,20 +92,26 @@ void WireIterator::backward() {
// If the iterator is a sentinel, reactivate the iterator.
if (isSentinel_) {
isSentinel_ = false;
isFinal_ = true;
return;
}

// For sinks/deallocations/inserts/yields, qubit_ is an OpOperand. Hence, only
// get the def-op.
if (isa<SinkOp, YieldOp, qtensor::InsertOp, scf::YieldOp, scf::ForOp,
scf::IfOp, scf::WhileOp>(op_)) {
// If the op is a nullptr, the qubit value is a block argument and thus the
// beginning of the qubit wire.
if (op_ == nullptr) {
return;
}

// For these operations, qubit_ is an OpOperand. Hence, only get the def-op.
if (isSinkLikeOperation(op_)) {
op_ = qubit_.getDefiningOp();
isFinal_ = false;
return;
}

// Allocations or static definitions define the start of the qubit wire.
// Source-like ops define the start of the qubit wire.
// Consequently, stop and early exit.
if (isa<AllocOp, StaticOp, qtensor::ExtractOp>(op_)) {
if (isSourceLikeOperation(op_)) {
return;
}

Expand All @@ -94,6 +121,28 @@ void WireIterator::backward() {
[&](UnitaryOpInterface op) { qubit_ = op.getInputForOutput(qubit_); })
.Case<MeasureOp>([&](MeasureOp op) { qubit_ = op.getQubitIn(); })
.Case<ResetOp>([&](ResetOp op) { qubit_ = op.getQubitIn(); })
.Case<scf::ForOp, scf::WhileOp>([&](auto op) {
if (auto res = dyn_cast<OpResult>(qubit_)) {
OpOperand* operand = op.getTiedLoopInit(res);
qubit_ = operand->get();
return;
}

llvm::reportFatalInternalError(
"expected scf.for result for tied init lookup");
})
.Case<qco::IfOp>([&](qco::IfOp op) {
if (auto res = dyn_cast<OpResult>(qubit_)) {
auto it = llvm::find(op.getResults(), res);
assert(it != op.getResults().end());
const auto idx = std::distance(op.getResults().begin(), it);
qubit_ = op.getQubits()[idx];
return;
}

llvm::reportFatalInternalError(
"expected scf.if result for tied init lookup");
})
Comment thread
MatthiasReumann marked this conversation as resolved.
.Default([&](Operation* op) {
llvm::reportFatalInternalError("unknown op in def-use chain: " +
op->getName().getStringRef());
Expand All @@ -103,6 +152,7 @@ void WireIterator::backward() {
// If the current qubit SSA value is a BlockArgument (no defining op), the
// operation will be a nullptr.
op_ = qubit_.getDefiningOp();
isFinal_ = false;
}

static_assert(std::bidirectional_iterator<WireIterator>);
Expand Down
98 changes: 92 additions & 6 deletions mlir/unittests/Dialect/QCO/Utils/test_wireiterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
#include <gtest/gtest.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/DialectRegistry.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/Support/LLVM.h>

#include <iterator>
#include <memory>
#include <tuple>
#include <utility>

using namespace mlir;
Expand All @@ -29,7 +32,8 @@ class WireIteratorTest : public testing::TestWithParam<bool> {
protected:
void SetUp() override {
DialectRegistry registry;
registry.insert<qco::QCODialect, arith::ArithDialect, func::FuncDialect>();
registry.insert<qco::QCODialect, scf::SCFDialect, arith::ArithDialect,
func::FuncDialect>();

context = std::make_unique<MLIRContext>();
context->appendDialectRegistry(registry);
Expand All @@ -40,20 +44,44 @@ class WireIteratorTest : public testing::TestWithParam<bool> {
};
} // namespace

TEST_P(WireIteratorTest, MixedUse) {
TEST_P(WireIteratorTest, Traversal) {
const bool isDynamic = GetParam();

// Build circuit.
qco::QCOProgramBuilder builder(context.get());
builder.initialize();

const auto q00 = isDynamic ? builder.allocQubit() : builder.staticQubit(0);
const auto q10 = isDynamic ? builder.allocQubit() : builder.staticQubit(1);
const auto q01 = builder.h(q00);
const auto [q02, q11] = builder.cx(q01, q10);
const auto [q03, c0] = builder.measure(q02);
const auto q04 = builder.reset(q03);
builder.sink(q04);
builder.sink(q11);

Value iterQ00;
Value iterQ01;
Value iterQ02;
Value iterQ10;
Value iterQ11;

const auto loopOut =
builder.scfFor(1, 4, 1, {q04, q11}, [&](Value, ValueRange iterArgs) {
iterQ00 = iterArgs[0];
iterQ10 = iterArgs[1];
iterQ01 = builder.h(iterQ00);
std::tie(iterQ02, iterQ11) = builder.cx(iterQ01, iterQ10);
return SmallVector{iterQ02, iterQ11};
});
const auto q05 = loopOut[0];
const auto q12 = loopOut[1];
const auto ifOut = builder.qcoIf(
true, {q05, q12},
[&](ValueRange args) { return SmallVector{args[0], args[1]}; },
[&](ValueRange args) { return SmallVector{args[0], args[1]}; });
const auto q06 = ifOut[0];
const auto q13 = ifOut[1];
builder.sink(q06);
builder.sink(q13);
[[maybe_unused]] auto module = builder.finalize();

// Setup WireIterator.
Expand Down Expand Up @@ -83,7 +111,15 @@ TEST_P(WireIteratorTest, MixedUse) {
ASSERT_EQ(it.qubit(), q04);

++it;
ASSERT_EQ(it.operation(), *(q04.getUsers().begin())); // qco.sink
ASSERT_EQ(it.operation(), q05.getDefiningOp()); // scf.for
ASSERT_EQ(it.qubit(), q05);

++it;
ASSERT_EQ(it.operation(), q06.getDefiningOp()); // qco.if
ASSERT_EQ(it.qubit(), q06);

++it;
ASSERT_EQ(it.operation(), *(q06.getUsers().begin())); // qco.sink
ASSERT_EQ(it.qubit(), nullptr);

++it;
Expand All @@ -97,9 +133,17 @@ TEST_P(WireIteratorTest, MixedUse) {
//

--it;
ASSERT_EQ(it.operation(), *(q04.getUsers().begin())); // qco.sink
ASSERT_EQ(it.operation(), *(q06.getUsers().begin())); // qco.sink
ASSERT_EQ(it.qubit(), nullptr);

--it;
ASSERT_EQ(it.operation(), q06.getDefiningOp()); // qco.if
ASSERT_EQ(it.qubit(), q06);

--it;
ASSERT_EQ(it.operation(), q05.getDefiningOp()); // scf.for
ASSERT_EQ(it.qubit(), q05);

--it;
ASSERT_EQ(it.operation(), q04.getDefiningOp()); // qco.reset
ASSERT_EQ(it.qubit(), q04);
Expand All @@ -123,6 +167,48 @@ TEST_P(WireIteratorTest, MixedUse) {
--it;
ASSERT_EQ(it.operation(), q00.getDefiningOp()); // qco.alloc or qco.static
ASSERT_EQ(it.qubit(), q00);

//
// Test: Recursive use with block-argument.
//

qco::WireIterator recIt(iterQ00);
ASSERT_EQ(recIt.operation(), nullptr); // Blockargument
ASSERT_EQ(recIt.qubit(), iterQ00);

++recIt;
ASSERT_EQ(recIt.operation(), iterQ01.getDefiningOp()); // qco.h
ASSERT_EQ(recIt.qubit(), iterQ01);

++recIt;
ASSERT_EQ(recIt.operation(), iterQ02.getDefiningOp()); // qco.ctrl
ASSERT_EQ(recIt.qubit(), iterQ02);

++recIt;
ASSERT_EQ(recIt.operation(), *(iterQ02.getUsers().begin())); // scf.yield
ASSERT_EQ(recIt.qubit(), nullptr);

++recIt;
ASSERT_EQ(recIt, std::default_sentinel);

++recIt;
ASSERT_EQ(recIt, std::default_sentinel);

--recIt;
ASSERT_EQ(recIt.operation(), *(iterQ02.getUsers().begin())); // scf.yield
ASSERT_EQ(recIt.qubit(), nullptr);

--recIt;
ASSERT_EQ(recIt.operation(), iterQ02.getDefiningOp()); // qco.ctrl
ASSERT_EQ(recIt.qubit(), iterQ02);

--recIt;
ASSERT_EQ(recIt.operation(), iterQ01.getDefiningOp()); // qco.h
ASSERT_EQ(recIt.qubit(), iterQ01);

--recIt;
ASSERT_EQ(recIt.operation(), nullptr); // Blockargument
ASSERT_EQ(recIt.qubit(), iterQ00);
}

INSTANTIATE_TEST_SUITE_P(DynamicAndStatic, WireIteratorTest, ::testing::Bool(),
Expand Down
Loading