From 710bca1335c198a058adb0af1d2af9bd0b1eb3f8 Mon Sep 17 00:00:00 2001 From: Matthias Reumann Date: Wed, 24 Jun 2026 13:40:16 +0200 Subject: [PATCH 1/7] Update driver to support scf ops --- mlir/include/mlir/Dialect/QCO/Utils/Drivers.h | 121 ++++++++++-------- 1 file changed, 67 insertions(+), 54 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h b/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h index 0a59e7d83e..ed0f77feba 100644 --- a/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h +++ b/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h @@ -10,7 +10,6 @@ #pragma once -#include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOInterfaces.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/Utils/Qubits.h" @@ -18,11 +17,15 @@ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include +#include +#include #include #include #include #include +#include +#include #include #include #include @@ -88,13 +91,33 @@ LogicalResult walkProgram(Region& region, const WalkProgramFn& fn) { return success(); } -using ReleasedOps = SmallVector; -using PendingWiresMap = - DenseMap>; +using ReleasedOps = SmallVector; +using PendingWiresMap = DenseMap>; + +namespace impl { +/// Return the number of qubits a operation produces/consumes. +inline size_t getNumQubits(Operation* op) { + return TypeSwitch(op) + .Case( + [&](UnitaryOpInterface op) { return op.getNumQubits(); }) + .Case([&](scf::ForOp op) { return op.getInits().size(); }) + .Case([&](scf::WhileOp op) { return op.getInits().size(); }) + .Case([&](qco::IfOp op) { return op.getQubits().size(); }) + .Case( + [&](auto) { return 1; }) + .Default([&](Operation* op) { + const auto name = op->getName().getStringRef(); + reportFatalInternalError("unknown op: " + name); + return 0; + }); +} +} // namespace impl struct IsReady { bool operator()(PendingWiresMap::value_type& kv) const { - return kv.second.size() == kv.first.getNumQubits(); + const auto npending = kv.second.size(); + return impl::getNumQubits(kv.first) == npending; } }; @@ -134,67 +157,57 @@ LogicalResult walkProgramGraph(MutableArrayRef wires, using Traits = WireTraversalTraits; ReleasedOps released; - PendingWiresMap pending; pending.reserve(wires.size()); - SmallVector curr(wires.size()); + SmallVector curr(wires.size()); std::iota(curr.begin(), curr.end(), 0UL); - SmallVector next; + SmallVector next; next.reserve(wires.size()); while (!curr.empty()) { - for (std::size_t i : curr) { + for (size_t i : curr) { auto& it = wires[i]; + + if (it.operation() == nullptr) { // isa + std::ranges::advance(it, Traits::stride()); + continue; + } + while (Traits::isActive(it)) { - const auto res = - TypeSwitch(it.operation()) - .template Case([&](UnitaryOpInterface& op) { - // If there are fewer wires than the qubit requires inputs, - // it's impossible to release the operation. Hence, fail. - if (op.getNumQubits() > wires.size()) { - return WalkResult::interrupt(); - } - - if (op.getNumQubits() == 1) { - std::ranges::advance(it, Traits::stride()); - return WalkResult::advance(); - } - - // Insert the unitary to the pending map. - // The caller decides if this op should be released. - const auto [mapIt, inserted] = pending.try_emplace(op); - auto& indices = mapIt->second; - - if (inserted) { - indices.reserve(op.getNumQubits()); - } - - indices.emplace_back(i); - - return WalkResult::skip(); // Stop at multi-qubit gate. - }) - // AllocOp, StaticOp, and qtensor::ExtractOp are only reachable - // on the forward path; backward isActive() halts before - // reaching them (decrementing at a source op is a no-op). - .template Case([&](auto) { - std::ranges::advance(it, Traits::stride()); - return WalkResult::advance(); - }) - .Default([&](Operation* op) -> WalkResult { - const auto name = op->getName().getStringRef(); - report_fatal_error("unknown op encountered: " + name); - }); - - if (res.wasSkipped()) { - break; + assert(it.operation() != nullptr); + + const auto nqubits = impl::getNumQubits(it.operation()); + + assert(nqubits != 0); + if (nqubits == 1) { + std::ranges::advance(it, Traits::stride()); + continue; } - if (res.wasInterrupted()) { + assert(nqubits >= 2); + + // If there are fewer wires than the operation requires inputs, + // it's impossible to release the operation. Hence, fail. + + if (nqubits > wires.size()) { return failure(); } + + // Insert the unitary to the pending map. + // The caller decides if this op should be released. + + const auto [mapIt, inserted] = pending.try_emplace(it.operation()); + auto& indices = mapIt->second; + + if (inserted) { + indices.reserve(nqubits); + } + + indices.emplace_back(i); + + break; // Stop at multi-qubit unitary. } } @@ -212,11 +225,11 @@ LogicalResult walkProgramGraph(MutableArrayRef wires, } } - for (const UnitaryOpInterface& op : released) { + for (Operation* op : released) { const auto mapIt = pending.find(op); assert(mapIt != pending.end()); - for (std::size_t i : mapIt->second) { + for (size_t i : mapIt->second) { std::ranges::advance(wires[i], Traits::stride()); next.emplace_back(i); } From b7bbc71ebc53073bb281f013c080caa4bacf7c81 Mon Sep 17 00:00:00 2001 From: Matthias Reumann Date: Wed, 24 Jun 2026 13:46:31 +0200 Subject: [PATCH 2/7] Update CHANGELOG.md [no ci] --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 614fe3c174..24b4547851 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,7 +42,7 @@ with the exception that minor releases may include breaking changes. [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1710], [#1717], [#1728], [#1730], [#1749], [#1751], [#1762], [#1765], - [#1774], [#1780], [#1781], [#1782], [#1787], [#1802]) + [#1774], [#1780], [#1781], [#1782], [#1787], [#1802], [#1808]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**], [**@simon1hofmann**]) @@ -598,6 +598,7 @@ changelogs._ +[#1808]: https://github.com/munich-quantum-toolkit/core/pull/1808 [#1802]: https://github.com/munich-quantum-toolkit/core/pull/1802 [#1787]: https://github.com/munich-quantum-toolkit/core/pull/1787 [#1782]: https://github.com/munich-quantum-toolkit/core/pull/1782 From ff56c5583546a94a403c0206509109c29456930d Mon Sep 17 00:00:00 2001 From: Matthias Reumann Date: Wed, 24 Jun 2026 14:53:04 +0200 Subject: [PATCH 3/7] Extend test suite --- mlir/include/mlir/Dialect/QCO/Utils/Drivers.h | 2 - .../mlir/Dialect/QCO/Utils/WireIterator.h | 18 +++- mlir/lib/Dialect/QCO/Utils/WireIterator.cpp | 80 +++++++++++++---- .../Dialect/QCO/Utils/test_drivers.cpp | 86 ++++++++++++++++--- 4 files changed, 151 insertions(+), 35 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h b/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h index ed0f77feba..c098af10b4 100644 --- a/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h +++ b/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h @@ -172,12 +172,10 @@ LogicalResult walkProgramGraph(MutableArrayRef wires, if (it.operation() == nullptr) { // isa std::ranges::advance(it, Traits::stride()); - continue; } while (Traits::isActive(it)) { assert(it.operation() != nullptr); - const auto nqubits = impl::getNumQubits(it.operation()); assert(nqubits != 0); diff --git a/mlir/include/mlir/Dialect/QCO/Utils/WireIterator.h b/mlir/include/mlir/Dialect/QCO/Utils/WireIterator.h index 6e5b395413..22218a3874 100644 --- a/mlir/include/mlir/Dialect/QCO/Utils/WireIterator.h +++ b/mlir/include/mlir/Dialect/QCO/Utils/WireIterator.h @@ -14,6 +14,7 @@ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include +#include #include #include @@ -32,9 +33,11 @@ class [[nodiscard]] WireIterator { using difference_type = std::ptrdiff_t; using value_type = Operation*; - WireIterator() : op_(nullptr), qubit_(nullptr), isSentinel_(false) {} + WireIterator() + : op_(nullptr), qubit_(nullptr), isFinal_(false), isSentinel_(false) {} explicit WireIterator(Value qubit) - : op_(qubit.getDefiningOp()), qubit_(qubit), isSentinel_(false) {} + : op_(qubit.getDefiningOp()), qubit_(qubit), isFinal_(false), + isSentinel_(false) {} /// @returns the operation the iterator points to. [[nodiscard]] Operation* operation() const { return op_; } @@ -77,14 +80,21 @@ class [[nodiscard]] WireIterator { } private: - /// @brief Move to the next operation on the qubit wire. + /// Return true, if an op doesn't return, but only consumes, a qubit value. + static bool isSinkLikeOperation(Operation* op); + + /// Return true, if an op doesn't consume, but only returns, a qubit value. + static bool isSourceLikeOperation(Operation* op); + + /// Move to the next operation on the qubit wire. void forward(); - /// @brief Move to the previous operation on the qubit wire. + /// Move to the previous operation on the qubit wire. void backward(); Operation* op_; Value qubit_; + bool isFinal_; bool isSentinel_; }; diff --git a/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp b/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp index bb7bb6932f..7958567e55 100644 --- a/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp +++ b/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include #include #include #include @@ -25,13 +26,20 @@ #include namespace mlir::qco { + +bool WireIterator::isSinkLikeOperation(Operation* op) { + return isa(op); +} + +bool WireIterator::isSourceLikeOperation(Operation* op) { + return isa(op); +} + Value WireIterator::qubit() const { - // Boundary ops (sink/deallocation/insert/yield) consume the wire via an - // operand and have no OpResult, matching the boundaries in forward/backward. - if (op_ != nullptr && - (isa(op_))) { + if (op_ != nullptr && isSinkLikeOperation(op_)) { return nullptr; } + return qubit_; } @@ -41,18 +49,22 @@ void WireIterator::forward() { return; } + // After the final operation comes the sentinel. + if (isFinal_) { + isSentinel_ = true; + return; + } + // Find the user-operation of the qubit SSA value. assert(qubit_.hasOneUse() && "expected linear typing"); op_ = *(qubit_.user_begin()); - // A sink/insert/yield or region entry defines the end of the qubit wire. - if (isa(op_)) { - isSentinel_ = true; + if (isSinkLikeOperation(op_)) { + isFinal_ = true; return; } - if (!(isa(op_))) { + if (!isSourceLikeOperation(op_)) { // Find the output from the input qubit SSA value. TypeSwitch(op_) .Case([&](UnitaryOpInterface op) { @@ -60,6 +72,15 @@ void WireIterator::forward() { }) .Case([&](MeasureOp op) { qubit_ = op.getQubitOut(); }) .Case([&](ResetOp op) { qubit_ = op.getQubitOut(); }) + .Case([&](auto op) { + qubit_ = op.getTiedLoopResult(&*(qubit_.use_begin())); + }) + .Case([&](qco::IfOp op) { + auto it = llvm::find(op.getQubits(), qubit_); + assert(it != op.getQubits().end()); + const auto idx = std::distance(op.getQubits().begin(), it); + qubit_ = op.getResults()[idx]; + }) .Default([&](Operation* op) { llvm::reportFatalInternalError("unknown op in def-use chain: " + op->getName().getStringRef()); @@ -71,20 +92,26 @@ void WireIterator::backward() { // If the iterator is a sentinel, reactivate the iterator. if (isSentinel_) { isSentinel_ = false; + isFinal_ = true; return; } - // For sinks/deallocations/inserts/yields, qubit_ is an OpOperand. Hence, only - // get the def-op. - if (isa(op_)) { + // If the op is a nullptr, the qubit value is a block argument and thus the + // beginning of the qubit wire. + if (op_ == nullptr) { + return; + } + + // For these operations, qubit_ is an OpOperand. Hence, only get the def-op. + if (isSinkLikeOperation(op_)) { op_ = qubit_.getDefiningOp(); + isFinal_ = false; return; } - // Allocations or static definitions define the start of the qubit wire. + // Source-like ops define the start of the qubit wire. // Consequently, stop and early exit. - if (isa(op_)) { + if (isSourceLikeOperation(op_)) { return; } @@ -94,6 +121,28 @@ void WireIterator::backward() { [&](UnitaryOpInterface op) { qubit_ = op.getInputForOutput(qubit_); }) .Case([&](MeasureOp op) { qubit_ = op.getQubitIn(); }) .Case([&](ResetOp op) { qubit_ = op.getQubitIn(); }) + .Case([&](auto op) { + if (auto res = dyn_cast(qubit_)) { + OpOperand* operand = op.getTiedLoopInit(res); + qubit_ = operand->get(); + return; + } + + llvm::reportFatalInternalError( + "expected scf.for result for tied init lookup"); + }) + .Case([&](qco::IfOp op) { + if (auto res = dyn_cast(qubit_)) { + auto it = llvm::find(op.getResults(), res); + assert(it != op.getResults().end()); + const auto idx = std::distance(op.getResults().begin(), it); + qubit_ = op.getQubits()[idx]; + return; + } + + llvm::reportFatalInternalError( + "expected scf.if result for tied init lookup"); + }) .Default([&](Operation* op) { llvm::reportFatalInternalError("unknown op in def-use chain: " + op->getName().getStringRef()); @@ -103,6 +152,7 @@ void WireIterator::backward() { // If the current qubit SSA value is a BlockArgument (no defining op), the // operation will be a nullptr. op_ = qubit_.getDefiningOp(); + isFinal_ = false; } static_assert(std::bidirectional_iterator); diff --git a/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp b/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp index 616e08d66d..b4aa99c677 100644 --- a/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp +++ b/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp @@ -18,11 +18,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -36,7 +38,8 @@ class DriversTest : public testing::Test { protected: void SetUp() override { DialectRegistry registry; - registry.insert(); + registry.insert(); context = std::make_unique(); context->appendDialectRegistry(registry); @@ -99,6 +102,27 @@ TEST_F(DriversTest, ProgramWalk) { ASSERT_EQ(ex3, q31); } +TEST_F(DriversTest, ProgramGraphWalkTooFewWires) { + qco::QCOProgramBuilder builder(context.get()); + builder.initialize(); + + const auto q00 = builder.allocQubit(); + const auto q10 = builder.allocQubit(); + const auto [q01, q11] = builder.cx(q00, q10); + + [[maybe_unused]] auto mod = builder.finalize(); + + // Collect just one wire. + SmallVector wires; + wires.emplace_back(q00); + + auto res = qco::walkProgramGraph( + wires, [&](const qco::ReadyRange&, qco::ReleasedOps&) { + return WalkResult::skip(); + }); + ASSERT_TRUE(res.failed()); +} + TEST_F(DriversTest, ProgramGraphWalk) { qco::QCOProgramBuilder builder(context.get()); builder.initialize(); @@ -121,10 +145,20 @@ TEST_F(DriversTest, ProgramGraphWalk) { const auto [q04, q13] = builder.cx(q03, q12); const auto q14 = builder.h(q13); - builder.measure(q04); - builder.measure(q14); - builder.measure(q23); - builder.measure(q31); + Value iterQ0; + Value iterQ1; + ValueRange blockArgs; + const auto forResults = builder.scfFor( + 0, 3, 1, {q04, q14, q23, q31}, [&](Value, ValueRange args) { + blockArgs = args; + std::tie(iterQ0, iterQ1) = builder.cx(args[0], args[1]); + return SmallVector{iterQ0, iterQ1, args[2], args[3]}; + }); + + builder.measure(forResults[0]); + builder.measure(forResults[1]); + builder.measure(forResults[2]); + builder.measure(forResults[3]); auto mod = builder.finalize(); auto func = *(mod->getOps().begin()); @@ -151,11 +185,12 @@ TEST_F(DriversTest, ProgramGraphWalk) { }); ASSERT_TRUE(res.succeeded()); - ASSERT_GE(readyPerLayer.size(), 3); + ASSERT_GE(readyPerLayer.size(), 4); ASSERT_TRUE(readyPerLayer[0].contains(q02.getDefiningOp())); ASSERT_TRUE(readyPerLayer[0].contains(q21.getDefiningOp())); ASSERT_TRUE(readyPerLayer[1].contains(q12.getDefiningOp())); ASSERT_TRUE(readyPerLayer[2].contains(q04.getDefiningOp())); + ASSERT_TRUE(readyPerLayer[3].contains(forResults[0].getDefiningOp())); // Backward pass. readyPerLayer.clear(); @@ -171,11 +206,12 @@ TEST_F(DriversTest, ProgramGraphWalk) { }); ASSERT_TRUE(res.succeeded()); - ASSERT_GE(readyPerLayer.size(), 3); - ASSERT_TRUE(readyPerLayer[0].contains(q04.getDefiningOp())); - ASSERT_TRUE(readyPerLayer[1].contains(q12.getDefiningOp())); - ASSERT_TRUE(readyPerLayer[2].contains(q02.getDefiningOp())); - ASSERT_TRUE(readyPerLayer[2].contains(q21.getDefiningOp())); + ASSERT_GE(readyPerLayer.size(), 4); + ASSERT_TRUE(readyPerLayer[0].contains(forResults[0].getDefiningOp())); + ASSERT_TRUE(readyPerLayer[1].contains(q04.getDefiningOp())); + ASSERT_TRUE(readyPerLayer[2].contains(q12.getDefiningOp())); + ASSERT_TRUE(readyPerLayer[3].contains(q02.getDefiningOp())); + ASSERT_TRUE(readyPerLayer[3].contains(q21.getDefiningOp())); // Forward, but instead of releasing all, we use ::skip(). readyPerLayer.clear(); @@ -186,16 +222,16 @@ TEST_F(DriversTest, ProgramGraphWalk) { layer.insert(op); } readyPerLayer.emplace_back(layer); - return WalkResult::skip(); }); ASSERT_TRUE(res.succeeded()); - ASSERT_GE(readyPerLayer.size(), 3); + ASSERT_GE(readyPerLayer.size(), 4); ASSERT_TRUE(readyPerLayer[0].contains(q02.getDefiningOp())); ASSERT_TRUE(readyPerLayer[0].contains(q21.getDefiningOp())); ASSERT_TRUE(readyPerLayer[1].contains(q12.getDefiningOp())); ASSERT_TRUE(readyPerLayer[2].contains(q04.getDefiningOp())); + ASSERT_TRUE(readyPerLayer[3].contains(forResults[0].getDefiningOp())); // Backward, but stop after first layer. readyPerLayer.clear(); @@ -212,5 +248,27 @@ TEST_F(DriversTest, ProgramGraphWalk) { ASSERT_TRUE(res.failed()); ASSERT_EQ(readyPerLayer.size(), 1); - ASSERT_TRUE(readyPerLayer[0].contains(q04.getDefiningOp())); + ASSERT_TRUE(readyPerLayer[0].contains(forResults[0].getDefiningOp())); + + // Forward, but start at block arguments. + wires.clear(); + for (Value arg : blockArgs) { + wires.emplace_back(arg); + } + + readyPerLayer.clear(); + res = qco::walkProgramGraph( + wires, [&](const qco::ReadyRange& ready, qco::ReleasedOps& released) { + DenseSet layer; + for (const auto& [op, progs] : ready) { + layer.insert(op); + released.emplace_back(op); + } + readyPerLayer.emplace_back(layer); + return WalkResult::advance(); + }); + + ASSERT_TRUE(res.succeeded()); + ASSERT_GE(readyPerLayer.size(), 1); + ASSERT_TRUE(readyPerLayer[0].contains(iterQ0.getDefiningOp())); } From 1e9228a973e36c456478f93858b762952b934165 Mon Sep 17 00:00:00 2001 From: Matthias Reumann Date: Wed, 24 Jun 2026 14:59:46 +0200 Subject: [PATCH 4/7] Fix lint and fetch updated wire iterator tests --- .../Dialect/QCO/Utils/test_drivers.cpp | 2 +- .../Dialect/QCO/Utils/test_wireiterator.cpp | 98 +++++++++++++++++-- 2 files changed, 93 insertions(+), 7 deletions(-) diff --git a/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp b/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp index b4aa99c677..3c8c1480ab 100644 --- a/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp +++ b/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp @@ -18,9 +18,9 @@ #include #include #include -#include #include #include +#include #include #include #include diff --git a/mlir/unittests/Dialect/QCO/Utils/test_wireiterator.cpp b/mlir/unittests/Dialect/QCO/Utils/test_wireiterator.cpp index 7eb59db509..295ca3c016 100644 --- a/mlir/unittests/Dialect/QCO/Utils/test_wireiterator.cpp +++ b/mlir/unittests/Dialect/QCO/Utils/test_wireiterator.cpp @@ -15,11 +15,14 @@ #include #include #include +#include #include #include +#include #include #include +#include #include using namespace mlir; @@ -29,7 +32,8 @@ class WireIteratorTest : public testing::TestWithParam { protected: void SetUp() override { DialectRegistry registry; - registry.insert(); + registry.insert(); context = std::make_unique(); context->appendDialectRegistry(registry); @@ -40,20 +44,44 @@ class WireIteratorTest : public testing::TestWithParam { }; } // namespace -TEST_P(WireIteratorTest, MixedUse) { +TEST_P(WireIteratorTest, Traversal) { const bool isDynamic = GetParam(); // Build circuit. qco::QCOProgramBuilder builder(context.get()); builder.initialize(); + const auto q00 = isDynamic ? builder.allocQubit() : builder.staticQubit(0); const auto q10 = isDynamic ? builder.allocQubit() : builder.staticQubit(1); const auto q01 = builder.h(q00); const auto [q02, q11] = builder.cx(q01, q10); const auto [q03, c0] = builder.measure(q02); const auto q04 = builder.reset(q03); - builder.sink(q04); - builder.sink(q11); + + Value iterQ00; + Value iterQ01; + Value iterQ02; + Value iterQ10; + Value iterQ11; + + const auto loopOut = + builder.scfFor(1, 4, 1, {q04, q11}, [&](Value, ValueRange iterArgs) { + iterQ00 = iterArgs[0]; + iterQ10 = iterArgs[1]; + iterQ01 = builder.h(iterQ00); + std::tie(iterQ02, iterQ11) = builder.cx(iterQ01, iterQ10); + return SmallVector{iterQ02, iterQ11}; + }); + const auto q05 = loopOut[0]; + const auto q12 = loopOut[1]; + const auto ifOut = builder.qcoIf( + true, {q05, q12}, + [&](ValueRange args) { return SmallVector{args[0], args[1]}; }, + [&](ValueRange args) { return SmallVector{args[0], args[1]}; }); + const auto q06 = ifOut[0]; + const auto q13 = ifOut[1]; + builder.sink(q06); + builder.sink(q13); [[maybe_unused]] auto module = builder.finalize(); // Setup WireIterator. @@ -83,7 +111,15 @@ TEST_P(WireIteratorTest, MixedUse) { ASSERT_EQ(it.qubit(), q04); ++it; - ASSERT_EQ(it.operation(), *(q04.getUsers().begin())); // qco.sink + ASSERT_EQ(it.operation(), q05.getDefiningOp()); // scf.for + ASSERT_EQ(it.qubit(), q05); + + ++it; + ASSERT_EQ(it.operation(), q06.getDefiningOp()); // qco.if + ASSERT_EQ(it.qubit(), q06); + + ++it; + ASSERT_EQ(it.operation(), *(q06.getUsers().begin())); // qco.sink ASSERT_EQ(it.qubit(), nullptr); ++it; @@ -97,9 +133,17 @@ TEST_P(WireIteratorTest, MixedUse) { // --it; - ASSERT_EQ(it.operation(), *(q04.getUsers().begin())); // qco.sink + ASSERT_EQ(it.operation(), *(q06.getUsers().begin())); // qco.sink ASSERT_EQ(it.qubit(), nullptr); + --it; + ASSERT_EQ(it.operation(), q06.getDefiningOp()); // qco.if + ASSERT_EQ(it.qubit(), q06); + + --it; + ASSERT_EQ(it.operation(), q05.getDefiningOp()); // scf.for + ASSERT_EQ(it.qubit(), q05); + --it; ASSERT_EQ(it.operation(), q04.getDefiningOp()); // qco.reset ASSERT_EQ(it.qubit(), q04); @@ -123,6 +167,48 @@ TEST_P(WireIteratorTest, MixedUse) { --it; ASSERT_EQ(it.operation(), q00.getDefiningOp()); // qco.alloc or qco.static ASSERT_EQ(it.qubit(), q00); + + // + // Test: Recursive use with block-argument. + // + + qco::WireIterator recIt(iterQ00); + ASSERT_EQ(recIt.operation(), nullptr); // Blockargument + ASSERT_EQ(recIt.qubit(), iterQ00); + + ++recIt; + ASSERT_EQ(recIt.operation(), iterQ01.getDefiningOp()); // qco.h + ASSERT_EQ(recIt.qubit(), iterQ01); + + ++recIt; + ASSERT_EQ(recIt.operation(), iterQ02.getDefiningOp()); // qco.ctrl + ASSERT_EQ(recIt.qubit(), iterQ02); + + ++recIt; + ASSERT_EQ(recIt.operation(), *(iterQ02.getUsers().begin())); // scf.yield + ASSERT_EQ(recIt.qubit(), nullptr); + + ++recIt; + ASSERT_EQ(recIt, std::default_sentinel); + + ++recIt; + ASSERT_EQ(recIt, std::default_sentinel); + + --recIt; + ASSERT_EQ(recIt.operation(), *(iterQ02.getUsers().begin())); // scf.yield + ASSERT_EQ(recIt.qubit(), nullptr); + + --recIt; + ASSERT_EQ(recIt.operation(), iterQ02.getDefiningOp()); // qco.ctrl + ASSERT_EQ(recIt.qubit(), iterQ02); + + --recIt; + ASSERT_EQ(recIt.operation(), iterQ01.getDefiningOp()); // qco.h + ASSERT_EQ(recIt.qubit(), iterQ01); + + --recIt; + ASSERT_EQ(recIt.operation(), nullptr); // Blockargument + ASSERT_EQ(recIt.qubit(), iterQ00); } INSTANTIATE_TEST_SUITE_P(DynamicAndStatic, WireIteratorTest, ::testing::Bool(), From c8ba172909a4ec54bfeaac048040de46cc802126 Mon Sep 17 00:00:00 2001 From: Matthias Reumann Date: Thu, 25 Jun 2026 07:39:14 +0200 Subject: [PATCH 5/7] Address PR comments --- mlir/include/mlir/Dialect/QCO/Utils/Drivers.h | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h b/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h index c098af10b4..3d360ebb78 100644 --- a/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h +++ b/mlir/include/mlir/Dialect/QCO/Utils/Drivers.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/QCO/Utils/WireIterator.h" #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include #include #include #include @@ -95,17 +96,21 @@ using ReleasedOps = SmallVector; using PendingWiresMap = DenseMap>; namespace impl { -/// Return the number of qubits a operation produces/consumes. -inline size_t getNumQubits(Operation* op) { + +/// Return the number of qubit arguments of unitary-like operation. +inline size_t getNumQubitArgs(Operation* op) { return TypeSwitch(op) .Case( [&](UnitaryOpInterface op) { return op.getNumQubits(); }) - .Case([&](scf::ForOp op) { return op.getInits().size(); }) - .Case([&](scf::WhileOp op) { return op.getInits().size(); }) - .Case([&](qco::IfOp op) { return op.getQubits().size(); }) - .Case( - [&](auto) { return 1; }) + .Case([&](auto op) { + return llvm::count_if( + op.getInits(), [](Value v) { return isa(v.getType()); }); + }) + .Case([&](qco::IfOp op) { + return llvm::count_if(op.getQubits(), [](Value v) { + return isa(v.getType()); + }); + }) .Default([&](Operation* op) { const auto name = op->getName().getStringRef(); reportFatalInternalError("unknown op: " + name); @@ -117,7 +122,7 @@ inline size_t getNumQubits(Operation* op) { struct IsReady { bool operator()(PendingWiresMap::value_type& kv) const { const auto npending = kv.second.size(); - return impl::getNumQubits(kv.first) == npending; + return impl::getNumQubitArgs(kv.first) == npending; } }; @@ -175,16 +180,27 @@ LogicalResult walkProgramGraph(MutableArrayRef wires, } while (Traits::isActive(it)) { - assert(it.operation() != nullptr); - const auto nqubits = impl::getNumQubits(it.operation()); - assert(nqubits != 0); - if (nqubits == 1) { + // For source-like (AllocOp, StaticOp, qtensor::ExtractOp), + // sink-like (SinkOp, YieldOp, qtensor::InsertOp, scf::YieldOp), and + // one-qubit non-unitary (ResetOp, MeasureOp) operations, simply advance + // the iterator. + + if (isa( + it.operation())) { std::ranges::advance(it, Traits::stride()); continue; } - assert(nqubits >= 2); + const auto nqubits = impl::getNumQubitArgs(it.operation()); + + // Advance past one-qubit operations. + + if (nqubits == 1) { + std::ranges::advance(it, Traits::stride()); + continue; + } // If there are fewer wires than the operation requires inputs, // it's impossible to release the operation. Hence, fail. From 483a5c077ca69821f8dd36812808c343e9ebb025 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jun 2026 08:22:14 +0000 Subject: [PATCH 6/7] =?UTF-8?q?=F0=9F=8E=A8=20pre-commit=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0efbe0e493..ef7be25d24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,9 +42,9 @@ with the exception that minor releases may include breaking changes. [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1710], [#1717], [#1728], [#1730], [#1749], [#1751], [#1762], [#1765], - [#1774], [#1780], [#1781], [#1782], [#1787], [#1802], [#1806], [#1807], [#1808]) - ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], - [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**], + [#1774], [#1780], [#1781], [#1782], [#1787], [#1802], [#1806], [#1807], + [#1808]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], + [**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**], [**@simon1hofmann**]) ### Changed From 33aadd3ecde147e1e3121ee4c47397526740d500 Mon Sep 17 00:00:00 2001 From: Matthias Reumann Date: Thu, 25 Jun 2026 10:37:34 +0200 Subject: [PATCH 7/7] Add qco.if op to driver test --- mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp b/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp index 3c8c1480ab..c9f84fefd0 100644 --- a/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp +++ b/mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp @@ -155,7 +155,14 @@ TEST_F(DriversTest, ProgramGraphWalk) { return SmallVector{iterQ0, iterQ1, args[2], args[3]}; }); - builder.measure(forResults[0]); + const auto q05 = builder.qcoIf( + false, forResults[0], + [&](ValueRange args) { return SmallVector{builder.h(args[0])}; }, + [&](ValueRange args) { + return SmallVector{builder.id(args[0])}; + })[0]; + + builder.measure(q05); builder.measure(forResults[1]); builder.measure(forResults[2]); builder.measure(forResults[3]);