|
15 | 15 | #include "dd/Node.hpp" |
16 | 16 | #include "dd/Operations.hpp" |
17 | 17 | #include "dd/StateGeneration.hpp" |
18 | | -#include "ir/operations/ClassicControlledOperation.hpp" |
| 18 | +#include "ir/Definitions.hpp" |
| 19 | +#include "ir/operations/IfElseOperation.hpp" |
19 | 20 | #include "ir/operations/NonUnitaryOperation.hpp" |
20 | 21 | #include "ir/operations/OpType.hpp" |
21 | 22 |
|
@@ -85,7 +86,7 @@ auto CircuitSimulator::analyseCircuit() -> CircuitAnalysis { |
85 | 86 | auto analysis = CircuitAnalysis{}; |
86 | 87 |
|
87 | 88 | for (auto& op : *qc) { |
88 | | - if (op->isClassicControlledOperation() || op->getType() == qc::Reset) { |
| 89 | + if (op->isIfElseOperation() || op->getType() == qc::Reset) { |
89 | 90 | analysis.isDynamic = true; |
90 | 91 | } |
91 | 92 | if (const auto* measure = dynamic_cast<qc::NonUnitaryOperation*>(op.get()); |
@@ -156,7 +157,7 @@ CircuitSimulator::singleShot(const bool ignoreNonUnitaries) { |
156 | 157 | (static_cast<double>(approximationInfo.stepNumber + 1)))); |
157 | 158 |
|
158 | 159 | for (auto& op : *qc) { |
159 | | - if (op->isNonUnitaryOperation()) { |
| 160 | + if (op->isNonUnitaryOperation() && !op->isIfElseOperation()) { |
160 | 161 | if (ignoreNonUnitaries) { |
161 | 162 | continue; |
162 | 163 | } |
@@ -186,32 +187,61 @@ CircuitSimulator::singleShot(const bool ignoreNonUnitaries) { |
186 | 187 | } |
187 | 188 | dd->garbageCollect(); |
188 | 189 | } else { |
189 | | - if (op->isClassicControlledOperation()) { |
190 | | - if (auto* classicallyControlledOp = |
191 | | - dynamic_cast<qc::ClassicControlledOperation*>(op.get())) { |
192 | | - const auto startIndex = static_cast<std::uint16_t>( |
193 | | - classicallyControlledOp->getParameter().at(0)); |
194 | | - const auto length = static_cast<std::uint16_t>( |
195 | | - classicallyControlledOp->getParameter().at(1)); |
196 | | - const auto expectedValue = |
197 | | - classicallyControlledOp->getExpectedValue(); |
198 | | - unsigned int actualValue = 0; |
| 190 | + if (op->isIfElseOperation()) { |
| 191 | + if (auto* ifElseOp = dynamic_cast<qc::IfElseOperation*>(op.get())) { |
| 192 | + const auto& comparisonKind = ifElseOp->getComparisonKind(); |
| 193 | + |
| 194 | + std::size_t startIndex = 0; |
| 195 | + std::size_t length = 0; |
| 196 | + std::uint64_t expectedValue = 0; |
| 197 | + if (ifElseOp->getControlBit().has_value()) { |
| 198 | + startIndex = ifElseOp->getControlBit().value(); |
| 199 | + length = 1; |
| 200 | + expectedValue = ifElseOp->getExpectedValueBit() ? 1U : 0U; |
| 201 | + } else { |
| 202 | + startIndex = ifElseOp->getControlRegister()->getStartIndex(); |
| 203 | + length = ifElseOp->getControlRegister()->getSize(); |
| 204 | + expectedValue = ifElseOp->getExpectedValueRegister(); |
| 205 | + } |
| 206 | + |
| 207 | + std::uint64_t actualValue = 0; |
199 | 208 | for (std::size_t i = 0; i < length; i++) { |
200 | 209 | actualValue |= (classicValues[startIndex + i] ? 1U : 0U) << i; |
201 | 210 | } |
202 | 211 |
|
203 | | - // std::clog << "expected " << expected_value << " and actual value |
204 | | - // was " << actual_value << "\n"; |
205 | | - |
206 | | - if (actualValue != expectedValue) { |
| 212 | + const auto control = [actualValue, expectedValue, comparisonKind]() { |
| 213 | + switch (comparisonKind) { |
| 214 | + case qc::ComparisonKind::Eq: |
| 215 | + return actualValue == expectedValue; |
| 216 | + case qc::ComparisonKind::Neq: |
| 217 | + return actualValue != expectedValue; |
| 218 | + case qc::ComparisonKind::Lt: |
| 219 | + return actualValue < expectedValue; |
| 220 | + case qc::ComparisonKind::Leq: |
| 221 | + return actualValue <= expectedValue; |
| 222 | + case qc::ComparisonKind::Gt: |
| 223 | + return actualValue > expectedValue; |
| 224 | + case qc::ComparisonKind::Geq: |
| 225 | + return actualValue >= expectedValue; |
| 226 | + } |
| 227 | + qc::unreachable(); |
| 228 | + }(); |
| 229 | + |
| 230 | + if (control) { |
| 231 | + auto thenOp = ifElseOp->getThenOp()->clone(); |
| 232 | + applyOperationToState(thenOp); |
| 233 | + } else if (ifElseOp->getElseOp() != nullptr) { |
| 234 | + auto elseOp = ifElseOp->getElseOp()->clone(); |
| 235 | + applyOperationToState(elseOp); |
| 236 | + } else { |
207 | 237 | continue; |
208 | 238 | } |
209 | 239 | } else { |
210 | | - throw std::runtime_error( |
211 | | - "Dynamic cast to ClassicControlledOperation failed."); |
| 240 | + throw std::runtime_error("Dynamic cast to IfElseOperation failed."); |
212 | 241 | } |
| 242 | + } else { |
| 243 | + applyOperationToState(op); |
213 | 244 | } |
214 | | - applyOperationToState(op); |
215 | 245 |
|
216 | 246 | if (approximationInfo.stepNumber > 0 && |
217 | 247 | approximationInfo.stepFidelity < 1.0) { |
|
0 commit comments