Skip to content
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ with the exception that minor releases may include breaking changes.
[#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623],
[#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700],
[#1710], [#1717], [#1728], [#1730], [#1749], [#1751], [#1762], [#1765],
[#1774], [#1780], [#1781], [#1782], [#1787], [#1802], [#1806], [#1807])
([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**],
[**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**],
[#1774], [#1780], [#1781], [#1782], [#1787], [#1802], [#1806], [#1807],
[#1808]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**],
[**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**],
[**@simon1hofmann**])

### Changed
Expand Down Expand Up @@ -598,6 +598,7 @@ changelogs._

<!-- PR links -->

[#1808]: https://github.com/munich-quantum-toolkit/core/pull/1808
[#1807]: https://github.com/munich-quantum-toolkit/core/pull/1807
[#1806]: https://github.com/munich-quantum-toolkit/core/pull/1806
[#1802]: https://github.com/munich-quantum-toolkit/core/pull/1802
Expand Down
135 changes: 81 additions & 54 deletions mlir/include/mlir/Dialect/QCO/Utils/Drivers.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,23 @@

#pragma once

#include "mlir/Dialect/QCO/IR/QCODialect.h"
#include "mlir/Dialect/QCO/IR/QCOInterfaces.h"
#include "mlir/Dialect/QCO/IR/QCOOps.h"
#include "mlir/Dialect/QCO/Utils/Qubits.h"
#include "mlir/Dialect/QCO/Utils/WireIterator.h"
#include "mlir/Dialect/QTensor/IR/QTensorOps.h"

#include <llvm/ADT/STLExtras.h>
#include <llvm/ADT/TypeSwitch.h>
#include <llvm/Support/ErrorHandling.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/Region.h>
#include <mlir/IR/Value.h>
#include <mlir/IR/Visitors.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Support/WalkResult.h>

#include <cassert>
#include <cstddef>
#include <functional>
#include <iterator>
Expand Down Expand Up @@ -88,13 +92,37 @@ LogicalResult walkProgram(Region& region, const WalkProgramFn& fn) {
return success();
}

using ReleasedOps = SmallVector<UnitaryOpInterface, 8>;
using PendingWiresMap =
DenseMap<UnitaryOpInterface, SmallVector<std::size_t, 2>>;
using ReleasedOps = SmallVector<Operation*, 8>;
using PendingWiresMap = DenseMap<Operation*, SmallVector<size_t>>;

namespace impl {

/// Return the number of qubit arguments of unitary-like operation.
inline size_t getNumQubitArgs(Operation* op) {
return TypeSwitch<Operation*, size_t>(op)
.Case<UnitaryOpInterface>(
[&](UnitaryOpInterface op) { return op.getNumQubits(); })
.Case<scf::ForOp, scf::WhileOp>([&](auto op) {
return llvm::count_if(
op.getInits(), [](Value v) { return isa<QubitType>(v.getType()); });
})
.Case<qco::IfOp>([&](qco::IfOp op) {
return llvm::count_if(op.getQubits(), [](Value v) {
return isa<QubitType>(v.getType());
});
})
.Default([&](Operation* op) {
const auto name = op->getName().getStringRef();
reportFatalInternalError("unknown op: " + name);
return 0;
});
}
} // namespace impl

struct IsReady {
bool operator()(PendingWiresMap::value_type& kv) const {
return kv.second.size() == kv.first.getNumQubits();
const auto npending = kv.second.size();
return impl::getNumQubitArgs(kv.first) == npending;
}
};

Expand Down Expand Up @@ -134,67 +162,66 @@ LogicalResult walkProgramGraph(MutableArrayRef<WireIterator> wires,
using Traits = WireTraversalTraits<Direction>;

ReleasedOps released;

PendingWiresMap pending;
pending.reserve(wires.size());

SmallVector<std::size_t> curr(wires.size());
SmallVector<size_t> curr(wires.size());
std::iota(curr.begin(), curr.end(), 0UL);

SmallVector<std::size_t> next;
SmallVector<size_t> next;
next.reserve(wires.size());

while (!curr.empty()) {
for (std::size_t i : curr) {
for (size_t i : curr) {
auto& it = wires[i];

if (it.operation() == nullptr) { // isa<BlockArgument>
std::ranges::advance(it, Traits::stride());
}

while (Traits::isActive(it)) {
const auto res =
TypeSwitch<Operation*, WalkResult>(it.operation())
.template Case<UnitaryOpInterface>([&](UnitaryOpInterface& op) {
// If there are fewer wires than the qubit requires inputs,
// it's impossible to release the operation. Hence, fail.
if (op.getNumQubits() > wires.size()) {
return WalkResult::interrupt();
}

if (op.getNumQubits() == 1) {
std::ranges::advance(it, Traits::stride());
return WalkResult::advance();
}

// Insert the unitary to the pending map.
// The caller decides if this op should be released.
const auto [mapIt, inserted] = pending.try_emplace(op);
auto& indices = mapIt->second;

if (inserted) {
indices.reserve(op.getNumQubits());
}

indices.emplace_back(i);

return WalkResult::skip(); // Stop at multi-qubit gate.
})
// AllocOp, StaticOp, and qtensor::ExtractOp are only reachable
// on the forward path; backward isActive() halts before
// reaching them (decrementing at a source op is a no-op).
.template Case<AllocOp, StaticOp, qtensor::ExtractOp, ResetOp,
MeasureOp, SinkOp, qtensor::InsertOp>([&](auto) {
std::ranges::advance(it, Traits::stride());
return WalkResult::advance();
})
.Default([&](Operation* op) -> WalkResult {
const auto name = op->getName().getStringRef();
report_fatal_error("unknown op encountered: " + name);
});

if (res.wasSkipped()) {
break;

// For source-like (AllocOp, StaticOp, qtensor::ExtractOp),
// sink-like (SinkOp, YieldOp, qtensor::InsertOp, scf::YieldOp), and
// one-qubit non-unitary (ResetOp, MeasureOp) operations, simply advance
// the iterator.

if (isa<AllocOp, StaticOp, ResetOp, MeasureOp, SinkOp, YieldOp,
qtensor::ExtractOp, qtensor::InsertOp, scf::YieldOp>(
it.operation())) {
std::ranges::advance(it, Traits::stride());
continue;
}

const auto nqubits = impl::getNumQubitArgs(it.operation());

// Advance past one-qubit operations.

if (nqubits == 1) {
std::ranges::advance(it, Traits::stride());
continue;
}

if (res.wasInterrupted()) {
// If there are fewer wires than the operation requires inputs,
// it's impossible to release the operation. Hence, fail.

if (nqubits > wires.size()) {
return failure();
}

// Insert the unitary to the pending map.
// The caller decides if this op should be released.

const auto [mapIt, inserted] = pending.try_emplace(it.operation());
auto& indices = mapIt->second;

if (inserted) {
indices.reserve(nqubits);
}

indices.emplace_back(i);

break; // Stop at multi-qubit unitary.
}
}

Expand All @@ -212,11 +239,11 @@ LogicalResult walkProgramGraph(MutableArrayRef<WireIterator> wires,
}
}

for (const UnitaryOpInterface& op : released) {
for (Operation* op : released) {
const auto mapIt = pending.find(op);
assert(mapIt != pending.end());

for (std::size_t i : mapIt->second) {
for (size_t i : mapIt->second) {
std::ranges::advance(wires[i], Traits::stride());
next.emplace_back(i);
}
Expand Down
93 changes: 79 additions & 14 deletions mlir/unittests/Dialect/QCO/Utils/test_drivers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
#include <llvm/ADT/SmallVector.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/DialectRegistry.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/ValueRange.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Support/WalkResult.h>

Expand All @@ -36,7 +38,8 @@ class DriversTest : public testing::Test {
protected:
void SetUp() override {
DialectRegistry registry;
registry.insert<qco::QCODialect, arith::ArithDialect, func::FuncDialect>();
registry.insert<qco::QCODialect, scf::SCFDialect, arith::ArithDialect,
func::FuncDialect>();

context = std::make_unique<MLIRContext>();
context->appendDialectRegistry(registry);
Expand Down Expand Up @@ -99,6 +102,27 @@ TEST_F(DriversTest, ProgramWalk) {
ASSERT_EQ(ex3, q31);
}

TEST_F(DriversTest, ProgramGraphWalkTooFewWires) {
qco::QCOProgramBuilder builder(context.get());
builder.initialize();

const auto q00 = builder.allocQubit();
const auto q10 = builder.allocQubit();
const auto [q01, q11] = builder.cx(q00, q10);

[[maybe_unused]] auto mod = builder.finalize();

// Collect just one wire.
SmallVector<qco::WireIterator> wires;
wires.emplace_back(q00);

auto res = qco::walkProgramGraph<qco::WireDirection::Forward>(
wires, [&](const qco::ReadyRange&, qco::ReleasedOps&) {
return WalkResult::skip();
});
ASSERT_TRUE(res.failed());
}

TEST_F(DriversTest, ProgramGraphWalk) {
qco::QCOProgramBuilder builder(context.get());
builder.initialize();
Expand All @@ -121,10 +145,27 @@ TEST_F(DriversTest, ProgramGraphWalk) {
const auto [q04, q13] = builder.cx(q03, q12);
const auto q14 = builder.h(q13);

builder.measure(q04);
builder.measure(q14);
builder.measure(q23);
builder.measure(q31);
Value iterQ0;
Value iterQ1;
ValueRange blockArgs;
const auto forResults = builder.scfFor(
0, 3, 1, {q04, q14, q23, q31}, [&](Value, ValueRange args) {
blockArgs = args;
std::tie(iterQ0, iterQ1) = builder.cx(args[0], args[1]);
return SmallVector<Value>{iterQ0, iterQ1, args[2], args[3]};
});

const auto q05 = builder.qcoIf(
false, forResults[0],
[&](ValueRange args) { return SmallVector<Value>{builder.h(args[0])}; },
[&](ValueRange args) {
return SmallVector<Value>{builder.id(args[0])};
})[0];

builder.measure(q05);
builder.measure(forResults[1]);
builder.measure(forResults[2]);
builder.measure(forResults[3]);

auto mod = builder.finalize();
auto func = *(mod->getOps<func::FuncOp>().begin());
Expand All @@ -151,11 +192,12 @@ TEST_F(DriversTest, ProgramGraphWalk) {
});

ASSERT_TRUE(res.succeeded());
ASSERT_GE(readyPerLayer.size(), 3);
ASSERT_GE(readyPerLayer.size(), 4);
ASSERT_TRUE(readyPerLayer[0].contains(q02.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[0].contains(q21.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[1].contains(q12.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[2].contains(q04.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[3].contains(forResults[0].getDefiningOp()));

// Backward pass.
readyPerLayer.clear();
Expand All @@ -171,11 +213,12 @@ TEST_F(DriversTest, ProgramGraphWalk) {
});

ASSERT_TRUE(res.succeeded());
ASSERT_GE(readyPerLayer.size(), 3);
ASSERT_TRUE(readyPerLayer[0].contains(q04.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[1].contains(q12.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[2].contains(q02.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[2].contains(q21.getDefiningOp()));
ASSERT_GE(readyPerLayer.size(), 4);
ASSERT_TRUE(readyPerLayer[0].contains(forResults[0].getDefiningOp()));
ASSERT_TRUE(readyPerLayer[1].contains(q04.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[2].contains(q12.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[3].contains(q02.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[3].contains(q21.getDefiningOp()));

// Forward, but instead of releasing all, we use ::skip().
readyPerLayer.clear();
Expand All @@ -186,16 +229,16 @@ TEST_F(DriversTest, ProgramGraphWalk) {
layer.insert(op);
}
readyPerLayer.emplace_back(layer);

return WalkResult::skip();
});

ASSERT_TRUE(res.succeeded());
ASSERT_GE(readyPerLayer.size(), 3);
ASSERT_GE(readyPerLayer.size(), 4);
ASSERT_TRUE(readyPerLayer[0].contains(q02.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[0].contains(q21.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[1].contains(q12.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[2].contains(q04.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[3].contains(forResults[0].getDefiningOp()));

// Backward, but stop after first layer.
readyPerLayer.clear();
Expand All @@ -212,5 +255,27 @@ TEST_F(DriversTest, ProgramGraphWalk) {

ASSERT_TRUE(res.failed());
ASSERT_EQ(readyPerLayer.size(), 1);
ASSERT_TRUE(readyPerLayer[0].contains(q04.getDefiningOp()));
ASSERT_TRUE(readyPerLayer[0].contains(forResults[0].getDefiningOp()));

// Forward, but start at block arguments.
wires.clear();
for (Value arg : blockArgs) {
wires.emplace_back(arg);
}

readyPerLayer.clear();
res = qco::walkProgramGraph<qco::WireDirection::Forward>(
wires, [&](const qco::ReadyRange& ready, qco::ReleasedOps& released) {
DenseSet<Operation*> layer;
for (const auto& [op, progs] : ready) {
layer.insert(op);
released.emplace_back(op);
}
readyPerLayer.emplace_back(layer);
return WalkResult::advance();
});

ASSERT_TRUE(res.succeeded());
ASSERT_GE(readyPerLayer.size(), 1);
ASSERT_TRUE(readyPerLayer[0].contains(iterQ0.getDefiningOp()));
}
Loading