Skip to content

Commit bb53604

Browse files
committed
add basic test case for UnitaryOpInterface::getFastUnitaryMatrix()
1 parent 5891136 commit bb53604

2 files changed

Lines changed: 137 additions & 1 deletion

File tree

mlir/unittests/dialect/qco/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#
77
# Licensed under the MIT License
88

9-
add_executable(mqt-core-mlir-qco-dialect-test test_unitary_matrix.cpp)
9+
add_executable(mqt-core-mlir-qco-dialect-test test_unitary_matrix.cpp test_unitary_op_interface.cpp)
1010

1111
target_link_libraries(mqt-core-mlir-qco-dialect-test PRIVATE GTest::gtest_main MLIRQCODialect
1212
MLIRLLVMDialect MLIRQCOProgramBuilder)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright (c) 2023 - 2026 Chair for Design Automation, TUM
3+
* Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH
4+
* All rights reserved.
5+
*
6+
* SPDX-License-Identifier: MIT
7+
*
8+
* Licensed under the MIT License
9+
*/
10+
11+
#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h"
12+
13+
#include <gtest/gtest.h>
14+
#include <mlir/Dialect/ControlFlow/IR/ControlFlow.h>
15+
#include <mlir/Dialect/Func/IR/FuncOps.h>
16+
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
17+
#include <mlir/Dialect/SCF/IR/SCF.h>
18+
19+
namespace {
20+
21+
using namespace mlir;
22+
23+
class QcoUnitaryOpInterfaceTest : public testing::Test {
24+
protected:
25+
void SetUp() override {
26+
DialectRegistry registry;
27+
registry
28+
.insert<qco::QCODialect, arith::ArithDialect, cf::ControlFlowDialect,
29+
func::FuncDialect, scf::SCFDialect, LLVM::LLVMDialect>();
30+
31+
context = std::make_unique<MLIRContext>();
32+
context->appendDialectRegistry(registry);
33+
context->loadAllAvailableDialects();
34+
}
35+
36+
/**
37+
* @brief Build expected QCO IR programmatically and run canonicalization
38+
*/
39+
[[nodiscard]] OwningOpRef<ModuleOp> buildQCOIR(
40+
const std::function<void(qco::QCOProgramBuilder&)>& buildFunc) const {
41+
qco::QCOProgramBuilder builder(context.get());
42+
builder.initialize();
43+
buildFunc(builder);
44+
return builder.finalize();
45+
}
46+
47+
/**
48+
* @brief Get first operation of given type in a module containing a function
49+
* as its first operation.
50+
*/
51+
template <typename OpType>
52+
[[nodiscard]] OpType getFirstOp(ModuleOp moduleOp) {
53+
auto funcOp = llvm::dyn_cast<func::FuncOp>(
54+
moduleOp.getBody()->getOperations().front());
55+
assert(funcOp);
56+
57+
auto ops = funcOp.getOps<OpType>();
58+
if (ops.empty()) {
59+
return nullptr;
60+
}
61+
62+
return *ops.begin();
63+
}
64+
65+
/**
66+
* @brief Get text representation of given module.
67+
*/
68+
[[nodiscard]] static std::string toString(ModuleOp moduleOp) {
69+
std::string buffer;
70+
llvm::raw_string_ostream serializeStream{buffer};
71+
moduleOp->print(serializeStream);
72+
return serializeStream.str();
73+
}
74+
75+
private:
76+
std::unique_ptr<MLIRContext> context;
77+
};
78+
79+
} // namespace
80+
81+
TEST_F(QcoUnitaryOpInterfaceTest, getFastUnitaryMatrix2x2) {
82+
auto moduleOp = buildQCOIR([](qco::QCOProgramBuilder& builder) {
83+
auto reg = builder.allocQubitRegister(1, "q");
84+
reg[0] = builder.id(reg[0]);
85+
reg[0] = builder.rx(1.0, reg[0]);
86+
reg[0] = builder.u(0.2, 0.3, 0.4, reg[0]);
87+
});
88+
89+
auto&& moduleOps = moduleOp->getBody()->getOperations();
90+
ASSERT_FALSE(moduleOps.empty());
91+
auto funcOp = llvm::dyn_cast<func::FuncOp>(moduleOps.begin());
92+
for (auto&& op : funcOp.getOps()) {
93+
auto unitaryOp = llvm::dyn_cast<qco::UnitaryOpInterface>(op);
94+
if (unitaryOp) {
95+
EXPECT_EQ(unitaryOp.getUnitaryMatrix(),
96+
unitaryOp.getFastUnitaryMatrix<Eigen::Matrix2cd>());
97+
}
98+
}
99+
}
100+
101+
TEST_F(QcoUnitaryOpInterfaceTest, getFastUnitaryMatrix4x4) {
102+
auto moduleOp = buildQCOIR([](qco::QCOProgramBuilder& builder) {
103+
auto reg = builder.allocQubitRegister(2, "q");
104+
std::tie(reg[0], reg[1]) = builder.rxx(2.0, reg[0], reg[1]);
105+
std::tie(reg[0], reg[1]) = builder.rzx(1.0, reg[0], reg[1]);
106+
});
107+
108+
auto&& moduleOps = moduleOp->getBody()->getOperations();
109+
ASSERT_FALSE(moduleOps.empty());
110+
auto funcOp = llvm::dyn_cast<func::FuncOp>(moduleOps.begin());
111+
for (auto&& op : funcOp.getOps()) {
112+
auto unitaryOp = llvm::dyn_cast<qco::UnitaryOpInterface>(op);
113+
if (unitaryOp) {
114+
EXPECT_EQ(unitaryOp.getUnitaryMatrix(),
115+
unitaryOp.getFastUnitaryMatrix<Eigen::Matrix4cd>());
116+
}
117+
}
118+
}
119+
120+
TEST_F(QcoUnitaryOpInterfaceTest, getFastUnitaryMatrixDynamic) {
121+
auto moduleOp = buildQCOIR([](qco::QCOProgramBuilder& builder) {
122+
auto reg = builder.allocQubitRegister(2, "q");
123+
std::tie(reg[1], reg[0]) = builder.ch(reg[1], reg[0]);
124+
});
125+
126+
auto&& moduleOps = moduleOp->getBody()->getOperations();
127+
ASSERT_FALSE(moduleOps.empty());
128+
auto funcOp = llvm::dyn_cast<func::FuncOp>(moduleOps.begin());
129+
for (auto&& op : funcOp.getOps()) {
130+
auto unitaryOp = llvm::dyn_cast<qco::UnitaryOpInterface>(op);
131+
if (unitaryOp) {
132+
EXPECT_EQ(unitaryOp.getUnitaryMatrix(),
133+
unitaryOp.getFastUnitaryMatrix<Eigen::MatrixXcd>());
134+
}
135+
}
136+
}

0 commit comments

Comments
 (0)