Skip to content

Commit 86a043a

Browse files
committed
cleanup
1 parent d59ad46 commit 86a043a

1 file changed

Lines changed: 28 additions & 36 deletions

File tree

mlir/include/mlir/Dialect/Common/IR/CommonTraits.h

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,47 +13,35 @@
1313
#include <array>
1414
#include <cmath>
1515
#include <cstddef>
16+
#include <functional>
1617
#include <mlir/IR/OpDefinition.h>
1718
#include <mlir/IR/Operation.h>
1819
#include <mlir/Support/LLVM.h>
1920
#include <stdexcept>
2021

2122
namespace mqt::ir::common {
22-
struct DefinitionMatrixElement {
23-
enum class Type {
24-
Value,
25-
ParameterIndex,
26-
};
27-
enum class Transformation {
28-
Identity,
29-
Sin,
30-
Cos,
31-
};
3223

33-
double value;
34-
Type type;
35-
Transformation transformation = Transformation::Identity;
24+
template <std::size_t NumQubits> struct DefinitionMatrix {
25+
static constexpr std::size_t MatrixSize = 1 << NumQubits;
26+
27+
template<typename T>
28+
using MatrixType = std::array<T, MatrixSize * MatrixSize>;
3629

37-
double operator()() {
38-
if (type == Type::Value) {
39-
switch (transformation) {
40-
case Transformation::Identity:
41-
return value;
42-
case Transformation::Sin:
43-
return std::sin(value);
44-
case Transformation::Cos:
45-
return std::cos(value);
46-
}
47-
return value;
48-
} else {
49-
// TODO
50-
}
30+
MatrixType<double(*)(mlir::ValueRange)> matrix;
31+
32+
static constexpr std::size_t index(std::size_t x, std::size_t y) {
33+
return (y * MatrixSize) + x;
5134
}
52-
};
5335

54-
template <std::size_t NumQubits> struct DefinitionMatrix {
55-
static constexpr std::size_t MatrixSize = 1 << NumQubits;
56-
std::array<double(*)(mlir::ValueRange), MatrixSize * MatrixSize> matrix;
36+
constexpr MatrixType<double> getMatrix(mlir::ValueRange params) {
37+
// TODO? lazy-initialized cache
38+
MatrixType<double> result;
39+
static_assert(result.size() == matrix.size());
40+
for (std::size_t i = 0; i < result.size(); ++i) {
41+
result[i] = matrix[i](params);
42+
}
43+
return result;
44+
}
5745
};
5846

5947
template <size_t N, DefinitionMatrix<N> Matrix> class TargetArityTrait {
@@ -70,11 +58,15 @@ template <size_t N, DefinitionMatrix<N> Matrix> class TargetArityTrait {
7058
return mlir::success();
7159
}
7260

73-
[[nodiscard]] static auto getDefinitionMatrix() { return Matrix; }
74-
[[nodiscard]] static double getDefinitionMatrix(mlir::Operation* op, int x,
75-
int y) {
76-
return 0.0;
77-
// return Matrix[y * (1 << N) + x]();
61+
[[nodiscard]] static auto getDefinitionMatrix() {
62+
return Matrix;
63+
}
64+
[[nodiscard]] static auto getDefinitionMatrix(mlir::Operation* op) {
65+
auto concreteOp = mlir::cast<ConcreteOp>(op);
66+
return Matrix.getMatrix(concreteOp.getParams());
67+
}
68+
[[nodiscard]] static double getDefinitionMatrixElement(mlir::Operation* op, std::size_t x, std::size_t y) {
69+
return getDefinitionMatrix(op).at(DefinitionMatrix<N>::index(x, y));
7870
}
7971
};
8072
};

0 commit comments

Comments
 (0)