|
| 1 | +#include <cstdlib> |
| 2 | +#include <functional> |
| 3 | +#include <memory> |
| 4 | +#include <string> |
| 5 | + |
| 6 | +// KeyMemRT dialects |
| 7 | +#include "lib/Dialect/HEIRInterfaces.h" |
| 8 | +#include "lib/Dialect/CKKS/Conversions/CKKSToLWE/CKKSToLWE.h" |
| 9 | +#include "lib/Dialect/CKKS/IR/CKKSDialect.h" |
| 10 | +#include "lib/Dialect/KMRT/IR/KMRTDialect.h" |
| 11 | +#include "lib/Dialect/KMRT/Transforms/Passes.h" |
| 12 | +#include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h" |
| 13 | +#include "lib/Dialect/LWE/IR/LWEDialect.h" |
| 14 | +#include "lib/Dialect/LWE/Transforms/Passes.h" |
| 15 | +#include "lib/Dialect/ModArith/IR/ModArithDialect.h" |
| 16 | +#include "lib/Dialect/Openfhe/IR/OpenfheDialect.h" |
| 17 | +#include "lib/Dialect/Openfhe/Transforms/Passes.h" |
| 18 | +#include "lib/Dialect/Polynomial/IR/PolynomialDialect.h" |
| 19 | +#include "lib/Dialect/RNS/IR/RNSDialect.h" |
| 20 | +#include "lib/Dialect/RNS/IR/RNSTypeInterfaces.h" |
| 21 | +#include "lib/Dialect/Random/IR/RandomDialect.h" |
| 22 | +#include "lib/Dialect/TensorExt/IR/TensorExtDialect.h" |
| 23 | + |
| 24 | +// KeyMemRT transforms |
| 25 | +#include "lib/Transforms/AddRotationKeys/AddRotationKeys.h" |
| 26 | +#include "lib/Transforms/AnnotateModule/AnnotateModule.h" |
| 27 | +#include "lib/Transforms/BootstrapRotationAnalysis/BootstrapRotationAnalysis.h" |
| 28 | +#include "lib/Transforms/LowerLinearTransform/LowerLinearTransform.h" |
| 29 | +#include "lib/Transforms/ProfileAnnotator/ProfileAnnotator.h" |
| 30 | +#include "lib/Transforms/SymbolicBSGSDecomposition/SymbolicBSGSDecomposition.h" |
| 31 | +#include "lib/Transforms/UnnecessaryBootstrapRemoval/UnnecessaryBootstrapRemoval.h" |
| 32 | + |
| 33 | +// MLIR core |
| 34 | +#include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project |
| 35 | +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project |
| 36 | +#include "mlir/include/mlir/Dialect/Affine/Passes.h" // from @llvm-project |
| 37 | +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project |
| 38 | +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| 39 | +#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project |
| 40 | +#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project |
| 41 | +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project |
| 42 | +#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project |
| 43 | +#include "mlir/include/mlir/Pass/PassRegistry.h" // from @llvm-project |
| 44 | +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project |
| 45 | +#include "mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project |
| 46 | +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project |
| 47 | + |
| 48 | +using namespace mlir; |
| 49 | +using namespace heir; |
| 50 | + |
| 51 | +int main(int argc, char **argv) { |
| 52 | + mlir::DialectRegistry registry; |
| 53 | + |
| 54 | + // Register KeyMemRT dialects |
| 55 | + registry.insert<ckks::CKKSDialect>(); |
| 56 | + registry.insert<kmrt::KMRTDialect>(); |
| 57 | + registry.insert<lwe::LWEDialect>(); |
| 58 | + registry.insert<mod_arith::ModArithDialect>(); |
| 59 | + registry.insert<openfhe::OpenfheDialect>(); |
| 60 | + registry.insert<::mlir::heir::polynomial::PolynomialDialect>(); |
| 61 | + registry.insert<random::RandomDialect>(); |
| 62 | + registry.insert<rns::RNSDialect>(); |
| 63 | + registry.insert<tensor_ext::TensorExtDialect>(); |
| 64 | + |
| 65 | + // Register MLIR dialects |
| 66 | + registry.insert<affine::AffineDialect>(); |
| 67 | + registry.insert<mlir::arith::ArithDialect>(); |
| 68 | + registry.insert<func::FuncDialect>(); |
| 69 | + registry.insert<memref::MemRefDialect>(); |
| 70 | + registry.insert<scf::SCFDialect>(); |
| 71 | + registry.insert<tensor::TensorDialect>(); |
| 72 | + |
| 73 | + // Register MLIR passes |
| 74 | + registerTransformsPasses(); // canonicalize, cse, etc. |
| 75 | + affine::registerAffinePasses(); // loop unrolling, lower-affine |
| 76 | + |
| 77 | + // Register affine-to-standard conversion |
| 78 | + registerPass( |
| 79 | + []() -> std::unique_ptr<Pass> { return createLowerAffinePass(); }); |
| 80 | + |
| 81 | + // Register KeyMemRT transforms |
| 82 | + kmrt::registerKMRTPasses(); |
| 83 | + lwe::registerLWEPasses(); |
| 84 | + openfhe::registerOpenfhePasses(); |
| 85 | + registerAddRotationKeysPasses(); |
| 86 | + registerAnnotateModulePasses(); |
| 87 | + registerBootstrapRotationAnalysisPasses(); |
| 88 | + registerLowerLinearTransformPasses(); |
| 89 | + registerProfileAnnotatorPasses(); |
| 90 | + registerSymbolicBSGSDecompositionPasses(); |
| 91 | + registerUnnecessaryBootstrapRemovalPasses(); |
| 92 | + |
| 93 | + // Register KeyMemRT conversions |
| 94 | + ckks::registerCKKSToLWEPasses(); |
| 95 | + lwe::registerLWEToOpenfhePasses(); |
| 96 | + |
| 97 | + // Register KeyMemRT interfaces |
| 98 | + rns::registerExternalRNSTypeInterfaces(registry); |
| 99 | + registerOperandAndResultAttrInterface(registry); |
| 100 | + |
| 101 | + return asMainReturnCode( |
| 102 | + MlirOptMain(argc, argv, "KeyMemRT Pass Driver", registry)); |
| 103 | +} |
0 commit comments