diff --git a/lib/Analysis/LevelAnalysis/BUILD b/lib/Analysis/LevelAnalysis/BUILD index 80a6e53d6a..e409bc46b1 100644 --- a/lib/Analysis/LevelAnalysis/BUILD +++ b/lib/Analysis/LevelAnalysis/BUILD @@ -16,6 +16,7 @@ cc_library( "@heir//lib/Dialect:ModuleAttributes", "@heir//lib/Dialect/Mgmt/IR:Dialect", "@heir//lib/Dialect/Secret/IR:Dialect", + "@heir//lib/Target/CompilationTarget", "@heir//lib/Utils", "@heir//lib/Utils:AttributeUtils", "@llvm-project//llvm:Support", diff --git a/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp b/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp index addac46817..a0e85ae2a3 100644 --- a/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp +++ b/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp @@ -11,6 +11,7 @@ #include "lib/Dialect/Mgmt/IR/MgmtOps.h" #include "lib/Dialect/ModuleAttributes.h" #include "lib/Dialect/Secret/IR/SecretTypes.h" +#include "lib/Target/CompilationTarget/CompilationTarget.h" #include "lib/Utils/AttributeUtils.h" #include "lib/Utils/Utils.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project @@ -96,12 +97,16 @@ LevelState transferForward(mgmt::LevelReduceMinOp op, LevelState transferForward(mgmt::BootstrapOp op, ArrayRef operands) { + auto module = op->getParentOfType(); + const CompilationTarget* target = getTargetConfig(module); + int levelsConsumed = target ? target->bootstrapLevelsConsumed : 0; + LevelState result = std::visit( Overloaded{ - [](MaxLevel) -> LevelState { return LevelState(0); }, + [=](MaxLevel) -> LevelState { return LevelState(levelsConsumed); }, [](Uninit) -> LevelState { return LevelState(Invalid{}); }, [](Invalid) -> LevelState { return LevelState(Invalid{}); }, - [](int val) -> LevelState { return LevelState(0); }, + [=](int val) -> LevelState { return LevelState(levelsConsumed); }, }, operands[0]->getValue().get()); LLVM_DEBUG(debugLog("bootstrap", operands, result)); diff --git a/lib/Tablegen/BUILD b/lib/Tablegen/BUILD new file mode 100644 index 0000000000..9a855a1c9c --- /dev/null +++ b/lib/Tablegen/BUILD @@ -0,0 +1,27 @@ +load("@rules_cc//cc:cc_binary.bzl", "cc_binary") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "CompilationTargetEmitter", + srcs = ["CompilationTargetEmitter.cpp"], + hdrs = ["CompilationTargetEmitter.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TableGen", + ], +) + +cc_binary( + name = "heir-tblgen", + srcs = ["TablegenMain.cpp"], + deps = [ + ":CompilationTargetEmitter", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TableGen", + ], +) diff --git a/lib/Tablegen/CompilationTargetEmitter.cpp b/lib/Tablegen/CompilationTargetEmitter.cpp new file mode 100644 index 0000000000..0ff39de790 --- /dev/null +++ b/lib/Tablegen/CompilationTargetEmitter.cpp @@ -0,0 +1,28 @@ +#include "lib/Tablegen/CompilationTargetEmitter.h" + +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "llvm/include/llvm/TableGen/Record.h" // from @llvm-project +#include "llvm/include/llvm/TableGen/TableGenBackend.h" // from @llvm-project + +namespace mlir { +namespace heir { + +bool emitCompilationTargetRegistration(const llvm::RecordKeeper& records, + llvm::raw_ostream& os) { + auto targets = records.getAllDerivedDefinitions("CompilationTarget"); + + os << "CompilationTargetRegistry::CompilationTargetRegistry() {\n"; + for (auto* target : targets) { + auto backendName = target->getValueAsString("backendName"); + auto bootstrapLevelsConsumed = + target->getValueAsInt("bootstrapLevelsConsumed"); + + os << " targets[\"" << backendName << "\"] = CompilationTarget{\"" + << backendName << "\", " << (int)bootstrapLevelsConsumed << "};\n"; + } + os << "}\n"; + return false; +} + +} // namespace heir +} // namespace mlir \ No newline at end of file diff --git a/lib/Tablegen/CompilationTargetEmitter.h b/lib/Tablegen/CompilationTargetEmitter.h new file mode 100644 index 0000000000..3aa2140701 --- /dev/null +++ b/lib/Tablegen/CompilationTargetEmitter.h @@ -0,0 +1,15 @@ +#ifndef LIB_TABLEGEN_COMPILATIONTARGETEMITTER_H_ +#define LIB_TABLEGEN_COMPILATIONTARGETEMITTER_H_ + +#include "llvm/include/llvm/TableGen/Record.h" // from @llvm-project + +namespace mlir { +namespace heir { + +bool emitCompilationTargetRegistration(const llvm::RecordKeeper& records, + llvm::raw_ostream& os); + +} // namespace heir +} // namespace mlir + +#endif // LIB_TABLEGEN_COMPILATIONTARGETEMITTER_H_ diff --git a/lib/Tablegen/TablegenMain.cpp b/lib/Tablegen/TablegenMain.cpp new file mode 100644 index 0000000000..076ed7e384 --- /dev/null +++ b/lib/Tablegen/TablegenMain.cpp @@ -0,0 +1,35 @@ +#include "lib/Tablegen/CompilationTargetEmitter.h" +#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project +#include "llvm/include/llvm/Support/InitLLVM.h" // from @llvm-project +#include "llvm/include/llvm/TableGen/Main.h" // from @llvm-project +#include "llvm/include/llvm/TableGen/Record.h" // from @llvm-project + +using namespace mlir; +using namespace heir; + +enum ActionType { + None, + GenCompilationTargetRegistration, +}; + +static llvm::cl::opt action( + llvm::cl::desc("Action to perform:"), + llvm::cl::values(clEnumValN(GenCompilationTargetRegistration, + "gen-compilation-target-registration", + "Generate compilation target registration"))); + +bool heirTableGenMain(llvm::raw_ostream& os, + const llvm::RecordKeeper& records) { + switch (action) { + case GenCompilationTargetRegistration: + return emitCompilationTargetRegistration(records, os); + default: + return false; + } +} + +int main(int argc, char** argv) { + llvm::InitLLVM y(argc, argv); + llvm::cl::ParseCommandLineOptions(argc, argv); + return llvm::TableGenMain(argv[0], &heirTableGenMain); +} diff --git a/lib/Target/CompilationTarget/BUILD b/lib/Target/CompilationTarget/BUILD new file mode 100644 index 0000000000..e8db3df27b --- /dev/null +++ b/lib/Target/CompilationTarget/BUILD @@ -0,0 +1,45 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "CompilationTarget", + srcs = ["CompilationTarget.cpp"], + hdrs = ["CompilationTarget.h"], + deps = [ + ":compilation_target_inc_gen", + "@heir//lib/Dialect:ModuleAttributes", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +td_library( + name = "td_files", + srcs = ["HEIRTarget.td"], + includes = ["../../../.."], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "compilation_target_inc_gen", + tbl_outs = [ + ( + ["-gen-compilation-target-registration"], + "CompilationTarget.cpp.inc", + ), + ], + tblgen = "@heir//lib/Tablegen:heir-tblgen", + td_file = "HEIRTarget.td", + deps = [ + ":td_files", + ], +) + +exports_files(["HEIRTarget.td"]) diff --git a/lib/Target/CompilationTarget/CompilationTarget.cpp b/lib/Target/CompilationTarget/CompilationTarget.cpp new file mode 100644 index 0000000000..5324380fd6 --- /dev/null +++ b/lib/Target/CompilationTarget/CompilationTarget.cpp @@ -0,0 +1,38 @@ +#include "lib/Target/CompilationTarget/CompilationTarget.h" + +#include "lib/Dialect/ModuleAttributes.h" +#include "llvm/include/llvm/ADT/StringMap.h" // from @llvm-project +#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#include "lib/Target/CompilationTarget/CompilationTarget.cpp.inc" + +CompilationTargetRegistry& CompilationTargetRegistry::getInstance() { + static CompilationTargetRegistry instance; + return instance; +} + +const CompilationTarget* CompilationTargetRegistry::get(llvm::StringRef name) { + auto& instance = getInstance(); + auto it = instance.targets.find(name); + if (it == instance.targets.end()) { + return nullptr; + } + return &it->second; +} + +const CompilationTarget* getTargetConfig(ModuleOp module) { + for (auto attr : module->getAttrs()) { + llvm::StringRef name = attr.getName().strref(); + if (name.consume_front("backend.")) { + return CompilationTargetRegistry::get(name); + } + } + return nullptr; +} + +} // namespace heir +} // namespace mlir diff --git a/lib/Target/CompilationTarget/CompilationTarget.h b/lib/Target/CompilationTarget/CompilationTarget.h new file mode 100644 index 0000000000..fe61f81982 --- /dev/null +++ b/lib/Target/CompilationTarget/CompilationTarget.h @@ -0,0 +1,34 @@ +#ifndef LIB_TARGET_COMPILATIONTARGET_COMPILATIONTARGET_H_ +#define LIB_TARGET_COMPILATIONTARGET_COMPILATIONTARGET_H_ + +#include + +#include "llvm/include/llvm/ADT/StringMap.h" // from @llvm-project +#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace mlir { +namespace heir { + +struct CompilationTarget { + std::string backendName; + int bootstrapLevelsConsumed; +}; + +class CompilationTargetRegistry { + public: + static const CompilationTarget* get(llvm::StringRef name); + + private: + CompilationTargetRegistry(); + static CompilationTargetRegistry& getInstance(); + + llvm::StringMap targets; +}; + +const CompilationTarget* getTargetConfig(ModuleOp module); + +} // namespace heir +} // namespace mlir + +#endif // LIB_TARGET_COMPILATIONTARGET_COMPILATIONTARGET_H_ diff --git a/lib/Target/CompilationTarget/HEIRTarget.td b/lib/Target/CompilationTarget/HEIRTarget.td new file mode 100644 index 0000000000..14c802ad1d --- /dev/null +++ b/lib/Target/CompilationTarget/HEIRTarget.td @@ -0,0 +1,12 @@ +class CompilationTarget { + string backendName = name; + int bootstrapLevelsConsumed = 0; +} + +def OpenFHE : CompilationTarget<"openfhe"> { + let bootstrapLevelsConsumed = 3; +} + +def Lattigo : CompilationTarget<"lattigo"> { + let bootstrapLevelsConsumed = 1; +} \ No newline at end of file diff --git a/lib/Transforms/AnnotateModule/AnnotateModule.cpp b/lib/Transforms/AnnotateModule/AnnotateModule.cpp index c86dbc2983..90c5f04583 100644 --- a/lib/Transforms/AnnotateModule/AnnotateModule.cpp +++ b/lib/Transforms/AnnotateModule/AnnotateModule.cpp @@ -1,6 +1,7 @@ #include "lib/Transforms/AnnotateModule/AnnotateModule.h" #include "lib/Dialect/ModuleAttributes.h" +#include "lib/Target/CompilationTarget/CompilationTarget.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project @@ -26,10 +27,18 @@ struct AnnotateModule : impl::AnnotateModuleBase { moduleSetCGGI(module); } - if (backend == "openfhe") { - moduleSetOpenfhe(module); - } else if (backend == "lattigo") { - moduleSetLattigo(module); + if (!backend.empty()) { + if (!CompilationTargetRegistry::get(backend)) { + module.emitError() << "Unknown backend: " << backend; + signalPassFailure(); + return; + } + + if (backend == "openfhe") { + moduleSetOpenfhe(module); + } else if (backend == "lattigo") { + moduleSetLattigo(module); + } } } }; diff --git a/lib/Transforms/AnnotateModule/BUILD b/lib/Transforms/AnnotateModule/BUILD index 8939e4cd25..092db49258 100644 --- a/lib/Transforms/AnnotateModule/BUILD +++ b/lib/Transforms/AnnotateModule/BUILD @@ -15,6 +15,7 @@ cc_library( deps = [ ":pass_inc_gen", "@heir//lib/Dialect:ModuleAttributes", + "@heir//lib/Target/CompilationTarget", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", diff --git a/tests/Dialect/Mgmt/Transforms/bootstrap_levels.mlir b/tests/Dialect/Mgmt/Transforms/bootstrap_levels.mlir new file mode 100644 index 0000000000..0dfbf86f16 --- /dev/null +++ b/tests/Dialect/Mgmt/Transforms/bootstrap_levels.mlir @@ -0,0 +1,19 @@ +// RUN: heir-opt --annotate-module="backend=openfhe" --annotate-mgmt %s | FileCheck %s --check-prefix=CHECK-OPENFHE +// RUN: heir-opt --annotate-module="backend=lattigo" --annotate-mgmt %s | FileCheck %s --check-prefix=CHECK-LATTIGO + +func.func @main(%arg0: !secret.secret>) -> !secret.secret> { + %b = secret.generic(%arg0: !secret.secret>) { + ^body(%clear_a: tensor<8xi8>): + %c = mgmt.bootstrap %clear_a : tensor<8xi8> + secret.yield %c : tensor<8xi8> + } -> !secret.secret> + func.return %b : !secret.secret> +} + +// CHECK-OPENFHE: func.func @main(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) +// CHECK-OPENFHE: mgmt.bootstrap +// CHECK-OPENFHE-SAME: {mgmt.mgmt = #mgmt.mgmt{{.*}}} + +// CHECK-LATTIGO: func.func @main(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) +// CHECK-LATTIGO: mgmt.bootstrap +// CHECK-LATTIGO-SAME: {mgmt.mgmt = #mgmt.mgmt{{.*}}} diff --git a/tests/Transforms/annotate_module/invalid_backend.mlir b/tests/Transforms/annotate_module/invalid_backend.mlir new file mode 100644 index 0000000000..15f1596b1e --- /dev/null +++ b/tests/Transforms/annotate_module/invalid_backend.mlir @@ -0,0 +1,5 @@ +// RUN: heir-opt --annotate-module="backend=invalid_backend" --verify-diagnostics %s + +// expected-error @+1 {{Unknown backend: invalid_backend}} +module { +} \ No newline at end of file