Skip to content
72 changes: 36 additions & 36 deletions mlir/include/mlir/Dialect/QCO/Utils/Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@ struct Matrix2x2 {
* @brief Embed this single-qubit matrix into an @p numQubits-qubit Hilbert
* space.
*
* Wire @p qubitIndex uses the same MSB-first convention as @ref
* Matrix4x4::kron (high bit first operand, low bit second). For each basis
* pair whose untouched wires match, copies this matrix at the target qubit's
* row/column bits.
* Wire @p qubitIndex uses the same convention as @ref QuantumComputation:
* qubit @p i is bit @p i of the basis index. For each basis pair whose
* untouched wires match, copies this matrix at the target qubit's row/column
* bits.
*
* @param numQubits Number of qubits in the target Hilbert space.
* @param qubitIndex Wire index to act on.
Expand All @@ -253,7 +253,7 @@ struct Matrix2x2 {
/**
* @brief Embed this single-qubit matrix into a two-qubit Hilbert space.
*
* @param qubitIndex Wire index (`0` = high bit / MSB, `1` = low bit).
* @param qubitIndex Wire index to act on (qubit @p i = bit @p i).
* @return The `4x4` embedded unitary.
*/
[[nodiscard]] Matrix4x4 embedInTwoQubit(std::size_t qubitIndex) const;
Expand Down Expand Up @@ -407,11 +407,12 @@ struct Matrix4x4 {
/**
* @brief Kronecker product `lhs (x) rhs` of two single-qubit matrices.
*
* Uses the computational-basis bit order where the first operand labels the
* high bit, matching `UnitaryOpInterface::getUnitaryMatrix4x4`.
* Uses the computational-basis bit order where qubit @p i is bit @p i,
* matching @ref QuantumComputation. For @f$A \otimes B@f$, @p lhs acts on
* wire @f$1@f$ and @p rhs on wire @f$0@f$.
*
* @param lhs Left factor (acts on the high bit / qubit 0).
* @param rhs Right factor (acts on the low bit / qubit 1).
* @param lhs Left factor (acts on wire 1).
* @param rhs Right factor (acts on wire 0).
* @return The `4x4` Kronecker product.
*/
[[nodiscard]] static Matrix4x4 kron(const Matrix2x2& lhs,
Expand Down Expand Up @@ -482,9 +483,9 @@ struct Matrix4x4 {
* @brief Embed this two-qubit matrix into an @p numQubits-qubit Hilbert
* space.
*
* Operand 0 labels the high bit of the pair and acts on @p q0Index; operand 1
* labels the low bit and acts on @p q1Index. For each basis pair whose other
* wires match, copies this matrix at the packed two-qubit row/column indices.
* Operand 0 labels wire @p q0Index and operand 1 labels wire @p q1Index.
* For each basis pair whose other wires match, copies this matrix at the
* packed two-qubit row/column indices.
*
* @param numQubits Number of qubits in the target Hilbert space.
* @param q0Index Wire index of operand 0.
Expand All @@ -496,9 +497,10 @@ struct Matrix4x4 {
std::size_t q1Index) const;

/**
* @brief Reorder this matrix to act on qubits `{0, 1}`.
* @brief Reorder this matrix to act on wires @p q0Index and @p q1Index.
*
* @param q0Index Wire index of operand 0; @p q1Index wire index of operand 1.
* @param q0Index Wire index of operand 0.
* @param q1Index Wire index of operand 1.
* @return Reordered copy of this matrix.
*/
[[nodiscard]] Matrix4x4 reorderForQubits(std::size_t q0Index,
Expand Down Expand Up @@ -614,28 +616,6 @@ class DynamicMatrix {
*/
[[nodiscard]] Complex operator()(std::int64_t row, std::int64_t col) const;

/**
* @brief Copies a 2x2 block into the bottom-right corner.
* @param block Source block placed at indices `(dim-2, dim-2)` through
* `(dim-1, dim-1)`.
*/
void setBottomRightCorner(const Matrix2x2& block);

/**
* @brief Copies a 4x4 block into the bottom-right corner.
* @param block Source block placed at indices `(dim-4, dim-4)` through
* `(dim-1, dim-1)`.
*/
void setBottomRightCorner(const Matrix4x4& block);

/**
* @brief Copies a dynamic block into the bottom-right corner.
* @param block Source block placed at indices `(dim - block.rows(), ...)`
* through
* `(dim-1, dim-1)`.
*/
void setBottomRightCorner(const DynamicMatrix& block);

/**
* @brief Returns the conjugate transpose (adjoint) of this matrix.
* @return Adjoint matrix `A^\dagger`.
Expand Down Expand Up @@ -754,6 +734,26 @@ class DynamicMatrix {
std::unique_ptr<Impl> impl_;
};

/**
* @brief Embeds @p targetUnitary as multi-controlled on @p controlQubits.
*
* Qubit @p i is bit @p i of the basis index, matching @ref QuantumComputation.
* When every control qubit is \f$|1\rangle\f$, applies @p targetUnitary on
* @p targetQubits; otherwise the identity.
*
* @param numQubits Size of the embedding Hilbert space (number of local wires).
* @param controlQubits Local wire indices in \f$[0, \texttt{numQubits})\f$ for
* the control wires (compact positions after remapping, not sparse
* program-register indices).
* @param targetQubits Local wire indices in \f$[0, \texttt{numQubits})\f$ for
* the target wires, in the order used to index @p targetUnitary.
* @param targetUnitary Local unitary on the target subspace.
* @return The controlled unitary on the @p numQubits-qubit Hilbert space.
*/
[[nodiscard]] DynamicMatrix embedControlledUnitary(
std::size_t numQubits, ArrayRef<std::size_t> controlQubits,
ArrayRef<std::size_t> targetQubits, const DynamicMatrix& targetUnitary);

/**
* @brief Type trait for the four supported matrix types.
*
Expand Down
106 changes: 80 additions & 26 deletions mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/SmallVectorExtras.h>
#include <llvm/Support/ErrorHandling.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/QTensor/IR/QTensorOps.h>
#include <mlir/IR/Block.h>
#include <mlir/IR/Builders.h>
Expand All @@ -30,14 +31,66 @@
#include <mlir/IR/Value.h>
#include <mlir/Support/LLVM.h>

#include <cassert>
#include <cstddef>
#include <cstdint>
#include <optional>

using namespace mlir;
using namespace mlir::qco;

/**
* @brief Returns the program register index of @p qubit when known at compile
* time.
*
* Supports @c qco.static and @c qtensor.extract with an @c arith.constant
* index. Dynamic or negative indices yield @c std::nullopt.
*/
[[nodiscard]] static std::optional<std::size_t> programQubitIndex(Value qubit) {
auto* definingOp = qubit.getDefiningOp();
if (definingOp == nullptr) {
return std::nullopt;
}
if (auto staticOp = dyn_cast<StaticOp>(definingOp)) {
return static_cast<std::size_t>(staticOp.getIndex());
}
auto extractOp = dyn_cast<qtensor::ExtractOp>(definingOp);
if (!extractOp) {
return std::nullopt;
}
auto indexOp = extractOp.getIndex().getDefiningOp<arith::ConstantOp>();
if (!indexOp) {
return std::nullopt;
}
const auto indexAttr = dyn_cast<IntegerAttr>(indexOp.getValue());
if (!indexAttr) {
return std::nullopt;
}
const auto index = indexAttr.getInt();
if (index < 0) {
return std::nullopt;
}
return static_cast<std::size_t>(index);
}

/**
* @brief Maps each SSA qubit in @p qubits to its program register index.
*
* @return Indices in operand order, or @c std::nullopt if any wire is not
* resolved by @ref programQubitIndex.
*/
[[nodiscard]] static std::optional<SmallVector<std::size_t>>
resolveQubitIndices(ValueRange qubits) {
SmallVector<std::size_t> indices;
indices.reserve(qubits.size());
for (const auto qubit : qubits) {
if (const auto index = programQubitIndex(qubit)) {
indices.push_back(*index);
} else {
return std::nullopt;
}
}
return indices;
}

namespace {

/**
Expand Down Expand Up @@ -308,36 +361,37 @@ std::optional<DynamicMatrix> CtrlOp::getUnitaryMatrix() {
"is not supported due to memory constraints.");
}

const auto numControls = getNumControls();

// Build `I_{2^controls} ⊗ U` by placing the target block in the bottom-right
// corner of a `2^controls * targetDim` identity.
const auto controlledMatrix =
[numControls](const std::int64_t targetDim,
const auto& targetBlock) -> DynamicMatrix {
auto matrix = DynamicMatrix::identity(static_cast<int64_t>(
(1ULL << numControls) * static_cast<std::size_t>(targetDim)));
matrix.setBottomRightCorner(targetBlock);
return matrix;
};
const auto controlQubits = resolveQubitIndices(getInputControls());
const auto targetQubits = resolveQubitIndices(getInputTargets());
if (!controlQubits || !targetQubits) {
return std::nullopt;
}

// Single inner unitary (e.g. `ctrl { h }`, `ctrl { cx }`).
// Inner unitary on targets: one body op or a composed single-qubit sequence.
std::optional<DynamicMatrix> targetMatrix;
if (auto bodyUnitary =
utils::getSoleBodyUnitary<UnitaryOpInterface>(*getBody())) {
if (const auto targetMatrix =
bodyUnitary.getUnitaryMatrix<DynamicMatrix>()) {
assert(targetMatrix->cols() == targetMatrix->rows());
return controlledMatrix(targetMatrix->cols(), *targetMatrix);
targetMatrix = bodyUnitary.getUnitaryMatrix<DynamicMatrix>();
} else if (getNumTargets() == 1) {
if (const auto composed = composeSingleQubitBodyMatrix(*getBody())) {
targetMatrix = DynamicMatrix(*composed);
}
}
if (!targetMatrix) {
return std::nullopt;
}

// Composed single-qubit body (e.g. `ctrl { h; x }`); embed the 2x2 directly.
if (getNumTargets() == 1) {
if (const auto composed = composeSingleQubitBodyMatrix(*getBody())) {
return controlledMatrix(2, *composed);
}
}
SmallVector<std::size_t> participating;
participating.append(*controlQubits);
participating.append(*targetQubits);
llvm::sort(participating);
participating.erase(llvm::unique(participating), participating.end());

return std::nullopt;
const auto toLocal = [&](const std::size_t wire) {
return static_cast<std::size_t>(llvm::find(participating, wire) -
participating.begin());
};
return embedControlledUnitary(
participating.size(), llvm::map_to_vector(*controlQubits, toLocal),
llvm::map_to_vector(*targetQubits, toLocal), *targetMatrix);
}
Loading
Loading