1313#include < array>
1414#include < cmath>
1515#include < cstddef>
16+ #include < functional>
1617#include < mlir/IR/OpDefinition.h>
1718#include < mlir/IR/Operation.h>
1819#include < mlir/Support/LLVM.h>
1920#include < stdexcept>
2021
2122namespace mqt ::ir::common {
22- struct DefinitionMatrixElement {
23- enum class Type {
24- Value,
25- ParameterIndex,
26- };
27- enum class Transformation {
28- Identity,
29- Sin,
30- Cos,
31- };
3223
33- double value;
34- Type type;
35- Transformation transformation = Transformation::Identity;
24+ template <std::size_t NumQubits> struct DefinitionMatrix {
25+ static constexpr std::size_t MatrixSize = 1 << NumQubits;
26+
27+ template <typename T>
28+ using MatrixType = std::array<T, MatrixSize * MatrixSize>;
3629
37- double operator ()() {
38- if (type == Type::Value) {
39- switch (transformation) {
40- case Transformation::Identity:
41- return value;
42- case Transformation::Sin:
43- return std::sin (value);
44- case Transformation::Cos:
45- return std::cos (value);
46- }
47- return value;
48- } else {
49- // TODO
50- }
30+ MatrixType<double (*)(mlir::ValueRange)> matrix;
31+
32+ static constexpr std::size_t index (std::size_t x, std::size_t y) {
33+ return (y * MatrixSize) + x;
5134 }
52- };
5335
54- template <std::size_t NumQubits> struct DefinitionMatrix {
55- static constexpr std::size_t MatrixSize = 1 << NumQubits;
56- std::array<double (*)(mlir::ValueRange), MatrixSize * MatrixSize> matrix;
36+ constexpr MatrixType<double > getMatrix (mlir::ValueRange params) {
37+ // TODO? lazy-initialized cache
38+ MatrixType<double > result;
39+ static_assert (result.size () == matrix.size ());
40+ for (std::size_t i = 0 ; i < result.size (); ++i) {
41+ result[i] = matrix[i](params);
42+ }
43+ return result;
44+ }
5745};
5846
5947template <size_t N, DefinitionMatrix<N> Matrix> class TargetArityTrait {
@@ -70,11 +58,15 @@ template <size_t N, DefinitionMatrix<N> Matrix> class TargetArityTrait {
7058 return mlir::success ();
7159 }
7260
73- [[nodiscard]] static auto getDefinitionMatrix () { return Matrix; }
74- [[nodiscard]] static double getDefinitionMatrix (mlir::Operation* op, int x,
75- int y) {
76- return 0.0 ;
77- // return Matrix[y * (1 << N) + x]();
61+ [[nodiscard]] static auto getDefinitionMatrix () {
62+ return Matrix;
63+ }
64+ [[nodiscard]] static auto getDefinitionMatrix (mlir::Operation* op) {
65+ auto concreteOp = mlir::cast<ConcreteOp>(op);
66+ return Matrix.getMatrix (concreteOp.getParams ());
67+ }
68+ [[nodiscard]] static double getDefinitionMatrixElement (mlir::Operation* op, std::size_t x, std::size_t y) {
69+ return getDefinitionMatrix (op).at (DefinitionMatrix<N>::index (x, y));
7870 }
7971 };
8072};
0 commit comments